Coverage for src/chuck_data/commands/run_sql.py: 0%

164 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-05 22:56 -0700

1""" 

2Command for executing SQL queries on a Databricks warehouse. 

3""" 

4 

5from typing import Optional, Any, Dict 

6from ..clients.databricks import DatabricksAPIClient 

7from ..command_registry import CommandDefinition 

8from ..commands.base import CommandResult 

9from ..config import get_warehouse_id, get_active_catalog 

10import logging 

11 

12 

13def handle_command( 

14 client: Optional[DatabricksAPIClient], **kwargs: Any 

15) -> CommandResult: 

16 """ 

17 Execute a SQL query on a Databricks SQL warehouse. 

18 

19 Args: 

20 client: DatabricksAPIClient instance for API calls 

21 **kwargs: Command parameters 

22 - warehouse_id: ID of the warehouse to run the query on 

23 - query: SQL query to execute 

24 - catalog: Optional catalog name to use 

25 - wait_timeout: How long to wait for query completion (default "30s") 

26 

27 Returns: 

28 CommandResult with query results if successful 

29 """ 

30 if not client: 

31 return CommandResult( 

32 False, 

33 message="No Databricks client available. Please set up your workspace first.", 

34 ) 

35 

36 # Extract parameters 

37 warehouse_id = kwargs.get("warehouse_id") 

38 query = kwargs.get("query") 

39 catalog = kwargs.get("catalog") 

40 wait_timeout = kwargs.get("wait_timeout", "30s") 

41 

42 # If warehouse_id not provided, try to use the configured warehouse 

43 if not warehouse_id: 

44 warehouse_id = get_warehouse_id() 

45 if not warehouse_id: 

46 return CommandResult( 

47 False, 

48 message="No warehouse ID specified and no active warehouse selected. Please provide a warehouse_id or select a warehouse first using /select-warehouse.", 

49 ) 

50 

51 # If catalog not provided, try to use active catalog 

52 if not catalog: 

53 catalog = get_active_catalog() 

54 # It's fine if catalog is None, the API will handle it 

55 

56 try: 

57 # Execute the SQL query 

58 result = client.submit_sql_statement( 

59 sql_text=query, 

60 warehouse_id=warehouse_id, 

61 catalog=catalog, 

62 wait_timeout=wait_timeout, 

63 ) 

64 

65 # Check query status 

66 state = result.get("status", {}).get("state", result.get("state", "")) 

67 

68 if state == "SUCCEEDED": 

69 result_data = result.get("result", {}) 

70 data_array = result_data.get("data_array", []) 

71 external_links = result_data.get("external_links", []) 

72 manifest = result.get("manifest", {}) 

73 

74 # Extract column schema information 

75 column_infos = [] 

76 schema_location = None 

77 

78 # Try correct location: result.manifest.schema.columns (based on API response) 

79 if manifest.get("schema", {}).get("columns"): 

80 column_infos = manifest.get("schema", {}).get("columns", []) 

81 schema_location = "result.manifest.schema.columns" 

82 # Try secondary location: result_data.schema.columns 

83 elif result_data.get("schema", {}).get("columns"): 

84 column_infos = result_data.get("schema", {}).get("columns", []) 

85 schema_location = "result_data.schema.columns" 

86 # Try tertiary location: result_data.manifest.schema.columns 

87 elif result_data.get("manifest", {}).get("schema", {}).get("columns"): 

88 column_infos = ( 

89 result_data.get("manifest", {}).get("schema", {}).get("columns", []) 

90 ) 

91 schema_location = "result_data.manifest.schema.columns" 

92 # Try direct columns in schema 

93 elif "schema" in result_data and isinstance(result_data["schema"], list): 

94 column_infos = result_data["schema"] 

95 schema_location = "result_data.schema (direct list)" 

96 # Try at top level 

97 elif result.get("schema", {}).get("columns"): 

98 column_infos = result.get("schema", {}).get("columns", []) 

99 schema_location = "result.schema.columns" 

100 

101 # Extract column names 

102 columns = [] 

103 if column_infos: 

104 columns = [ 

105 col.get("name") for col in column_infos if isinstance(col, dict) 

106 ] 

107 logging.debug(f"Found column schema at {schema_location}: {columns}") 

108 

109 # Check if we have external links (large result set) 

110 if external_links: 

111 # Large result set - use external links for pagination 

112 total_row_count = manifest.get("total_row_count", 0) 

113 chunks = manifest.get("chunks", []) 

114 

115 logging.info( 

116 f"Large SQL result set detected: {total_row_count} total rows, {len(external_links)} chunks" 

117 ) 

118 

119 # If still no columns, create generic column names based on schema 

120 if not columns: 

121 column_count = manifest.get("schema", {}).get("column_count", 0) 

122 if column_count > 0: 

123 columns = [f"column_{i+1}" for i in range(column_count)] 

124 logging.warning( 

125 f"No column names found, generated {len(columns)} generic column names" 

126 ) 

127 

128 # Format the results for paginated display 

129 formatted_results = { 

130 "columns": columns, 

131 "external_links": external_links, 

132 "manifest": manifest, 

133 "total_row_count": total_row_count, 

134 "chunks": chunks, 

135 "execution_time_ms": result.get("execution_time_ms"), 

136 "is_paginated": True, 

137 } 

138 else: 

139 # Small result set - traditional display with data_array 

140 # If still no columns but we have data, create generic column names 

141 if not columns and data_array: 

142 first_row = data_array[0] if data_array else [] 

143 columns = [f"column_{i+1}" for i in range(len(first_row))] 

144 logging.warning( 

145 f"No column schema found in SQL result, generated {len(columns)} generic column names" 

146 ) 

147 

148 # Format the results for display 

149 formatted_results = { 

150 "columns": columns, 

151 "rows": data_array, 

152 "row_count": len(data_array), 

153 "execution_time_ms": result.get("execution_time_ms"), 

154 "is_paginated": False, 

155 } 

156 

157 return CommandResult( 

158 True, 

159 data=formatted_results, 

160 message=f"Query executed successfully with {len(data_array)} result(s).", 

161 ) 

162 elif state == "FAILED": 

163 error_message = ( 

164 result.get("status", {}) 

165 .get("error", {}) 

166 .get("message", result.get("error", {}).get("message", "Unknown error")) 

167 ) 

168 return CommandResult( 

169 False, message=f"Query execution failed: {error_message}" 

170 ) 

171 elif state == "CANCELED": 

172 return CommandResult(False, message="Query execution was canceled.") 

173 else: 

174 return CommandResult( 

175 False, 

176 message=f"Query did not complete successfully. Final state: {state}", 

177 ) 

178 except Exception as e: 

179 logging.error(f"Error executing SQL query: {str(e)}") 

180 return CommandResult( 

181 False, message=f"Failed to execute SQL query: {str(e)}", error=e 

182 ) 

183 

184 

185def format_sql_results_for_agent(result: CommandResult) -> Dict[str, Any]: 

186 """ 

187 Custom formatter for SQL results that displays them in a table format for the agent. 

188 

189 Args: 

190 result: CommandResult containing SQL query results 

191 

192 Returns: 

193 Dictionary with formatted results for agent consumption 

194 """ 

195 if not result.success: 

196 return {"error": result.message or "SQL query failed"} 

197 

198 if not result.data: 

199 return { 

200 "success": True, 

201 "message": result.message or "Query completed successfully", 

202 "results": "No data returned", 

203 } 

204 

205 # Check if this is a paginated result set 

206 if result.data.get("is_paginated", False): 

207 return _format_paginated_results_for_agent(result) 

208 

209 columns = result.data.get("columns", []) 

210 rows = result.data.get("rows", []) 

211 row_count = result.data.get("row_count", 0) 

212 execution_time = result.data.get("execution_time_ms") 

213 

214 # If no columns but we have rows, try to infer from row structure 

215 if not columns and rows: 

216 # For now, create generic column names 

217 first_row = rows[0] if rows else [] 

218 columns = [f"column_{i+1}" for i in range(len(first_row))] 

219 

220 # Create a formatted table representation 

221 table_lines = [] 

222 

223 # Determine column widths dynamically 

224 col_widths = [] 

225 if columns: 

226 for i, col in enumerate(columns): 

227 max_width = len(str(col)) # Start with header width 

228 # Check data widths (sample first 10 rows) 

229 sample_rows = rows[:10] if len(rows) > 10 else rows 

230 for row in sample_rows: 

231 if isinstance(row, list) and i < len(row): 

232 val_width = len(str(row[i] if row[i] is not None else "")) 

233 max_width = max(max_width, val_width) 

234 # Cap width at 25 characters for readability 

235 col_widths.append(min(max_width + 2, 25)) 

236 

237 # Add header 

238 if columns: 

239 header = " | ".join( 

240 str(col).ljust(col_widths[i]) for i, col in enumerate(columns) 

241 ) 

242 table_lines.append(header) 

243 table_lines.append("-" * len(header)) 

244 

245 # Add rows (limit to first 10 for readability) 

246 display_rows = rows[:10] if len(rows) > 10 else rows 

247 for row in display_rows: 

248 if isinstance(row, list): 

249 formatted_cells = [] 

250 for i, val in enumerate(row[: len(columns)]): 

251 val_str = str(val if val is not None else "") 

252 # Truncate if too long 

253 if len(val_str) > col_widths[i] - 2: 

254 val_str = val_str[: col_widths[i] - 5] + "..." 

255 formatted_cells.append(val_str.ljust(col_widths[i])) 

256 table_lines.append(" | ".join(formatted_cells)) 

257 

258 if len(rows) > 10: 

259 table_lines.append(f"\n... and {len(rows) - 10} more rows") 

260 

261 table_output = "\n".join(table_lines) 

262 

263 # Format the response 

264 response = { 

265 "success": True, 

266 "message": result.message or "Query executed successfully", 

267 "results_table": table_output, 

268 "summary": { 

269 "total_rows": row_count, 

270 "columns": columns, 

271 "execution_time_ms": execution_time, 

272 }, 

273 } 

274 

275 # Also include raw data for programmatic access 

276 if len(rows) <= 50: # Only include raw data for smaller result sets 

277 response["raw_data"] = {"columns": columns, "rows": rows} 

278 

279 return response 

280 

281 

282def _format_paginated_results_for_agent(result: CommandResult) -> Dict[str, Any]: 

283 """ 

284 Format paginated SQL results for agent consumption. 

285 

286 For paginated results, we fetch the first page to show a sample, 

287 but inform the agent about the full result set size. 

288 """ 

289 from ..commands.sql_external_data import PaginatedSQLResult 

290 

291 data = result.data 

292 columns = data.get("columns", []) 

293 external_links = data.get("external_links", []) 

294 total_row_count = data.get("total_row_count", 0) 

295 chunks = data.get("chunks", []) 

296 execution_time = data.get("execution_time_ms") 

297 

298 try: 

299 # Create paginated result handler 

300 paginated_result = PaginatedSQLResult( 

301 columns=columns, 

302 external_links=external_links, 

303 total_row_count=total_row_count, 

304 chunks=chunks, 

305 ) 

306 

307 # Fetch first page as a sample 

308 sample_rows, has_more = paginated_result.get_next_page() 

309 

310 # Create formatted table for the sample 

311 table_lines = [] 

312 col_widths = [] 

313 

314 if columns: 

315 for i, col in enumerate(columns): 

316 max_width = len(str(col)) 

317 # Check data widths in sample 

318 for row in sample_rows[:10]: 

319 if isinstance(row, list) and i < len(row): 

320 val_width = len(str(row[i] if row[i] is not None else "")) 

321 max_width = max(max_width, val_width) 

322 col_widths.append(min(max_width + 2, 25)) 

323 

324 # Add header 

325 header = " | ".join( 

326 str(col).ljust(col_widths[i]) for i, col in enumerate(columns) 

327 ) 

328 table_lines.append(header) 

329 table_lines.append("-" * len(header)) 

330 

331 # Add sample rows 

332 display_rows = sample_rows[:10] 

333 for row in display_rows: 

334 if isinstance(row, list): 

335 formatted_cells = [] 

336 for i, val in enumerate(row[: len(columns)]): 

337 val_str = str(val if val is not None else "") 

338 if len(val_str) > col_widths[i] - 2: 

339 val_str = val_str[: col_widths[i] - 5] + "..." 

340 formatted_cells.append(val_str.ljust(col_widths[i])) 

341 table_lines.append(" | ".join(formatted_cells)) 

342 

343 if total_row_count > len(sample_rows): 

344 table_lines.append( 

345 f"\n... and {total_row_count - len(sample_rows)} more rows (use interactive display to see all)" 

346 ) 

347 

348 table_output = "\n".join(table_lines) 

349 

350 # Format the response for agent 

351 response = { 

352 "success": True, 

353 "message": result.message or "Large result set query executed successfully", 

354 "results_table": table_output, 

355 "summary": { 

356 "total_rows": total_row_count, 

357 "sample_rows_shown": len(sample_rows), 

358 "columns": columns, 

359 "execution_time_ms": execution_time, 

360 "is_paginated": True, 

361 "note": "This is a large result set. Full results available in interactive display.", 

362 }, 

363 } 

364 

365 # Include sample data for programmatic access 

366 if sample_rows: 

367 response["raw_data"] = { 

368 "columns": columns, 

369 "sample_rows": sample_rows, 

370 "total_row_count": total_row_count, 

371 } 

372 

373 return response 

374 

375 except Exception as e: 

376 logging.error(f"Error formatting paginated results for agent: {e}") 

377 return { 

378 "success": True, 

379 "message": result.message or "Large result set query executed successfully", 

380 "results_table": f"Large result set with {total_row_count} rows available.\nError fetching sample: {str(e)}", 

381 "summary": { 

382 "total_rows": total_row_count, 

383 "columns": columns, 

384 "execution_time_ms": execution_time, 

385 "is_paginated": True, 

386 "error": str(e), 

387 }, 

388 } 

389 

390 

391DEFINITION = CommandDefinition( 

392 name="run-sql", 

393 description="Execute a SQL query on a Databricks SQL warehouse.", 

394 handler=handle_command, 

395 parameters={ 

396 "warehouse_id": { 

397 "type": "string", 

398 "description": "ID of the warehouse to run the query on.", 

399 }, 

400 "query": {"type": "string", "description": "SQL query to execute."}, 

401 "catalog": { 

402 "type": "string", 

403 "description": "Optional catalog name to use for the query.", 

404 }, 

405 "wait_timeout": { 

406 "type": "string", 

407 "description": "How long to wait for query completion (e.g., '30s', '1m').", 

408 "default": "30s", 

409 }, 

410 }, 

411 required_params=[ 

412 "query" 

413 ], # Only query is required; warehouse_id can come from config 

414 tui_aliases=["/run-sql", "/sql"], 

415 needs_api_client=True, 

416 visible_to_user=True, 

417 visible_to_agent=True, 

418 agent_display="full", 

419 condensed_action="Running sql", 

420 output_formatter=format_sql_results_for_agent, 

421 usage_hint='Usage: /run-sql --query "SELECT * FROM my_table" [--warehouse_id <warehouse_id>] [--catalog <catalog>]\n(Uses active warehouse and catalog if not specified)', 

422)