Coverage for src/chuck_data/commands/agent.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-05 22:56 -0700

1""" 

2Command for processing natural language queries using an LLM agent. 

3 

4This module contains the handler for the agent command which processes 

5natural language queries and interacts with Databricks resources. 

6""" 

7 

8import logging 

9from typing import Optional, Any 

10 

11from ..clients.databricks import DatabricksAPIClient 

12from ..command_registry import CommandDefinition 

13from ..commands.base import CommandResult 

14from ..metrics_collector import get_metrics_collector 

15 

16 

17def handle_command( 

18 client: Optional[DatabricksAPIClient], **kwargs: Any 

19) -> CommandResult: 

20 """ 

21 Process a natural language query using the LLM agent. 

22 

23 Args: 

24 client: DatabricksAPIClient instance for API calls (optional) 

25 **kwargs: Command parameters 

26 - query: The natural language query from the user 

27 - mode: Optional agent mode (general, pii, bulk_pii, stitch) 

28 - rest: Any additional text input provided after the command 

29 - raw_args: Unparsed arguments (fallback when command parser fails) 

30 - catalog_name: Optional catalog name for context 

31 - schema_name: Optional schema name for context 

32 

33 Returns: 

34 CommandResult with agent response 

35 """ 

36 # First check for different ways the query might be provided 

37 # Priority: 1. query parameter, 2. rest parameter, 3. raw_args 

38 query = kwargs.get("query") 

39 rest = kwargs.get("rest") 

40 raw_args = kwargs.get("raw_args") 

41 

42 # If query wasn't provided but we have rest or raw_args, use that as the query 

43 if not query: # This checks if the initial kwargs.get("query") was empty/None 

44 if rest: 

45 query = rest 

46 elif raw_args: 

47 if isinstance(raw_args, (list, tuple)): 

48 query = " ".join(str(arg) for arg in raw_args) 

49 else: 

50 # Handle case where raw_args is a single string 

51 query = str(raw_args) 

52 

53 # At this point, query might be a string from 'query', 'rest', 'raw_args', or still None. 

54 # Strip whitespace if query is a string. This handles cases like " ". 

55 # If query is None, .strip() would error, so we check isinstance. 

56 if isinstance(query, str): 

57 query = query.strip() 

58 

59 # Now, check if the (potentially stripped) query is truly empty or None. 

60 if not query: 

61 return CommandResult( 

62 False, message="Please provide a query. Usage: /ask Your question here" 

63 ) 

64 

65 # Get optional parameters 

66 mode = kwargs.get("mode", "general").lower() 

67 catalog_name = kwargs.get("catalog_name") 

68 schema_name = kwargs.get("schema_name") 

69 tool_output_callback = kwargs.get("tool_output_callback") 

70 

71 try: 

72 from ..agent import AgentManager 

73 from ..config import get_agent_history, set_agent_history 

74 

75 # Get metrics collector 

76 metrics_collector = get_metrics_collector() 

77 

78 # Create agent manager with the API client and tool output callback 

79 agent = AgentManager(client, tool_output_callback=tool_output_callback) 

80 

81 # Load conversation history 

82 try: 

83 history = get_agent_history() 

84 except Exception: 

85 history = [] 

86 

87 if history: 

88 agent.conversation_history = history 

89 

90 # Process the query based on the selected mode 

91 if mode == "pii": 

92 # PII detection mode for a single table 

93 response = agent.process_pii_detection( 

94 table_name=query, catalog_name=catalog_name, schema_name=schema_name 

95 ) 

96 elif mode == "bulk_pii": 

97 # Bulk PII scanning mode for a schema 

98 response = agent.process_bulk_pii_scan( 

99 catalog_name=catalog_name, schema_name=schema_name 

100 ) 

101 elif mode == "stitch": 

102 # Stitch setup mode 

103 response = agent.process_setup_stitch( 

104 catalog_name=catalog_name, schema_name=schema_name 

105 ) 

106 else: 

107 # Default general query mode 

108 response = agent.process_query(query) 

109 

110 # Save conversation history 

111 set_agent_history(agent.conversation_history) 

112 

113 # Track the agent interaction event 

114 if mode == "pii": 

115 # For PII detection mode 

116 processed_tools = [{"name": "pii_detection", "arguments": {"table": query}}] 

117 event_context = "agent_interaction" 

118 additional_data = {"event_context": event_context, "agent_mode": mode} 

119 elif mode == "bulk_pii": 

120 # For bulk PII scanning mode 

121 processed_tools = [ 

122 { 

123 "name": "bulk_pii_scan", 

124 "arguments": {"catalog": catalog_name, "schema": schema_name}, 

125 } 

126 ] 

127 event_context = "agent_interaction" 

128 additional_data = {"event_context": event_context, "agent_mode": mode} 

129 elif mode == "stitch": 

130 # For Stitch setup mode 

131 processed_tools = [ 

132 { 

133 "name": "setup_stitch", 

134 "arguments": {"catalog": catalog_name, "schema": schema_name}, 

135 } 

136 ] 

137 event_context = "agent_interaction" 

138 additional_data = {"event_context": event_context, "agent_mode": mode} 

139 else: 

140 # For general query mode 

141 processed_tools = [{"name": "general_query", "arguments": {"query": query}}] 

142 event_context = "agent_interaction" 

143 additional_data = {"event_context": event_context, "agent_mode": mode} 

144 

145 # Get the last AI response from the conversation history 

146 last_ai_response = None 

147 if agent.conversation_history and len(agent.conversation_history) > 0: 

148 for msg in reversed(agent.conversation_history): 

149 # Handle both dict messages and ChatCompletionMessage objects 

150 role = ( 

151 msg.get("role") 

152 if hasattr(msg, "get") 

153 else getattr(msg, "role", None) 

154 ) 

155 if role == "assistant": 

156 last_ai_response = msg 

157 break 

158 

159 # Track the event 

160 metrics_collector.track_event( 

161 prompt=query, 

162 tools=processed_tools, 

163 conversation_history=[last_ai_response] if last_ai_response else None, 

164 additional_data=additional_data, 

165 ) 

166 

167 return CommandResult( 

168 True, 

169 data={"response": response, "conversation": agent.conversation_history}, 

170 ) 

171 

172 except Exception as e: 

173 # Handle pagination cancellation specially - let it bubble up 

174 from ..exceptions import PaginationCancelled 

175 

176 if isinstance(e, PaginationCancelled): 

177 raise # Re-raise to bubble up to main TUI loop 

178 

179 logging.error(f"Agent error: {e}", exc_info=True) 

180 return CommandResult( 

181 False, message=f"Failed to process query: {str(e)}", error=e 

182 ) 

183 

184 

185DEFINITION = CommandDefinition( 

186 name="agent", 

187 description="Process natural language queries using an LLM agent", 

188 handler=handle_command, 

189 parameters={ 

190 "query": { 

191 "type": "string", 

192 "description": "Natural language query to process", 

193 }, 

194 "mode": { 

195 "type": "string", 

196 "description": "Agent mode (general, pii, bulk_pii, stitch)", 

197 "default": "general", 

198 }, 

199 "catalog_name": { 

200 "type": "string", 

201 "description": "Optional catalog name for context (uses active catalog if not provided)", 

202 }, 

203 "schema_name": { 

204 "type": "string", 

205 "description": "Optional schema name for context (uses active schema if not provided)", 

206 }, 

207 "rest": { 

208 "type": "string", 

209 "description": "Additional text after the command to use as query", 

210 }, 

211 "raw_args": { 

212 "type": ["array", "string"], 

213 "description": "Raw unparsed arguments in case parsing fails", 

214 }, 

215 }, 

216 # Not requiring query since we handle combining raw_args and rest in the handler 

217 required_params=[], 

218 tui_aliases=["/agent", "/ask"], 

219 needs_api_client=True, 

220 visible_to_user=True, 

221 visible_to_agent=False, # Don't let the agent use itself 

222 usage_hint='Usage: /ask Your natural language question here\n /agent --query "Your question here" [--mode general|pii|bulk_pii|stitch]', 

223)