Coverage for src/srunx/client.py: 87%
155 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:10 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:10 +0000
1"""SLURM client for job submission and management."""
3import subprocess
4import tempfile
5import time
6from collections.abc import Sequence
7from importlib.resources import files
8from pathlib import Path
10from srunx.callbacks import Callback
11from srunx.logging import get_logger
12from srunx.models import (
13 BaseJob,
14 Job,
15 JobStatus,
16 JobType,
17 RunableJobType,
18 ShellJob,
19 render_job_script,
20)
21from srunx.utils import get_job_status, job_status_msg
23logger = get_logger(__name__)
26class Slurm:
27 """Client for interacting with SLURM workload manager."""
29 def __init__(
30 self,
31 default_template: str | None = None,
32 callbacks: Sequence[Callback] | None = None,
33 ):
34 """Initialize SLURM client.
36 Args:
37 default_template: Path to default job template.
38 callbacks: List of callbacks.
39 """
40 self.default_template = default_template or self._get_default_template()
41 self.callbacks = list(callbacks) if callbacks else []
43 def submit(
44 self,
45 job: RunableJobType,
46 template_path: str | None = None,
47 callbacks: Sequence[Callback] | None = None,
48 verbose: bool = False,
49 ) -> RunableJobType:
50 """Submit a job to SLURM.
52 Args:
53 job: Job configuration.
54 template_path: Optional template path (uses default if not provided).
55 callbacks: List of callbacks.
56 verbose: Whether to print the rendered content.
58 Returns:
59 Job instance with updated job_id and status.
61 Raises:
62 subprocess.CalledProcessError: If job submission fails.
63 """
65 if isinstance(job, Job):
66 template = template_path or self.default_template
68 with tempfile.TemporaryDirectory() as temp_dir:
69 script_path = render_job_script(template, job, temp_dir, verbose)
70 logger.debug(f"Generated SLURM script at: {script_path}")
72 # Handle container execution
73 sbatch_cmd = ["sbatch"]
74 if job.environment.sqsh:
75 sbatch_cmd.extend(["--sqsh", job.environment.sqsh])
76 logger.debug(f"Using sqsh: {job.environment.sqsh}")
78 sbatch_cmd.append(script_path)
79 logger.debug(f"Executing command: {' '.join(sbatch_cmd)}")
81 try:
82 result = subprocess.run(
83 sbatch_cmd,
84 capture_output=True,
85 text=True,
86 check=True,
87 )
88 except subprocess.CalledProcessError as e:
89 logger.error(f"Failed to submit job '{job.name}': {e}")
90 logger.error(f"Command: {' '.join(e.cmd)}")
91 logger.error(f"Return code: {e.returncode}")
92 logger.error(f"Stdout: {e.stdout}")
93 logger.error(f"Stderr: {e.stderr}")
94 raise
96 elif isinstance(job, ShellJob):
97 try:
98 result = subprocess.run(
99 ["sbatch", job.path],
100 capture_output=True,
101 text=True,
102 check=True,
103 )
104 except subprocess.CalledProcessError as e:
105 logger.error(f"Failed to submit job '{job.name}': {e}")
106 logger.error(f"Command: {' '.join(e.cmd)}")
107 logger.error(f"Return code: {e.returncode}")
108 logger.error(f"Stdout: {e.stdout}")
109 logger.error(f"Stderr: {e.stderr}")
110 raise
112 else:
113 raise ValueError("Either 'command' or 'file' must be set")
115 time.sleep(3)
116 job_id = int(result.stdout.split()[-1])
117 job.job_id = job_id
118 job.status = JobStatus.PENDING
120 logger.debug(f"Successfully submitted job '{job.name}' with ID {job_id}")
122 all_callbacks = self.callbacks[:]
123 if callbacks:
124 all_callbacks.extend(callbacks)
125 for callback in all_callbacks:
126 callback.on_job_submitted(job)
128 return job
130 @staticmethod
131 def retrieve(job_id: int) -> BaseJob:
132 """Retrieve job information from SLURM.
134 Args:
135 job_id: SLURM job ID.
137 Returns:
138 Job object with current status.
139 """
140 return get_job_status(job_id)
142 def cancel(self, job_id: int) -> None:
143 """Cancel a SLURM job.
145 Args:
146 job_id: SLURM job ID to cancel.
148 Raises:
149 subprocess.CalledProcessError: If job cancellation fails.
150 """
151 logger.info(f"Cancelling job {job_id}")
153 try:
154 subprocess.run(
155 ["scancel", str(job_id)],
156 check=True,
157 )
158 logger.info(f"Successfully cancelled job {job_id}")
159 except subprocess.CalledProcessError as e:
160 logger.error(f"Failed to cancel job {job_id}: {e}")
161 raise
163 def queue(self, user: str | None = None) -> list[BaseJob]:
164 """List jobs for a user.
166 Args:
167 user: Username (defaults to current user).
169 Returns:
170 List of Job objects.
171 """
172 cmd = [
173 "squeue",
174 "--format",
175 "%.18i %.9P %.15j %.8u %.8T %.10M %.9l %.6D %R",
176 "--noheader",
177 ]
178 if user:
179 cmd.extend(["--user", user])
181 result = subprocess.run(cmd, capture_output=True, text=True, check=True)
183 jobs = []
184 for line in result.stdout.strip().split("\n"):
185 if not line.strip():
186 continue
188 parts = line.split()
189 if len(parts) >= 5:
190 job_id = int(parts[0])
191 job_name = parts[2]
192 status_str = parts[4]
194 try:
195 status = JobStatus(status_str)
196 except ValueError:
197 status = JobStatus.PENDING # Default for unknown status
199 job = BaseJob(
200 name=job_name,
201 job_id=job_id,
202 )
203 job.status = status
204 jobs.append(job)
206 return jobs
208 def monitor(
209 self,
210 job_obj_or_id: JobType | int,
211 poll_interval: int = 5,
212 callbacks: Sequence[Callback] | None = None,
213 ) -> JobType:
214 """Wait for a job to complete.
216 Args:
217 job_obj_or_id: Job object or job ID.
218 poll_interval: Polling interval in seconds.
219 callbacks: List of callbacks.
221 Returns:
222 Completed job object.
224 Raises:
225 RuntimeError: If job fails.
226 """
227 if isinstance(job_obj_or_id, int):
228 job = self.retrieve(job_obj_or_id)
229 else:
230 job = job_obj_or_id
232 all_callbacks = self.callbacks[:]
233 if callbacks:
234 all_callbacks.extend(callbacks)
236 msg = f"👀 {'MONITORING':<12} Job {job.name:<12} (ID: {job.job_id})"
237 logger.info(msg)
239 previous_status = None
241 while True:
242 job.refresh()
244 # Log status changes
245 if job.status != previous_status:
246 status_str = job.status.value if job.status else "Unknown"
247 logger.debug(f"Job(name={job.name}, id={job.job_id}) is {status_str}")
248 previous_status = job.status
250 match job.status:
251 case JobStatus.COMPLETED:
252 logger.info(job_status_msg(job))
253 for callback in all_callbacks:
254 callback.on_job_completed(job)
255 return job
256 case JobStatus.FAILED:
257 err_msg = job_status_msg(job) + "\n"
258 if isinstance(job, Job):
259 log_file = Path(job.log_dir) / f"{job.name}_{job.job_id}.out"
260 if log_file.exists():
261 with open(log_file) as f:
262 err_msg += f.read()
263 err_msg += f"\nLog file: {log_file}"
264 else:
265 err_msg += f"Log file not found: {log_file}"
266 for callback in all_callbacks:
267 callback.on_job_failed(job)
268 raise RuntimeError(err_msg)
269 case JobStatus.CANCELLED | JobStatus.TIMEOUT:
270 err_msg = job_status_msg(job) + "\n"
271 if isinstance(job, Job):
272 log_file = Path(job.log_dir) / f"{job.name}_{job.job_id}.out"
273 if log_file.exists():
274 with open(log_file) as f:
275 err_msg += f.read()
276 err_msg += f"\nLog file: {log_file}"
277 else:
278 err_msg += f"Log file not found: {log_file}"
279 for callback in all_callbacks:
280 callback.on_job_cancelled(job)
281 raise RuntimeError(err_msg)
282 time.sleep(poll_interval)
284 def run(
285 self,
286 job: RunableJobType,
287 template_path: str | None = None,
288 callbacks: Sequence[Callback] | None = None,
289 poll_interval: int = 5,
290 verbose: bool = False,
291 ) -> RunableJobType:
292 """Submit a job and wait for completion."""
293 submitted_job = self.submit(
294 job, template_path=template_path, callbacks=callbacks, verbose=verbose
295 )
296 monitored_job = self.monitor(
297 submitted_job, poll_interval=poll_interval, callbacks=callbacks
298 )
300 # Ensure the return type matches the expected type
301 if isinstance(monitored_job, Job | ShellJob):
302 return monitored_job
303 else:
304 # This should not happen in practice, but needed for type safety
305 return submitted_job
307 def _get_default_template(self) -> str:
308 """Get the default job template path."""
309 return str(files("srunx.templates").joinpath("base.slurm.jinja"))
312# Convenience functions for backward compatibility
313def submit_job(
314 job: RunableJobType,
315 template_path: str | None = None,
316 callbacks: Sequence[Callback] | None = None,
317 verbose: bool = False,
318) -> RunableJobType:
319 """Submit a job to SLURM (convenience function).
321 Args:
322 job: Job configuration.
323 template_path: Optional template path (uses default if not provided).
324 callbacks: List of callbacks.
325 verbose: Whether to print the rendered content.
326 """
327 client = Slurm()
328 return client.submit(
329 job, template_path=template_path, callbacks=callbacks, verbose=verbose
330 )
333def retrieve_job(job_id: int) -> BaseJob:
334 """Get job status (convenience function).
336 Args:
337 job_id: SLURM job ID.
338 """
339 client = Slurm()
340 return client.retrieve(job_id)
343def cancel_job(job_id: int) -> None:
344 """Cancel a job (convenience function).
346 Args:
347 job_id: SLURM job ID.
348 """
349 client = Slurm()
350 client.cancel(job_id)