Coverage for src/chuck_data/profiler.py: 0%

117 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-05 22:56 -0700

1import logging 

2import datetime 

3import time 

4import base64 

5import json 

6 

7 

8def list_tables(client, warehouse_id): 

9 """ 

10 Lists all tables in Unity Catalog using a SQL query. 

11 

12 Args: 

13 client: DatabricksAPIClient instance 

14 warehouse_id: ID of the SQL warehouse 

15 

16 Returns: 

17 List of dictionaries with table_name, catalog_name, and schema_name 

18 """ 

19 print("Listing tables in Unity Catalog...", warehouse_id) 

20 sql_text = "SELECT table_name, catalog_name, schema_name FROM system.information_schema.tables;" 

21 data = { 

22 "on_wait_timeout": "CONTINUE", 

23 "statement": sql_text, 

24 "wait_timeout": "30s", 

25 "warehouse_id": warehouse_id, 

26 } 

27 

28 # Submit the SQL statement 

29 response = client.post("/api/2.0/sql/statements", data) 

30 statement_id = response.get("statement_id") 

31 

32 # Poll until complete 

33 state = None 

34 while True: 

35 status = client.get(f"/api/2.0/sql/statements/{statement_id}") 

36 state = status.get("status", {}).get("state", status.get("state")) 

37 if state not in ["PENDING", "RUNNING"]: 

38 break 

39 time.sleep(1) 

40 

41 # Get results 

42 if state != "SUCCEEDED": 

43 return [] 

44 

45 # Parse results 

46 result = status.get("result", {}) 

47 data_array = result.get("data", []) 

48 

49 tables = [] 

50 for row in data_array: 

51 tables.append( 

52 {"table_name": row[0], "catalog_name": row[1], "schema_name": row[2]} 

53 ) 

54 

55 return tables 

56 

57 

58def get_table_schema(client, warehouse_id, catalog_name, schema_name, table_name): 

59 """ 

60 Retrieves the extended schema of a specified table. 

61 

62 Args: 

63 client: DatabricksAPIClient instance 

64 warehouse_id: ID of the SQL warehouse 

65 catalog_name: Catalog name 

66 schema_name: Schema name 

67 table_name: Table name 

68 

69 Returns: 

70 List of schema details 

71 """ 

72 sql_text = f"DESCRIBE EXTENDED {catalog_name}.{schema_name}.{table_name}" 

73 data = {"warehouse_id": warehouse_id, "catalog": catalog_name, "sql_text": sql_text} 

74 

75 # Submit the SQL statement 

76 response = client.post("/api/2.0/sql/statements", data) 

77 statement_id = response.get("statement_id") 

78 

79 # Poll until complete 

80 state = None 

81 while True: 

82 status = client.get(f"/api/2.0/sql/statements/{statement_id}") 

83 state = status.get("status", {}).get("state", status.get("state")) 

84 if state not in ["PENDING", "RUNNING"]: 

85 break 

86 time.sleep(1) 

87 

88 # Get results 

89 if state != "SUCCEEDED": 

90 return [] 

91 

92 # Parse results 

93 result = status.get("result", {}) 

94 data_array = result.get("data", []) 

95 

96 # Format schema info 

97 schema = [] 

98 for row in data_array: 

99 schema.append( 

100 { 

101 "col_name": row[0], 

102 "data_type": row[1], 

103 "comment": row[2] if len(row) > 2 else "", 

104 } 

105 ) 

106 

107 return schema 

108 

109 

110def get_sample_data(client, warehouse_id, catalog_name, schema_name, table_name): 

111 """ 

112 Retrieves sample data (first 10 rows) from the specified table. 

113 

114 Args: 

115 client: DatabricksAPIClient instance 

116 warehouse_id: ID of the SQL warehouse 

117 catalog_name: Catalog name 

118 schema_name: Schema name 

119 table_name: Table name 

120 

121 Returns: 

122 Dictionary with column_names and rows 

123 """ 

124 sql_text = f"SELECT * FROM {catalog_name}.{schema_name}.{table_name} LIMIT 10" 

125 data = {"warehouse_id": warehouse_id, "catalog": catalog_name, "sql_text": sql_text} 

126 

127 # Submit the SQL statement 

128 response = client.post("/api/2.0/sql/statements", data) 

129 statement_id = response.get("statement_id") 

130 

131 # Poll until complete 

132 state = None 

133 while True: 

134 status = client.get(f"/api/2.0/sql/statements/{statement_id}") 

135 state = status.get("status", {}).get("state", status.get("state")) 

136 if state not in ["PENDING", "RUNNING"]: 

137 break 

138 time.sleep(1) 

139 

140 # Get results 

141 if state != "SUCCEEDED": 

142 return [] 

143 

144 # Parse results 

145 result = status.get("result", {}) 

146 column_names = [col["name"] for col in result.get("schema", [])] 

147 sample_rows = [] 

148 

149 for row in result.get("data", []): 

150 sample_row = {} 

151 for i, col_name in enumerate(column_names): 

152 if i < len(row): 

153 sample_row[col_name] = row[i] 

154 else: 

155 sample_row[col_name] = None 

156 sample_rows.append(sample_row) 

157 

158 return {"column_names": column_names, "rows": sample_rows} 

159 

160 

161def query_llm(client, endpoint_name, input_data): 

162 """ 

163 Queries the LLM via the Serving Endpoints API. 

164 

165 Args: 

166 client: DatabricksAPIClient instance 

167 endpoint_name: Name of the serving endpoint 

168 input_data: Data to send to the LLM 

169 

170 Returns: 

171 Response from the LLM 

172 """ 

173 endpoint = f"/api/2.0/serving-endpoints/{endpoint_name}/invocations" 

174 

175 # Format input for the LLM 

176 # This format may need adjustment based on model requirements 

177 request_data = { 

178 "inputs": [ 

179 {"schema": input_data["schema"], "sample_data": input_data["sample_data"]} 

180 ] 

181 } 

182 

183 response = client.post(endpoint, request_data) 

184 return response 

185 

186 

187def generate_manifest(table_info, schema, sample_data, pii_tags): 

188 """ 

189 Generates a JSON manifest with profiling results. 

190 

191 Args: 

192 table_info: Dictionary with table_name, catalog_name, schema_name 

193 schema: Table schema information 

194 sample_data: Sample data from the table 

195 pii_tags: PII tags from LLM response 

196 

197 Returns: 

198 Dictionary representing the manifest 

199 """ 

200 manifest = { 

201 "table": { 

202 "catalog_name": table_info["catalog_name"], 

203 "schema_name": table_info["schema_name"], 

204 "table_name": table_info["table_name"], 

205 }, 

206 "schema": schema, 

207 "pii_tags": pii_tags, 

208 "profiling_timestamp": datetime.datetime.now().isoformat(), 

209 } 

210 

211 return manifest 

212 

213 

214def store_manifest(client, manifest_path, manifest): 

215 """ 

216 Stores the manifest in DBFS. 

217 

218 Args: 

219 client: DatabricksAPIClient instance 

220 manifest_path: Path in DBFS 

221 manifest: Dictionary to store 

222 

223 Returns: 

224 True if successful, False otherwise 

225 """ 

226 # DBFS API endpoint for file upload 

227 endpoint = "/api/2.0/dbfs/put" 

228 

229 # Convert manifest to JSON string 

230 manifest_json = json.dumps(manifest, indent=2) 

231 

232 # Prepare the request with file content and path 

233 request_data = { 

234 "path": manifest_path, 

235 "contents": base64.b64encode(manifest_json.encode()).decode(), 

236 "overwrite": True, 

237 } 

238 

239 try: 

240 client.post(endpoint, request_data) 

241 return True 

242 except Exception as e: 

243 logging.error(f"Failed to store manifest: {e}") 

244 return False 

245 

246 

247def profile_table(client, warehouse_id, endpoint_name, table_info=None): 

248 """ 

249 Main function to orchestrate the profiling process. 

250 

251 Args: 

252 client: DatabricksAPIClient instance 

253 warehouse_id: ID of the SQL warehouse 

254 endpoint_name: Name of the serving endpoint 

255 table_info: Optional dictionary with table_name, catalog_name, schema_name 

256 If None, the first table will be used 

257 

258 Returns: 

259 Path to the stored manifest, or None if profiling failed 

260 """ 

261 try: 

262 # Step 1: List tables if no specific table provided 

263 if table_info is None: 

264 tables = list_tables(client, warehouse_id) 

265 if not tables: 

266 logging.error("No tables found") 

267 return None 

268 

269 # Select the first table 

270 table_info = tables[0] 

271 

272 # Step 2: Get schema and sample data 

273 schema = get_table_schema( 

274 client, 

275 warehouse_id, 

276 table_info["catalog_name"], 

277 table_info["schema_name"], 

278 table_info["table_name"], 

279 ) 

280 

281 if not schema: 

282 logging.error("Failed to retrieve schema") 

283 return None 

284 

285 sample_data = get_sample_data( 

286 client, 

287 warehouse_id, 

288 table_info["catalog_name"], 

289 table_info["schema_name"], 

290 table_info["table_name"], 

291 ) 

292 

293 if not sample_data: 

294 logging.error("Failed to retrieve sample data") 

295 return None 

296 

297 # Step 3: Prepare input for LLM 

298 input_data = {"schema": schema, "sample_data": sample_data} 

299 

300 # Step 4: Query LLM 

301 llm_response = query_llm(client, endpoint_name, input_data) 

302 

303 # Extract PII tags from response (format may vary based on LLM) 

304 # This is a simplified extraction - adjust based on actual response format 

305 pii_tags = llm_response.get("predictions", [])[0].get("pii_tags", []) 

306 

307 # Step 5: Generate manifest 

308 manifest = generate_manifest(table_info, schema, sample_data, pii_tags) 

309 

310 # Step 6: Store manifest in DBFS 

311 manifest_path = f"/chuck/manifests/{table_info['table_name']}_manifest.json" 

312 success = store_manifest(client, manifest_path, manifest) 

313 

314 if success: 

315 return manifest_path 

316 else: 

317 return None 

318 

319 except Exception as e: 

320 logging.error(f"Error during profiling: {e}") 

321 return None