Coverage for src/srunx/models.py: 81%
168 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:10 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:10 +0000
1"""Data models for SLURM job management."""
3import os
4import subprocess
5import time
6from enum import Enum
7from pathlib import Path
8from typing import Self
10import jinja2
11from pydantic import BaseModel, Field, PrivateAttr, model_validator
13from srunx.exceptions import WorkflowValidationError
14from srunx.logging import get_logger
16logger = get_logger(__name__)
19class JobStatus(Enum):
20 """Job status enumeration for both SLURM jobs and workflow tasks."""
22 UNKNOWN = "UNKNOWN"
23 PENDING = "PENDING"
24 RUNNING = "RUNNING"
25 COMPLETED = "COMPLETED"
26 FAILED = "FAILED"
27 CANCELLED = "CANCELLED"
28 TIMEOUT = "TIMEOUT"
31class JobResource(BaseModel):
32 """SLURM resource allocation requirements."""
34 nodes: int = Field(default=1, ge=1, description="Number of compute nodes")
35 gpus_per_node: int = Field(default=0, ge=0, description="Number of GPUs per node")
36 ntasks_per_node: int = Field(
37 default=1, ge=1, description="Number of tasks per node"
38 )
39 cpus_per_task: int = Field(default=1, ge=1, description="Number of CPUs per task")
40 memory_per_node: str | None = Field(
41 default=None, description="Memory per node (e.g., '32GB')"
42 )
43 time_limit: str | None = Field(
44 default=None, description="Time limit (e.g., '1:00:00')"
45 )
48class JobEnvironment(BaseModel):
49 """Job environment configuration."""
51 conda: str | None = Field(default=None, description="Conda environment name")
52 venv: str | None = Field(default=None, description="Virtual environment path")
53 sqsh: str | None = Field(default=None, description="SquashFS image path")
54 env_vars: dict[str, str] = Field(
55 default_factory=dict, description="Environment variables"
56 )
58 @model_validator(mode="after")
59 def validate_environment(self) -> Self:
60 envs = [self.conda, self.venv, self.sqsh]
61 non_none_count = sum(x is not None for x in envs)
62 if non_none_count != 1:
63 raise ValueError("Exactly one of 'conda', 'venv', or 'sqsh' must be set")
64 return self
67class BaseJob(BaseModel):
68 name: str = Field(default="job", description="Job name")
69 job_id: int | None = Field(default=None, description="SLURM job ID")
70 depends_on: list[str] = Field(
71 default_factory=list, description="Task dependencies for workflow execution"
72 )
74 _status: JobStatus = PrivateAttr(default=JobStatus.PENDING)
76 @property
77 def status(self) -> JobStatus:
78 """
79 Accessing ``job.status`` always triggers a lightweight refresh
80 (only if we have a ``job_id`` and the status isn't terminal).
81 """
82 if self.job_id is not None and self._status not in {
83 JobStatus.COMPLETED,
84 JobStatus.FAILED,
85 JobStatus.CANCELLED,
86 JobStatus.TIMEOUT,
87 }:
88 self.refresh()
89 return self._status
91 @status.setter
92 def status(self, value: JobStatus) -> None:
93 self._status = value
95 def refresh(self, retries: int = 3) -> Self:
96 """Query sacct and update ``_status`` in-place."""
97 if self.job_id is None:
98 return self
100 for retry in range(retries):
101 try:
102 result = subprocess.run(
103 [
104 "sacct",
105 "-j",
106 str(self.job_id),
107 "--format",
108 "JobID,State",
109 "--noheader",
110 "--parsable2",
111 ],
112 capture_output=True,
113 text=True,
114 check=True,
115 )
116 except subprocess.CalledProcessError as e:
117 logger.error(f"Failed to query job {self.job_id}: {e}")
118 raise
120 line = result.stdout.strip().split("\n")[0] if result.stdout.strip() else ""
121 if not line:
122 if retry < retries - 1:
123 time.sleep(1)
124 continue
125 self._status = JobStatus.UNKNOWN
126 return self
127 break
129 _, state = line.split("|", 1)
130 self._status = JobStatus(state)
131 return self
133 def dependencies_satisfied(self, completed_job_names: list[str]) -> bool:
134 """All dependencies are completed & this job is still pending."""
135 return self.status == JobStatus.PENDING and all(
136 dep in completed_job_names for dep in self.depends_on
137 )
140class Job(BaseJob):
141 """Represents a SLURM job with complete configuration."""
143 command: list[str] = Field(description="Command to execute")
144 resources: JobResource = Field(
145 default_factory=JobResource, description="Resource requirements"
146 )
147 environment: JobEnvironment = Field(
148 default_factory=JobEnvironment, description="Environment setup"
149 )
150 log_dir: str = Field(
151 default=os.getenv("SLURM_LOG_DIR", "logs"),
152 description="Directory for log files",
153 )
154 work_dir: str = Field(default_factory=os.getcwd, description="Working directory")
157class ShellJob(BaseJob):
158 path: str = Field(description="Shell script path to execute")
161type JobType = BaseJob | Job | ShellJob
162type RunableJobType = Job | ShellJob
165class Workflow(BaseModel):
166 """Represents a workflow containing multiple jobs with dependencies."""
168 name: str = Field(description="Workflow name")
169 jobs: list[RunableJobType] = Field(description="List of jobs in the workflow")
171 def get(self, name: str) -> RunableJobType | None:
172 """Get a job by name."""
173 for job in self.jobs:
174 if job.name == name:
175 return job.refresh()
176 return None
178 def get_dependencies(self, job_name: str) -> list[str]:
179 """Get dependencies for a specific job."""
180 job = self.get(job_name)
181 return job.depends_on if job else []
183 def show(self):
184 msg = f"""\
185{" PLAN ":=^80}
186Workflow: {self.name}
187Jobs: {len(self.jobs)}
188"""
190 def add_indent(indent: int, msg: str) -> str:
191 return " " * indent + msg
193 for job in self.jobs:
194 msg += add_indent(1, f"Job: {job.name}\n")
195 if isinstance(job, Job):
196 msg += add_indent(
197 2, f"{'Command:': <13} {' '.join(job.command or [])}\n"
198 )
199 msg += add_indent(
200 2,
201 f"{'Resources:': <13} {job.resources.nodes} nodes, {job.resources.gpus_per_node} GPUs/node\n",
202 )
203 if job.environment.conda:
204 msg += add_indent(
205 2, f"{'Conda env:': <13} {job.environment.conda}\n"
206 )
207 if job.environment.sqsh:
208 msg += add_indent(2, f"{'Sqsh:': <13} {job.environment.sqsh}\n")
209 if job.environment.venv:
210 msg += add_indent(2, f"{'Venv:': <13} {job.environment.venv}\n")
211 elif isinstance(job, ShellJob):
212 msg += add_indent(2, f"{'Path:': <13} {job.path}\n")
213 if job.depends_on:
214 msg += add_indent(
215 2, f"{'Dependencies:': <13} {', '.join(job.depends_on)}\n"
216 )
218 msg += f"{'=' * 80}\n"
219 print(msg)
221 def validate(self):
222 """Validate workflow job dependencies."""
223 job_names = {job.name for job in self.jobs}
225 if len(job_names) != len(self.jobs):
226 raise WorkflowValidationError("Duplicate job names found in workflow")
228 for job in self.jobs:
229 for dependency in job.depends_on:
230 if dependency not in job_names:
231 raise WorkflowValidationError(
232 f"Job '{job.name}' depends on unknown job '{dependency}'"
233 )
235 # Check for circular dependencies (simple check)
236 visited = set()
237 rec_stack = set()
239 def has_cycle(job_name: str) -> bool:
240 if job_name in rec_stack:
241 return True
242 if job_name in visited:
243 return False
245 visited.add(job_name)
246 rec_stack.add(job_name)
248 job = self.get(job_name)
249 if job:
250 for dependency in job.depends_on:
251 if has_cycle(dependency):
252 return True
254 rec_stack.remove(job_name)
255 return False
257 for job in self.jobs:
258 if has_cycle(job.name):
259 raise WorkflowValidationError(
260 f"Circular dependency detected involving job '{job.name}'"
261 )
264def render_job_script(
265 template_path: Path | str,
266 job: Job,
267 output_dir: Path | str,
268 verbose: bool = False,
269) -> str:
270 """Render a SLURM job script from a template.
272 Args:
273 template_path: Path to the Jinja template file.
274 job: Job configuration.
275 output_dir: Directory where the generated script will be saved.
276 verbose: Whether to print the rendered content.
278 Returns:
279 Path to the generated SLURM batch script.
281 Raises:
282 FileNotFoundError: If the template file does not exist.
283 jinja2.TemplateError: If template rendering fails.
284 """
285 template_file = Path(template_path)
286 if not template_file.is_file():
287 raise FileNotFoundError(f"Template file '{template_path}' not found")
289 with open(template_file, encoding="utf-8") as f:
290 template_content = f.read()
292 template = jinja2.Template(template_content, undefined=jinja2.StrictUndefined)
294 # Prepare template variables
295 template_vars = {
296 "job_name": job.name,
297 "command": " ".join(job.command or []),
298 "log_dir": job.log_dir,
299 "work_dir": job.work_dir,
300 "environment_setup": _build_environment_setup(job.environment),
301 **job.resources.model_dump(),
302 }
304 rendered_content = template.render(template_vars)
306 if verbose:
307 print(rendered_content)
309 # Generate output file
310 output_path = Path(output_dir) / f"{job.name}.slurm"
311 with open(output_path, "w", encoding="utf-8") as f:
312 f.write(rendered_content)
314 return str(output_path)
317def _build_environment_setup(environment: JobEnvironment) -> str:
318 """Build environment setup script."""
319 setup_lines = []
321 # Set environment variables
322 for key, value in environment.env_vars.items():
323 setup_lines.append(f"export {key}={value}")
325 # Activate environments
326 if environment.conda:
327 setup_lines.extend(["conda deactivate", f"conda activate {environment.conda}"])
328 elif environment.venv:
329 setup_lines.append(f"source {environment.venv}/bin/activate")
330 elif environment.sqsh:
331 setup_lines.extend(
332 [
333 f': "${{IMAGE:={environment.sqsh}}}"',
334 "declare -a CONTAINER_ARGS=(",
335 ' --container-image "$IMAGE"',
336 ")",
337 ]
338 )
340 return "\n".join(setup_lines)