Coverage for src/chuck_data/clients/databricks.py: 0%
336 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
1"""
2Reusable Databricks API client for authentication and requests.
3"""
5import base64
6import json
7import logging
8import os
9import requests
10import time
11import urllib.parse
12from datetime import datetime
13from ..config import get_warehouse_id
14from ..clients.amperity import get_amperity_url
15from ..databricks.url_utils import (
16 detect_cloud_provider,
17 normalize_workspace_url,
18)
21class DatabricksAPIClient:
22 """Reusable Databricks API client for authentication and requests."""
24 def __init__(self, workspace_url, token):
25 """
26 Initialize the API client.
28 Args:
29 workspace_url: Databricks workspace URL (with or without protocol/domain)
30 token: Databricks API token
31 """
32 self.original_url = workspace_url
33 self.workspace_url = self._normalize_workspace_url(workspace_url)
34 self.cloud_provider = detect_cloud_provider(workspace_url)
35 self.base_domain = self._get_base_domain()
36 self.token = token
37 self.headers = {
38 "Authorization": f"Bearer {self.token}",
39 "User-Agent": "amperity",
40 }
42 def _normalize_workspace_url(self, url):
43 """
44 Normalize the workspace URL to the format needed for API calls.
46 Args:
47 url: Input workspace URL that might be in various formats
49 Returns:
50 Cleaned workspace URL (workspace ID only)
51 """
52 return normalize_workspace_url(url)
54 def _get_base_domain(self):
55 """
56 Get the appropriate base domain based on the cloud provider.
58 Returns:
59 Base domain string for the detected cloud provider
60 """
61 from ..databricks.url_utils import DATABRICKS_DOMAIN_MAP
63 return DATABRICKS_DOMAIN_MAP.get(
64 self.cloud_provider, DATABRICKS_DOMAIN_MAP["AWS"]
65 )
67 def get_compute_node_type(self):
68 """
69 Get the appropriate compute node type based on the cloud provider.
71 Returns:
72 Node type string for the detected cloud provider
73 """
74 node_type_map = {
75 "AWS": "r5d.4xlarge",
76 "Azure": "Standard_E16ds_v4",
77 "GCP": "n2-standard-16", # Default GCP node type
78 "Generic": "r5d.4xlarge", # Default to AWS
79 }
80 return node_type_map.get(self.cloud_provider, "r5d.4xlarge")
82 def get_cloud_attributes(self):
83 """
84 Get cloud-specific attributes for cluster configuration.
86 Returns:
87 Dictionary containing cloud-specific attributes
88 """
89 if self.cloud_provider == "AWS":
90 return {
91 "aws_attributes": {
92 "first_on_demand": 1,
93 "availability": "SPOT_WITH_FALLBACK",
94 "zone_id": "us-west-2b",
95 "spot_bid_price_percent": 100,
96 "ebs_volume_count": 0,
97 }
98 }
99 elif self.cloud_provider == "Azure":
100 return {
101 "azure_attributes": {
102 "first_on_demand": 1,
103 "availability": "SPOT_WITH_FALLBACK_AZURE",
104 "spot_bid_max_price": -1,
105 }
106 }
107 elif self.cloud_provider == "GCP":
108 return {
109 "gcp_attributes": {
110 "use_preemptible_executors": True,
111 "google_service_account": None,
112 }
113 }
114 else:
115 # Default to AWS
116 return {
117 "aws_attributes": {
118 "first_on_demand": 1,
119 "availability": "SPOT_WITH_FALLBACK",
120 "zone_id": "us-west-2b",
121 "spot_bid_price_percent": 100,
122 "ebs_volume_count": 0,
123 }
124 }
126 #
127 # Base API request methods
128 #
130 def get(self, endpoint):
131 """
132 Send a GET request to the Databricks API.
134 Args:
135 endpoint: API endpoint (starting with /)
137 Returns:
138 JSON response from the API
140 Raises:
141 ValueError: If an HTTP error occurs
142 ConnectionError: If a connection error occurs
143 """
144 url = f"https://{self.workspace_url}.{self.base_domain}{endpoint}"
145 logging.debug(f"GET request to: {url}")
147 try:
148 response = requests.get(url, headers=self.headers)
149 response.raise_for_status()
150 return response.json()
151 except requests.exceptions.HTTPError as e:
152 logging.debug(f"HTTP error: {e}, Response: {response.text}")
153 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}")
154 except requests.RequestException as e:
155 logging.debug(f"Connection error: {e}")
156 raise ConnectionError(f"Connection error occurred: {e}")
158 def get_with_params(self, endpoint, params=None):
159 """
160 Send a GET request to the Databricks API with query parameters.
162 Args:
163 endpoint: API endpoint (starting with /)
164 params: Dictionary of query parameters
166 Returns:
167 JSON response from the API
169 Raises:
170 ValueError: If an HTTP error occurs
171 ConnectionError: If a connection error occurs
172 """
173 url = f"https://{self.workspace_url}.{self.base_domain}{endpoint}"
174 logging.debug(f"GET request with params to: {url}")
176 try:
177 response = requests.get(url, headers=self.headers, params=params)
178 response.raise_for_status()
179 return response.json()
180 except requests.exceptions.HTTPError as e:
181 logging.debug(f"HTTP error: {e}, Response: {response.text}")
182 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}")
183 except requests.RequestException as e:
184 logging.debug(f"Connection error: {e}")
185 raise ConnectionError(f"Connection error occurred: {e}")
187 def post(self, endpoint, data):
188 """
189 Send a POST request to the Databricks API.
191 Args:
192 endpoint: API endpoint (starting with /)
193 data: JSON data to send in the request body
195 Returns:
196 JSON response from the API
198 Raises:
199 ValueError: If an HTTP error occurs
200 ConnectionError: If a connection error occurs
201 """
202 url = f"https://{self.workspace_url}.{self.base_domain}{endpoint}"
203 logging.debug(f"POST request to: {url}")
205 try:
206 response = requests.post(url, headers=self.headers, json=data)
207 response.raise_for_status()
208 return response.json()
209 except requests.exceptions.HTTPError as e:
210 logging.debug(f"HTTP error: {e}, Response: {response.text}")
211 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}")
212 except requests.RequestException as e:
213 logging.debug(f"Connection error: {e}")
214 raise ConnectionError(f"Connection error occurred: {e}")
216 #
217 # Authentication methods
218 #
220 def validate_token(self):
221 """
222 Validate the current token by calling the SCIM Me endpoint.
224 Returns:
225 True if the token is valid, False otherwise
226 """
227 try:
228 response = self.get("/api/2.0/preview/scim/v2/Me")
229 return True if response else False
230 except Exception as e:
231 logging.debug(f"Token validation failed: {e}")
232 return False
234 #
235 # Unity Catalog methods
236 #
238 def list_catalogs(self, include_browse=False, max_results=None, page_token=None):
239 """
240 Gets an array of catalogs in the metastore.
242 Args:
243 include_browse: Whether to include catalogs for which the principal can only access selective metadata
244 max_results: Maximum number of catalogs to return (optional)
245 page_token: Opaque pagination token to go to next page (optional)
247 Returns:
248 Dictionary containing:
249 - catalogs: List of catalogs
250 - next_page_token: Token for retrieving the next page (if available)
251 """
252 params = {}
253 if include_browse:
254 params["include_browse"] = "true"
255 if max_results is not None:
256 params["max_results"] = str(max_results)
257 if page_token:
258 params["page_token"] = page_token
260 if params:
261 return self.get_with_params("/api/2.1/unity-catalog/catalogs", params)
262 return self.get("/api/2.1/unity-catalog/catalogs")
264 def get_catalog(self, catalog_name):
265 """
266 Gets a catalog from Unity Catalog.
268 Args:
269 catalog_name: Name of the catalog
271 Returns:
272 Catalog information
273 """
274 return self.get(f"/api/2.1/unity-catalog/catalogs/{catalog_name}")
276 def list_schemas(
277 self, catalog_name, include_browse=False, max_results=None, page_token=None
278 ):
279 """
280 Gets an array of schemas for a catalog in the metastore.
282 Args:
283 catalog_name: Parent catalog for schemas of interest (required)
284 include_browse: Whether to include schemas for which the principal can only access selective metadata
285 max_results: Maximum number of schemas to return (optional)
286 page_token: Opaque pagination token to go to next page (optional)
288 Returns:
289 Dictionary containing:
290 - schemas: List of schemas
291 - next_page_token: Token for retrieving the next page (if available)
292 """
293 params = {"catalog_name": catalog_name}
294 if include_browse:
295 params["include_browse"] = "true"
296 if max_results is not None:
297 params["max_results"] = str(max_results)
298 if page_token:
299 params["page_token"] = page_token
301 return self.get_with_params("/api/2.1/unity-catalog/schemas", params)
303 def get_schema(self, full_name):
304 """
305 Gets a schema from Unity Catalog.
307 Args:
308 full_name: Full name of the schema in the format 'catalog_name.schema_name'
310 Returns:
311 Schema information
312 """
313 return self.get(f"/api/2.1/unity-catalog/schemas/{full_name}")
315 def list_tables(
316 self,
317 catalog_name,
318 schema_name,
319 max_results=None,
320 page_token=None,
321 include_delta_metadata=False,
322 omit_columns=False,
323 omit_properties=False,
324 omit_username=False,
325 include_browse=False,
326 include_manifest_capabilities=False,
327 ):
328 """
329 Gets an array of all tables for the current metastore under the parent catalog and schema.
331 Args:
332 catalog_name: Name of parent catalog for tables of interest (required)
333 schema_name: Parent schema of tables (required)
334 max_results: Maximum number of tables to return (optional)
335 page_token: Opaque token to send for the next page of results (optional)
336 include_delta_metadata: Whether delta metadata should be included (optional)
337 omit_columns: Whether to omit columns from the response (optional)
338 omit_properties: Whether to omit properties from the response (optional)
339 omit_username: Whether to omit username from the response (optional)
340 include_browse: Whether to include tables with selective metadata access (optional)
341 include_manifest_capabilities: Whether to include table capabilities (optional)
343 Returns:
344 Dictionary containing:
345 - tables: List of tables
346 - next_page_token: Token for retrieving the next page (if available)
347 """
348 params = {"catalog_name": catalog_name, "schema_name": schema_name}
350 if max_results is not None:
351 params["max_results"] = str(max_results)
352 if page_token:
353 params["page_token"] = page_token
354 if include_delta_metadata:
355 params["include_delta_metadata"] = "true"
356 if omit_columns:
357 params["omit_columns"] = "true"
358 if omit_properties:
359 params["omit_properties"] = "true"
360 if omit_username:
361 params["omit_username"] = "true"
362 if include_browse:
363 params["include_browse"] = "true"
364 if include_manifest_capabilities:
365 params["include_manifest_capabilities"] = "true"
367 return self.get_with_params("/api/2.1/unity-catalog/tables", params)
369 def get_table(
370 self,
371 full_name,
372 include_delta_metadata=False,
373 include_browse=False,
374 include_manifest_capabilities=False,
375 ):
376 """
377 Gets a table from the metastore for a specific catalog and schema.
379 Args:
380 full_name: Full name of the table in format 'catalog_name.schema_name.table_name'
381 include_delta_metadata: Whether delta metadata should be included (optional)
382 include_browse: Whether to include tables with selective metadata access (optional)
383 include_manifest_capabilities: Whether to include table capabilities (optional)
385 Returns:
386 Table information
387 """
388 params = {}
389 if include_delta_metadata:
390 params["include_delta_metadata"] = "true"
391 if include_browse:
392 params["include_browse"] = "true"
393 if include_manifest_capabilities:
394 params["include_manifest_capabilities"] = "true"
396 if params:
397 return self.get_with_params(
398 f"/api/2.1/unity-catalog/tables/{full_name}", params
399 )
400 return self.get(f"/api/2.1/unity-catalog/tables/{full_name}")
402 def list_volumes(
403 self,
404 catalog_name,
405 schema_name,
406 max_results=None,
407 page_token=None,
408 include_browse=False,
409 ):
410 """
411 Gets an array of volumes for the current metastore under the parent catalog and schema.
413 Args:
414 catalog_name: Name of parent catalog (required)
415 schema_name: Name of parent schema (required)
416 max_results: Maximum number of volumes to return (optional)
417 page_token: Opaque token for pagination (optional)
418 include_browse: Whether to include volumes with selective metadata access (optional)
420 Returns:
421 Dictionary containing:
422 - volumes: List of volumes
423 - next_page_token: Token for retrieving the next page (if available)
424 """
425 params = {"catalog_name": catalog_name, "schema_name": schema_name}
426 if max_results is not None:
427 params["max_results"] = str(max_results)
428 if page_token:
429 params["page_token"] = page_token
430 if include_browse:
431 params["include_browse"] = "true"
433 return self.get_with_params("/api/2.1/unity-catalog/volumes", params)
435 def create_volume(self, catalog_name, schema_name, name, volume_type="MANAGED"):
436 """
437 Create a new volume in Unity Catalog.
439 Args:
440 catalog_name: The name of the catalog where the volume will be created
441 schema_name: The name of the schema where the volume will be created
442 name: The name of the volume to create
443 volume_type: The type of volume to create (default: "MANAGED")
445 Returns:
446 Dict containing the created volume information
447 """
448 data = {
449 "catalog_name": catalog_name,
450 "schema_name": schema_name,
451 "name": name,
452 "volume_type": volume_type,
453 }
454 return self.post("/api/2.1/unity-catalog/volumes", data)
456 #
457 # Models and Serving methods
458 #
460 def list_models(self):
461 """
462 Fetch a list of models from the Databricks Serving API.
464 Returns:
465 List of available model endpoints
466 """
467 response = self.get("/api/2.0/serving-endpoints")
468 return response.get("endpoints", [])
470 def get_model(self, model_name):
471 """
472 Get details of a specific model from Databricks Serving API.
474 Args:
475 model_name: Name of the model to retrieve
477 Returns:
478 Model details if found
479 """
480 try:
481 return self.get(f"/api/2.0/serving-endpoints/{model_name}")
482 except ValueError as e:
483 if "404" in str(e):
484 logging.warning(f"Model '{model_name}' not found")
485 return None
486 raise
488 #
489 # Warehouse methods
490 #
492 def list_warehouses(self):
493 """
494 Lists all SQL warehouses in the Databricks workspace.
496 Returns:
497 List of warehouses
498 """
499 response = self.get("/api/2.0/sql/warehouses")
500 return response.get("warehouses", [])
502 def get_warehouse(self, warehouse_id):
503 """
504 Gets information about a specific SQL warehouse.
506 Args:
507 warehouse_id: ID of the SQL warehouse
509 Returns:
510 Warehouse information
511 """
512 return self.get(f"/api/2.0/sql/warehouses/{warehouse_id}")
514 def create_warehouse(self, opts):
515 """
516 Creates a new SQL warehouse.
518 Args:
519 opts: Dictionary containing warehouse configuration options
521 Returns:
522 Created warehouse information
523 """
524 return self.post("/api/2.0/sql/warehouses", opts)
526 def submit_sql_statement(
527 self,
528 sql_text,
529 warehouse_id,
530 catalog=None,
531 wait_timeout="30s",
532 on_wait_timeout="CONTINUE",
533 ):
534 """
535 Submit a SQL statement to Databricks SQL warehouse and wait for completion.
537 Args:
538 sql_text: SQL statement to execute
539 warehouse_id: ID of the SQL warehouse
540 catalog: Optional catalog name
541 wait_timeout: How long to wait for query completion (default "30s")
542 on_wait_timeout: What to do on timeout ("CONTINUE" or "CANCEL")
544 Returns:
545 Dictionary containing the SQL statement execution result
546 """
547 data = {
548 "statement": sql_text,
549 "warehouse_id": warehouse_id,
550 "wait_timeout": wait_timeout,
551 "on_wait_timeout": on_wait_timeout,
552 }
554 if catalog:
555 data["catalog"] = catalog
557 # Submit the SQL statement
558 response = self.post("/api/2.0/sql/statements", data)
559 statement_id = response.get("statement_id")
561 # Poll until complete
562 while True:
563 status = self.get(f"/api/2.0/sql/statements/{statement_id}")
564 state = status.get("status", {}).get("state", status.get("state"))
565 if state not in ["PENDING", "RUNNING"]:
566 break
567 time.sleep(1)
569 return status
571 #
572 # Jobs methods
573 #
575 def submit_job_run(self, config_path, init_script_path, run_name=None):
576 """
577 Submit a one-time Databricks job run using the /runs/submit endpoint.
579 Args:
580 config_path: Path to the configuration file for the job in the Volume
581 init_script_path: Path to the initialization script
582 run_name: Optional name for the run. If None, a default name will be generated.
584 Returns:
585 Dict containing the job run information (including run_id)
586 """
587 if not run_name:
588 run_name = (
589 f"Chuck AI One-Time Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
590 )
592 # Define the task and cluster for the one-time run
593 # Create base cluster configuration
594 cluster_config = {
595 "cluster_name": "",
596 "spark_version": "16.0.x-cpu-ml-scala2.12",
597 "init_scripts": [
598 {
599 "volumes": {
600 "destination": init_script_path,
601 }
602 }
603 ],
604 "node_type_id": self.get_compute_node_type(),
605 "custom_tags": {
606 "stack": "aws-dev",
607 "sys": "chuck",
608 "tenant": "amperity",
609 },
610 "spark_env_vars": {
611 "JNAME": "zulu17-ca-amd64",
612 "CHUCK_API_URL": f"https://{get_amperity_url()}",
613 "DEBUG_INIT_SRIPT_URL": init_script_path,
614 "DEBUG_CONFIG_PATH": config_path,
615 },
616 "enable_elastic_disk": False,
617 "data_security_mode": "SINGLE_USER",
618 "runtime_engine": "STANDARD",
619 "autoscale": {"min_workers": 10, "max_workers": 50},
620 }
622 # Add cloud-specific attributes
623 cluster_config.update(self.get_cloud_attributes())
625 run_payload = {
626 "run_name": run_name,
627 "tasks": [
628 {
629 "task_key": "Run_Stitch",
630 "run_if": "ALL_SUCCESS",
631 "spark_jar_task": {
632 "jar_uri": "",
633 "main_class_name": os.environ.get(
634 "MAIN_CLASS", "amperity.stitch_standalone.chuck_main"
635 ),
636 "parameters": [
637 "",
638 config_path,
639 ],
640 "run_as_repl": True,
641 },
642 "libraries": [{"jar": "file:///opt/amperity/job.jar"}],
643 "timeout_seconds": 0,
644 "email_notifications": {},
645 "webhook_notifications": {},
646 "new_cluster": cluster_config,
647 },
648 ],
649 "timeout_seconds": 0,
650 }
652 return self.post("/api/2.2/jobs/runs/submit", run_payload)
654 def get_job_run_status(self, run_id):
655 """
656 Get the status of a Databricks job run.
658 Args:
659 run_id: The job run ID (as str or int)
661 Returns:
662 Dict containing the job run status information
663 """
664 params = {"run_id": run_id}
665 return self.get_with_params("/api/2.2/jobs/runs/get", params)
667 #
668 # File system methods
669 #
671 def upload_file(self, path, file_path=None, content=None, overwrite=False):
672 """
673 Upload a file using the /api/2.0/fs/files endpoint.
675 Args:
676 path: The destination path (e.g., "/Volumes/my-catalog/my-schema/my-volume/file.txt")
677 file_path: Local file path to upload (mutually exclusive with content)
678 content: String content to upload (mutually exclusive with file_path)
679 overwrite: Whether to overwrite an existing file
681 Returns:
682 True if successful (API returns no content on success)
684 Raises:
685 ValueError: If both file_path and content are provided or neither is provided
686 ValueError: If an HTTP error occurs
687 ConnectionError: If a connection error occurs
688 """
689 if (file_path and content) or (not file_path and not content):
690 raise ValueError("Exactly one of file_path or content must be provided")
692 # URL encode the path and make sure it starts with a slash
693 if not path.startswith("/"):
694 path = f"/{path}"
696 # Remove duplicate slashes if any
697 while "//" in path:
698 path = path.replace("//", "/")
700 # URL encode path components but preserve the slashes
701 encoded_path = "/".join(
702 urllib.parse.quote(component) for component in path.split("/") if component
703 )
704 encoded_path = f"/{encoded_path}"
706 url = f"https://{self.workspace_url}.{self.base_domain}/api/2.0/fs/files{encoded_path}"
708 if overwrite:
709 url += "?overwrite=true"
711 logging.debug(f"File upload request to: {url}")
713 headers = self.headers.copy()
714 headers.update({"Content-Type": "application/octet-stream"})
716 # Get binary data to upload
717 if file_path:
718 with open(file_path, "rb") as f:
719 binary_data = f.read()
720 else:
721 # Convert string content to bytes
722 binary_data = content.encode("utf-8")
724 try:
725 # Use PUT request with raw binary data in the body
726 response = requests.put(url, headers=headers, data=binary_data)
727 response.raise_for_status()
728 # API returns 204 No Content on success
729 return True
730 except requests.exceptions.HTTPError as e:
731 logging.debug(f"HTTP error: {e}, Response: {response.text}")
732 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}")
733 except requests.RequestException as e:
734 logging.debug(f"Connection error: {e}")
735 raise ConnectionError(f"Connection error occurred: {e}")
737 def store_dbfs_file(self, path, contents, overwrite=True):
738 """
739 Store content to DBFS using the /api/2.0/dbfs/put endpoint.
741 Args:
742 path: Path in DBFS to store the file
743 contents: String content to store (will be JSON encoded)
744 overwrite: Whether to overwrite an existing file
746 Returns:
747 True if successful
748 """
749 # Encode the content as base64
750 encoded_contents = (
751 base64.b64encode(contents.encode()).decode()
752 if isinstance(contents, str)
753 else base64.b64encode(contents).decode()
754 )
756 # Prepare the request with file content and path
757 request_data = {
758 "path": path,
759 "contents": encoded_contents,
760 "overwrite": overwrite,
761 }
763 # Call DBFS API
764 self.post("/api/2.0/dbfs/put", request_data)
765 return True
767 #
768 # Amperity-specific methods
769 #
771 def fetch_amperity_job_init(self, token, api_url: str | None = None):
772 """
773 Fetch initialization script for Amperity jobs.
775 Args:
776 token: Amperity authentication token
777 api_url: Optional override for the job init endpoint
779 Returns:
780 Dict containing the initialization script data
781 """
782 try:
783 headers = {
784 "Authorization": f"Bearer {token}",
785 "Content-Type": "application/json",
786 }
788 if not api_url:
789 api_url = f"https://{get_amperity_url()}/api/job/launch"
791 response = requests.post(api_url, headers=headers, json={})
792 response.raise_for_status()
793 return response.json()
794 except requests.exceptions.HTTPError as e:
795 response = e.response
796 resp_text = response.text if response else ""
797 logging.debug(f"HTTP error: {e}, Response: {resp_text}")
798 if response is not None:
799 try:
800 message = response.json().get("message", resp_text)
801 except ValueError:
802 message = resp_text
803 raise ValueError(
804 f"{response.status_code} Error: {message}. Please /logout and /login again"
805 )
806 raise ValueError(
807 f"HTTP error occurred: {e}. Please /logout and /login again"
808 )
809 except requests.RequestException as e:
810 logging.debug(f"Connection error: {e}")
811 raise ConnectionError(f"Connection error occurred: {e}")
813 def get_current_user(self):
814 """
815 Get the current user's username from Databricks API.
817 Returns:
818 Username string from the current user
819 """
820 try:
821 response = self.get("/api/2.0/preview/scim/v2/Me")
822 username = response.get("userName")
823 if not username:
824 logging.debug("Username not found in response")
825 raise ValueError("Username not found in API response")
826 return username
827 except Exception as e:
828 logging.debug(f"Error getting current user: {e}")
829 raise
831 def create_stitch_notebook(
832 self, table_path, notebook_name=None, stitch_config=None, datasources=None
833 ):
834 """
835 Create a stitch notebook for the given table path.
837 This function will:
838 1. Load the stitch notebook template
839 2. Extract the notebook name from metadata (or use provided name)
840 3. Get datasources either from provided list, stitch_config, or query the table
841 4. Replace template fields with appropriate values
842 5. Import the notebook to Databricks
844 Args:
845 table_path: Full path to the unified table in format catalog.schema.table
846 notebook_name: Optional name for the notebook. If not provided, the name will be
847 extracted from the template's metadata.
848 stitch_config: Optional stitch configuration dictionary. If provided, datasources will be
849 extracted from stitch_config["tables"][*]["path"] values.
850 datasources: Optional list of datasource values. If provided, these will be used
851 instead of querying the table.
853 Returns:
854 Dictionary containing the notebook path and status
855 """
856 # 1. Load the template notebook
857 template_path = os.path.join(
858 os.path.dirname(os.path.dirname(__file__)),
859 "assets",
860 "stitch_notebook_template.ipynb",
861 )
863 try:
864 with open(template_path, "r") as f:
865 notebook_content = json.load(f)
867 # 2. Extract notebook name from metadata (if not provided)
868 extracted_name = "Stitch Results"
869 for cell in notebook_content.get("cells", []):
870 if cell.get("cell_type") == "markdown":
871 source = cell.get("source", [])
872 if source and source[0].startswith("# "):
873 extracted_name = source[0].replace("# ", "").strip()
874 break
876 # Use the provided notebook name if available, otherwise use the extracted name
877 final_notebook_name = notebook_name if notebook_name else extracted_name
879 # 3. Get distinct datasource values
880 # First check if datasources were directly provided
881 if not datasources:
882 # Check if we should extract datasources from stitch_config
883 if (
884 stitch_config
885 and "tables" in stitch_config
886 and stitch_config["tables"]
887 ):
888 # Extract unique datasources from the path values in stitch_config tables
889 datasource_set = set()
890 for table in stitch_config["tables"]:
891 if "path" in table:
892 datasource_set.add(table["path"])
894 datasources = list(datasource_set)
895 # Extract datasources successfully from stitch config
897 # If we still don't have datasources, query the table directly
898 if not datasources:
899 # Get the configured warehouse ID
900 warehouse_id = get_warehouse_id()
902 # If no warehouse ID is configured, try to find a default one
903 if not warehouse_id:
904 warehouses = self.list_warehouses()
905 if not warehouses:
906 raise ValueError(
907 "No SQL warehouses found and no warehouse configured. Please select a warehouse using /warehouse_selection."
908 )
910 # Use the first available warehouse
911 warehouse_id = warehouses[0]["id"]
912 logging.warning(
913 f"No warehouse configured. Using first available warehouse: {warehouse_id}"
914 )
916 # Query for distinct datasource values
917 sql_query = f"SELECT DISTINCT datasource FROM {table_path}"
918 # Execute SQL query to get datasources
919 result = self.submit_sql_statement(sql_query, warehouse_id)
921 # Extract the results from the query response
922 if (
923 result
924 and result.get("result")
925 and result["result"].get("data_array")
926 ):
927 datasources = [row[0] for row in result["result"]["data_array"]]
928 # Successfully extracted datasources from query
930 # If we still don't have datasources, use a default value
931 if not datasources:
932 logging.warning(f"No datasources found for {table_path}")
933 datasources = ["default_source"]
934 # Use default datasource as a fallback
936 # 4. Create the source names JSON mapping
937 source_names_json = {}
938 for source in datasources:
939 source_names_json[source] = source
941 # Convert to JSON string (formatted nicely)
942 source_names_str = json.dumps(source_names_json, indent=4)
943 # Source mapping created for template
945 # 5. Replace the template fields
946 # Replace template placeholders with actual values
948 # Need to directly modify the notebook cells rather than doing string replacement on the JSON
949 for cell in notebook_content.get("cells", []):
950 if cell.get("cell_type") == "code":
951 source_lines = cell.get("source", [])
952 for i, line in enumerate(source_lines):
953 if (
954 '"{UNIFIED_PATH}"' in line
955 or "unified_coalesced_path = " in line
956 ):
957 # Found placeholder in notebook
958 # Replace the line with our table path
959 source_lines[i] = (
960 f'unified_coalesced_path = "{table_path}"\n'
961 )
962 break
964 # Replace source semantic mapping in the cells
965 replaced_mapping = False
966 for cell in notebook_content.get("cells", []):
967 if cell.get("cell_type") == "code":
968 source_lines = cell.get("source", [])
969 # Look for the source_semantic_mapping definition
970 for i, line in enumerate(source_lines):
971 if "source_semantic_mapping =" in line and not replaced_mapping:
972 # Found source names placeholder
974 # Find the closing brace
975 closing_index = None
976 opening_count = 0
977 for j in range(i, len(source_lines)):
978 if "{" in source_lines[j]:
979 opening_count += 1
980 if "}" in source_lines[j]:
981 opening_count -= 1
982 if opening_count == 0:
983 closing_index = j
984 break
986 if closing_index is not None:
987 # Replace the mapping with our custom mapping
988 mapping_start = i
989 mapping_end = closing_index + 1
990 new_line = (
991 f"source_semantic_mapping = {source_names_str}\n"
992 )
993 source_lines[mapping_start:mapping_end] = [new_line]
994 replaced_mapping = True
995 # Import notebook to workspace
996 break
998 if not replaced_mapping:
999 logging.warning(
1000 "Could not find source_semantic_mapping in the notebook template to replace"
1001 )
1003 # 6. Get the current user's username
1004 username = self.get_current_user()
1006 # 7. Construct the notebook path
1007 notebook_path = f"/Workspace/Users/{username}/{final_notebook_name}"
1009 # 8. Convert the notebook to base64 for API call
1010 notebook_content_str = json.dumps(notebook_content)
1011 encoded_content = base64.b64encode(notebook_content_str.encode()).decode()
1013 # 9. Import the notebook to Databricks
1014 import_data = {
1015 "path": notebook_path,
1016 "content": encoded_content,
1017 "format": "JUPYTER",
1018 "overwrite": True,
1019 }
1021 self.post("/api/2.0/workspace/import", import_data)
1023 # 10. Log success and return the path
1024 # Notebook created successfully
1026 return {"notebook_path": notebook_path, "status": "success"}
1028 except Exception as e:
1029 logging.debug(f"Error creating stitch notebook: {e}")
1030 raise