Coverage for src/srunx/runner.py: 95%

125 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-22 15:10 +0000

1"""Workflow runner for executing YAML-defined workflows with SLURM""" 

2 

3import time 

4from collections import defaultdict 

5from collections.abc import Sequence 

6from concurrent.futures import ThreadPoolExecutor 

7from pathlib import Path 

8from typing import Any, Self 

9 

10import yaml 

11 

12from srunx.callbacks import Callback 

13from srunx.client import Slurm 

14from srunx.exceptions import WorkflowValidationError 

15from srunx.logging import get_logger 

16from srunx.models import ( 

17 Job, 

18 JobEnvironment, 

19 JobResource, 

20 JobStatus, 

21 RunableJobType, 

22 ShellJob, 

23 Workflow, 

24) 

25 

26logger = get_logger(__name__) 

27 

28 

29class WorkflowRunner: 

30 """Runner for executing workflows defined in YAML with dynamic job scheduling. 

31 

32 Jobs are executed as soon as their dependencies are satisfied, 

33 rather than waiting for entire dependency levels to complete. 

34 """ 

35 

36 def __init__( 

37 self, workflow: Workflow, callbacks: Sequence[Callback] | None = None 

38 ) -> None: 

39 """Initialize workflow runner. 

40 

41 Args: 

42 workflow: Workflow to execute. 

43 callbacks: List of callbacks for job notifications. 

44 """ 

45 self.workflow = workflow 

46 self.slurm = Slurm(callbacks=callbacks) 

47 self.callbacks = callbacks or [] 

48 

49 @classmethod 

50 def from_yaml( 

51 cls, yaml_path: str | Path, callbacks: Sequence[Callback] | None = None 

52 ) -> Self: 

53 """Load and validate a workflow from a YAML file. 

54 

55 Args: 

56 yaml_path: Path to the YAML workflow definition file. 

57 callbacks: List of callbacks for job notifications. 

58 

59 Returns: 

60 WorkflowRunner instance with loaded workflow. 

61 

62 Raises: 

63 FileNotFoundError: If the YAML file doesn't exist. 

64 yaml.YAMLError: If the YAML is malformed. 

65 ValidationError: If the workflow structure is invalid. 

66 """ 

67 yaml_file = Path(yaml_path) 

68 if not yaml_file.exists(): 

69 raise FileNotFoundError(f"Workflow file not found: {yaml_path}") 

70 

71 with open(yaml_file, encoding="utf-8") as f: 

72 data = yaml.safe_load(f) 

73 

74 name = data.get("name", "unnamed") 

75 jobs_data = data.get("jobs", []) 

76 

77 jobs = [] 

78 for job_data in jobs_data: 

79 job = cls.parse_job(job_data) 

80 jobs.append(job) 

81 return cls(workflow=Workflow(name=name, jobs=jobs), callbacks=callbacks) 

82 

83 def get_independent_jobs(self) -> list[RunableJobType]: 

84 """Get all jobs that are independent of any other job.""" 

85 independent_jobs = [] 

86 for job in self.workflow.jobs: 

87 if not job.depends_on: 

88 independent_jobs.append(job) 

89 return independent_jobs 

90 

91 def run(self) -> dict[str, RunableJobType]: 

92 """Run a workflow with dynamic job scheduling. 

93 

94 Jobs are executed as soon as their dependencies are satisfied. 

95 

96 Returns: 

97 Dictionary mapping job names to completed Job instances. 

98 """ 

99 logger.info( 

100 f"🚀 Starting Workflow {self.workflow.name} with {len(self.workflow.jobs)} jobs" 

101 ) 

102 for callback in self.callbacks: 

103 callback.on_workflow_started(self.workflow) 

104 

105 # Track all jobs and results 

106 all_jobs = self.workflow.jobs.copy() 

107 results: dict[str, RunableJobType] = {} 

108 running_futures: dict[str, Any] = {} 

109 

110 # Build reverse dependency map for efficient lookups 

111 dependents = defaultdict(set) 

112 for job in all_jobs: 

113 for dep in job.depends_on: 

114 dependents[dep].add(job.name) 

115 

116 def execute_job(job: RunableJobType) -> RunableJobType: 

117 """Execute a single job.""" 

118 logger.info(f"🌋 {'SUBMITTED':<12} Job {job.name:<12}") 

119 

120 try: 

121 result = self.slurm.run(job) 

122 return result 

123 except Exception as e: 

124 raise 

125 

126 def on_job_complete(job_name: str, result: RunableJobType) -> list[str]: 

127 """Handle job completion and return newly ready job names.""" 

128 results[job_name] = result 

129 completed_job_names = list(set(results.keys())) 

130 

131 # Find newly ready jobs 

132 newly_ready = [] 

133 for dependent_name in dependents[job_name]: 

134 dependent_job = next(j for j in all_jobs if j.name == dependent_name) 

135 if ( 

136 dependent_job.status == JobStatus.PENDING 

137 and dependent_job.dependencies_satisfied(completed_job_names) 

138 ): 

139 newly_ready.append(dependent_name) 

140 

141 return newly_ready 

142 

143 # Execute workflow with ThreadPoolExecutor 

144 with ThreadPoolExecutor(max_workers=8) as executor: 

145 # Submit initial ready jobs 

146 initial_jobs = self.get_independent_jobs() 

147 

148 for job in initial_jobs: 

149 future = executor.submit(execute_job, job) 

150 running_futures[job.name] = future 

151 

152 # Process completed jobs and schedule new ones 

153 while running_futures: 

154 # Check for completed futures 

155 completed = [] 

156 for job_name, future in list(running_futures.items()): 

157 if future.done(): 

158 completed.append((job_name, future)) 

159 del running_futures[job_name] 

160 

161 if not completed: 

162 time.sleep(0.1) # Brief sleep to avoid busy waiting 

163 continue 

164 

165 # Handle completed jobs 

166 for job_name, future in completed: 

167 try: 

168 result = future.result() 

169 newly_ready_names = on_job_complete(job_name, result) 

170 

171 # Schedule newly ready jobs 

172 for ready_name in newly_ready_names: 

173 if ready_name not in running_futures: 

174 ready_job = next( 

175 j for j in all_jobs if j.name == ready_name 

176 ) 

177 new_future = executor.submit(execute_job, ready_job) 

178 running_futures[ready_name] = new_future 

179 

180 except Exception as e: 

181 logger.error(f"❌ Job {job_name} failed: {e}") 

182 raise 

183 

184 # Verify all jobs completed successfully 

185 failed_jobs = [j.name for j in all_jobs if j.status == JobStatus.FAILED] 

186 incomplete_jobs = [ 

187 j.name 

188 for j in all_jobs 

189 if j.status not in [JobStatus.COMPLETED, JobStatus.FAILED] 

190 ] 

191 

192 if failed_jobs: 

193 logger.error(f"❌ Jobs failed: {failed_jobs}") 

194 raise RuntimeError(f"Workflow execution failed: {failed_jobs}") 

195 

196 if incomplete_jobs: 

197 logger.error(f"❌ Jobs did not complete: {incomplete_jobs}") 

198 raise RuntimeError(f"Workflow execution incomplete: {incomplete_jobs}") 

199 

200 logger.success(f"🎉 Workflow {self.workflow.name} completed!!") 

201 

202 for callback in self.callbacks: 

203 callback.on_workflow_completed(self.workflow) 

204 

205 return results 

206 

207 def execute_from_yaml(self, yaml_path: str | Path) -> dict[str, RunableJobType]: 

208 """Load and execute a workflow from YAML file. 

209 

210 Args: 

211 yaml_path: Path to YAML workflow file. 

212 

213 Returns: 

214 Dictionary mapping job names to completed Job instances. 

215 """ 

216 logger.info(f"Loading workflow from {yaml_path}") 

217 runner = self.from_yaml(yaml_path) 

218 return runner.run() 

219 

220 @staticmethod 

221 def parse_job(data: dict[str, Any]) -> RunableJobType: 

222 if data.get("path") and data.get("command"): 

223 raise WorkflowValidationError("Job cannot have both 'path' and 'command'") 

224 

225 base = {"name": data["name"], "depends_on": data.get("depends_on", [])} 

226 

227 if data.get("path"): 

228 return ShellJob.model_validate({**base, "path": data["path"]}) 

229 

230 resource = JobResource.model_validate(data.get("resources", {})) 

231 environment = JobEnvironment.model_validate(data.get("environment", {})) 

232 

233 job_data = { 

234 **base, 

235 "command": data["command"], 

236 "resources": resource, 

237 "environment": environment, 

238 } 

239 if data.get("log_dir"): 

240 job_data["log_dir"] = data["log_dir"] 

241 if data.get("work_dir"): 

242 job_data["work_dir"] = data["work_dir"] 

243 

244 return Job.model_validate(job_data) 

245 

246 

247def run_workflow_from_file(yaml_path: str | Path) -> dict[str, RunableJobType]: 

248 """Convenience function to run workflow from YAML file. 

249 

250 Args: 

251 yaml_path: Path to YAML workflow file. 

252 

253 Returns: 

254 Dictionary mapping job names to completed Job instances. 

255 """ 

256 runner = WorkflowRunner.from_yaml(yaml_path) 

257 return runner.run()