Coverage for src/srunx/client.py: 87%

155 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 15:10 +0000

1"""SLURM client for job submission and management.""" 

2 

3import subprocess 

4import tempfile 

5import time 

6from collections.abc import Sequence 

7from importlib.resources import files 

8from pathlib import Path 

9 

10from srunx.callbacks import Callback 

11from srunx.logging import get_logger 

12from srunx.models import ( 

13 BaseJob, 

14 Job, 

15 JobStatus, 

16 JobType, 

17 RunableJobType, 

18 ShellJob, 

19 render_job_script, 

20) 

21from srunx.utils import get_job_status, job_status_msg 

22 

23logger = get_logger(__name__) 

24 

25 

26class Slurm: 

27 """Client for interacting with SLURM workload manager.""" 

28 

29 def __init__( 

30 self, 

31 default_template: str | None = None, 

32 callbacks: Sequence[Callback] | None = None, 

33 ): 

34 """Initialize SLURM client. 

35 

36 Args: 

37 default_template: Path to default job template. 

38 callbacks: List of callbacks. 

39 """ 

40 self.default_template = default_template or self._get_default_template() 

41 self.callbacks = list(callbacks) if callbacks else [] 

42 

43 def submit( 

44 self, 

45 job: RunableJobType, 

46 template_path: str | None = None, 

47 callbacks: Sequence[Callback] | None = None, 

48 verbose: bool = False, 

49 ) -> RunableJobType: 

50 """Submit a job to SLURM. 

51 

52 Args: 

53 job: Job configuration. 

54 template_path: Optional template path (uses default if not provided). 

55 callbacks: List of callbacks. 

56 verbose: Whether to print the rendered content. 

57 

58 Returns: 

59 Job instance with updated job_id and status. 

60 

61 Raises: 

62 subprocess.CalledProcessError: If job submission fails. 

63 """ 

64 

65 if isinstance(job, Job): 

66 template = template_path or self.default_template 

67 

68 with tempfile.TemporaryDirectory() as temp_dir: 

69 script_path = render_job_script(template, job, temp_dir, verbose) 

70 logger.debug(f"Generated SLURM script at: {script_path}") 

71 

72 # Handle container execution 

73 sbatch_cmd = ["sbatch"] 

74 if job.environment.sqsh: 

75 sbatch_cmd.extend(["--sqsh", job.environment.sqsh]) 

76 logger.debug(f"Using sqsh: {job.environment.sqsh}") 

77 

78 sbatch_cmd.append(script_path) 

79 logger.debug(f"Executing command: {' '.join(sbatch_cmd)}") 

80 

81 try: 

82 result = subprocess.run( 

83 sbatch_cmd, 

84 capture_output=True, 

85 text=True, 

86 check=True, 

87 ) 

88 except subprocess.CalledProcessError as e: 

89 logger.error(f"Failed to submit job '{job.name}': {e}") 

90 logger.error(f"Command: {' '.join(e.cmd)}") 

91 logger.error(f"Return code: {e.returncode}") 

92 logger.error(f"Stdout: {e.stdout}") 

93 logger.error(f"Stderr: {e.stderr}") 

94 raise 

95 

96 elif isinstance(job, ShellJob): 

97 try: 

98 result = subprocess.run( 

99 ["sbatch", job.path], 

100 capture_output=True, 

101 text=True, 

102 check=True, 

103 ) 

104 except subprocess.CalledProcessError as e: 

105 logger.error(f"Failed to submit job '{job.name}': {e}") 

106 logger.error(f"Command: {' '.join(e.cmd)}") 

107 logger.error(f"Return code: {e.returncode}") 

108 logger.error(f"Stdout: {e.stdout}") 

109 logger.error(f"Stderr: {e.stderr}") 

110 raise 

111 

112 else: 

113 raise ValueError("Either 'command' or 'file' must be set") 

114 

115 time.sleep(3) 

116 job_id = int(result.stdout.split()[-1]) 

117 job.job_id = job_id 

118 job.status = JobStatus.PENDING 

119 

120 logger.debug(f"Successfully submitted job '{job.name}' with ID {job_id}") 

121 

122 all_callbacks = self.callbacks[:] 

123 if callbacks: 

124 all_callbacks.extend(callbacks) 

125 for callback in all_callbacks: 

126 callback.on_job_submitted(job) 

127 

128 return job 

129 

130 @staticmethod 

131 def retrieve(job_id: int) -> BaseJob: 

132 """Retrieve job information from SLURM. 

133 

134 Args: 

135 job_id: SLURM job ID. 

136 

137 Returns: 

138 Job object with current status. 

139 """ 

140 return get_job_status(job_id) 

141 

142 def cancel(self, job_id: int) -> None: 

143 """Cancel a SLURM job. 

144 

145 Args: 

146 job_id: SLURM job ID to cancel. 

147 

148 Raises: 

149 subprocess.CalledProcessError: If job cancellation fails. 

150 """ 

151 logger.info(f"Cancelling job {job_id}") 

152 

153 try: 

154 subprocess.run( 

155 ["scancel", str(job_id)], 

156 check=True, 

157 ) 

158 logger.info(f"Successfully cancelled job {job_id}") 

159 except subprocess.CalledProcessError as e: 

160 logger.error(f"Failed to cancel job {job_id}: {e}") 

161 raise 

162 

163 def queue(self, user: str | None = None) -> list[BaseJob]: 

164 """List jobs for a user. 

165 

166 Args: 

167 user: Username (defaults to current user). 

168 

169 Returns: 

170 List of Job objects. 

171 """ 

172 cmd = [ 

173 "squeue", 

174 "--format", 

175 "%.18i %.9P %.15j %.8u %.8T %.10M %.9l %.6D %R", 

176 "--noheader", 

177 ] 

178 if user: 

179 cmd.extend(["--user", user]) 

180 

181 result = subprocess.run(cmd, capture_output=True, text=True, check=True) 

182 

183 jobs = [] 

184 for line in result.stdout.strip().split("\n"): 

185 if not line.strip(): 

186 continue 

187 

188 parts = line.split() 

189 if len(parts) >= 5: 

190 job_id = int(parts[0]) 

191 job_name = parts[2] 

192 status_str = parts[4] 

193 

194 try: 

195 status = JobStatus(status_str) 

196 except ValueError: 

197 status = JobStatus.PENDING # Default for unknown status 

198 

199 job = BaseJob( 

200 name=job_name, 

201 job_id=job_id, 

202 ) 

203 job.status = status 

204 jobs.append(job) 

205 

206 return jobs 

207 

208 def monitor( 

209 self, 

210 job_obj_or_id: JobType | int, 

211 poll_interval: int = 5, 

212 callbacks: Sequence[Callback] | None = None, 

213 ) -> JobType: 

214 """Wait for a job to complete. 

215 

216 Args: 

217 job_obj_or_id: Job object or job ID. 

218 poll_interval: Polling interval in seconds. 

219 callbacks: List of callbacks. 

220 

221 Returns: 

222 Completed job object. 

223 

224 Raises: 

225 RuntimeError: If job fails. 

226 """ 

227 if isinstance(job_obj_or_id, int): 

228 job = self.retrieve(job_obj_or_id) 

229 else: 

230 job = job_obj_or_id 

231 

232 all_callbacks = self.callbacks[:] 

233 if callbacks: 

234 all_callbacks.extend(callbacks) 

235 

236 msg = f"👀 {'MONITORING':<12} Job {job.name:<12} (ID: {job.job_id})" 

237 logger.info(msg) 

238 

239 previous_status = None 

240 

241 while True: 

242 job.refresh() 

243 

244 # Log status changes 

245 if job.status != previous_status: 

246 status_str = job.status.value if job.status else "Unknown" 

247 logger.debug(f"Job(name={job.name}, id={job.job_id}) is {status_str}") 

248 previous_status = job.status 

249 

250 match job.status: 

251 case JobStatus.COMPLETED: 

252 logger.info(job_status_msg(job)) 

253 for callback in all_callbacks: 

254 callback.on_job_completed(job) 

255 return job 

256 case JobStatus.FAILED: 

257 err_msg = job_status_msg(job) + "\n" 

258 if isinstance(job, Job): 

259 log_file = Path(job.log_dir) / f"{job.name}_{job.job_id}.out" 

260 if log_file.exists(): 

261 with open(log_file) as f: 

262 err_msg += f.read() 

263 err_msg += f"\nLog file: {log_file}" 

264 else: 

265 err_msg += f"Log file not found: {log_file}" 

266 for callback in all_callbacks: 

267 callback.on_job_failed(job) 

268 raise RuntimeError(err_msg) 

269 case JobStatus.CANCELLED | JobStatus.TIMEOUT: 

270 err_msg = job_status_msg(job) + "\n" 

271 if isinstance(job, Job): 

272 log_file = Path(job.log_dir) / f"{job.name}_{job.job_id}.out" 

273 if log_file.exists(): 

274 with open(log_file) as f: 

275 err_msg += f.read() 

276 err_msg += f"\nLog file: {log_file}" 

277 else: 

278 err_msg += f"Log file not found: {log_file}" 

279 for callback in all_callbacks: 

280 callback.on_job_cancelled(job) 

281 raise RuntimeError(err_msg) 

282 time.sleep(poll_interval) 

283 

284 def run( 

285 self, 

286 job: RunableJobType, 

287 template_path: str | None = None, 

288 callbacks: Sequence[Callback] | None = None, 

289 poll_interval: int = 5, 

290 verbose: bool = False, 

291 ) -> RunableJobType: 

292 """Submit a job and wait for completion.""" 

293 submitted_job = self.submit( 

294 job, template_path=template_path, callbacks=callbacks, verbose=verbose 

295 ) 

296 monitored_job = self.monitor( 

297 submitted_job, poll_interval=poll_interval, callbacks=callbacks 

298 ) 

299 

300 # Ensure the return type matches the expected type 

301 if isinstance(monitored_job, Job | ShellJob): 

302 return monitored_job 

303 else: 

304 # This should not happen in practice, but needed for type safety 

305 return submitted_job 

306 

307 def _get_default_template(self) -> str: 

308 """Get the default job template path.""" 

309 return str(files("srunx.templates").joinpath("base.slurm.jinja")) 

310 

311 

312# Convenience functions for backward compatibility 

313def submit_job( 

314 job: RunableJobType, 

315 template_path: str | None = None, 

316 callbacks: Sequence[Callback] | None = None, 

317 verbose: bool = False, 

318) -> RunableJobType: 

319 """Submit a job to SLURM (convenience function). 

320 

321 Args: 

322 job: Job configuration. 

323 template_path: Optional template path (uses default if not provided). 

324 callbacks: List of callbacks. 

325 verbose: Whether to print the rendered content. 

326 """ 

327 client = Slurm() 

328 return client.submit( 

329 job, template_path=template_path, callbacks=callbacks, verbose=verbose 

330 ) 

331 

332 

333def retrieve_job(job_id: int) -> BaseJob: 

334 """Get job status (convenience function). 

335 

336 Args: 

337 job_id: SLURM job ID. 

338 """ 

339 client = Slurm() 

340 return client.retrieve(job_id) 

341 

342 

343def cancel_job(job_id: int) -> None: 

344 """Cancel a job (convenience function). 

345 

346 Args: 

347 job_id: SLURM job ID. 

348 """ 

349 client = Slurm() 

350 client.cancel(job_id)