Coverage for src/commands/pii_tools.py: 73%

88 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-05 22:56 -0700

1""" 

2PII handling utilities for command handlers. 

3 

4This module contains helper functions for detecting and tagging PII (Personally 

5Identifiable Information) in database tables. 

6""" 

7 

8import logging 

9import json 

10import concurrent.futures 

11from typing import Dict, Any, Optional 

12 

13from src.clients.databricks import DatabricksAPIClient 

14from src.llm.client import LLMClient 

15 

16 

17def _helper_tag_pii_columns_logic( 

18 databricks_client: DatabricksAPIClient, 

19 llm_client_instance: LLMClient, 

20 table_name_param: str, 

21 catalog_name_context: Optional[str] = None, 

22 schema_name_context: Optional[str] = None, 

23) -> Dict[str, Any]: 

24 """Internal logic for PII tagging of a single table.""" 

25 response_content_for_error = "" 

26 try: 

27 # Resolve full table name using APIs directly instead of handler 

28 table_details_kwargs = {"full_name": table_name_param} 

29 if catalog_name_context and schema_name_context and "." not in table_name_param: 

30 # Only a table name was provided, construct full name 

31 full_name = ( 

32 f"{catalog_name_context}.{schema_name_context}.{table_name_param}" 

33 ) 

34 table_details_kwargs = {"full_name": full_name} 

35 

36 try: 

37 # Use direct API call instead of handle_table 

38 table_info = databricks_client.get_table(**table_details_kwargs) 

39 if not table_info: 

40 error_msg = f"Failed to retrieve table details for PII tagging: {table_name_param}" 

41 return { 

42 "error": error_msg, 

43 "table_name_param": table_name_param, 

44 "skipped": True, 

45 } 

46 

47 resolved_full_name = table_info.get("full_name", table_name_param) 

48 columns = table_info.get("columns", []) 

49 except Exception as e: 

50 error_msg = f"Failed to retrieve table details: {str(e)}" 

51 return { 

52 "error": error_msg, 

53 "table_name_param": table_name_param, 

54 "skipped": True, 

55 } # Skipped due to error 

56 

57 base_name_of_resolved = resolved_full_name.split(".")[-1] 

58 

59 if base_name_of_resolved.startswith("_stitch"): 

60 return { 

61 "skipped": True, 

62 "reason": f"Table '{resolved_full_name}' starts with _stitch.", 

63 "full_name": resolved_full_name, 

64 "table_name": base_name_of_resolved, 

65 } 

66 

67 if not columns: 

68 return { 

69 "table_name": base_name_of_resolved, 

70 "full_name": resolved_full_name, 

71 "column_count": 0, 

72 "pii_column_count": 0, 

73 "has_pii": False, 

74 "columns": [], 

75 "pii_columns": [], 

76 "skipped": False, 

77 } 

78 

79 # Use the LLM client instance passed to the function 

80 column_details_for_llm = [ 

81 {"name": col.get("name", ""), "type": col.get("type_name", "")} 

82 for col in columns 

83 ] 

84 

85 system_message = ( 

86 "You are an expert PII detection assistant. Your task is to analyze a list of database columns (name and type) " 

87 "and assign a PII semantic tag to each column if applicable. Use ONLY the following PII semantic tags: " 

88 "address, address2, birthdate, city, country, create-dt, email, full-name, gender, generational-suffix, " 

89 "given-name, phone, postal, state, surname, title, update-dt. If a column does not contain PII, assign null. " 

90 "Respond ONLY with a valid JSON list of objects, where each object represents a column and has the following structure: " 

91 '{"name": "column_name", "semantic": "pii_tag_or_null"}. ' 

92 "Maintain original order. No explanations or introductory text." 

93 ) 

94 user_prompt = f"Analyze the following columns from table '{resolved_full_name}' and provide PII semantic tags in the specified JSON format: {json.dumps(column_details_for_llm, indent=2)}" 

95 

96 llm_response_obj = llm_client_instance.chat( 

97 messages=[ 

98 {"role": "system", "content": system_message}, 

99 {"role": "user", "content": user_prompt}, 

100 ] 

101 ) 

102 response_content_for_error = llm_response_obj.choices[ 

103 0 

104 ].message.content # Store for potential error reporting 

105 response_content_clean = response_content_for_error.strip() 

106 if response_content_clean.startswith("```json"): 

107 response_content_clean = response_content_clean[7:-3].strip() 

108 elif response_content_clean.startswith("```"): 

109 response_content_clean = response_content_clean[3:-3].strip() 

110 

111 llm_tags = json.loads(response_content_clean) 

112 if not isinstance(llm_tags, list) or len(llm_tags) != len(columns): 

113 raise ValueError( 

114 f"LLM PII tag response format error. Expected {len(columns)} items, got {len(llm_tags)}." 

115 ) 

116 

117 semantic_map = { 

118 item["name"]: item["semantic"] 

119 for item in llm_tags 

120 if isinstance(item, dict) and "name" in item 

121 } 

122 tagged_columns_list = [] 

123 for col in columns: 

124 col_name = col.get("name", "") 

125 tagged_columns_list.append( 

126 { 

127 "name": col_name, 

128 "type": col.get("type_name", ""), 

129 "semantic": semantic_map.get(col_name), 

130 } 

131 ) 

132 

133 pii_cols = [col for col in tagged_columns_list if col["semantic"]] 

134 return { 

135 "table_name": base_name_of_resolved, 

136 "full_name": resolved_full_name, 

137 "column_count": len(columns), 

138 "pii_column_count": len(pii_cols), 

139 "has_pii": bool(pii_cols), 

140 "columns": tagged_columns_list, 

141 "pii_columns": pii_cols, 

142 "skipped": False, 

143 } 

144 except json.JSONDecodeError as e_json: 

145 logging.error( 

146 f"_helper_tag_pii_columns_logic: JSONDecodeError: {e_json} from LLM response: {response_content_for_error[:500]}" 

147 ) # Log more of the response 

148 return {"error": f"Failed to parse PII LLM response: {e_json}", "skipped": True} 

149 except Exception as e_tag: 

150 logging.error( 

151 f"_helper_tag_pii_columns_logic error for '{table_name_param}': {e_tag}", 

152 exc_info=True, 

153 ) 

154 return { 

155 "error": f"Error during PII tagging for '{table_name_param}': {str(e_tag)}", 

156 "skipped": True, 

157 } 

158 

159 

160def _helper_scan_schema_for_pii_logic( 

161 client: DatabricksAPIClient, 

162 llm_client_instance: LLMClient, 

163 catalog_name: str, 

164 schema_name: str, 

165) -> Dict[str, Any]: 

166 """Internal logic for scanning all tables in a schema for PII.""" 

167 if not catalog_name or not schema_name: 

168 return {"error": "Catalog and schema names are required for bulk PII scan."} 

169 

170 # Use direct API call instead of handle_tables 

171 try: 

172 tables_response = client.list_tables( 

173 catalog_name=catalog_name, schema_name=schema_name, omit_columns=True 

174 ) 

175 all_tables_in_schema = tables_response.get("tables", []) 

176 except Exception as e: 

177 return { 

178 "error": f"Failed to list tables for {catalog_name}.{schema_name}: {str(e)}" 

179 } 

180 

181 tables_to_scan_summaries = [ 

182 tbl 

183 for tbl in all_tables_in_schema 

184 if isinstance(tbl, dict) and not tbl.get("name", "").startswith("_stitch") 

185 ] 

186 

187 if not tables_to_scan_summaries: 

188 return { 

189 "message": f"No user tables (excluding _stitch*) found in {catalog_name}.{schema_name}.", 

190 "catalog": catalog_name, 

191 "schema": schema_name, 

192 "tables_scanned": 0, 

193 "tables_with_pii": 0, 

194 "total_pii_columns": 0, 

195 "results_detail": [], 

196 } 

197 

198 logging.info( 

199 f"Starting PII Scan for {len(tables_to_scan_summaries)} tables in {catalog_name}.{schema_name}." 

200 ) 

201 scan_results_detail = [] 

202 MAX_WORKERS = 5 

203 futures_map = {} 

204 with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: 

205 for table_summary_dict in tables_to_scan_summaries: 

206 table_name_only = table_summary_dict.get("name") 

207 if not table_name_only: 

208 continue 

209 # Pass client and context to the helper 

210 futures_map[ 

211 executor.submit( 

212 _helper_tag_pii_columns_logic, 

213 client, 

214 llm_client_instance, 

215 table_name_only, 

216 catalog_name, 

217 schema_name, 

218 ) 

219 ] = f"{catalog_name}.{schema_name}.{table_name_only}" 

220 

221 for future in concurrent.futures.as_completed(futures_map): 

222 fq_table_name_processed = futures_map[future] 

223 try: 

224 table_pii_result_dict = future.result() 

225 scan_results_detail.append(table_pii_result_dict) 

226 except Exception as exc_future: 

227 logging.error( 

228 f"Error processing table '{fq_table_name_processed}' in PII scan thread: {exc_future}", 

229 exc_info=True, 

230 ) 

231 scan_results_detail.append( 

232 { 

233 "full_name": fq_table_name_processed, 

234 "error": str(exc_future), 

235 "skipped": True, 

236 } 

237 ) 

238 

239 scan_results_detail.sort(key=lambda x: x.get("full_name", "")) 

240 total_pii_cols_found = sum( 

241 r.get("pii_column_count", 0) 

242 for r in scan_results_detail 

243 if not r.get("error") and not r.get("skipped") 

244 ) 

245 num_tables_with_pii = sum( 

246 1 

247 for r in scan_results_detail 

248 if not r.get("error") and not r.get("skipped") and r.get("has_pii") 

249 ) 

250 num_tables_successfully_processed = sum( 

251 1 for r in scan_results_detail if not r.get("error") and not r.get("skipped") 

252 ) 

253 

254 return { 

255 "catalog": catalog_name, 

256 "schema": schema_name, 

257 "tables_scanned_attempted": len(tables_to_scan_summaries), 

258 "tables_successfully_processed": num_tables_successfully_processed, 

259 "tables_with_pii": num_tables_with_pii, 

260 "total_pii_columns": total_pii_cols_found, 

261 "results_detail": scan_results_detail, 

262 }