Coverage for src/commands/run_sql.py: 6%
164 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
1"""
2Command for executing SQL queries on a Databricks warehouse.
3"""
5from typing import Optional, Any, Dict
6from src.clients.databricks import DatabricksAPIClient
7from src.command_registry import CommandDefinition
8from src.commands.base import CommandResult
9from src.config import get_warehouse_id, get_active_catalog
10import logging
13def handle_command(
14 client: Optional[DatabricksAPIClient], **kwargs: Any
15) -> CommandResult:
16 """
17 Execute a SQL query on a Databricks SQL warehouse.
19 Args:
20 client: DatabricksAPIClient instance for API calls
21 **kwargs: Command parameters
22 - warehouse_id: ID of the warehouse to run the query on
23 - query: SQL query to execute
24 - catalog: Optional catalog name to use
25 - wait_timeout: How long to wait for query completion (default "30s")
27 Returns:
28 CommandResult with query results if successful
29 """
30 if not client:
31 return CommandResult(
32 False,
33 message="No Databricks client available. Please set up your workspace first.",
34 )
36 # Extract parameters
37 warehouse_id = kwargs.get("warehouse_id")
38 query = kwargs.get("query")
39 catalog = kwargs.get("catalog")
40 wait_timeout = kwargs.get("wait_timeout", "30s")
42 # If warehouse_id not provided, try to use the configured warehouse
43 if not warehouse_id:
44 warehouse_id = get_warehouse_id()
45 if not warehouse_id:
46 return CommandResult(
47 False,
48 message="No warehouse ID specified and no active warehouse selected. Please provide a warehouse_id or select a warehouse first using /select-warehouse.",
49 )
51 # If catalog not provided, try to use active catalog
52 if not catalog:
53 catalog = get_active_catalog()
54 # It's fine if catalog is None, the API will handle it
56 try:
57 # Execute the SQL query
58 result = client.submit_sql_statement(
59 sql_text=query,
60 warehouse_id=warehouse_id,
61 catalog=catalog,
62 wait_timeout=wait_timeout,
63 )
65 # Check query status
66 state = result.get("status", {}).get("state", result.get("state", ""))
68 if state == "SUCCEEDED":
69 result_data = result.get("result", {})
70 data_array = result_data.get("data_array", [])
71 external_links = result_data.get("external_links", [])
72 manifest = result.get("manifest", {})
74 # Extract column schema information
75 column_infos = []
76 schema_location = None
78 # Try correct location: result.manifest.schema.columns (based on API response)
79 if manifest.get("schema", {}).get("columns"):
80 column_infos = manifest.get("schema", {}).get("columns", [])
81 schema_location = "result.manifest.schema.columns"
82 # Try secondary location: result_data.schema.columns
83 elif result_data.get("schema", {}).get("columns"):
84 column_infos = result_data.get("schema", {}).get("columns", [])
85 schema_location = "result_data.schema.columns"
86 # Try tertiary location: result_data.manifest.schema.columns
87 elif result_data.get("manifest", {}).get("schema", {}).get("columns"):
88 column_infos = (
89 result_data.get("manifest", {}).get("schema", {}).get("columns", [])
90 )
91 schema_location = "result_data.manifest.schema.columns"
92 # Try direct columns in schema
93 elif "schema" in result_data and isinstance(result_data["schema"], list):
94 column_infos = result_data["schema"]
95 schema_location = "result_data.schema (direct list)"
96 # Try at top level
97 elif result.get("schema", {}).get("columns"):
98 column_infos = result.get("schema", {}).get("columns", [])
99 schema_location = "result.schema.columns"
101 # Extract column names
102 columns = []
103 if column_infos:
104 columns = [
105 col.get("name") for col in column_infos if isinstance(col, dict)
106 ]
107 logging.debug(f"Found column schema at {schema_location}: {columns}")
109 # Check if we have external links (large result set)
110 if external_links:
111 # Large result set - use external links for pagination
112 total_row_count = manifest.get("total_row_count", 0)
113 chunks = manifest.get("chunks", [])
115 logging.info(
116 f"Large SQL result set detected: {total_row_count} total rows, {len(external_links)} chunks"
117 )
119 # If still no columns, create generic column names based on schema
120 if not columns:
121 column_count = manifest.get("schema", {}).get("column_count", 0)
122 if column_count > 0:
123 columns = [f"column_{i+1}" for i in range(column_count)]
124 logging.warning(
125 f"No column names found, generated {len(columns)} generic column names"
126 )
128 # Format the results for paginated display
129 formatted_results = {
130 "columns": columns,
131 "external_links": external_links,
132 "manifest": manifest,
133 "total_row_count": total_row_count,
134 "chunks": chunks,
135 "execution_time_ms": result.get("execution_time_ms"),
136 "is_paginated": True,
137 }
138 else:
139 # Small result set - traditional display with data_array
140 # If still no columns but we have data, create generic column names
141 if not columns and data_array:
142 first_row = data_array[0] if data_array else []
143 columns = [f"column_{i+1}" for i in range(len(first_row))]
144 logging.warning(
145 f"No column schema found in SQL result, generated {len(columns)} generic column names"
146 )
148 # Format the results for display
149 formatted_results = {
150 "columns": columns,
151 "rows": data_array,
152 "row_count": len(data_array),
153 "execution_time_ms": result.get("execution_time_ms"),
154 "is_paginated": False,
155 }
157 return CommandResult(
158 True,
159 data=formatted_results,
160 message=f"Query executed successfully with {len(data_array)} result(s).",
161 )
162 elif state == "FAILED":
163 error_message = (
164 result.get("status", {})
165 .get("error", {})
166 .get("message", result.get("error", {}).get("message", "Unknown error"))
167 )
168 return CommandResult(
169 False, message=f"Query execution failed: {error_message}"
170 )
171 elif state == "CANCELED":
172 return CommandResult(False, message="Query execution was canceled.")
173 else:
174 return CommandResult(
175 False,
176 message=f"Query did not complete successfully. Final state: {state}",
177 )
178 except Exception as e:
179 logging.error(f"Error executing SQL query: {str(e)}")
180 return CommandResult(
181 False, message=f"Failed to execute SQL query: {str(e)}", error=e
182 )
185def format_sql_results_for_agent(result: CommandResult) -> Dict[str, Any]:
186 """
187 Custom formatter for SQL results that displays them in a table format for the agent.
189 Args:
190 result: CommandResult containing SQL query results
192 Returns:
193 Dictionary with formatted results for agent consumption
194 """
195 if not result.success:
196 return {"error": result.message or "SQL query failed"}
198 if not result.data:
199 return {
200 "success": True,
201 "message": result.message or "Query completed successfully",
202 "results": "No data returned",
203 }
205 # Check if this is a paginated result set
206 if result.data.get("is_paginated", False):
207 return _format_paginated_results_for_agent(result)
209 columns = result.data.get("columns", [])
210 rows = result.data.get("rows", [])
211 row_count = result.data.get("row_count", 0)
212 execution_time = result.data.get("execution_time_ms")
214 # If no columns but we have rows, try to infer from row structure
215 if not columns and rows:
216 # For now, create generic column names
217 first_row = rows[0] if rows else []
218 columns = [f"column_{i+1}" for i in range(len(first_row))]
220 # Create a formatted table representation
221 table_lines = []
223 # Determine column widths dynamically
224 col_widths = []
225 if columns:
226 for i, col in enumerate(columns):
227 max_width = len(str(col)) # Start with header width
228 # Check data widths (sample first 10 rows)
229 sample_rows = rows[:10] if len(rows) > 10 else rows
230 for row in sample_rows:
231 if isinstance(row, list) and i < len(row):
232 val_width = len(str(row[i] if row[i] is not None else ""))
233 max_width = max(max_width, val_width)
234 # Cap width at 25 characters for readability
235 col_widths.append(min(max_width + 2, 25))
237 # Add header
238 if columns:
239 header = " | ".join(
240 str(col).ljust(col_widths[i]) for i, col in enumerate(columns)
241 )
242 table_lines.append(header)
243 table_lines.append("-" * len(header))
245 # Add rows (limit to first 10 for readability)
246 display_rows = rows[:10] if len(rows) > 10 else rows
247 for row in display_rows:
248 if isinstance(row, list):
249 formatted_cells = []
250 for i, val in enumerate(row[: len(columns)]):
251 val_str = str(val if val is not None else "")
252 # Truncate if too long
253 if len(val_str) > col_widths[i] - 2:
254 val_str = val_str[: col_widths[i] - 5] + "..."
255 formatted_cells.append(val_str.ljust(col_widths[i]))
256 table_lines.append(" | ".join(formatted_cells))
258 if len(rows) > 10:
259 table_lines.append(f"\n... and {len(rows) - 10} more rows")
261 table_output = "\n".join(table_lines)
263 # Format the response
264 response = {
265 "success": True,
266 "message": result.message or "Query executed successfully",
267 "results_table": table_output,
268 "summary": {
269 "total_rows": row_count,
270 "columns": columns,
271 "execution_time_ms": execution_time,
272 },
273 }
275 # Also include raw data for programmatic access
276 if len(rows) <= 50: # Only include raw data for smaller result sets
277 response["raw_data"] = {"columns": columns, "rows": rows}
279 return response
282def _format_paginated_results_for_agent(result: CommandResult) -> Dict[str, Any]:
283 """
284 Format paginated SQL results for agent consumption.
286 For paginated results, we fetch the first page to show a sample,
287 but inform the agent about the full result set size.
288 """
289 from src.commands.sql_external_data import PaginatedSQLResult
291 data = result.data
292 columns = data.get("columns", [])
293 external_links = data.get("external_links", [])
294 total_row_count = data.get("total_row_count", 0)
295 chunks = data.get("chunks", [])
296 execution_time = data.get("execution_time_ms")
298 try:
299 # Create paginated result handler
300 paginated_result = PaginatedSQLResult(
301 columns=columns,
302 external_links=external_links,
303 total_row_count=total_row_count,
304 chunks=chunks,
305 )
307 # Fetch first page as a sample
308 sample_rows, has_more = paginated_result.get_next_page()
310 # Create formatted table for the sample
311 table_lines = []
312 col_widths = []
314 if columns:
315 for i, col in enumerate(columns):
316 max_width = len(str(col))
317 # Check data widths in sample
318 for row in sample_rows[:10]:
319 if isinstance(row, list) and i < len(row):
320 val_width = len(str(row[i] if row[i] is not None else ""))
321 max_width = max(max_width, val_width)
322 col_widths.append(min(max_width + 2, 25))
324 # Add header
325 header = " | ".join(
326 str(col).ljust(col_widths[i]) for i, col in enumerate(columns)
327 )
328 table_lines.append(header)
329 table_lines.append("-" * len(header))
331 # Add sample rows
332 display_rows = sample_rows[:10]
333 for row in display_rows:
334 if isinstance(row, list):
335 formatted_cells = []
336 for i, val in enumerate(row[: len(columns)]):
337 val_str = str(val if val is not None else "")
338 if len(val_str) > col_widths[i] - 2:
339 val_str = val_str[: col_widths[i] - 5] + "..."
340 formatted_cells.append(val_str.ljust(col_widths[i]))
341 table_lines.append(" | ".join(formatted_cells))
343 if total_row_count > len(sample_rows):
344 table_lines.append(
345 f"\n... and {total_row_count - len(sample_rows)} more rows (use interactive display to see all)"
346 )
348 table_output = "\n".join(table_lines)
350 # Format the response for agent
351 response = {
352 "success": True,
353 "message": result.message or "Large result set query executed successfully",
354 "results_table": table_output,
355 "summary": {
356 "total_rows": total_row_count,
357 "sample_rows_shown": len(sample_rows),
358 "columns": columns,
359 "execution_time_ms": execution_time,
360 "is_paginated": True,
361 "note": "This is a large result set. Full results available in interactive display.",
362 },
363 }
365 # Include sample data for programmatic access
366 if sample_rows:
367 response["raw_data"] = {
368 "columns": columns,
369 "sample_rows": sample_rows,
370 "total_row_count": total_row_count,
371 }
373 return response
375 except Exception as e:
376 logging.error(f"Error formatting paginated results for agent: {e}")
377 return {
378 "success": True,
379 "message": result.message or "Large result set query executed successfully",
380 "results_table": f"Large result set with {total_row_count} rows available.\nError fetching sample: {str(e)}",
381 "summary": {
382 "total_rows": total_row_count,
383 "columns": columns,
384 "execution_time_ms": execution_time,
385 "is_paginated": True,
386 "error": str(e),
387 },
388 }
391DEFINITION = CommandDefinition(
392 name="run-sql",
393 description="Execute a SQL query on a Databricks SQL warehouse.",
394 handler=handle_command,
395 parameters={
396 "warehouse_id": {
397 "type": "string",
398 "description": "ID of the warehouse to run the query on.",
399 },
400 "query": {"type": "string", "description": "SQL query to execute."},
401 "catalog": {
402 "type": "string",
403 "description": "Optional catalog name to use for the query.",
404 },
405 "wait_timeout": {
406 "type": "string",
407 "description": "How long to wait for query completion (e.g., '30s', '1m').",
408 "default": "30s",
409 },
410 },
411 required_params=[
412 "query"
413 ], # Only query is required; warehouse_id can come from config
414 tui_aliases=["/run-sql", "/sql"],
415 needs_api_client=True,
416 visible_to_user=True,
417 visible_to_agent=True,
418 agent_display="full",
419 condensed_action="Running sql",
420 output_formatter=format_sql_results_for_agent,
421 usage_hint='Usage: /run-sql --query "SELECT * FROM my_table" [--warehouse_id <warehouse_id>] [--catalog <catalog>]\n(Uses active warehouse and catalog if not specified)',
422)