Coverage for src/commands/stitch_tools.py: 42%
191 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
1"""
2Stitch integration helper functions for command handlers.
4This module contains utilities for setting up Stitch integration
5with Databricks catalogs and schemas.
6"""
8import logging
9import json
10import datetime
11from typing import Dict, Any
13from src.clients.databricks import DatabricksAPIClient
14from src.llm.client import LLMClient
15from src.config import get_amperity_token
16from .pii_tools import _helper_scan_schema_for_pii_logic
17from .cluster_init_tools import _helper_upload_cluster_init_logic
19UNSUPPORTED_TYPES = [
20 "INTERVAL",
21 "VOID",
22 "ARRAY",
23 "MAP",
24 "STRUCT",
25 "VARIANT",
26 "OBJECT",
27 "GEOGRAPHY",
28 "GEOMETRY",
29]
32def _helper_setup_stitch_logic(
33 client: DatabricksAPIClient,
34 llm_client_instance: LLMClient,
35 target_catalog: str,
36 target_schema: str,
37) -> Dict[str, Any]:
38 """Legacy function for backward compatibility. Calls prepare phase only.
40 IMPORTANT: This has been modified to only run the preparation phase and not
41 automatically launch the job, which is now handled by the interactive flow.
42 """
43 # Phase 1: Prepare config only
44 prep_result = _helper_prepare_stitch_config(
45 client, llm_client_instance, target_catalog, target_schema
46 )
47 if prep_result.get("error"):
48 return prep_result
50 # Return the prepared config for further processing
51 # No longer automatically launching the job
52 return prep_result
55def _helper_prepare_stitch_config(
56 client: DatabricksAPIClient,
57 llm_client_instance: LLMClient,
58 target_catalog: str,
59 target_schema: str,
60) -> Dict[str, Any]:
61 """Phase 1: Prepare Stitch configuration without launching job."""
62 if not target_catalog or not target_schema:
63 return {"error": "Target catalog and schema are required for Stitch setup."}
65 # Step 1: Scan for PII data (using the helper for this logic)
66 pii_scan_output = _helper_scan_schema_for_pii_logic(
67 client, llm_client_instance, target_catalog, target_schema
68 )
69 if pii_scan_output.get("error"):
70 return {
71 "error": f"PII Scan failed during Stitch setup: {pii_scan_output['error']}"
72 }
74 # Step 2: Check/Create "chuck" volume
75 volume_name = "chuck"
76 volume_exists = False
78 # Check if volume exists - direct API call
79 try:
80 volumes_response = client.list_volumes(
81 catalog_name=target_catalog, schema_name=target_schema
82 )
83 for volume_info in volumes_response.get("volumes", []):
84 if volume_info.get("name") == volume_name:
85 volume_exists = True
86 break
87 except Exception as e:
88 return {"error": f"Failed to list volumes: {str(e)}"}
90 if not volume_exists:
91 logging.debug(
92 f"Volume '{volume_name}' not found in {target_catalog}.{target_schema}. Attempting to create."
93 )
94 try:
95 # Direct API call to create volume
96 volume_response = client.create_volume(
97 catalog_name=target_catalog, schema_name=target_schema, name=volume_name
98 )
99 if not volume_response:
100 return {"error": f"Failed to create volume '{volume_name}'"}
101 logging.debug(f"Volume '{volume_name}' created successfully.")
102 except Exception as e:
103 return {"error": f"Failed to create volume '{volume_name}': {str(e)}"}
105 # Step 3: Generate Stitch configuration
106 current_datetime = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
107 stitch_job_name = f"stitch-{current_datetime}"
108 stitch_config = {
109 "name": stitch_job_name,
110 "tables": [],
111 "settings": {
112 "output_catalog_name": target_catalog,
113 "output_schema_name": "stitch_outputs",
114 },
115 }
117 # Track unsupported columns for user feedback
118 unsupported_columns = []
120 for table_pii_data in pii_scan_output.get("results_detail", []):
121 if (
122 table_pii_data.get("error")
123 or table_pii_data.get("skipped")
124 or not table_pii_data.get("has_pii")
125 ):
126 continue # Only include successfully scanned tables with PII
128 table_cfg = {"path": table_pii_data["full_name"], "fields": []}
129 table_unsupported = []
131 for col_data in table_pii_data.get("columns", []):
132 if col_data["type"] not in UNSUPPORTED_TYPES:
133 field_cfg = {
134 "field-name": col_data["name"],
135 "type": col_data["type"],
136 "semantics": [],
137 }
138 if col_data.get("semantic"): # Only add non-null/empty semantics
139 field_cfg["semantics"].append(col_data["semantic"])
140 table_cfg["fields"].append(field_cfg)
141 else:
142 # Track unsupported column
143 table_unsupported.append(
144 {
145 "column": col_data["name"],
146 "type": col_data["type"],
147 "semantic": col_data.get("semantic"),
148 }
149 )
151 # Add unsupported columns for this table if any
152 if table_unsupported:
153 unsupported_columns.append(
154 {"table": table_pii_data["full_name"], "columns": table_unsupported}
155 )
157 # Only add table if it has at least one supported field
158 if table_cfg["fields"]:
159 stitch_config["tables"].append(table_cfg)
161 if not stitch_config["tables"]:
162 return {
163 "error": "No tables with PII found to include in Stitch configuration.",
164 "pii_scan_output": pii_scan_output,
165 }
167 # Step 4: Prepare file paths and get Amperity token
168 config_file_path = f"/Volumes/{target_catalog}/{target_schema}/{volume_name}/{stitch_job_name}.json"
169 init_script_volume_path = (
170 f"/Volumes/{target_catalog}/{target_schema}/{volume_name}/cluster_init.sh"
171 )
173 amperity_token = get_amperity_token()
174 if not amperity_token:
175 return {"error": "Amperity token not found. Please run /amp_login first."}
177 # Fetch init script content but don't write it yet
178 try:
179 init_script_data = client.fetch_amperity_job_init(amperity_token)
180 init_script_content = init_script_data.get("cluster-init")
181 if not init_script_content:
182 return {"error": "Failed to get cluster init script from Amperity API."}
183 except Exception as e_fetch_init:
184 logging.error(
185 f"Error fetching Amperity init script: {e_fetch_init}", exc_info=True
186 )
187 return {"error": f"Error fetching Amperity init script: {str(e_fetch_init)}"}
189 # Upload cluster init script with versioning
190 upload_result = _helper_upload_cluster_init_logic(
191 client=client,
192 target_catalog=target_catalog,
193 target_schema=target_schema,
194 init_script_content=init_script_content,
195 )
196 if upload_result.get("error"):
197 return upload_result
199 # Use the versioned init script path
200 init_script_volume_path = upload_result["volume_path"]
201 logging.debug(
202 f"Versioned cluster init script uploaded to {init_script_volume_path}"
203 )
205 return {
206 "success": True,
207 "stitch_config": stitch_config,
208 "metadata": {
209 "target_catalog": target_catalog,
210 "target_schema": target_schema,
211 "volume_name": volume_name,
212 "stitch_job_name": stitch_job_name,
213 "config_file_path": config_file_path,
214 "init_script_path": init_script_volume_path,
215 "init_script_content": init_script_content,
216 "amperity_token": amperity_token,
217 "pii_scan_output": pii_scan_output,
218 "unsupported_columns": unsupported_columns,
219 },
220 }
223def _helper_modify_stitch_config(
224 current_config: Dict[str, Any],
225 modification_request: str,
226 llm_client_instance: LLMClient,
227 metadata: Dict[str, Any],
228) -> Dict[str, Any]:
229 """Phase 2: Modify Stitch configuration based on user request using LLM."""
230 try:
231 # Create a prompt for the LLM to modify the config
232 prompt = f"""You are helping modify a Stitch integration configuration based on user requests.
234Current configuration:
235{json.dumps(current_config, indent=2)}
237User modification request: "{modification_request}"
239Please modify the configuration according to the user's request and return ONLY the updated JSON configuration.
240Ensure the JSON is valid and follows the same structure.
242Important rules:
243- Keep the same overall structure (name, tables, settings)
244- Each table should have "path" and "fields" arrays
245- Each field should have "field-name", "type", and "semantics" arrays
246- Only include tables and fields that make sense based on the original PII scan data
247- If removing tables/fields, just omit them from the output
248- If adding semantics, use standard PII types like "email", "name", "phone", "ssn", etc.
249"""
251 # Call LLM to get modified config
252 llm_response = llm_client_instance.chat(
253 messages=[{"role": "user", "content": prompt}]
254 )
256 if not llm_response or not llm_response.choices:
257 return {"error": "Failed to get LLM response for config modification"}
259 # Parse the LLM response as JSON
260 try:
261 response_text = llm_response.choices[0].message.content
262 if not response_text or not isinstance(response_text, str):
263 return {"error": "LLM returned invalid response format"}
265 # Clean up response text (remove code blocks if present)
266 response_text = response_text.strip()
267 if response_text.startswith("```json"):
268 response_text = response_text[7:-3].strip()
269 elif response_text.startswith("```"):
270 response_text = response_text[3:-3].strip()
272 modified_config = json.loads(response_text)
273 except json.JSONDecodeError as e:
274 return {"error": f"LLM returned invalid JSON: {str(e)}"}
276 # Basic validation of the modified config
277 if not isinstance(modified_config, dict):
278 return {"error": "Modified config must be a JSON object"}
280 required_keys = ["name", "tables", "settings"]
281 for key in required_keys:
282 if key not in modified_config:
283 return {"error": f"Modified config missing required key: {key}"}
285 if not isinstance(modified_config["tables"], list):
286 return {"error": "Modified config 'tables' must be an array"}
288 # Validate each table structure
289 for table in modified_config["tables"]:
290 if (
291 not isinstance(table, dict)
292 or "path" not in table
293 or "fields" not in table
294 ):
295 return {"error": "Each table must have 'path' and 'fields' properties"}
297 if not isinstance(table["fields"], list):
298 return {"error": "Table 'fields' must be an array"}
300 for field in table["fields"]:
301 if not isinstance(field, dict):
302 return {"error": "Each field must be an object"}
303 required_field_keys = ["field-name", "type", "semantics"]
304 for fkey in required_field_keys:
305 if fkey not in field:
306 return {"error": f"Field missing required key: {fkey}"}
308 return {
309 "success": True,
310 "stitch_config": modified_config,
311 "modification_summary": f"Configuration modified based on request: {modification_request}",
312 }
314 except Exception as e:
315 logging.error(f"Error modifying Stitch config: {e}", exc_info=True)
316 return {"error": f"Error modifying configuration: {str(e)}"}
319def _create_stitch_report_notebook(
320 client: DatabricksAPIClient,
321 stitch_config: Dict[str, Any],
322 target_catalog: str,
323 target_schema: str,
324 stitch_job_name: str,
325) -> Dict[str, Any]:
326 """Helper function to create a Stitch report notebook automatically.
328 This uses the DatabricksAPIClient.create_stitch_notebook method but with datasources
329 extracted from the stitch_config tables' paths and a table_path constructed from the
330 target catalog and schema.
332 Args:
333 client: DatabricksAPIClient instance
334 stitch_config: The Stitch configuration dictionary
335 target_catalog: Target catalog name
336 target_schema: Target schema name
337 stitch_job_name: Name of the Stitch job (used for notebook naming)
339 Returns:
340 Dictionary with success/error status and notebook path if successful
341 """
342 try:
343 # Construct table path in the required format
344 table_path = f"{target_catalog}.stitch_outputs.unified_coalesced"
346 # Construct a descriptive notebook name
347 notebook_name = f"Stitch Report: {target_catalog}.{target_schema}"
349 # Call the create_stitch_notebook method with our parameters
350 try:
351 result = client.create_stitch_notebook(
352 table_path=table_path,
353 notebook_name=notebook_name,
354 stitch_config=stitch_config,
355 )
357 # If we get here, the notebook was created successfully, even if result doesn't have notebook_path
358 notebook_path = result.get(
359 "notebook_path", f"/Workspace/Users/unknown/{notebook_name}"
360 )
362 return {
363 "success": True,
364 "notebook_path": notebook_path,
365 "message": f"Successfully created Stitch report notebook at {notebook_path}",
366 }
367 except Exception as e:
368 # Only return an error if there was an actual exception
369 return {"success": False, "error": str(e)}
370 except Exception as e:
371 logging.error(f"Error creating Stitch report notebook: {str(e)}", exc_info=True)
372 return {"success": False, "error": str(e)}
375def _helper_launch_stitch_job(
376 client: DatabricksAPIClient, stitch_config: Dict[str, Any], metadata: Dict[str, Any]
377) -> Dict[str, Any]:
378 """Phase 3: Write final config and launch Stitch job."""
379 try:
380 # Extract metadata
381 target_catalog = metadata["target_catalog"]
382 target_schema = metadata["target_schema"]
383 stitch_job_name = metadata["stitch_job_name"]
384 config_file_path = metadata["config_file_path"]
385 init_script_path = metadata["init_script_path"]
386 init_script_content = metadata["init_script_content"]
387 pii_scan_output = metadata["pii_scan_output"]
388 unsupported_columns = metadata["unsupported_columns"]
390 # Write final config file to volume
391 config_content_json = json.dumps(stitch_config, indent=2)
392 try:
393 upload_success = client.upload_file(
394 path=config_file_path, content=config_content_json, overwrite=True
395 )
396 if not upload_success:
397 return {
398 "error": f"Failed to write Stitch config to '{config_file_path}'"
399 }
400 logging.debug(f"Stitch config written to {config_file_path}")
401 except Exception as e:
402 return {
403 "error": f"Failed to write Stitch config '{config_file_path}': {str(e)}"
404 }
406 # Write init script to volume
407 try:
408 upload_init_success = client.upload_file(
409 path=init_script_path, content=init_script_content, overwrite=True
410 )
411 if not upload_init_success:
412 return {"error": f"Failed to write init script to '{init_script_path}'"}
413 logging.debug(f"Cluster init script written to {init_script_path}")
414 except Exception as e:
415 return {
416 "error": f"Failed to write init script '{init_script_path}': {str(e)}"
417 }
419 # Launch the Stitch job
420 try:
421 job_run_data = client.submit_job_run(
422 config_path=config_file_path,
423 init_script_path=init_script_path,
424 run_name=f"Stitch Setup: {stitch_job_name}",
425 )
426 run_id = job_run_data.get("run_id")
427 if not run_id:
428 return {"error": "Failed to launch job (no run_id returned)"}
429 except Exception as e:
430 return {"error": f"Failed to launch Stitch job: {str(e)}"}
432 # Build success message
433 summary_msg_lines = [
434 f"Stitch setup for {target_catalog}.{target_schema} initiated."
435 ]
436 summary_msg_lines.append(f"Config: {config_file_path}")
437 summary_msg_lines.append(f"Databricks Job Run ID: {run_id}")
439 # Add unsupported columns information if any
440 if unsupported_columns:
441 summary_msg_lines.append("")
442 summary_msg_lines.append(
443 "Note: Some columns were excluded due to unsupported data types:"
444 )
445 for table_info in unsupported_columns:
446 summary_msg_lines.append(f" Table: {table_info['table']}")
447 for col_info in table_info["columns"]:
448 semantic_info = (
449 f" (semantic: {col_info['semantic']})"
450 if col_info["semantic"]
451 else ""
452 )
453 summary_msg_lines.append(
454 f" - {col_info['column']} ({col_info['type']}){semantic_info}"
455 )
457 # Automatically create stitch report notebook
458 notebook_result = _create_stitch_report_notebook(
459 client=client,
460 stitch_config=stitch_config,
461 target_catalog=target_catalog,
462 target_schema=target_schema,
463 stitch_job_name=stitch_job_name,
464 )
466 # Add notebook creation information to the summary
467 if notebook_result.get("success"):
468 summary_msg_lines.append("\nCreated Stitch Report notebook:")
469 summary_msg_lines.append(
470 f"Notebook Path: {notebook_result.get('notebook_path', 'Unknown')}"
471 )
472 else:
473 # If notebook creation failed, log the error but don't fail the overall job
474 error_msg = notebook_result.get("error", "Unknown error")
475 summary_msg_lines.append(
476 f"\nNote: Could not create Stitch Report notebook: {error_msg}"
477 )
478 logging.warning(f"Failed to create Stitch Report notebook: {error_msg}")
480 final_summary = "\n".join(summary_msg_lines)
481 return {
482 "success": True,
483 "message": final_summary,
484 "stitch_job_name": stitch_job_name,
485 "run_id": run_id,
486 "config_path": config_file_path,
487 "init_script_path": init_script_path,
488 "pii_scan_summary": pii_scan_output.get("message", "PII scan performed."),
489 "unsupported_columns": unsupported_columns,
490 "notebook_result": (
491 notebook_result if "notebook_result" in locals() else None
492 ),
493 }
495 except Exception as e:
496 logging.error(f"Error launching Stitch job: {e}", exc_info=True)
497 return {"error": f"Error launching Stitch job: {str(e)}"}