Coverage for src/chuck_data/agent/manager.py: 0%
110 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
1import json
2import logging
3from copy import deepcopy
4from ..llm.client import LLMClient
5from .tool_executor import get_tool_schemas, execute_tool
6from ..config import (
7 get_active_catalog,
8 get_active_schema,
9 get_warehouse_id,
10 get_workspace_url,
11)
13from .prompts import (
14 DEFAULT_SYSTEM_MESSAGE,
15 PII_AGENT_SYSTEM_MESSAGE,
16 BULK_PII_AGENT_SYSTEM_MESSAGE,
17 STITCH_AGENT_SYSTEM_MESSAGE,
18)
21class AgentManager:
22 def __init__(self, client, model=None, tool_output_callback=None):
23 self.api_client = client
24 self.llm_client = LLMClient()
25 self.model = model
26 self.tool_output_callback = tool_output_callback
27 self.conversation_history = [
28 {"role": "system", "content": DEFAULT_SYSTEM_MESSAGE}
29 ]
31 def add_user_message(self, content):
32 self.conversation_history.append({"role": "user", "content": content})
34 def add_assistant_message(self, content):
35 self.conversation_history.append({"role": "assistant", "content": content})
37 def add_system_message(self, content):
38 # If there's already a system message, replace it; otherwise prepend
39 for i, msg in enumerate(self.conversation_history):
40 if msg["role"] == "system":
41 self.conversation_history[i] = {"role": "system", "content": content}
42 return
43 self.conversation_history.insert(0, {"role": "system", "content": content})
45 def process_pii_detection(self, table_name):
46 """Process a PII detection request for a specific table
48 Args:
49 table_name: Name of the table to analyze
50 Returns:
51 Final response from the LLM
52 """
53 # Start with a clean conversation specifically for PII detection
54 self.conversation_history = []
56 # Add system message for PII detection
57 self.add_system_message(PII_AGENT_SYSTEM_MESSAGE)
59 # Add user message requesting PII analysis
60 self.add_user_message(f"Analyze the table '{table_name}' for PII data.")
62 # Get available tools - specifically need the tag_pii_columns and get_table_info tools
63 tools = get_tool_schemas()
65 # Process using the LLM
66 return self.process_with_tools(tools)
68 def process_bulk_pii_scan(self, catalog_name=None, schema_name=None):
69 """Process a bulk PII scan for all tables in the current catalog and schema
71 Args:
72 catalog_name: Optional name of the catalog to scan (uses active catalog if None)
73 schema_name: Optional name of the schema to scan (uses active schema if None)
75 Returns:
76 Final response from the LLM with consolidated PII analysis
77 """
78 # Start with a clean conversation specifically for bulk PII scanning
79 self.conversation_history = []
81 # Add system message for bulk PII detection
82 self.add_system_message(BULK_PII_AGENT_SYSTEM_MESSAGE)
84 # Add user message requesting bulk PII analysis
85 if catalog_name and schema_name:
86 self.add_user_message(
87 f"Scan all tables in catalog '{catalog_name}' and schema '{schema_name}' for PII data."
88 )
89 else:
90 self.add_user_message(
91 "Scan all tables in the current catalog and schema for PII data."
92 )
94 # Get available tools
95 tools = get_tool_schemas()
97 # Process using the LLM
98 return self.process_with_tools(tools)
100 def process_setup_stitch(self, catalog_name=None, schema_name=None):
101 """Process a Stitch setup request
103 Args:
104 catalog_name: Optional name of the catalog to use
105 schema_name: Optional name of the schema to use
107 Returns:
108 Final response from the LLM with setup instructions
109 """
110 # Start with a clean conversation
111 self.conversation_history = []
113 # Add system message for stitch setup
114 self.add_system_message(STITCH_AGENT_SYSTEM_MESSAGE)
116 # Add user message requesting stitch setup
117 if catalog_name and schema_name:
118 self.add_user_message(
119 f"Set up a Stitch integration for catalog '{catalog_name}' and schema '{schema_name}'."
120 )
121 else:
122 self.add_user_message(
123 "Set up a Stitch integration for the current catalog and schema."
124 )
126 # Get available tools
127 tools = get_tool_schemas()
129 # Process using the LLM
130 return self.process_with_tools(tools)
132 def process_with_tools(self, tools, max_iterations: int = 20):
133 """Process the current conversation with tools until a final response is received.
135 Args:
136 tools: Tool schemas to use
137 max_iterations: Maximum number of LLM calls to make before aborting
139 Returns:
140 Final text response from the LLM, or an error message if the limit is reached
141 """
142 original_system_message_content = None
143 system_message_index = -1
144 iteration_count = 0
146 # Find the system message and store its original content
147 # We do this once before the loop starts
148 for i, msg in enumerate(self.conversation_history):
149 if msg["role"] == "system":
150 system_message_index = i
151 # Store the content as it was when the process started
152 original_system_message_content = msg["content"]
153 break
155 if system_message_index == -1:
156 # This should ideally not happen if the history is initialized correctly
157 logging.error("System message not found in conversation history.")
158 # Handle error appropriately, maybe raise exception or return error message
160 while iteration_count < max_iterations:
161 # Prepare a temporary history copy for this specific LLM call
162 current_history = deepcopy(self.conversation_history)
164 # Get current configuration state
165 active_catalog = get_active_catalog() or "Not set"
166 active_schema = get_active_schema() or "Not set"
167 warehouse_id = get_warehouse_id() or "Not set"
168 workspace_url = (
169 get_workspace_url() or "Not set"
170 ) # Assuming get_workspace_url exists
172 config_state_info = (
173 f"\n\n--- CURRENT CONTEXT ---\n"
174 f"Workspace URL: {workspace_url}\n"
175 f"Active Catalog: {active_catalog}\n"
176 f"Active Schema: {active_schema}\n"
177 f"Active Warehouse ID: {warehouse_id}\n"
178 f"-----------------------"
179 )
181 # Update the system message *in the temporary copy*
182 # Append the current config state to the *original* system message content
183 if (
184 system_message_index != -1
185 and original_system_message_content is not None
186 ):
187 current_history[system_message_index]["content"] = (
188 original_system_message_content + config_state_info
189 )
190 # else: log warning or handle case where system message wasn't found initially
192 # Get the LLM response using the temporary, updated history
193 response = self.llm_client.chat(
194 messages=current_history, # Use the modified temporary history for the call
195 model=self.model,
196 tools=tools,
197 stream=False, # Important: No streaming within the loop
198 )
200 response_message = response.choices[0].message
201 iteration_count += 1
203 # --- IMPORTANT ---
204 # All modifications to the conversation history (appending assistant messages, tool calls, tool results)
205 # MUST be done on the original self.conversation_history, NOT the temporary current_history.
206 # current_history is only used for the LLM call itself.
208 # Check if the response contains tool calls
209 if response_message.tool_calls:
210 # Add the assistant's response (requesting tool calls) to history
211 self.conversation_history.append(response_message)
213 # Execute each tool call
214 for tool_call in response_message.tool_calls:
215 tool_name = tool_call.function.name
216 tool_id = tool_call.id
217 try:
218 tool_args = json.loads(tool_call.function.arguments)
219 tool_result = execute_tool(
220 self.api_client,
221 tool_name,
222 tool_args,
223 output_callback=self.tool_output_callback,
224 )
225 except json.JSONDecodeError as e:
226 tool_result = {"error": f"Invalid JSON arguments: {e}"}
227 except Exception as e:
228 # Handle pagination cancellation specially - let it bubble up
229 from ..exceptions import PaginationCancelled
231 if isinstance(e, PaginationCancelled):
232 raise # Re-raise to bubble up to main TUI loop
234 tool_result = {"error": f"Tool execution failed: {e}"}
236 # Add the tool execution result to the *original* history
237 self.conversation_history.append(
238 {
239 "role": "tool",
240 "tool_call_id": tool_id,
241 "name": tool_name,
242 "content": json.dumps(tool_result),
243 }
244 )
246 # Check if any tool initiated interactive mode - if so, stop processing
247 from ..interactive_context import InteractiveContext
249 interactive_context = InteractiveContext()
250 if interactive_context.is_in_interactive_mode():
251 # A tool has initiated interactive mode, stop agent processing
252 # Return empty response to let TUI handle the interaction
253 logging.debug(
254 "Tool initiated interactive mode, stopping agent processing"
255 )
256 return ""
258 # Continue the loop to get the next LLM response based on tool results
259 continue
260 else:
261 # No tool calls, this is the final response
262 final_content = response_message.content
263 # remove all lines with any <function> tags
264 final_content = "\n".join(
265 line
266 for line in final_content.splitlines()
267 if "<function" not in line
268 )
269 self.add_assistant_message(final_content)
270 return final_content
272 logging.error(
273 "process_with_tools reached maximum iterations without final response"
274 )
275 error_msg = "Error: maximum iterations reached."
276 self.add_assistant_message(error_msg)
277 return error_msg
279 def process_query(self, query):
280 """Process a general query using available tools
282 Args:
283 query: User's query text
285 Returns:
286 Final response from the LLM
287 """
288 # If no system message exists, add the default one
289 has_system = False
290 for msg in self.conversation_history:
291 if msg["role"] == "system":
292 has_system = True
293 break
295 if not has_system:
296 self.add_system_message(DEFAULT_SYSTEM_MESSAGE)
298 # Add user message to history
299 self.add_user_message(query)
301 # Get available tools
302 tools = get_tool_schemas()
304 # Process using the LLM
305 return self.process_with_tools(tools)