Coverage for src/srunx/models.py: 81%

168 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 21:22 +0900

1"""Data models for SLURM job management.""" 

2 

3import os 

4import subprocess 

5import time 

6from enum import Enum 

7from pathlib import Path 

8from typing import Self 

9 

10import jinja2 

11from pydantic import BaseModel, Field, PrivateAttr, model_validator 

12 

13from srunx.exceptions import WorkflowValidationError 

14from srunx.logging import get_logger 

15 

16logger = get_logger(__name__) 

17 

18 

19class JobStatus(Enum): 

20 """Job status enumeration for both SLURM jobs and workflow tasks.""" 

21 

22 UNKNOWN = "UNKNOWN" 

23 PENDING = "PENDING" 

24 RUNNING = "RUNNING" 

25 COMPLETED = "COMPLETED" 

26 FAILED = "FAILED" 

27 CANCELLED = "CANCELLED" 

28 TIMEOUT = "TIMEOUT" 

29 

30 

31class JobResource(BaseModel): 

32 """SLURM resource allocation requirements.""" 

33 

34 nodes: int = Field(default=1, ge=1, description="Number of compute nodes") 

35 gpus_per_node: int = Field(default=0, ge=0, description="Number of GPUs per node") 

36 ntasks_per_node: int = Field( 

37 default=1, ge=1, description="Number of tasks per node" 

38 ) 

39 cpus_per_task: int = Field(default=1, ge=1, description="Number of CPUs per task") 

40 memory_per_node: str | None = Field( 

41 default=None, description="Memory per node (e.g., '32GB')" 

42 ) 

43 time_limit: str | None = Field( 

44 default=None, description="Time limit (e.g., '1:00:00')" 

45 ) 

46 

47 

48class JobEnvironment(BaseModel): 

49 """Job environment configuration.""" 

50 

51 conda: str | None = Field(default=None, description="Conda environment name") 

52 venv: str | None = Field(default=None, description="Virtual environment path") 

53 sqsh: str | None = Field(default=None, description="SquashFS image path") 

54 env_vars: dict[str, str] = Field( 

55 default_factory=dict, description="Environment variables" 

56 ) 

57 

58 @model_validator(mode="after") 

59 def validate_environment(self) -> Self: 

60 envs = [self.conda, self.venv, self.sqsh] 

61 non_none_count = sum(x is not None for x in envs) 

62 if non_none_count != 1: 

63 raise ValueError("Exactly one of 'conda', 'venv', or 'sqsh' must be set") 

64 return self 

65 

66 

67class BaseJob(BaseModel): 

68 name: str = Field(default="job", description="Job name") 

69 job_id: int | None = Field(default=None, description="SLURM job ID") 

70 depends_on: list[str] = Field( 

71 default_factory=list, description="Task dependencies for workflow execution" 

72 ) 

73 

74 _status: JobStatus = PrivateAttr(default=JobStatus.PENDING) 

75 

76 @property 

77 def status(self) -> JobStatus: 

78 """ 

79 Accessing ``job.status`` always triggers a lightweight refresh 

80 (only if we have a ``job_id`` and the status isn't terminal). 

81 """ 

82 if self.job_id is not None and self._status not in { 

83 JobStatus.COMPLETED, 

84 JobStatus.FAILED, 

85 JobStatus.CANCELLED, 

86 JobStatus.TIMEOUT, 

87 }: 

88 self.refresh() 

89 return self._status 

90 

91 @status.setter 

92 def status(self, value: JobStatus) -> None: 

93 self._status = value 

94 

95 def refresh(self, retries: int = 3) -> Self: 

96 """Query sacct and update ``_status`` in-place.""" 

97 if self.job_id is None: 

98 return self 

99 

100 for retry in range(retries): 

101 try: 

102 result = subprocess.run( 

103 [ 

104 "sacct", 

105 "-j", 

106 str(self.job_id), 

107 "--format", 

108 "JobID,State", 

109 "--noheader", 

110 "--parsable2", 

111 ], 

112 capture_output=True, 

113 text=True, 

114 check=True, 

115 ) 

116 except subprocess.CalledProcessError as e: 

117 logger.error(f"Failed to query job {self.job_id}: {e}") 

118 raise 

119 

120 line = result.stdout.strip().split("\n")[0] if result.stdout.strip() else "" 

121 if not line: 

122 if retry < retries - 1: 

123 time.sleep(1) 

124 continue 

125 self._status = JobStatus.UNKNOWN 

126 return self 

127 break 

128 

129 _, state = line.split("|", 1) 

130 self._status = JobStatus(state) 

131 return self 

132 

133 def dependencies_satisfied(self, completed_job_names: list[str]) -> bool: 

134 """All dependencies are completed & this job is still pending.""" 

135 return self.status == JobStatus.PENDING and all( 

136 dep in completed_job_names for dep in self.depends_on 

137 ) 

138 

139 

140class Job(BaseJob): 

141 """Represents a SLURM job with complete configuration.""" 

142 

143 command: list[str] = Field(description="Command to execute") 

144 resources: JobResource = Field( 

145 default_factory=JobResource, description="Resource requirements" 

146 ) 

147 environment: JobEnvironment = Field( 

148 default_factory=JobEnvironment, description="Environment setup" 

149 ) 

150 log_dir: str = Field( 

151 default=os.getenv("SLURM_LOG_DIR", "logs"), 

152 description="Directory for log files", 

153 ) 

154 work_dir: str = Field(default_factory=os.getcwd, description="Working directory") 

155 

156 

157class ShellJob(BaseJob): 

158 path: str = Field(description="Shell script path to execute") 

159 

160 

161type JobType = BaseJob | Job | ShellJob 

162type RunableJobType = Job | ShellJob 

163 

164 

165class Workflow(BaseModel): 

166 """Represents a workflow containing multiple jobs with dependencies.""" 

167 

168 name: str = Field(description="Workflow name") 

169 jobs: list[RunableJobType] = Field(description="List of jobs in the workflow") 

170 

171 def get(self, name: str) -> RunableJobType | None: 

172 """Get a job by name.""" 

173 for job in self.jobs: 

174 if job.name == name: 

175 return job.refresh() 

176 return None 

177 

178 def get_dependencies(self, job_name: str) -> list[str]: 

179 """Get dependencies for a specific job.""" 

180 job = self.get(job_name) 

181 return job.depends_on if job else [] 

182 

183 def show(self): 

184 msg = f"""\ 

185{" PLAN ":=^80} 

186Workflow: {self.name} 

187Jobs: {len(self.jobs)} 

188""" 

189 

190 def add_indent(indent: int, msg: str) -> str: 

191 return " " * indent + msg 

192 

193 for job in self.jobs: 

194 msg += add_indent(1, f"Job: {job.name}\n") 

195 if isinstance(job, Job): 

196 msg += add_indent( 

197 2, f"{'Command:': <13} {' '.join(job.command or [])}\n" 

198 ) 

199 msg += add_indent( 

200 2, 

201 f"{'Resources:': <13} {job.resources.nodes} nodes, {job.resources.gpus_per_node} GPUs/node\n", 

202 ) 

203 if job.environment.conda: 

204 msg += add_indent( 

205 2, f"{'Conda env:': <13} {job.environment.conda}\n" 

206 ) 

207 if job.environment.sqsh: 

208 msg += add_indent(2, f"{'Sqsh:': <13} {job.environment.sqsh}\n") 

209 if job.environment.venv: 

210 msg += add_indent(2, f"{'Venv:': <13} {job.environment.venv}\n") 

211 elif isinstance(job, ShellJob): 

212 msg += add_indent(2, f"{'Path:': <13} {job.path}\n") 

213 if job.depends_on: 

214 msg += add_indent( 

215 2, f"{'Dependencies:': <13} {', '.join(job.depends_on)}\n" 

216 ) 

217 

218 msg += f"{'=' * 80}\n" 

219 print(msg) 

220 

221 def validate(self): 

222 """Validate workflow job dependencies.""" 

223 job_names = {job.name for job in self.jobs} 

224 

225 if len(job_names) != len(self.jobs): 

226 raise WorkflowValidationError("Duplicate job names found in workflow") 

227 

228 for job in self.jobs: 

229 for dependency in job.depends_on: 

230 if dependency not in job_names: 

231 raise WorkflowValidationError( 

232 f"Job '{job.name}' depends on unknown job '{dependency}'" 

233 ) 

234 

235 # Check for circular dependencies (simple check) 

236 visited = set() 

237 rec_stack = set() 

238 

239 def has_cycle(job_name: str) -> bool: 

240 if job_name in rec_stack: 

241 return True 

242 if job_name in visited: 

243 return False 

244 

245 visited.add(job_name) 

246 rec_stack.add(job_name) 

247 

248 job = self.get(job_name) 

249 if job: 

250 for dependency in job.depends_on: 

251 if has_cycle(dependency): 

252 return True 

253 

254 rec_stack.remove(job_name) 

255 return False 

256 

257 for job in self.jobs: 

258 if has_cycle(job.name): 

259 raise WorkflowValidationError( 

260 f"Circular dependency detected involving job '{job.name}'" 

261 ) 

262 

263 

264def render_job_script( 

265 template_path: Path | str, 

266 job: Job, 

267 output_dir: Path | str, 

268 verbose: bool = False, 

269) -> str: 

270 """Render a SLURM job script from a template. 

271 

272 Args: 

273 template_path: Path to the Jinja template file. 

274 job: Job configuration. 

275 output_dir: Directory where the generated script will be saved. 

276 verbose: Whether to print the rendered content. 

277 

278 Returns: 

279 Path to the generated SLURM batch script. 

280 

281 Raises: 

282 FileNotFoundError: If the template file does not exist. 

283 jinja2.TemplateError: If template rendering fails. 

284 """ 

285 template_file = Path(template_path) 

286 if not template_file.is_file(): 

287 raise FileNotFoundError(f"Template file '{template_path}' not found") 

288 

289 with open(template_file, encoding="utf-8") as f: 

290 template_content = f.read() 

291 

292 template = jinja2.Template(template_content, undefined=jinja2.StrictUndefined) 

293 

294 # Prepare template variables 

295 template_vars = { 

296 "job_name": job.name, 

297 "command": " ".join(job.command or []), 

298 "log_dir": job.log_dir, 

299 "work_dir": job.work_dir, 

300 "environment_setup": _build_environment_setup(job.environment), 

301 **job.resources.model_dump(), 

302 } 

303 

304 rendered_content = template.render(template_vars) 

305 

306 if verbose: 

307 print(rendered_content) 

308 

309 # Generate output file 

310 output_path = Path(output_dir) / f"{job.name}.slurm" 

311 with open(output_path, "w", encoding="utf-8") as f: 

312 f.write(rendered_content) 

313 

314 return str(output_path) 

315 

316 

317def _build_environment_setup(environment: JobEnvironment) -> str: 

318 """Build environment setup script.""" 

319 setup_lines = [] 

320 

321 # Set environment variables 

322 for key, value in environment.env_vars.items(): 

323 setup_lines.append(f"export {key}={value}") 

324 

325 # Activate environments 

326 if environment.conda: 

327 setup_lines.extend(["conda deactivate", f"conda activate {environment.conda}"]) 

328 elif environment.venv: 

329 setup_lines.append(f"source {environment.venv}/bin/activate") 

330 elif environment.sqsh: 

331 setup_lines.extend( 

332 [ 

333 f': "${{IMAGE:={environment.sqsh}}}"', 

334 "declare -a CONTAINER_ARGS=(", 

335 ' --container-image "$IMAGE"', 

336 ")", 

337 ] 

338 ) 

339 

340 return "\n".join(setup_lines)