Coverage for src/clients/databricks.py: 44%

336 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-05 22:56 -0700

1""" 

2Reusable Databricks API client for authentication and requests. 

3""" 

4 

5import base64 

6import json 

7import logging 

8import os 

9import requests 

10import time 

11import urllib.parse 

12from datetime import datetime 

13from src.config import get_warehouse_id 

14from src.clients.amperity import get_amperity_url 

15from src.databricks.url_utils import ( 

16 detect_cloud_provider, 

17 normalize_workspace_url, 

18) 

19 

20 

21class DatabricksAPIClient: 

22 """Reusable Databricks API client for authentication and requests.""" 

23 

24 def __init__(self, workspace_url, token): 

25 """ 

26 Initialize the API client. 

27 

28 Args: 

29 workspace_url: Databricks workspace URL (with or without protocol/domain) 

30 token: Databricks API token 

31 """ 

32 self.original_url = workspace_url 

33 self.workspace_url = self._normalize_workspace_url(workspace_url) 

34 self.cloud_provider = detect_cloud_provider(workspace_url) 

35 self.base_domain = self._get_base_domain() 

36 self.token = token 

37 self.headers = { 

38 "Authorization": f"Bearer {self.token}", 

39 "User-Agent": "amperity", 

40 } 

41 

42 def _normalize_workspace_url(self, url): 

43 """ 

44 Normalize the workspace URL to the format needed for API calls. 

45 

46 Args: 

47 url: Input workspace URL that might be in various formats 

48 

49 Returns: 

50 Cleaned workspace URL (workspace ID only) 

51 """ 

52 return normalize_workspace_url(url) 

53 

54 def _get_base_domain(self): 

55 """ 

56 Get the appropriate base domain based on the cloud provider. 

57 

58 Returns: 

59 Base domain string for the detected cloud provider 

60 """ 

61 from src.databricks.url_utils import DATABRICKS_DOMAIN_MAP 

62 

63 return DATABRICKS_DOMAIN_MAP.get( 

64 self.cloud_provider, DATABRICKS_DOMAIN_MAP["AWS"] 

65 ) 

66 

67 def get_compute_node_type(self): 

68 """ 

69 Get the appropriate compute node type based on the cloud provider. 

70 

71 Returns: 

72 Node type string for the detected cloud provider 

73 """ 

74 node_type_map = { 

75 "AWS": "r5d.4xlarge", 

76 "Azure": "Standard_E16ds_v4", 

77 "GCP": "n2-standard-16", # Default GCP node type 

78 "Generic": "r5d.4xlarge", # Default to AWS 

79 } 

80 return node_type_map.get(self.cloud_provider, "r5d.4xlarge") 

81 

82 def get_cloud_attributes(self): 

83 """ 

84 Get cloud-specific attributes for cluster configuration. 

85 

86 Returns: 

87 Dictionary containing cloud-specific attributes 

88 """ 

89 if self.cloud_provider == "AWS": 

90 return { 

91 "aws_attributes": { 

92 "first_on_demand": 1, 

93 "availability": "SPOT_WITH_FALLBACK", 

94 "zone_id": "us-west-2b", 

95 "spot_bid_price_percent": 100, 

96 "ebs_volume_count": 0, 

97 } 

98 } 

99 elif self.cloud_provider == "Azure": 

100 return { 

101 "azure_attributes": { 

102 "first_on_demand": 1, 

103 "availability": "SPOT_WITH_FALLBACK_AZURE", 

104 "spot_bid_max_price": -1, 

105 } 

106 } 

107 elif self.cloud_provider == "GCP": 

108 return { 

109 "gcp_attributes": { 

110 "use_preemptible_executors": True, 

111 "google_service_account": None, 

112 } 

113 } 

114 else: 

115 # Default to AWS 

116 return { 

117 "aws_attributes": { 

118 "first_on_demand": 1, 

119 "availability": "SPOT_WITH_FALLBACK", 

120 "zone_id": "us-west-2b", 

121 "spot_bid_price_percent": 100, 

122 "ebs_volume_count": 0, 

123 } 

124 } 

125 

126 # 

127 # Base API request methods 

128 # 

129 

130 def get(self, endpoint): 

131 """ 

132 Send a GET request to the Databricks API. 

133 

134 Args: 

135 endpoint: API endpoint (starting with /) 

136 

137 Returns: 

138 JSON response from the API 

139 

140 Raises: 

141 ValueError: If an HTTP error occurs 

142 ConnectionError: If a connection error occurs 

143 """ 

144 url = f"https://{self.workspace_url}.{self.base_domain}{endpoint}" 

145 logging.debug(f"GET request to: {url}") 

146 

147 try: 

148 response = requests.get(url, headers=self.headers) 

149 response.raise_for_status() 

150 return response.json() 

151 except requests.exceptions.HTTPError as e: 

152 logging.debug(f"HTTP error: {e}, Response: {response.text}") 

153 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}") 

154 except requests.RequestException as e: 

155 logging.debug(f"Connection error: {e}") 

156 raise ConnectionError(f"Connection error occurred: {e}") 

157 

158 def get_with_params(self, endpoint, params=None): 

159 """ 

160 Send a GET request to the Databricks API with query parameters. 

161 

162 Args: 

163 endpoint: API endpoint (starting with /) 

164 params: Dictionary of query parameters 

165 

166 Returns: 

167 JSON response from the API 

168 

169 Raises: 

170 ValueError: If an HTTP error occurs 

171 ConnectionError: If a connection error occurs 

172 """ 

173 url = f"https://{self.workspace_url}.{self.base_domain}{endpoint}" 

174 logging.debug(f"GET request with params to: {url}") 

175 

176 try: 

177 response = requests.get(url, headers=self.headers, params=params) 

178 response.raise_for_status() 

179 return response.json() 

180 except requests.exceptions.HTTPError as e: 

181 logging.debug(f"HTTP error: {e}, Response: {response.text}") 

182 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}") 

183 except requests.RequestException as e: 

184 logging.debug(f"Connection error: {e}") 

185 raise ConnectionError(f"Connection error occurred: {e}") 

186 

187 def post(self, endpoint, data): 

188 """ 

189 Send a POST request to the Databricks API. 

190 

191 Args: 

192 endpoint: API endpoint (starting with /) 

193 data: JSON data to send in the request body 

194 

195 Returns: 

196 JSON response from the API 

197 

198 Raises: 

199 ValueError: If an HTTP error occurs 

200 ConnectionError: If a connection error occurs 

201 """ 

202 url = f"https://{self.workspace_url}.{self.base_domain}{endpoint}" 

203 logging.debug(f"POST request to: {url}") 

204 

205 try: 

206 response = requests.post(url, headers=self.headers, json=data) 

207 response.raise_for_status() 

208 return response.json() 

209 except requests.exceptions.HTTPError as e: 

210 logging.debug(f"HTTP error: {e}, Response: {response.text}") 

211 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}") 

212 except requests.RequestException as e: 

213 logging.debug(f"Connection error: {e}") 

214 raise ConnectionError(f"Connection error occurred: {e}") 

215 

216 # 

217 # Authentication methods 

218 # 

219 

220 def validate_token(self): 

221 """ 

222 Validate the current token by calling the SCIM Me endpoint. 

223 

224 Returns: 

225 True if the token is valid, False otherwise 

226 """ 

227 try: 

228 response = self.get("/api/2.0/preview/scim/v2/Me") 

229 return True if response else False 

230 except Exception as e: 

231 logging.debug(f"Token validation failed: {e}") 

232 return False 

233 

234 # 

235 # Unity Catalog methods 

236 # 

237 

238 def list_catalogs(self, include_browse=False, max_results=None, page_token=None): 

239 """ 

240 Gets an array of catalogs in the metastore. 

241 

242 Args: 

243 include_browse: Whether to include catalogs for which the principal can only access selective metadata 

244 max_results: Maximum number of catalogs to return (optional) 

245 page_token: Opaque pagination token to go to next page (optional) 

246 

247 Returns: 

248 Dictionary containing: 

249 - catalogs: List of catalogs 

250 - next_page_token: Token for retrieving the next page (if available) 

251 """ 

252 params = {} 

253 if include_browse: 

254 params["include_browse"] = "true" 

255 if max_results is not None: 

256 params["max_results"] = str(max_results) 

257 if page_token: 

258 params["page_token"] = page_token 

259 

260 if params: 

261 return self.get_with_params("/api/2.1/unity-catalog/catalogs", params) 

262 return self.get("/api/2.1/unity-catalog/catalogs") 

263 

264 def get_catalog(self, catalog_name): 

265 """ 

266 Gets a catalog from Unity Catalog. 

267 

268 Args: 

269 catalog_name: Name of the catalog 

270 

271 Returns: 

272 Catalog information 

273 """ 

274 return self.get(f"/api/2.1/unity-catalog/catalogs/{catalog_name}") 

275 

276 def list_schemas( 

277 self, catalog_name, include_browse=False, max_results=None, page_token=None 

278 ): 

279 """ 

280 Gets an array of schemas for a catalog in the metastore. 

281 

282 Args: 

283 catalog_name: Parent catalog for schemas of interest (required) 

284 include_browse: Whether to include schemas for which the principal can only access selective metadata 

285 max_results: Maximum number of schemas to return (optional) 

286 page_token: Opaque pagination token to go to next page (optional) 

287 

288 Returns: 

289 Dictionary containing: 

290 - schemas: List of schemas 

291 - next_page_token: Token for retrieving the next page (if available) 

292 """ 

293 params = {"catalog_name": catalog_name} 

294 if include_browse: 

295 params["include_browse"] = "true" 

296 if max_results is not None: 

297 params["max_results"] = str(max_results) 

298 if page_token: 

299 params["page_token"] = page_token 

300 

301 return self.get_with_params("/api/2.1/unity-catalog/schemas", params) 

302 

303 def get_schema(self, full_name): 

304 """ 

305 Gets a schema from Unity Catalog. 

306 

307 Args: 

308 full_name: Full name of the schema in the format 'catalog_name.schema_name' 

309 

310 Returns: 

311 Schema information 

312 """ 

313 return self.get(f"/api/2.1/unity-catalog/schemas/{full_name}") 

314 

315 def list_tables( 

316 self, 

317 catalog_name, 

318 schema_name, 

319 max_results=None, 

320 page_token=None, 

321 include_delta_metadata=False, 

322 omit_columns=False, 

323 omit_properties=False, 

324 omit_username=False, 

325 include_browse=False, 

326 include_manifest_capabilities=False, 

327 ): 

328 """ 

329 Gets an array of all tables for the current metastore under the parent catalog and schema. 

330 

331 Args: 

332 catalog_name: Name of parent catalog for tables of interest (required) 

333 schema_name: Parent schema of tables (required) 

334 max_results: Maximum number of tables to return (optional) 

335 page_token: Opaque token to send for the next page of results (optional) 

336 include_delta_metadata: Whether delta metadata should be included (optional) 

337 omit_columns: Whether to omit columns from the response (optional) 

338 omit_properties: Whether to omit properties from the response (optional) 

339 omit_username: Whether to omit username from the response (optional) 

340 include_browse: Whether to include tables with selective metadata access (optional) 

341 include_manifest_capabilities: Whether to include table capabilities (optional) 

342 

343 Returns: 

344 Dictionary containing: 

345 - tables: List of tables 

346 - next_page_token: Token for retrieving the next page (if available) 

347 """ 

348 params = {"catalog_name": catalog_name, "schema_name": schema_name} 

349 

350 if max_results is not None: 

351 params["max_results"] = str(max_results) 

352 if page_token: 

353 params["page_token"] = page_token 

354 if include_delta_metadata: 

355 params["include_delta_metadata"] = "true" 

356 if omit_columns: 

357 params["omit_columns"] = "true" 

358 if omit_properties: 

359 params["omit_properties"] = "true" 

360 if omit_username: 

361 params["omit_username"] = "true" 

362 if include_browse: 

363 params["include_browse"] = "true" 

364 if include_manifest_capabilities: 

365 params["include_manifest_capabilities"] = "true" 

366 return self.get_with_params("/api/2.1/unity-catalog/tables", params) 

367 

368 def get_table( 

369 self, 

370 full_name, 

371 include_delta_metadata=False, 

372 include_browse=False, 

373 include_manifest_capabilities=False, 

374 ): 

375 """ 

376 Gets a table from the metastore for a specific catalog and schema. 

377 

378 Args: 

379 full_name: Full name of the table in format 'catalog_name.schema_name.table_name' 

380 include_delta_metadata: Whether delta metadata should be included (optional) 

381 include_browse: Whether to include tables with selective metadata access (optional) 

382 include_manifest_capabilities: Whether to include table capabilities (optional) 

383 

384 Returns: 

385 Table information 

386 """ 

387 params = {} 

388 if include_delta_metadata: 

389 params["include_delta_metadata"] = "true" 

390 if include_browse: 

391 params["include_browse"] = "true" 

392 if include_manifest_capabilities: 

393 params["include_manifest_capabilities"] = "true" 

394 

395 if params: 

396 return self.get_with_params( 

397 f"/api/2.1/unity-catalog/tables/{full_name}", params 

398 ) 

399 return self.get(f"/api/2.1/unity-catalog/tables/{full_name}") 

400 

401 def list_volumes( 

402 self, 

403 catalog_name, 

404 schema_name, 

405 max_results=None, 

406 page_token=None, 

407 include_browse=False, 

408 ): 

409 """ 

410 Gets an array of volumes for the current metastore under the parent catalog and schema. 

411 

412 Args: 

413 catalog_name: Name of parent catalog (required) 

414 schema_name: Name of parent schema (required) 

415 max_results: Maximum number of volumes to return (optional) 

416 page_token: Opaque token for pagination (optional) 

417 include_browse: Whether to include volumes with selective metadata access (optional) 

418 

419 Returns: 

420 Dictionary containing: 

421 - volumes: List of volumes 

422 - next_page_token: Token for retrieving the next page (if available) 

423 """ 

424 params = {"catalog_name": catalog_name, "schema_name": schema_name} 

425 if max_results is not None: 

426 params["max_results"] = str(max_results) 

427 if page_token: 

428 params["page_token"] = page_token 

429 if include_browse: 

430 params["include_browse"] = "true" 

431 

432 return self.get_with_params("/api/2.1/unity-catalog/volumes", params) 

433 

434 def create_volume(self, catalog_name, schema_name, name, volume_type="MANAGED"): 

435 """ 

436 Create a new volume in Unity Catalog. 

437 

438 Args: 

439 catalog_name: The name of the catalog where the volume will be created 

440 schema_name: The name of the schema where the volume will be created 

441 name: The name of the volume to create 

442 volume_type: The type of volume to create (default: "MANAGED") 

443 

444 Returns: 

445 Dict containing the created volume information 

446 """ 

447 data = { 

448 "catalog_name": catalog_name, 

449 "schema_name": schema_name, 

450 "name": name, 

451 "volume_type": volume_type, 

452 } 

453 return self.post("/api/2.1/unity-catalog/volumes", data) 

454 

455 # 

456 # Models and Serving methods 

457 # 

458 

459 def list_models(self): 

460 """ 

461 Fetch a list of models from the Databricks Serving API. 

462 

463 Returns: 

464 List of available model endpoints 

465 """ 

466 response = self.get("/api/2.0/serving-endpoints") 

467 return response.get("endpoints", []) 

468 

469 def get_model(self, model_name): 

470 """ 

471 Get details of a specific model from Databricks Serving API. 

472 

473 Args: 

474 model_name: Name of the model to retrieve 

475 

476 Returns: 

477 Model details if found 

478 """ 

479 try: 

480 return self.get(f"/api/2.0/serving-endpoints/{model_name}") 

481 except ValueError as e: 

482 if "404" in str(e): 

483 logging.warning(f"Model '{model_name}' not found") 

484 return None 

485 raise 

486 

487 # 

488 # Warehouse methods 

489 # 

490 

491 def list_warehouses(self): 

492 """ 

493 Lists all SQL warehouses in the Databricks workspace. 

494 

495 Returns: 

496 List of warehouses 

497 """ 

498 response = self.get("/api/2.0/sql/warehouses") 

499 return response.get("warehouses", []) 

500 

501 def get_warehouse(self, warehouse_id): 

502 """ 

503 Gets information about a specific SQL warehouse. 

504 

505 Args: 

506 warehouse_id: ID of the SQL warehouse 

507 

508 Returns: 

509 Warehouse information 

510 """ 

511 return self.get(f"/api/2.0/sql/warehouses/{warehouse_id}") 

512 

513 def create_warehouse(self, opts): 

514 """ 

515 Creates a new SQL warehouse. 

516 

517 Args: 

518 opts: Dictionary containing warehouse configuration options 

519 

520 Returns: 

521 Created warehouse information 

522 """ 

523 return self.post("/api/2.0/sql/warehouses", opts) 

524 

525 def submit_sql_statement( 

526 self, 

527 sql_text, 

528 warehouse_id, 

529 catalog=None, 

530 wait_timeout="30s", 

531 on_wait_timeout="CONTINUE", 

532 ): 

533 """ 

534 Submit a SQL statement to Databricks SQL warehouse and wait for completion. 

535 

536 Args: 

537 sql_text: SQL statement to execute 

538 warehouse_id: ID of the SQL warehouse 

539 catalog: Optional catalog name 

540 wait_timeout: How long to wait for query completion (default "30s") 

541 on_wait_timeout: What to do on timeout ("CONTINUE" or "CANCEL") 

542 

543 Returns: 

544 Dictionary containing the SQL statement execution result 

545 """ 

546 data = { 

547 "statement": sql_text, 

548 "warehouse_id": warehouse_id, 

549 "wait_timeout": wait_timeout, 

550 "on_wait_timeout": on_wait_timeout, 

551 } 

552 

553 if catalog: 

554 data["catalog"] = catalog 

555 

556 # Submit the SQL statement 

557 response = self.post("/api/2.0/sql/statements", data) 

558 statement_id = response.get("statement_id") 

559 

560 # Poll until complete 

561 while True: 

562 status = self.get(f"/api/2.0/sql/statements/{statement_id}") 

563 state = status.get("status", {}).get("state", status.get("state")) 

564 if state not in ["PENDING", "RUNNING"]: 

565 break 

566 time.sleep(1) 

567 

568 return status 

569 

570 # 

571 # Jobs methods 

572 # 

573 

574 def submit_job_run(self, config_path, init_script_path, run_name=None): 

575 """ 

576 Submit a one-time Databricks job run using the /runs/submit endpoint. 

577 

578 Args: 

579 config_path: Path to the configuration file for the job in the Volume 

580 init_script_path: Path to the initialization script 

581 run_name: Optional name for the run. If None, a default name will be generated. 

582 

583 Returns: 

584 Dict containing the job run information (including run_id) 

585 """ 

586 if not run_name: 

587 run_name = ( 

588 f"Chuck AI One-Time Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" 

589 ) 

590 

591 # Define the task and cluster for the one-time run 

592 # Create base cluster configuration 

593 cluster_config = { 

594 "cluster_name": "", 

595 "spark_version": "16.0.x-cpu-ml-scala2.12", 

596 "init_scripts": [ 

597 { 

598 "volumes": { 

599 "destination": init_script_path, 

600 } 

601 } 

602 ], 

603 "node_type_id": self.get_compute_node_type(), 

604 "custom_tags": { 

605 "stack": "aws-dev", 

606 "sys": "chuck", 

607 "tenant": "amperity", 

608 }, 

609 "spark_env_vars": { 

610 "JNAME": "zulu17-ca-amd64", 

611 "CHUCK_API_URL": f"https://{get_amperity_url()}", 

612 "DEBUG_INIT_SRIPT_URL": init_script_path, 

613 "DEBUG_CONFIG_PATH": config_path, 

614 }, 

615 "enable_elastic_disk": False, 

616 "data_security_mode": "SINGLE_USER", 

617 "runtime_engine": "STANDARD", 

618 "autoscale": {"min_workers": 10, "max_workers": 50}, 

619 } 

620 

621 # Add cloud-specific attributes 

622 cluster_config.update(self.get_cloud_attributes()) 

623 

624 run_payload = { 

625 "run_name": run_name, 

626 "tasks": [ 

627 { 

628 "task_key": "Run_Stitch", 

629 "run_if": "ALL_SUCCESS", 

630 "spark_jar_task": { 

631 "jar_uri": "", 

632 "main_class_name": os.environ.get( 

633 "MAIN_CLASS", "amperity.stitch_standalone.chuck_main" 

634 ), 

635 "parameters": [ 

636 "", 

637 config_path, 

638 ], 

639 "run_as_repl": True, 

640 }, 

641 "libraries": [{"jar": "file:///opt/amperity/job.jar"}], 

642 "timeout_seconds": 0, 

643 "email_notifications": {}, 

644 "webhook_notifications": {}, 

645 "new_cluster": cluster_config, 

646 }, 

647 ], 

648 "timeout_seconds": 0, 

649 } 

650 

651 return self.post("/api/2.2/jobs/runs/submit", run_payload) 

652 

653 def get_job_run_status(self, run_id): 

654 """ 

655 Get the status of a Databricks job run. 

656 

657 Args: 

658 run_id: The job run ID (as str or int) 

659 

660 Returns: 

661 Dict containing the job run status information 

662 """ 

663 params = {"run_id": run_id} 

664 return self.get_with_params("/api/2.2/jobs/runs/get", params) 

665 

666 # 

667 # File system methods 

668 # 

669 

670 def upload_file(self, path, file_path=None, content=None, overwrite=False): 

671 """ 

672 Upload a file using the /api/2.0/fs/files endpoint. 

673 

674 Args: 

675 path: The destination path (e.g., "/Volumes/my-catalog/my-schema/my-volume/file.txt") 

676 file_path: Local file path to upload (mutually exclusive with content) 

677 content: String content to upload (mutually exclusive with file_path) 

678 overwrite: Whether to overwrite an existing file 

679 

680 Returns: 

681 True if successful (API returns no content on success) 

682 

683 Raises: 

684 ValueError: If both file_path and content are provided or neither is provided 

685 ValueError: If an HTTP error occurs 

686 ConnectionError: If a connection error occurs 

687 """ 

688 if (file_path and content) or (not file_path and not content): 

689 raise ValueError("Exactly one of file_path or content must be provided") 

690 

691 # URL encode the path and make sure it starts with a slash 

692 if not path.startswith("/"): 

693 path = f"/{path}" 

694 

695 # Remove duplicate slashes if any 

696 while "//" in path: 

697 path = path.replace("//", "/") 

698 

699 # URL encode path components but preserve the slashes 

700 encoded_path = "/".join( 

701 urllib.parse.quote(component) for component in path.split("/") if component 

702 ) 

703 encoded_path = f"/{encoded_path}" 

704 

705 url = f"https://{self.workspace_url}.{self.base_domain}/api/2.0/fs/files{encoded_path}" 

706 

707 if overwrite: 

708 url += "?overwrite=true" 

709 

710 logging.debug(f"File upload request to: {url}") 

711 

712 headers = self.headers.copy() 

713 headers.update({"Content-Type": "application/octet-stream"}) 

714 

715 # Get binary data to upload 

716 if file_path: 

717 with open(file_path, "rb") as f: 

718 binary_data = f.read() 

719 else: 

720 # Convert string content to bytes 

721 binary_data = content.encode("utf-8") 

722 

723 try: 

724 # Use PUT request with raw binary data in the body 

725 response = requests.put(url, headers=headers, data=binary_data) 

726 response.raise_for_status() 

727 # API returns 204 No Content on success 

728 return True 

729 except requests.exceptions.HTTPError as e: 

730 logging.debug(f"HTTP error: {e}, Response: {response.text}") 

731 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}") 

732 except requests.RequestException as e: 

733 logging.debug(f"Connection error: {e}") 

734 raise ConnectionError(f"Connection error occurred: {e}") 

735 

736 def store_dbfs_file(self, path, contents, overwrite=True): 

737 """ 

738 Store content to DBFS using the /api/2.0/dbfs/put endpoint. 

739 

740 Args: 

741 path: Path in DBFS to store the file 

742 contents: String content to store (will be JSON encoded) 

743 overwrite: Whether to overwrite an existing file 

744 

745 Returns: 

746 True if successful 

747 """ 

748 # Encode the content as base64 

749 encoded_contents = ( 

750 base64.b64encode(contents.encode()).decode() 

751 if isinstance(contents, str) 

752 else base64.b64encode(contents).decode() 

753 ) 

754 

755 # Prepare the request with file content and path 

756 request_data = { 

757 "path": path, 

758 "contents": encoded_contents, 

759 "overwrite": overwrite, 

760 } 

761 

762 # Call DBFS API 

763 self.post("/api/2.0/dbfs/put", request_data) 

764 return True 

765 

766 # 

767 # Amperity-specific methods 

768 # 

769 

770 def fetch_amperity_job_init(self, token, api_url: str | None = None): 

771 """ 

772 Fetch initialization script for Amperity jobs. 

773 

774 Args: 

775 token: Amperity authentication token 

776 api_url: Optional override for the job init endpoint 

777 

778 Returns: 

779 Dict containing the initialization script data 

780 """ 

781 try: 

782 headers = { 

783 "Authorization": f"Bearer {token}", 

784 "Content-Type": "application/json", 

785 } 

786 

787 if not api_url: 

788 api_url = f"https://{get_amperity_url()}/api/job/launch" 

789 

790 response = requests.post(api_url, headers=headers, json={}) 

791 response.raise_for_status() 

792 return response.json() 

793 except requests.exceptions.HTTPError as e: 

794 response = e.response 

795 resp_text = response.text if response else "" 

796 logging.debug(f"HTTP error: {e}, Response: {resp_text}") 

797 if response is not None: 

798 try: 

799 message = response.json().get("message", resp_text) 

800 except ValueError: 

801 message = resp_text 

802 raise ValueError( 

803 f"{response.status_code} Error: {message}. Please /logout and /login again" 

804 ) 

805 raise ValueError( 

806 f"HTTP error occurred: {e}. Please /logout and /login again" 

807 ) 

808 except requests.RequestException as e: 

809 logging.debug(f"Connection error: {e}") 

810 raise ConnectionError(f"Connection error occurred: {e}") 

811 

812 def get_current_user(self): 

813 """ 

814 Get the current user's username from Databricks API. 

815 

816 Returns: 

817 Username string from the current user 

818 """ 

819 try: 

820 response = self.get("/api/2.0/preview/scim/v2/Me") 

821 username = response.get("userName") 

822 if not username: 

823 logging.debug("Username not found in response") 

824 raise ValueError("Username not found in API response") 

825 return username 

826 except Exception as e: 

827 logging.debug(f"Error getting current user: {e}") 

828 raise 

829 

830 def create_stitch_notebook( 

831 self, table_path, notebook_name=None, stitch_config=None, datasources=None 

832 ): 

833 """ 

834 Create a stitch notebook for the given table path. 

835 

836 This function will: 

837 1. Load the stitch notebook template 

838 2. Extract the notebook name from metadata (or use provided name) 

839 3. Get datasources either from provided list, stitch_config, or query the table 

840 4. Replace template fields with appropriate values 

841 5. Import the notebook to Databricks 

842 

843 Args: 

844 table_path: Full path to the unified table in format catalog.schema.table 

845 notebook_name: Optional name for the notebook. If not provided, the name will be 

846 extracted from the template's metadata. 

847 stitch_config: Optional stitch configuration dictionary. If provided, datasources will be 

848 extracted from stitch_config["tables"][*]["path"] values. 

849 datasources: Optional list of datasource values. If provided, these will be used 

850 instead of querying the table. 

851 

852 Returns: 

853 Dictionary containing the notebook path and status 

854 """ 

855 # 1. Load the template notebook 

856 template_path = os.path.join( 

857 os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 

858 "assets", 

859 "stitch_notebook_template.ipynb", 

860 ) 

861 

862 try: 

863 with open(template_path, "r") as f: 

864 notebook_content = json.load(f) 

865 

866 # 2. Extract notebook name from metadata (if not provided) 

867 extracted_name = "Stitch Results" 

868 for cell in notebook_content.get("cells", []): 

869 if cell.get("cell_type") == "markdown": 

870 source = cell.get("source", []) 

871 if source and source[0].startswith("# "): 

872 extracted_name = source[0].replace("# ", "").strip() 

873 break 

874 

875 # Use the provided notebook name if available, otherwise use the extracted name 

876 final_notebook_name = notebook_name if notebook_name else extracted_name 

877 

878 # 3. Get distinct datasource values 

879 # First check if datasources were directly provided 

880 if not datasources: 

881 # Check if we should extract datasources from stitch_config 

882 if ( 

883 stitch_config 

884 and "tables" in stitch_config 

885 and stitch_config["tables"] 

886 ): 

887 # Extract unique datasources from the path values in stitch_config tables 

888 datasource_set = set() 

889 for table in stitch_config["tables"]: 

890 if "path" in table: 

891 datasource_set.add(table["path"]) 

892 

893 datasources = list(datasource_set) 

894 # Extract datasources successfully from stitch config 

895 

896 # If we still don't have datasources, query the table directly 

897 if not datasources: 

898 # Get the configured warehouse ID 

899 warehouse_id = get_warehouse_id() 

900 

901 # If no warehouse ID is configured, try to find a default one 

902 if not warehouse_id: 

903 warehouses = self.list_warehouses() 

904 if not warehouses: 

905 raise ValueError( 

906 "No SQL warehouses found and no warehouse configured. Please select a warehouse using /warehouse_selection." 

907 ) 

908 

909 # Use the first available warehouse 

910 warehouse_id = warehouses[0]["id"] 

911 logging.warning( 

912 f"No warehouse configured. Using first available warehouse: {warehouse_id}" 

913 ) 

914 

915 # Query for distinct datasource values 

916 sql_query = f"SELECT DISTINCT datasource FROM {table_path}" 

917 # Execute SQL query to get datasources 

918 result = self.submit_sql_statement(sql_query, warehouse_id) 

919 

920 # Extract the results from the query response 

921 if ( 

922 result 

923 and result.get("result") 

924 and result["result"].get("data_array") 

925 ): 

926 datasources = [row[0] for row in result["result"]["data_array"]] 

927 # Successfully extracted datasources from query 

928 

929 # If we still don't have datasources, use a default value 

930 if not datasources: 

931 logging.warning(f"No datasources found for {table_path}") 

932 datasources = ["default_source"] 

933 # Use default datasource as a fallback 

934 

935 # 4. Create the source names JSON mapping 

936 source_names_json = {} 

937 for source in datasources: 

938 source_names_json[source] = source 

939 

940 # Convert to JSON string (formatted nicely) 

941 source_names_str = json.dumps(source_names_json, indent=4) 

942 # Source mapping created for template 

943 

944 # 5. Replace the template fields 

945 # Replace template placeholders with actual values 

946 

947 # Need to directly modify the notebook cells rather than doing string replacement on the JSON 

948 for cell in notebook_content.get("cells", []): 

949 if cell.get("cell_type") == "code": 

950 source_lines = cell.get("source", []) 

951 for i, line in enumerate(source_lines): 

952 if ( 

953 '"{UNIFIED_PATH}"' in line 

954 or "unified_coalesced_path = " in line 

955 ): 

956 # Found placeholder in notebook 

957 # Replace the line with our table path 

958 source_lines[i] = ( 

959 f'unified_coalesced_path = "{table_path}"\n' 

960 ) 

961 break 

962 

963 # Replace source semantic mapping in the cells 

964 replaced_mapping = False 

965 for cell in notebook_content.get("cells", []): 

966 if cell.get("cell_type") == "code": 

967 source_lines = cell.get("source", []) 

968 # Look for the source_semantic_mapping definition 

969 for i, line in enumerate(source_lines): 

970 if "source_semantic_mapping =" in line and not replaced_mapping: 

971 # Found source names placeholder 

972 

973 # Find the closing brace 

974 closing_index = None 

975 opening_count = 0 

976 for j in range(i, len(source_lines)): 

977 if "{" in source_lines[j]: 

978 opening_count += 1 

979 if "}" in source_lines[j]: 

980 opening_count -= 1 

981 if opening_count == 0: 

982 closing_index = j 

983 break 

984 

985 if closing_index is not None: 

986 # Replace the mapping with our custom mapping 

987 mapping_start = i 

988 mapping_end = closing_index + 1 

989 new_line = ( 

990 f"source_semantic_mapping = {source_names_str}\n" 

991 ) 

992 source_lines[mapping_start:mapping_end] = [new_line] 

993 replaced_mapping = True 

994 # Import notebook to workspace 

995 break 

996 

997 if not replaced_mapping: 

998 logging.warning( 

999 "Could not find source_semantic_mapping in the notebook template to replace" 

1000 ) 

1001 

1002 # 6. Get the current user's username 

1003 username = self.get_current_user() 

1004 

1005 # 7. Construct the notebook path 

1006 notebook_path = f"/Workspace/Users/{username}/{final_notebook_name}" 

1007 

1008 # 8. Convert the notebook to base64 for API call 

1009 notebook_content_str = json.dumps(notebook_content) 

1010 encoded_content = base64.b64encode(notebook_content_str.encode()).decode() 

1011 

1012 # 9. Import the notebook to Databricks 

1013 import_data = { 

1014 "path": notebook_path, 

1015 "content": encoded_content, 

1016 "format": "JUPYTER", 

1017 "overwrite": True, 

1018 } 

1019 

1020 self.post("/api/2.0/workspace/import", import_data) 

1021 

1022 # 10. Log success and return the path 

1023 # Notebook created successfully 

1024 

1025 return {"notebook_path": notebook_path, "status": "success"} 

1026 

1027 except Exception as e: 

1028 logging.debug(f"Error creating stitch notebook: {e}") 

1029 raise