Coverage for src/clients/databricks.py: 44%
336 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
1"""
2Reusable Databricks API client for authentication and requests.
3"""
5import base64
6import json
7import logging
8import os
9import requests
10import time
11import urllib.parse
12from datetime import datetime
13from src.config import get_warehouse_id
14from src.clients.amperity import get_amperity_url
15from src.databricks.url_utils import (
16 detect_cloud_provider,
17 normalize_workspace_url,
18)
21class DatabricksAPIClient:
22 """Reusable Databricks API client for authentication and requests."""
24 def __init__(self, workspace_url, token):
25 """
26 Initialize the API client.
28 Args:
29 workspace_url: Databricks workspace URL (with or without protocol/domain)
30 token: Databricks API token
31 """
32 self.original_url = workspace_url
33 self.workspace_url = self._normalize_workspace_url(workspace_url)
34 self.cloud_provider = detect_cloud_provider(workspace_url)
35 self.base_domain = self._get_base_domain()
36 self.token = token
37 self.headers = {
38 "Authorization": f"Bearer {self.token}",
39 "User-Agent": "amperity",
40 }
42 def _normalize_workspace_url(self, url):
43 """
44 Normalize the workspace URL to the format needed for API calls.
46 Args:
47 url: Input workspace URL that might be in various formats
49 Returns:
50 Cleaned workspace URL (workspace ID only)
51 """
52 return normalize_workspace_url(url)
54 def _get_base_domain(self):
55 """
56 Get the appropriate base domain based on the cloud provider.
58 Returns:
59 Base domain string for the detected cloud provider
60 """
61 from src.databricks.url_utils import DATABRICKS_DOMAIN_MAP
63 return DATABRICKS_DOMAIN_MAP.get(
64 self.cloud_provider, DATABRICKS_DOMAIN_MAP["AWS"]
65 )
67 def get_compute_node_type(self):
68 """
69 Get the appropriate compute node type based on the cloud provider.
71 Returns:
72 Node type string for the detected cloud provider
73 """
74 node_type_map = {
75 "AWS": "r5d.4xlarge",
76 "Azure": "Standard_E16ds_v4",
77 "GCP": "n2-standard-16", # Default GCP node type
78 "Generic": "r5d.4xlarge", # Default to AWS
79 }
80 return node_type_map.get(self.cloud_provider, "r5d.4xlarge")
82 def get_cloud_attributes(self):
83 """
84 Get cloud-specific attributes for cluster configuration.
86 Returns:
87 Dictionary containing cloud-specific attributes
88 """
89 if self.cloud_provider == "AWS":
90 return {
91 "aws_attributes": {
92 "first_on_demand": 1,
93 "availability": "SPOT_WITH_FALLBACK",
94 "zone_id": "us-west-2b",
95 "spot_bid_price_percent": 100,
96 "ebs_volume_count": 0,
97 }
98 }
99 elif self.cloud_provider == "Azure":
100 return {
101 "azure_attributes": {
102 "first_on_demand": 1,
103 "availability": "SPOT_WITH_FALLBACK_AZURE",
104 "spot_bid_max_price": -1,
105 }
106 }
107 elif self.cloud_provider == "GCP":
108 return {
109 "gcp_attributes": {
110 "use_preemptible_executors": True,
111 "google_service_account": None,
112 }
113 }
114 else:
115 # Default to AWS
116 return {
117 "aws_attributes": {
118 "first_on_demand": 1,
119 "availability": "SPOT_WITH_FALLBACK",
120 "zone_id": "us-west-2b",
121 "spot_bid_price_percent": 100,
122 "ebs_volume_count": 0,
123 }
124 }
126 #
127 # Base API request methods
128 #
130 def get(self, endpoint):
131 """
132 Send a GET request to the Databricks API.
134 Args:
135 endpoint: API endpoint (starting with /)
137 Returns:
138 JSON response from the API
140 Raises:
141 ValueError: If an HTTP error occurs
142 ConnectionError: If a connection error occurs
143 """
144 url = f"https://{self.workspace_url}.{self.base_domain}{endpoint}"
145 logging.debug(f"GET request to: {url}")
147 try:
148 response = requests.get(url, headers=self.headers)
149 response.raise_for_status()
150 return response.json()
151 except requests.exceptions.HTTPError as e:
152 logging.debug(f"HTTP error: {e}, Response: {response.text}")
153 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}")
154 except requests.RequestException as e:
155 logging.debug(f"Connection error: {e}")
156 raise ConnectionError(f"Connection error occurred: {e}")
158 def get_with_params(self, endpoint, params=None):
159 """
160 Send a GET request to the Databricks API with query parameters.
162 Args:
163 endpoint: API endpoint (starting with /)
164 params: Dictionary of query parameters
166 Returns:
167 JSON response from the API
169 Raises:
170 ValueError: If an HTTP error occurs
171 ConnectionError: If a connection error occurs
172 """
173 url = f"https://{self.workspace_url}.{self.base_domain}{endpoint}"
174 logging.debug(f"GET request with params to: {url}")
176 try:
177 response = requests.get(url, headers=self.headers, params=params)
178 response.raise_for_status()
179 return response.json()
180 except requests.exceptions.HTTPError as e:
181 logging.debug(f"HTTP error: {e}, Response: {response.text}")
182 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}")
183 except requests.RequestException as e:
184 logging.debug(f"Connection error: {e}")
185 raise ConnectionError(f"Connection error occurred: {e}")
187 def post(self, endpoint, data):
188 """
189 Send a POST request to the Databricks API.
191 Args:
192 endpoint: API endpoint (starting with /)
193 data: JSON data to send in the request body
195 Returns:
196 JSON response from the API
198 Raises:
199 ValueError: If an HTTP error occurs
200 ConnectionError: If a connection error occurs
201 """
202 url = f"https://{self.workspace_url}.{self.base_domain}{endpoint}"
203 logging.debug(f"POST request to: {url}")
205 try:
206 response = requests.post(url, headers=self.headers, json=data)
207 response.raise_for_status()
208 return response.json()
209 except requests.exceptions.HTTPError as e:
210 logging.debug(f"HTTP error: {e}, Response: {response.text}")
211 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}")
212 except requests.RequestException as e:
213 logging.debug(f"Connection error: {e}")
214 raise ConnectionError(f"Connection error occurred: {e}")
216 #
217 # Authentication methods
218 #
220 def validate_token(self):
221 """
222 Validate the current token by calling the SCIM Me endpoint.
224 Returns:
225 True if the token is valid, False otherwise
226 """
227 try:
228 response = self.get("/api/2.0/preview/scim/v2/Me")
229 return True if response else False
230 except Exception as e:
231 logging.debug(f"Token validation failed: {e}")
232 return False
234 #
235 # Unity Catalog methods
236 #
238 def list_catalogs(self, include_browse=False, max_results=None, page_token=None):
239 """
240 Gets an array of catalogs in the metastore.
242 Args:
243 include_browse: Whether to include catalogs for which the principal can only access selective metadata
244 max_results: Maximum number of catalogs to return (optional)
245 page_token: Opaque pagination token to go to next page (optional)
247 Returns:
248 Dictionary containing:
249 - catalogs: List of catalogs
250 - next_page_token: Token for retrieving the next page (if available)
251 """
252 params = {}
253 if include_browse:
254 params["include_browse"] = "true"
255 if max_results is not None:
256 params["max_results"] = str(max_results)
257 if page_token:
258 params["page_token"] = page_token
260 if params:
261 return self.get_with_params("/api/2.1/unity-catalog/catalogs", params)
262 return self.get("/api/2.1/unity-catalog/catalogs")
264 def get_catalog(self, catalog_name):
265 """
266 Gets a catalog from Unity Catalog.
268 Args:
269 catalog_name: Name of the catalog
271 Returns:
272 Catalog information
273 """
274 return self.get(f"/api/2.1/unity-catalog/catalogs/{catalog_name}")
276 def list_schemas(
277 self, catalog_name, include_browse=False, max_results=None, page_token=None
278 ):
279 """
280 Gets an array of schemas for a catalog in the metastore.
282 Args:
283 catalog_name: Parent catalog for schemas of interest (required)
284 include_browse: Whether to include schemas for which the principal can only access selective metadata
285 max_results: Maximum number of schemas to return (optional)
286 page_token: Opaque pagination token to go to next page (optional)
288 Returns:
289 Dictionary containing:
290 - schemas: List of schemas
291 - next_page_token: Token for retrieving the next page (if available)
292 """
293 params = {"catalog_name": catalog_name}
294 if include_browse:
295 params["include_browse"] = "true"
296 if max_results is not None:
297 params["max_results"] = str(max_results)
298 if page_token:
299 params["page_token"] = page_token
301 return self.get_with_params("/api/2.1/unity-catalog/schemas", params)
303 def get_schema(self, full_name):
304 """
305 Gets a schema from Unity Catalog.
307 Args:
308 full_name: Full name of the schema in the format 'catalog_name.schema_name'
310 Returns:
311 Schema information
312 """
313 return self.get(f"/api/2.1/unity-catalog/schemas/{full_name}")
315 def list_tables(
316 self,
317 catalog_name,
318 schema_name,
319 max_results=None,
320 page_token=None,
321 include_delta_metadata=False,
322 omit_columns=False,
323 omit_properties=False,
324 omit_username=False,
325 include_browse=False,
326 include_manifest_capabilities=False,
327 ):
328 """
329 Gets an array of all tables for the current metastore under the parent catalog and schema.
331 Args:
332 catalog_name: Name of parent catalog for tables of interest (required)
333 schema_name: Parent schema of tables (required)
334 max_results: Maximum number of tables to return (optional)
335 page_token: Opaque token to send for the next page of results (optional)
336 include_delta_metadata: Whether delta metadata should be included (optional)
337 omit_columns: Whether to omit columns from the response (optional)
338 omit_properties: Whether to omit properties from the response (optional)
339 omit_username: Whether to omit username from the response (optional)
340 include_browse: Whether to include tables with selective metadata access (optional)
341 include_manifest_capabilities: Whether to include table capabilities (optional)
343 Returns:
344 Dictionary containing:
345 - tables: List of tables
346 - next_page_token: Token for retrieving the next page (if available)
347 """
348 params = {"catalog_name": catalog_name, "schema_name": schema_name}
350 if max_results is not None:
351 params["max_results"] = str(max_results)
352 if page_token:
353 params["page_token"] = page_token
354 if include_delta_metadata:
355 params["include_delta_metadata"] = "true"
356 if omit_columns:
357 params["omit_columns"] = "true"
358 if omit_properties:
359 params["omit_properties"] = "true"
360 if omit_username:
361 params["omit_username"] = "true"
362 if include_browse:
363 params["include_browse"] = "true"
364 if include_manifest_capabilities:
365 params["include_manifest_capabilities"] = "true"
366 return self.get_with_params("/api/2.1/unity-catalog/tables", params)
368 def get_table(
369 self,
370 full_name,
371 include_delta_metadata=False,
372 include_browse=False,
373 include_manifest_capabilities=False,
374 ):
375 """
376 Gets a table from the metastore for a specific catalog and schema.
378 Args:
379 full_name: Full name of the table in format 'catalog_name.schema_name.table_name'
380 include_delta_metadata: Whether delta metadata should be included (optional)
381 include_browse: Whether to include tables with selective metadata access (optional)
382 include_manifest_capabilities: Whether to include table capabilities (optional)
384 Returns:
385 Table information
386 """
387 params = {}
388 if include_delta_metadata:
389 params["include_delta_metadata"] = "true"
390 if include_browse:
391 params["include_browse"] = "true"
392 if include_manifest_capabilities:
393 params["include_manifest_capabilities"] = "true"
395 if params:
396 return self.get_with_params(
397 f"/api/2.1/unity-catalog/tables/{full_name}", params
398 )
399 return self.get(f"/api/2.1/unity-catalog/tables/{full_name}")
401 def list_volumes(
402 self,
403 catalog_name,
404 schema_name,
405 max_results=None,
406 page_token=None,
407 include_browse=False,
408 ):
409 """
410 Gets an array of volumes for the current metastore under the parent catalog and schema.
412 Args:
413 catalog_name: Name of parent catalog (required)
414 schema_name: Name of parent schema (required)
415 max_results: Maximum number of volumes to return (optional)
416 page_token: Opaque token for pagination (optional)
417 include_browse: Whether to include volumes with selective metadata access (optional)
419 Returns:
420 Dictionary containing:
421 - volumes: List of volumes
422 - next_page_token: Token for retrieving the next page (if available)
423 """
424 params = {"catalog_name": catalog_name, "schema_name": schema_name}
425 if max_results is not None:
426 params["max_results"] = str(max_results)
427 if page_token:
428 params["page_token"] = page_token
429 if include_browse:
430 params["include_browse"] = "true"
432 return self.get_with_params("/api/2.1/unity-catalog/volumes", params)
434 def create_volume(self, catalog_name, schema_name, name, volume_type="MANAGED"):
435 """
436 Create a new volume in Unity Catalog.
438 Args:
439 catalog_name: The name of the catalog where the volume will be created
440 schema_name: The name of the schema where the volume will be created
441 name: The name of the volume to create
442 volume_type: The type of volume to create (default: "MANAGED")
444 Returns:
445 Dict containing the created volume information
446 """
447 data = {
448 "catalog_name": catalog_name,
449 "schema_name": schema_name,
450 "name": name,
451 "volume_type": volume_type,
452 }
453 return self.post("/api/2.1/unity-catalog/volumes", data)
455 #
456 # Models and Serving methods
457 #
459 def list_models(self):
460 """
461 Fetch a list of models from the Databricks Serving API.
463 Returns:
464 List of available model endpoints
465 """
466 response = self.get("/api/2.0/serving-endpoints")
467 return response.get("endpoints", [])
469 def get_model(self, model_name):
470 """
471 Get details of a specific model from Databricks Serving API.
473 Args:
474 model_name: Name of the model to retrieve
476 Returns:
477 Model details if found
478 """
479 try:
480 return self.get(f"/api/2.0/serving-endpoints/{model_name}")
481 except ValueError as e:
482 if "404" in str(e):
483 logging.warning(f"Model '{model_name}' not found")
484 return None
485 raise
487 #
488 # Warehouse methods
489 #
491 def list_warehouses(self):
492 """
493 Lists all SQL warehouses in the Databricks workspace.
495 Returns:
496 List of warehouses
497 """
498 response = self.get("/api/2.0/sql/warehouses")
499 return response.get("warehouses", [])
501 def get_warehouse(self, warehouse_id):
502 """
503 Gets information about a specific SQL warehouse.
505 Args:
506 warehouse_id: ID of the SQL warehouse
508 Returns:
509 Warehouse information
510 """
511 return self.get(f"/api/2.0/sql/warehouses/{warehouse_id}")
513 def create_warehouse(self, opts):
514 """
515 Creates a new SQL warehouse.
517 Args:
518 opts: Dictionary containing warehouse configuration options
520 Returns:
521 Created warehouse information
522 """
523 return self.post("/api/2.0/sql/warehouses", opts)
525 def submit_sql_statement(
526 self,
527 sql_text,
528 warehouse_id,
529 catalog=None,
530 wait_timeout="30s",
531 on_wait_timeout="CONTINUE",
532 ):
533 """
534 Submit a SQL statement to Databricks SQL warehouse and wait for completion.
536 Args:
537 sql_text: SQL statement to execute
538 warehouse_id: ID of the SQL warehouse
539 catalog: Optional catalog name
540 wait_timeout: How long to wait for query completion (default "30s")
541 on_wait_timeout: What to do on timeout ("CONTINUE" or "CANCEL")
543 Returns:
544 Dictionary containing the SQL statement execution result
545 """
546 data = {
547 "statement": sql_text,
548 "warehouse_id": warehouse_id,
549 "wait_timeout": wait_timeout,
550 "on_wait_timeout": on_wait_timeout,
551 }
553 if catalog:
554 data["catalog"] = catalog
556 # Submit the SQL statement
557 response = self.post("/api/2.0/sql/statements", data)
558 statement_id = response.get("statement_id")
560 # Poll until complete
561 while True:
562 status = self.get(f"/api/2.0/sql/statements/{statement_id}")
563 state = status.get("status", {}).get("state", status.get("state"))
564 if state not in ["PENDING", "RUNNING"]:
565 break
566 time.sleep(1)
568 return status
570 #
571 # Jobs methods
572 #
574 def submit_job_run(self, config_path, init_script_path, run_name=None):
575 """
576 Submit a one-time Databricks job run using the /runs/submit endpoint.
578 Args:
579 config_path: Path to the configuration file for the job in the Volume
580 init_script_path: Path to the initialization script
581 run_name: Optional name for the run. If None, a default name will be generated.
583 Returns:
584 Dict containing the job run information (including run_id)
585 """
586 if not run_name:
587 run_name = (
588 f"Chuck AI One-Time Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
589 )
591 # Define the task and cluster for the one-time run
592 # Create base cluster configuration
593 cluster_config = {
594 "cluster_name": "",
595 "spark_version": "16.0.x-cpu-ml-scala2.12",
596 "init_scripts": [
597 {
598 "volumes": {
599 "destination": init_script_path,
600 }
601 }
602 ],
603 "node_type_id": self.get_compute_node_type(),
604 "custom_tags": {
605 "stack": "aws-dev",
606 "sys": "chuck",
607 "tenant": "amperity",
608 },
609 "spark_env_vars": {
610 "JNAME": "zulu17-ca-amd64",
611 "CHUCK_API_URL": f"https://{get_amperity_url()}",
612 "DEBUG_INIT_SRIPT_URL": init_script_path,
613 "DEBUG_CONFIG_PATH": config_path,
614 },
615 "enable_elastic_disk": False,
616 "data_security_mode": "SINGLE_USER",
617 "runtime_engine": "STANDARD",
618 "autoscale": {"min_workers": 10, "max_workers": 50},
619 }
621 # Add cloud-specific attributes
622 cluster_config.update(self.get_cloud_attributes())
624 run_payload = {
625 "run_name": run_name,
626 "tasks": [
627 {
628 "task_key": "Run_Stitch",
629 "run_if": "ALL_SUCCESS",
630 "spark_jar_task": {
631 "jar_uri": "",
632 "main_class_name": os.environ.get(
633 "MAIN_CLASS", "amperity.stitch_standalone.chuck_main"
634 ),
635 "parameters": [
636 "",
637 config_path,
638 ],
639 "run_as_repl": True,
640 },
641 "libraries": [{"jar": "file:///opt/amperity/job.jar"}],
642 "timeout_seconds": 0,
643 "email_notifications": {},
644 "webhook_notifications": {},
645 "new_cluster": cluster_config,
646 },
647 ],
648 "timeout_seconds": 0,
649 }
651 return self.post("/api/2.2/jobs/runs/submit", run_payload)
653 def get_job_run_status(self, run_id):
654 """
655 Get the status of a Databricks job run.
657 Args:
658 run_id: The job run ID (as str or int)
660 Returns:
661 Dict containing the job run status information
662 """
663 params = {"run_id": run_id}
664 return self.get_with_params("/api/2.2/jobs/runs/get", params)
666 #
667 # File system methods
668 #
670 def upload_file(self, path, file_path=None, content=None, overwrite=False):
671 """
672 Upload a file using the /api/2.0/fs/files endpoint.
674 Args:
675 path: The destination path (e.g., "/Volumes/my-catalog/my-schema/my-volume/file.txt")
676 file_path: Local file path to upload (mutually exclusive with content)
677 content: String content to upload (mutually exclusive with file_path)
678 overwrite: Whether to overwrite an existing file
680 Returns:
681 True if successful (API returns no content on success)
683 Raises:
684 ValueError: If both file_path and content are provided or neither is provided
685 ValueError: If an HTTP error occurs
686 ConnectionError: If a connection error occurs
687 """
688 if (file_path and content) or (not file_path and not content):
689 raise ValueError("Exactly one of file_path or content must be provided")
691 # URL encode the path and make sure it starts with a slash
692 if not path.startswith("/"):
693 path = f"/{path}"
695 # Remove duplicate slashes if any
696 while "//" in path:
697 path = path.replace("//", "/")
699 # URL encode path components but preserve the slashes
700 encoded_path = "/".join(
701 urllib.parse.quote(component) for component in path.split("/") if component
702 )
703 encoded_path = f"/{encoded_path}"
705 url = f"https://{self.workspace_url}.{self.base_domain}/api/2.0/fs/files{encoded_path}"
707 if overwrite:
708 url += "?overwrite=true"
710 logging.debug(f"File upload request to: {url}")
712 headers = self.headers.copy()
713 headers.update({"Content-Type": "application/octet-stream"})
715 # Get binary data to upload
716 if file_path:
717 with open(file_path, "rb") as f:
718 binary_data = f.read()
719 else:
720 # Convert string content to bytes
721 binary_data = content.encode("utf-8")
723 try:
724 # Use PUT request with raw binary data in the body
725 response = requests.put(url, headers=headers, data=binary_data)
726 response.raise_for_status()
727 # API returns 204 No Content on success
728 return True
729 except requests.exceptions.HTTPError as e:
730 logging.debug(f"HTTP error: {e}, Response: {response.text}")
731 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}")
732 except requests.RequestException as e:
733 logging.debug(f"Connection error: {e}")
734 raise ConnectionError(f"Connection error occurred: {e}")
736 def store_dbfs_file(self, path, contents, overwrite=True):
737 """
738 Store content to DBFS using the /api/2.0/dbfs/put endpoint.
740 Args:
741 path: Path in DBFS to store the file
742 contents: String content to store (will be JSON encoded)
743 overwrite: Whether to overwrite an existing file
745 Returns:
746 True if successful
747 """
748 # Encode the content as base64
749 encoded_contents = (
750 base64.b64encode(contents.encode()).decode()
751 if isinstance(contents, str)
752 else base64.b64encode(contents).decode()
753 )
755 # Prepare the request with file content and path
756 request_data = {
757 "path": path,
758 "contents": encoded_contents,
759 "overwrite": overwrite,
760 }
762 # Call DBFS API
763 self.post("/api/2.0/dbfs/put", request_data)
764 return True
766 #
767 # Amperity-specific methods
768 #
770 def fetch_amperity_job_init(self, token, api_url: str | None = None):
771 """
772 Fetch initialization script for Amperity jobs.
774 Args:
775 token: Amperity authentication token
776 api_url: Optional override for the job init endpoint
778 Returns:
779 Dict containing the initialization script data
780 """
781 try:
782 headers = {
783 "Authorization": f"Bearer {token}",
784 "Content-Type": "application/json",
785 }
787 if not api_url:
788 api_url = f"https://{get_amperity_url()}/api/job/launch"
790 response = requests.post(api_url, headers=headers, json={})
791 response.raise_for_status()
792 return response.json()
793 except requests.exceptions.HTTPError as e:
794 response = e.response
795 resp_text = response.text if response else ""
796 logging.debug(f"HTTP error: {e}, Response: {resp_text}")
797 if response is not None:
798 try:
799 message = response.json().get("message", resp_text)
800 except ValueError:
801 message = resp_text
802 raise ValueError(
803 f"{response.status_code} Error: {message}. Please /logout and /login again"
804 )
805 raise ValueError(
806 f"HTTP error occurred: {e}. Please /logout and /login again"
807 )
808 except requests.RequestException as e:
809 logging.debug(f"Connection error: {e}")
810 raise ConnectionError(f"Connection error occurred: {e}")
812 def get_current_user(self):
813 """
814 Get the current user's username from Databricks API.
816 Returns:
817 Username string from the current user
818 """
819 try:
820 response = self.get("/api/2.0/preview/scim/v2/Me")
821 username = response.get("userName")
822 if not username:
823 logging.debug("Username not found in response")
824 raise ValueError("Username not found in API response")
825 return username
826 except Exception as e:
827 logging.debug(f"Error getting current user: {e}")
828 raise
830 def create_stitch_notebook(
831 self, table_path, notebook_name=None, stitch_config=None, datasources=None
832 ):
833 """
834 Create a stitch notebook for the given table path.
836 This function will:
837 1. Load the stitch notebook template
838 2. Extract the notebook name from metadata (or use provided name)
839 3. Get datasources either from provided list, stitch_config, or query the table
840 4. Replace template fields with appropriate values
841 5. Import the notebook to Databricks
843 Args:
844 table_path: Full path to the unified table in format catalog.schema.table
845 notebook_name: Optional name for the notebook. If not provided, the name will be
846 extracted from the template's metadata.
847 stitch_config: Optional stitch configuration dictionary. If provided, datasources will be
848 extracted from stitch_config["tables"][*]["path"] values.
849 datasources: Optional list of datasource values. If provided, these will be used
850 instead of querying the table.
852 Returns:
853 Dictionary containing the notebook path and status
854 """
855 # 1. Load the template notebook
856 template_path = os.path.join(
857 os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
858 "assets",
859 "stitch_notebook_template.ipynb",
860 )
862 try:
863 with open(template_path, "r") as f:
864 notebook_content = json.load(f)
866 # 2. Extract notebook name from metadata (if not provided)
867 extracted_name = "Stitch Results"
868 for cell in notebook_content.get("cells", []):
869 if cell.get("cell_type") == "markdown":
870 source = cell.get("source", [])
871 if source and source[0].startswith("# "):
872 extracted_name = source[0].replace("# ", "").strip()
873 break
875 # Use the provided notebook name if available, otherwise use the extracted name
876 final_notebook_name = notebook_name if notebook_name else extracted_name
878 # 3. Get distinct datasource values
879 # First check if datasources were directly provided
880 if not datasources:
881 # Check if we should extract datasources from stitch_config
882 if (
883 stitch_config
884 and "tables" in stitch_config
885 and stitch_config["tables"]
886 ):
887 # Extract unique datasources from the path values in stitch_config tables
888 datasource_set = set()
889 for table in stitch_config["tables"]:
890 if "path" in table:
891 datasource_set.add(table["path"])
893 datasources = list(datasource_set)
894 # Extract datasources successfully from stitch config
896 # If we still don't have datasources, query the table directly
897 if not datasources:
898 # Get the configured warehouse ID
899 warehouse_id = get_warehouse_id()
901 # If no warehouse ID is configured, try to find a default one
902 if not warehouse_id:
903 warehouses = self.list_warehouses()
904 if not warehouses:
905 raise ValueError(
906 "No SQL warehouses found and no warehouse configured. Please select a warehouse using /warehouse_selection."
907 )
909 # Use the first available warehouse
910 warehouse_id = warehouses[0]["id"]
911 logging.warning(
912 f"No warehouse configured. Using first available warehouse: {warehouse_id}"
913 )
915 # Query for distinct datasource values
916 sql_query = f"SELECT DISTINCT datasource FROM {table_path}"
917 # Execute SQL query to get datasources
918 result = self.submit_sql_statement(sql_query, warehouse_id)
920 # Extract the results from the query response
921 if (
922 result
923 and result.get("result")
924 and result["result"].get("data_array")
925 ):
926 datasources = [row[0] for row in result["result"]["data_array"]]
927 # Successfully extracted datasources from query
929 # If we still don't have datasources, use a default value
930 if not datasources:
931 logging.warning(f"No datasources found for {table_path}")
932 datasources = ["default_source"]
933 # Use default datasource as a fallback
935 # 4. Create the source names JSON mapping
936 source_names_json = {}
937 for source in datasources:
938 source_names_json[source] = source
940 # Convert to JSON string (formatted nicely)
941 source_names_str = json.dumps(source_names_json, indent=4)
942 # Source mapping created for template
944 # 5. Replace the template fields
945 # Replace template placeholders with actual values
947 # Need to directly modify the notebook cells rather than doing string replacement on the JSON
948 for cell in notebook_content.get("cells", []):
949 if cell.get("cell_type") == "code":
950 source_lines = cell.get("source", [])
951 for i, line in enumerate(source_lines):
952 if (
953 '"{UNIFIED_PATH}"' in line
954 or "unified_coalesced_path = " in line
955 ):
956 # Found placeholder in notebook
957 # Replace the line with our table path
958 source_lines[i] = (
959 f'unified_coalesced_path = "{table_path}"\n'
960 )
961 break
963 # Replace source semantic mapping in the cells
964 replaced_mapping = False
965 for cell in notebook_content.get("cells", []):
966 if cell.get("cell_type") == "code":
967 source_lines = cell.get("source", [])
968 # Look for the source_semantic_mapping definition
969 for i, line in enumerate(source_lines):
970 if "source_semantic_mapping =" in line and not replaced_mapping:
971 # Found source names placeholder
973 # Find the closing brace
974 closing_index = None
975 opening_count = 0
976 for j in range(i, len(source_lines)):
977 if "{" in source_lines[j]:
978 opening_count += 1
979 if "}" in source_lines[j]:
980 opening_count -= 1
981 if opening_count == 0:
982 closing_index = j
983 break
985 if closing_index is not None:
986 # Replace the mapping with our custom mapping
987 mapping_start = i
988 mapping_end = closing_index + 1
989 new_line = (
990 f"source_semantic_mapping = {source_names_str}\n"
991 )
992 source_lines[mapping_start:mapping_end] = [new_line]
993 replaced_mapping = True
994 # Import notebook to workspace
995 break
997 if not replaced_mapping:
998 logging.warning(
999 "Could not find source_semantic_mapping in the notebook template to replace"
1000 )
1002 # 6. Get the current user's username
1003 username = self.get_current_user()
1005 # 7. Construct the notebook path
1006 notebook_path = f"/Workspace/Users/{username}/{final_notebook_name}"
1008 # 8. Convert the notebook to base64 for API call
1009 notebook_content_str = json.dumps(notebook_content)
1010 encoded_content = base64.b64encode(notebook_content_str.encode()).decode()
1012 # 9. Import the notebook to Databricks
1013 import_data = {
1014 "path": notebook_path,
1015 "content": encoded_content,
1016 "format": "JUPYTER",
1017 "overwrite": True,
1018 }
1020 self.post("/api/2.0/workspace/import", import_data)
1022 # 10. Log success and return the path
1023 # Notebook created successfully
1025 return {"notebook_path": notebook_path, "status": "success"}
1027 except Exception as e:
1028 logging.debug(f"Error creating stitch notebook: {e}")
1029 raise