interactor
64class Interactor: 65 def __init__( 66 self, 67 base_url: Optional[str] = None, 68 api_key: Optional[str] = None, 69 model: str = "openai:gpt-4o-mini", 70 fallback_model = "ollama:mistral-nemo:latest", 71 tools: Optional[bool] = True, 72 stream: bool = True, 73 quiet: bool = False, 74 context_length: int = 128000, 75 max_retries: int = 3, 76 retry_delay: float = 1.0, 77 log_path: Optional[str] = None, 78 raw: Optional[bool] = False, 79 session_enabled: bool = False, 80 session_id: Optional[str] = None, 81 session_path: Optional[str] = None 82 ): 83 """Initialize the universal AI interaction client. 84 85 Args: 86 base_url: Optional base URL for the API. If None, uses the provider's default URL. 87 api_key: Optional API key. If None, attempts to use environment variables based on provider. 88 model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). 89 tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. 90 stream: Enable (True) or disable (False) streaming responses. 91 context_length: Maximum number of tokens to maintain in conversation history. 92 max_retries: Maximum number of retries for failed API calls. 93 retry_delay: Initial delay (in seconds) for exponential backoff retries. 94 session_enabled: Enable persistent session support. 95 session_id: Optional session ID to load messages from. 96 97 Raises: 98 ValueError: If provider is not supported or API key is missing for non-Ollama providers. 99 """ 100 self.system = "You are a helpful Assistant." 101 self.raw = raw 102 self.quiet = quiet 103 self.logger = logging.getLogger(f"InteractorLogger_{id(self)}") 104 self.logger.setLevel(logging.DEBUG) 105 self.providers = { 106 "openai": { 107 "sdk": "openai", 108 "base_url": "https://api.openai.com/v1", 109 "api_key": api_key or os.getenv("OPENAI_API_KEY") or None 110 }, 111 "ollama": { 112 "sdk": "openai", 113 "base_url": "http://localhost:11434/v1", 114 "api_key": api_key or "ollama" 115 }, 116 "nvidia": { 117 "sdk": "openai", 118 "base_url": "https://integrate.api.nvidia.com/v1", 119 "api_key": api_key or os.getenv("NVIDIA_API_KEY") or None 120 }, 121 "google": { 122 "sdk": "openai", 123 "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", 124 "api_key": api_key or os.getenv("GEMINI_API_KEY") or None 125 }, 126 "anthropic": { 127 "sdk": "anthropic", 128 "base_url": "https://api.anthropic.com/v1", 129 "api_key": api_key or os.getenv("ANTHROPIC_API_KEY") or None 130 }, 131 "mistral": { 132 "sdk": "openai", 133 "base_url": "https://api.mistral.ai/v1", 134 "api_key": api_key or os.getenv("MISTRAL_API_KEY") or None 135 }, 136 "deepseek": { 137 "sdk": "openai", 138 "base_url": "https://api.deepseek.com", 139 "api_key": api_key or os.getenv("DEEPSEEK_API_KEY") or None 140 }, 141 } 142 """ 143 "grok": { 144 "sdk": "grok", 145 "base_url": "https://api.x.ai/v1", 146 "api_key": api_key or os.getenv("GROK_API_KEY") or None 147 } 148 } 149 """ 150 151 # Console log handler (always enabled at WARNING+) 152 if not self.logger.handlers: 153 console_handler = logging.StreamHandler(sys.stdout) 154 console_handler.setLevel(logging.WARNING) 155 console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) 156 self.logger.addHandler(console_handler) 157 158 self._log_enabled = False 159 if log_path: 160 file_handler = logging.FileHandler(log_path) 161 file_handler.setLevel(logging.DEBUG) 162 file_handler.setFormatter(logging.Formatter( 163 "%(asctime)s - %(levelname)s - %(message)s", 164 datefmt="%Y-%m-%d %H:%M:%S" 165 )) 166 self.logger.addHandler(file_handler) 167 self._log_enabled = True 168 169 170 self.token_estimate = 0 171 self.last_token_estimate = 0 172 self.stream = stream 173 self.tools = [] 174 self.session_history = [] 175 self.history = [] 176 self.context_length = context_length 177 self.encoding = None 178 self.max_retries = max_retries 179 self.retry_delay = retry_delay 180 self.reveal_tool = [] 181 self.fallback_model = fallback_model 182 self.sdk = None 183 184 # Session support 185 self.session_enabled = session_enabled 186 self.session_id = session_id 187 self._last_session_id = session_id 188 self.session = Session(directory=session_path) if session_enabled else None 189 190 191 if model is None: 192 model = "openai:gpt-4o-mini" 193 194 # Initialize model + encoding 195 self._setup_client(model, base_url, api_key) 196 self.tools_enabled = self.tools_supported if tools is None else tools and self.tools_supported 197 self._setup_encoding() 198 self.messages_add(role="system", content=self.system) 199 200 201 def _log(self, message: str, level: str = "info"): 202 """Log a message to the configured logging handlers. 203 204 This internal method handles logging to both console and file handlers 205 if configured. It respects the logging level and only logs if logging 206 is enabled. 207 208 Args: 209 message (str): The message to log 210 level (str): Logging level - one of "debug", "info", "warning", "error" 211 """ 212 if self._log_enabled: 213 getattr(self.logger, level)(message) 214 215 216 def _setup_client( 217 self, 218 model: Optional[str] = None, 219 base_url: Optional[str] = None, 220 api_key: Optional[str] = None 221 ): 222 """Initialize or reconfigure the Interactor for the given model and SDK. 223 224 Ensures idempotent setup, assigns SDK-specific clients and tool handling logic, 225 and normalizes history to match the provider-specific message schema. 226 """ 227 if not model: 228 raise ValueError("Model must be specified as 'provider:model_name'") 229 230 provider, model_name = model.split(":", 1) 231 232 if not hasattr(self, "session_history"): 233 self.session_history = [] 234 235 # Skip setup if nothing has changed (client may not yet exist on first call) 236 if ( 237 hasattr(self, "client") 238 and self.client 239 and self.provider == provider 240 and self.model == model_name 241 and self.base_url == (base_url or self.base_url) 242 ): 243 return 244 245 if provider not in self.providers: 246 raise ValueError(f"Unsupported provider: {provider}. Supported: {list(self.providers.keys())}") 247 248 # Load provider configuration 249 provider_config = self.providers[provider] 250 self.sdk = provider_config.get("sdk", "openai") 251 self.provider = provider 252 self.model = model_name 253 self.base_url = base_url or provider_config["base_url"] 254 effective_api_key = api_key or provider_config["api_key"] 255 256 if not effective_api_key and provider != "ollama": 257 raise ValueError(f"API key not provided and not found in environment for {provider.upper()}_API_KEY") 258 259 # SDK-specific configuration 260 if self.sdk == "openai": 261 self.client = openai.OpenAI(base_url=self.base_url, api_key=effective_api_key) 262 self.async_client = openai.AsyncOpenAI(base_url=self.base_url, api_key=effective_api_key) 263 self.sdk_runner = self._openai_runner 264 self.tool_key = "tool_call_id" 265 266 elif self.sdk == "anthropic": 267 self.client = anthropic.Anthropic(api_key=effective_api_key) 268 self.async_client = anthropic.AsyncAnthropic(api_key=effective_api_key) 269 self.sdk_runner = self._anthropic_runner 270 self.tool_key = "tool_use_id" 271 272 else: 273 raise ValueError(f"Unsupported SDK type: {self.sdk}") 274 275 # Determine tool support 276 self.tools_supported = self._check_tool_support() 277 if not self.tools_supported: 278 self.logger.warning(f"Tool calling not supported for {provider}:{model_name}") 279 280 # Normalize session history to match SDK after any provider/model change 281 self._normalizer(force=True) 282 283 self._log(f"[MODEL] Switched to {provider}:{model_name}") 284 285 286 def _check_tool_support(self) -> bool: 287 """Determine if the current model supports tool calling. 288 289 Returns: 290 bool: True if tools are supported for the active provider/model, False otherwise. 291 """ 292 try: 293 if self.sdk == "openai": 294 response = self.client.chat.completions.create( 295 model=self.model, 296 messages=[{"role": "user", "content": "Test tool support."}], 297 stream=False, 298 tools=[{ 299 "type": "function", 300 "function": { 301 "name": "test_tool", 302 "description": "Check tool support", 303 "parameters": { 304 "type": "object", 305 "properties": { 306 "query": {"type": "string"} 307 }, 308 "required": ["query"] 309 } 310 } 311 }], 312 tool_choice="auto" 313 ) 314 message = response.choices[0].message 315 return bool(message.tool_calls and len(message.tool_calls) > 0) 316 317 elif self.sdk == "anthropic": 318 # For Claude models, we pre-define support based on model ID 319 # Known tool-supporting Claude models 320 claude_tool_models = ["claude-3-opus", "claude-3-sonnet", "claude-3-haiku", 321 "claude-3.5-sonnet", "claude-3.7-sonnet"] 322 323 # Check if the current model supports tools 324 for supported_model in claude_tool_models: 325 if supported_model in self.model.lower(): 326 self._log(f"[TOOLS] Anthropic model {self.model} is known to support tools") 327 return True 328 329 # If not explicitly supported, try to test 330 try: 331 _ = self.client.messages.create( 332 model=self.model, 333 messages=[{"role": "user", "content": "What's the weather?"}], 334 tools=[{ 335 "name": "test_tool", 336 "description": "Check tool support", 337 "input_schema": { 338 "type": "object", 339 "properties": { 340 "query": {"type": "string"} 341 }, 342 "required": ["query"] 343 } 344 }], 345 max_tokens=10 346 ) 347 return True 348 except anthropic.BadRequestError as e: 349 error_msg = str(e).lower() 350 if "tool" in error_msg and "not supported" in error_msg: 351 self._log(f"[TOOLS] Anthropic model {self.model} does not support tools: {e}") 352 return False 353 if "not a supported tool field" in error_msg: 354 self._log(f"[TOOLS] Anthropic API rejected tool format: {e}") 355 return False 356 raise 357 except Exception as e: 358 self._log(f"[TOOLS] Unexpected error testing tool support: {e}", level="error") 359 return False 360 361 else: 362 self.logger.warning(f"Tool support check not implemented for SDK '{self.sdk}'") 363 return False 364 365 except Exception as e: 366 self.logger.error(f"Tool support check failed for {self.provider}:{self.model} — {e}") 367 return False 368 369 370 def add_function( 371 self, 372 external_callable: Callable, 373 name: Optional[str] = None, 374 description: Optional[str] = None, 375 override: bool = False, 376 disabled: bool = False, 377 schema_extensions: Optional[Dict[str, Any]] = None 378 ): 379 """ 380 Register a function for LLM tool calling with full type hints and metadata. 381 382 Args: 383 external_callable (Callable): The function to register. 384 name (Optional[str]): Optional custom name. Defaults to function's __name__. 385 description (Optional[str]): Optional custom description. Defaults to first line of docstring. 386 override (bool): If True, replaces an existing tool with the same name. 387 disabled (bool): If True, registers the function in a disabled state. 388 schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to 389 schema extensions that override or add to the auto-generated schema. 390 391 Raises: 392 ValueError: If the callable is invalid or duplicate name found without override. 393 394 Example: 395 interactor.add_function( 396 my_tool, 397 override=True, 398 disabled=False, 399 schema_extensions={ 400 "param1": {"minimum": 0, "maximum": 100}, 401 "param2": {"format": "email"} 402 } 403 ) 404 """ 405 def _python_type_to_schema(ptype: Any) -> dict: 406 """Convert a Python type annotation to OpenAI-compatible JSON Schema.""" 407 # Handle None case 408 if ptype is None: 409 return {"type": "null"} 410 411 # Get the origin and arguments of the type 412 origin = get_origin(ptype) 413 args = get_args(ptype) 414 415 # Handle Union types (including Optional) 416 if origin is Union: 417 # Check for Optional (Union with None) 418 none_type = type(None) 419 if none_type in args: 420 non_none = [a for a in args if a is not none_type] 421 if len(non_none) == 1: 422 inner = _python_type_to_schema(non_none[0]) 423 inner_copy = inner.copy() 424 inner_copy["nullable"] = True 425 return inner_copy 426 # Multiple types excluding None 427 types = [_python_type_to_schema(a) for a in non_none] 428 return {"anyOf": types, "nullable": True} 429 # Regular Union without None 430 return {"anyOf": [_python_type_to_schema(a) for a in args]} 431 432 # Handle List and similar container types 433 if origin in (list, List): 434 item_type = args[0] if args else Any 435 if item_type is Any: 436 return {"type": "array"} 437 return {"type": "array", "items": _python_type_to_schema(item_type)} 438 439 # Handle Dict types with typing info 440 if origin in (dict, Dict): 441 if not args or len(args) != 2: 442 return {"type": "object"} 443 444 key_type, val_type = args 445 # We can only really use val_type in JSON Schema 446 if val_type is not Any and val_type is not object: 447 return { 448 "type": "object", 449 "additionalProperties": _python_type_to_schema(val_type) 450 } 451 return {"type": "object"} 452 453 # Handle Literal types for enums 454 if origin is Literal: 455 values = args 456 # Try to determine type from values 457 if all(isinstance(v, str) for v in values): 458 return {"type": "string", "enum": list(values)} 459 elif all(isinstance(v, bool) for v in values): 460 return {"type": "boolean", "enum": list(values)} 461 elif all(isinstance(v, (int, float)) for v in values): 462 return {"type": "number", "enum": list(values)} 463 else: 464 # Mixed types, use anyOf 465 return {"anyOf": [{"type": _get_json_type(v), "enum": [v]} for v in values]} 466 467 # Handle basic types 468 if ptype is str: 469 return {"type": "string"} 470 if ptype is int: 471 return {"type": "integer"} 472 if ptype is float: 473 return {"type": "number"} 474 if ptype is bool: 475 return {"type": "boolean"} 476 477 # Handle common datetime types 478 if ptype is datetime: 479 return {"type": "string", "format": "date-time"} 480 if ptype is date: 481 return {"type": "string", "format": "date"} 482 483 # Handle UUID 484 if ptype is uuid.UUID: 485 return {"type": "string", "format": "uuid"} 486 487 # Default to object for any other types 488 return {"type": "object"} 489 490 def _get_json_type(value): 491 """Get the JSON Schema type name for a Python value. 492 493 This helper function maps Python types to their corresponding 494 JSON Schema type names. It handles basic types and provides 495 sensible defaults for complex types. 496 497 Args: 498 value: The Python value to get the JSON type for 499 500 Returns: 501 str: The JSON Schema type name ('string', 'number', 'boolean', 502 'array', 'object', or 'object' as default) 503 """ 504 if isinstance(value, str): 505 return "string" 506 elif isinstance(value, bool): 507 return "boolean" 508 elif isinstance(value, int) or isinstance(value, float): 509 return "number" 510 elif isinstance(value, list): 511 return "array" 512 elif isinstance(value, dict): 513 return "object" 514 else: 515 return "object" # Default 516 517 def _parse_param_docs(docstring: str) -> dict: 518 """Extract parameter descriptions from a docstring.""" 519 if not docstring: 520 return {} 521 522 lines = docstring.splitlines() 523 param_docs = {} 524 current_param = None 525 in_params = False 526 527 # Regular expressions for finding parameter sections and param lines 528 param_section_re = re.compile(r"^(Args|Parameters):\s*$") 529 param_line_re = re.compile(r"^\s{4}(\w+)\s*(?:\([^\)]*\))?:\s*(.*)") 530 531 for line in lines: 532 # Check if we're entering the parameters section 533 if param_section_re.match(line.strip()): 534 in_params = True 535 continue 536 537 if in_params: 538 # Skip empty lines 539 if not line.strip(): 540 continue 541 542 # Check for a parameter definition line 543 match = param_line_re.match(line) 544 if match: 545 current_param = match.group(1) 546 param_docs[current_param] = match.group(2).strip() 547 # Check for continuation of a parameter description 548 elif current_param and line.startswith(" " * 8): 549 param_docs[current_param] += " " + line.strip() 550 # If we see a line that doesn't match our patterns, we're out of the params section 551 else: 552 current_param = None 553 554 return param_docs 555 556 # Start of main function logic 557 558 # Skip if tools are disabled 559 if not self.tools_enabled: 560 return 561 562 # Validate input callable 563 if not external_callable: 564 raise ValueError("A valid external callable must be provided.") 565 566 # Set function name, either from parameter or from callable's __name__ 567 function_name = name or external_callable.__name__ 568 569 # Try to get docstring and extract description 570 try: 571 docstring = inspect.getdoc(external_callable) 572 description = description or (docstring.split("\n")[0].strip() if docstring else "No description provided.") 573 except Exception as e: 574 self._log(f"[TOOL] Warning: Could not extract docstring from {function_name}: {e}", level="warning") 575 docstring = "" 576 description = description or "No description provided." 577 578 # Extract parameter documentation from docstring 579 param_docs = _parse_param_docs(docstring) 580 581 # Handle conflicts with existing functions 582 if override: 583 self.delete_function(function_name) 584 elif any(t["function"]["name"] == function_name for t in self.tools): 585 raise ValueError(f"Function '{function_name}' is already registered. Use override=True to replace.") 586 587 # Try to get function signature for parameter info 588 try: 589 signature = inspect.signature(external_callable) 590 except (ValueError, TypeError) as e: 591 raise ValueError(f"Cannot inspect callable '{function_name}': {e}") 592 593 # Process parameters to build schema 594 properties = {} 595 required = [] 596 597 for param_name, param in signature.parameters.items(): 598 # Skip self, cls parameters for instance/class methods 599 if param_name in ("self", "cls") and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: 600 continue 601 602 # Get parameter annotation, defaulting to Any 603 annotation = param.annotation if param.annotation != inspect.Parameter.empty else Any 604 605 try: 606 # Convert Python type to JSON Schema 607 schema = _python_type_to_schema(annotation) 608 609 # Add description from docstring or create a default one 610 schema["description"] = param_docs.get(param_name, f"{param_name} parameter") 611 612 # Add to properties 613 properties[param_name] = schema 614 615 # If no default value is provided, parameter is required 616 if param.default == inspect.Parameter.empty: 617 required.append(param_name) 618 self._log(f"[TOOL] Parameter '{param_name}' is required", level="debug") 619 else: 620 self._log(f"[TOOL] Parameter '{param_name}' has default value: {param.default}", level="debug") 621 622 except Exception as e: 623 self._log(f"[TOOL] Error processing parameter {param_name} for {function_name}: {e}", level="error") 624 # Add a basic object schema as fallback 625 properties[param_name] = { 626 "type": "string", # Default to string instead of object for better compatibility 627 "description": f"{param_name} parameter (type conversion failed)" 628 } 629 630 # For parameters with no default value, mark as required even if processing failed 631 if param.default == inspect.Parameter.empty: 632 required.append(param_name) 633 self._log(f"[TOOL] Parameter '{param_name}' marked as required despite conversion failure", level="debug") 634 635 # Apply schema extensions if provided 636 if schema_extensions: 637 for param_name, extensions in schema_extensions.items(): 638 if param_name in properties: 639 properties[param_name].update(extensions) 640 641 # Create parameters object with proper placement of 'required' field 642 parameters = { 643 "type": "object", 644 "properties": properties, 645 } 646 647 # Only add required field if there are required parameters 648 if required: 649 parameters["required"] = required 650 651 # Build the final tool specification 652 tool_spec = { 653 "type": "function", 654 "function": { 655 "name": function_name, 656 "description": description, 657 "parameters": parameters 658 } 659 } 660 661 # Set disabled flag if requested 662 if disabled: 663 tool_spec["function"]["disabled"] = True 664 665 # Add to tools list 666 self.tools.append(tool_spec) 667 668 # Make the function available as an attribute on the instance 669 setattr(self, function_name, external_callable) 670 671 # Log the registration with detailed information 672 self._log(f"[TOOL] Registered function '{function_name}' with {len(properties)} parameters", level="info") 673 if required: 674 self._log(f"[TOOL] Required parameters: {required}", level="info") 675 676 return function_name # Return the name for reference 677 678 679 def disable_function(self, name: str) -> bool: 680 """ 681 Disable a registered tool function by name. 682 683 This marks the function as inactive for tool calling without removing it from the internal registry. 684 The function remains visible in the tool listing but is skipped during tool selection by the LLM. 685 686 Args: 687 name (str): The name of the function to disable. 688 689 Returns: 690 bool: True if the function was found and disabled, False otherwise. 691 692 Example: 693 interactor.disable_function("extract_text") 694 """ 695 for tool in self.tools: 696 if tool["function"]["name"] == name: 697 tool["function"]["disabled"] = True 698 return True 699 return False 700 701 702 def enable_function(self, name: str) -> bool: 703 """ 704 Re-enable a previously disabled tool function by name. 705 706 This removes the 'disabled' flag from a tool function, making it available again for LLM use. 707 708 Args: 709 name (str): The name of the function to enable. 710 711 Returns: 712 bool: True if the function was found and enabled, False otherwise. 713 714 Example: 715 interactor.enable_function("extract_text") 716 """ 717 for tool in self.tools: 718 if tool["function"]["name"] == name: 719 tool["function"].pop("disabled", None) 720 return True 721 return False 722 723 724 def delete_function(self, name: str) -> bool: 725 """ 726 Permanently remove a registered tool function from the Interactor. 727 728 This deletes both the tool metadata and the callable attribute, making it fully inaccessible 729 from the active session. Useful for dynamically trimming the toolset. 730 731 Args: 732 name (str): The name of the function to delete. 733 734 Returns: 735 bool: True if the function was found and removed, False otherwise. 736 737 Example: 738 interactor.delete_function("extract_text") 739 """ 740 before = len(self.tools) 741 self.tools = [tool for tool in self.tools if tool["function"]["name"] != name] 742 if hasattr(self, name): 743 delattr(self, name) 744 return len(self.tools) < before 745 746 747 def list_functions(self) -> List[Dict[str, Any]]: 748 """Get the list of registered functions for tool calling. 749 750 Returns: 751 List[Dict[str, Any]]: List of registered functions. 752 """ 753 return self.tools 754 755 756 def list_models( 757 self, 758 providers: Optional[Union[str, List[str]]] = None, 759 filter: Optional[str] = None 760 ) -> List[str]: 761 """Retrieve available models from configured providers. 762 763 Args: 764 providers: Provider name or list of provider names. If None, all are queried. 765 filter: Optional regex to filter model names. 766 767 Returns: 768 List[str]: Sorted list of "provider:model_id" strings. 769 """ 770 models = [] 771 772 if providers is None: 773 providers_to_list = self.providers 774 elif isinstance(providers, str): 775 providers_to_list = {providers: self.providers.get(providers)} 776 elif isinstance(providers, list): 777 providers_to_list = {p: self.providers.get(p) for p in providers} 778 else: 779 return [] 780 781 invalid_providers = [p for p in providers_to_list if p not in self.providers or self.providers[p] is None] 782 if invalid_providers: 783 self.logger.error(f"Invalid providers: {invalid_providers}") 784 return [] 785 786 regex_pattern = None 787 if filter: 788 try: 789 regex_pattern = re.compile(filter, re.IGNORECASE) 790 except re.error as e: 791 self.logger.error(f"Invalid regex pattern: {e}") 792 return [] 793 794 for provider_name, config in providers_to_list.items(): 795 sdk = config.get("sdk", "openai") 796 base_url = config.get("base_url") 797 api_key = config.get("api_key") 798 799 try: 800 if sdk == "openai": 801 client = openai.OpenAI(api_key=api_key, base_url=base_url) 802 response = client.models.list() 803 for model in response.data: 804 model_id = f"{provider_name}:{model.id}" 805 if not regex_pattern or regex_pattern.search(model_id): 806 models.append(model_id) 807 808 elif sdk == "anthropic": 809 client = Anthropic(api_key=api_key) 810 response = client.models.list() 811 for model in response: 812 model_id = f"{provider_name}:{model.id}" 813 if not regex_pattern or regex_pattern.search(model_id): 814 models.append(model_id) 815 else: 816 self.logger.warning(f"SDK '{sdk}' for provider '{provider_name}' is not supported by list_models()") 817 818 except Exception as e: 819 self.logger.error(f"Failed to list models for {provider_name}: {e}") 820 821 return sorted(models, key=str.lower) 822 823 824 async def _retry_with_backoff(self, func: Callable, *args, **kwargs): 825 """Execute a function with exponential backoff retry logic. 826 827 This method implements a robust retry mechanism for API calls with 828 exponential backoff. It handles rate limits, connection errors, and 829 other transient failures. If all retries fail, it will attempt to 830 switch to a fallback model if configured. 831 832 Args: 833 func (Callable): The async function to execute 834 *args: Positional arguments to pass to the function 835 **kwargs: Keyword arguments to pass to the function 836 837 Returns: 838 The result of the function call if successful 839 840 Raises: 841 Exception: If all retries fail and no fallback model is available 842 """ 843 for attempt in range(self.max_retries + 1): 844 try: 845 return await func(*args, **kwargs) 846 847 except (RateLimitError, APIConnectionError, aiohttp.ClientError) as e: 848 if attempt == self.max_retries: 849 model_key = f"{self.provider}:{self.model}" 850 if self.fallback_model and model_key != self.fallback_model: 851 print(f"[yellow]Model '{model_key}' failed. Switching to fallback: {self.fallback_model}[/yellow]") 852 self._setup_client(self.fallback_model) 853 self._setup_encoding() 854 self._normalizer() 855 return await func(*args, **kwargs) # retry once with fallback model 856 else: 857 self.logger.error(f"All {self.max_retries} retries failed: {e}") 858 raise 859 860 delay = self.retry_delay * (2 ** attempt) 861 self.logger.warning(f"Retry {attempt + 1}/{self.max_retries} after {delay}s due to {e}") 862 self._log(f"[RETRY] Attempt {attempt + 1}/{self.max_retries} failed: {e}", level="warning") 863 await asyncio.sleep(delay) 864 865 except OpenAIError as e: 866 self.logger.error(f"OpenAI error: {e}") 867 raise 868 869 except Exception as e: 870 self.logger.error(f"Unexpected error: {e}") 871 raise 872 873 874 def interact( 875 self, 876 user_input: Optional[str], 877 quiet: bool = False, 878 tools: bool = True, 879 stream: bool = True, 880 markdown: bool = False, 881 model: Optional[str] = None, 882 output_callback: Optional[Callable[[str], None]] = None, 883 session_id: Optional[str] = None, 884 raw: Optional[bool] = None, 885 tool_suppress: bool = True, 886 timeout: float = 60.0 887 ) -> Union[Optional[str], "TokenStream"]: 888 """Main universal gateway for all LLM interaction. 889 890 This function serves as the single entry point for all interactions with the language model. 891 When `raw=False` (default), it handles the interaction internally and returns the full response. 892 When `raw=True`, it returns a context manager that yields chunks of the response for custom handling. 893 894 Args: 895 user_input: Text input from the user. 896 quiet: If True, don't print status info or progress. 897 tools: Enable (True) or disable (False) tool calling. 898 stream: Enable (True) or disable (False) streaming responses. 899 markdown: If True, renders content as markdown. 900 model: Optional model override. 901 output_callback: Optional callback to handle the output. 902 session_id: Optional session ID to load messages from. 903 raw: If True, return a context manager instead of handling the interaction internally. 904 If None, use the class-level setting from __init__. 905 tool_suppress: If True and raw=True, filter out tool-related status messages. 906 timeout: Maximum time in seconds to wait for the stream to complete when raw=True. 907 908 Returns: 909 If raw=False: The complete response from the model as a string, or None if there was an error. 910 If raw=True: A context manager that yields chunks of the response as they arrive. 911 912 Example with default mode: 913 response = ai.interact("Tell me a joke") 914 915 Example with raw mode: 916 with ai.interact("Tell me a joke", raw=True) as stream: 917 for chunk in stream: 918 print(chunk, end="", flush=True) 919 """ 920 if not user_input: 921 return None 922 923 if quiet or self.quiet: 924 markdown = False 925 stream = False 926 927 # Determine if we should use raw mode 928 # If raw parameter is explicitly provided, use that; otherwise use class setting 929 use_raw = self.raw if raw is None else raw 930 931 # If raw mode is requested, delegate to interact_raw 932 if use_raw: 933 return self._interact_raw( 934 user_input=user_input, 935 tools=tools, 936 model=model, 937 session_id=session_id, 938 tool_suppress=tool_suppress, 939 timeout=timeout 940 ) 941 942 # Setup model if specified 943 if model: 944 self._setup_client(model) 945 self._setup_encoding() 946 947 # Session handling 948 if self.session_enabled and session_id: 949 self.session_id = session_id 950 self.session_load(session_id) 951 952 # Add user message using messages_add 953 self.messages_add(role="user", content=user_input) 954 955 # Log token count estimate 956 token_count = self._count_tokens(self.history) 957 if not quiet: 958 print(f"[dim]Estimated tokens in context: {token_count} / {self.context_length}[/dim]") 959 960 # Make sure we have enough context space 961 if token_count > self.context_length: 962 if self._cycle_messages(): 963 if not quiet: 964 print("[red]Context window exceeded. Cannot proceed.[/red]") 965 return None 966 967 # Log user input 968 self._log(f"[USER] {user_input}") 969 970 # Handle the actual interaction with complete streaming for all responses 971 result = asyncio.run(self._interact_async_core( 972 user_input=user_input, 973 quiet=quiet, 974 tools=tools, 975 stream=stream, 976 markdown=markdown, 977 output_callback=output_callback 978 )) 979 980 # Log completion for this interaction 981 self._log(f"[INTERACTION] Completed with {len(self.history)} total messages") 982 983 return result 984 985 986 def _interact_raw( 987 self, 988 user_input: Optional[str], 989 tools: bool = True, 990 model: Optional[str] = None, 991 session_id: Optional[str] = None, 992 tool_suppress: bool = True, 993 timeout: float = 60.0 994 ): 995 """ 996 Low-level function that returns a raw stream of tokens from the model. 997 998 This method works as a context manager that yields a generator of streaming tokens. 999 The caller is responsible for handling the output stream. Typically, this is used 1000 indirectly through interact() with raw=True. 1001 1002 Args: 1003 user_input: Text input from the user. 1004 tools: Enable (True) or disable (False) tool calling. 1005 model: Optional model override. 1006 session_id: Optional session ID to load messages from. 1007 tool_suppress: If True, filter out tool-related status messages. 1008 timeout: Maximum time in seconds to wait for the stream to complete. 1009 1010 Returns: 1011 A context manager that yields a stream of tokens. 1012 1013 Example: 1014 with ai.interact_raw("Hello world") as stream: 1015 for chunk in stream: 1016 print(chunk, end="", flush=True) 1017 """ 1018 if not user_input: 1019 return None 1020 1021 # Setup model if specified 1022 if model: 1023 self._setup_client(model) 1024 self._setup_encoding() 1025 1026 # Session handling 1027 if self.session_enabled and session_id: 1028 self.session_id = session_id 1029 self.session_load(session_id) 1030 1031 # Add user message 1032 self.messages_add(role="user", content=user_input) 1033 1034 # Log token count estimate 1035 token_count = self._count_tokens(self.history) 1036 self._log(f"[STREAM] Estimated tokens in context: {token_count} / {self.context_length}") 1037 1038 # Make sure we have enough context space 1039 if token_count > self.context_length: 1040 if self._cycle_messages(): 1041 self._log("[STREAM] Context window exceeded. Cannot proceed.", level="error") 1042 return None 1043 1044 # Log user input 1045 self._log(f"[USER] {user_input}") 1046 1047 # Create a token stream class using a thread-safe queue 1048 class TokenStream: 1049 def __init__(self, interactor, user_input, tools, tool_suppress, timeout): 1050 """Initialize a new TokenStream instance. 1051 1052 This class provides a context manager for streaming token responses 1053 from the AI model. It handles asynchronous token delivery, tool call 1054 suppression, and timeout management. 1055 1056 Args: 1057 interactor: The parent Interactor instance 1058 user_input: The user's input text 1059 tools: Whether tool calling is enabled 1060 tool_suppress: Whether to suppress tool-related status messages 1061 timeout: Maximum time in seconds to wait for stream completion 1062 """ 1063 self.interactor = interactor 1064 self.user_input = user_input 1065 self.tools = tools 1066 self.tool_suppress = tool_suppress 1067 self.timeout = timeout 1068 self.token_queue = queue.Queue() 1069 self.thread = None 1070 self.result = None 1071 self.error = None 1072 self.completed = False 1073 1074 def __enter__(self): 1075 """Enter the context manager and start the streaming process. 1076 1077 This method initializes the streaming worker thread and returns 1078 self for iteration. The worker thread handles the actual API 1079 communication and token delivery. 1080 1081 Returns: 1082 TokenStream: Self for iteration 1083 """ 1084 # Start the thread for async interaction 1085 def stream_worker(): 1086 """Worker thread that handles the streaming interaction. 1087 1088 This internal function runs in a separate thread to handle 1089 the asynchronous API communication and token delivery. 1090 """ 1091 # Define output callback to put tokens in queue 1092 def callback(text): 1093 """Process and queue incoming text tokens. 1094 1095 This internal function handles incoming text chunks, 1096 optionally filtering tool-related messages, and adds 1097 them to the token queue. 1098 1099 Args: 1100 text: The text chunk to process and queue 1101 """ 1102 # Filter out tool messages if requested 1103 if self.tool_suppress: 1104 try: 1105 # Check if this is a tool status message (JSON format) 1106 data = json.loads(text) 1107 if isinstance(data, dict) and data.get("type") == "tool_call": 1108 # Skip this message 1109 return 1110 except (json.JSONDecodeError, TypeError): 1111 # Not JSON or not a dict, continue normally 1112 pass 1113 1114 # Add to queue 1115 self.token_queue.put(text) 1116 1117 # Run the interaction in a new event loop 1118 loop = asyncio.new_event_loop() 1119 asyncio.set_event_loop(loop) 1120 1121 try: 1122 # Run the interaction 1123 self.result = loop.run_until_complete( 1124 self.interactor._interact_async_core( 1125 user_input=self.user_input, 1126 quiet=True, 1127 tools=self.tools, 1128 stream=True, 1129 markdown=False, 1130 output_callback=callback 1131 ) 1132 ) 1133 # Signal successful completion 1134 self.completed = True 1135 except Exception as e: 1136 self.error = str(e) 1137 self.interactor.logger.error(f"Streaming error: {traceback.format_exc()}") 1138 # Add error information to the queue if we haven't yielded anything yet 1139 if self.token_queue.empty(): 1140 self.token_queue.put(f"Error: {str(e)}") 1141 finally: 1142 # Signal end of stream regardless of success/failure 1143 self.token_queue.put(None) 1144 loop.close() 1145 1146 # Start the worker thread 1147 self.thread = threading.Thread(target=stream_worker) 1148 self.thread.daemon = True 1149 self.thread.start() 1150 1151 # Return self for iteration 1152 return self 1153 1154 def __iter__(self): 1155 """Return self as an iterator. 1156 1157 Returns: 1158 TokenStream: Self for iteration 1159 """ 1160 return self 1161 1162 def __next__(self): 1163 """Get the next token from the stream. 1164 1165 This method implements the iterator protocol, retrieving the next 1166 token from the queue with timeout handling. 1167 1168 Returns: 1169 str: The next token from the stream 1170 1171 Raises: 1172 StopIteration: When the stream is complete or times out 1173 """ 1174 # Get next token from queue with timeout to prevent hanging 1175 try: 1176 token = self.token_queue.get(timeout=self.timeout) 1177 if token is None: 1178 # End of stream 1179 raise StopIteration 1180 return token 1181 except queue.Empty: 1182 # Timeout reached 1183 self.interactor.logger.warning(f"Stream timeout after {self.timeout}s") 1184 if not self.completed and not self.error: 1185 # Clean up the thread - it might be stuck 1186 if self.thread and self.thread.is_alive(): 1187 # We can't forcibly terminate a thread in Python, 1188 # but we can report the issue 1189 self.interactor.logger.error("Stream worker thread is hung") 1190 raise StopIteration 1191 1192 def __exit__(self, exc_type, exc_val, exc_tb): 1193 """Exit the context manager and clean up resources. 1194 1195 This method handles cleanup when the context manager is exited, 1196 including thread cleanup and message history updates. 1197 1198 Args: 1199 exc_type: The exception type if an exception was raised 1200 exc_val: The exception value if an exception was raised 1201 exc_tb: The exception traceback if an exception was raised 1202 1203 Returns: 1204 bool: False to not suppress any exceptions 1205 """ 1206 # Clean up resources 1207 if self.thread and self.thread.is_alive(): 1208 self.thread.join(timeout=2.0) 1209 1210 # Add messages to history if successful 1211 if self.completed and self.result and not exc_type: 1212 if isinstance(self.result, str) and self.result != "No response.": 1213 # If we had a successful completion, ensure the result is in the history 1214 last_msg = self.interactor.history[-1] if self.interactor.history else None 1215 if not last_msg or last_msg.get("role") != "assistant" or last_msg.get("content") != self.result: 1216 # Add a clean assistant message to history if not already there 1217 self.interactor.messages_add(role="assistant", content=self.result) 1218 1219 # If there was an error in the stream processing, log it 1220 if self.error: 1221 self.interactor.logger.error(f"Stream processing error: {self.error}") 1222 1223 return False # Don't suppress exceptions 1224 1225 return TokenStream(self, user_input, tools, tool_suppress, timeout) 1226 1227 1228 async def _interact_async_core( 1229 self, 1230 user_input: str, 1231 quiet: bool = False, 1232 tools: bool = True, 1233 stream: bool = True, 1234 markdown: bool = False, 1235 output_callback: Optional[Callable] = None 1236 ) -> str: 1237 """Main SDK-agnostic async execution pipeline with tool call looping support.""" 1238 # Prepare display handler 1239 live = Live(console=console, refresh_per_second=100) if markdown and stream else None 1240 if live: 1241 live.start() 1242 1243 # Initialize variables for iteration tracking 1244 full_content = "" 1245 max_iterations = 5 # Prevent infinite loops 1246 iterations = 0 1247 1248 # Main interaction loop - continues until no more tool calls or max iterations reached 1249 while iterations < max_iterations: 1250 iterations += 1 1251 1252 try: 1253 # Execute the appropriate SDK runner - history is already normalized 1254 response_data = await self.sdk_runner( 1255 model=self.model, 1256 messages=self.history, 1257 stream=stream, 1258 markdown=markdown, 1259 quiet=quiet if iterations == 1 else False, 1260 live=live, 1261 output_callback=output_callback 1262 ) 1263 1264 # Extract response data 1265 content = response_data.get("content", "") 1266 tool_calls = response_data.get("tool_calls", []) 1267 1268 # Log the response data for debugging 1269 self._log(f"[ITERATION {iterations}] Content: {len(content)} chars, Tool calls: {len(tool_calls)}") 1270 1271 # Add content to full response 1272 if iterations == 1: 1273 full_content = content 1274 elif content: 1275 if full_content and content: 1276 full_content += f"\n{content}" 1277 else: 1278 full_content = content 1279 1280 # Add assistant message with or without tool calls 1281 if tool_calls: 1282 # Process each tool call 1283 for call in tool_calls: 1284 # Add assistant message with tool call 1285 tool_info = { 1286 "id": call["id"], 1287 "name": call["function"]["name"], 1288 "arguments": call["function"]["arguments"] 1289 } 1290 1291 # Add the assistant message with tool call 1292 self.messages_add( 1293 role="assistant", 1294 content=content if len(tool_calls) == 1 else "", 1295 tool_info=tool_info 1296 ) 1297 1298 # Execute the tool 1299 call_name = call["function"]["name"] 1300 call_args = call["function"]["arguments"] 1301 call_id = call["id"] 1302 1303 # Stop Rich Live while executing tool calls 1304 live_was_active = True 1305 if live and live.is_started: 1306 live_was_active = True 1307 live.stop() 1308 1309 result = await self._handle_tool_call_async( 1310 function_name=call_name, 1311 function_arguments=call_args, 1312 tool_call_id=call_id, 1313 quiet=quiet, 1314 safe=False, 1315 output_callback=output_callback 1316 ) 1317 1318 # Restart live display if it was active before 1319 if live_was_active and live: 1320 live.start() 1321 1322 # Add tool result message 1323 tool_result_info = { 1324 "id": call_id, 1325 "result": result 1326 } 1327 1328 self.messages_add( 1329 role="tool", 1330 content=result, 1331 tool_info=tool_result_info 1332 ) 1333 else: 1334 # Simple assistant response without tool calls 1335 self.messages_add(role="assistant", content=content) 1336 break # No more tools to process, we're done 1337 1338 # Reset live display if needed 1339 if stream and live: 1340 live.stop() 1341 live = Live(console=console, refresh_per_second=100) 1342 live.start() 1343 1344 except Exception as e: 1345 self.logger.error(f"[{self.sdk.upper()} ERROR] {str(e)}") 1346 self._log(f"[ERROR] Error in interaction loop: {str(e)}", level="error") 1347 if live: 1348 live.stop() 1349 return f"Error: {str(e)}" 1350 1351 # Clean up display 1352 if live: 1353 live.stop() 1354 1355 return full_content or None 1356 1357 1358 async def _openai_runner( 1359 self, 1360 *, 1361 model, 1362 messages, 1363 stream, 1364 markdown=False, 1365 quiet=False, 1366 live=None, 1367 output_callback=None 1368 ): 1369 """Handle OpenAI-specific API interactions and response processing.""" 1370 # Log what we're sending for debugging 1371 self._log(f"[OPENAI REQUEST] Sending request to {model} with {len(self.history)} messages", level="debug") 1372 1373 # Prepare API parameters - history is already normalized by _normalizer 1374 params = { 1375 "model": model, 1376 "messages": self.history, 1377 "stream": stream, 1378 } 1379 1380 # Add tools if enabled 1381 if self.tools_enabled and self.tools_supported: 1382 enabled_tools = self._get_enabled_tools() 1383 if enabled_tools: 1384 params["tools"] = enabled_tools 1385 params["tool_choice"] = "auto" 1386 1387 # Call API with retry handling 1388 try: 1389 response = await self._retry_with_backoff( 1390 self.async_client.chat.completions.create, 1391 **params 1392 ) 1393 except Exception: 1394 self.logger.error(f"[OPENAI ERROR RUNNER]: {traceback.format_exc()}") 1395 raise 1396 1397 assistant_content = "" 1398 tool_calls_dict = {} 1399 1400 # Process streaming response 1401 if stream and hasattr(response, "__aiter__"): 1402 async for chunk in response: 1403 delta = getattr(chunk.choices[0], "delta", None) 1404 1405 # Handle content chunks 1406 if hasattr(delta, "content") and delta.content is not None: 1407 text = delta.content 1408 assistant_content += text 1409 if output_callback: 1410 output_callback(text) 1411 elif live: 1412 live.update(Markdown(assistant_content)) 1413 elif not markdown: 1414 print(text, end="") 1415 1416 # Process tool calls 1417 if hasattr(delta, "tool_calls") and delta.tool_calls: 1418 for tool_call_delta in delta.tool_calls: 1419 index = tool_call_delta.index 1420 if index not in tool_calls_dict: 1421 tool_calls_dict[index] = { 1422 "id": tool_call_delta.id if hasattr(tool_call_delta, "id") else None, 1423 "function": {"name": "", "arguments": ""} 1424 } 1425 1426 function = getattr(tool_call_delta, "function", None) 1427 if function: 1428 name = getattr(function, "name", None) 1429 args = getattr(function, "arguments", "") 1430 if name: 1431 tool_calls_dict[index]["function"]["name"] = name 1432 if args: 1433 tool_calls_dict[index]["function"]["arguments"] += args 1434 if tool_call_delta.id and not tool_calls_dict[index]["id"]: 1435 tool_calls_dict[index]["id"] = tool_call_delta.id 1436 1437 # Make sure the ID is set regardless 1438 if hasattr(tool_call_delta, "id") and tool_call_delta.id and not tool_calls_dict[index]["id"]: 1439 tool_calls_dict[index]["id"] = tool_call_delta.id 1440 1441 if not output_callback and not markdown and not quiet: 1442 print() 1443 1444 # Process non-streaming response 1445 else: 1446 message = response.choices[0].message 1447 assistant_content = message.content or "" 1448 1449 if hasattr(message, "tool_calls") and message.tool_calls: 1450 for i, tool_call in enumerate(message.tool_calls): 1451 tool_calls_dict[i] = { 1452 "id": tool_call.id, 1453 "function": { 1454 "name": tool_call.function.name, 1455 "arguments": tool_call.function.arguments 1456 } 1457 } 1458 1459 if output_callback: 1460 output_callback(assistant_content) 1461 elif not quiet: 1462 print(assistant_content) 1463 1464 # Log tool calls for debugging 1465 if tool_calls_dict: 1466 self._log(f"[OPENAI TOOL CALLS] Found {len(tool_calls_dict)} tool calls", level="debug") 1467 for idx, call in tool_calls_dict.items(): 1468 self._log(f"[OPENAI TOOL CALL {idx}] {call['function']['name']} with ID {call['id']}", level="debug") 1469 1470 # Return standardized response format 1471 return { 1472 "content": assistant_content, 1473 "tool_calls": list(tool_calls_dict.values()) 1474 } 1475 1476 1477 async def _anthropic_runner( 1478 self, 1479 *, 1480 model, 1481 messages, 1482 stream, 1483 markdown=False, 1484 quiet=False, 1485 live=None, 1486 output_callback=None 1487 ): 1488 """Handle Anthropic-specific API interactions and response processing.""" 1489 # Log what we're sending for debugging 1490 self._log(f"[ANTHROPIC REQUEST] Sending request to {model} with {len(self.history)} messages", level="debug") 1491 1492 # Prepare API parameters - history is already normalized by _normalizer 1493 params = { 1494 "model": model, 1495 "messages": self.history, 1496 "max_tokens": 8192, 1497 "system": self.system 1498 } 1499 1500 # Add tools support if needed 1501 if self.tools_enabled and self.tools_supported: 1502 enabled_tools = [] 1503 for tool in self._get_enabled_tools(): 1504 # Extract parameters from OpenAI format 1505 tool_params = tool["function"]["parameters"] 1506 1507 # Create Anthropic-compatible tool definition 1508 format_tool = { 1509 "name": tool["function"]["name"], 1510 "description": tool["function"].get("description", ""), 1511 "input_schema": { 1512 "type": "object", 1513 "properties": tool_params.get("properties", {}) 1514 } 1515 } 1516 1517 # Ensure 'required' is at the correct level for Anthropic (as a direct child of input_schema) 1518 if "required" in tool_params: 1519 format_tool["input_schema"]["required"] = tool_params["required"] 1520 1521 enabled_tools.append(format_tool) 1522 1523 params["tools"] = enabled_tools 1524 1525 assistant_content = "" 1526 tool_calls_dict = {} 1527 1528 try: 1529 # Process streaming response 1530 if stream: 1531 stream_params = params.copy() 1532 stream_params["stream"] = True 1533 1534 stream_response = await self._retry_with_backoff( 1535 self.async_client.messages.create, 1536 **stream_params 1537 ) 1538 1539 content_type = None 1540 async for chunk in stream_response: 1541 chunk_type = getattr(chunk, "type", "unknown") 1542 self._log(f"[ANTHROPIC CHUNK] Type: {chunk_type}", level="debug") 1543 if chunk_type == "content_block_start" and hasattr(chunk.content_block, "type"): 1544 content_type = chunk.content_block.type 1545 if content_type == "tool_use": 1546 tool_id = chunk.content_block.id 1547 tool_name = chunk.content_block.name 1548 tool_input = chunk.content_block.input 1549 tool_calls_dict[tool_id] = { 1550 "id": tool_id, 1551 "function": { 1552 "name": tool_name, 1553 "arguments": "" 1554 } 1555 } 1556 self._log(f"[ANTHROPIC TOOL USE] {tool_name}", level="debug") 1557 1558 # Handle text content 1559 if chunk_type == "content_block_delta" and hasattr(chunk.delta, "text"): 1560 delta = chunk.delta.text 1561 assistant_content += delta 1562 if output_callback: 1563 output_callback(delta) 1564 elif live: 1565 live.update(Markdown(assistant_content)) 1566 elif not markdown: 1567 print(delta, end="") 1568 1569 # Handle complete tool use 1570 elif chunk_type == "content_block_delta" and content_type == "tool_use": 1571 tool_calls_dict[tool_id]["function"]["arguments"] += chunk.delta.partial_json 1572 1573 # Process non-streaming response 1574 else: 1575 # For non-streaming, ensure we don't send the stream parameter 1576 non_stream_params = params.copy() 1577 non_stream_params.pop("stream", None) # Remove stream if it exists 1578 1579 response = await self._retry_with_backoff( 1580 self.async_client.messages.create, 1581 **non_stream_params 1582 ) 1583 1584 # Extract text content 1585 for content_block in response.content: 1586 if content_block.type == "text": 1587 assistant_content += content_block.text 1588 1589 if content_block.type == "tool_use": 1590 tool_id = content_block.id 1591 tool_name = content_block.name 1592 tool_input = content_block.input 1593 tool_calls_dict[tool_id] = { 1594 "id": tool_id, 1595 "function": { 1596 "name": tool_name, 1597 "arguments": tool_input 1598 } 1599 } 1600 self._log(f"[ANTHROPIC TOOL USE] {tool_name}", level="debug") 1601 1602 if output_callback: 1603 output_callback(assistant_content) 1604 elif not quiet: 1605 print(assistant_content) 1606 1607 except Exception as e: 1608 self._log(f"[ANTHROPIC ERROR RUNNER] {traceback.format_exc()}", level="error") 1609 1610 # Return something usable even in case of error 1611 return { 1612 "content": f"Error processing Anthropic response: {str(e)}", 1613 "tool_calls": [] 1614 } 1615 1616 # Return standardized response format 1617 return { 1618 "content": assistant_content, 1619 "tool_calls": list(tool_calls_dict.values()) 1620 } 1621 1622 1623 def _get_enabled_tools(self) -> List[dict]: 1624 """Return the list of currently enabled tool function definitions.""" 1625 return [ 1626 tool for tool in self.tools 1627 if not tool["function"].get("disabled", False) 1628 ] 1629 1630 1631 async def _handle_tool_call_async( 1632 self, 1633 function_name: str, 1634 function_arguments: str, 1635 tool_call_id: str, 1636 quiet: bool = False, 1637 safe: bool = False, 1638 output_callback: Optional[Callable[[str], None]] = None 1639 ) -> str: 1640 """Process a tool call asynchronously and return the result. 1641 1642 Args: 1643 function_name: Name of the function to call. 1644 function_arguments: JSON string containing the function arguments. 1645 tool_call_id: Unique identifier for this tool call. 1646 params: Parameters used for the original API call. 1647 safe: If True, prompts for confirmation before executing the tool call. 1648 output_callback: Optional callback to handle the tool call result. 1649 1650 Returns: 1651 The result of the function call. 1652 1653 Raises: 1654 ValueError: If the function is not found or JSON is invalid. 1655 """ 1656 if isinstance(function_arguments, str): 1657 arguments = json.loads(function_arguments) 1658 else: 1659 arguments = function_arguments 1660 1661 self._log(f"[TOOL:{function_name}] args={arguments}") 1662 1663 func = getattr(self, function_name, None) 1664 if not func: 1665 raise ValueError(f"Function '{function_name}' not found.") 1666 1667 be_quiet = self.quiet if quiet is None else quiet 1668 1669 if not be_quiet: 1670 print(f"\nRunning {function_name}...") 1671 1672 if output_callback: 1673 notification = json.dumps({ 1674 "type": "tool_call", 1675 "tool_name": function_name, 1676 "status": "started" 1677 }) 1678 output_callback(notification) 1679 1680 try: 1681 if safe: 1682 prompt = f"[bold yellow]Proposed tool call:[/bold yellow] {function_name}({json.dumps(arguments, indent=2)})\n[bold cyan]Execute? [y/n]: [/bold cyan]" 1683 confirmed = Confirm.ask(prompt, default=False) 1684 if not confirmed: 1685 command_result = { 1686 "status": "cancelled", 1687 "message": "Tool call aborted by user" 1688 } 1689 print("[red]Tool call cancelled by user[/red]") 1690 else: 1691 loop = asyncio.get_event_loop() 1692 command_result = await loop.run_in_executor(None, lambda: func(**arguments)) 1693 else: 1694 loop = asyncio.get_event_loop() 1695 command_result = await loop.run_in_executor(None, lambda: func(**arguments)) 1696 1697 try: 1698 json.dumps(command_result) 1699 except TypeError as e: 1700 self.logger.error(f"Tool call result not serializable: {e}") 1701 return {"error": "Tool call returned unserializable data."} 1702 1703 if output_callback: 1704 notification = json.dumps({ 1705 "type": "tool_call", 1706 "tool_name": function_name, 1707 "status": "completed" 1708 }) 1709 output_callback(notification) 1710 1711 return command_result 1712 1713 except Exception as e: 1714 self._log(f"[ERROR] Tool execution failed: {e}", level="error") 1715 self.logger.error(f"Error executing tool function '{function_name}': {e}") 1716 return {"error": str(e)} 1717 1718 1719 def _setup_encoding(self): 1720 """Initialize the token encoding system for the current model. 1721 1722 This method sets up the appropriate tokenizer based on the current 1723 model provider. For OpenAI models, it attempts to use the model-specific 1724 tokenizer, falling back to cl100k_base if not available. For other 1725 providers, it uses cl100k_base as a default. 1726 1727 The encoding is used for token counting and context management. 1728 """ 1729 try: 1730 if self.provider == "openai": 1731 try: 1732 self.encoding = tiktoken.encoding_for_model(self.model) 1733 self._log(f"[ENCODING] Loaded tokenizer for OpenAI model: {self.model}") 1734 except: 1735 self.encoding = tiktoken.get_encoding("cl100k_base") 1736 self._log(f"[ENCODING] Fallback to cl100k_base for model: {self.model}") 1737 else: 1738 self.encoding = tiktoken.get_encoding("cl100k_base") 1739 self._log(f"[ENCODING] Defaulting to cl100k_base for non-OpenAI model: {self.model}") 1740 except Exception as e: 1741 self.logger.error(f"Failed to setup encoding: {e}") 1742 self.encoding = tiktoken.get_encoding("cl100k_base") 1743 1744 1745 def _estimate_tokens_tiktoken(self, messages) -> int: 1746 """Rough token count estimate using tiktoken for OpenAI or fallback cases.""" 1747 if not hasattr(self, "encoding") or not self.encoding: 1748 self._setup_encoding() 1749 return sum(len(self.encoding.encode(msg.get("content", ""))) for msg in messages if isinstance(msg.get("content"), str)) 1750 1751 1752 def _count_tokens(self, messages, use_cache=True) -> int: 1753 """Accurately estimate token count for messages including tool calls with caching support. 1754 1755 Args: 1756 messages: List of message objects in either OpenAI or Anthropic format. 1757 use_cache: Whether to use and update the token count cache. 1758 1759 Returns: 1760 int: Estimated token count. 1761 """ 1762 # Setup encoding if needed 1763 if not hasattr(self, "encoding") or not self.encoding: 1764 self._setup_encoding() 1765 1766 # Initialize cache if it doesn't exist 1767 if not hasattr(self, "_token_count_cache"): 1768 self._token_count_cache = {} 1769 1770 # Generate a cache key based on message content hashes 1771 if use_cache: 1772 try: 1773 # Create a cache key using message IDs if available, or content hashes 1774 cache_key_parts = [] 1775 for msg in messages: 1776 if isinstance(msg, dict): 1777 # Try to use stable identifiers for cache key 1778 msg_id = msg.get("id", None) 1779 timestamp = msg.get("timestamp", None) 1780 1781 if msg_id and timestamp: 1782 cache_key_parts.append(f"{msg_id}:{timestamp}") 1783 else: 1784 # Fall back to content-based hash if no stable IDs 1785 content_str = str(msg.get("content", "")) 1786 role = msg.get("role", "unknown") 1787 cache_key_parts.append(f"{role}:{hash(content_str)}") 1788 1789 cache_key = ":".join(cache_key_parts) 1790 if cache_key in self._token_count_cache: 1791 return self._token_count_cache[cache_key] 1792 except Exception as e: 1793 # If caching fails, just continue with normal counting 1794 self._log(f"[TOKEN COUNT] Cache key generation failed: {e}", level="debug") 1795 use_cache = False 1796 1797 # For Claude models, try to use their built-in token counter 1798 if self.sdk == "anthropic": 1799 try: 1800 # Convert messages to Anthropic format if needed 1801 anthropic_messages = [] 1802 for msg in messages: 1803 if msg.get("role") == "system": 1804 continue # System handled separately 1805 1806 if msg.get("role") == "tool": 1807 # Skip tool messages in token count to avoid double-counting 1808 continue 1809 1810 if msg.get("role") == "user" and isinstance(msg.get("content"), list): 1811 # Already in Anthropic format with tool_result 1812 anthropic_messages.append(msg) 1813 elif msg.get("role") in ["user", "assistant"]: 1814 if not msg.get("tool_calls") and not msg.get("tool_use"): 1815 # Simple message 1816 anthropic_messages.append({ 1817 "role": msg.get("role"), 1818 "content": msg.get("content", "") 1819 }) 1820 1821 # Use Anthropic's token counter if messages exist 1822 if anthropic_messages: 1823 response = self.client.messages.count_tokens( 1824 model=self.model, 1825 messages=anthropic_messages, 1826 system=self.system 1827 ) 1828 token_count = response.input_tokens 1829 1830 # Cache the result for future use 1831 if use_cache and 'cache_key' in locals(): 1832 self._token_count_cache[cache_key] = token_count 1833 1834 return token_count 1835 except Exception as e: 1836 # Fall back to our estimation 1837 self._log(f"[TOKEN COUNT] Error using Anthropic token counter: {e}", level="debug") 1838 1839 # More accurate token counting for all message types 1840 num_tokens = 0 1841 1842 # Count tokens for each message 1843 for msg in messages: 1844 # Base token count for message metadata (role + message format) 1845 num_tokens += 4 # Message overhead 1846 1847 # Add tokens for role name 1848 role = msg.get("role", "") 1849 num_tokens += len(self.encoding.encode(role)) 1850 1851 # Count tokens in message content 1852 if isinstance(msg.get("content"), str): 1853 content = msg.get("content", "") 1854 content_tokens = len(self.encoding.encode(content)) 1855 num_tokens += content_tokens 1856 1857 elif isinstance(msg.get("content"), list): 1858 # Handle Anthropic-style content lists 1859 for item in msg.get("content", []): 1860 if isinstance(item, dict): 1861 # Tool result or other structured content 1862 if item.get("type") == "tool_result": 1863 result_content = item.get("content", "") 1864 if isinstance(result_content, str): 1865 num_tokens += len(self.encoding.encode(result_content)) 1866 else: 1867 # JSON serialization for dict/list content 1868 num_tokens += len(self.encoding.encode(json.dumps(result_content))) 1869 # Add tokens for tool_use_id and type fields 1870 num_tokens += len(self.encoding.encode(item.get("type", ""))) 1871 num_tokens += len(self.encoding.encode(item.get("tool_use_id", ""))) 1872 1873 # Text content type 1874 elif item.get("type") == "text": 1875 num_tokens += len(self.encoding.encode(item.get("text", ""))) 1876 1877 # Tool use type 1878 elif item.get("type") == "tool_use": 1879 num_tokens += len(self.encoding.encode(item.get("name", ""))) 1880 tool_input = item.get("input", {}) 1881 if isinstance(tool_input, str): 1882 num_tokens += len(self.encoding.encode(tool_input)) 1883 else: 1884 num_tokens += len(self.encoding.encode(json.dumps(tool_input))) 1885 num_tokens += len(self.encoding.encode(item.get("id", ""))) 1886 else: 1887 # Plain text content 1888 num_tokens += len(self.encoding.encode(str(item))) 1889 1890 # Count tokens in tool calls for OpenAI format 1891 if msg.get("tool_calls"): 1892 for tool_call in msg.get("tool_calls", []): 1893 if isinstance(tool_call, dict): 1894 # Count tokens for function name 1895 func_name = tool_call.get("function", {}).get("name", "") 1896 num_tokens += len(self.encoding.encode(func_name)) 1897 1898 # Count tokens for arguments 1899 args = tool_call.get("function", {}).get("arguments", "") 1900 if isinstance(args, str): 1901 num_tokens += len(self.encoding.encode(args)) 1902 else: 1903 num_tokens += len(self.encoding.encode(json.dumps(args))) 1904 1905 # Add tokens for id and type fields 1906 num_tokens += len(self.encoding.encode(tool_call.get("id", ""))) 1907 num_tokens += len(self.encoding.encode(tool_call.get("type", "function"))) 1908 1909 # Count tokens in Anthropic tool_use field 1910 if msg.get("tool_use"): 1911 tool_use = msg.get("tool_use") 1912 # Count tokens for name 1913 num_tokens += len(self.encoding.encode(tool_use.get("name", ""))) 1914 1915 # Count tokens for input 1916 tool_input = tool_use.get("input", {}) 1917 if isinstance(tool_input, str): 1918 num_tokens += len(self.encoding.encode(tool_input)) 1919 else: 1920 num_tokens += len(self.encoding.encode(json.dumps(tool_input))) 1921 1922 # Add tokens for id field 1923 num_tokens += len(self.encoding.encode(tool_use.get("id", ""))) 1924 1925 # Handle tool response message format 1926 if msg.get("role") == "tool": 1927 # Add tokens for tool_call_id 1928 tool_id = msg.get("tool_call_id", "") 1929 num_tokens += len(self.encoding.encode(tool_id)) 1930 1931 # Add message end tokens 1932 num_tokens += 2 1933 1934 # Cache the result for future use 1935 if use_cache and 'cache_key' in locals(): 1936 self._token_count_cache[cache_key] = num_tokens 1937 1938 return num_tokens 1939 1940 1941 def _cycle_messages(self): 1942 """Intelligently trim the message history to fit within the allowed context length. 1943 1944 This method implements a sophisticated trimming strategy that: 1945 1. Always preserves system messages 1946 2. Always keeps the most recent complete conversation turn 1947 3. Prioritizes keeping tool call chains intact 1948 4. Preserves important context from earlier exchanges 1949 5. Aggressively prunes redundant information before essential content 1950 1951 The method uses a multi-pass approach: 1952 1. First pass: Identify critical messages that must be kept 1953 2. Second pass: Calculate token counts for each message 1954 3. Third pass: Keep important tool chains intact 1955 4. Fourth pass: Fill remaining space with recent messages 1956 1957 Returns: 1958 bool: True if all messages were trimmed (context exceeded), 1959 False if trimming was successful and context is within limits 1960 1961 Note: 1962 This method maintains both session_history and history in sync, 1963 ensuring proper SDK-specific formatting is preserved. 1964 """ 1965 # Check if we need to trim 1966 token_count = self._count_tokens(self.history) 1967 1968 # If we're already under the limit, return early 1969 if token_count <= self.context_length: 1970 return False 1971 1972 self._log(f"[TRIM] Starting message cycling: {token_count} tokens exceeds {self.context_length} limit", level="info") 1973 1974 # We'll need to track tokens as we reconstruct the history 1975 target_tokens = max(self.context_length * 0.8, self.context_length - 1000) # Target 80% or 1000 less than max 1976 1977 # First pass: identify critical messages we must keep 1978 must_keep = [] 1979 tool_chain_groups = {} # Group related tool calls and their results 1980 1981 # Always keep system messages (should be first) 1982 system_indices = [] 1983 for i, msg in enumerate(self.history): 1984 if msg.get("role") == "system": 1985 system_indices.append(i) 1986 must_keep.append(i) 1987 1988 # Identify the most recent complete exchange (user question + assistant response) 1989 latest_exchange = [] 1990 # Start from the end and work backward to find the last complete exchange 1991 for i in range(len(self.history) - 1, -1, -1): 1992 msg = self.history[i] 1993 if msg.get("role") == "assistant" and not latest_exchange: 1994 latest_exchange.append(i) 1995 elif msg.get("role") == "user" and latest_exchange: 1996 latest_exchange.append(i) 1997 break 1998 1999 # Add the latest exchange to must-keep 2000 must_keep.extend(latest_exchange) 2001 2002 # Identify tool chains - track which messages belong to the same tool flow 2003 tool_id_to_chain = {} 2004 for i, msg in enumerate(self.history): 2005 # For assistant messages with tool calls 2006 if msg.get("role") == "assistant" and msg.get("tool_calls"): 2007 for tool_call in msg.get("tool_calls"): 2008 tool_id = tool_call.get("id") 2009 if tool_id: 2010 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2011 2012 # For tool response messages 2013 elif msg.get("role") == "tool" and msg.get("tool_call_id"): 2014 tool_id = msg.get("tool_call_id") 2015 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2016 2017 # For Anthropic format with tool use 2018 elif msg.get("role") == "assistant" and isinstance(msg.get("content"), list): 2019 for block in msg.get("content", []): 2020 if isinstance(block, dict) and block.get("type") == "tool_use": 2021 tool_id = block.get("id") 2022 if tool_id: 2023 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2024 2025 # For Anthropic tool result messages 2026 elif msg.get("role") == "user" and isinstance(msg.get("content"), list): 2027 for block in msg.get("content", []): 2028 if isinstance(block, dict) and block.get("type") == "tool_result": 2029 tool_id = block.get("tool_use_id") 2030 if tool_id: 2031 tool_id_to_chain[tool_id] = tool_id_to_chain.get(tool_id, []) + [i] 2032 2033 # Group together all indices for each tool chain 2034 for tool_id, indices in tool_id_to_chain.items(): 2035 chain_key = f"tool_{min(indices)}" # Group by the earliest message 2036 if chain_key not in tool_chain_groups: 2037 tool_chain_groups[chain_key] = set() 2038 tool_chain_groups[chain_key].update(indices) 2039 2040 # Second pass: calculate tokens for each message 2041 message_tokens = [] 2042 for i, msg in enumerate(self.history): 2043 # Count tokens for this individual message 2044 tokens = self._count_tokens([msg]) 2045 message_tokens.append((i, tokens)) 2046 2047 # Keep the messages identified as must-keep 2048 keep_indices = set(must_keep) 2049 2050 # Calculate the tokens we've committed to keeping 2051 keep_tokens = sum(tokens for i, tokens in message_tokens if i in keep_indices) 2052 2053 # Check if we've already exceeded the target with just must-keep messages 2054 if keep_tokens > self.context_length: 2055 # We're in trouble - the essential messages alone exceed context 2056 # Drop older messages until we're under the limit 2057 all_indices = sorted(keep_indices) 2058 2059 # Start dropping oldest messages, but NEVER drop system messages 2060 for idx in all_indices: 2061 if idx not in system_indices: 2062 keep_indices.remove(idx) 2063 keep_tokens -= message_tokens[idx][1] 2064 if keep_tokens <= target_tokens: 2065 break 2066 2067 # If we've removed everything but system messages and still over limit 2068 if keep_tokens > self.context_length: 2069 self._log(f"[TRIM] Critical failure: even with minimal context ({keep_tokens} tokens), we exceed the limit", level="error") 2070 # Keep only system messages if any 2071 keep_indices = set(system_indices) 2072 return True # Context exceeded completely 2073 2074 # Third pass: keep the most important tool chains intact 2075 available_tokens = target_tokens - keep_tokens 2076 # Sort tool chains by recency (assumed by the chain_key which uses the earliest message) 2077 sorted_chains = sorted(tool_chain_groups.items(), key=lambda x: x[0], reverse=True) 2078 2079 for chain_key, indices in sorted_chains: 2080 # Skip if we've already decided to keep all messages in this chain 2081 if indices.issubset(keep_indices): 2082 continue 2083 2084 # Calculate how many tokens this chain would add 2085 chain_tokens = sum(tokens for i, tokens in message_tokens if i in indices and i not in keep_indices) 2086 2087 # If we can fit the entire chain, keep it 2088 if chain_tokens <= available_tokens: 2089 keep_indices.update(indices) 2090 available_tokens -= chain_tokens 2091 # Otherwise, we might want to keep partial chains in the future, but for now, skip 2092 2093 # Fourth pass: fill in with as many remaining messages as possible, prioritizing recency 2094 # Get remaining messages sorted by recency (newest first) 2095 remaining_indices = [(i, tokens) for i, tokens in message_tokens if i not in keep_indices] 2096 remaining_indices.sort(reverse=True) # Sort newest first 2097 2098 for i, tokens in remaining_indices: 2099 if tokens <= available_tokens: 2100 keep_indices.add(i) 2101 available_tokens -= tokens 2102 2103 # Final message reconstruction 2104 self._log(f"[TRIM] Keeping {len(keep_indices)}/{len(self.history)} messages, estimated {target_tokens - available_tokens} tokens", level="info") 2105 2106 # Create new history with just the kept messages, preserving order 2107 new_history = [self.history[i] for i in sorted(keep_indices)] 2108 self.history = new_history 2109 2110 # Update session_history to match the pruned history 2111 if hasattr(self, "session_history"): 2112 # Map between history items and session_history 2113 session_to_keep = [] 2114 2115 # For each session history message, check if it corresponds to a kept message 2116 for session_msg in self.session_history: 2117 # Keep system messages 2118 if session_msg.get("role") == "system": 2119 session_to_keep.append(session_msg) 2120 continue 2121 2122 # Try to match based on available IDs or content 2123 msg_id = session_msg.get("id") 2124 2125 # For tool messages, check tool_info.id against tool_call_id 2126 if "metadata" in session_msg and "tool_info" in session_msg["metadata"]: 2127 tool_id = session_msg["metadata"]["tool_info"].get("id") 2128 2129 # Check if this tool_id is still in the kept history 2130 for history_msg in new_history: 2131 # Check standard ids 2132 history_tool_id = None 2133 2134 # Check OpenAI format 2135 if history_msg.get("role") == "tool": 2136 history_tool_id = history_msg.get("tool_call_id") 2137 elif history_msg.get("role") == "assistant" and history_msg.get("tool_calls"): 2138 for call in history_msg.get("tool_calls", []): 2139 if call.get("id") == tool_id: 2140 history_tool_id = call.get("id") 2141 break 2142 2143 # Check Anthropic format 2144 elif isinstance(history_msg.get("content"), list): 2145 for block in history_msg.get("content", []): 2146 if isinstance(block, dict): 2147 if block.get("type") == "tool_use" and block.get("id") == tool_id: 2148 history_tool_id = block.get("id") 2149 break 2150 elif block.get("type") == "tool_result" and block.get("tool_use_id") == tool_id: 2151 history_tool_id = block.get("tool_use_id") 2152 break 2153 2154 if history_tool_id == tool_id: 2155 session_to_keep.append(session_msg) 2156 break 2157 2158 # For regular messages, try content matching as fallback 2159 else: 2160 content_match = False 2161 if isinstance(session_msg.get("content"), str) and session_msg.get("content"): 2162 for history_msg in new_history: 2163 if history_msg.get("role") == session_msg.get("role") and history_msg.get("content") == session_msg.get("content"): 2164 content_match = True 2165 break 2166 2167 if content_match: 2168 session_to_keep.append(session_msg) 2169 2170 # Update session_history with kept messages 2171 self.session_history = session_to_keep 2172 2173 # Re-normalize to ensure consistency 2174 self._normalizer(force=True) 2175 2176 # Verify our final token count 2177 final_token_count = self._count_tokens(self.history) 2178 self._log(f"[TRIM] Final history has {len(self.history)} messages, {final_token_count} tokens", level="info") 2179 2180 # Return whether we've completely exceeded context 2181 return final_token_count > self.context_length or len(self.history) == len(system_indices) 2182 2183 2184 def messages_add( 2185 self, 2186 role: str, 2187 content: Any, 2188 tool_info: Optional[Dict] = None, 2189 normalize: bool = True 2190 ) -> str: 2191 """ 2192 Add a message to the standardized session_history and then update SDK-specific history. 2193 2194 This method is the central point for all message additions to the conversation. 2195 2196 Args: 2197 role: The role of the message ("user", "assistant", "system", "tool") 2198 content: The message content (text or structured) 2199 tool_info: Optional tool-related metadata 2200 normalize: Whether to normalize history after adding this message 2201 2202 Returns: 2203 str: Unique ID of the added message 2204 """ 2205 # Generate a unique message ID 2206 message_id = str(uuid.uuid4()) 2207 2208 # Create the standardized message for session_history 2209 timestamp = datetime.now(timezone.utc).isoformat() 2210 2211 # Store system messages directly 2212 if role == "system": 2213 self.system = content 2214 2215 # Create standard format message 2216 standard_message = { 2217 "role": role, 2218 "content": content, 2219 "id": message_id, 2220 "timestamp": timestamp, 2221 "metadata": { 2222 "sdk": self.sdk 2223 } 2224 } 2225 2226 # Add tool info if provided 2227 if tool_info: 2228 standard_message["metadata"]["tool_info"] = tool_info 2229 2230 # Add to session_history 2231 if not hasattr(self, "session_history"): 2232 self.session_history = [] 2233 2234 self.session_history.append(standard_message) 2235 2236 # Save to persistent session if enabled 2237 if self.session_enabled and self.session_id: 2238 # Convert standard message to session-compatible format 2239 session_msg = { 2240 "role": role, 2241 "content": content, 2242 "id": message_id, 2243 "timestamp": timestamp 2244 } 2245 2246 # Add tool-related fields if present 2247 if tool_info: 2248 for key, value in tool_info.items(): 2249 session_msg[key] = value 2250 2251 # Store in session 2252 self.session.msg_insert(self.session_id, session_msg) 2253 2254 # Update the SDK-specific format in self.history by running the normalizer 2255 if normalize: 2256 # We only need to normalize the most recent message for efficiency 2257 # Pass a flag indicating we're just normalizing a new message 2258 self._normalizer(force=False, new_message_only=True) 2259 2260 # Log the added message 2261 self._log(f"[MESSAGE ADDED] {role}: {str(content)[:50]}...") 2262 2263 return message_id 2264 2265 2266 def messages_system(self, prompt: str): 2267 """Set or retrieve the current system prompt. 2268 2269 This method manages the system prompt that guides the AI's behavior. 2270 It can be used to both set a new system prompt and retrieve the current one. 2271 When setting a new prompt, it updates the system message in the conversation 2272 history and persists it to the session if enabled. 2273 2274 Args: 2275 prompt (str): The new system prompt to set. If empty or None, 2276 returns the current system prompt without changes. 2277 2278 Returns: 2279 str: The current system prompt after any updates 2280 """ 2281 if not isinstance(prompt, str) or not prompt: 2282 return self.system 2283 2284 # If the prompt hasn't changed, don't do anything 2285 if self.system == prompt: 2286 return self.system 2287 2288 # Update the system prompt 2289 old_system = self.system 2290 self.system = prompt 2291 2292 # For OpenAI, update or insert the system message in history 2293 if self.sdk == "openai": 2294 # Check if there's already a system message 2295 system_index = next((i for i, msg in enumerate(self.history) 2296 if msg.get("role") == "system"), None) 2297 2298 if system_index is not None: 2299 # Update existing system message 2300 self.history[system_index]["content"] = prompt 2301 else: 2302 # Insert new system message at the beginning 2303 self.history.insert(0, {"role": "system", "content": prompt}) 2304 2305 # For Anthropic, system message is not part of history, just save it for API calls 2306 2307 # Log to session only if prompt actually changed 2308 if self.session_enabled and self.session_id and old_system != prompt: 2309 self.session.msg_insert(self.session_id, {"role": "system", "content": prompt}) 2310 2311 return self.system 2312 2313 2314 def messages(self) -> list: 2315 """Return full session messages (persisted or in-memory).""" 2316 if self.session_enabled and self.session_id: 2317 return self.session.load_full(self.session_id).get("messages", []) 2318 return self.session_history 2319 2320 2321 def messages_length(self) -> int: 2322 """Calculate the total token count for the message history.""" 2323 if not self.encoding: 2324 return 0 2325 2326 total_tokens = 0 2327 for message in self.history: 2328 if message.get("content"): 2329 total_tokens += len(self.encoding.encode(message["content"])) 2330 if message.get("tool_calls"): 2331 for tool_call in message["tool_calls"]: 2332 if tool_call.get("function"): 2333 total_tokens += len(self.encoding.encode(tool_call["function"].get("name", ""))) 2334 total_tokens += len(self.encoding.encode(tool_call["function"].get("arguments", ""))) 2335 return total_tokens 2336 2337 2338 def session_load(self, session_id: Optional[str]): 2339 """Load and normalize messages for a specific session. 2340 2341 This method loads a conversation session from persistent storage and 2342 normalizes the messages to the current SDK format. It handles system 2343 messages, tool calls, and maintains message ordering. If loading fails, 2344 it resets to an empty session with the default system prompt. 2345 2346 Args: 2347 session_id (Optional[str]): The ID of the session to load. 2348 If None, resets to in-memory mode. 2349 2350 Note: 2351 This method will update both session_history and history to match 2352 the loaded session's state. It also ensures proper SDK-specific 2353 message formatting. 2354 """ 2355 self.session_id = session_id 2356 self._last_session_id = session_id 2357 2358 if self.session_enabled and session_id: 2359 try: 2360 # Load raw session data 2361 session_data = self.session.load_full(session_id) 2362 messages = session_data.get("messages", []) 2363 2364 # Convert session format to our standard format 2365 self.session_history = [] 2366 2367 # Track the most recent system message 2368 latest_system_msg = None 2369 2370 for msg in messages: 2371 # Extract fields 2372 role = msg.get("role", "user") 2373 content = msg.get("content", "") 2374 msg_id = msg.get("id", str(uuid.uuid4())) 2375 timestamp = msg.get("timestamp", datetime.now(timezone.utc).isoformat()) 2376 2377 # If this is a system message, track it but don't add to session_history yet 2378 if role == "system": 2379 if latest_system_msg is None or timestamp > latest_system_msg["timestamp"]: 2380 latest_system_msg = { 2381 "role": role, 2382 "content": content, 2383 "id": msg_id, 2384 "timestamp": timestamp, 2385 "metadata": {"sdk": self.sdk} 2386 } 2387 continue 2388 2389 # Build tool_info if present 2390 tool_info = None 2391 if any(key in msg for key in ["tool_use_id", "tool_call_id", "name", "arguments"]): 2392 tool_info = { 2393 "id": msg.get("tool_use_id") or msg.get("tool_call_id"), 2394 "name": msg.get("name", "unknown_tool"), 2395 "arguments": msg.get("arguments", {}) 2396 } 2397 2398 # Create standard message 2399 standard_msg = { 2400 "role": role, 2401 "content": content, 2402 "id": msg_id, 2403 "timestamp": timestamp, 2404 "metadata": { 2405 "sdk": self.sdk 2406 } 2407 } 2408 2409 if tool_info: 2410 standard_msg["metadata"]["tool_info"] = tool_info 2411 2412 self.session_history.append(standard_msg) 2413 2414 # If we found a system message, update the system property and add to history 2415 if latest_system_msg: 2416 self.system = latest_system_msg["content"] 2417 # Insert at the beginning of session_history 2418 self.session_history.insert(0, latest_system_msg) 2419 else: 2420 # If no system message was found, add the current system message 2421 self.messages_add(role="system", content=self.system) 2422 2423 # Normalize to current SDK format 2424 self._normalizer(force=True) 2425 2426 self._log(f"[SESSION] Switched to session '{session_id}'") 2427 except Exception as e: 2428 self.logger.error(f"Failed to load session '{session_id}': {e}") 2429 self.session_reset() 2430 else: 2431 # Reset to empty state with system message 2432 self.session_reset() 2433 2434 2435 def session_reset(self): 2436 """Reset the current session state and reinitialize to default system prompt. 2437 2438 This method performs a complete reset of the conversation state: 2439 1. Clears all message history 2440 2. Disables session ID tracking 2441 3. Returns to in-memory mode 2442 4. Reinitializes with the default system prompt 2443 2444 The reset is useful for starting fresh conversations or recovering 2445 from error states. It maintains the basic system configuration while 2446 clearing all conversation context. 2447 """ 2448 self.session_id = None 2449 self._last_session_id = None 2450 2451 # Clear histories 2452 self.session_history = [] 2453 self.history = [] 2454 2455 # Reapply the system message 2456 if hasattr(self, "system") and self.system: 2457 # Add to session_history 2458 self.messages_add(role="system", content=self.system) 2459 else: 2460 # Ensure we have a default system message 2461 self.system = "You are a helpful Assistant." 2462 self.messages_add(role="system", content=self.system) 2463 2464 self._log("[SESSION] Reset to in-memory mode") 2465 2466 2467 def _normalizer(self, force=False, new_message_only=False): 2468 """Central normalization function for message format conversion. 2469 2470 This method transforms the standardized session_history into the 2471 SDK-specific format needed in self.history. It handles different 2472 message types (system, user, assistant, tool) and their various 2473 formats across different SDKs. 2474 2475 Args: 2476 force (bool): If True, always normalize even if SDK hasn't changed. 2477 Default is False, which only normalizes on SDK change. 2478 new_message_only (bool): If True, only normalize the most recent message 2479 for efficiency when adding single messages. 2480 2481 Note: 2482 This method is the central point for message format conversion and 2483 ensures consistency between session storage and API communication. 2484 """ 2485 # Skip normalization if SDK hasn't changed and force is False 2486 if not force and hasattr(self, '_last_sdk') and self._last_sdk == self.sdk: 2487 # If we only need to normalize the most recent message 2488 if new_message_only and self.session_history: 2489 # Get the most recent message from session_history 2490 recent_msg = self.session_history[-1] 2491 2492 # Apply SDK-specific normalization for just this message 2493 if self.sdk == "openai": 2494 self._openai_normalize_message(recent_msg) 2495 elif self.sdk == "anthropic": 2496 self._anthropic_normalize_message(recent_msg) 2497 else: 2498 # Generic handler for unknown SDKs 2499 self._generic_normalize_message(recent_msg) 2500 2501 return 2502 2503 # Record the current SDK to detect future changes 2504 self._last_sdk = self.sdk 2505 2506 # For full normalization, clear current history and rebuild it 2507 self.history = [] 2508 2509 # Call the appropriate SDK-specific normalizer 2510 if self.sdk == "openai": 2511 self._openai_normalizer() 2512 elif self.sdk == "anthropic": 2513 self._anthropic_normalizer() 2514 else: 2515 self.logger.warning(f"No normalizer available for SDK: {self.sdk}") 2516 # Fallback to a simple conversion for unknown SDKs 2517 for msg in self.session_history: 2518 self._generic_normalize_message(msg) 2519 2520 2521 def _openai_normalizer(self): 2522 """Convert standardized session_history to OpenAI-compatible format. 2523 2524 This method transforms the internal message format into the structure 2525 required by the OpenAI API. It handles: 2526 - System messages at the start of history 2527 - User messages with plain text 2528 - Assistant messages with optional tool calls 2529 - Tool response messages with tool_call_id 2530 2531 The resulting format matches OpenAI's chat completion API requirements 2532 for both regular messages and function calling. 2533 """ 2534 # For OpenAI, we need to include system message in the history 2535 # and convert tool calls/results to OpenAI format 2536 2537 # Start with empty history 2538 self.history = [] 2539 2540 # First, add the current system message at position 0 2541 self.history.append({ 2542 "role": "system", 2543 "content": self.system 2544 }) 2545 2546 # Process all non-system messages 2547 for msg in self.session_history: 2548 if msg["role"] == "system": 2549 continue # Skip system messages, already handled 2550 2551 # Handle different message types 2552 if msg["role"] == "user": 2553 # User messages are straightforward 2554 self.history.append({ 2555 "role": "user", 2556 "content": msg["content"] 2557 }) 2558 2559 elif msg["role"] == "assistant": 2560 # For assistant messages with tool calls 2561 if "metadata" in msg and msg["metadata"].get("tool_info"): 2562 # This is an assistant message with tool calls 2563 tool_info = msg["metadata"]["tool_info"] 2564 2565 # Create OpenAI assistant message with tool calls 2566 assistant_msg = { 2567 "role": "assistant", 2568 "content": msg["content"] if isinstance(msg["content"], str) else "", 2569 "tool_calls": [{ 2570 "id": tool_info["id"], 2571 "type": "function", 2572 "function": { 2573 "name": tool_info["name"], 2574 "arguments": json.dumps(tool_info["arguments"]) if isinstance(tool_info["arguments"], dict) else tool_info["arguments"] 2575 } 2576 }] 2577 } 2578 self.history.append(assistant_msg) 2579 else: 2580 # Regular assistant message 2581 self.history.append({ 2582 "role": "assistant", 2583 "content": msg["content"] 2584 }) 2585 2586 elif msg["role"] == "tool": 2587 # Tool response messages 2588 if "metadata" in msg and "tool_info" in msg["metadata"]: 2589 tool_msg = { 2590 "role": "tool", 2591 "tool_call_id": msg["metadata"]["tool_info"]["id"], 2592 "content": json.dumps(msg["content"]) if isinstance(msg["content"], (dict, list)) else msg["content"] 2593 } 2594 self.history.append(tool_msg) 2595 2596 2597 def _anthropic_normalizer(self): 2598 """Convert standardized session_history to Anthropic-compatible format. 2599 2600 This method transforms the internal message format into the structure 2601 required by the Anthropic API. It handles: 2602 - System messages (stored separately, not in history) 2603 - User messages with optional tool results 2604 - Assistant messages with optional tool use 2605 - Content blocks for structured responses 2606 2607 The resulting format matches Anthropic's message API requirements 2608 for both regular messages and tool use. 2609 """ 2610 # For Anthropic, we don't include system message in the history 2611 # but need to handle content blocks for tool use/results 2612 2613 # Start with empty history 2614 self.history = [] 2615 2616 # Process all non-system messages 2617 for msg in self.session_history: 2618 if msg["role"] == "system": 2619 # Update system prompt if this is the most recent system message 2620 # (only apply the most recent system message if we have multiple) 2621 if msg == self.session_history[-1] or all(m["role"] != "system" for m in self.session_history[self.session_history.index(msg)+1:]): 2622 self.system = msg["content"] 2623 continue # Skip system messages in history 2624 2625 # Handle different message types 2626 if msg["role"] == "user": 2627 # User messages - check if it contains tool results 2628 if "metadata" in msg and "tool_info" in msg["metadata"] and msg["metadata"]["tool_info"].get("result"): 2629 # This is a tool result message 2630 tool_info = msg["metadata"]["tool_info"] 2631 2632 # Create Anthropic tool result format 2633 tool_result_msg = { 2634 "role": "user", 2635 "content": [{ 2636 "type": "tool_result", 2637 "tool_use_id": tool_info["id"], 2638 "content": json.dumps(tool_info["result"]) if isinstance(tool_info["result"], (dict, list)) else str(tool_info["result"]) 2639 }] 2640 } 2641 self.history.append(tool_result_msg) 2642 else: 2643 # Regular user message 2644 self.history.append({ 2645 "role": "user", 2646 "content": msg["content"] 2647 }) 2648 2649 elif msg["role"] == "assistant": 2650 # For assistant messages, check for tool use 2651 if "metadata" in msg and "tool_info" in msg["metadata"]: 2652 # This is an assistant message with tool use 2653 tool_info = msg["metadata"]["tool_info"] 2654 2655 # Build content blocks 2656 content_blocks = [] 2657 2658 # Add text content if present 2659 if msg["content"]: 2660 content_blocks.append({ 2661 "type": "text", 2662 "text": msg["content"] if isinstance(msg["content"], str) else "" 2663 }) 2664 2665 # Add tool use block 2666 content_blocks.append({ 2667 "type": "tool_use", 2668 "id": tool_info["id"], 2669 "name": tool_info["name"], 2670 "input": tool_info["arguments"] if isinstance(tool_info["arguments"], dict) else json.loads(tool_info["arguments"]) 2671 }) 2672 2673 # Create Anthropic assistant message with tool use 2674 self.history.append({ 2675 "role": "assistant", 2676 "content": content_blocks 2677 }) 2678 else: 2679 # Regular assistant message 2680 self.history.append({ 2681 "role": "assistant", 2682 "content": msg["content"] 2683 }) 2684 2685 elif msg["role"] == "tool": 2686 # Tool messages in standard format get converted to user messages with tool_result 2687 if "metadata" in msg and "tool_info" in msg["metadata"]: 2688 tool_info = msg["metadata"]["tool_info"] 2689 2690 # Create Anthropic tool result message 2691 tool_result_msg = { 2692 "role": "user", 2693 "content": [{ 2694 "type": "tool_result", 2695 "tool_use_id": tool_info["id"], 2696 "content": json.dumps(msg["content"]) if isinstance(msg["content"], (dict, list)) else str(msg["content"]) 2697 }] 2698 } 2699 self.history.append(tool_result_msg) 2700 2701 2702 def _openai_normalize_message(self, msg): 2703 """Normalize a single message to OpenAI format and add to history.""" 2704 role = msg.get("role") 2705 content = msg.get("content") 2706 2707 if role == "system": 2708 # Check if we already have a system message in history 2709 system_index = next((i for i, m in enumerate(self.history) 2710 if m.get("role") == "system"), None) 2711 if system_index is not None: 2712 # Update existing system message 2713 self.history[system_index]["content"] = content 2714 else: 2715 # Insert new system message at the beginning 2716 self.history.insert(0, { 2717 "role": "system", 2718 "content": content 2719 }) 2720 # Update the system property 2721 self.system = content 2722 2723 elif role == "user": 2724 self.history.append({ 2725 "role": "user", 2726 "content": content 2727 }) 2728 2729 elif role == "assistant": 2730 # For assistant messages, handle potential tool calls 2731 if "metadata" in msg and msg["metadata"].get("tool_info"): 2732 # This is an assistant message with tool calls 2733 tool_info = msg["metadata"]["tool_info"] 2734 2735 # Create OpenAI assistant message with tool calls 2736 try: 2737 arguments = tool_info.get("arguments", {}) 2738 arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else arguments 2739 except: 2740 arguments_str = str(arguments) 2741 2742 assistant_msg = { 2743 "role": "assistant", 2744 "content": content if isinstance(content, str) else "", 2745 "tool_calls": [{ 2746 "id": tool_info["id"], 2747 "type": "function", 2748 "function": { 2749 "name": tool_info["name"], 2750 "arguments": arguments_str 2751 } 2752 }] 2753 } 2754 self.history.append(assistant_msg) 2755 else: 2756 # Regular assistant message 2757 self.history.append({ 2758 "role": "assistant", 2759 "content": content 2760 }) 2761 2762 elif role == "tool": 2763 # Tool response messages 2764 if "metadata" in msg and "tool_info" in msg["metadata"]: 2765 tool_info = msg["metadata"]["tool_info"] 2766 tool_msg = { 2767 "role": "tool", 2768 "tool_call_id": tool_info["id"], 2769 "content": json.dumps(content) if isinstance(content, (dict, list)) else str(content) 2770 } 2771 self.history.append(tool_msg) 2772 2773 2774 def _anthropic_normalize_message(self, msg): 2775 """Normalize a single message to Anthropic format and add to history.""" 2776 role = msg.get("role") 2777 content = msg.get("content") 2778 2779 if role == "system": 2780 # Store system prompt separately, not in history for Anthropic 2781 self.system = content 2782 2783 elif role == "user": 2784 # User messages - check if it contains tool results 2785 if "metadata" in msg and "tool_info" in msg["metadata"]: 2786 tool_info = msg["metadata"]["tool_info"] 2787 # Check for result or directly use content 2788 result_content = tool_info.get("result", content) 2789 2790 # Create Anthropic tool result format 2791 try: 2792 result_str = json.dumps(result_content) if isinstance(result_content, (dict, list)) else str(result_content) 2793 except: 2794 result_str = str(result_content) 2795 2796 tool_result_msg = { 2797 "role": "user", 2798 "content": [{ 2799 "type": "tool_result", 2800 "tool_use_id": tool_info["id"], 2801 "content": result_str 2802 }] 2803 } 2804 self.history.append(tool_result_msg) 2805 else: 2806 # Regular user message 2807 self.history.append({ 2808 "role": "user", 2809 "content": content 2810 }) 2811 2812 elif role == "assistant": 2813 # For assistant messages, check for tool use 2814 if "metadata" in msg and "tool_info" in msg["metadata"]: 2815 # This is an assistant message with tool use 2816 tool_info = msg["metadata"]["tool_info"] 2817 2818 # Build content blocks 2819 content_blocks = [] 2820 2821 # Add text content if present 2822 if content: 2823 content_blocks.append({ 2824 "type": "text", 2825 "text": content if isinstance(content, str) else "" 2826 }) 2827 2828 # Add tool use block - safely convert arguments 2829 try: 2830 # Parse arguments to ensure it's a dictionary 2831 if isinstance(tool_info["arguments"], str): 2832 try: 2833 args_dict = json.loads(tool_info["arguments"]) 2834 except json.JSONDecodeError: 2835 args_dict = {"text": tool_info["arguments"]} 2836 else: 2837 args_dict = tool_info["arguments"] 2838 except: 2839 args_dict = {"error": "Failed to parse arguments"} 2840 2841 content_blocks.append({ 2842 "type": "tool_use", 2843 "id": tool_info["id"], 2844 "name": tool_info["name"], 2845 "input": args_dict 2846 }) 2847 2848 # Create Anthropic assistant message with tool use 2849 self.history.append({ 2850 "role": "assistant", 2851 "content": content_blocks 2852 }) 2853 else: 2854 # Regular assistant message 2855 self.history.append({ 2856 "role": "assistant", 2857 "content": content 2858 }) 2859 2860 elif role == "tool": 2861 # Tool messages in standard format get converted to user messages with tool_result 2862 if "metadata" in msg and "tool_info" in msg["metadata"]: 2863 tool_info = msg["metadata"]["tool_info"] 2864 2865 try: 2866 result_str = json.dumps(content) if isinstance(content, (dict, list)) else str(content) 2867 except: 2868 result_str = str(content) 2869 2870 # Create Anthropic tool result message 2871 tool_result_msg = { 2872 "role": "user", 2873 "content": [{ 2874 "type": "tool_result", 2875 "tool_use_id": tool_info["id"], 2876 "content": result_str 2877 }] 2878 } 2879 self.history.append(tool_result_msg) 2880 2881 2882 def _generic_normalize_message(self, msg): 2883 """Generic normalizer for unknown SDKs. 2884 2885 This method provides a basic message normalization for SDKs that 2886 don't have specific handling. It performs minimal conversion to 2887 ensure basic message structure is maintained. 2888 2889 Args: 2890 msg (dict): The message to normalize, containing at minimum: 2891 - role: Message role (user/assistant/system) 2892 - content: Message content 2893 2894 Note: 2895 This is a fallback method and should be overridden for specific 2896 SDK implementations when possible. 2897 """ 2898 role = msg.get("role") 2899 content = msg.get("content") 2900 2901 if role in ["user", "assistant", "system"]: 2902 self.history.append({ 2903 "role": role, 2904 "content": content 2905 }) 2906 2907 2908 def track_token_usage(self): 2909 """Track and return token usage across the conversation history. 2910 2911 This method maintains a history of token usage measurements and provides 2912 current usage statistics. It tracks: 2913 - Current token count 2914 - Context length limit 2915 - Usage percentage 2916 - Historical measurements 2917 - Current provider and model info 2918 2919 Returns: 2920 dict: Dictionary containing: 2921 - current: Current token count 2922 - limit: Maximum context length 2923 - percentage: Usage as percentage of limit 2924 - history: Last 10 measurements with timestamps 2925 - provider: Current provider name 2926 - model: Current model name 2927 2928 Note: 2929 The history is limited to the last 100 measurements to prevent 2930 unbounded memory growth. 2931 """ 2932 if not hasattr(self, "_token_history"): 2933 self._token_history = [] 2934 2935 # Count current tokens 2936 current_count = self._count_tokens(self.history) 2937 2938 # Add to history 2939 timestamp = datetime.now(timezone.utc).isoformat() 2940 self._token_history.append({ 2941 "timestamp": timestamp, 2942 "count": current_count, 2943 "limit": self.context_length, 2944 "provider": self.provider, 2945 "model": self.model 2946 }) 2947 2948 # Keep only the last 100 measurements to avoid unlimited growth 2949 if len(self._token_history) > 100: 2950 self._token_history = self._token_history[-100:] 2951 2952 # Return current tracking info 2953 return { 2954 "current": current_count, 2955 "limit": self.context_length, 2956 "percentage": round((current_count / self.context_length) * 100, 1) if self.context_length else 0, 2957 "history": self._token_history[-10:], # Return last 10 measurements 2958 "provider": self.provider, 2959 "model": self.model 2960 } 2961 2962 2963 def get_message_token_breakdown(self): 2964 """Analyze token usage by message type and provide a detailed breakdown. 2965 2966 This method performs a detailed analysis of token usage across the 2967 conversation history, breaking down usage by: 2968 - Message role (system, user, assistant, tool) 2969 - Content type (text, tool calls, tool results) 2970 - Individual message statistics 2971 2972 Returns: 2973 dict: Token usage breakdown containing: 2974 - total: Total tokens used 2975 - by_role: Tokens used by each role 2976 - by_type: Tokens used by content type 2977 - messages: List of individual message stats including: 2978 - index: Message position 2979 - role: Message role 2980 - tokens: Tokens used 2981 - has_tools: Whether message contains tool calls 2982 2983 Note: 2984 This analysis is useful for understanding token usage patterns 2985 and optimizing conversation context. 2986 """ 2987 breakdown = { 2988 "total": 0, 2989 "by_role": { 2990 "system": 0, 2991 "user": 0, 2992 "assistant": 0, 2993 "tool": 0 2994 }, 2995 "by_type": { 2996 "text": 0, 2997 "tool_calls": 0, 2998 "tool_results": 0 2999 }, 3000 "messages": [] 3001 } 3002 3003 # Analyze each message 3004 for i, msg in enumerate(self.history): 3005 msg_tokens = self._count_tokens([msg]) 3006 role = msg.get("role", "unknown") 3007 3008 # Track by role 3009 if role in breakdown["by_role"]: 3010 breakdown["by_role"][role] += msg_tokens 3011 3012 # Track by content type 3013 if role == "assistant" and msg.get("tool_calls"): 3014 breakdown["by_type"]["tool_calls"] += msg_tokens 3015 elif role == "tool": 3016 breakdown["by_type"]["tool_results"] += msg_tokens 3017 else: 3018 breakdown["by_type"]["text"] += msg_tokens 3019 3020 # Add individual message data 3021 breakdown["messages"].append({ 3022 "index": i, 3023 "role": role, 3024 "tokens": msg_tokens, 3025 "has_tools": bool(msg.get("tool_calls") or msg.get("tool_use") or 3026 (isinstance(msg.get("content"), list) and 3027 any(isinstance(c, dict) and c.get("type") in ["tool_use", "tool_result"] 3028 for c in msg.get("content", [])))) 3029 }) 3030 3031 # Update tota? 3032 breakdown["total"] += msg_tokens 3033 3034 return breakdown
65 def __init__( 66 self, 67 base_url: Optional[str] = None, 68 api_key: Optional[str] = None, 69 model: str = "openai:gpt-4o-mini", 70 fallback_model = "ollama:mistral-nemo:latest", 71 tools: Optional[bool] = True, 72 stream: bool = True, 73 quiet: bool = False, 74 context_length: int = 128000, 75 max_retries: int = 3, 76 retry_delay: float = 1.0, 77 log_path: Optional[str] = None, 78 raw: Optional[bool] = False, 79 session_enabled: bool = False, 80 session_id: Optional[str] = None, 81 session_path: Optional[str] = None 82 ): 83 """Initialize the universal AI interaction client. 84 85 Args: 86 base_url: Optional base URL for the API. If None, uses the provider's default URL. 87 api_key: Optional API key. If None, attempts to use environment variables based on provider. 88 model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). 89 tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. 90 stream: Enable (True) or disable (False) streaming responses. 91 context_length: Maximum number of tokens to maintain in conversation history. 92 max_retries: Maximum number of retries for failed API calls. 93 retry_delay: Initial delay (in seconds) for exponential backoff retries. 94 session_enabled: Enable persistent session support. 95 session_id: Optional session ID to load messages from. 96 97 Raises: 98 ValueError: If provider is not supported or API key is missing for non-Ollama providers. 99 """ 100 self.system = "You are a helpful Assistant." 101 self.raw = raw 102 self.quiet = quiet 103 self.logger = logging.getLogger(f"InteractorLogger_{id(self)}") 104 self.logger.setLevel(logging.DEBUG) 105 self.providers = { 106 "openai": { 107 "sdk": "openai", 108 "base_url": "https://api.openai.com/v1", 109 "api_key": api_key or os.getenv("OPENAI_API_KEY") or None 110 }, 111 "ollama": { 112 "sdk": "openai", 113 "base_url": "http://localhost:11434/v1", 114 "api_key": api_key or "ollama" 115 }, 116 "nvidia": { 117 "sdk": "openai", 118 "base_url": "https://integrate.api.nvidia.com/v1", 119 "api_key": api_key or os.getenv("NVIDIA_API_KEY") or None 120 }, 121 "google": { 122 "sdk": "openai", 123 "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", 124 "api_key": api_key or os.getenv("GEMINI_API_KEY") or None 125 }, 126 "anthropic": { 127 "sdk": "anthropic", 128 "base_url": "https://api.anthropic.com/v1", 129 "api_key": api_key or os.getenv("ANTHROPIC_API_KEY") or None 130 }, 131 "mistral": { 132 "sdk": "openai", 133 "base_url": "https://api.mistral.ai/v1", 134 "api_key": api_key or os.getenv("MISTRAL_API_KEY") or None 135 }, 136 "deepseek": { 137 "sdk": "openai", 138 "base_url": "https://api.deepseek.com", 139 "api_key": api_key or os.getenv("DEEPSEEK_API_KEY") or None 140 }, 141 } 142 """ 143 "grok": { 144 "sdk": "grok", 145 "base_url": "https://api.x.ai/v1", 146 "api_key": api_key or os.getenv("GROK_API_KEY") or None 147 } 148 } 149 """ 150 151 # Console log handler (always enabled at WARNING+) 152 if not self.logger.handlers: 153 console_handler = logging.StreamHandler(sys.stdout) 154 console_handler.setLevel(logging.WARNING) 155 console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) 156 self.logger.addHandler(console_handler) 157 158 self._log_enabled = False 159 if log_path: 160 file_handler = logging.FileHandler(log_path) 161 file_handler.setLevel(logging.DEBUG) 162 file_handler.setFormatter(logging.Formatter( 163 "%(asctime)s - %(levelname)s - %(message)s", 164 datefmt="%Y-%m-%d %H:%M:%S" 165 )) 166 self.logger.addHandler(file_handler) 167 self._log_enabled = True 168 169 170 self.token_estimate = 0 171 self.last_token_estimate = 0 172 self.stream = stream 173 self.tools = [] 174 self.session_history = [] 175 self.history = [] 176 self.context_length = context_length 177 self.encoding = None 178 self.max_retries = max_retries 179 self.retry_delay = retry_delay 180 self.reveal_tool = [] 181 self.fallback_model = fallback_model 182 self.sdk = None 183 184 # Session support 185 self.session_enabled = session_enabled 186 self.session_id = session_id 187 self._last_session_id = session_id 188 self.session = Session(directory=session_path) if session_enabled else None 189 190 191 if model is None: 192 model = "openai:gpt-4o-mini" 193 194 # Initialize model + encoding 195 self._setup_client(model, base_url, api_key) 196 self.tools_enabled = self.tools_supported if tools is None else tools and self.tools_supported 197 self._setup_encoding() 198 self.messages_add(role="system", content=self.system)
Initialize the universal AI interaction client.
Args: base_url: Optional base URL for the API. If None, uses the provider's default URL. api_key: Optional API key. If None, attempts to use environment variables based on provider. model: Model identifier in format "provider:model_name" (e.g., "openai:gpt-4o-mini"). tools: Enable (True) or disable (False) tool calling; None for auto-detection based on model support. stream: Enable (True) or disable (False) streaming responses. context_length: Maximum number of tokens to maintain in conversation history. max_retries: Maximum number of retries for failed API calls. retry_delay: Initial delay (in seconds) for exponential backoff retries. session_enabled: Enable persistent session support. session_id: Optional session ID to load messages from.
Raises: ValueError: If provider is not supported or API key is missing for non-Ollama providers.
"grok": { "sdk": "grok", "base_url": "https://api.x.ai/v1", "api_key": api_key or os.getenv("GROK_API_KEY") or None } }
370 def add_function( 371 self, 372 external_callable: Callable, 373 name: Optional[str] = None, 374 description: Optional[str] = None, 375 override: bool = False, 376 disabled: bool = False, 377 schema_extensions: Optional[Dict[str, Any]] = None 378 ): 379 """ 380 Register a function for LLM tool calling with full type hints and metadata. 381 382 Args: 383 external_callable (Callable): The function to register. 384 name (Optional[str]): Optional custom name. Defaults to function's __name__. 385 description (Optional[str]): Optional custom description. Defaults to first line of docstring. 386 override (bool): If True, replaces an existing tool with the same name. 387 disabled (bool): If True, registers the function in a disabled state. 388 schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to 389 schema extensions that override or add to the auto-generated schema. 390 391 Raises: 392 ValueError: If the callable is invalid or duplicate name found without override. 393 394 Example: 395 interactor.add_function( 396 my_tool, 397 override=True, 398 disabled=False, 399 schema_extensions={ 400 "param1": {"minimum": 0, "maximum": 100}, 401 "param2": {"format": "email"} 402 } 403 ) 404 """ 405 def _python_type_to_schema(ptype: Any) -> dict: 406 """Convert a Python type annotation to OpenAI-compatible JSON Schema.""" 407 # Handle None case 408 if ptype is None: 409 return {"type": "null"} 410 411 # Get the origin and arguments of the type 412 origin = get_origin(ptype) 413 args = get_args(ptype) 414 415 # Handle Union types (including Optional) 416 if origin is Union: 417 # Check for Optional (Union with None) 418 none_type = type(None) 419 if none_type in args: 420 non_none = [a for a in args if a is not none_type] 421 if len(non_none) == 1: 422 inner = _python_type_to_schema(non_none[0]) 423 inner_copy = inner.copy() 424 inner_copy["nullable"] = True 425 return inner_copy 426 # Multiple types excluding None 427 types = [_python_type_to_schema(a) for a in non_none] 428 return {"anyOf": types, "nullable": True} 429 # Regular Union without None 430 return {"anyOf": [_python_type_to_schema(a) for a in args]} 431 432 # Handle List and similar container types 433 if origin in (list, List): 434 item_type = args[0] if args else Any 435 if item_type is Any: 436 return {"type": "array"} 437 return {"type": "array", "items": _python_type_to_schema(item_type)} 438 439 # Handle Dict types with typing info 440 if origin in (dict, Dict): 441 if not args or len(args) != 2: 442 return {"type": "object"} 443 444 key_type, val_type = args 445 # We can only really use val_type in JSON Schema 446 if val_type is not Any and val_type is not object: 447 return { 448 "type": "object", 449 "additionalProperties": _python_type_to_schema(val_type) 450 } 451 return {"type": "object"} 452 453 # Handle Literal types for enums 454 if origin is Literal: 455 values = args 456 # Try to determine type from values 457 if all(isinstance(v, str) for v in values): 458 return {"type": "string", "enum": list(values)} 459 elif all(isinstance(v, bool) for v in values): 460 return {"type": "boolean", "enum": list(values)} 461 elif all(isinstance(v, (int, float)) for v in values): 462 return {"type": "number", "enum": list(values)} 463 else: 464 # Mixed types, use anyOf 465 return {"anyOf": [{"type": _get_json_type(v), "enum": [v]} for v in values]} 466 467 # Handle basic types 468 if ptype is str: 469 return {"type": "string"} 470 if ptype is int: 471 return {"type": "integer"} 472 if ptype is float: 473 return {"type": "number"} 474 if ptype is bool: 475 return {"type": "boolean"} 476 477 # Handle common datetime types 478 if ptype is datetime: 479 return {"type": "string", "format": "date-time"} 480 if ptype is date: 481 return {"type": "string", "format": "date"} 482 483 # Handle UUID 484 if ptype is uuid.UUID: 485 return {"type": "string", "format": "uuid"} 486 487 # Default to object for any other types 488 return {"type": "object"} 489 490 def _get_json_type(value): 491 """Get the JSON Schema type name for a Python value. 492 493 This helper function maps Python types to their corresponding 494 JSON Schema type names. It handles basic types and provides 495 sensible defaults for complex types. 496 497 Args: 498 value: The Python value to get the JSON type for 499 500 Returns: 501 str: The JSON Schema type name ('string', 'number', 'boolean', 502 'array', 'object', or 'object' as default) 503 """ 504 if isinstance(value, str): 505 return "string" 506 elif isinstance(value, bool): 507 return "boolean" 508 elif isinstance(value, int) or isinstance(value, float): 509 return "number" 510 elif isinstance(value, list): 511 return "array" 512 elif isinstance(value, dict): 513 return "object" 514 else: 515 return "object" # Default 516 517 def _parse_param_docs(docstring: str) -> dict: 518 """Extract parameter descriptions from a docstring.""" 519 if not docstring: 520 return {} 521 522 lines = docstring.splitlines() 523 param_docs = {} 524 current_param = None 525 in_params = False 526 527 # Regular expressions for finding parameter sections and param lines 528 param_section_re = re.compile(r"^(Args|Parameters):\s*$") 529 param_line_re = re.compile(r"^\s{4}(\w+)\s*(?:\([^\)]*\))?:\s*(.*)") 530 531 for line in lines: 532 # Check if we're entering the parameters section 533 if param_section_re.match(line.strip()): 534 in_params = True 535 continue 536 537 if in_params: 538 # Skip empty lines 539 if not line.strip(): 540 continue 541 542 # Check for a parameter definition line 543 match = param_line_re.match(line) 544 if match: 545 current_param = match.group(1) 546 param_docs[current_param] = match.group(2).strip() 547 # Check for continuation of a parameter description 548 elif current_param and line.startswith(" " * 8): 549 param_docs[current_param] += " " + line.strip() 550 # If we see a line that doesn't match our patterns, we're out of the params section 551 else: 552 current_param = None 553 554 return param_docs 555 556 # Start of main function logic 557 558 # Skip if tools are disabled 559 if not self.tools_enabled: 560 return 561 562 # Validate input callable 563 if not external_callable: 564 raise ValueError("A valid external callable must be provided.") 565 566 # Set function name, either from parameter or from callable's __name__ 567 function_name = name or external_callable.__name__ 568 569 # Try to get docstring and extract description 570 try: 571 docstring = inspect.getdoc(external_callable) 572 description = description or (docstring.split("\n")[0].strip() if docstring else "No description provided.") 573 except Exception as e: 574 self._log(f"[TOOL] Warning: Could not extract docstring from {function_name}: {e}", level="warning") 575 docstring = "" 576 description = description or "No description provided." 577 578 # Extract parameter documentation from docstring 579 param_docs = _parse_param_docs(docstring) 580 581 # Handle conflicts with existing functions 582 if override: 583 self.delete_function(function_name) 584 elif any(t["function"]["name"] == function_name for t in self.tools): 585 raise ValueError(f"Function '{function_name}' is already registered. Use override=True to replace.") 586 587 # Try to get function signature for parameter info 588 try: 589 signature = inspect.signature(external_callable) 590 except (ValueError, TypeError) as e: 591 raise ValueError(f"Cannot inspect callable '{function_name}': {e}") 592 593 # Process parameters to build schema 594 properties = {} 595 required = [] 596 597 for param_name, param in signature.parameters.items(): 598 # Skip self, cls parameters for instance/class methods 599 if param_name in ("self", "cls") and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: 600 continue 601 602 # Get parameter annotation, defaulting to Any 603 annotation = param.annotation if param.annotation != inspect.Parameter.empty else Any 604 605 try: 606 # Convert Python type to JSON Schema 607 schema = _python_type_to_schema(annotation) 608 609 # Add description from docstring or create a default one 610 schema["description"] = param_docs.get(param_name, f"{param_name} parameter") 611 612 # Add to properties 613 properties[param_name] = schema 614 615 # If no default value is provided, parameter is required 616 if param.default == inspect.Parameter.empty: 617 required.append(param_name) 618 self._log(f"[TOOL] Parameter '{param_name}' is required", level="debug") 619 else: 620 self._log(f"[TOOL] Parameter '{param_name}' has default value: {param.default}", level="debug") 621 622 except Exception as e: 623 self._log(f"[TOOL] Error processing parameter {param_name} for {function_name}: {e}", level="error") 624 # Add a basic object schema as fallback 625 properties[param_name] = { 626 "type": "string", # Default to string instead of object for better compatibility 627 "description": f"{param_name} parameter (type conversion failed)" 628 } 629 630 # For parameters with no default value, mark as required even if processing failed 631 if param.default == inspect.Parameter.empty: 632 required.append(param_name) 633 self._log(f"[TOOL] Parameter '{param_name}' marked as required despite conversion failure", level="debug") 634 635 # Apply schema extensions if provided 636 if schema_extensions: 637 for param_name, extensions in schema_extensions.items(): 638 if param_name in properties: 639 properties[param_name].update(extensions) 640 641 # Create parameters object with proper placement of 'required' field 642 parameters = { 643 "type": "object", 644 "properties": properties, 645 } 646 647 # Only add required field if there are required parameters 648 if required: 649 parameters["required"] = required 650 651 # Build the final tool specification 652 tool_spec = { 653 "type": "function", 654 "function": { 655 "name": function_name, 656 "description": description, 657 "parameters": parameters 658 } 659 } 660 661 # Set disabled flag if requested 662 if disabled: 663 tool_spec["function"]["disabled"] = True 664 665 # Add to tools list 666 self.tools.append(tool_spec) 667 668 # Make the function available as an attribute on the instance 669 setattr(self, function_name, external_callable) 670 671 # Log the registration with detailed information 672 self._log(f"[TOOL] Registered function '{function_name}' with {len(properties)} parameters", level="info") 673 if required: 674 self._log(f"[TOOL] Required parameters: {required}", level="info") 675 676 return function_name # Return the name for reference
Register a function for LLM tool calling with full type hints and metadata.
Args: external_callable (Callable): The function to register. name (Optional[str]): Optional custom name. Defaults to function's __name__. description (Optional[str]): Optional custom description. Defaults to first line of docstring. override (bool): If True, replaces an existing tool with the same name. disabled (bool): If True, registers the function in a disabled state. schema_extensions (Optional[Dict[str, Any]]): Optional dictionary mapping parameter names to schema extensions that override or add to the auto-generated schema.
Raises: ValueError: If the callable is invalid or duplicate name found without override.
Example: interactor.add_function( my_tool, override=True, disabled=False, schema_extensions={ "param1": {"minimum": 0, "maximum": 100}, "param2": {"format": "email"} } )
679 def disable_function(self, name: str) -> bool: 680 """ 681 Disable a registered tool function by name. 682 683 This marks the function as inactive for tool calling without removing it from the internal registry. 684 The function remains visible in the tool listing but is skipped during tool selection by the LLM. 685 686 Args: 687 name (str): The name of the function to disable. 688 689 Returns: 690 bool: True if the function was found and disabled, False otherwise. 691 692 Example: 693 interactor.disable_function("extract_text") 694 """ 695 for tool in self.tools: 696 if tool["function"]["name"] == name: 697 tool["function"]["disabled"] = True 698 return True 699 return False
Disable a registered tool function by name.
This marks the function as inactive for tool calling without removing it from the internal registry. The function remains visible in the tool listing but is skipped during tool selection by the LLM.
Args: name (str): The name of the function to disable.
Returns: bool: True if the function was found and disabled, False otherwise.
Example: interactor.disable_function("extract_text")
702 def enable_function(self, name: str) -> bool: 703 """ 704 Re-enable a previously disabled tool function by name. 705 706 This removes the 'disabled' flag from a tool function, making it available again for LLM use. 707 708 Args: 709 name (str): The name of the function to enable. 710 711 Returns: 712 bool: True if the function was found and enabled, False otherwise. 713 714 Example: 715 interactor.enable_function("extract_text") 716 """ 717 for tool in self.tools: 718 if tool["function"]["name"] == name: 719 tool["function"].pop("disabled", None) 720 return True 721 return False
Re-enable a previously disabled tool function by name.
This removes the 'disabled' flag from a tool function, making it available again for LLM use.
Args: name (str): The name of the function to enable.
Returns: bool: True if the function was found and enabled, False otherwise.
Example: interactor.enable_function("extract_text")
724 def delete_function(self, name: str) -> bool: 725 """ 726 Permanently remove a registered tool function from the Interactor. 727 728 This deletes both the tool metadata and the callable attribute, making it fully inaccessible 729 from the active session. Useful for dynamically trimming the toolset. 730 731 Args: 732 name (str): The name of the function to delete. 733 734 Returns: 735 bool: True if the function was found and removed, False otherwise. 736 737 Example: 738 interactor.delete_function("extract_text") 739 """ 740 before = len(self.tools) 741 self.tools = [tool for tool in self.tools if tool["function"]["name"] != name] 742 if hasattr(self, name): 743 delattr(self, name) 744 return len(self.tools) < before
Permanently remove a registered tool function from the Interactor.
This deletes both the tool metadata and the callable attribute, making it fully inaccessible from the active session. Useful for dynamically trimming the toolset.
Args: name (str): The name of the function to delete.
Returns: bool: True if the function was found and removed, False otherwise.
Example: interactor.delete_function("extract_text")
747 def list_functions(self) -> List[Dict[str, Any]]: 748 """Get the list of registered functions for tool calling. 749 750 Returns: 751 List[Dict[str, Any]]: List of registered functions. 752 """ 753 return self.tools
Get the list of registered functions for tool calling.
Returns: List[Dict[str, Any]]: List of registered functions.
756 def list_models( 757 self, 758 providers: Optional[Union[str, List[str]]] = None, 759 filter: Optional[str] = None 760 ) -> List[str]: 761 """Retrieve available models from configured providers. 762 763 Args: 764 providers: Provider name or list of provider names. If None, all are queried. 765 filter: Optional regex to filter model names. 766 767 Returns: 768 List[str]: Sorted list of "provider:model_id" strings. 769 """ 770 models = [] 771 772 if providers is None: 773 providers_to_list = self.providers 774 elif isinstance(providers, str): 775 providers_to_list = {providers: self.providers.get(providers)} 776 elif isinstance(providers, list): 777 providers_to_list = {p: self.providers.get(p) for p in providers} 778 else: 779 return [] 780 781 invalid_providers = [p for p in providers_to_list if p not in self.providers or self.providers[p] is None] 782 if invalid_providers: 783 self.logger.error(f"Invalid providers: {invalid_providers}") 784 return [] 785 786 regex_pattern = None 787 if filter: 788 try: 789 regex_pattern = re.compile(filter, re.IGNORECASE) 790 except re.error as e: 791 self.logger.error(f"Invalid regex pattern: {e}") 792 return [] 793 794 for provider_name, config in providers_to_list.items(): 795 sdk = config.get("sdk", "openai") 796 base_url = config.get("base_url") 797 api_key = config.get("api_key") 798 799 try: 800 if sdk == "openai": 801 client = openai.OpenAI(api_key=api_key, base_url=base_url) 802 response = client.models.list() 803 for model in response.data: 804 model_id = f"{provider_name}:{model.id}" 805 if not regex_pattern or regex_pattern.search(model_id): 806 models.append(model_id) 807 808 elif sdk == "anthropic": 809 client = Anthropic(api_key=api_key) 810 response = client.models.list() 811 for model in response: 812 model_id = f"{provider_name}:{model.id}" 813 if not regex_pattern or regex_pattern.search(model_id): 814 models.append(model_id) 815 else: 816 self.logger.warning(f"SDK '{sdk}' for provider '{provider_name}' is not supported by list_models()") 817 818 except Exception as e: 819 self.logger.error(f"Failed to list models for {provider_name}: {e}") 820 821 return sorted(models, key=str.lower)
Retrieve available models from configured providers.
Args: providers: Provider name or list of provider names. If None, all are queried. filter: Optional regex to filter model names.
Returns: List[str]: Sorted list of "provider:model_id" strings.
874 def interact( 875 self, 876 user_input: Optional[str], 877 quiet: bool = False, 878 tools: bool = True, 879 stream: bool = True, 880 markdown: bool = False, 881 model: Optional[str] = None, 882 output_callback: Optional[Callable[[str], None]] = None, 883 session_id: Optional[str] = None, 884 raw: Optional[bool] = None, 885 tool_suppress: bool = True, 886 timeout: float = 60.0 887 ) -> Union[Optional[str], "TokenStream"]: 888 """Main universal gateway for all LLM interaction. 889 890 This function serves as the single entry point for all interactions with the language model. 891 When `raw=False` (default), it handles the interaction internally and returns the full response. 892 When `raw=True`, it returns a context manager that yields chunks of the response for custom handling. 893 894 Args: 895 user_input: Text input from the user. 896 quiet: If True, don't print status info or progress. 897 tools: Enable (True) or disable (False) tool calling. 898 stream: Enable (True) or disable (False) streaming responses. 899 markdown: If True, renders content as markdown. 900 model: Optional model override. 901 output_callback: Optional callback to handle the output. 902 session_id: Optional session ID to load messages from. 903 raw: If True, return a context manager instead of handling the interaction internally. 904 If None, use the class-level setting from __init__. 905 tool_suppress: If True and raw=True, filter out tool-related status messages. 906 timeout: Maximum time in seconds to wait for the stream to complete when raw=True. 907 908 Returns: 909 If raw=False: The complete response from the model as a string, or None if there was an error. 910 If raw=True: A context manager that yields chunks of the response as they arrive. 911 912 Example with default mode: 913 response = ai.interact("Tell me a joke") 914 915 Example with raw mode: 916 with ai.interact("Tell me a joke", raw=True) as stream: 917 for chunk in stream: 918 print(chunk, end="", flush=True) 919 """ 920 if not user_input: 921 return None 922 923 if quiet or self.quiet: 924 markdown = False 925 stream = False 926 927 # Determine if we should use raw mode 928 # If raw parameter is explicitly provided, use that; otherwise use class setting 929 use_raw = self.raw if raw is None else raw 930 931 # If raw mode is requested, delegate to interact_raw 932 if use_raw: 933 return self._interact_raw( 934 user_input=user_input, 935 tools=tools, 936 model=model, 937 session_id=session_id, 938 tool_suppress=tool_suppress, 939 timeout=timeout 940 ) 941 942 # Setup model if specified 943 if model: 944 self._setup_client(model) 945 self._setup_encoding() 946 947 # Session handling 948 if self.session_enabled and session_id: 949 self.session_id = session_id 950 self.session_load(session_id) 951 952 # Add user message using messages_add 953 self.messages_add(role="user", content=user_input) 954 955 # Log token count estimate 956 token_count = self._count_tokens(self.history) 957 if not quiet: 958 print(f"[dim]Estimated tokens in context: {token_count} / {self.context_length}[/dim]") 959 960 # Make sure we have enough context space 961 if token_count > self.context_length: 962 if self._cycle_messages(): 963 if not quiet: 964 print("[red]Context window exceeded. Cannot proceed.[/red]") 965 return None 966 967 # Log user input 968 self._log(f"[USER] {user_input}") 969 970 # Handle the actual interaction with complete streaming for all responses 971 result = asyncio.run(self._interact_async_core( 972 user_input=user_input, 973 quiet=quiet, 974 tools=tools, 975 stream=stream, 976 markdown=markdown, 977 output_callback=output_callback 978 )) 979 980 # Log completion for this interaction 981 self._log(f"[INTERACTION] Completed with {len(self.history)} total messages") 982 983 return result
Main universal gateway for all LLM interaction.
This function serves as the single entry point for all interactions with the language model.
When raw=False
(default), it handles the interaction internally and returns the full response.
When raw=True
, it returns a context manager that yields chunks of the response for custom handling.
Args: user_input: Text input from the user. quiet: If True, don't print status info or progress. tools: Enable (True) or disable (False) tool calling. stream: Enable (True) or disable (False) streaming responses. markdown: If True, renders content as markdown. model: Optional model override. output_callback: Optional callback to handle the output. session_id: Optional session ID to load messages from. raw: If True, return a context manager instead of handling the interaction internally. If None, use the class-level setting from __init__. tool_suppress: If True and raw=True, filter out tool-related status messages. timeout: Maximum time in seconds to wait for the stream to complete when raw=True.
Returns: If raw=False: The complete response from the model as a string, or None if there was an error. If raw=True: A context manager that yields chunks of the response as they arrive.
Example with default mode: response = ai.interact("Tell me a joke")
Example with raw mode: with ai.interact("Tell me a joke", raw=True) as stream: for chunk in stream: print(chunk, end="", flush=True)
2184 def messages_add( 2185 self, 2186 role: str, 2187 content: Any, 2188 tool_info: Optional[Dict] = None, 2189 normalize: bool = True 2190 ) -> str: 2191 """ 2192 Add a message to the standardized session_history and then update SDK-specific history. 2193 2194 This method is the central point for all message additions to the conversation. 2195 2196 Args: 2197 role: The role of the message ("user", "assistant", "system", "tool") 2198 content: The message content (text or structured) 2199 tool_info: Optional tool-related metadata 2200 normalize: Whether to normalize history after adding this message 2201 2202 Returns: 2203 str: Unique ID of the added message 2204 """ 2205 # Generate a unique message ID 2206 message_id = str(uuid.uuid4()) 2207 2208 # Create the standardized message for session_history 2209 timestamp = datetime.now(timezone.utc).isoformat() 2210 2211 # Store system messages directly 2212 if role == "system": 2213 self.system = content 2214 2215 # Create standard format message 2216 standard_message = { 2217 "role": role, 2218 "content": content, 2219 "id": message_id, 2220 "timestamp": timestamp, 2221 "metadata": { 2222 "sdk": self.sdk 2223 } 2224 } 2225 2226 # Add tool info if provided 2227 if tool_info: 2228 standard_message["metadata"]["tool_info"] = tool_info 2229 2230 # Add to session_history 2231 if not hasattr(self, "session_history"): 2232 self.session_history = [] 2233 2234 self.session_history.append(standard_message) 2235 2236 # Save to persistent session if enabled 2237 if self.session_enabled and self.session_id: 2238 # Convert standard message to session-compatible format 2239 session_msg = { 2240 "role": role, 2241 "content": content, 2242 "id": message_id, 2243 "timestamp": timestamp 2244 } 2245 2246 # Add tool-related fields if present 2247 if tool_info: 2248 for key, value in tool_info.items(): 2249 session_msg[key] = value 2250 2251 # Store in session 2252 self.session.msg_insert(self.session_id, session_msg) 2253 2254 # Update the SDK-specific format in self.history by running the normalizer 2255 if normalize: 2256 # We only need to normalize the most recent message for efficiency 2257 # Pass a flag indicating we're just normalizing a new message 2258 self._normalizer(force=False, new_message_only=True) 2259 2260 # Log the added message 2261 self._log(f"[MESSAGE ADDED] {role}: {str(content)[:50]}...") 2262 2263 return message_id
Add a message to the standardized session_history and then update SDK-specific history.
This method is the central point for all message additions to the conversation.
Args: role: The role of the message ("user", "assistant", "system", "tool") content: The message content (text or structured) tool_info: Optional tool-related metadata normalize: Whether to normalize history after adding this message
Returns: str: Unique ID of the added message
2266 def messages_system(self, prompt: str): 2267 """Set or retrieve the current system prompt. 2268 2269 This method manages the system prompt that guides the AI's behavior. 2270 It can be used to both set a new system prompt and retrieve the current one. 2271 When setting a new prompt, it updates the system message in the conversation 2272 history and persists it to the session if enabled. 2273 2274 Args: 2275 prompt (str): The new system prompt to set. If empty or None, 2276 returns the current system prompt without changes. 2277 2278 Returns: 2279 str: The current system prompt after any updates 2280 """ 2281 if not isinstance(prompt, str) or not prompt: 2282 return self.system 2283 2284 # If the prompt hasn't changed, don't do anything 2285 if self.system == prompt: 2286 return self.system 2287 2288 # Update the system prompt 2289 old_system = self.system 2290 self.system = prompt 2291 2292 # For OpenAI, update or insert the system message in history 2293 if self.sdk == "openai": 2294 # Check if there's already a system message 2295 system_index = next((i for i, msg in enumerate(self.history) 2296 if msg.get("role") == "system"), None) 2297 2298 if system_index is not None: 2299 # Update existing system message 2300 self.history[system_index]["content"] = prompt 2301 else: 2302 # Insert new system message at the beginning 2303 self.history.insert(0, {"role": "system", "content": prompt}) 2304 2305 # For Anthropic, system message is not part of history, just save it for API calls 2306 2307 # Log to session only if prompt actually changed 2308 if self.session_enabled and self.session_id and old_system != prompt: 2309 self.session.msg_insert(self.session_id, {"role": "system", "content": prompt}) 2310 2311 return self.system
Set or retrieve the current system prompt.
This method manages the system prompt that guides the AI's behavior. It can be used to both set a new system prompt and retrieve the current one. When setting a new prompt, it updates the system message in the conversation history and persists it to the session if enabled.
Args: prompt (str): The new system prompt to set. If empty or None, returns the current system prompt without changes.
Returns: str: The current system prompt after any updates
2314 def messages(self) -> list: 2315 """Return full session messages (persisted or in-memory).""" 2316 if self.session_enabled and self.session_id: 2317 return self.session.load_full(self.session_id).get("messages", []) 2318 return self.session_history
Return full session messages (persisted or in-memory).
2321 def messages_length(self) -> int: 2322 """Calculate the total token count for the message history.""" 2323 if not self.encoding: 2324 return 0 2325 2326 total_tokens = 0 2327 for message in self.history: 2328 if message.get("content"): 2329 total_tokens += len(self.encoding.encode(message["content"])) 2330 if message.get("tool_calls"): 2331 for tool_call in message["tool_calls"]: 2332 if tool_call.get("function"): 2333 total_tokens += len(self.encoding.encode(tool_call["function"].get("name", ""))) 2334 total_tokens += len(self.encoding.encode(tool_call["function"].get("arguments", ""))) 2335 return total_tokens
Calculate the total token count for the message history.
2338 def session_load(self, session_id: Optional[str]): 2339 """Load and normalize messages for a specific session. 2340 2341 This method loads a conversation session from persistent storage and 2342 normalizes the messages to the current SDK format. It handles system 2343 messages, tool calls, and maintains message ordering. If loading fails, 2344 it resets to an empty session with the default system prompt. 2345 2346 Args: 2347 session_id (Optional[str]): The ID of the session to load. 2348 If None, resets to in-memory mode. 2349 2350 Note: 2351 This method will update both session_history and history to match 2352 the loaded session's state. It also ensures proper SDK-specific 2353 message formatting. 2354 """ 2355 self.session_id = session_id 2356 self._last_session_id = session_id 2357 2358 if self.session_enabled and session_id: 2359 try: 2360 # Load raw session data 2361 session_data = self.session.load_full(session_id) 2362 messages = session_data.get("messages", []) 2363 2364 # Convert session format to our standard format 2365 self.session_history = [] 2366 2367 # Track the most recent system message 2368 latest_system_msg = None 2369 2370 for msg in messages: 2371 # Extract fields 2372 role = msg.get("role", "user") 2373 content = msg.get("content", "") 2374 msg_id = msg.get("id", str(uuid.uuid4())) 2375 timestamp = msg.get("timestamp", datetime.now(timezone.utc).isoformat()) 2376 2377 # If this is a system message, track it but don't add to session_history yet 2378 if role == "system": 2379 if latest_system_msg is None or timestamp > latest_system_msg["timestamp"]: 2380 latest_system_msg = { 2381 "role": role, 2382 "content": content, 2383 "id": msg_id, 2384 "timestamp": timestamp, 2385 "metadata": {"sdk": self.sdk} 2386 } 2387 continue 2388 2389 # Build tool_info if present 2390 tool_info = None 2391 if any(key in msg for key in ["tool_use_id", "tool_call_id", "name", "arguments"]): 2392 tool_info = { 2393 "id": msg.get("tool_use_id") or msg.get("tool_call_id"), 2394 "name": msg.get("name", "unknown_tool"), 2395 "arguments": msg.get("arguments", {}) 2396 } 2397 2398 # Create standard message 2399 standard_msg = { 2400 "role": role, 2401 "content": content, 2402 "id": msg_id, 2403 "timestamp": timestamp, 2404 "metadata": { 2405 "sdk": self.sdk 2406 } 2407 } 2408 2409 if tool_info: 2410 standard_msg["metadata"]["tool_info"] = tool_info 2411 2412 self.session_history.append(standard_msg) 2413 2414 # If we found a system message, update the system property and add to history 2415 if latest_system_msg: 2416 self.system = latest_system_msg["content"] 2417 # Insert at the beginning of session_history 2418 self.session_history.insert(0, latest_system_msg) 2419 else: 2420 # If no system message was found, add the current system message 2421 self.messages_add(role="system", content=self.system) 2422 2423 # Normalize to current SDK format 2424 self._normalizer(force=True) 2425 2426 self._log(f"[SESSION] Switched to session '{session_id}'") 2427 except Exception as e: 2428 self.logger.error(f"Failed to load session '{session_id}': {e}") 2429 self.session_reset() 2430 else: 2431 # Reset to empty state with system message 2432 self.session_reset()
Load and normalize messages for a specific session.
This method loads a conversation session from persistent storage and normalizes the messages to the current SDK format. It handles system messages, tool calls, and maintains message ordering. If loading fails, it resets to an empty session with the default system prompt.
Args: session_id (Optional[str]): The ID of the session to load. If None, resets to in-memory mode.
Note: This method will update both session_history and history to match the loaded session's state. It also ensures proper SDK-specific message formatting.
2435 def session_reset(self): 2436 """Reset the current session state and reinitialize to default system prompt. 2437 2438 This method performs a complete reset of the conversation state: 2439 1. Clears all message history 2440 2. Disables session ID tracking 2441 3. Returns to in-memory mode 2442 4. Reinitializes with the default system prompt 2443 2444 The reset is useful for starting fresh conversations or recovering 2445 from error states. It maintains the basic system configuration while 2446 clearing all conversation context. 2447 """ 2448 self.session_id = None 2449 self._last_session_id = None 2450 2451 # Clear histories 2452 self.session_history = [] 2453 self.history = [] 2454 2455 # Reapply the system message 2456 if hasattr(self, "system") and self.system: 2457 # Add to session_history 2458 self.messages_add(role="system", content=self.system) 2459 else: 2460 # Ensure we have a default system message 2461 self.system = "You are a helpful Assistant." 2462 self.messages_add(role="system", content=self.system) 2463 2464 self._log("[SESSION] Reset to in-memory mode")
Reset the current session state and reinitialize to default system prompt.
This method performs a complete reset of the conversation state:
- Clears all message history
- Disables session ID tracking
- Returns to in-memory mode
- Reinitializes with the default system prompt
The reset is useful for starting fresh conversations or recovering from error states. It maintains the basic system configuration while clearing all conversation context.
2908 def track_token_usage(self): 2909 """Track and return token usage across the conversation history. 2910 2911 This method maintains a history of token usage measurements and provides 2912 current usage statistics. It tracks: 2913 - Current token count 2914 - Context length limit 2915 - Usage percentage 2916 - Historical measurements 2917 - Current provider and model info 2918 2919 Returns: 2920 dict: Dictionary containing: 2921 - current: Current token count 2922 - limit: Maximum context length 2923 - percentage: Usage as percentage of limit 2924 - history: Last 10 measurements with timestamps 2925 - provider: Current provider name 2926 - model: Current model name 2927 2928 Note: 2929 The history is limited to the last 100 measurements to prevent 2930 unbounded memory growth. 2931 """ 2932 if not hasattr(self, "_token_history"): 2933 self._token_history = [] 2934 2935 # Count current tokens 2936 current_count = self._count_tokens(self.history) 2937 2938 # Add to history 2939 timestamp = datetime.now(timezone.utc).isoformat() 2940 self._token_history.append({ 2941 "timestamp": timestamp, 2942 "count": current_count, 2943 "limit": self.context_length, 2944 "provider": self.provider, 2945 "model": self.model 2946 }) 2947 2948 # Keep only the last 100 measurements to avoid unlimited growth 2949 if len(self._token_history) > 100: 2950 self._token_history = self._token_history[-100:] 2951 2952 # Return current tracking info 2953 return { 2954 "current": current_count, 2955 "limit": self.context_length, 2956 "percentage": round((current_count / self.context_length) * 100, 1) if self.context_length else 0, 2957 "history": self._token_history[-10:], # Return last 10 measurements 2958 "provider": self.provider, 2959 "model": self.model 2960 }
Track and return token usage across the conversation history.
This method maintains a history of token usage measurements and provides current usage statistics. It tracks:
- Current token count
- Context length limit
- Usage percentage
- Historical measurements
- Current provider and model info
Returns: dict: Dictionary containing: - current: Current token count - limit: Maximum context length - percentage: Usage as percentage of limit - history: Last 10 measurements with timestamps - provider: Current provider name - model: Current model name
Note: The history is limited to the last 100 measurements to prevent unbounded memory growth.
2963 def get_message_token_breakdown(self): 2964 """Analyze token usage by message type and provide a detailed breakdown. 2965 2966 This method performs a detailed analysis of token usage across the 2967 conversation history, breaking down usage by: 2968 - Message role (system, user, assistant, tool) 2969 - Content type (text, tool calls, tool results) 2970 - Individual message statistics 2971 2972 Returns: 2973 dict: Token usage breakdown containing: 2974 - total: Total tokens used 2975 - by_role: Tokens used by each role 2976 - by_type: Tokens used by content type 2977 - messages: List of individual message stats including: 2978 - index: Message position 2979 - role: Message role 2980 - tokens: Tokens used 2981 - has_tools: Whether message contains tool calls 2982 2983 Note: 2984 This analysis is useful for understanding token usage patterns 2985 and optimizing conversation context. 2986 """ 2987 breakdown = { 2988 "total": 0, 2989 "by_role": { 2990 "system": 0, 2991 "user": 0, 2992 "assistant": 0, 2993 "tool": 0 2994 }, 2995 "by_type": { 2996 "text": 0, 2997 "tool_calls": 0, 2998 "tool_results": 0 2999 }, 3000 "messages": [] 3001 } 3002 3003 # Analyze each message 3004 for i, msg in enumerate(self.history): 3005 msg_tokens = self._count_tokens([msg]) 3006 role = msg.get("role", "unknown") 3007 3008 # Track by role 3009 if role in breakdown["by_role"]: 3010 breakdown["by_role"][role] += msg_tokens 3011 3012 # Track by content type 3013 if role == "assistant" and msg.get("tool_calls"): 3014 breakdown["by_type"]["tool_calls"] += msg_tokens 3015 elif role == "tool": 3016 breakdown["by_type"]["tool_results"] += msg_tokens 3017 else: 3018 breakdown["by_type"]["text"] += msg_tokens 3019 3020 # Add individual message data 3021 breakdown["messages"].append({ 3022 "index": i, 3023 "role": role, 3024 "tokens": msg_tokens, 3025 "has_tools": bool(msg.get("tool_calls") or msg.get("tool_use") or 3026 (isinstance(msg.get("content"), list) and 3027 any(isinstance(c, dict) and c.get("type") in ["tool_use", "tool_result"] 3028 for c in msg.get("content", [])))) 3029 }) 3030 3031 # Update tota? 3032 breakdown["total"] += msg_tokens 3033 3034 return breakdown
Analyze token usage by message type and provide a detailed breakdown.
This method performs a detailed analysis of token usage across the conversation history, breaking down usage by:
- Message role (system, user, assistant, tool)
- Content type (text, tool calls, tool results)
- Individual message statistics
Returns: dict: Token usage breakdown containing: - total: Total tokens used - by_role: Tokens used by each role - by_type: Tokens used by content type - messages: List of individual message stats including: - index: Message position - role: Message role - tokens: Tokens used - has_tools: Whether message contains tool calls
Note: This analysis is useful for understanding token usage patterns and optimizing conversation context.
19class Session: 20 def __init__(self, directory: str = None): 21 """ 22 Initialize the session manager and ensure the session directory exists. 23 24 Args: 25 directory (str): Filesystem path for session storage. Must not be None or empty. 26 27 Raises: 28 ValueError: If directory is None or not a string. 29 OSError: If the directory cannot be created or accessed. 30 """ 31 if not directory: 32 raise ValueError("Session directory must be a valid non-empty string path.") 33 34 try: 35 self.path = Path(os.path.expanduser(directory)) 36 self.path.mkdir(parents=True, exist_ok=True) 37 except Exception as e: 38 raise OSError(f"Failed to initialize session directory '{directory}': {e}") 39 40 # --------------------------- 41 # Core CRUD 42 # --------------------------- 43 44 def list(self) -> List[Dict]: 45 """ 46 Return metadata for all sessions in the directory. 47 48 Returns: 49 List[Dict]: Sorted list of session metadata dictionaries. 50 """ 51 out = [] 52 for file in self.path.glob("*.json"): 53 try: 54 with open(file, "r") as f: 55 d = json.load(f) 56 out.append({ 57 "id": d.get("id"), 58 "name": d.get("name"), 59 "created": d.get("created"), 60 "tags": d.get("tags", []), 61 "summary": d.get("summary") 62 }) 63 except Exception: 64 continue 65 return sorted(out, key=lambda x: x["created"], reverse=True) 66 67 def create(self, name: str, tags: Optional[List[str]] = None) -> str: 68 """ 69 Create and persist a new session. 70 71 Args: 72 name (str): Name of the new session. 73 tags (List[str], optional): Optional list of tags. 74 75 Returns: 76 str: Unique session ID of the new session. 77 """ 78 sid = str(uuid.uuid4()) 79 session = { 80 "id": sid, 81 "name": name, 82 "created": datetime.now(timezone.utc).isoformat(), 83 "parent": None, 84 "branch_point": None, 85 "tags": tags or [], 86 "summary": None, 87 "messages": [] 88 } 89 self._save_file(sid, session) 90 return sid 91 92 def load(self, session_id: str) -> List[Dict]: 93 """ 94 Return OpenAI-compatible message list from a session. 95 96 Filters out internal keys and leaves only standard API-compatible fields. 97 98 Args: 99 session_id (str): ID of the session to load. 100 101 Returns: 102 List[Dict]: List of clean message dictionaries. 103 """ 104 session = self._read_file(session_id) 105 return [ 106 {k: v for k, v in m.items() if k in { 107 "role", "content", "tool_calls", "name", "function_call", "tool_call_id" 108 }} for m in session.get("messages", []) 109 ] 110 111 def load_full(self, session_id: str) -> Dict: 112 """ 113 Return the complete session file as-is. 114 115 Args: 116 session_id (str): ID of the session. 117 118 Returns: 119 Dict: Entire raw session data. 120 """ 121 return self._read_file(session_id) 122 123 def delete(self, session_id: str): 124 """ 125 Delete a session file from disk. 126 127 Args: 128 session_id (str): ID of the session to delete. 129 """ 130 file = self.path / f"{session_id}.json" 131 if file.exists(): 132 file.unlink() 133 134 def update(self, session_id: str, key: str, value: Any): 135 """ 136 Update a top-level key in a session file. 137 138 Args: 139 session_id (str): Session ID. 140 key (str): Field to update. 141 value (Any): New value for the field. 142 """ 143 session = self._read_file(session_id) 144 session[key] = value 145 self._save_file(session_id, session) 146 147 # --------------------------- 148 # Message Operations 149 # --------------------------- 150 151 def msg_insert(self, session_id: str, message: Dict) -> str: 152 """ 153 Insert a new message into a session. 154 155 Args: 156 session_id (str): Session ID. 157 message (Dict): Message dictionary to insert. 158 159 Returns: 160 str: ID of the inserted message. 161 """ 162 session = self._read_file(session_id) 163 entry = { 164 "id": str(uuid.uuid4()), 165 "timestamp": datetime.now(timezone.utc).isoformat(), 166 **message 167 } 168 session["messages"].append(entry) 169 self._save_file(session_id, session) 170 return entry["id"] 171 172 def msg_get(self, session_id: str, message_id: str) -> Optional[Dict]: 173 """ 174 Retrieve a specific message from a session. 175 176 Args: 177 session_id (str): Session ID. 178 message_id (str): ID of the message to retrieve. 179 180 Returns: 181 Optional[Dict]: The message if found, else None. 182 """ 183 session = self._read_file(session_id) 184 for msg in session.get("messages", []): 185 if msg.get("id") == message_id: 186 return msg 187 return None 188 189 def msg_index(self, session_id: str, message_id: str) -> Optional[int]: 190 """ 191 Get the index of a message within a session. 192 193 Args: 194 session_id (str): Session ID. 195 message_id (str): Message ID. 196 197 Returns: 198 Optional[int]: Index if found, else None. 199 """ 200 session = self._read_file(session_id) 201 for i, msg in enumerate(session.get("messages", [])): 202 if msg.get("id") == message_id: 203 return i 204 return None 205 206 def msg_update(self, session_id: str, message_id: str, new_content: str) -> bool: 207 """ 208 Update the content of a specific message. 209 210 Args: 211 session_id (str): Session ID. 212 message_id (str): Message ID. 213 new_content (str): New content for the message. 214 215 Returns: 216 bool: True if update succeeded, False otherwise. 217 """ 218 session = self._read_file(session_id) 219 for m in session["messages"]: 220 if m.get("id") == message_id: 221 m["content"] = new_content 222 self._save_file(session_id, session) 223 return True 224 return False 225 226 def msg_delete(self, session_id: str, message_id: str) -> bool: 227 """ 228 Delete a message from a session. 229 230 Args: 231 session_id (str): Session ID. 232 message_id (str): Message ID. 233 234 Returns: 235 bool: True if deletion occurred, False otherwise. 236 """ 237 session = self._read_file(session_id) 238 before = len(session["messages"]) 239 session["messages"] = [m for m in session["messages"] if m.get("id") != message_id] 240 self._save_file(session_id, session) 241 return len(session["messages"]) < before 242 243 # --------------------------- 244 # Branching & Summarization 245 # --------------------------- 246 247 def branch(self, from_id: str, message_id: str, new_name: str) -> str: 248 """Create a new session by branching from a specific message. 249 250 This method creates a new session that branches from an existing one at a specific 251 message point. The new session inherits all messages up to and including the 252 specified message, then starts fresh from there. 253 254 Args: 255 from_id (str): ID of the source session to branch from. 256 message_id (str): ID of the message to branch at. 257 new_name (str): Name for the new branched session. 258 259 Returns: 260 str: ID of the newly created branched session. 261 262 Raises: 263 ValueError: If the source session or message ID is not found. 264 """ 265 # Get source session 266 source = self._read_file(from_id) 267 if not source: 268 raise ValueError(f"Source session '{from_id}' not found") 269 270 # Find the branch point 271 branch_index = self.msg_index(from_id, message_id) 272 if branch_index is None: 273 raise ValueError(f"Message '{message_id}' not found in session '{from_id}'") 274 275 # Create new session 276 new_id = self.create(new_name, source.get("tags", [])) 277 new_session = self._read_file(new_id) 278 279 # Copy messages up to branch point 280 new_session["messages"] = source["messages"][:branch_index + 1] 281 new_session["parent"] = from_id 282 new_session["branch_point"] = message_id 283 284 # Save and return 285 self._save_file(new_id, new_session) 286 return new_id 287 288 def summarize(self, interactor, session_id: str) -> str: 289 """Generate a summary of the session using the provided interactor. 290 291 This method uses the AI interactor to analyze the session content and generate 292 a concise summary. The summary is stored in the session metadata and returned. 293 294 Args: 295 interactor: An AI interactor instance capable of generating summaries. 296 session_id (str): ID of the session to summarize. 297 298 Returns: 299 str: The generated summary text. 300 301 Note: 302 The summary is automatically stored in the session metadata and can be 303 retrieved later using load_full(). 304 """ 305 session = self._read_file(session_id) 306 if not session: 307 return "" 308 309 # Get clean message list 310 messages = self.load(session_id) 311 if not messages: 312 return "" 313 314 # Generate summary 315 summary = interactor.interact( 316 "Summarize this conversation in 2-3 sentences:", 317 tools=False, 318 stream=False, 319 markdown=False 320 ) 321 322 # Store and return 323 session["summary"] = summary 324 self._save_file(session_id, session) 325 return summary 326 327 # --------------------------- 328 # Search Capabilities 329 # --------------------------- 330 331 def search(self, query: str, session_id: Optional[str] = None) -> List[Dict]: 332 """Search for messages containing the query text within a session or all sessions. 333 334 This method performs a case-insensitive text search across message content. 335 If a session_id is provided, only searches within that session. Otherwise, 336 searches across all sessions. 337 338 Args: 339 query (str): Text to search for. 340 session_id (Optional[str]): Optional session ID to limit search scope. 341 342 Returns: 343 List[Dict]: List of matching messages with their session context. 344 Each dict contains: 345 - session_id: ID of the containing session 346 - message: The matching message 347 - context: Surrounding messages for context 348 """ 349 results = [] 350 query = query.lower() 351 352 # Determine search scope 353 if session_id: 354 sessions = [(session_id, self._read_file(session_id))] 355 else: 356 sessions = [(f.stem, self._read_file(f.stem)) for f in self.path.glob("*.json")] 357 358 # Search each session 359 for sid, session in sessions: 360 if not session: 361 continue 362 363 messages = session.get("messages", []) 364 for i, msg in enumerate(messages): 365 content = str(msg.get("content", "")).lower() 366 if query in content: 367 # Get context (2 messages before and after) 368 start = max(0, i - 2) 369 end = min(len(messages), i + 3) 370 context = messages[start:end] 371 372 results.append({ 373 "session_id": sid, 374 "message": msg, 375 "context": context 376 }) 377 378 return results 379 380 def search_meta(self, query: str) -> List[Dict]: 381 """Search session metadata (name, tags, summary) for matching sessions. 382 383 This method performs a case-insensitive search across session metadata fields 384 including name, tags, and summary. It returns matching sessions with their 385 full metadata. 386 387 Args: 388 query (str): Text to search for in metadata. 389 390 Returns: 391 List[Dict]: List of matching session metadata dictionaries. 392 Each dict contains: 393 - id: Session ID 394 - name: Session name 395 - created: Creation timestamp 396 - tags: List of tags 397 - summary: Session summary if available 398 """ 399 results = [] 400 query = query.lower() 401 402 for file in self.path.glob("*.json"): 403 try: 404 with open(file, "r") as f: 405 session = json.load(f) 406 407 # Check metadata fields 408 name = str(session.get("name", "")).lower() 409 tags = [str(t).lower() for t in session.get("tags", [])] 410 summary = str(session.get("summary", "")).lower() 411 412 if (query in name or 413 any(query in tag for tag in tags) or 414 query in summary): 415 results.append({ 416 "id": session.get("id"), 417 "name": session.get("name"), 418 "created": session.get("created"), 419 "tags": session.get("tags", []), 420 "summary": session.get("summary") 421 }) 422 except Exception: 423 continue 424 425 return sorted(results, key=lambda x: x["created"], reverse=True) 426 427 # --------------------------- 428 # Internal I/O 429 # --------------------------- 430 431 def _read_file(self, session_id: str) -> Dict: 432 """Read and parse a session file from disk. 433 434 This internal method handles reading and parsing session files. 435 It ensures proper error handling and returns an empty session 436 structure if the file doesn't exist or is invalid. 437 438 Args: 439 session_id (str): ID of the session to read. 440 441 Returns: 442 Dict: Session data dictionary or empty session structure. 443 """ 444 file = self.path / f"{session_id}.json" 445 if not file.exists(): 446 return { 447 "id": session_id, 448 "name": "New Session", 449 "created": datetime.now(timezone.utc).isoformat(), 450 "messages": [] 451 } 452 453 try: 454 with open(file, "r") as f: 455 return json.load(f) 456 except Exception: 457 return { 458 "id": session_id, 459 "name": "New Session", 460 "created": datetime.now(timezone.utc).isoformat(), 461 "messages": [] 462 } 463 464 def _save_file(self, session_id: str, data: Dict): 465 """Write session data to disk. 466 467 This internal method handles writing session data to disk. 468 It ensures proper error handling and atomic writes. 469 470 Args: 471 session_id (str): ID of the session to save. 472 data (Dict): Session data to write. 473 474 Raises: 475 OSError: If the file cannot be written. 476 """ 477 file = self.path / f"{session_id}.json" 478 temp_file = file.with_suffix(".tmp") 479 480 try: 481 # Write to temporary file first 482 with open(temp_file, "w") as f: 483 json.dump(data, f, indent=2) 484 485 # Atomic rename 486 temp_file.replace(file) 487 except Exception as e: 488 if temp_file.exists(): 489 temp_file.unlink() 490 raise OSError(f"Failed to save session '{session_id}': {e}")
20 def __init__(self, directory: str = None): 21 """ 22 Initialize the session manager and ensure the session directory exists. 23 24 Args: 25 directory (str): Filesystem path for session storage. Must not be None or empty. 26 27 Raises: 28 ValueError: If directory is None or not a string. 29 OSError: If the directory cannot be created or accessed. 30 """ 31 if not directory: 32 raise ValueError("Session directory must be a valid non-empty string path.") 33 34 try: 35 self.path = Path(os.path.expanduser(directory)) 36 self.path.mkdir(parents=True, exist_ok=True) 37 except Exception as e: 38 raise OSError(f"Failed to initialize session directory '{directory}': {e}")
Initialize the session manager and ensure the session directory exists.
Args: directory (str): Filesystem path for session storage. Must not be None or empty.
Raises: ValueError: If directory is None or not a string. OSError: If the directory cannot be created or accessed.
44 def list(self) -> List[Dict]: 45 """ 46 Return metadata for all sessions in the directory. 47 48 Returns: 49 List[Dict]: Sorted list of session metadata dictionaries. 50 """ 51 out = [] 52 for file in self.path.glob("*.json"): 53 try: 54 with open(file, "r") as f: 55 d = json.load(f) 56 out.append({ 57 "id": d.get("id"), 58 "name": d.get("name"), 59 "created": d.get("created"), 60 "tags": d.get("tags", []), 61 "summary": d.get("summary") 62 }) 63 except Exception: 64 continue 65 return sorted(out, key=lambda x: x["created"], reverse=True)
Return metadata for all sessions in the directory.
Returns: List[Dict]: Sorted list of session metadata dictionaries.
67 def create(self, name: str, tags: Optional[List[str]] = None) -> str: 68 """ 69 Create and persist a new session. 70 71 Args: 72 name (str): Name of the new session. 73 tags (List[str], optional): Optional list of tags. 74 75 Returns: 76 str: Unique session ID of the new session. 77 """ 78 sid = str(uuid.uuid4()) 79 session = { 80 "id": sid, 81 "name": name, 82 "created": datetime.now(timezone.utc).isoformat(), 83 "parent": None, 84 "branch_point": None, 85 "tags": tags or [], 86 "summary": None, 87 "messages": [] 88 } 89 self._save_file(sid, session) 90 return sid
Create and persist a new session.
Args: name (str): Name of the new session. tags (List[str], optional): Optional list of tags.
Returns: str: Unique session ID of the new session.
92 def load(self, session_id: str) -> List[Dict]: 93 """ 94 Return OpenAI-compatible message list from a session. 95 96 Filters out internal keys and leaves only standard API-compatible fields. 97 98 Args: 99 session_id (str): ID of the session to load. 100 101 Returns: 102 List[Dict]: List of clean message dictionaries. 103 """ 104 session = self._read_file(session_id) 105 return [ 106 {k: v for k, v in m.items() if k in { 107 "role", "content", "tool_calls", "name", "function_call", "tool_call_id" 108 }} for m in session.get("messages", []) 109 ]
Return OpenAI-compatible message list from a session.
Filters out internal keys and leaves only standard API-compatible fields.
Args: session_id (str): ID of the session to load.
Returns: List[Dict]: List of clean message dictionaries.
111 def load_full(self, session_id: str) -> Dict: 112 """ 113 Return the complete session file as-is. 114 115 Args: 116 session_id (str): ID of the session. 117 118 Returns: 119 Dict: Entire raw session data. 120 """ 121 return self._read_file(session_id)
Return the complete session file as-is.
Args: session_id (str): ID of the session.
Returns: Dict: Entire raw session data.
123 def delete(self, session_id: str): 124 """ 125 Delete a session file from disk. 126 127 Args: 128 session_id (str): ID of the session to delete. 129 """ 130 file = self.path / f"{session_id}.json" 131 if file.exists(): 132 file.unlink()
Delete a session file from disk.
Args: session_id (str): ID of the session to delete.
134 def update(self, session_id: str, key: str, value: Any): 135 """ 136 Update a top-level key in a session file. 137 138 Args: 139 session_id (str): Session ID. 140 key (str): Field to update. 141 value (Any): New value for the field. 142 """ 143 session = self._read_file(session_id) 144 session[key] = value 145 self._save_file(session_id, session)
Update a top-level key in a session file.
Args: session_id (str): Session ID. key (str): Field to update. value (Any): New value for the field.
151 def msg_insert(self, session_id: str, message: Dict) -> str: 152 """ 153 Insert a new message into a session. 154 155 Args: 156 session_id (str): Session ID. 157 message (Dict): Message dictionary to insert. 158 159 Returns: 160 str: ID of the inserted message. 161 """ 162 session = self._read_file(session_id) 163 entry = { 164 "id": str(uuid.uuid4()), 165 "timestamp": datetime.now(timezone.utc).isoformat(), 166 **message 167 } 168 session["messages"].append(entry) 169 self._save_file(session_id, session) 170 return entry["id"]
Insert a new message into a session.
Args: session_id (str): Session ID. message (Dict): Message dictionary to insert.
Returns: str: ID of the inserted message.
172 def msg_get(self, session_id: str, message_id: str) -> Optional[Dict]: 173 """ 174 Retrieve a specific message from a session. 175 176 Args: 177 session_id (str): Session ID. 178 message_id (str): ID of the message to retrieve. 179 180 Returns: 181 Optional[Dict]: The message if found, else None. 182 """ 183 session = self._read_file(session_id) 184 for msg in session.get("messages", []): 185 if msg.get("id") == message_id: 186 return msg 187 return None
Retrieve a specific message from a session.
Args: session_id (str): Session ID. message_id (str): ID of the message to retrieve.
Returns: Optional[Dict]: The message if found, else None.
189 def msg_index(self, session_id: str, message_id: str) -> Optional[int]: 190 """ 191 Get the index of a message within a session. 192 193 Args: 194 session_id (str): Session ID. 195 message_id (str): Message ID. 196 197 Returns: 198 Optional[int]: Index if found, else None. 199 """ 200 session = self._read_file(session_id) 201 for i, msg in enumerate(session.get("messages", [])): 202 if msg.get("id") == message_id: 203 return i 204 return None
Get the index of a message within a session.
Args: session_id (str): Session ID. message_id (str): Message ID.
Returns: Optional[int]: Index if found, else None.
206 def msg_update(self, session_id: str, message_id: str, new_content: str) -> bool: 207 """ 208 Update the content of a specific message. 209 210 Args: 211 session_id (str): Session ID. 212 message_id (str): Message ID. 213 new_content (str): New content for the message. 214 215 Returns: 216 bool: True if update succeeded, False otherwise. 217 """ 218 session = self._read_file(session_id) 219 for m in session["messages"]: 220 if m.get("id") == message_id: 221 m["content"] = new_content 222 self._save_file(session_id, session) 223 return True 224 return False
Update the content of a specific message.
Args: session_id (str): Session ID. message_id (str): Message ID. new_content (str): New content for the message.
Returns: bool: True if update succeeded, False otherwise.
226 def msg_delete(self, session_id: str, message_id: str) -> bool: 227 """ 228 Delete a message from a session. 229 230 Args: 231 session_id (str): Session ID. 232 message_id (str): Message ID. 233 234 Returns: 235 bool: True if deletion occurred, False otherwise. 236 """ 237 session = self._read_file(session_id) 238 before = len(session["messages"]) 239 session["messages"] = [m for m in session["messages"] if m.get("id") != message_id] 240 self._save_file(session_id, session) 241 return len(session["messages"]) < before
Delete a message from a session.
Args: session_id (str): Session ID. message_id (str): Message ID.
Returns: bool: True if deletion occurred, False otherwise.
247 def branch(self, from_id: str, message_id: str, new_name: str) -> str: 248 """Create a new session by branching from a specific message. 249 250 This method creates a new session that branches from an existing one at a specific 251 message point. The new session inherits all messages up to and including the 252 specified message, then starts fresh from there. 253 254 Args: 255 from_id (str): ID of the source session to branch from. 256 message_id (str): ID of the message to branch at. 257 new_name (str): Name for the new branched session. 258 259 Returns: 260 str: ID of the newly created branched session. 261 262 Raises: 263 ValueError: If the source session or message ID is not found. 264 """ 265 # Get source session 266 source = self._read_file(from_id) 267 if not source: 268 raise ValueError(f"Source session '{from_id}' not found") 269 270 # Find the branch point 271 branch_index = self.msg_index(from_id, message_id) 272 if branch_index is None: 273 raise ValueError(f"Message '{message_id}' not found in session '{from_id}'") 274 275 # Create new session 276 new_id = self.create(new_name, source.get("tags", [])) 277 new_session = self._read_file(new_id) 278 279 # Copy messages up to branch point 280 new_session["messages"] = source["messages"][:branch_index + 1] 281 new_session["parent"] = from_id 282 new_session["branch_point"] = message_id 283 284 # Save and return 285 self._save_file(new_id, new_session) 286 return new_id
Create a new session by branching from a specific message.
This method creates a new session that branches from an existing one at a specific message point. The new session inherits all messages up to and including the specified message, then starts fresh from there.
Args: from_id (str): ID of the source session to branch from. message_id (str): ID of the message to branch at. new_name (str): Name for the new branched session.
Returns: str: ID of the newly created branched session.
Raises: ValueError: If the source session or message ID is not found.
288 def summarize(self, interactor, session_id: str) -> str: 289 """Generate a summary of the session using the provided interactor. 290 291 This method uses the AI interactor to analyze the session content and generate 292 a concise summary. The summary is stored in the session metadata and returned. 293 294 Args: 295 interactor: An AI interactor instance capable of generating summaries. 296 session_id (str): ID of the session to summarize. 297 298 Returns: 299 str: The generated summary text. 300 301 Note: 302 The summary is automatically stored in the session metadata and can be 303 retrieved later using load_full(). 304 """ 305 session = self._read_file(session_id) 306 if not session: 307 return "" 308 309 # Get clean message list 310 messages = self.load(session_id) 311 if not messages: 312 return "" 313 314 # Generate summary 315 summary = interactor.interact( 316 "Summarize this conversation in 2-3 sentences:", 317 tools=False, 318 stream=False, 319 markdown=False 320 ) 321 322 # Store and return 323 session["summary"] = summary 324 self._save_file(session_id, session) 325 return summary
Generate a summary of the session using the provided interactor.
This method uses the AI interactor to analyze the session content and generate a concise summary. The summary is stored in the session metadata and returned.
Args: interactor: An AI interactor instance capable of generating summaries. session_id (str): ID of the session to summarize.
Returns: str: The generated summary text.
Note: The summary is automatically stored in the session metadata and can be retrieved later using load_full().
331 def search(self, query: str, session_id: Optional[str] = None) -> List[Dict]: 332 """Search for messages containing the query text within a session or all sessions. 333 334 This method performs a case-insensitive text search across message content. 335 If a session_id is provided, only searches within that session. Otherwise, 336 searches across all sessions. 337 338 Args: 339 query (str): Text to search for. 340 session_id (Optional[str]): Optional session ID to limit search scope. 341 342 Returns: 343 List[Dict]: List of matching messages with their session context. 344 Each dict contains: 345 - session_id: ID of the containing session 346 - message: The matching message 347 - context: Surrounding messages for context 348 """ 349 results = [] 350 query = query.lower() 351 352 # Determine search scope 353 if session_id: 354 sessions = [(session_id, self._read_file(session_id))] 355 else: 356 sessions = [(f.stem, self._read_file(f.stem)) for f in self.path.glob("*.json")] 357 358 # Search each session 359 for sid, session in sessions: 360 if not session: 361 continue 362 363 messages = session.get("messages", []) 364 for i, msg in enumerate(messages): 365 content = str(msg.get("content", "")).lower() 366 if query in content: 367 # Get context (2 messages before and after) 368 start = max(0, i - 2) 369 end = min(len(messages), i + 3) 370 context = messages[start:end] 371 372 results.append({ 373 "session_id": sid, 374 "message": msg, 375 "context": context 376 }) 377 378 return results
Search for messages containing the query text within a session or all sessions.
This method performs a case-insensitive text search across message content. If a session_id is provided, only searches within that session. Otherwise, searches across all sessions.
Args: query (str): Text to search for. session_id (Optional[str]): Optional session ID to limit search scope.
Returns: List[Dict]: List of matching messages with their session context. Each dict contains: - session_id: ID of the containing session - message: The matching message - context: Surrounding messages for context
380 def search_meta(self, query: str) -> List[Dict]: 381 """Search session metadata (name, tags, summary) for matching sessions. 382 383 This method performs a case-insensitive search across session metadata fields 384 including name, tags, and summary. It returns matching sessions with their 385 full metadata. 386 387 Args: 388 query (str): Text to search for in metadata. 389 390 Returns: 391 List[Dict]: List of matching session metadata dictionaries. 392 Each dict contains: 393 - id: Session ID 394 - name: Session name 395 - created: Creation timestamp 396 - tags: List of tags 397 - summary: Session summary if available 398 """ 399 results = [] 400 query = query.lower() 401 402 for file in self.path.glob("*.json"): 403 try: 404 with open(file, "r") as f: 405 session = json.load(f) 406 407 # Check metadata fields 408 name = str(session.get("name", "")).lower() 409 tags = [str(t).lower() for t in session.get("tags", [])] 410 summary = str(session.get("summary", "")).lower() 411 412 if (query in name or 413 any(query in tag for tag in tags) or 414 query in summary): 415 results.append({ 416 "id": session.get("id"), 417 "name": session.get("name"), 418 "created": session.get("created"), 419 "tags": session.get("tags", []), 420 "summary": session.get("summary") 421 }) 422 except Exception: 423 continue 424 425 return sorted(results, key=lambda x: x["created"], reverse=True)
Search session metadata (name, tags, summary) for matching sessions.
This method performs a case-insensitive search across session metadata fields including name, tags, and summary. It returns matching sessions with their full metadata.
Args: query (str): Text to search for in metadata.
Returns: List[Dict]: List of matching session metadata dictionaries. Each dict contains: - id: Session ID - name: Session name - created: Creation timestamp - tags: List of tags - summary: Session summary if available