feat: added proper settings handling
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""Agent module for media library management."""
|
||||
|
||||
from .agent import Agent
|
||||
from .config import settings
|
||||
from alfred.settings import settings
|
||||
|
||||
__all__ = ["Agent", "settings"]
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any
|
||||
|
||||
from alfred.infrastructure.persistence import get_memory
|
||||
|
||||
from .config import settings
|
||||
from alfred.settings import settings
|
||||
from .prompts import PromptBuilder
|
||||
from .registry import Tool, make_tools
|
||||
|
||||
@@ -21,17 +21,20 @@ class Agent:
|
||||
Uses OpenAI-compatible tool calling API.
|
||||
"""
|
||||
|
||||
def __init__(self, llm, max_tool_iterations: int = 5):
|
||||
def __init__(self, settings, llm, max_tool_iterations: int = 5):
|
||||
"""
|
||||
Initialize the agent.
|
||||
|
||||
Args:
|
||||
settings: Application settings instance
|
||||
llm: LLM client with complete() method
|
||||
max_tool_iterations: Maximum number of tool execution iterations
|
||||
"""
|
||||
self.settings = settings
|
||||
self.llm = llm
|
||||
self.tools: dict[str, Tool] = make_tools()
|
||||
self.tools: dict[str, Tool] = make_tools(settings)
|
||||
self.prompt_builder = PromptBuilder(self.tools)
|
||||
self.settings = settings
|
||||
self.max_tool_iterations = max_tool_iterations
|
||||
|
||||
def step(self, user_input: str) -> str:
|
||||
@@ -78,7 +81,7 @@ class Agent:
|
||||
tools_spec = self.prompt_builder.build_tools_spec()
|
||||
|
||||
# Tool execution loop
|
||||
for _iteration in range(self.max_tool_iterations):
|
||||
for _iteration in range(self.settings.max_tool_iterations):
|
||||
# Call LLM with tools
|
||||
llm_result = self.llm.complete(messages, tools=tools_spec)
|
||||
|
||||
@@ -230,7 +233,7 @@ class Agent:
|
||||
tools_spec = self.prompt_builder.build_tools_spec()
|
||||
|
||||
# Tool execution loop
|
||||
for _iteration in range(self.max_tool_iterations):
|
||||
for _iteration in range(self.settings.max_tool_iterations):
|
||||
# Call LLM with tools
|
||||
llm_result = self.llm.complete(messages, tools=tools_spec)
|
||||
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
"""Configuration management with validation."""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class ConfigurationError(Exception):
|
||||
"""Raised when configuration is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Settings:
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# LLM Configuration
|
||||
deepseek_api_key: str = field(
|
||||
default_factory=lambda: os.getenv("DEEPSEEK_API_KEY", "")
|
||||
)
|
||||
deepseek_base_url: str = field(
|
||||
default_factory=lambda: os.getenv(
|
||||
"DEEPSEEK_BASE_URL", "https://api.deepseek.com"
|
||||
)
|
||||
)
|
||||
model: str = field(
|
||||
default_factory=lambda: os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
|
||||
)
|
||||
temperature: float = field(
|
||||
default_factory=lambda: float(os.getenv("TEMPERATURE", "0.2"))
|
||||
)
|
||||
|
||||
# TMDB Configuration
|
||||
tmdb_api_key: str = field(default_factory=lambda: os.getenv("TMDB_API_KEY", ""))
|
||||
tmdb_base_url: str = field(
|
||||
default_factory=lambda: os.getenv(
|
||||
"TMDB_BASE_URL", "https://api.themoviedb.org/3"
|
||||
)
|
||||
)
|
||||
|
||||
# Storage Configuration
|
||||
memory_file: str = field(
|
||||
default_factory=lambda: os.getenv("MEMORY_FILE", "memory.json")
|
||||
)
|
||||
|
||||
# Security Configuration
|
||||
max_tool_iterations: int = field(
|
||||
default_factory=lambda: int(os.getenv("MAX_TOOL_ITERATIONS", "5"))
|
||||
)
|
||||
request_timeout: int = field(
|
||||
default_factory=lambda: int(os.getenv("REQUEST_TIMEOUT", "30"))
|
||||
)
|
||||
|
||||
# Memory Configuration
|
||||
max_history_messages: int = field(
|
||||
default_factory=lambda: int(os.getenv("MAX_HISTORY_MESSAGES", "10"))
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate settings after initialization."""
|
||||
self._validate()
|
||||
|
||||
def _validate(self) -> None:
|
||||
"""Validate configuration values."""
|
||||
# Validate temperature
|
||||
if not 0.0 <= self.temperature <= 2.0:
|
||||
raise ConfigurationError(
|
||||
f"Temperature must be between 0.0 and 2.0, got {self.temperature}"
|
||||
)
|
||||
|
||||
# Validate max_tool_iterations
|
||||
if self.max_tool_iterations < 1 or self.max_tool_iterations > 20:
|
||||
raise ConfigurationError(
|
||||
f"max_tool_iterations must be between 1 and 20, got {self.max_tool_iterations}"
|
||||
)
|
||||
|
||||
# Validate request_timeout
|
||||
if self.request_timeout < 1 or self.request_timeout > 300:
|
||||
raise ConfigurationError(
|
||||
f"request_timeout must be between 1 and 300 seconds, got {self.request_timeout}"
|
||||
)
|
||||
|
||||
# Validate URLs
|
||||
if not self.deepseek_base_url.startswith(("http://", "https://")):
|
||||
raise ConfigurationError(
|
||||
f"Invalid deepseek_base_url: {self.deepseek_base_url}"
|
||||
)
|
||||
|
||||
if not self.tmdb_base_url.startswith(("http://", "https://")):
|
||||
raise ConfigurationError(f"Invalid tmdb_base_url: {self.tmdb_base_url}")
|
||||
|
||||
# Validate memory file path
|
||||
memory_path = Path(self.memory_file)
|
||||
if memory_path.exists() and not memory_path.is_file():
|
||||
raise ConfigurationError(
|
||||
f"memory_file exists but is not a file: {self.memory_file}"
|
||||
)
|
||||
|
||||
def is_deepseek_configured(self) -> bool:
|
||||
"""Check if DeepSeek API is properly configured."""
|
||||
return bool(self.deepseek_api_key and self.deepseek_base_url)
|
||||
|
||||
def is_tmdb_configured(self) -> bool:
|
||||
"""Check if TMDB API is properly configured."""
|
||||
return bool(self.tmdb_api_key and self.tmdb_base_url)
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
import requests
|
||||
from requests.exceptions import HTTPError, RequestException, Timeout
|
||||
|
||||
from ..config import settings
|
||||
from alfred.settings import settings, Settings
|
||||
from .exceptions import LLMAPIError, LLMConfigurationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -21,6 +21,7 @@ class DeepSeekClient:
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
timeout: int | None = None,
|
||||
settings: Settings | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize DeepSeek client.
|
||||
@@ -34,10 +35,10 @@ class DeepSeekClient:
|
||||
Raises:
|
||||
LLMConfigurationError: If API key is missing
|
||||
"""
|
||||
self.api_key = api_key or settings.deepseek_api_key
|
||||
self.base_url = base_url or settings.deepseek_base_url
|
||||
self.model = model or settings.model
|
||||
self.timeout = timeout or settings.request_timeout
|
||||
self.api_key = api_key or self.settings.deepseek_api_key
|
||||
self.base_url = base_url or self.settings.deepseek_base_url
|
||||
self.model = model or self.settings.deepseek_model
|
||||
self.timeout = timeout or self.settings.request_timeout
|
||||
|
||||
if not self.api_key:
|
||||
raise LLMConfigurationError(
|
||||
@@ -94,7 +95,7 @@ class DeepSeekClient:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": settings.temperature,
|
||||
"temperature": settings.llm_temperature,
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any
|
||||
import requests
|
||||
from requests.exceptions import HTTPError, RequestException, Timeout
|
||||
|
||||
from ..config import settings
|
||||
from alfred.settings import Settings, settings
|
||||
from .exceptions import LLMAPIError, LLMConfigurationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -32,6 +32,7 @@ class OllamaClient:
|
||||
model: str | None = None,
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
settings: Settings | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize Ollama client.
|
||||
@@ -45,13 +46,11 @@ class OllamaClient:
|
||||
Raises:
|
||||
LLMConfigurationError: If configuration is invalid
|
||||
"""
|
||||
self.base_url = base_url or os.getenv(
|
||||
"OLLAMA_BASE_URL", "http://localhost:11434"
|
||||
)
|
||||
self.model = model or os.getenv("OLLAMA_MODEL", "llama3.2")
|
||||
self.base_url = base_url or settings.ollama_base_url
|
||||
self.model = model or settings.ollama_model
|
||||
self.timeout = timeout or settings.request_timeout
|
||||
self.temperature = (
|
||||
temperature if temperature is not None else settings.temperature
|
||||
temperature if temperature is not None else settings.llm_temperature
|
||||
)
|
||||
|
||||
if not self.base_url:
|
||||
|
||||
@@ -78,10 +78,13 @@ def _create_tool_from_function(func: Callable) -> Tool:
|
||||
)
|
||||
|
||||
|
||||
def make_tools() -> dict[str, Tool]:
|
||||
def make_tools(settings) -> dict[str, Tool]:
|
||||
"""
|
||||
Create and register all available tools.
|
||||
|
||||
Args:
|
||||
settings: Application settings instance
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool names to Tool objects
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user