Coverage for src/commands/agent.py: 93%
75 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
1"""
2Command for processing natural language queries using an LLM agent.
4This module contains the handler for the agent command which processes
5natural language queries and interacts with Databricks resources.
6"""
8import logging
9from typing import Optional, Any
11from src.clients.databricks import DatabricksAPIClient
12from src.command_registry import CommandDefinition
13from src.commands.base import CommandResult
14from src.metrics_collector import get_metrics_collector
17def handle_command(
18 client: Optional[DatabricksAPIClient], **kwargs: Any
19) -> CommandResult:
20 """
21 Process a natural language query using the LLM agent.
23 Args:
24 client: DatabricksAPIClient instance for API calls (optional)
25 **kwargs: Command parameters
26 - query: The natural language query from the user
27 - mode: Optional agent mode (general, pii, bulk_pii, stitch)
28 - rest: Any additional text input provided after the command
29 - raw_args: Unparsed arguments (fallback when command parser fails)
30 - catalog_name: Optional catalog name for context
31 - schema_name: Optional schema name for context
33 Returns:
34 CommandResult with agent response
35 """
36 # First check for different ways the query might be provided
37 # Priority: 1. query parameter, 2. rest parameter, 3. raw_args
38 query = kwargs.get("query")
39 rest = kwargs.get("rest")
40 raw_args = kwargs.get("raw_args")
42 # If query wasn't provided but we have rest or raw_args, use that as the query
43 if not query: # This checks if the initial kwargs.get("query") was empty/None
44 if rest:
45 query = rest
46 elif raw_args:
47 if isinstance(raw_args, (list, tuple)):
48 query = " ".join(str(arg) for arg in raw_args)
49 else:
50 # Handle case where raw_args is a single string
51 query = str(raw_args)
53 # At this point, query might be a string from 'query', 'rest', 'raw_args', or still None.
54 # Strip whitespace if query is a string. This handles cases like " ".
55 # If query is None, .strip() would error, so we check isinstance.
56 if isinstance(query, str):
57 query = query.strip()
59 # Now, check if the (potentially stripped) query is truly empty or None.
60 if not query:
61 return CommandResult(
62 False, message="Please provide a query. Usage: /ask Your question here"
63 )
65 # Get optional parameters
66 mode = kwargs.get("mode", "general").lower()
67 catalog_name = kwargs.get("catalog_name")
68 schema_name = kwargs.get("schema_name")
69 tool_output_callback = kwargs.get("tool_output_callback")
71 try:
72 from src.agent import AgentManager
73 from src.config import get_agent_history, set_agent_history
75 # Get metrics collector
76 metrics_collector = get_metrics_collector()
78 # Create agent manager with the API client and tool output callback
79 agent = AgentManager(client, tool_output_callback=tool_output_callback)
81 # Load conversation history
82 try:
83 history = get_agent_history()
84 except Exception:
85 history = []
87 if history:
88 agent.conversation_history = history
90 # Process the query based on the selected mode
91 if mode == "pii":
92 # PII detection mode for a single table
93 response = agent.process_pii_detection(
94 table_name=query, catalog_name=catalog_name, schema_name=schema_name
95 )
96 elif mode == "bulk_pii":
97 # Bulk PII scanning mode for a schema
98 response = agent.process_bulk_pii_scan(
99 catalog_name=catalog_name, schema_name=schema_name
100 )
101 elif mode == "stitch":
102 # Stitch setup mode
103 response = agent.process_setup_stitch(
104 catalog_name=catalog_name, schema_name=schema_name
105 )
106 else:
107 # Default general query mode
108 response = agent.process_query(query)
110 # Save conversation history
111 set_agent_history(agent.conversation_history)
113 # Track the agent interaction event
114 if mode == "pii":
115 # For PII detection mode
116 processed_tools = [{"name": "pii_detection", "arguments": {"table": query}}]
117 event_context = "agent_interaction"
118 additional_data = {"event_context": event_context, "agent_mode": mode}
119 elif mode == "bulk_pii":
120 # For bulk PII scanning mode
121 processed_tools = [
122 {
123 "name": "bulk_pii_scan",
124 "arguments": {"catalog": catalog_name, "schema": schema_name},
125 }
126 ]
127 event_context = "agent_interaction"
128 additional_data = {"event_context": event_context, "agent_mode": mode}
129 elif mode == "stitch":
130 # For Stitch setup mode
131 processed_tools = [
132 {
133 "name": "setup_stitch",
134 "arguments": {"catalog": catalog_name, "schema": schema_name},
135 }
136 ]
137 event_context = "agent_interaction"
138 additional_data = {"event_context": event_context, "agent_mode": mode}
139 else:
140 # For general query mode
141 processed_tools = [{"name": "general_query", "arguments": {"query": query}}]
142 event_context = "agent_interaction"
143 additional_data = {"event_context": event_context, "agent_mode": mode}
145 # Get the last AI response from the conversation history
146 last_ai_response = None
147 if agent.conversation_history and len(agent.conversation_history) > 0:
148 for msg in reversed(agent.conversation_history):
149 # Handle both dict messages and ChatCompletionMessage objects
150 role = (
151 msg.get("role")
152 if hasattr(msg, "get")
153 else getattr(msg, "role", None)
154 )
155 if role == "assistant":
156 last_ai_response = msg
157 break
159 # Track the event
160 metrics_collector.track_event(
161 prompt=query,
162 tools=processed_tools,
163 conversation_history=[last_ai_response] if last_ai_response else None,
164 additional_data=additional_data,
165 )
167 return CommandResult(
168 True,
169 data={"response": response, "conversation": agent.conversation_history},
170 )
172 except Exception as e:
173 # Handle pagination cancellation specially - let it bubble up
174 from src.exceptions import PaginationCancelled
176 if isinstance(e, PaginationCancelled):
177 raise # Re-raise to bubble up to main TUI loop
179 logging.error(f"Agent error: {e}", exc_info=True)
180 return CommandResult(
181 False, message=f"Failed to process query: {str(e)}", error=e
182 )
185DEFINITION = CommandDefinition(
186 name="agent",
187 description="Process natural language queries using an LLM agent",
188 handler=handle_command,
189 parameters={
190 "query": {
191 "type": "string",
192 "description": "Natural language query to process",
193 },
194 "mode": {
195 "type": "string",
196 "description": "Agent mode (general, pii, bulk_pii, stitch)",
197 "default": "general",
198 },
199 "catalog_name": {
200 "type": "string",
201 "description": "Optional catalog name for context (uses active catalog if not provided)",
202 },
203 "schema_name": {
204 "type": "string",
205 "description": "Optional schema name for context (uses active schema if not provided)",
206 },
207 "rest": {
208 "type": "string",
209 "description": "Additional text after the command to use as query",
210 },
211 "raw_args": {
212 "type": ["array", "string"],
213 "description": "Raw unparsed arguments in case parsing fails",
214 },
215 },
216 # Not requiring query since we handle combining raw_args and rest in the handler
217 required_params=[],
218 tui_aliases=["/agent", "/ask"],
219 needs_api_client=True,
220 visible_to_user=True,
221 visible_to_agent=False, # Don't let the agent use itself
222 usage_hint='Usage: /ask Your natural language question here\n /agent --query "Your question here" [--mode general|pii|bulk_pii|stitch]',
223)