Coverage for src/srunx/workflows/runner.py: 27%

230 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-21 03:27 +0900

1"""Workflow runner for executing YAML-defined workflows with SLURM""" 

2 

3import threading 

4import time 

5from collections import defaultdict 

6from concurrent.futures import ThreadPoolExecutor 

7from pathlib import Path 

8from typing import Any 

9 

10import yaml 

11 

12from srunx.client import Slurm 

13from srunx.logging import get_logger 

14from srunx.models import ( 

15 BaseJob, 

16 Job, 

17 JobEnvironment, 

18 JobResource, 

19 JobStatus, 

20 ShellJob, 

21 Workflow, 

22) 

23from srunx.workflows.tasks import submit_and_monitor_job 

24 

25logger = get_logger(__name__) 

26 

27 

28class WorkflowRunner: 

29 """Runner for executing workflows defined in YAML with dynamic task scheduling. 

30 

31 Tasks are executed as soon as their dependencies are satisfied, 

32 rather than waiting for entire dependency levels to complete. 

33 """ 

34 

35 def __init__(self) -> None: 

36 """Initialize workflow runner.""" 

37 self.executed_tasks: dict[str, Job | ShellJob] = {} 

38 self.slurm = Slurm() 

39 

40 def load_from_yaml(self, yaml_path: str | Path) -> Workflow: 

41 """Load and validate a workflow from a YAML file. 

42 

43 Args: 

44 yaml_path: Path to the YAML workflow definition file. 

45 

46 Returns: 

47 Validated Workflow object. 

48 

49 Raises: 

50 FileNotFoundError: If the YAML file doesn't exist. 

51 yaml.YAMLError: If the YAML is malformed. 

52 ValidationError: If the workflow structure is invalid. 

53 """ 

54 yaml_file = Path(yaml_path) 

55 if not yaml_file.exists(): 

56 raise FileNotFoundError(f"Workflow file not found: {yaml_path}") 

57 

58 with open(yaml_file, encoding="utf-8") as f: 

59 data = yaml.safe_load(f) 

60 

61 return self._parse_workflow_data(data) 

62 

63 def _parse_workflow_data(self, data: dict) -> Workflow: 

64 """Parse workflow data from dictionary.""" 

65 workflow_name = data.get("name", "unnamed_workflow") 

66 tasks_data = data.get("tasks", []) 

67 

68 tasks = [] 

69 for task_data in tasks_data: 

70 task = self._parse_task_data(task_data) 

71 tasks.append(task) 

72 

73 return Workflow(name=workflow_name, tasks=tasks) 

74 

75 def _parse_task_data(self, task_data: dict) -> BaseJob: 

76 """Parse a single task from dictionary using Pydantic model_validate.""" 

77 # Basic task properties 

78 name = task_data["name"] 

79 path = task_data.get("path") 

80 depends_on = task_data.get("depends_on", []) 

81 

82 job_data: dict[str, Any] = {"name": name, "depends_on": depends_on} 

83 

84 job: Job | ShellJob 

85 if path: 

86 job_data["path"] = path 

87 job = ShellJob.model_validate(job_data) 

88 else: 

89 command = task_data.get("command") 

90 if command is None: 

91 raise ValueError(f"Task '{name}' must have either 'command' or 'path'") 

92 

93 job_data["command"] = command 

94 

95 # Optional fields with defaults handled by Pydantic 

96 if task_data.get("log_dir") is not None: 

97 job_data["log_dir"] = task_data["log_dir"] 

98 if task_data.get("work_dir") is not None: 

99 job_data["work_dir"] = task_data["work_dir"] 

100 

101 # Resource configuration - use model_validate for type safety 

102 resource_data = { 

103 "nodes": task_data.get("nodes", 1), 

104 "gpus_per_node": task_data.get("gpus_per_node", 0), 

105 "ntasks_per_node": task_data.get("ntasks_per_node", 1), 

106 "cpus_per_task": task_data.get("cpus_per_task", 1), 

107 } 

108 if task_data.get("memory_per_node") is not None: 

109 resource_data["memory_per_node"] = task_data["memory_per_node"] 

110 if task_data.get("time_limit") is not None: 

111 resource_data["time_limit"] = task_data["time_limit"] 

112 

113 job_data["resources"] = JobResource.model_validate(resource_data) 

114 

115 # Environment configuration - use model_validate for type safety 

116 env_data = { 

117 "env_vars": task_data.get("env_vars", {}), 

118 } 

119 if task_data.get("conda") is not None: 

120 env_data["conda"] = task_data["conda"] 

121 if task_data.get("venv") is not None: 

122 env_data["venv"] = task_data["venv"] 

123 # Handle 'container' as alias for 'sqsh' 

124 sqsh_value = task_data.get("sqsh") or task_data.get("container") 

125 if sqsh_value is not None: 

126 env_data["sqsh"] = sqsh_value 

127 

128 job_data["environment"] = JobEnvironment.model_validate(env_data) 

129 

130 # Create job using model_validate 

131 job = Job.model_validate(job_data) 

132 

133 return job 

134 

135 def run(self, workflow: Workflow) -> dict[str, Job | ShellJob]: 

136 """Run a workflow with dynamic task scheduling. 

137 

138 Tasks are executed as soon as their dependencies are satisfied, 

139 rather than waiting for entire levels to complete. 

140 

141 Args: 

142 workflow: Workflow to execute. 

143 

144 Returns: 

145 Dictionary mapping task names to Job instances. 

146 """ 

147 task_map = {task.name: task for task in workflow.tasks} 

148 

149 # Track task states using type-safe JobStatus enum 

150 task_states = {task.name: JobStatus.PENDING for task in workflow.tasks} 

151 

152 # Build reverse dependency map: task -> tasks that depend on it 

153 reverse_deps = defaultdict(set) 

154 for task in workflow.tasks: 

155 for dep in task.depends_on: 

156 reverse_deps[dep].add(task.name) 

157 

158 # Results and futures tracking 

159 results: dict[str, Job | ShellJob] = {} 

160 running_futures: dict[str, Any] = {} 

161 

162 # Thread-safe lock for state updates 

163 state_lock = threading.Lock() 

164 

165 def get_ready_tasks() -> list[str]: 

166 """Get all tasks that are ready to run using type-safe status checking.""" 

167 ready_tasks = [] 

168 for task_name, task in task_map.items(): 

169 if task.can_start(task_states): 

170 ready_tasks.append(task_name) 

171 return ready_tasks 

172 

173 def execute_task(task_name: str) -> Job | ShellJob: 

174 """Execute a single task and wait for completion.""" 

175 logger.info(f"🚀 Starting task: {task_name}") 

176 task = task_map[task_name] 

177 

178 # Update task status to running 

179 with state_lock: 

180 task_states[task_name] = JobStatus.RUNNING 

181 task.update_status(JobStatus.RUNNING) 

182 

183 # Type narrow the job to the expected union type 

184 if not isinstance(task, Job | ShellJob): 

185 raise TypeError(f"Unexpected job type: {type(task)}") 

186 

187 try: 

188 job_result = submit_and_monitor_job(task) 

189 logger.success(f"✅ Completed task: {task_name}") 

190 

191 # Update task status to completed 

192 with state_lock: 

193 task_states[task_name] = JobStatus.COMPLETED 

194 task.update_status(JobStatus.COMPLETED) 

195 

196 return job_result 

197 except Exception as e: 

198 logger.error(f"❌ Task {task_name} failed: {e}") 

199 with state_lock: 

200 task_states[task_name] = JobStatus.FAILED 

201 task.update_status(JobStatus.FAILED) 

202 raise 

203 

204 def on_task_complete(task_name: str, result: Job | ShellJob) -> list[str]: 

205 """Handle task completion and schedule dependent tasks. 

206 

207 Returns: 

208 List of newly ready task names. 

209 """ 

210 with state_lock: 

211 # Status is already updated in execute_task 

212 results[task_name] = result 

213 self.executed_tasks[task_name] = result 

214 

215 # Check if any dependent tasks are now ready using type-safe status checking 

216 newly_ready = [] 

217 for dependent_task_name in reverse_deps[task_name]: 

218 dependent_task = task_map[dependent_task_name] 

219 if dependent_task.can_start(task_states): 

220 newly_ready.append(dependent_task_name) 

221 

222 logger.info( 

223 f"📋 Task {task_name} completed. Ready to start: {newly_ready}" 

224 ) 

225 return newly_ready 

226 

227 # Use ThreadPoolExecutor for parallel execution 

228 with ThreadPoolExecutor(max_workers=8) as executor: 

229 # Submit initial tasks (those with no dependencies) 

230 initial_tasks = get_ready_tasks() 

231 logger.info(f"🌋 Starting initial tasks: {initial_tasks}") 

232 

233 for task_name in initial_tasks: 

234 future = executor.submit(execute_task, task_name) 

235 running_futures[task_name] = future 

236 

237 # Process completed tasks and schedule new ones 

238 while running_futures: 

239 # Wait for at least one task to complete 

240 completed_futures = [] 

241 for task_name, future in list(running_futures.items()): 

242 if future.done(): 

243 completed_futures.append((task_name, future)) 

244 del running_futures[task_name] 

245 

246 if not completed_futures: 

247 # Sleep briefly to avoid busy waiting 

248 time.sleep(0.1) 

249 continue 

250 

251 # Handle completed tasks 

252 for task_name, future in completed_futures: 

253 try: 

254 result = future.result() 

255 newly_ready = on_task_complete(task_name, result) 

256 

257 # Schedule newly ready tasks 

258 for ready_task in newly_ready: 

259 if ready_task not in running_futures: 

260 new_future = executor.submit(execute_task, ready_task) 

261 running_futures[ready_task] = new_future 

262 

263 except Exception as e: 

264 logger.error(f"❌ Task {task_name} failed: {e}") 

265 # Mark as failed to avoid infinite loop 

266 with state_lock: 

267 task_states[task_name] = JobStatus.FAILED 

268 task_map[task_name].update_status(JobStatus.FAILED) 

269 raise 

270 

271 # Verify all tasks completed 

272 incomplete_tasks = [ 

273 name for name, state in task_states.items() if state != JobStatus.COMPLETED 

274 ] 

275 if incomplete_tasks: 

276 failed_tasks = [ 

277 name 

278 for name, state in task_states.items() 

279 if state == JobStatus.FAILED 

280 ] 

281 if failed_tasks: 

282 logger.error(f"❌ Tasks failed: {failed_tasks}") 

283 raise RuntimeError(f"Workflow execution failed: {failed_tasks}") 

284 else: 

285 logger.error(f"❌ Some tasks did not complete: {incomplete_tasks}") 

286 raise RuntimeError(f"Workflow execution incomplete: {incomplete_tasks}") 

287 

288 return results 

289 

290 def execute_from_yaml(self, yaml_path: str | Path) -> dict[str, Job | ShellJob]: 

291 """Load and execute a workflow from YAML file. 

292 

293 Args: 

294 yaml_path: Path to YAML workflow file. 

295 

296 Returns: 

297 Dictionary mapping task names to Job instances. 

298 """ 

299 logger.info(f"Loading workflow from {yaml_path}") 

300 workflow = self.load_from_yaml(yaml_path) 

301 

302 logger.info( 

303 f"Executing workflow '{workflow.name}' with {len(workflow.tasks)} tasks" 

304 ) 

305 results = self.run(workflow) 

306 

307 logger.success("🎉 Workflow completed successfully") 

308 return results 

309 

310 def _build_execution_levels(self, workflow: Workflow) -> dict[int, list[str]]: 

311 """Build execution levels for parallel task execution. 

312 

313 Tasks in the same level can be executed in parallel. 

314 This method is kept for backward compatibility but is no longer used 

315 in the new dynamic scheduling approach. 

316 

317 Args: 

318 workflow: Workflow to analyze. 

319 

320 Returns: 

321 Dictionary mapping level numbers to lists of task names. 

322 """ 

323 task_map = {task.name: task for task in workflow.tasks} 

324 levels: dict[int, list[str]] = defaultdict(list) 

325 task_levels: dict[str, int] = {} 

326 

327 # Calculate the maximum depth for each task 

328 def calculate_depth(task_name: str, visited: set[str]) -> int: 

329 if task_name in visited: 

330 raise ValueError( 

331 f"Circular dependency detected involving task '{task_name}'" 

332 ) 

333 

334 if task_name in task_levels: 

335 return task_levels[task_name] 

336 

337 task = task_map[task_name] 

338 if not task.depends_on: 

339 # No dependencies, can execute at level 0 

340 task_levels[task_name] = 0 

341 return 0 

342 

343 visited.add(task_name) 

344 max_dep_level = -1 

345 

346 for dep in task.depends_on: 

347 dep_level = calculate_depth(dep, visited) 

348 max_dep_level = max(max_dep_level, dep_level) 

349 

350 visited.remove(task_name) 

351 

352 # This task executes after all its dependencies 

353 task_level = max_dep_level + 1 

354 task_levels[task_name] = task_level 

355 return task_level 

356 

357 # Calculate levels for all tasks 

358 for task in workflow.tasks: 

359 level = calculate_depth(task.name, set()) 

360 levels[level].append(task.name) 

361 

362 return dict(levels) 

363 

364 

365def run_workflow_from_file(yaml_path: str | Path) -> dict[str, Job | ShellJob]: 

366 """Convenience function to run workflow from YAML file. 

367 

368 Args: 

369 yaml_path: Path to YAML workflow file. 

370 

371 Returns: 

372 Dictionary mapping task names to Job instances. 

373 """ 

374 runner = WorkflowRunner() 

375 return runner.execute_from_yaml(yaml_path) 

376 

377 

378def validate_workflow_dependencies(workflow: Workflow) -> None: 

379 """Validate workflow task dependencies.""" 

380 task_names = {task.name for task in workflow.tasks} 

381 

382 for task in workflow.tasks: 

383 for dependency in task.depends_on: 

384 if dependency not in task_names: 

385 raise ValueError( 

386 f"Task '{task.name}' depends on unknown task '{dependency}'" 

387 ) 

388 

389 # Check for circular dependencies (simple check) 

390 visited = set() 

391 rec_stack = set() 

392 

393 def has_cycle(task_name: str) -> bool: 

394 if task_name in rec_stack: 

395 return True 

396 if task_name in visited: 

397 return False 

398 

399 visited.add(task_name) 

400 rec_stack.add(task_name) 

401 

402 task = workflow.get_task(task_name) 

403 if task: 

404 for dependency in task.depends_on: 

405 if has_cycle(dependency): 

406 return True 

407 

408 rec_stack.remove(task_name) 

409 return False 

410 

411 for task in workflow.tasks: 

412 if has_cycle(task.name): 

413 raise ValueError( 

414 f"Circular dependency detected involving task '{task.name}'" 

415 ) 

416 

417 

418def show_workflow_plan(workflow: Workflow) -> None: 

419 """Show workflow execution plan.""" 

420 msg = f"""\ 

421{" PLAN ":=^80} 

422Workflow: {workflow.name} 

423Tasks: {len(workflow.tasks)} 

424Execution: Sequential with dependency-based scheduling 

425""" 

426 

427 for task in workflow.tasks: 

428 msg += f" Task: {task.name}\n" 

429 if isinstance(task.job, Job): 

430 msg += f"{' Command:': <21} {' '.join(task.job.command or [])}\n" 

431 msg += f"{' Resources:': <21} {task.job.resources.nodes} nodes, {task.job.resources.gpus_per_node} GPUs/node\n" 

432 if task.job.environment.conda: 

433 msg += f"{' Conda env:': <21} {task.job.environment.conda}\n" 

434 if task.job.environment.sqsh: 

435 msg += f"{' Sqsh:': <21} {task.job.environment.sqsh}\n" 

436 if task.job.environment.venv: 

437 msg += f"{' Venv:': <21} {task.job.environment.venv}\n" 

438 elif isinstance(task.job, ShellJob): 

439 msg += f"{' Path:': <21} {task.job.path}\n" 

440 if task.depends_on: 

441 msg += f"{' Dependencies:': <21} {', '.join(task.depends_on)}\n" 

442 

443 msg += f"{'=' * 80}\n" 

444 print(msg)