Coverage for src/srunx/cli/main.py: 70%
222 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:10 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-22 15:10 +0000
1"""Main CLI interface for srunx."""
3import argparse
4import os
5import sys
6from pathlib import Path
8from rich.console import Console
9from rich.table import Table
11from srunx.callbacks import SlackCallback
12from srunx.client import Slurm
13from srunx.logging import (
14 configure_cli_logging,
15 configure_workflow_logging,
16 get_logger,
17)
18from srunx.models import Job, JobEnvironment, JobResource
19from srunx.runner import WorkflowRunner
21logger = get_logger(__name__)
24def create_job_parser() -> argparse.ArgumentParser:
25 """Create argument parser for job submission."""
26 parser = argparse.ArgumentParser(
27 description="Submit SLURM jobs with various configurations",
28 formatter_class=argparse.RawDescriptionHelpFormatter,
29 )
31 # Required arguments
32 parser.add_argument(
33 "command",
34 nargs="+",
35 help="Command to execute in the SLURM job",
36 )
38 # Job configuration
39 parser.add_argument(
40 "--name",
41 "--job-name",
42 type=str,
43 default="job",
44 help="Job name (default: %(default)s)",
45 )
46 parser.add_argument(
47 "--log-dir",
48 type=str,
49 default=os.getenv("SLURM_LOG_DIR", "logs"),
50 help="Log directory (default: %(default)s)",
51 )
52 parser.add_argument(
53 "--work-dir",
54 "--chdir",
55 type=str,
56 help="Working directory for the job",
57 )
59 # Resource configuration
60 resource_group = parser.add_argument_group("Resource Options")
61 resource_group.add_argument(
62 "-N",
63 "--nodes",
64 type=int,
65 default=1,
66 help="Number of nodes (default: %(default)s)",
67 )
68 resource_group.add_argument(
69 "--gpus-per-node",
70 type=int,
71 default=0,
72 help="Number of GPUs per node (default: %(default)s)",
73 )
74 resource_group.add_argument(
75 "--ntasks-per-node",
76 type=int,
77 default=1,
78 help="Number of tasks per node (default: %(default)s)",
79 )
80 resource_group.add_argument(
81 "--cpus-per-task",
82 type=int,
83 default=1,
84 help="Number of CPUs per task (default: %(default)s)",
85 )
86 resource_group.add_argument(
87 "--memory",
88 "--mem",
89 type=str,
90 help="Memory per node (e.g., '32GB', '1TB')",
91 )
92 resource_group.add_argument(
93 "--time",
94 "--time-limit",
95 type=str,
96 help="Time limit (e.g., '1:00:00', '30:00', '1-12:00:00')",
97 )
99 # Environment configuration
100 env_group = parser.add_argument_group("Environment Options")
101 env_group.add_argument(
102 "--conda",
103 type=str,
104 help="Conda environment name",
105 )
106 env_group.add_argument(
107 "--venv",
108 type=str,
109 help="Virtual environment path",
110 )
111 env_group.add_argument(
112 "--sqsh",
113 type=str,
114 help="SquashFS image path",
115 )
116 env_group.add_argument(
117 "--env",
118 action="append",
119 dest="env_vars",
120 help="Environment variable KEY=VALUE (can be used multiple times)",
121 )
123 # Execution options
124 exec_group = parser.add_argument_group("Execution Options")
125 exec_group.add_argument(
126 "--template",
127 type=str,
128 help="Path to custom SLURM template file",
129 )
130 exec_group.add_argument(
131 "--wait",
132 action="store_true",
133 help="Wait for job completion",
134 )
135 exec_group.add_argument(
136 "--poll-interval",
137 type=int,
138 default=5,
139 help="Polling interval in seconds when waiting (default: %(default)s)",
140 )
142 # Logging options
143 log_group = parser.add_argument_group("Logging Options")
144 log_group.add_argument(
145 "--log-level",
146 choices=["DEBUG", "INFO", "WARNING", "ERROR"],
147 default="INFO",
148 help="Set logging level (default: %(default)s)",
149 )
150 log_group.add_argument(
151 "--quiet",
152 "-q",
153 action="store_true",
154 help="Only show warnings and errors",
155 )
157 # Callback options
158 callback_group = parser.add_argument_group("Notification Options")
159 callback_group.add_argument(
160 "--slack",
161 action="store_true",
162 help="Send notifications to Slack",
163 )
165 # Misc options
166 misc_group = parser.add_argument_group("Misc Options")
167 misc_group.add_argument(
168 "--verbose",
169 action="store_true",
170 help="Print the rendered content",
171 )
173 return parser
176def create_status_parser() -> argparse.ArgumentParser:
177 """Create argument parser for job status."""
178 parser = argparse.ArgumentParser(
179 description="Check SLURM job status",
180 formatter_class=argparse.RawDescriptionHelpFormatter,
181 )
183 parser.add_argument(
184 "job_id",
185 type=int,
186 help="SLURM job ID to check",
187 )
189 return parser
192def create_queue_parser() -> argparse.ArgumentParser:
193 """Create argument parser for queueing jobs."""
194 parser = argparse.ArgumentParser(
195 description="Queue SLURM jobs",
196 formatter_class=argparse.RawDescriptionHelpFormatter,
197 )
199 parser.add_argument(
200 "--user",
201 "-u",
202 type=str,
203 help="Queue jobs for specific user (default: current user)",
204 )
206 return parser
209def create_cancel_parser() -> argparse.ArgumentParser:
210 """Create argument parser for job cancellation."""
211 parser = argparse.ArgumentParser(
212 description="Cancel SLURM job",
213 formatter_class=argparse.RawDescriptionHelpFormatter,
214 )
216 parser.add_argument(
217 "job_id",
218 type=int,
219 help="SLURM job ID to cancel",
220 )
222 return parser
225def create_main_parser() -> argparse.ArgumentParser:
226 """Create main argument parser with subcommands."""
227 parser = argparse.ArgumentParser(
228 description="srunx - Python library for SLURM job management",
229 formatter_class=argparse.RawDescriptionHelpFormatter,
230 )
232 # Global options
233 parser.add_argument(
234 "--log-level",
235 "-l",
236 choices=["DEBUG", "INFO", "WARNING", "ERROR"],
237 default="INFO",
238 help="Set logging level (default: %(default)s)",
239 )
240 parser.add_argument(
241 "--quiet",
242 "-q",
243 action="store_true",
244 help="Only show warnings and errors",
245 )
247 subparsers = parser.add_subparsers(dest="command", help="Available commands")
249 # Submit command (default)
250 submit_parser = subparsers.add_parser("submit", help="Submit a SLURM job")
251 submit_parser.set_defaults(func=cmd_submit)
252 _copy_parser_args(create_job_parser(), submit_parser)
254 # Status command
255 status_parser = subparsers.add_parser("status", help="Check job status")
256 status_parser.set_defaults(func=cmd_status)
257 _copy_parser_args(create_status_parser(), status_parser)
259 # Queue command
260 queue_parser = subparsers.add_parser("queue", help="Queue jobs")
261 queue_parser.set_defaults(func=cmd_queue)
262 _copy_parser_args(create_queue_parser(), queue_parser)
264 # Cancel command
265 cancel_parser = subparsers.add_parser("cancel", help="Cancel job")
266 cancel_parser.set_defaults(func=cmd_cancel)
267 _copy_parser_args(create_cancel_parser(), cancel_parser)
269 # Flow command
270 flow_parser = subparsers.add_parser("flow", help="Workflow management")
271 flow_parser.set_defaults(func=None) # Will be overridden by subcommands
273 # Flow subcommands
274 flow_subparsers = flow_parser.add_subparsers(
275 dest="flow_command", help="Flow commands"
276 )
278 # Flow run command
279 flow_run_parser = flow_subparsers.add_parser("run", help="Execute workflow")
280 flow_run_parser.set_defaults(func=cmd_flow_run)
281 flow_run_parser.add_argument(
282 "yaml_file",
283 type=str,
284 help="Path to YAML workflow definition file",
285 )
286 flow_run_parser.add_argument(
287 "--dry-run",
288 action="store_true",
289 help="Show what would be executed without running jobs",
290 )
291 flow_run_parser.add_argument(
292 "--slack",
293 action="store_true",
294 help="Send notifications to Slack",
295 )
297 # Flow validate command
298 flow_validate_parser = flow_subparsers.add_parser(
299 "validate", help="Validate workflow"
300 )
301 flow_validate_parser.set_defaults(func=cmd_flow_validate)
302 flow_validate_parser.add_argument(
303 "yaml_file",
304 type=str,
305 help="Path to YAML workflow definition file",
306 )
308 return parser
311def _copy_parser_args(
312 source_parser: argparse.ArgumentParser, target_parser: argparse.ArgumentParser
313) -> None:
314 """Copy arguments from source parser to target parser."""
315 for action in source_parser._actions:
316 if action.dest == "help":
317 continue
318 target_parser._add_action(action)
321def _parse_env_vars(env_var_list: list[str] | None) -> dict[str, str]:
322 """Parse environment variables from list of KEY=VALUE strings."""
323 env_vars = {}
324 if env_var_list:
325 for env_var in env_var_list:
326 if "=" in env_var:
327 key, value = env_var.split("=", 1)
328 env_vars[key] = value
329 else:
330 logger.warning(f"Invalid environment variable format: {env_var}")
331 return env_vars
334def cmd_submit(args: argparse.Namespace) -> None:
335 """Handle job submission command."""
336 try:
337 # Parse environment variables
338 env_vars = _parse_env_vars(getattr(args, "env_vars", None))
340 # Create job configuration
341 resources = JobResource(
342 nodes=args.nodes,
343 gpus_per_node=args.gpus_per_node,
344 ntasks_per_node=args.ntasks_per_node,
345 cpus_per_task=args.cpus_per_task,
346 memory_per_node=getattr(args, "memory", None),
347 time_limit=getattr(args, "time", None),
348 )
350 environment = JobEnvironment(
351 conda=getattr(args, "conda", None),
352 venv=getattr(args, "venv", None),
353 sqsh=getattr(args, "sqsh", None),
354 env_vars=env_vars,
355 )
357 job_data = {
358 "name": args.name,
359 "command": args.command,
360 "resources": resources,
361 "environment": environment,
362 "log_dir": args.log_dir,
363 }
365 if args.work_dir is not None:
366 job_data["work_dir"] = args.work_dir
368 job = Job.model_validate(job_data)
370 if args.slack:
371 webhook_url = os.getenv("SLACK_WEBHOOK_URL")
372 if not webhook_url:
373 raise ValueError("SLACK_WEBHOOK_URL is not set")
374 callbacks = [SlackCallback(webhook_url=webhook_url)]
375 else:
376 callbacks = []
378 # Submit job
379 client = Slurm(callbacks=callbacks)
380 submitted_job = client.submit(
381 job, getattr(args, "template", None), verbose=args.verbose
382 )
384 logger.info(f"Submitted job {submitted_job.job_id}: {submitted_job.name}")
386 # Wait for completion if requested
387 if getattr(args, "wait", False):
388 logger.info(f"Waiting for job {submitted_job.job_id} to complete...")
389 completed_job = client.monitor(
390 submitted_job, poll_interval=args.poll_interval
391 )
392 status_str = (
393 completed_job.status.value if completed_job.status else "Unknown"
394 )
395 logger.info(
396 f"Job {submitted_job.job_id} completed with status: {status_str}"
397 )
399 except Exception as e:
400 logger.error(f"Error submitting job: {e}")
401 sys.exit(1)
404def cmd_status(args: argparse.Namespace) -> None:
405 """Handle job status command."""
406 try:
407 client = Slurm()
408 job = client.retrieve(args.job_id)
410 logger.info(f"Job ID: {job.job_id}")
411 logger.info(f"Name: {job.name}")
412 if job.status:
413 logger.info(f"Status: {job.status.value}")
414 else:
415 logger.info("Status: Unknown")
417 except Exception as e:
418 logger.error(f"Error getting job status: {e}")
419 sys.exit(1)
422def cmd_queue(args: argparse.Namespace) -> None:
423 """Handle job queueing command."""
424 try:
425 client = Slurm()
426 jobs = client.queue(getattr(args, "user", None))
428 if not jobs:
429 logger.info("No jobs found")
430 return
432 logger.info(f"{'Job ID':<12} {'Name':<20} {'Status':<12}")
433 logger.info("-" * 45)
434 for job in jobs:
435 status_str = job.status.value if job.status else "Unknown"
436 logger.info(f"{job.job_id:<12} {job.name:<20} {status_str:<12}")
438 except Exception as e:
439 logger.error(f"Error queueing jobs: {e}")
440 sys.exit(1)
443def cmd_cancel(args: argparse.Namespace) -> None:
444 """Handle job cancellation command."""
445 try:
446 client = Slurm()
447 client.cancel(args.job_id)
448 logger.info(f"Cancelled job {args.job_id}")
450 except Exception as e:
451 logger.error(f"Error cancelling job: {e}")
452 sys.exit(1)
455def cmd_flow_run(args: argparse.Namespace) -> None:
456 """Handle flow run command."""
457 # Configure logging for workflow execution
458 configure_workflow_logging(level=getattr(args, "log_level", "INFO"))
460 try:
461 yaml_file = Path(args.yaml_file)
462 if not yaml_file.exists():
463 logger.error(f"Workflow file not found: {args.yaml_file}")
464 sys.exit(1)
466 # Setup callbacks if requested
467 callbacks = []
468 if getattr(args, "slack", False):
469 webhook_url = os.getenv("SLACK_WEBHOOK_URL")
470 if not webhook_url:
471 raise ValueError("SLACK_WEBHOOK_URL environment variable is not set")
472 callbacks.append(SlackCallback(webhook_url=webhook_url))
474 runner = WorkflowRunner.from_yaml(yaml_file, callbacks=callbacks)
476 # Validate dependencies
477 runner.workflow.validate()
479 if args.dry_run:
480 runner.workflow.show()
481 return
483 # Execute workflow
484 results = runner.run()
486 logger.success(f"🎉 Workflow {runner.workflow.name} completed!!")
487 table = Table(title=f"Workflow {runner.workflow.name} Summary")
488 table.add_column("Job", justify="left", style="cyan", no_wrap=True)
489 table.add_column("Status", justify="left", style="cyan", no_wrap=True)
490 table.add_column("ID", justify="left", style="cyan", no_wrap=True)
491 for job in results.values():
492 table.add_row(job.name, job.status.value, str(job.job_id))
494 console = Console()
495 console.print(table)
497 except Exception as e:
498 logger.error(f"Workflow execution failed: {e}")
499 sys.exit(1)
502def cmd_flow_validate(args: argparse.Namespace) -> None:
503 """Handle flow validate command."""
504 # Configure logging for workflow validation
505 configure_workflow_logging(level=getattr(args, "log_level", "INFO"))
507 try:
508 yaml_file = Path(args.yaml_file)
509 if not yaml_file.exists():
510 logger.error(f"Workflow file not found: {args.yaml_file}")
511 sys.exit(1)
513 runner = WorkflowRunner.from_yaml(yaml_file)
515 # Validate dependencies
516 runner.workflow.validate()
518 logger.info("Workflow validation successful")
520 except Exception as e:
521 logger.error(f"Workflow validation failed: {e}")
522 sys.exit(1)
525def main() -> None:
526 """Main entry point for the CLI."""
527 parser = create_main_parser()
528 args = parser.parse_args()
530 # Configure logging
531 log_level = getattr(args, "log_level", "INFO")
532 quiet = getattr(args, "quiet", False)
533 configure_cli_logging(level=log_level, quiet=quiet)
535 # If no command specified, default to submit behavior for backward compatibility
536 if not hasattr(args, "func") or args.func is None:
537 # Check if this is a flow command without subcommand
538 if hasattr(args, "command") and args.command == "flow":
539 if not hasattr(args, "flow_command") or args.flow_command is None:
540 logger.error("Flow command requires a subcommand (run or validate)")
541 parser.print_help()
542 sys.exit(1)
543 else:
544 # Try to parse as submit command
545 submit_parser = create_job_parser()
546 try:
547 submit_args = submit_parser.parse_args()
548 cmd_submit(submit_args)
549 except SystemExit:
550 parser.print_help()
551 sys.exit(1)
552 else:
553 args.func(args)
556if __name__ == "__main__":
557 main()