Coverage for src/commands/stitch_tools.py: 42%

191 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-05 22:56 -0700

1""" 

2Stitch integration helper functions for command handlers. 

3 

4This module contains utilities for setting up Stitch integration 

5with Databricks catalogs and schemas. 

6""" 

7 

8import logging 

9import json 

10import datetime 

11from typing import Dict, Any 

12 

13from src.clients.databricks import DatabricksAPIClient 

14from src.llm.client import LLMClient 

15from src.config import get_amperity_token 

16from .pii_tools import _helper_scan_schema_for_pii_logic 

17from .cluster_init_tools import _helper_upload_cluster_init_logic 

18 

19UNSUPPORTED_TYPES = [ 

20 "INTERVAL", 

21 "VOID", 

22 "ARRAY", 

23 "MAP", 

24 "STRUCT", 

25 "VARIANT", 

26 "OBJECT", 

27 "GEOGRAPHY", 

28 "GEOMETRY", 

29] 

30 

31 

32def _helper_setup_stitch_logic( 

33 client: DatabricksAPIClient, 

34 llm_client_instance: LLMClient, 

35 target_catalog: str, 

36 target_schema: str, 

37) -> Dict[str, Any]: 

38 """Legacy function for backward compatibility. Calls prepare phase only. 

39 

40 IMPORTANT: This has been modified to only run the preparation phase and not 

41 automatically launch the job, which is now handled by the interactive flow. 

42 """ 

43 # Phase 1: Prepare config only 

44 prep_result = _helper_prepare_stitch_config( 

45 client, llm_client_instance, target_catalog, target_schema 

46 ) 

47 if prep_result.get("error"): 

48 return prep_result 

49 

50 # Return the prepared config for further processing 

51 # No longer automatically launching the job 

52 return prep_result 

53 

54 

55def _helper_prepare_stitch_config( 

56 client: DatabricksAPIClient, 

57 llm_client_instance: LLMClient, 

58 target_catalog: str, 

59 target_schema: str, 

60) -> Dict[str, Any]: 

61 """Phase 1: Prepare Stitch configuration without launching job.""" 

62 if not target_catalog or not target_schema: 

63 return {"error": "Target catalog and schema are required for Stitch setup."} 

64 

65 # Step 1: Scan for PII data (using the helper for this logic) 

66 pii_scan_output = _helper_scan_schema_for_pii_logic( 

67 client, llm_client_instance, target_catalog, target_schema 

68 ) 

69 if pii_scan_output.get("error"): 

70 return { 

71 "error": f"PII Scan failed during Stitch setup: {pii_scan_output['error']}" 

72 } 

73 

74 # Step 2: Check/Create "chuck" volume 

75 volume_name = "chuck" 

76 volume_exists = False 

77 

78 # Check if volume exists - direct API call 

79 try: 

80 volumes_response = client.list_volumes( 

81 catalog_name=target_catalog, schema_name=target_schema 

82 ) 

83 for volume_info in volumes_response.get("volumes", []): 

84 if volume_info.get("name") == volume_name: 

85 volume_exists = True 

86 break 

87 except Exception as e: 

88 return {"error": f"Failed to list volumes: {str(e)}"} 

89 

90 if not volume_exists: 

91 logging.debug( 

92 f"Volume '{volume_name}' not found in {target_catalog}.{target_schema}. Attempting to create." 

93 ) 

94 try: 

95 # Direct API call to create volume 

96 volume_response = client.create_volume( 

97 catalog_name=target_catalog, schema_name=target_schema, name=volume_name 

98 ) 

99 if not volume_response: 

100 return {"error": f"Failed to create volume '{volume_name}'"} 

101 logging.debug(f"Volume '{volume_name}' created successfully.") 

102 except Exception as e: 

103 return {"error": f"Failed to create volume '{volume_name}': {str(e)}"} 

104 

105 # Step 3: Generate Stitch configuration 

106 current_datetime = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M") 

107 stitch_job_name = f"stitch-{current_datetime}" 

108 stitch_config = { 

109 "name": stitch_job_name, 

110 "tables": [], 

111 "settings": { 

112 "output_catalog_name": target_catalog, 

113 "output_schema_name": "stitch_outputs", 

114 }, 

115 } 

116 

117 # Track unsupported columns for user feedback 

118 unsupported_columns = [] 

119 

120 for table_pii_data in pii_scan_output.get("results_detail", []): 

121 if ( 

122 table_pii_data.get("error") 

123 or table_pii_data.get("skipped") 

124 or not table_pii_data.get("has_pii") 

125 ): 

126 continue # Only include successfully scanned tables with PII 

127 

128 table_cfg = {"path": table_pii_data["full_name"], "fields": []} 

129 table_unsupported = [] 

130 

131 for col_data in table_pii_data.get("columns", []): 

132 if col_data["type"] not in UNSUPPORTED_TYPES: 

133 field_cfg = { 

134 "field-name": col_data["name"], 

135 "type": col_data["type"], 

136 "semantics": [], 

137 } 

138 if col_data.get("semantic"): # Only add non-null/empty semantics 

139 field_cfg["semantics"].append(col_data["semantic"]) 

140 table_cfg["fields"].append(field_cfg) 

141 else: 

142 # Track unsupported column 

143 table_unsupported.append( 

144 { 

145 "column": col_data["name"], 

146 "type": col_data["type"], 

147 "semantic": col_data.get("semantic"), 

148 } 

149 ) 

150 

151 # Add unsupported columns for this table if any 

152 if table_unsupported: 

153 unsupported_columns.append( 

154 {"table": table_pii_data["full_name"], "columns": table_unsupported} 

155 ) 

156 

157 # Only add table if it has at least one supported field 

158 if table_cfg["fields"]: 

159 stitch_config["tables"].append(table_cfg) 

160 

161 if not stitch_config["tables"]: 

162 return { 

163 "error": "No tables with PII found to include in Stitch configuration.", 

164 "pii_scan_output": pii_scan_output, 

165 } 

166 

167 # Step 4: Prepare file paths and get Amperity token 

168 config_file_path = f"/Volumes/{target_catalog}/{target_schema}/{volume_name}/{stitch_job_name}.json" 

169 init_script_volume_path = ( 

170 f"/Volumes/{target_catalog}/{target_schema}/{volume_name}/cluster_init.sh" 

171 ) 

172 

173 amperity_token = get_amperity_token() 

174 if not amperity_token: 

175 return {"error": "Amperity token not found. Please run /amp_login first."} 

176 

177 # Fetch init script content but don't write it yet 

178 try: 

179 init_script_data = client.fetch_amperity_job_init(amperity_token) 

180 init_script_content = init_script_data.get("cluster-init") 

181 if not init_script_content: 

182 return {"error": "Failed to get cluster init script from Amperity API."} 

183 except Exception as e_fetch_init: 

184 logging.error( 

185 f"Error fetching Amperity init script: {e_fetch_init}", exc_info=True 

186 ) 

187 return {"error": f"Error fetching Amperity init script: {str(e_fetch_init)}"} 

188 

189 # Upload cluster init script with versioning 

190 upload_result = _helper_upload_cluster_init_logic( 

191 client=client, 

192 target_catalog=target_catalog, 

193 target_schema=target_schema, 

194 init_script_content=init_script_content, 

195 ) 

196 if upload_result.get("error"): 

197 return upload_result 

198 

199 # Use the versioned init script path 

200 init_script_volume_path = upload_result["volume_path"] 

201 logging.debug( 

202 f"Versioned cluster init script uploaded to {init_script_volume_path}" 

203 ) 

204 

205 return { 

206 "success": True, 

207 "stitch_config": stitch_config, 

208 "metadata": { 

209 "target_catalog": target_catalog, 

210 "target_schema": target_schema, 

211 "volume_name": volume_name, 

212 "stitch_job_name": stitch_job_name, 

213 "config_file_path": config_file_path, 

214 "init_script_path": init_script_volume_path, 

215 "init_script_content": init_script_content, 

216 "amperity_token": amperity_token, 

217 "pii_scan_output": pii_scan_output, 

218 "unsupported_columns": unsupported_columns, 

219 }, 

220 } 

221 

222 

223def _helper_modify_stitch_config( 

224 current_config: Dict[str, Any], 

225 modification_request: str, 

226 llm_client_instance: LLMClient, 

227 metadata: Dict[str, Any], 

228) -> Dict[str, Any]: 

229 """Phase 2: Modify Stitch configuration based on user request using LLM.""" 

230 try: 

231 # Create a prompt for the LLM to modify the config 

232 prompt = f"""You are helping modify a Stitch integration configuration based on user requests. 

233 

234Current configuration: 

235{json.dumps(current_config, indent=2)} 

236 

237User modification request: "{modification_request}" 

238 

239Please modify the configuration according to the user's request and return ONLY the updated JSON configuration. 

240Ensure the JSON is valid and follows the same structure. 

241 

242Important rules: 

243- Keep the same overall structure (name, tables, settings) 

244- Each table should have "path" and "fields" arrays 

245- Each field should have "field-name", "type", and "semantics" arrays 

246- Only include tables and fields that make sense based on the original PII scan data 

247- If removing tables/fields, just omit them from the output 

248- If adding semantics, use standard PII types like "email", "name", "phone", "ssn", etc. 

249""" 

250 

251 # Call LLM to get modified config 

252 llm_response = llm_client_instance.chat( 

253 messages=[{"role": "user", "content": prompt}] 

254 ) 

255 

256 if not llm_response or not llm_response.choices: 

257 return {"error": "Failed to get LLM response for config modification"} 

258 

259 # Parse the LLM response as JSON 

260 try: 

261 response_text = llm_response.choices[0].message.content 

262 if not response_text or not isinstance(response_text, str): 

263 return {"error": "LLM returned invalid response format"} 

264 

265 # Clean up response text (remove code blocks if present) 

266 response_text = response_text.strip() 

267 if response_text.startswith("```json"): 

268 response_text = response_text[7:-3].strip() 

269 elif response_text.startswith("```"): 

270 response_text = response_text[3:-3].strip() 

271 

272 modified_config = json.loads(response_text) 

273 except json.JSONDecodeError as e: 

274 return {"error": f"LLM returned invalid JSON: {str(e)}"} 

275 

276 # Basic validation of the modified config 

277 if not isinstance(modified_config, dict): 

278 return {"error": "Modified config must be a JSON object"} 

279 

280 required_keys = ["name", "tables", "settings"] 

281 for key in required_keys: 

282 if key not in modified_config: 

283 return {"error": f"Modified config missing required key: {key}"} 

284 

285 if not isinstance(modified_config["tables"], list): 

286 return {"error": "Modified config 'tables' must be an array"} 

287 

288 # Validate each table structure 

289 for table in modified_config["tables"]: 

290 if ( 

291 not isinstance(table, dict) 

292 or "path" not in table 

293 or "fields" not in table 

294 ): 

295 return {"error": "Each table must have 'path' and 'fields' properties"} 

296 

297 if not isinstance(table["fields"], list): 

298 return {"error": "Table 'fields' must be an array"} 

299 

300 for field in table["fields"]: 

301 if not isinstance(field, dict): 

302 return {"error": "Each field must be an object"} 

303 required_field_keys = ["field-name", "type", "semantics"] 

304 for fkey in required_field_keys: 

305 if fkey not in field: 

306 return {"error": f"Field missing required key: {fkey}"} 

307 

308 return { 

309 "success": True, 

310 "stitch_config": modified_config, 

311 "modification_summary": f"Configuration modified based on request: {modification_request}", 

312 } 

313 

314 except Exception as e: 

315 logging.error(f"Error modifying Stitch config: {e}", exc_info=True) 

316 return {"error": f"Error modifying configuration: {str(e)}"} 

317 

318 

319def _create_stitch_report_notebook( 

320 client: DatabricksAPIClient, 

321 stitch_config: Dict[str, Any], 

322 target_catalog: str, 

323 target_schema: str, 

324 stitch_job_name: str, 

325) -> Dict[str, Any]: 

326 """Helper function to create a Stitch report notebook automatically. 

327 

328 This uses the DatabricksAPIClient.create_stitch_notebook method but with datasources 

329 extracted from the stitch_config tables' paths and a table_path constructed from the 

330 target catalog and schema. 

331 

332 Args: 

333 client: DatabricksAPIClient instance 

334 stitch_config: The Stitch configuration dictionary 

335 target_catalog: Target catalog name 

336 target_schema: Target schema name 

337 stitch_job_name: Name of the Stitch job (used for notebook naming) 

338 

339 Returns: 

340 Dictionary with success/error status and notebook path if successful 

341 """ 

342 try: 

343 # Construct table path in the required format 

344 table_path = f"{target_catalog}.stitch_outputs.unified_coalesced" 

345 

346 # Construct a descriptive notebook name 

347 notebook_name = f"Stitch Report: {target_catalog}.{target_schema}" 

348 

349 # Call the create_stitch_notebook method with our parameters 

350 try: 

351 result = client.create_stitch_notebook( 

352 table_path=table_path, 

353 notebook_name=notebook_name, 

354 stitch_config=stitch_config, 

355 ) 

356 

357 # If we get here, the notebook was created successfully, even if result doesn't have notebook_path 

358 notebook_path = result.get( 

359 "notebook_path", f"/Workspace/Users/unknown/{notebook_name}" 

360 ) 

361 

362 return { 

363 "success": True, 

364 "notebook_path": notebook_path, 

365 "message": f"Successfully created Stitch report notebook at {notebook_path}", 

366 } 

367 except Exception as e: 

368 # Only return an error if there was an actual exception 

369 return {"success": False, "error": str(e)} 

370 except Exception as e: 

371 logging.error(f"Error creating Stitch report notebook: {str(e)}", exc_info=True) 

372 return {"success": False, "error": str(e)} 

373 

374 

375def _helper_launch_stitch_job( 

376 client: DatabricksAPIClient, stitch_config: Dict[str, Any], metadata: Dict[str, Any] 

377) -> Dict[str, Any]: 

378 """Phase 3: Write final config and launch Stitch job.""" 

379 try: 

380 # Extract metadata 

381 target_catalog = metadata["target_catalog"] 

382 target_schema = metadata["target_schema"] 

383 stitch_job_name = metadata["stitch_job_name"] 

384 config_file_path = metadata["config_file_path"] 

385 init_script_path = metadata["init_script_path"] 

386 init_script_content = metadata["init_script_content"] 

387 pii_scan_output = metadata["pii_scan_output"] 

388 unsupported_columns = metadata["unsupported_columns"] 

389 

390 # Write final config file to volume 

391 config_content_json = json.dumps(stitch_config, indent=2) 

392 try: 

393 upload_success = client.upload_file( 

394 path=config_file_path, content=config_content_json, overwrite=True 

395 ) 

396 if not upload_success: 

397 return { 

398 "error": f"Failed to write Stitch config to '{config_file_path}'" 

399 } 

400 logging.debug(f"Stitch config written to {config_file_path}") 

401 except Exception as e: 

402 return { 

403 "error": f"Failed to write Stitch config '{config_file_path}': {str(e)}" 

404 } 

405 

406 # Write init script to volume 

407 try: 

408 upload_init_success = client.upload_file( 

409 path=init_script_path, content=init_script_content, overwrite=True 

410 ) 

411 if not upload_init_success: 

412 return {"error": f"Failed to write init script to '{init_script_path}'"} 

413 logging.debug(f"Cluster init script written to {init_script_path}") 

414 except Exception as e: 

415 return { 

416 "error": f"Failed to write init script '{init_script_path}': {str(e)}" 

417 } 

418 

419 # Launch the Stitch job 

420 try: 

421 job_run_data = client.submit_job_run( 

422 config_path=config_file_path, 

423 init_script_path=init_script_path, 

424 run_name=f"Stitch Setup: {stitch_job_name}", 

425 ) 

426 run_id = job_run_data.get("run_id") 

427 if not run_id: 

428 return {"error": "Failed to launch job (no run_id returned)"} 

429 except Exception as e: 

430 return {"error": f"Failed to launch Stitch job: {str(e)}"} 

431 

432 # Build success message 

433 summary_msg_lines = [ 

434 f"Stitch setup for {target_catalog}.{target_schema} initiated." 

435 ] 

436 summary_msg_lines.append(f"Config: {config_file_path}") 

437 summary_msg_lines.append(f"Databricks Job Run ID: {run_id}") 

438 

439 # Add unsupported columns information if any 

440 if unsupported_columns: 

441 summary_msg_lines.append("") 

442 summary_msg_lines.append( 

443 "Note: Some columns were excluded due to unsupported data types:" 

444 ) 

445 for table_info in unsupported_columns: 

446 summary_msg_lines.append(f" Table: {table_info['table']}") 

447 for col_info in table_info["columns"]: 

448 semantic_info = ( 

449 f" (semantic: {col_info['semantic']})" 

450 if col_info["semantic"] 

451 else "" 

452 ) 

453 summary_msg_lines.append( 

454 f" - {col_info['column']} ({col_info['type']}){semantic_info}" 

455 ) 

456 

457 # Automatically create stitch report notebook 

458 notebook_result = _create_stitch_report_notebook( 

459 client=client, 

460 stitch_config=stitch_config, 

461 target_catalog=target_catalog, 

462 target_schema=target_schema, 

463 stitch_job_name=stitch_job_name, 

464 ) 

465 

466 # Add notebook creation information to the summary 

467 if notebook_result.get("success"): 

468 summary_msg_lines.append("\nCreated Stitch Report notebook:") 

469 summary_msg_lines.append( 

470 f"Notebook Path: {notebook_result.get('notebook_path', 'Unknown')}" 

471 ) 

472 else: 

473 # If notebook creation failed, log the error but don't fail the overall job 

474 error_msg = notebook_result.get("error", "Unknown error") 

475 summary_msg_lines.append( 

476 f"\nNote: Could not create Stitch Report notebook: {error_msg}" 

477 ) 

478 logging.warning(f"Failed to create Stitch Report notebook: {error_msg}") 

479 

480 final_summary = "\n".join(summary_msg_lines) 

481 return { 

482 "success": True, 

483 "message": final_summary, 

484 "stitch_job_name": stitch_job_name, 

485 "run_id": run_id, 

486 "config_path": config_file_path, 

487 "init_script_path": init_script_path, 

488 "pii_scan_summary": pii_scan_output.get("message", "PII scan performed."), 

489 "unsupported_columns": unsupported_columns, 

490 "notebook_result": ( 

491 notebook_result if "notebook_result" in locals() else None 

492 ), 

493 } 

494 

495 except Exception as e: 

496 logging.error(f"Error launching Stitch job: {e}", exc_info=True) 

497 return {"error": f"Error launching Stitch job: {str(e)}"}