Coverage for src/profiler.py: 52%
117 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
1import logging
2import datetime
3import time
4import base64
5import json
8def list_tables(client, warehouse_id):
9 """
10 Lists all tables in Unity Catalog using a SQL query.
12 Args:
13 client: DatabricksAPIClient instance
14 warehouse_id: ID of the SQL warehouse
16 Returns:
17 List of dictionaries with table_name, catalog_name, and schema_name
18 """
19 print("Listing tables in Unity Catalog...", warehouse_id)
20 sql_text = "SELECT table_name, catalog_name, schema_name FROM system.information_schema.tables;"
21 data = {
22 "on_wait_timeout": "CONTINUE",
23 "statement": sql_text,
24 "wait_timeout": "30s",
25 "warehouse_id": warehouse_id,
26 }
28 # Submit the SQL statement
29 response = client.post("/api/2.0/sql/statements", data)
30 statement_id = response.get("statement_id")
32 # Poll until complete
33 state = None
34 while True:
35 status = client.get(f"/api/2.0/sql/statements/{statement_id}")
36 state = status.get("status", {}).get("state", status.get("state"))
37 if state not in ["PENDING", "RUNNING"]:
38 break
39 time.sleep(1)
41 # Get results
42 if state != "SUCCEEDED":
43 return []
45 # Parse results
46 result = status.get("result", {})
47 data_array = result.get("data", [])
49 tables = []
50 for row in data_array:
51 tables.append(
52 {"table_name": row[0], "catalog_name": row[1], "schema_name": row[2]}
53 )
55 return tables
58def get_table_schema(client, warehouse_id, catalog_name, schema_name, table_name):
59 """
60 Retrieves the extended schema of a specified table.
62 Args:
63 client: DatabricksAPIClient instance
64 warehouse_id: ID of the SQL warehouse
65 catalog_name: Catalog name
66 schema_name: Schema name
67 table_name: Table name
69 Returns:
70 List of schema details
71 """
72 sql_text = f"DESCRIBE EXTENDED {catalog_name}.{schema_name}.{table_name}"
73 data = {"warehouse_id": warehouse_id, "catalog": catalog_name, "sql_text": sql_text}
75 # Submit the SQL statement
76 response = client.post("/api/2.0/sql/statements", data)
77 statement_id = response.get("statement_id")
79 # Poll until complete
80 state = None
81 while True:
82 status = client.get(f"/api/2.0/sql/statements/{statement_id}")
83 state = status.get("status", {}).get("state", status.get("state"))
84 if state not in ["PENDING", "RUNNING"]:
85 break
86 time.sleep(1)
88 # Get results
89 if state != "SUCCEEDED":
90 return []
92 # Parse results
93 result = status.get("result", {})
94 data_array = result.get("data", [])
96 # Format schema info
97 schema = []
98 for row in data_array:
99 schema.append(
100 {
101 "col_name": row[0],
102 "data_type": row[1],
103 "comment": row[2] if len(row) > 2 else "",
104 }
105 )
107 return schema
110def get_sample_data(client, warehouse_id, catalog_name, schema_name, table_name):
111 """
112 Retrieves sample data (first 10 rows) from the specified table.
114 Args:
115 client: DatabricksAPIClient instance
116 warehouse_id: ID of the SQL warehouse
117 catalog_name: Catalog name
118 schema_name: Schema name
119 table_name: Table name
121 Returns:
122 Dictionary with column_names and rows
123 """
124 sql_text = f"SELECT * FROM {catalog_name}.{schema_name}.{table_name} LIMIT 10"
125 data = {"warehouse_id": warehouse_id, "catalog": catalog_name, "sql_text": sql_text}
127 # Submit the SQL statement
128 response = client.post("/api/2.0/sql/statements", data)
129 statement_id = response.get("statement_id")
131 # Poll until complete
132 state = None
133 while True:
134 status = client.get(f"/api/2.0/sql/statements/{statement_id}")
135 state = status.get("status", {}).get("state", status.get("state"))
136 if state not in ["PENDING", "RUNNING"]:
137 break
138 time.sleep(1)
140 # Get results
141 if state != "SUCCEEDED":
142 return []
144 # Parse results
145 result = status.get("result", {})
146 column_names = [col["name"] for col in result.get("schema", [])]
147 sample_rows = []
149 for row in result.get("data", []):
150 sample_row = {}
151 for i, col_name in enumerate(column_names):
152 if i < len(row):
153 sample_row[col_name] = row[i]
154 else:
155 sample_row[col_name] = None
156 sample_rows.append(sample_row)
158 return {"column_names": column_names, "rows": sample_rows}
161def query_llm(client, endpoint_name, input_data):
162 """
163 Queries the LLM via the Serving Endpoints API.
165 Args:
166 client: DatabricksAPIClient instance
167 endpoint_name: Name of the serving endpoint
168 input_data: Data to send to the LLM
170 Returns:
171 Response from the LLM
172 """
173 endpoint = f"/api/2.0/serving-endpoints/{endpoint_name}/invocations"
175 # Format input for the LLM
176 # This format may need adjustment based on model requirements
177 request_data = {
178 "inputs": [
179 {"schema": input_data["schema"], "sample_data": input_data["sample_data"]}
180 ]
181 }
183 response = client.post(endpoint, request_data)
184 return response
187def generate_manifest(table_info, schema, sample_data, pii_tags):
188 """
189 Generates a JSON manifest with profiling results.
191 Args:
192 table_info: Dictionary with table_name, catalog_name, schema_name
193 schema: Table schema information
194 sample_data: Sample data from the table
195 pii_tags: PII tags from LLM response
197 Returns:
198 Dictionary representing the manifest
199 """
200 manifest = {
201 "table": {
202 "catalog_name": table_info["catalog_name"],
203 "schema_name": table_info["schema_name"],
204 "table_name": table_info["table_name"],
205 },
206 "schema": schema,
207 "pii_tags": pii_tags,
208 "profiling_timestamp": datetime.datetime.now().isoformat(),
209 }
211 return manifest
214def store_manifest(client, manifest_path, manifest):
215 """
216 Stores the manifest in DBFS.
218 Args:
219 client: DatabricksAPIClient instance
220 manifest_path: Path in DBFS
221 manifest: Dictionary to store
223 Returns:
224 True if successful, False otherwise
225 """
226 # DBFS API endpoint for file upload
227 endpoint = "/api/2.0/dbfs/put"
229 # Convert manifest to JSON string
230 manifest_json = json.dumps(manifest, indent=2)
232 # Prepare the request with file content and path
233 request_data = {
234 "path": manifest_path,
235 "contents": base64.b64encode(manifest_json.encode()).decode(),
236 "overwrite": True,
237 }
239 try:
240 client.post(endpoint, request_data)
241 return True
242 except Exception as e:
243 logging.error(f"Failed to store manifest: {e}")
244 return False
247def profile_table(client, warehouse_id, endpoint_name, table_info=None):
248 """
249 Main function to orchestrate the profiling process.
251 Args:
252 client: DatabricksAPIClient instance
253 warehouse_id: ID of the SQL warehouse
254 endpoint_name: Name of the serving endpoint
255 table_info: Optional dictionary with table_name, catalog_name, schema_name
256 If None, the first table will be used
258 Returns:
259 Path to the stored manifest, or None if profiling failed
260 """
261 try:
262 # Step 1: List tables if no specific table provided
263 if table_info is None:
264 tables = list_tables(client, warehouse_id)
265 if not tables:
266 logging.error("No tables found")
267 return None
269 # Select the first table
270 table_info = tables[0]
272 # Step 2: Get schema and sample data
273 schema = get_table_schema(
274 client,
275 warehouse_id,
276 table_info["catalog_name"],
277 table_info["schema_name"],
278 table_info["table_name"],
279 )
281 if not schema:
282 logging.error("Failed to retrieve schema")
283 return None
285 sample_data = get_sample_data(
286 client,
287 warehouse_id,
288 table_info["catalog_name"],
289 table_info["schema_name"],
290 table_info["table_name"],
291 )
293 if not sample_data:
294 logging.error("Failed to retrieve sample data")
295 return None
297 # Step 3: Prepare input for LLM
298 input_data = {"schema": schema, "sample_data": sample_data}
300 # Step 4: Query LLM
301 llm_response = query_llm(client, endpoint_name, input_data)
303 # Extract PII tags from response (format may vary based on LLM)
304 # This is a simplified extraction - adjust based on actual response format
305 pii_tags = llm_response.get("predictions", [])[0].get("pii_tags", [])
307 # Step 5: Generate manifest
308 manifest = generate_manifest(table_info, schema, sample_data, pii_tags)
310 # Step 6: Store manifest in DBFS
311 manifest_path = f"/chuck/manifests/{table_info['table_name']}_manifest.json"
312 success = store_manifest(client, manifest_path, manifest)
314 if success:
315 return manifest_path
316 else:
317 return None
319 except Exception as e:
320 logging.error(f"Error during profiling: {e}")
321 return None