Coverage for src/chuck_data/commands/wizard/steps.py: 0%
153 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-05 22:56 -0700
1"""
2Step handlers for setup wizard.
3"""
5from abc import ABC, abstractmethod
6import logging
8from .state import WizardState, StepResult, WizardStep, WizardAction
9from .validator import InputValidator
11from ...clients.amperity import AmperityAPIClient
12from ...config import (
13 get_amperity_token,
14 set_workspace_url,
15 set_databricks_token,
16 set_active_model,
17 set_usage_tracking_consent,
18)
19from ...ui.tui import get_chuck_service
20from ...models import list_models
23class SetupStep(ABC):
24 """Base class for setup wizard steps."""
26 def __init__(self, validator: InputValidator):
27 self.validator = validator
29 @abstractmethod
30 def handle_input(self, input_text: str, state: WizardState) -> StepResult:
31 """Handle user input for this step."""
32 pass
34 @abstractmethod
35 def get_prompt_message(self, state: WizardState) -> str:
36 """Get the prompt message for this step."""
37 pass
39 @abstractmethod
40 def get_step_title(self) -> str:
41 """Get the title for this step."""
42 pass
44 def should_hide_input(self, state: WizardState) -> bool:
45 """Whether input should be hidden (for passwords/tokens)."""
46 return False
49class AmperityAuthStep(SetupStep):
50 """Handle Amperity authentication."""
52 def get_step_title(self) -> str:
53 return "Amperity Authentication"
55 def get_prompt_message(self, state: WizardState) -> str:
56 return "Starting Amperity authentication..."
58 def handle_input(self, input_text: str, state: WizardState) -> StepResult:
59 """Handle Amperity authentication - this step doesn't take input."""
60 # Check if we already have a valid token
61 existing_token = get_amperity_token()
63 if existing_token:
64 return StepResult(
65 success=True,
66 message="Amperity token already exists. Proceeding to Databricks setup.",
67 next_step=WizardStep.WORKSPACE_URL,
68 action=WizardAction.CONTINUE,
69 )
71 # Initialize the auth manager and start the flow
72 try:
73 auth_manager = AmperityAPIClient()
74 success, message = auth_manager.start_auth()
76 if not success:
77 return StepResult(
78 success=False,
79 message=f"Error starting Amperity authentication: {message}",
80 action=WizardAction.RETRY,
81 )
83 # Block until authentication completes
84 auth_success, auth_message = auth_manager.wait_for_auth_completion(
85 poll_interval=1
86 )
88 if auth_success:
89 return StepResult(
90 success=True,
91 message="Amperity authentication complete. Proceeding to Databricks setup.",
92 next_step=WizardStep.WORKSPACE_URL,
93 action=WizardAction.CONTINUE,
94 )
95 else:
96 # Check if cancelled
97 if "cancelled" in auth_message.lower():
98 return StepResult(
99 success=False,
100 message="Setup cancelled. Run /setup again when ready.",
101 action=WizardAction.EXIT,
102 )
104 # Clean up error message
105 clean_message = auth_message
106 if auth_message.lower().startswith("authentication failed:"):
107 clean_message = auth_message.split(":", 1)[1].strip()
109 return StepResult(
110 success=False,
111 message=f"Authentication failed: {clean_message}",
112 action=WizardAction.RETRY,
113 )
115 except Exception as e:
116 logging.error(f"Error in Amperity authentication: {e}")
117 return StepResult(
118 success=False,
119 message=f"Authentication error: {str(e)}",
120 action=WizardAction.RETRY,
121 )
124class WorkspaceUrlStep(SetupStep):
125 """Handle workspace URL input."""
127 def get_step_title(self) -> str:
128 return "Databricks Workspace"
130 def get_prompt_message(self, state: WizardState) -> str:
131 return "Please enter your Databricks workspace URL (e.g., https://my-workspace.cloud.databricks.com)"
133 def handle_input(self, input_text: str, state: WizardState) -> StepResult:
134 """Handle workspace URL input."""
135 # Validate the input
136 validation = self.validator.validate_workspace_url(input_text)
138 if not validation.is_valid:
139 return StepResult(
140 success=False, message=validation.message, action=WizardAction.RETRY
141 )
143 # Store the validated URL
144 return StepResult(
145 success=True,
146 message="Workspace URL validated. Now enter your Databricks token.",
147 next_step=WizardStep.TOKEN_INPUT,
148 action=WizardAction.CONTINUE,
149 data={"workspace_url": validation.processed_value},
150 )
153class TokenInputStep(SetupStep):
154 """Handle Databricks token input."""
156 def get_step_title(self) -> str:
157 return "Databricks Token"
159 def get_prompt_message(self, state: WizardState) -> str:
160 return "Please enter your Databricks API token:"
162 def should_hide_input(self, state: WizardState) -> bool:
163 return True # Hide token input
165 def handle_input(self, input_text: str, state: WizardState) -> StepResult:
166 """Handle token input."""
167 if not state.workspace_url:
168 return StepResult(
169 success=False,
170 message="Workspace URL not set. Please restart the wizard.",
171 action=WizardAction.EXIT,
172 )
174 # Validate the token
175 validation = self.validator.validate_token(input_text, state.workspace_url)
177 if not validation.is_valid:
178 return StepResult(
179 success=False,
180 message=f"{validation.message}. Please re-enter your workspace URL and token.",
181 next_step=WizardStep.WORKSPACE_URL,
182 action=WizardAction.CONTINUE,
183 )
185 try:
186 # Save workspace URL and token
187 url_success = set_workspace_url(state.workspace_url)
188 if not url_success:
189 return StepResult(
190 success=False,
191 message="Failed to save workspace URL. Please try again.",
192 action=WizardAction.RETRY,
193 )
195 token_success = set_databricks_token(validation.processed_value)
196 if not token_success:
197 return StepResult(
198 success=False,
199 message="Failed to save Databricks token. Please try again.",
200 action=WizardAction.RETRY,
201 )
203 # Reinitialize the service client
204 service = get_chuck_service()
205 if service:
206 init_success = service.reinitialize_client()
207 if not init_success:
208 logging.warning(
209 "Failed to reinitialize client, but credentials were saved"
210 )
211 return StepResult(
212 success=True,
213 message="Credentials saved but client reinitialization failed.",
214 next_step=WizardStep.USAGE_CONSENT,
215 action=WizardAction.CONTINUE,
216 data={"token": validation.processed_value, "models": []},
217 )
219 # Try to list models
220 try:
221 models = list_models(service.client)
222 logging.debug(f"Found {len(models)} models")
224 if models:
225 return StepResult(
226 success=True,
227 message="Databricks configured. Select a model.",
228 next_step=WizardStep.MODEL_SELECTION,
229 action=WizardAction.CONTINUE,
230 data={
231 "token": validation.processed_value,
232 "models": models,
233 },
234 )
235 else:
236 return StepResult(
237 success=True,
238 message="No models found. Proceeding to usage consent.",
239 next_step=WizardStep.USAGE_CONSENT,
240 action=WizardAction.CONTINUE,
241 data={"token": validation.processed_value, "models": []},
242 )
244 except Exception as e:
245 logging.error(f"Error listing models: {e}")
246 return StepResult(
247 success=True,
248 message="Error listing models. Proceeding to usage consent.",
249 next_step=WizardStep.USAGE_CONSENT,
250 action=WizardAction.CONTINUE,
251 data={"token": validation.processed_value, "models": []},
252 )
253 else:
254 return StepResult(
255 success=True,
256 message="Databricks configured. Proceeding to usage consent.",
257 next_step=WizardStep.USAGE_CONSENT,
258 action=WizardAction.CONTINUE,
259 data={"token": validation.processed_value, "models": []},
260 )
262 except Exception as e:
263 logging.error(f"Error saving Databricks configuration: {e}")
264 return StepResult(
265 success=False,
266 message=f"Error saving configuration: {str(e)}",
267 action=WizardAction.RETRY,
268 )
271class ModelSelectionStep(SetupStep):
272 """Handle model selection."""
274 def get_step_title(self) -> str:
275 return "LLM Model Selection"
277 def get_prompt_message(self, state: WizardState) -> str:
278 return "Please enter the number or name of the model you want to use:"
280 def handle_input(self, input_text: str, state: WizardState) -> StepResult:
281 """Handle model selection input."""
282 if not state.models:
283 return StepResult(
284 success=False,
285 message="No models available. Restarting wizard at workspace setup step.",
286 next_step=WizardStep.WORKSPACE_URL,
287 action=WizardAction.CONTINUE,
288 )
290 # Sort models the same way as display (recommended first)
291 recommended_models = [
292 "databricks-meta-llama-3-3-70b-instruct",
293 "databricks-claude-3-7-sonnet",
294 ]
296 sorted_models = []
298 # Add recommended models first
299 for rec_model in recommended_models:
300 for model in state.models:
301 if model["name"] == rec_model:
302 sorted_models.append(model)
303 break
305 # Add remaining models
306 for model in state.models:
307 if model["name"] not in recommended_models:
308 sorted_models.append(model)
310 # Validate the selection
311 validation = self.validator.validate_model_selection(input_text, sorted_models)
313 if not validation.is_valid:
314 return StepResult(
315 success=False, message=validation.message, action=WizardAction.RETRY
316 )
318 # Save the selected model
319 try:
320 success = set_active_model(validation.processed_value)
322 if success:
323 return StepResult(
324 success=True,
325 message=f"Model '{validation.processed_value}' selected. Proceeding to usage consent.",
326 next_step=WizardStep.USAGE_CONSENT,
327 action=WizardAction.CONTINUE,
328 data={"selected_model": validation.processed_value},
329 )
330 else:
331 return StepResult(
332 success=False,
333 message="Failed to save model selection. Please try again.",
334 action=WizardAction.RETRY,
335 )
337 except Exception as e:
338 logging.error(f"Error saving model selection: {e}")
339 return StepResult(
340 success=False,
341 message=f"Error saving model selection: {str(e)}",
342 action=WizardAction.RETRY,
343 )
346class UsageConsentStep(SetupStep):
347 """Handle usage tracking consent."""
349 def get_step_title(self) -> str:
350 return "Usage Tracking Consent"
352 def get_prompt_message(self, state: WizardState) -> str:
353 return (
354 "Do you consent to sharing your usage information with Amperity (yes/no)?"
355 )
357 def handle_input(self, input_text: str, state: WizardState) -> StepResult:
358 """Handle usage consent input."""
359 # Validate the input
360 validation = self.validator.validate_usage_consent(input_text)
362 if not validation.is_valid:
363 return StepResult(
364 success=False, message=validation.message, action=WizardAction.RETRY
365 )
367 # Save the consent
368 try:
369 consent = validation.processed_value == "yes"
370 success = set_usage_tracking_consent(consent)
372 if success:
373 if consent:
374 message = "Thank you for helping us make Chuck better! Setup wizard completed successfully!"
375 else:
376 message = "We understand, Chuck will not share your usage with Amperity. Setup wizard completed successfully!"
378 return StepResult(
379 success=True,
380 message=message,
381 next_step=WizardStep.COMPLETE,
382 action=WizardAction.COMPLETE,
383 data={"usage_consent": consent},
384 )
385 else:
386 return StepResult(
387 success=False,
388 message="Failed to save usage tracking preference. Please try again.",
389 action=WizardAction.RETRY,
390 )
392 except Exception as e:
393 logging.error(f"Error saving usage consent: {e}")
394 return StepResult(
395 success=False,
396 message=f"Error saving usage consent: {str(e)}",
397 action=WizardAction.RETRY,
398 )
401# Step factory
402def create_step(step_type: WizardStep, validator: InputValidator) -> SetupStep:
403 """Factory function to create step handlers."""
404 step_map = {
405 WizardStep.AMPERITY_AUTH: AmperityAuthStep,
406 WizardStep.WORKSPACE_URL: WorkspaceUrlStep,
407 WizardStep.TOKEN_INPUT: TokenInputStep,
408 WizardStep.MODEL_SELECTION: ModelSelectionStep,
409 WizardStep.USAGE_CONSENT: UsageConsentStep,
410 }
412 step_class = step_map.get(step_type)
413 if not step_class:
414 raise ValueError(f"Unknown step type: {step_type}")
416 return step_class(validator)