Coverage for src/agent/manager.py: 85%

110 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-05 22:56 -0700

1import json 

2import logging 

3from copy import deepcopy 

4from src.llm.client import LLMClient 

5from .tool_executor import get_tool_schemas, execute_tool 

6from src.config import ( 

7 get_active_catalog, 

8 get_active_schema, 

9 get_warehouse_id, 

10 get_workspace_url, 

11) 

12 

13from .prompts import ( 

14 DEFAULT_SYSTEM_MESSAGE, 

15 PII_AGENT_SYSTEM_MESSAGE, 

16 BULK_PII_AGENT_SYSTEM_MESSAGE, 

17 STITCH_AGENT_SYSTEM_MESSAGE, 

18) 

19 

20 

21class AgentManager: 

22 def __init__(self, client, model=None, tool_output_callback=None): 

23 self.api_client = client 

24 self.llm_client = LLMClient() 

25 self.model = model 

26 self.tool_output_callback = tool_output_callback 

27 self.conversation_history = [ 

28 {"role": "system", "content": DEFAULT_SYSTEM_MESSAGE} 

29 ] 

30 

31 def add_user_message(self, content): 

32 self.conversation_history.append({"role": "user", "content": content}) 

33 

34 def add_assistant_message(self, content): 

35 self.conversation_history.append({"role": "assistant", "content": content}) 

36 

37 def add_system_message(self, content): 

38 # If there's already a system message, replace it; otherwise prepend 

39 for i, msg in enumerate(self.conversation_history): 

40 if msg["role"] == "system": 

41 self.conversation_history[i] = {"role": "system", "content": content} 

42 return 

43 self.conversation_history.insert(0, {"role": "system", "content": content}) 

44 

45 def process_pii_detection(self, table_name): 

46 """Process a PII detection request for a specific table 

47 

48 Args: 

49 table_name: Name of the table to analyze 

50 Returns: 

51 Final response from the LLM 

52 """ 

53 # Start with a clean conversation specifically for PII detection 

54 self.conversation_history = [] 

55 

56 # Add system message for PII detection 

57 self.add_system_message(PII_AGENT_SYSTEM_MESSAGE) 

58 

59 # Add user message requesting PII analysis 

60 self.add_user_message(f"Analyze the table '{table_name}' for PII data.") 

61 

62 # Get available tools - specifically need the tag_pii_columns and get_table_info tools 

63 tools = get_tool_schemas() 

64 

65 # Process using the LLM 

66 return self.process_with_tools(tools) 

67 

68 def process_bulk_pii_scan(self, catalog_name=None, schema_name=None): 

69 """Process a bulk PII scan for all tables in the current catalog and schema 

70 

71 Args: 

72 catalog_name: Optional name of the catalog to scan (uses active catalog if None) 

73 schema_name: Optional name of the schema to scan (uses active schema if None) 

74 

75 Returns: 

76 Final response from the LLM with consolidated PII analysis 

77 """ 

78 # Start with a clean conversation specifically for bulk PII scanning 

79 self.conversation_history = [] 

80 

81 # Add system message for bulk PII detection 

82 self.add_system_message(BULK_PII_AGENT_SYSTEM_MESSAGE) 

83 

84 # Add user message requesting bulk PII analysis 

85 if catalog_name and schema_name: 

86 self.add_user_message( 

87 f"Scan all tables in catalog '{catalog_name}' and schema '{schema_name}' for PII data." 

88 ) 

89 else: 

90 self.add_user_message( 

91 "Scan all tables in the current catalog and schema for PII data." 

92 ) 

93 

94 # Get available tools 

95 tools = get_tool_schemas() 

96 

97 # Process using the LLM 

98 return self.process_with_tools(tools) 

99 

100 def process_setup_stitch(self, catalog_name=None, schema_name=None): 

101 """Process a Stitch setup request 

102 

103 Args: 

104 catalog_name: Optional name of the catalog to use 

105 schema_name: Optional name of the schema to use 

106 

107 Returns: 

108 Final response from the LLM with setup instructions 

109 """ 

110 # Start with a clean conversation 

111 self.conversation_history = [] 

112 

113 # Add system message for stitch setup 

114 self.add_system_message(STITCH_AGENT_SYSTEM_MESSAGE) 

115 

116 # Add user message requesting stitch setup 

117 if catalog_name and schema_name: 

118 self.add_user_message( 

119 f"Set up a Stitch integration for catalog '{catalog_name}' and schema '{schema_name}'." 

120 ) 

121 else: 

122 self.add_user_message( 

123 "Set up a Stitch integration for the current catalog and schema." 

124 ) 

125 

126 # Get available tools 

127 tools = get_tool_schemas() 

128 

129 # Process using the LLM 

130 return self.process_with_tools(tools) 

131 

132 def process_with_tools(self, tools, max_iterations: int = 20): 

133 """Process the current conversation with tools until a final response is received. 

134 

135 Args: 

136 tools: Tool schemas to use 

137 max_iterations: Maximum number of LLM calls to make before aborting 

138 

139 Returns: 

140 Final text response from the LLM, or an error message if the limit is reached 

141 """ 

142 original_system_message_content = None 

143 system_message_index = -1 

144 iteration_count = 0 

145 

146 # Find the system message and store its original content 

147 # We do this once before the loop starts 

148 for i, msg in enumerate(self.conversation_history): 

149 if msg["role"] == "system": 

150 system_message_index = i 

151 # Store the content as it was when the process started 

152 original_system_message_content = msg["content"] 

153 break 

154 

155 if system_message_index == -1: 

156 # This should ideally not happen if the history is initialized correctly 

157 logging.error("System message not found in conversation history.") 

158 # Handle error appropriately, maybe raise exception or return error message 

159 

160 while iteration_count < max_iterations: 

161 # Prepare a temporary history copy for this specific LLM call 

162 current_history = deepcopy(self.conversation_history) 

163 

164 # Get current configuration state 

165 active_catalog = get_active_catalog() or "Not set" 

166 active_schema = get_active_schema() or "Not set" 

167 warehouse_id = get_warehouse_id() or "Not set" 

168 workspace_url = ( 

169 get_workspace_url() or "Not set" 

170 ) # Assuming get_workspace_url exists 

171 

172 config_state_info = ( 

173 f"\n\n--- CURRENT CONTEXT ---\n" 

174 f"Workspace URL: {workspace_url}\n" 

175 f"Active Catalog: {active_catalog}\n" 

176 f"Active Schema: {active_schema}\n" 

177 f"Active Warehouse ID: {warehouse_id}\n" 

178 f"-----------------------" 

179 ) 

180 

181 # Update the system message *in the temporary copy* 

182 # Append the current config state to the *original* system message content 

183 if ( 

184 system_message_index != -1 

185 and original_system_message_content is not None 

186 ): 

187 current_history[system_message_index]["content"] = ( 

188 original_system_message_content + config_state_info 

189 ) 

190 # else: log warning or handle case where system message wasn't found initially 

191 

192 # Get the LLM response using the temporary, updated history 

193 response = self.llm_client.chat( 

194 messages=current_history, # Use the modified temporary history for the call 

195 model=self.model, 

196 tools=tools, 

197 stream=False, # Important: No streaming within the loop 

198 ) 

199 

200 response_message = response.choices[0].message 

201 iteration_count += 1 

202 

203 # --- IMPORTANT --- 

204 # All modifications to the conversation history (appending assistant messages, tool calls, tool results) 

205 # MUST be done on the original self.conversation_history, NOT the temporary current_history. 

206 # current_history is only used for the LLM call itself. 

207 

208 # Check if the response contains tool calls 

209 if response_message.tool_calls: 

210 # Add the assistant's response (requesting tool calls) to history 

211 self.conversation_history.append(response_message) 

212 

213 # Execute each tool call 

214 for tool_call in response_message.tool_calls: 

215 tool_name = tool_call.function.name 

216 tool_id = tool_call.id 

217 try: 

218 tool_args = json.loads(tool_call.function.arguments) 

219 tool_result = execute_tool( 

220 self.api_client, 

221 tool_name, 

222 tool_args, 

223 output_callback=self.tool_output_callback, 

224 ) 

225 except json.JSONDecodeError as e: 

226 tool_result = {"error": f"Invalid JSON arguments: {e}"} 

227 except Exception as e: 

228 # Handle pagination cancellation specially - let it bubble up 

229 from src.exceptions import PaginationCancelled 

230 

231 if isinstance(e, PaginationCancelled): 

232 raise # Re-raise to bubble up to main TUI loop 

233 

234 tool_result = {"error": f"Tool execution failed: {e}"} 

235 

236 # Add the tool execution result to the *original* history 

237 self.conversation_history.append( 

238 { 

239 "role": "tool", 

240 "tool_call_id": tool_id, 

241 "name": tool_name, 

242 "content": json.dumps(tool_result), 

243 } 

244 ) 

245 

246 # Check if any tool initiated interactive mode - if so, stop processing 

247 from src.interactive_context import InteractiveContext 

248 

249 interactive_context = InteractiveContext() 

250 if interactive_context.is_in_interactive_mode(): 

251 # A tool has initiated interactive mode, stop agent processing 

252 # Return empty response to let TUI handle the interaction 

253 logging.debug( 

254 "Tool initiated interactive mode, stopping agent processing" 

255 ) 

256 return "" 

257 

258 # Continue the loop to get the next LLM response based on tool results 

259 continue 

260 else: 

261 # No tool calls, this is the final response 

262 final_content = response_message.content 

263 # remove all lines with any <function> tags 

264 final_content = "\n".join( 

265 line 

266 for line in final_content.splitlines() 

267 if "<function" not in line 

268 ) 

269 self.add_assistant_message(final_content) 

270 return final_content 

271 

272 logging.error( 

273 "process_with_tools reached maximum iterations without final response" 

274 ) 

275 error_msg = "Error: maximum iterations reached." 

276 self.add_assistant_message(error_msg) 

277 return error_msg 

278 

279 def process_query(self, query): 

280 """Process a general query using available tools 

281 

282 Args: 

283 query: User's query text 

284 

285 Returns: 

286 Final response from the LLM 

287 """ 

288 # If no system message exists, add the default one 

289 has_system = False 

290 for msg in self.conversation_history: 

291 if msg["role"] == "system": 

292 has_system = True 

293 break 

294 

295 if not has_system: 

296 self.add_system_message(DEFAULT_SYSTEM_MESSAGE) 

297 

298 # Add user message to history 

299 self.add_user_message(query) 

300 

301 # Get available tools 

302 tools = get_tool_schemas() 

303 

304 # Process using the LLM 

305 return self.process_with_tools(tools)