Coverage for src/srunx/workflows/runner.py: 27%
230 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-21 03:27 +0900
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-21 03:27 +0900
1"""Workflow runner for executing YAML-defined workflows with SLURM"""
3import threading
4import time
5from collections import defaultdict
6from concurrent.futures import ThreadPoolExecutor
7from pathlib import Path
8from typing import Any
10import yaml
12from srunx.client import Slurm
13from srunx.logging import get_logger
14from srunx.models import (
15 BaseJob,
16 Job,
17 JobEnvironment,
18 JobResource,
19 JobStatus,
20 ShellJob,
21 Workflow,
22)
23from srunx.workflows.tasks import submit_and_monitor_job
25logger = get_logger(__name__)
28class WorkflowRunner:
29 """Runner for executing workflows defined in YAML with dynamic task scheduling.
31 Tasks are executed as soon as their dependencies are satisfied,
32 rather than waiting for entire dependency levels to complete.
33 """
35 def __init__(self) -> None:
36 """Initialize workflow runner."""
37 self.executed_tasks: dict[str, Job | ShellJob] = {}
38 self.slurm = Slurm()
40 def load_from_yaml(self, yaml_path: str | Path) -> Workflow:
41 """Load and validate a workflow from a YAML file.
43 Args:
44 yaml_path: Path to the YAML workflow definition file.
46 Returns:
47 Validated Workflow object.
49 Raises:
50 FileNotFoundError: If the YAML file doesn't exist.
51 yaml.YAMLError: If the YAML is malformed.
52 ValidationError: If the workflow structure is invalid.
53 """
54 yaml_file = Path(yaml_path)
55 if not yaml_file.exists():
56 raise FileNotFoundError(f"Workflow file not found: {yaml_path}")
58 with open(yaml_file, encoding="utf-8") as f:
59 data = yaml.safe_load(f)
61 return self._parse_workflow_data(data)
63 def _parse_workflow_data(self, data: dict) -> Workflow:
64 """Parse workflow data from dictionary."""
65 workflow_name = data.get("name", "unnamed_workflow")
66 tasks_data = data.get("tasks", [])
68 tasks = []
69 for task_data in tasks_data:
70 task = self._parse_task_data(task_data)
71 tasks.append(task)
73 return Workflow(name=workflow_name, tasks=tasks)
75 def _parse_task_data(self, task_data: dict) -> BaseJob:
76 """Parse a single task from dictionary using Pydantic model_validate."""
77 # Basic task properties
78 name = task_data["name"]
79 path = task_data.get("path")
80 depends_on = task_data.get("depends_on", [])
82 job_data: dict[str, Any] = {"name": name, "depends_on": depends_on}
84 job: Job | ShellJob
85 if path:
86 job_data["path"] = path
87 job = ShellJob.model_validate(job_data)
88 else:
89 command = task_data.get("command")
90 if command is None:
91 raise ValueError(f"Task '{name}' must have either 'command' or 'path'")
93 job_data["command"] = command
95 # Optional fields with defaults handled by Pydantic
96 if task_data.get("log_dir") is not None:
97 job_data["log_dir"] = task_data["log_dir"]
98 if task_data.get("work_dir") is not None:
99 job_data["work_dir"] = task_data["work_dir"]
101 # Resource configuration - use model_validate for type safety
102 resource_data = {
103 "nodes": task_data.get("nodes", 1),
104 "gpus_per_node": task_data.get("gpus_per_node", 0),
105 "ntasks_per_node": task_data.get("ntasks_per_node", 1),
106 "cpus_per_task": task_data.get("cpus_per_task", 1),
107 }
108 if task_data.get("memory_per_node") is not None:
109 resource_data["memory_per_node"] = task_data["memory_per_node"]
110 if task_data.get("time_limit") is not None:
111 resource_data["time_limit"] = task_data["time_limit"]
113 job_data["resources"] = JobResource.model_validate(resource_data)
115 # Environment configuration - use model_validate for type safety
116 env_data = {
117 "env_vars": task_data.get("env_vars", {}),
118 }
119 if task_data.get("conda") is not None:
120 env_data["conda"] = task_data["conda"]
121 if task_data.get("venv") is not None:
122 env_data["venv"] = task_data["venv"]
123 # Handle 'container' as alias for 'sqsh'
124 sqsh_value = task_data.get("sqsh") or task_data.get("container")
125 if sqsh_value is not None:
126 env_data["sqsh"] = sqsh_value
128 job_data["environment"] = JobEnvironment.model_validate(env_data)
130 # Create job using model_validate
131 job = Job.model_validate(job_data)
133 return job
135 def run(self, workflow: Workflow) -> dict[str, Job | ShellJob]:
136 """Run a workflow with dynamic task scheduling.
138 Tasks are executed as soon as their dependencies are satisfied,
139 rather than waiting for entire levels to complete.
141 Args:
142 workflow: Workflow to execute.
144 Returns:
145 Dictionary mapping task names to Job instances.
146 """
147 task_map = {task.name: task for task in workflow.tasks}
149 # Track task states using type-safe JobStatus enum
150 task_states = {task.name: JobStatus.PENDING for task in workflow.tasks}
152 # Build reverse dependency map: task -> tasks that depend on it
153 reverse_deps = defaultdict(set)
154 for task in workflow.tasks:
155 for dep in task.depends_on:
156 reverse_deps[dep].add(task.name)
158 # Results and futures tracking
159 results: dict[str, Job | ShellJob] = {}
160 running_futures: dict[str, Any] = {}
162 # Thread-safe lock for state updates
163 state_lock = threading.Lock()
165 def get_ready_tasks() -> list[str]:
166 """Get all tasks that are ready to run using type-safe status checking."""
167 ready_tasks = []
168 for task_name, task in task_map.items():
169 if task.can_start(task_states):
170 ready_tasks.append(task_name)
171 return ready_tasks
173 def execute_task(task_name: str) -> Job | ShellJob:
174 """Execute a single task and wait for completion."""
175 logger.info(f"🚀 Starting task: {task_name}")
176 task = task_map[task_name]
178 # Update task status to running
179 with state_lock:
180 task_states[task_name] = JobStatus.RUNNING
181 task.update_status(JobStatus.RUNNING)
183 # Type narrow the job to the expected union type
184 if not isinstance(task, Job | ShellJob):
185 raise TypeError(f"Unexpected job type: {type(task)}")
187 try:
188 job_result = submit_and_monitor_job(task)
189 logger.success(f"✅ Completed task: {task_name}")
191 # Update task status to completed
192 with state_lock:
193 task_states[task_name] = JobStatus.COMPLETED
194 task.update_status(JobStatus.COMPLETED)
196 return job_result
197 except Exception as e:
198 logger.error(f"❌ Task {task_name} failed: {e}")
199 with state_lock:
200 task_states[task_name] = JobStatus.FAILED
201 task.update_status(JobStatus.FAILED)
202 raise
204 def on_task_complete(task_name: str, result: Job | ShellJob) -> list[str]:
205 """Handle task completion and schedule dependent tasks.
207 Returns:
208 List of newly ready task names.
209 """
210 with state_lock:
211 # Status is already updated in execute_task
212 results[task_name] = result
213 self.executed_tasks[task_name] = result
215 # Check if any dependent tasks are now ready using type-safe status checking
216 newly_ready = []
217 for dependent_task_name in reverse_deps[task_name]:
218 dependent_task = task_map[dependent_task_name]
219 if dependent_task.can_start(task_states):
220 newly_ready.append(dependent_task_name)
222 logger.info(
223 f"📋 Task {task_name} completed. Ready to start: {newly_ready}"
224 )
225 return newly_ready
227 # Use ThreadPoolExecutor for parallel execution
228 with ThreadPoolExecutor(max_workers=8) as executor:
229 # Submit initial tasks (those with no dependencies)
230 initial_tasks = get_ready_tasks()
231 logger.info(f"🌋 Starting initial tasks: {initial_tasks}")
233 for task_name in initial_tasks:
234 future = executor.submit(execute_task, task_name)
235 running_futures[task_name] = future
237 # Process completed tasks and schedule new ones
238 while running_futures:
239 # Wait for at least one task to complete
240 completed_futures = []
241 for task_name, future in list(running_futures.items()):
242 if future.done():
243 completed_futures.append((task_name, future))
244 del running_futures[task_name]
246 if not completed_futures:
247 # Sleep briefly to avoid busy waiting
248 time.sleep(0.1)
249 continue
251 # Handle completed tasks
252 for task_name, future in completed_futures:
253 try:
254 result = future.result()
255 newly_ready = on_task_complete(task_name, result)
257 # Schedule newly ready tasks
258 for ready_task in newly_ready:
259 if ready_task not in running_futures:
260 new_future = executor.submit(execute_task, ready_task)
261 running_futures[ready_task] = new_future
263 except Exception as e:
264 logger.error(f"❌ Task {task_name} failed: {e}")
265 # Mark as failed to avoid infinite loop
266 with state_lock:
267 task_states[task_name] = JobStatus.FAILED
268 task_map[task_name].update_status(JobStatus.FAILED)
269 raise
271 # Verify all tasks completed
272 incomplete_tasks = [
273 name for name, state in task_states.items() if state != JobStatus.COMPLETED
274 ]
275 if incomplete_tasks:
276 failed_tasks = [
277 name
278 for name, state in task_states.items()
279 if state == JobStatus.FAILED
280 ]
281 if failed_tasks:
282 logger.error(f"❌ Tasks failed: {failed_tasks}")
283 raise RuntimeError(f"Workflow execution failed: {failed_tasks}")
284 else:
285 logger.error(f"❌ Some tasks did not complete: {incomplete_tasks}")
286 raise RuntimeError(f"Workflow execution incomplete: {incomplete_tasks}")
288 return results
290 def execute_from_yaml(self, yaml_path: str | Path) -> dict[str, Job | ShellJob]:
291 """Load and execute a workflow from YAML file.
293 Args:
294 yaml_path: Path to YAML workflow file.
296 Returns:
297 Dictionary mapping task names to Job instances.
298 """
299 logger.info(f"Loading workflow from {yaml_path}")
300 workflow = self.load_from_yaml(yaml_path)
302 logger.info(
303 f"Executing workflow '{workflow.name}' with {len(workflow.tasks)} tasks"
304 )
305 results = self.run(workflow)
307 logger.success("🎉 Workflow completed successfully")
308 return results
310 def _build_execution_levels(self, workflow: Workflow) -> dict[int, list[str]]:
311 """Build execution levels for parallel task execution.
313 Tasks in the same level can be executed in parallel.
314 This method is kept for backward compatibility but is no longer used
315 in the new dynamic scheduling approach.
317 Args:
318 workflow: Workflow to analyze.
320 Returns:
321 Dictionary mapping level numbers to lists of task names.
322 """
323 task_map = {task.name: task for task in workflow.tasks}
324 levels: dict[int, list[str]] = defaultdict(list)
325 task_levels: dict[str, int] = {}
327 # Calculate the maximum depth for each task
328 def calculate_depth(task_name: str, visited: set[str]) -> int:
329 if task_name in visited:
330 raise ValueError(
331 f"Circular dependency detected involving task '{task_name}'"
332 )
334 if task_name in task_levels:
335 return task_levels[task_name]
337 task = task_map[task_name]
338 if not task.depends_on:
339 # No dependencies, can execute at level 0
340 task_levels[task_name] = 0
341 return 0
343 visited.add(task_name)
344 max_dep_level = -1
346 for dep in task.depends_on:
347 dep_level = calculate_depth(dep, visited)
348 max_dep_level = max(max_dep_level, dep_level)
350 visited.remove(task_name)
352 # This task executes after all its dependencies
353 task_level = max_dep_level + 1
354 task_levels[task_name] = task_level
355 return task_level
357 # Calculate levels for all tasks
358 for task in workflow.tasks:
359 level = calculate_depth(task.name, set())
360 levels[level].append(task.name)
362 return dict(levels)
365def run_workflow_from_file(yaml_path: str | Path) -> dict[str, Job | ShellJob]:
366 """Convenience function to run workflow from YAML file.
368 Args:
369 yaml_path: Path to YAML workflow file.
371 Returns:
372 Dictionary mapping task names to Job instances.
373 """
374 runner = WorkflowRunner()
375 return runner.execute_from_yaml(yaml_path)
378def validate_workflow_dependencies(workflow: Workflow) -> None:
379 """Validate workflow task dependencies."""
380 task_names = {task.name for task in workflow.tasks}
382 for task in workflow.tasks:
383 for dependency in task.depends_on:
384 if dependency not in task_names:
385 raise ValueError(
386 f"Task '{task.name}' depends on unknown task '{dependency}'"
387 )
389 # Check for circular dependencies (simple check)
390 visited = set()
391 rec_stack = set()
393 def has_cycle(task_name: str) -> bool:
394 if task_name in rec_stack:
395 return True
396 if task_name in visited:
397 return False
399 visited.add(task_name)
400 rec_stack.add(task_name)
402 task = workflow.get_task(task_name)
403 if task:
404 for dependency in task.depends_on:
405 if has_cycle(dependency):
406 return True
408 rec_stack.remove(task_name)
409 return False
411 for task in workflow.tasks:
412 if has_cycle(task.name):
413 raise ValueError(
414 f"Circular dependency detected involving task '{task.name}'"
415 )
418def show_workflow_plan(workflow: Workflow) -> None:
419 """Show workflow execution plan."""
420 msg = f"""\
421{" PLAN ":=^80}
422Workflow: {workflow.name}
423Tasks: {len(workflow.tasks)}
424Execution: Sequential with dependency-based scheduling
425"""
427 for task in workflow.tasks:
428 msg += f" Task: {task.name}\n"
429 if isinstance(task.job, Job):
430 msg += f"{' Command:': <21} {' '.join(task.job.command or [])}\n"
431 msg += f"{' Resources:': <21} {task.job.resources.nodes} nodes, {task.job.resources.gpus_per_node} GPUs/node\n"
432 if task.job.environment.conda:
433 msg += f"{' Conda env:': <21} {task.job.environment.conda}\n"
434 if task.job.environment.sqsh:
435 msg += f"{' Sqsh:': <21} {task.job.environment.sqsh}\n"
436 if task.job.environment.venv:
437 msg += f"{' Venv:': <21} {task.job.environment.venv}\n"
438 elif isinstance(task.job, ShellJob):
439 msg += f"{' Path:': <21} {task.job.path}\n"
440 if task.depends_on:
441 msg += f"{' Dependencies:': <21} {', '.join(task.depends_on)}\n"
443 msg += f"{'=' * 80}\n"
444 print(msg)