Coverage for src/chuck_data/clients/databricks.py: 0%

336 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-05 22:56 -0700

1""" 

2Reusable Databricks API client for authentication and requests. 

3""" 

4 

5import base64 

6import json 

7import logging 

8import os 

9import requests 

10import time 

11import urllib.parse 

12from datetime import datetime 

13from ..config import get_warehouse_id 

14from ..clients.amperity import get_amperity_url 

15from ..databricks.url_utils import ( 

16 detect_cloud_provider, 

17 normalize_workspace_url, 

18) 

19 

20 

21class DatabricksAPIClient: 

22 """Reusable Databricks API client for authentication and requests.""" 

23 

24 def __init__(self, workspace_url, token): 

25 """ 

26 Initialize the API client. 

27 

28 Args: 

29 workspace_url: Databricks workspace URL (with or without protocol/domain) 

30 token: Databricks API token 

31 """ 

32 self.original_url = workspace_url 

33 self.workspace_url = self._normalize_workspace_url(workspace_url) 

34 self.cloud_provider = detect_cloud_provider(workspace_url) 

35 self.base_domain = self._get_base_domain() 

36 self.token = token 

37 self.headers = { 

38 "Authorization": f"Bearer {self.token}", 

39 "User-Agent": "amperity", 

40 } 

41 

42 def _normalize_workspace_url(self, url): 

43 """ 

44 Normalize the workspace URL to the format needed for API calls. 

45 

46 Args: 

47 url: Input workspace URL that might be in various formats 

48 

49 Returns: 

50 Cleaned workspace URL (workspace ID only) 

51 """ 

52 return normalize_workspace_url(url) 

53 

54 def _get_base_domain(self): 

55 """ 

56 Get the appropriate base domain based on the cloud provider. 

57 

58 Returns: 

59 Base domain string for the detected cloud provider 

60 """ 

61 from ..databricks.url_utils import DATABRICKS_DOMAIN_MAP 

62 

63 return DATABRICKS_DOMAIN_MAP.get( 

64 self.cloud_provider, DATABRICKS_DOMAIN_MAP["AWS"] 

65 ) 

66 

67 def get_compute_node_type(self): 

68 """ 

69 Get the appropriate compute node type based on the cloud provider. 

70 

71 Returns: 

72 Node type string for the detected cloud provider 

73 """ 

74 node_type_map = { 

75 "AWS": "r5d.4xlarge", 

76 "Azure": "Standard_E16ds_v4", 

77 "GCP": "n2-standard-16", # Default GCP node type 

78 "Generic": "r5d.4xlarge", # Default to AWS 

79 } 

80 return node_type_map.get(self.cloud_provider, "r5d.4xlarge") 

81 

82 def get_cloud_attributes(self): 

83 """ 

84 Get cloud-specific attributes for cluster configuration. 

85 

86 Returns: 

87 Dictionary containing cloud-specific attributes 

88 """ 

89 if self.cloud_provider == "AWS": 

90 return { 

91 "aws_attributes": { 

92 "first_on_demand": 1, 

93 "availability": "SPOT_WITH_FALLBACK", 

94 "zone_id": "us-west-2b", 

95 "spot_bid_price_percent": 100, 

96 "ebs_volume_count": 0, 

97 } 

98 } 

99 elif self.cloud_provider == "Azure": 

100 return { 

101 "azure_attributes": { 

102 "first_on_demand": 1, 

103 "availability": "SPOT_WITH_FALLBACK_AZURE", 

104 "spot_bid_max_price": -1, 

105 } 

106 } 

107 elif self.cloud_provider == "GCP": 

108 return { 

109 "gcp_attributes": { 

110 "use_preemptible_executors": True, 

111 "google_service_account": None, 

112 } 

113 } 

114 else: 

115 # Default to AWS 

116 return { 

117 "aws_attributes": { 

118 "first_on_demand": 1, 

119 "availability": "SPOT_WITH_FALLBACK", 

120 "zone_id": "us-west-2b", 

121 "spot_bid_price_percent": 100, 

122 "ebs_volume_count": 0, 

123 } 

124 } 

125 

126 # 

127 # Base API request methods 

128 # 

129 

130 def get(self, endpoint): 

131 """ 

132 Send a GET request to the Databricks API. 

133 

134 Args: 

135 endpoint: API endpoint (starting with /) 

136 

137 Returns: 

138 JSON response from the API 

139 

140 Raises: 

141 ValueError: If an HTTP error occurs 

142 ConnectionError: If a connection error occurs 

143 """ 

144 url = f"https://{self.workspace_url}.{self.base_domain}{endpoint}" 

145 logging.debug(f"GET request to: {url}") 

146 

147 try: 

148 response = requests.get(url, headers=self.headers) 

149 response.raise_for_status() 

150 return response.json() 

151 except requests.exceptions.HTTPError as e: 

152 logging.debug(f"HTTP error: {e}, Response: {response.text}") 

153 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}") 

154 except requests.RequestException as e: 

155 logging.debug(f"Connection error: {e}") 

156 raise ConnectionError(f"Connection error occurred: {e}") 

157 

158 def get_with_params(self, endpoint, params=None): 

159 """ 

160 Send a GET request to the Databricks API with query parameters. 

161 

162 Args: 

163 endpoint: API endpoint (starting with /) 

164 params: Dictionary of query parameters 

165 

166 Returns: 

167 JSON response from the API 

168 

169 Raises: 

170 ValueError: If an HTTP error occurs 

171 ConnectionError: If a connection error occurs 

172 """ 

173 url = f"https://{self.workspace_url}.{self.base_domain}{endpoint}" 

174 logging.debug(f"GET request with params to: {url}") 

175 

176 try: 

177 response = requests.get(url, headers=self.headers, params=params) 

178 response.raise_for_status() 

179 return response.json() 

180 except requests.exceptions.HTTPError as e: 

181 logging.debug(f"HTTP error: {e}, Response: {response.text}") 

182 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}") 

183 except requests.RequestException as e: 

184 logging.debug(f"Connection error: {e}") 

185 raise ConnectionError(f"Connection error occurred: {e}") 

186 

187 def post(self, endpoint, data): 

188 """ 

189 Send a POST request to the Databricks API. 

190 

191 Args: 

192 endpoint: API endpoint (starting with /) 

193 data: JSON data to send in the request body 

194 

195 Returns: 

196 JSON response from the API 

197 

198 Raises: 

199 ValueError: If an HTTP error occurs 

200 ConnectionError: If a connection error occurs 

201 """ 

202 url = f"https://{self.workspace_url}.{self.base_domain}{endpoint}" 

203 logging.debug(f"POST request to: {url}") 

204 

205 try: 

206 response = requests.post(url, headers=self.headers, json=data) 

207 response.raise_for_status() 

208 return response.json() 

209 except requests.exceptions.HTTPError as e: 

210 logging.debug(f"HTTP error: {e}, Response: {response.text}") 

211 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}") 

212 except requests.RequestException as e: 

213 logging.debug(f"Connection error: {e}") 

214 raise ConnectionError(f"Connection error occurred: {e}") 

215 

216 # 

217 # Authentication methods 

218 # 

219 

220 def validate_token(self): 

221 """ 

222 Validate the current token by calling the SCIM Me endpoint. 

223 

224 Returns: 

225 True if the token is valid, False otherwise 

226 """ 

227 try: 

228 response = self.get("/api/2.0/preview/scim/v2/Me") 

229 return True if response else False 

230 except Exception as e: 

231 logging.debug(f"Token validation failed: {e}") 

232 return False 

233 

234 # 

235 # Unity Catalog methods 

236 # 

237 

238 def list_catalogs(self, include_browse=False, max_results=None, page_token=None): 

239 """ 

240 Gets an array of catalogs in the metastore. 

241 

242 Args: 

243 include_browse: Whether to include catalogs for which the principal can only access selective metadata 

244 max_results: Maximum number of catalogs to return (optional) 

245 page_token: Opaque pagination token to go to next page (optional) 

246 

247 Returns: 

248 Dictionary containing: 

249 - catalogs: List of catalogs 

250 - next_page_token: Token for retrieving the next page (if available) 

251 """ 

252 params = {} 

253 if include_browse: 

254 params["include_browse"] = "true" 

255 if max_results is not None: 

256 params["max_results"] = str(max_results) 

257 if page_token: 

258 params["page_token"] = page_token 

259 

260 if params: 

261 return self.get_with_params("/api/2.1/unity-catalog/catalogs", params) 

262 return self.get("/api/2.1/unity-catalog/catalogs") 

263 

264 def get_catalog(self, catalog_name): 

265 """ 

266 Gets a catalog from Unity Catalog. 

267 

268 Args: 

269 catalog_name: Name of the catalog 

270 

271 Returns: 

272 Catalog information 

273 """ 

274 return self.get(f"/api/2.1/unity-catalog/catalogs/{catalog_name}") 

275 

276 def list_schemas( 

277 self, catalog_name, include_browse=False, max_results=None, page_token=None 

278 ): 

279 """ 

280 Gets an array of schemas for a catalog in the metastore. 

281 

282 Args: 

283 catalog_name: Parent catalog for schemas of interest (required) 

284 include_browse: Whether to include schemas for which the principal can only access selective metadata 

285 max_results: Maximum number of schemas to return (optional) 

286 page_token: Opaque pagination token to go to next page (optional) 

287 

288 Returns: 

289 Dictionary containing: 

290 - schemas: List of schemas 

291 - next_page_token: Token for retrieving the next page (if available) 

292 """ 

293 params = {"catalog_name": catalog_name} 

294 if include_browse: 

295 params["include_browse"] = "true" 

296 if max_results is not None: 

297 params["max_results"] = str(max_results) 

298 if page_token: 

299 params["page_token"] = page_token 

300 

301 return self.get_with_params("/api/2.1/unity-catalog/schemas", params) 

302 

303 def get_schema(self, full_name): 

304 """ 

305 Gets a schema from Unity Catalog. 

306 

307 Args: 

308 full_name: Full name of the schema in the format 'catalog_name.schema_name' 

309 

310 Returns: 

311 Schema information 

312 """ 

313 return self.get(f"/api/2.1/unity-catalog/schemas/{full_name}") 

314 

315 def list_tables( 

316 self, 

317 catalog_name, 

318 schema_name, 

319 max_results=None, 

320 page_token=None, 

321 include_delta_metadata=False, 

322 omit_columns=False, 

323 omit_properties=False, 

324 omit_username=False, 

325 include_browse=False, 

326 include_manifest_capabilities=False, 

327 ): 

328 """ 

329 Gets an array of all tables for the current metastore under the parent catalog and schema. 

330 

331 Args: 

332 catalog_name: Name of parent catalog for tables of interest (required) 

333 schema_name: Parent schema of tables (required) 

334 max_results: Maximum number of tables to return (optional) 

335 page_token: Opaque token to send for the next page of results (optional) 

336 include_delta_metadata: Whether delta metadata should be included (optional) 

337 omit_columns: Whether to omit columns from the response (optional) 

338 omit_properties: Whether to omit properties from the response (optional) 

339 omit_username: Whether to omit username from the response (optional) 

340 include_browse: Whether to include tables with selective metadata access (optional) 

341 include_manifest_capabilities: Whether to include table capabilities (optional) 

342 

343 Returns: 

344 Dictionary containing: 

345 - tables: List of tables 

346 - next_page_token: Token for retrieving the next page (if available) 

347 """ 

348 params = {"catalog_name": catalog_name, "schema_name": schema_name} 

349 

350 if max_results is not None: 

351 params["max_results"] = str(max_results) 

352 if page_token: 

353 params["page_token"] = page_token 

354 if include_delta_metadata: 

355 params["include_delta_metadata"] = "true" 

356 if omit_columns: 

357 params["omit_columns"] = "true" 

358 if omit_properties: 

359 params["omit_properties"] = "true" 

360 if omit_username: 

361 params["omit_username"] = "true" 

362 if include_browse: 

363 params["include_browse"] = "true" 

364 if include_manifest_capabilities: 

365 params["include_manifest_capabilities"] = "true" 

366 

367 return self.get_with_params("/api/2.1/unity-catalog/tables", params) 

368 

369 def get_table( 

370 self, 

371 full_name, 

372 include_delta_metadata=False, 

373 include_browse=False, 

374 include_manifest_capabilities=False, 

375 ): 

376 """ 

377 Gets a table from the metastore for a specific catalog and schema. 

378 

379 Args: 

380 full_name: Full name of the table in format 'catalog_name.schema_name.table_name' 

381 include_delta_metadata: Whether delta metadata should be included (optional) 

382 include_browse: Whether to include tables with selective metadata access (optional) 

383 include_manifest_capabilities: Whether to include table capabilities (optional) 

384 

385 Returns: 

386 Table information 

387 """ 

388 params = {} 

389 if include_delta_metadata: 

390 params["include_delta_metadata"] = "true" 

391 if include_browse: 

392 params["include_browse"] = "true" 

393 if include_manifest_capabilities: 

394 params["include_manifest_capabilities"] = "true" 

395 

396 if params: 

397 return self.get_with_params( 

398 f"/api/2.1/unity-catalog/tables/{full_name}", params 

399 ) 

400 return self.get(f"/api/2.1/unity-catalog/tables/{full_name}") 

401 

402 def list_volumes( 

403 self, 

404 catalog_name, 

405 schema_name, 

406 max_results=None, 

407 page_token=None, 

408 include_browse=False, 

409 ): 

410 """ 

411 Gets an array of volumes for the current metastore under the parent catalog and schema. 

412 

413 Args: 

414 catalog_name: Name of parent catalog (required) 

415 schema_name: Name of parent schema (required) 

416 max_results: Maximum number of volumes to return (optional) 

417 page_token: Opaque token for pagination (optional) 

418 include_browse: Whether to include volumes with selective metadata access (optional) 

419 

420 Returns: 

421 Dictionary containing: 

422 - volumes: List of volumes 

423 - next_page_token: Token for retrieving the next page (if available) 

424 """ 

425 params = {"catalog_name": catalog_name, "schema_name": schema_name} 

426 if max_results is not None: 

427 params["max_results"] = str(max_results) 

428 if page_token: 

429 params["page_token"] = page_token 

430 if include_browse: 

431 params["include_browse"] = "true" 

432 

433 return self.get_with_params("/api/2.1/unity-catalog/volumes", params) 

434 

435 def create_volume(self, catalog_name, schema_name, name, volume_type="MANAGED"): 

436 """ 

437 Create a new volume in Unity Catalog. 

438 

439 Args: 

440 catalog_name: The name of the catalog where the volume will be created 

441 schema_name: The name of the schema where the volume will be created 

442 name: The name of the volume to create 

443 volume_type: The type of volume to create (default: "MANAGED") 

444 

445 Returns: 

446 Dict containing the created volume information 

447 """ 

448 data = { 

449 "catalog_name": catalog_name, 

450 "schema_name": schema_name, 

451 "name": name, 

452 "volume_type": volume_type, 

453 } 

454 return self.post("/api/2.1/unity-catalog/volumes", data) 

455 

456 # 

457 # Models and Serving methods 

458 # 

459 

460 def list_models(self): 

461 """ 

462 Fetch a list of models from the Databricks Serving API. 

463 

464 Returns: 

465 List of available model endpoints 

466 """ 

467 response = self.get("/api/2.0/serving-endpoints") 

468 return response.get("endpoints", []) 

469 

470 def get_model(self, model_name): 

471 """ 

472 Get details of a specific model from Databricks Serving API. 

473 

474 Args: 

475 model_name: Name of the model to retrieve 

476 

477 Returns: 

478 Model details if found 

479 """ 

480 try: 

481 return self.get(f"/api/2.0/serving-endpoints/{model_name}") 

482 except ValueError as e: 

483 if "404" in str(e): 

484 logging.warning(f"Model '{model_name}' not found") 

485 return None 

486 raise 

487 

488 # 

489 # Warehouse methods 

490 # 

491 

492 def list_warehouses(self): 

493 """ 

494 Lists all SQL warehouses in the Databricks workspace. 

495 

496 Returns: 

497 List of warehouses 

498 """ 

499 response = self.get("/api/2.0/sql/warehouses") 

500 return response.get("warehouses", []) 

501 

502 def get_warehouse(self, warehouse_id): 

503 """ 

504 Gets information about a specific SQL warehouse. 

505 

506 Args: 

507 warehouse_id: ID of the SQL warehouse 

508 

509 Returns: 

510 Warehouse information 

511 """ 

512 return self.get(f"/api/2.0/sql/warehouses/{warehouse_id}") 

513 

514 def create_warehouse(self, opts): 

515 """ 

516 Creates a new SQL warehouse. 

517 

518 Args: 

519 opts: Dictionary containing warehouse configuration options 

520 

521 Returns: 

522 Created warehouse information 

523 """ 

524 return self.post("/api/2.0/sql/warehouses", opts) 

525 

526 def submit_sql_statement( 

527 self, 

528 sql_text, 

529 warehouse_id, 

530 catalog=None, 

531 wait_timeout="30s", 

532 on_wait_timeout="CONTINUE", 

533 ): 

534 """ 

535 Submit a SQL statement to Databricks SQL warehouse and wait for completion. 

536 

537 Args: 

538 sql_text: SQL statement to execute 

539 warehouse_id: ID of the SQL warehouse 

540 catalog: Optional catalog name 

541 wait_timeout: How long to wait for query completion (default "30s") 

542 on_wait_timeout: What to do on timeout ("CONTINUE" or "CANCEL") 

543 

544 Returns: 

545 Dictionary containing the SQL statement execution result 

546 """ 

547 data = { 

548 "statement": sql_text, 

549 "warehouse_id": warehouse_id, 

550 "wait_timeout": wait_timeout, 

551 "on_wait_timeout": on_wait_timeout, 

552 } 

553 

554 if catalog: 

555 data["catalog"] = catalog 

556 

557 # Submit the SQL statement 

558 response = self.post("/api/2.0/sql/statements", data) 

559 statement_id = response.get("statement_id") 

560 

561 # Poll until complete 

562 while True: 

563 status = self.get(f"/api/2.0/sql/statements/{statement_id}") 

564 state = status.get("status", {}).get("state", status.get("state")) 

565 if state not in ["PENDING", "RUNNING"]: 

566 break 

567 time.sleep(1) 

568 

569 return status 

570 

571 # 

572 # Jobs methods 

573 # 

574 

575 def submit_job_run(self, config_path, init_script_path, run_name=None): 

576 """ 

577 Submit a one-time Databricks job run using the /runs/submit endpoint. 

578 

579 Args: 

580 config_path: Path to the configuration file for the job in the Volume 

581 init_script_path: Path to the initialization script 

582 run_name: Optional name for the run. If None, a default name will be generated. 

583 

584 Returns: 

585 Dict containing the job run information (including run_id) 

586 """ 

587 if not run_name: 

588 run_name = ( 

589 f"Chuck AI One-Time Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" 

590 ) 

591 

592 # Define the task and cluster for the one-time run 

593 # Create base cluster configuration 

594 cluster_config = { 

595 "cluster_name": "", 

596 "spark_version": "16.0.x-cpu-ml-scala2.12", 

597 "init_scripts": [ 

598 { 

599 "volumes": { 

600 "destination": init_script_path, 

601 } 

602 } 

603 ], 

604 "node_type_id": self.get_compute_node_type(), 

605 "custom_tags": { 

606 "stack": "aws-dev", 

607 "sys": "chuck", 

608 "tenant": "amperity", 

609 }, 

610 "spark_env_vars": { 

611 "JNAME": "zulu17-ca-amd64", 

612 "CHUCK_API_URL": f"https://{get_amperity_url()}", 

613 "DEBUG_INIT_SRIPT_URL": init_script_path, 

614 "DEBUG_CONFIG_PATH": config_path, 

615 }, 

616 "enable_elastic_disk": False, 

617 "data_security_mode": "SINGLE_USER", 

618 "runtime_engine": "STANDARD", 

619 "autoscale": {"min_workers": 10, "max_workers": 50}, 

620 } 

621 

622 # Add cloud-specific attributes 

623 cluster_config.update(self.get_cloud_attributes()) 

624 

625 run_payload = { 

626 "run_name": run_name, 

627 "tasks": [ 

628 { 

629 "task_key": "Run_Stitch", 

630 "run_if": "ALL_SUCCESS", 

631 "spark_jar_task": { 

632 "jar_uri": "", 

633 "main_class_name": os.environ.get( 

634 "MAIN_CLASS", "amperity.stitch_standalone.chuck_main" 

635 ), 

636 "parameters": [ 

637 "", 

638 config_path, 

639 ], 

640 "run_as_repl": True, 

641 }, 

642 "libraries": [{"jar": "file:///opt/amperity/job.jar"}], 

643 "timeout_seconds": 0, 

644 "email_notifications": {}, 

645 "webhook_notifications": {}, 

646 "new_cluster": cluster_config, 

647 }, 

648 ], 

649 "timeout_seconds": 0, 

650 } 

651 

652 return self.post("/api/2.2/jobs/runs/submit", run_payload) 

653 

654 def get_job_run_status(self, run_id): 

655 """ 

656 Get the status of a Databricks job run. 

657 

658 Args: 

659 run_id: The job run ID (as str or int) 

660 

661 Returns: 

662 Dict containing the job run status information 

663 """ 

664 params = {"run_id": run_id} 

665 return self.get_with_params("/api/2.2/jobs/runs/get", params) 

666 

667 # 

668 # File system methods 

669 # 

670 

671 def upload_file(self, path, file_path=None, content=None, overwrite=False): 

672 """ 

673 Upload a file using the /api/2.0/fs/files endpoint. 

674 

675 Args: 

676 path: The destination path (e.g., "/Volumes/my-catalog/my-schema/my-volume/file.txt") 

677 file_path: Local file path to upload (mutually exclusive with content) 

678 content: String content to upload (mutually exclusive with file_path) 

679 overwrite: Whether to overwrite an existing file 

680 

681 Returns: 

682 True if successful (API returns no content on success) 

683 

684 Raises: 

685 ValueError: If both file_path and content are provided or neither is provided 

686 ValueError: If an HTTP error occurs 

687 ConnectionError: If a connection error occurs 

688 """ 

689 if (file_path and content) or (not file_path and not content): 

690 raise ValueError("Exactly one of file_path or content must be provided") 

691 

692 # URL encode the path and make sure it starts with a slash 

693 if not path.startswith("/"): 

694 path = f"/{path}" 

695 

696 # Remove duplicate slashes if any 

697 while "//" in path: 

698 path = path.replace("//", "/") 

699 

700 # URL encode path components but preserve the slashes 

701 encoded_path = "/".join( 

702 urllib.parse.quote(component) for component in path.split("/") if component 

703 ) 

704 encoded_path = f"/{encoded_path}" 

705 

706 url = f"https://{self.workspace_url}.{self.base_domain}/api/2.0/fs/files{encoded_path}" 

707 

708 if overwrite: 

709 url += "?overwrite=true" 

710 

711 logging.debug(f"File upload request to: {url}") 

712 

713 headers = self.headers.copy() 

714 headers.update({"Content-Type": "application/octet-stream"}) 

715 

716 # Get binary data to upload 

717 if file_path: 

718 with open(file_path, "rb") as f: 

719 binary_data = f.read() 

720 else: 

721 # Convert string content to bytes 

722 binary_data = content.encode("utf-8") 

723 

724 try: 

725 # Use PUT request with raw binary data in the body 

726 response = requests.put(url, headers=headers, data=binary_data) 

727 response.raise_for_status() 

728 # API returns 204 No Content on success 

729 return True 

730 except requests.exceptions.HTTPError as e: 

731 logging.debug(f"HTTP error: {e}, Response: {response.text}") 

732 raise ValueError(f"HTTP error occurred: {e}, Response: {response.text}") 

733 except requests.RequestException as e: 

734 logging.debug(f"Connection error: {e}") 

735 raise ConnectionError(f"Connection error occurred: {e}") 

736 

737 def store_dbfs_file(self, path, contents, overwrite=True): 

738 """ 

739 Store content to DBFS using the /api/2.0/dbfs/put endpoint. 

740 

741 Args: 

742 path: Path in DBFS to store the file 

743 contents: String content to store (will be JSON encoded) 

744 overwrite: Whether to overwrite an existing file 

745 

746 Returns: 

747 True if successful 

748 """ 

749 # Encode the content as base64 

750 encoded_contents = ( 

751 base64.b64encode(contents.encode()).decode() 

752 if isinstance(contents, str) 

753 else base64.b64encode(contents).decode() 

754 ) 

755 

756 # Prepare the request with file content and path 

757 request_data = { 

758 "path": path, 

759 "contents": encoded_contents, 

760 "overwrite": overwrite, 

761 } 

762 

763 # Call DBFS API 

764 self.post("/api/2.0/dbfs/put", request_data) 

765 return True 

766 

767 # 

768 # Amperity-specific methods 

769 # 

770 

771 def fetch_amperity_job_init(self, token, api_url: str | None = None): 

772 """ 

773 Fetch initialization script for Amperity jobs. 

774 

775 Args: 

776 token: Amperity authentication token 

777 api_url: Optional override for the job init endpoint 

778 

779 Returns: 

780 Dict containing the initialization script data 

781 """ 

782 try: 

783 headers = { 

784 "Authorization": f"Bearer {token}", 

785 "Content-Type": "application/json", 

786 } 

787 

788 if not api_url: 

789 api_url = f"https://{get_amperity_url()}/api/job/launch" 

790 

791 response = requests.post(api_url, headers=headers, json={}) 

792 response.raise_for_status() 

793 return response.json() 

794 except requests.exceptions.HTTPError as e: 

795 response = e.response 

796 resp_text = response.text if response else "" 

797 logging.debug(f"HTTP error: {e}, Response: {resp_text}") 

798 if response is not None: 

799 try: 

800 message = response.json().get("message", resp_text) 

801 except ValueError: 

802 message = resp_text 

803 raise ValueError( 

804 f"{response.status_code} Error: {message}. Please /logout and /login again" 

805 ) 

806 raise ValueError( 

807 f"HTTP error occurred: {e}. Please /logout and /login again" 

808 ) 

809 except requests.RequestException as e: 

810 logging.debug(f"Connection error: {e}") 

811 raise ConnectionError(f"Connection error occurred: {e}") 

812 

813 def get_current_user(self): 

814 """ 

815 Get the current user's username from Databricks API. 

816 

817 Returns: 

818 Username string from the current user 

819 """ 

820 try: 

821 response = self.get("/api/2.0/preview/scim/v2/Me") 

822 username = response.get("userName") 

823 if not username: 

824 logging.debug("Username not found in response") 

825 raise ValueError("Username not found in API response") 

826 return username 

827 except Exception as e: 

828 logging.debug(f"Error getting current user: {e}") 

829 raise 

830 

831 def create_stitch_notebook( 

832 self, table_path, notebook_name=None, stitch_config=None, datasources=None 

833 ): 

834 """ 

835 Create a stitch notebook for the given table path. 

836 

837 This function will: 

838 1. Load the stitch notebook template 

839 2. Extract the notebook name from metadata (or use provided name) 

840 3. Get datasources either from provided list, stitch_config, or query the table 

841 4. Replace template fields with appropriate values 

842 5. Import the notebook to Databricks 

843 

844 Args: 

845 table_path: Full path to the unified table in format catalog.schema.table 

846 notebook_name: Optional name for the notebook. If not provided, the name will be 

847 extracted from the template's metadata. 

848 stitch_config: Optional stitch configuration dictionary. If provided, datasources will be 

849 extracted from stitch_config["tables"][*]["path"] values. 

850 datasources: Optional list of datasource values. If provided, these will be used 

851 instead of querying the table. 

852 

853 Returns: 

854 Dictionary containing the notebook path and status 

855 """ 

856 # 1. Load the template notebook 

857 template_path = os.path.join( 

858 os.path.dirname(os.path.dirname(__file__)), 

859 "assets", 

860 "stitch_notebook_template.ipynb", 

861 ) 

862 

863 try: 

864 with open(template_path, "r") as f: 

865 notebook_content = json.load(f) 

866 

867 # 2. Extract notebook name from metadata (if not provided) 

868 extracted_name = "Stitch Results" 

869 for cell in notebook_content.get("cells", []): 

870 if cell.get("cell_type") == "markdown": 

871 source = cell.get("source", []) 

872 if source and source[0].startswith("# "): 

873 extracted_name = source[0].replace("# ", "").strip() 

874 break 

875 

876 # Use the provided notebook name if available, otherwise use the extracted name 

877 final_notebook_name = notebook_name if notebook_name else extracted_name 

878 

879 # 3. Get distinct datasource values 

880 # First check if datasources were directly provided 

881 if not datasources: 

882 # Check if we should extract datasources from stitch_config 

883 if ( 

884 stitch_config 

885 and "tables" in stitch_config 

886 and stitch_config["tables"] 

887 ): 

888 # Extract unique datasources from the path values in stitch_config tables 

889 datasource_set = set() 

890 for table in stitch_config["tables"]: 

891 if "path" in table: 

892 datasource_set.add(table["path"]) 

893 

894 datasources = list(datasource_set) 

895 # Extract datasources successfully from stitch config 

896 

897 # If we still don't have datasources, query the table directly 

898 if not datasources: 

899 # Get the configured warehouse ID 

900 warehouse_id = get_warehouse_id() 

901 

902 # If no warehouse ID is configured, try to find a default one 

903 if not warehouse_id: 

904 warehouses = self.list_warehouses() 

905 if not warehouses: 

906 raise ValueError( 

907 "No SQL warehouses found and no warehouse configured. Please select a warehouse using /warehouse_selection." 

908 ) 

909 

910 # Use the first available warehouse 

911 warehouse_id = warehouses[0]["id"] 

912 logging.warning( 

913 f"No warehouse configured. Using first available warehouse: {warehouse_id}" 

914 ) 

915 

916 # Query for distinct datasource values 

917 sql_query = f"SELECT DISTINCT datasource FROM {table_path}" 

918 # Execute SQL query to get datasources 

919 result = self.submit_sql_statement(sql_query, warehouse_id) 

920 

921 # Extract the results from the query response 

922 if ( 

923 result 

924 and result.get("result") 

925 and result["result"].get("data_array") 

926 ): 

927 datasources = [row[0] for row in result["result"]["data_array"]] 

928 # Successfully extracted datasources from query 

929 

930 # If we still don't have datasources, use a default value 

931 if not datasources: 

932 logging.warning(f"No datasources found for {table_path}") 

933 datasources = ["default_source"] 

934 # Use default datasource as a fallback 

935 

936 # 4. Create the source names JSON mapping 

937 source_names_json = {} 

938 for source in datasources: 

939 source_names_json[source] = source 

940 

941 # Convert to JSON string (formatted nicely) 

942 source_names_str = json.dumps(source_names_json, indent=4) 

943 # Source mapping created for template 

944 

945 # 5. Replace the template fields 

946 # Replace template placeholders with actual values 

947 

948 # Need to directly modify the notebook cells rather than doing string replacement on the JSON 

949 for cell in notebook_content.get("cells", []): 

950 if cell.get("cell_type") == "code": 

951 source_lines = cell.get("source", []) 

952 for i, line in enumerate(source_lines): 

953 if ( 

954 '"{UNIFIED_PATH}"' in line 

955 or "unified_coalesced_path = " in line 

956 ): 

957 # Found placeholder in notebook 

958 # Replace the line with our table path 

959 source_lines[i] = ( 

960 f'unified_coalesced_path = "{table_path}"\n' 

961 ) 

962 break 

963 

964 # Replace source semantic mapping in the cells 

965 replaced_mapping = False 

966 for cell in notebook_content.get("cells", []): 

967 if cell.get("cell_type") == "code": 

968 source_lines = cell.get("source", []) 

969 # Look for the source_semantic_mapping definition 

970 for i, line in enumerate(source_lines): 

971 if "source_semantic_mapping =" in line and not replaced_mapping: 

972 # Found source names placeholder 

973 

974 # Find the closing brace 

975 closing_index = None 

976 opening_count = 0 

977 for j in range(i, len(source_lines)): 

978 if "{" in source_lines[j]: 

979 opening_count += 1 

980 if "}" in source_lines[j]: 

981 opening_count -= 1 

982 if opening_count == 0: 

983 closing_index = j 

984 break 

985 

986 if closing_index is not None: 

987 # Replace the mapping with our custom mapping 

988 mapping_start = i 

989 mapping_end = closing_index + 1 

990 new_line = ( 

991 f"source_semantic_mapping = {source_names_str}\n" 

992 ) 

993 source_lines[mapping_start:mapping_end] = [new_line] 

994 replaced_mapping = True 

995 # Import notebook to workspace 

996 break 

997 

998 if not replaced_mapping: 

999 logging.warning( 

1000 "Could not find source_semantic_mapping in the notebook template to replace" 

1001 ) 

1002 

1003 # 6. Get the current user's username 

1004 username = self.get_current_user() 

1005 

1006 # 7. Construct the notebook path 

1007 notebook_path = f"/Workspace/Users/{username}/{final_notebook_name}" 

1008 

1009 # 8. Convert the notebook to base64 for API call 

1010 notebook_content_str = json.dumps(notebook_content) 

1011 encoded_content = base64.b64encode(notebook_content_str.encode()).decode() 

1012 

1013 # 9. Import the notebook to Databricks 

1014 import_data = { 

1015 "path": notebook_path, 

1016 "content": encoded_content, 

1017 "format": "JUPYTER", 

1018 "overwrite": True, 

1019 } 

1020 

1021 self.post("/api/2.0/workspace/import", import_data) 

1022 

1023 # 10. Log success and return the path 

1024 # Notebook created successfully 

1025 

1026 return {"notebook_path": notebook_path, "status": "success"} 

1027 

1028 except Exception as e: 

1029 logging.debug(f"Error creating stitch notebook: {e}") 

1030 raise