Coverage for src/chuck_data/commands/wizard/steps.py: 0%

153 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-05 22:56 -0700

1""" 

2Step handlers for setup wizard. 

3""" 

4 

5from abc import ABC, abstractmethod 

6import logging 

7 

8from .state import WizardState, StepResult, WizardStep, WizardAction 

9from .validator import InputValidator 

10 

11from ...clients.amperity import AmperityAPIClient 

12from ...config import ( 

13 get_amperity_token, 

14 set_workspace_url, 

15 set_databricks_token, 

16 set_active_model, 

17 set_usage_tracking_consent, 

18) 

19from ...ui.tui import get_chuck_service 

20from ...models import list_models 

21 

22 

23class SetupStep(ABC): 

24 """Base class for setup wizard steps.""" 

25 

26 def __init__(self, validator: InputValidator): 

27 self.validator = validator 

28 

29 @abstractmethod 

30 def handle_input(self, input_text: str, state: WizardState) -> StepResult: 

31 """Handle user input for this step.""" 

32 pass 

33 

34 @abstractmethod 

35 def get_prompt_message(self, state: WizardState) -> str: 

36 """Get the prompt message for this step.""" 

37 pass 

38 

39 @abstractmethod 

40 def get_step_title(self) -> str: 

41 """Get the title for this step.""" 

42 pass 

43 

44 def should_hide_input(self, state: WizardState) -> bool: 

45 """Whether input should be hidden (for passwords/tokens).""" 

46 return False 

47 

48 

49class AmperityAuthStep(SetupStep): 

50 """Handle Amperity authentication.""" 

51 

52 def get_step_title(self) -> str: 

53 return "Amperity Authentication" 

54 

55 def get_prompt_message(self, state: WizardState) -> str: 

56 return "Starting Amperity authentication..." 

57 

58 def handle_input(self, input_text: str, state: WizardState) -> StepResult: 

59 """Handle Amperity authentication - this step doesn't take input.""" 

60 # Check if we already have a valid token 

61 existing_token = get_amperity_token() 

62 

63 if existing_token: 

64 return StepResult( 

65 success=True, 

66 message="Amperity token already exists. Proceeding to Databricks setup.", 

67 next_step=WizardStep.WORKSPACE_URL, 

68 action=WizardAction.CONTINUE, 

69 ) 

70 

71 # Initialize the auth manager and start the flow 

72 try: 

73 auth_manager = AmperityAPIClient() 

74 success, message = auth_manager.start_auth() 

75 

76 if not success: 

77 return StepResult( 

78 success=False, 

79 message=f"Error starting Amperity authentication: {message}", 

80 action=WizardAction.RETRY, 

81 ) 

82 

83 # Block until authentication completes 

84 auth_success, auth_message = auth_manager.wait_for_auth_completion( 

85 poll_interval=1 

86 ) 

87 

88 if auth_success: 

89 return StepResult( 

90 success=True, 

91 message="Amperity authentication complete. Proceeding to Databricks setup.", 

92 next_step=WizardStep.WORKSPACE_URL, 

93 action=WizardAction.CONTINUE, 

94 ) 

95 else: 

96 # Check if cancelled 

97 if "cancelled" in auth_message.lower(): 

98 return StepResult( 

99 success=False, 

100 message="Setup cancelled. Run /setup again when ready.", 

101 action=WizardAction.EXIT, 

102 ) 

103 

104 # Clean up error message 

105 clean_message = auth_message 

106 if auth_message.lower().startswith("authentication failed:"): 

107 clean_message = auth_message.split(":", 1)[1].strip() 

108 

109 return StepResult( 

110 success=False, 

111 message=f"Authentication failed: {clean_message}", 

112 action=WizardAction.RETRY, 

113 ) 

114 

115 except Exception as e: 

116 logging.error(f"Error in Amperity authentication: {e}") 

117 return StepResult( 

118 success=False, 

119 message=f"Authentication error: {str(e)}", 

120 action=WizardAction.RETRY, 

121 ) 

122 

123 

124class WorkspaceUrlStep(SetupStep): 

125 """Handle workspace URL input.""" 

126 

127 def get_step_title(self) -> str: 

128 return "Databricks Workspace" 

129 

130 def get_prompt_message(self, state: WizardState) -> str: 

131 return "Please enter your Databricks workspace URL (e.g., https://my-workspace.cloud.databricks.com)" 

132 

133 def handle_input(self, input_text: str, state: WizardState) -> StepResult: 

134 """Handle workspace URL input.""" 

135 # Validate the input 

136 validation = self.validator.validate_workspace_url(input_text) 

137 

138 if not validation.is_valid: 

139 return StepResult( 

140 success=False, message=validation.message, action=WizardAction.RETRY 

141 ) 

142 

143 # Store the validated URL 

144 return StepResult( 

145 success=True, 

146 message="Workspace URL validated. Now enter your Databricks token.", 

147 next_step=WizardStep.TOKEN_INPUT, 

148 action=WizardAction.CONTINUE, 

149 data={"workspace_url": validation.processed_value}, 

150 ) 

151 

152 

153class TokenInputStep(SetupStep): 

154 """Handle Databricks token input.""" 

155 

156 def get_step_title(self) -> str: 

157 return "Databricks Token" 

158 

159 def get_prompt_message(self, state: WizardState) -> str: 

160 return "Please enter your Databricks API token:" 

161 

162 def should_hide_input(self, state: WizardState) -> bool: 

163 return True # Hide token input 

164 

165 def handle_input(self, input_text: str, state: WizardState) -> StepResult: 

166 """Handle token input.""" 

167 if not state.workspace_url: 

168 return StepResult( 

169 success=False, 

170 message="Workspace URL not set. Please restart the wizard.", 

171 action=WizardAction.EXIT, 

172 ) 

173 

174 # Validate the token 

175 validation = self.validator.validate_token(input_text, state.workspace_url) 

176 

177 if not validation.is_valid: 

178 return StepResult( 

179 success=False, 

180 message=f"{validation.message}. Please re-enter your workspace URL and token.", 

181 next_step=WizardStep.WORKSPACE_URL, 

182 action=WizardAction.CONTINUE, 

183 ) 

184 

185 try: 

186 # Save workspace URL and token 

187 url_success = set_workspace_url(state.workspace_url) 

188 if not url_success: 

189 return StepResult( 

190 success=False, 

191 message="Failed to save workspace URL. Please try again.", 

192 action=WizardAction.RETRY, 

193 ) 

194 

195 token_success = set_databricks_token(validation.processed_value) 

196 if not token_success: 

197 return StepResult( 

198 success=False, 

199 message="Failed to save Databricks token. Please try again.", 

200 action=WizardAction.RETRY, 

201 ) 

202 

203 # Reinitialize the service client 

204 service = get_chuck_service() 

205 if service: 

206 init_success = service.reinitialize_client() 

207 if not init_success: 

208 logging.warning( 

209 "Failed to reinitialize client, but credentials were saved" 

210 ) 

211 return StepResult( 

212 success=True, 

213 message="Credentials saved but client reinitialization failed.", 

214 next_step=WizardStep.USAGE_CONSENT, 

215 action=WizardAction.CONTINUE, 

216 data={"token": validation.processed_value, "models": []}, 

217 ) 

218 

219 # Try to list models 

220 try: 

221 models = list_models(service.client) 

222 logging.debug(f"Found {len(models)} models") 

223 

224 if models: 

225 return StepResult( 

226 success=True, 

227 message="Databricks configured. Select a model.", 

228 next_step=WizardStep.MODEL_SELECTION, 

229 action=WizardAction.CONTINUE, 

230 data={ 

231 "token": validation.processed_value, 

232 "models": models, 

233 }, 

234 ) 

235 else: 

236 return StepResult( 

237 success=True, 

238 message="No models found. Proceeding to usage consent.", 

239 next_step=WizardStep.USAGE_CONSENT, 

240 action=WizardAction.CONTINUE, 

241 data={"token": validation.processed_value, "models": []}, 

242 ) 

243 

244 except Exception as e: 

245 logging.error(f"Error listing models: {e}") 

246 return StepResult( 

247 success=True, 

248 message="Error listing models. Proceeding to usage consent.", 

249 next_step=WizardStep.USAGE_CONSENT, 

250 action=WizardAction.CONTINUE, 

251 data={"token": validation.processed_value, "models": []}, 

252 ) 

253 else: 

254 return StepResult( 

255 success=True, 

256 message="Databricks configured. Proceeding to usage consent.", 

257 next_step=WizardStep.USAGE_CONSENT, 

258 action=WizardAction.CONTINUE, 

259 data={"token": validation.processed_value, "models": []}, 

260 ) 

261 

262 except Exception as e: 

263 logging.error(f"Error saving Databricks configuration: {e}") 

264 return StepResult( 

265 success=False, 

266 message=f"Error saving configuration: {str(e)}", 

267 action=WizardAction.RETRY, 

268 ) 

269 

270 

271class ModelSelectionStep(SetupStep): 

272 """Handle model selection.""" 

273 

274 def get_step_title(self) -> str: 

275 return "LLM Model Selection" 

276 

277 def get_prompt_message(self, state: WizardState) -> str: 

278 return "Please enter the number or name of the model you want to use:" 

279 

280 def handle_input(self, input_text: str, state: WizardState) -> StepResult: 

281 """Handle model selection input.""" 

282 if not state.models: 

283 return StepResult( 

284 success=False, 

285 message="No models available. Restarting wizard at workspace setup step.", 

286 next_step=WizardStep.WORKSPACE_URL, 

287 action=WizardAction.CONTINUE, 

288 ) 

289 

290 # Sort models the same way as display (recommended first) 

291 recommended_models = [ 

292 "databricks-meta-llama-3-3-70b-instruct", 

293 "databricks-claude-3-7-sonnet", 

294 ] 

295 

296 sorted_models = [] 

297 

298 # Add recommended models first 

299 for rec_model in recommended_models: 

300 for model in state.models: 

301 if model["name"] == rec_model: 

302 sorted_models.append(model) 

303 break 

304 

305 # Add remaining models 

306 for model in state.models: 

307 if model["name"] not in recommended_models: 

308 sorted_models.append(model) 

309 

310 # Validate the selection 

311 validation = self.validator.validate_model_selection(input_text, sorted_models) 

312 

313 if not validation.is_valid: 

314 return StepResult( 

315 success=False, message=validation.message, action=WizardAction.RETRY 

316 ) 

317 

318 # Save the selected model 

319 try: 

320 success = set_active_model(validation.processed_value) 

321 

322 if success: 

323 return StepResult( 

324 success=True, 

325 message=f"Model '{validation.processed_value}' selected. Proceeding to usage consent.", 

326 next_step=WizardStep.USAGE_CONSENT, 

327 action=WizardAction.CONTINUE, 

328 data={"selected_model": validation.processed_value}, 

329 ) 

330 else: 

331 return StepResult( 

332 success=False, 

333 message="Failed to save model selection. Please try again.", 

334 action=WizardAction.RETRY, 

335 ) 

336 

337 except Exception as e: 

338 logging.error(f"Error saving model selection: {e}") 

339 return StepResult( 

340 success=False, 

341 message=f"Error saving model selection: {str(e)}", 

342 action=WizardAction.RETRY, 

343 ) 

344 

345 

346class UsageConsentStep(SetupStep): 

347 """Handle usage tracking consent.""" 

348 

349 def get_step_title(self) -> str: 

350 return "Usage Tracking Consent" 

351 

352 def get_prompt_message(self, state: WizardState) -> str: 

353 return ( 

354 "Do you consent to sharing your usage information with Amperity (yes/no)?" 

355 ) 

356 

357 def handle_input(self, input_text: str, state: WizardState) -> StepResult: 

358 """Handle usage consent input.""" 

359 # Validate the input 

360 validation = self.validator.validate_usage_consent(input_text) 

361 

362 if not validation.is_valid: 

363 return StepResult( 

364 success=False, message=validation.message, action=WizardAction.RETRY 

365 ) 

366 

367 # Save the consent 

368 try: 

369 consent = validation.processed_value == "yes" 

370 success = set_usage_tracking_consent(consent) 

371 

372 if success: 

373 if consent: 

374 message = "Thank you for helping us make Chuck better! Setup wizard completed successfully!" 

375 else: 

376 message = "We understand, Chuck will not share your usage with Amperity. Setup wizard completed successfully!" 

377 

378 return StepResult( 

379 success=True, 

380 message=message, 

381 next_step=WizardStep.COMPLETE, 

382 action=WizardAction.COMPLETE, 

383 data={"usage_consent": consent}, 

384 ) 

385 else: 

386 return StepResult( 

387 success=False, 

388 message="Failed to save usage tracking preference. Please try again.", 

389 action=WizardAction.RETRY, 

390 ) 

391 

392 except Exception as e: 

393 logging.error(f"Error saving usage consent: {e}") 

394 return StepResult( 

395 success=False, 

396 message=f"Error saving usage consent: {str(e)}", 

397 action=WizardAction.RETRY, 

398 ) 

399 

400 

401# Step factory 

402def create_step(step_type: WizardStep, validator: InputValidator) -> SetupStep: 

403 """Factory function to create step handlers.""" 

404 step_map = { 

405 WizardStep.AMPERITY_AUTH: AmperityAuthStep, 

406 WizardStep.WORKSPACE_URL: WorkspaceUrlStep, 

407 WizardStep.TOKEN_INPUT: TokenInputStep, 

408 WizardStep.MODEL_SELECTION: ModelSelectionStep, 

409 WizardStep.USAGE_CONSENT: UsageConsentStep, 

410 } 

411 

412 step_class = step_map.get(step_type) 

413 if not step_class: 

414 raise ValueError(f"Unknown step type: {step_type}") 

415 

416 return step_class(validator)