infra: reorganized repo
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
"""Agent module for media library management."""
|
||||
|
||||
from .agent import Agent
|
||||
from .config import settings
|
||||
|
||||
__all__ = ["Agent", "settings"]
|
||||
@@ -0,0 +1,371 @@
|
||||
"""Main agent for media library management."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
from .config import settings
|
||||
from .prompts import PromptBuilder
|
||||
from .registry import Tool, make_tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Agent:
|
||||
"""
|
||||
AI agent for media library management.
|
||||
|
||||
Uses OpenAI-compatible tool calling API.
|
||||
"""
|
||||
|
||||
def __init__(self, llm, max_tool_iterations: int = 5):
|
||||
"""
|
||||
Initialize the agent.
|
||||
|
||||
Args:
|
||||
llm: LLM client with complete() method
|
||||
max_tool_iterations: Maximum number of tool execution iterations
|
||||
"""
|
||||
self.llm = llm
|
||||
self.tools: dict[str, Tool] = make_tools()
|
||||
self.prompt_builder = PromptBuilder(self.tools)
|
||||
self.max_tool_iterations = max_tool_iterations
|
||||
|
||||
def step(self, user_input: str) -> str:
|
||||
"""
|
||||
Execute one agent step with the user input.
|
||||
|
||||
This method:
|
||||
1. Adds user message to memory
|
||||
2. Builds prompt with history and context
|
||||
3. Calls LLM, executing tools as needed
|
||||
4. Returns final response
|
||||
|
||||
Args:
|
||||
user_input: User's message
|
||||
|
||||
Returns:
|
||||
Agent's final response
|
||||
"""
|
||||
memory = get_memory()
|
||||
|
||||
# Add user message to history
|
||||
memory.stm.add_message("user", user_input)
|
||||
memory.save()
|
||||
|
||||
# Build initial messages
|
||||
system_prompt = self.prompt_builder.build_system_prompt()
|
||||
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add conversation history
|
||||
history = memory.stm.get_recent_history(settings.max_history_messages)
|
||||
messages.extend(history)
|
||||
|
||||
# Add unread events if any
|
||||
unread_events = memory.episodic.get_unread_events()
|
||||
if unread_events:
|
||||
events_text = "\n".join(
|
||||
[f"- {e['type']}: {e['data']}" for e in unread_events]
|
||||
)
|
||||
messages.append(
|
||||
{"role": "system", "content": f"Background events:\n{events_text}"}
|
||||
)
|
||||
|
||||
# Get tools specification for OpenAI format
|
||||
tools_spec = self.prompt_builder.build_tools_spec()
|
||||
|
||||
# Tool execution loop
|
||||
for _iteration in range(self.max_tool_iterations):
|
||||
# Call LLM with tools
|
||||
llm_result = self.llm.complete(messages, tools=tools_spec)
|
||||
|
||||
# Handle both tuple (response, usage) and dict response
|
||||
if isinstance(llm_result, tuple):
|
||||
response_message, usage = llm_result
|
||||
else:
|
||||
response_message = llm_result
|
||||
|
||||
# Check if there are tool calls
|
||||
tool_calls = response_message.get("tool_calls")
|
||||
|
||||
if not tool_calls:
|
||||
# No tool calls, this is the final response
|
||||
final_content = response_message.get("content", "")
|
||||
memory.stm.add_message("assistant", final_content)
|
||||
memory.save()
|
||||
return final_content
|
||||
|
||||
# Add assistant message with tool calls to conversation
|
||||
messages.append(response_message)
|
||||
|
||||
# Execute each tool call
|
||||
for tool_call in tool_calls:
|
||||
tool_result = self._execute_tool_call(tool_call)
|
||||
|
||||
# Add tool result to messages
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
"role": "tool",
|
||||
"name": tool_call.get("function", {}).get("name"),
|
||||
"content": json.dumps(tool_result, ensure_ascii=False),
|
||||
}
|
||||
)
|
||||
|
||||
# Max iterations reached, force final response
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Please provide a final response to the user without using any more tools.",
|
||||
}
|
||||
)
|
||||
|
||||
llm_result = self.llm.complete(messages)
|
||||
if isinstance(llm_result, tuple):
|
||||
final_message, usage = llm_result
|
||||
else:
|
||||
final_message = llm_result
|
||||
|
||||
final_response = final_message.get(
|
||||
"content", "I've completed the requested actions."
|
||||
)
|
||||
memory.stm.add_message("assistant", final_response)
|
||||
memory.save()
|
||||
return final_response
|
||||
|
||||
def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a single tool call.
|
||||
|
||||
Args:
|
||||
tool_call: OpenAI-format tool call dict
|
||||
|
||||
Returns:
|
||||
Result dictionary
|
||||
"""
|
||||
function = tool_call.get("function", {})
|
||||
tool_name = function.get("name", "")
|
||||
|
||||
try:
|
||||
args_str = function.get("arguments", "{}")
|
||||
args = json.loads(args_str)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse tool arguments: {e}")
|
||||
return {"error": "bad_args", "message": f"Invalid JSON arguments: {e}"}
|
||||
|
||||
# Validate tool exists
|
||||
if tool_name not in self.tools:
|
||||
available = list(self.tools.keys())
|
||||
return {
|
||||
"error": "unknown_tool",
|
||||
"message": f"Tool '{tool_name}' not found",
|
||||
"available_tools": available,
|
||||
}
|
||||
|
||||
tool = self.tools[tool_name]
|
||||
|
||||
# Execute tool
|
||||
try:
|
||||
result = tool.func(**args)
|
||||
return result
|
||||
except KeyboardInterrupt:
|
||||
# Don't catch KeyboardInterrupt - let it propagate
|
||||
raise
|
||||
except TypeError as e:
|
||||
# Bad arguments
|
||||
memory = get_memory()
|
||||
memory.episodic.add_error(tool_name, f"bad_args: {e}")
|
||||
return {"error": "bad_args", "message": str(e), "tool": tool_name}
|
||||
except Exception as e:
|
||||
# Other errors
|
||||
memory = get_memory()
|
||||
memory.episodic.add_error(tool_name, str(e))
|
||||
return {"error": "execution_failed", "message": str(e), "tool": tool_name}
|
||||
|
||||
async def step_streaming(
|
||||
self, user_input: str, completion_id: str, created_ts: int, model: str
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""
|
||||
Execute agent step with streaming support for LibreChat.
|
||||
|
||||
Yields SSE chunks for tool calls and final response.
|
||||
|
||||
Args:
|
||||
user_input: User's message
|
||||
completion_id: Completion ID for the response
|
||||
created_ts: Timestamp for the response
|
||||
model: Model name
|
||||
|
||||
Yields:
|
||||
SSE chunks in OpenAI format
|
||||
"""
|
||||
memory = get_memory()
|
||||
|
||||
# Add user message to history
|
||||
memory.stm.add_message("user", user_input)
|
||||
memory.save()
|
||||
|
||||
# Build initial messages
|
||||
system_prompt = self.prompt_builder.build_system_prompt()
|
||||
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add conversation history
|
||||
history = memory.stm.get_recent_history(settings.max_history_messages)
|
||||
messages.extend(history)
|
||||
|
||||
# Add unread events if any
|
||||
unread_events = memory.episodic.get_unread_events()
|
||||
if unread_events:
|
||||
events_text = "\n".join(
|
||||
[f"- {e['type']}: {e['data']}" for e in unread_events]
|
||||
)
|
||||
messages.append(
|
||||
{"role": "system", "content": f"Background events:\n{events_text}"}
|
||||
)
|
||||
|
||||
# Get tools specification for OpenAI format
|
||||
tools_spec = self.prompt_builder.build_tools_spec()
|
||||
|
||||
# Tool execution loop
|
||||
for _iteration in range(self.max_tool_iterations):
|
||||
# Call LLM with tools
|
||||
llm_result = self.llm.complete(messages, tools=tools_spec)
|
||||
|
||||
# Handle both tuple (response, usage) and dict response
|
||||
if isinstance(llm_result, tuple):
|
||||
response_message, usage = llm_result
|
||||
else:
|
||||
response_message = llm_result
|
||||
|
||||
# Check if there are tool calls
|
||||
tool_calls = response_message.get("tool_calls")
|
||||
|
||||
if not tool_calls:
|
||||
# No tool calls, this is the final response
|
||||
final_content = response_message.get("content", "")
|
||||
memory.stm.add_message("assistant", final_content)
|
||||
memory.save()
|
||||
|
||||
# Stream the final response
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_ts,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": final_content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
return
|
||||
|
||||
# Stream tool calls
|
||||
for tool_call in tool_calls:
|
||||
function = tool_call.get("function", {})
|
||||
tool_name = function.get("name", "")
|
||||
tool_args = function.get("arguments", "{}")
|
||||
|
||||
# Yield chunk indicating tool call
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_ts,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": tool_call.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": tool_args,
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Add assistant message with tool calls to conversation
|
||||
messages.append(response_message)
|
||||
|
||||
# Execute each tool call and stream results
|
||||
for tool_call in tool_calls:
|
||||
tool_result = self._execute_tool_call(tool_call)
|
||||
function = tool_call.get("function", {})
|
||||
tool_name = function.get("name", "")
|
||||
|
||||
# Add tool result to messages
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
"role": "tool",
|
||||
"name": tool_name,
|
||||
"content": json.dumps(tool_result, ensure_ascii=False),
|
||||
}
|
||||
)
|
||||
|
||||
# Stream tool result as content
|
||||
result_text = (
|
||||
f"\n🔧 {tool_name}: {json.dumps(tool_result, ensure_ascii=False)}\n"
|
||||
)
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_ts,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": result_text},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Max iterations reached, force final response
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Please provide a final response to the user without using any more tools.",
|
||||
}
|
||||
)
|
||||
|
||||
llm_result = self.llm.complete(messages)
|
||||
if isinstance(llm_result, tuple):
|
||||
final_message, usage = llm_result
|
||||
else:
|
||||
final_message = llm_result
|
||||
|
||||
final_response = final_message.get(
|
||||
"content", "I've completed the requested actions."
|
||||
)
|
||||
memory.stm.add_message("assistant", final_response)
|
||||
memory.save()
|
||||
|
||||
# Stream final response
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_ts,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": final_response},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
"""Configuration management with validation."""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class ConfigurationError(Exception):
|
||||
"""Raised when configuration is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Settings:
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# LLM Configuration
|
||||
deepseek_api_key: str = field(
|
||||
default_factory=lambda: os.getenv("DEEPSEEK_API_KEY", "")
|
||||
)
|
||||
deepseek_base_url: str = field(
|
||||
default_factory=lambda: os.getenv(
|
||||
"DEEPSEEK_BASE_URL", "https://api.deepseek.com"
|
||||
)
|
||||
)
|
||||
model: str = field(
|
||||
default_factory=lambda: os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
|
||||
)
|
||||
temperature: float = field(
|
||||
default_factory=lambda: float(os.getenv("TEMPERATURE", "0.2"))
|
||||
)
|
||||
|
||||
# TMDB Configuration
|
||||
tmdb_api_key: str = field(default_factory=lambda: os.getenv("TMDB_API_KEY", ""))
|
||||
tmdb_base_url: str = field(
|
||||
default_factory=lambda: os.getenv(
|
||||
"TMDB_BASE_URL", "https://api.themoviedb.org/3"
|
||||
)
|
||||
)
|
||||
|
||||
# Storage Configuration
|
||||
memory_file: str = field(
|
||||
default_factory=lambda: os.getenv("MEMORY_FILE", "memory.json")
|
||||
)
|
||||
|
||||
# Security Configuration
|
||||
max_tool_iterations: int = field(
|
||||
default_factory=lambda: int(os.getenv("MAX_TOOL_ITERATIONS", "5"))
|
||||
)
|
||||
request_timeout: int = field(
|
||||
default_factory=lambda: int(os.getenv("REQUEST_TIMEOUT", "30"))
|
||||
)
|
||||
|
||||
# Memory Configuration
|
||||
max_history_messages: int = field(
|
||||
default_factory=lambda: int(os.getenv("MAX_HISTORY_MESSAGES", "10"))
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate settings after initialization."""
|
||||
self._validate()
|
||||
|
||||
def _validate(self) -> None:
|
||||
"""Validate configuration values."""
|
||||
# Validate temperature
|
||||
if not 0.0 <= self.temperature <= 2.0:
|
||||
raise ConfigurationError(
|
||||
f"Temperature must be between 0.0 and 2.0, got {self.temperature}"
|
||||
)
|
||||
|
||||
# Validate max_tool_iterations
|
||||
if self.max_tool_iterations < 1 or self.max_tool_iterations > 20:
|
||||
raise ConfigurationError(
|
||||
f"max_tool_iterations must be between 1 and 20, got {self.max_tool_iterations}"
|
||||
)
|
||||
|
||||
# Validate request_timeout
|
||||
if self.request_timeout < 1 or self.request_timeout > 300:
|
||||
raise ConfigurationError(
|
||||
f"request_timeout must be between 1 and 300 seconds, got {self.request_timeout}"
|
||||
)
|
||||
|
||||
# Validate URLs
|
||||
if not self.deepseek_base_url.startswith(("http://", "https://")):
|
||||
raise ConfigurationError(
|
||||
f"Invalid deepseek_base_url: {self.deepseek_base_url}"
|
||||
)
|
||||
|
||||
if not self.tmdb_base_url.startswith(("http://", "https://")):
|
||||
raise ConfigurationError(f"Invalid tmdb_base_url: {self.tmdb_base_url}")
|
||||
|
||||
# Validate memory file path
|
||||
memory_path = Path(self.memory_file)
|
||||
if memory_path.exists() and not memory_path.is_file():
|
||||
raise ConfigurationError(
|
||||
f"memory_file exists but is not a file: {self.memory_file}"
|
||||
)
|
||||
|
||||
def is_deepseek_configured(self) -> bool:
|
||||
"""Check if DeepSeek API is properly configured."""
|
||||
return bool(self.deepseek_api_key and self.deepseek_base_url)
|
||||
|
||||
def is_tmdb_configured(self) -> bool:
|
||||
"""Check if TMDB API is properly configured."""
|
||||
return bool(self.tmdb_api_key and self.tmdb_base_url)
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,13 @@
|
||||
"""LLM clients module."""
|
||||
|
||||
from .deepseek import DeepSeekClient
|
||||
from .exceptions import LLMAPIError, LLMConfigurationError, LLMError
|
||||
from .ollama import OllamaClient
|
||||
|
||||
__all__ = [
|
||||
"DeepSeekClient",
|
||||
"OllamaClient",
|
||||
"LLMError",
|
||||
"LLMAPIError",
|
||||
"LLMConfigurationError",
|
||||
]
|
||||
@@ -0,0 +1,148 @@
|
||||
"""DeepSeek LLM client with robust error handling."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.exceptions import HTTPError, RequestException, Timeout
|
||||
|
||||
from ..config import settings
|
||||
from .exceptions import LLMAPIError, LLMConfigurationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeepSeekClient:
|
||||
"""Client for interacting with DeepSeek API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
timeout: int | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize DeepSeek client.
|
||||
|
||||
Args:
|
||||
api_key: API key for authentication (defaults to settings)
|
||||
base_url: Base URL for API (defaults to settings)
|
||||
model: Model name to use (defaults to settings)
|
||||
timeout: Request timeout in seconds (defaults to settings)
|
||||
|
||||
Raises:
|
||||
LLMConfigurationError: If API key is missing
|
||||
"""
|
||||
self.api_key = api_key or settings.deepseek_api_key
|
||||
self.base_url = base_url or settings.deepseek_base_url
|
||||
self.model = model or settings.model
|
||||
self.timeout = timeout or settings.request_timeout
|
||||
|
||||
if not self.api_key:
|
||||
raise LLMConfigurationError(
|
||||
"DeepSeek API key is required. Set DEEPSEEK_API_KEY environment variable."
|
||||
)
|
||||
|
||||
if not self.base_url:
|
||||
raise LLMConfigurationError(
|
||||
"DeepSeek base URL is required. Set DEEPSEEK_BASE_URL environment variable."
|
||||
)
|
||||
|
||||
logger.info(f"DeepSeek client initialized with model: {self.model}")
|
||||
|
||||
def complete( # noqa: PLR0912
|
||||
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a completion from the LLM.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content' keys
|
||||
tools: Optional list of tool specifications (OpenAI format)
|
||||
|
||||
Returns:
|
||||
OpenAI-compatible message dict with 'role', 'content', and optionally 'tool_calls'
|
||||
|
||||
Raises:
|
||||
LLMAPIError: If API request fails
|
||||
ValueError: If messages format is invalid
|
||||
"""
|
||||
# Validate messages format
|
||||
if not messages:
|
||||
raise ValueError("Messages list cannot be empty")
|
||||
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
raise ValueError(f"Each message must be a dict, got {type(msg)}")
|
||||
if "role" not in msg:
|
||||
raise ValueError(f"Message must have 'role' key, got {msg.keys()}")
|
||||
# Allow system, user, assistant, and tool roles
|
||||
if msg["role"] not in ("system", "user", "assistant", "tool"):
|
||||
raise ValueError(f"Invalid role: {msg['role']}")
|
||||
# Content is optional for tool messages (they may have tool_call_id instead)
|
||||
if msg["role"] != "tool" and "content" not in msg:
|
||||
raise ValueError(
|
||||
f"Non-tool message must have 'content' key, got {msg.keys()}"
|
||||
)
|
||||
|
||||
url = f"{self.base_url}/v1/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": settings.temperature,
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
|
||||
)
|
||||
response = requests.post(
|
||||
url, headers=headers, json=payload, timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Validate response structure
|
||||
if "choices" not in data or not data["choices"]:
|
||||
raise LLMAPIError("Invalid API response: missing 'choices'")
|
||||
|
||||
if "message" not in data["choices"][0]:
|
||||
raise LLMAPIError("Invalid API response: missing 'message' in choice")
|
||||
|
||||
# Return the full message dict (OpenAI format)
|
||||
message = data["choices"][0]["message"]
|
||||
logger.debug(f"Received response: {message.get('content', '')[:100]}...")
|
||||
|
||||
return message
|
||||
|
||||
except Timeout as e:
|
||||
logger.error(f"Request timeout after {self.timeout}s: {e}")
|
||||
raise LLMAPIError(f"Request timeout after {self.timeout} seconds") from e
|
||||
|
||||
except HTTPError as e:
|
||||
logger.error(f"HTTP error from DeepSeek API: {e}")
|
||||
if e.response is not None:
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_msg = error_data.get("error", {}).get("message", str(e))
|
||||
except Exception:
|
||||
error_msg = str(e)
|
||||
raise LLMAPIError(f"DeepSeek API error: {error_msg}") from e
|
||||
raise LLMAPIError(f"HTTP error: {e}") from e
|
||||
|
||||
except RequestException as e:
|
||||
logger.error(f"Request failed: {e}")
|
||||
raise LLMAPIError(f"Failed to connect to DeepSeek API: {e}") from e
|
||||
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
logger.error(f"Failed to parse API response: {e}")
|
||||
raise LLMAPIError(f"Invalid API response format: {e}") from e
|
||||
@@ -0,0 +1,19 @@
|
||||
"""LLM-related exceptions."""
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
"""Base exception for LLM-related errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LLMConfigurationError(LLMError):
|
||||
"""Raised when LLM is not properly configured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LLMAPIError(LLMError):
|
||||
"""Raised when LLM API returns an error."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Ollama LLM client with robust error handling."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.exceptions import HTTPError, RequestException, Timeout
|
||||
|
||||
from ..config import settings
|
||||
from .exceptions import LLMAPIError, LLMConfigurationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""
|
||||
Client for interacting with Ollama API.
|
||||
|
||||
Ollama runs locally and provides an OpenAI-compatible API.
|
||||
|
||||
Example:
|
||||
>>> client = OllamaClient(model="llama3.2")
|
||||
>>> messages = [{"role": "user", "content": "Hello!"}]
|
||||
>>> response = client.complete(messages)
|
||||
>>> print(response)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize Ollama client.
|
||||
|
||||
Args:
|
||||
base_url: Ollama API base URL (defaults to http://localhost:11434)
|
||||
model: Model name to use (e.g., "llama3.2", "mistral", "codellama")
|
||||
timeout: Request timeout in seconds (defaults to settings)
|
||||
temperature: Temperature for generation (defaults to settings)
|
||||
|
||||
Raises:
|
||||
LLMConfigurationError: If configuration is invalid
|
||||
"""
|
||||
self.base_url = base_url or os.getenv(
|
||||
"OLLAMA_BASE_URL", "http://localhost:11434"
|
||||
)
|
||||
self.model = model or os.getenv("OLLAMA_MODEL", "llama3.2")
|
||||
self.timeout = timeout or settings.request_timeout
|
||||
self.temperature = (
|
||||
temperature if temperature is not None else settings.temperature
|
||||
)
|
||||
|
||||
if not self.base_url:
|
||||
raise LLMConfigurationError(
|
||||
"Ollama base URL is required. Set OLLAMA_BASE_URL environment variable."
|
||||
)
|
||||
|
||||
if not self.model:
|
||||
raise LLMConfigurationError(
|
||||
"Ollama model is required. Set OLLAMA_MODEL environment variable."
|
||||
)
|
||||
|
||||
logger.info(f"Ollama client initialized with model: {self.model}")
|
||||
|
||||
def complete( # noqa: PLR0912
|
||||
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a completion from the LLM.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content' keys
|
||||
tools: Optional list of tool specifications (OpenAI format)
|
||||
|
||||
Returns:
|
||||
OpenAI-compatible message dict with 'role', 'content', and optionally 'tool_calls'
|
||||
|
||||
Raises:
|
||||
LLMAPIError: If API request fails
|
||||
ValueError: If messages format is invalid
|
||||
"""
|
||||
# Validate messages format
|
||||
if not messages:
|
||||
raise ValueError("Messages list cannot be empty")
|
||||
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
raise ValueError(f"Each message must be a dict, got {type(msg)}")
|
||||
if "role" not in msg:
|
||||
raise ValueError(f"Message must have 'role' key, got {msg.keys()}")
|
||||
# Allow system, user, assistant, and tool roles
|
||||
if msg["role"] not in ("system", "user", "assistant", "tool"):
|
||||
raise ValueError(f"Invalid role: {msg['role']}")
|
||||
# Content is optional for tool messages (they may have tool_call_id instead)
|
||||
if msg["role"] != "tool" and "content" not in msg:
|
||||
raise ValueError(
|
||||
f"Non-tool message must have 'content' key, got {msg.keys()}"
|
||||
)
|
||||
|
||||
url = f"{self.base_url}/api/chat"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": self.temperature,
|
||||
},
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
|
||||
)
|
||||
response = requests.post(url, json=payload, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Validate response structure
|
||||
if "message" not in data:
|
||||
raise LLMAPIError("Invalid API response: missing 'message'")
|
||||
|
||||
# Return the full message dict (OpenAI format)
|
||||
message = data["message"]
|
||||
logger.debug(f"Received response: {message.get('content', '')[:100]}...")
|
||||
|
||||
return message
|
||||
|
||||
except Timeout as e:
|
||||
logger.error(f"Request timeout after {self.timeout}s: {e}")
|
||||
raise LLMAPIError(f"Request timeout after {self.timeout} seconds") from e
|
||||
|
||||
except HTTPError as e:
|
||||
logger.error(f"HTTP error from Ollama API: {e}")
|
||||
if e.response is not None:
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_msg = error_data.get("error", str(e))
|
||||
except Exception:
|
||||
error_msg = str(e)
|
||||
raise LLMAPIError(f"Ollama API error: {error_msg}") from e
|
||||
raise LLMAPIError(f"HTTP error: {e}") from e
|
||||
|
||||
except RequestException as e:
|
||||
logger.error(f"Request failed: {e}")
|
||||
raise LLMAPIError(f"Failed to connect to Ollama API: {e}") from e
|
||||
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
logger.error(f"Failed to parse API response: {e}")
|
||||
raise LLMAPIError(f"Invalid API response format: {e}") from e
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
"""
|
||||
List available models in Ollama.
|
||||
|
||||
Returns:
|
||||
List of model names
|
||||
"""
|
||||
url = f"{self.base_url}/api/tags"
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
models = [model["name"] for model in data.get("models", [])]
|
||||
logger.info(f"Found {len(models)} models: {models}")
|
||||
return models
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list models: {e}")
|
||||
return []
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
Check if Ollama is running and accessible.
|
||||
|
||||
Returns:
|
||||
True if Ollama is available, False otherwise
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/api/tags"
|
||||
response = requests.get(url, timeout=5)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
@@ -0,0 +1,101 @@
|
||||
# agent/parameters.py
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParameterSchema:
|
||||
"""Describes a required parameter for the agent."""
|
||||
|
||||
key: str
|
||||
description: str
|
||||
why_needed: str # Explanation for the AI
|
||||
type: str # "string", "number", "object", etc.
|
||||
validator: Callable[[Any], bool] | None = None
|
||||
default: Any = None
|
||||
required: bool = True
|
||||
|
||||
|
||||
# Define all required parameters
|
||||
REQUIRED_PARAMETERS = [
|
||||
ParameterSchema(
|
||||
key="config",
|
||||
description="Configuration object containing all folder paths",
|
||||
why_needed=(
|
||||
"This contains the paths to all important folders:\n"
|
||||
"- download_folder: Where downloaded files arrive before being organized\n"
|
||||
"- tvshow_folder: Where TV show files are organized and stored\n"
|
||||
"- movie_folder: Where movie files are organized and stored\n"
|
||||
"- torrent_folder: Where torrent structures are saved for the torrent client"
|
||||
),
|
||||
type="object",
|
||||
validator=lambda x: isinstance(x, dict),
|
||||
required=True,
|
||||
default={},
|
||||
),
|
||||
ParameterSchema(
|
||||
key="tv_shows",
|
||||
description="List of TV shows the user is following",
|
||||
why_needed=(
|
||||
"This tracks which TV shows you're following. "
|
||||
"Each show includes: IMDB ID, title, number of seasons, and status (ongoing or ended)."
|
||||
),
|
||||
type="array",
|
||||
validator=lambda x: isinstance(x, list),
|
||||
required=False,
|
||||
default=[],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_parameter_schema(key: str) -> ParameterSchema | None:
|
||||
"""Get schema for a specific parameter."""
|
||||
for param in REQUIRED_PARAMETERS:
|
||||
if param.key == key:
|
||||
return param
|
||||
return None
|
||||
|
||||
|
||||
def get_missing_required_parameters(memory_data: dict) -> list[ParameterSchema]:
|
||||
"""Get list of required parameters that are missing or None."""
|
||||
missing = []
|
||||
for param in REQUIRED_PARAMETERS:
|
||||
if param.required:
|
||||
value = memory_data.get(param.key)
|
||||
if value is None:
|
||||
missing.append(param)
|
||||
return missing
|
||||
|
||||
|
||||
def format_parameters_for_prompt() -> str:
|
||||
"""Format parameter descriptions for the AI system prompt."""
|
||||
lines = ["REQUIRED PARAMETERS:"]
|
||||
for param in REQUIRED_PARAMETERS:
|
||||
status = "REQUIRED" if param.required else "OPTIONAL"
|
||||
lines.append(f"\n- {param.key} ({status}):")
|
||||
lines.append(f" Description: {param.description}")
|
||||
lines.append(f" Why needed: {param.why_needed}")
|
||||
lines.append(f" Type: {param.type}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def validate_parameter(key: str, value: Any) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Validate a parameter value against its schema.
|
||||
|
||||
Returns:
|
||||
(is_valid, error_message)
|
||||
"""
|
||||
schema = get_parameter_schema(key)
|
||||
if not schema:
|
||||
return True, None # Unknown parameters are allowed
|
||||
|
||||
if schema.validator:
|
||||
try:
|
||||
if not schema.validator(value):
|
||||
return False, f"Validation failed for {key}"
|
||||
except Exception as e:
|
||||
return False, f"Validation error for {key}: {str(e)}"
|
||||
|
||||
return True, None
|
||||
@@ -0,0 +1,180 @@
|
||||
"""Prompt builder for the agent system."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
from .registry import Tool
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
"""Builds system prompts for the agent with memory context."""
|
||||
|
||||
def __init__(self, tools: dict[str, Tool]):
|
||||
self.tools = tools
|
||||
|
||||
def build_tools_spec(self) -> list[dict[str, Any]]:
|
||||
"""Build the tool specification for the LLM API."""
|
||||
tool_specs = []
|
||||
for tool in self.tools.values():
|
||||
spec = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
},
|
||||
}
|
||||
tool_specs.append(spec)
|
||||
return tool_specs
|
||||
|
||||
def _format_tools_description(self) -> str:
|
||||
"""Format tools with their descriptions and parameters."""
|
||||
if not self.tools:
|
||||
return ""
|
||||
return "\n".join(
|
||||
f"- {tool.name}: {tool.description}\n"
|
||||
f" Parameters: {json.dumps(tool.parameters, ensure_ascii=False)}"
|
||||
for tool in self.tools.values()
|
||||
)
|
||||
|
||||
def _format_episodic_context(self, memory) -> str:
|
||||
"""Format episodic memory context for the prompt."""
|
||||
lines = []
|
||||
|
||||
if memory.episodic.last_search_results:
|
||||
results = memory.episodic.last_search_results
|
||||
result_list = results.get("results", [])
|
||||
lines.append(
|
||||
f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)"
|
||||
)
|
||||
# Show first 5 results
|
||||
for i, result in enumerate(result_list[:5]):
|
||||
name = result.get("name", "Unknown")
|
||||
lines.append(f" {i+1}. {name}")
|
||||
if len(result_list) > 5:
|
||||
lines.append(f" ... and {len(result_list) - 5} more")
|
||||
|
||||
if memory.episodic.pending_question:
|
||||
question = memory.episodic.pending_question
|
||||
lines.append(f"\nPENDING QUESTION: {question.get('question')}")
|
||||
lines.append(f" Type: {question.get('type')}")
|
||||
if question.get("options"):
|
||||
lines.append(f" Options: {len(question.get('options'))}")
|
||||
|
||||
if memory.episodic.active_downloads:
|
||||
lines.append(f"\nACTIVE DOWNLOADS: {len(memory.episodic.active_downloads)}")
|
||||
for dl in memory.episodic.active_downloads[:3]:
|
||||
lines.append(f" - {dl.get('name')}: {dl.get('progress', 0)}%")
|
||||
|
||||
if memory.episodic.recent_errors:
|
||||
lines.append("\nRECENT ERRORS (up to 3):")
|
||||
for error in memory.episodic.recent_errors[-3:]:
|
||||
lines.append(
|
||||
f" - Action '{error.get('action')}' failed: {error.get('error')}"
|
||||
)
|
||||
|
||||
# Unread events
|
||||
unread = [e for e in memory.episodic.background_events if not e.get("read")]
|
||||
if unread:
|
||||
lines.append(f"\nUNREAD EVENTS: {len(unread)}")
|
||||
for event in unread[:3]:
|
||||
lines.append(f" - {event.get('type')}: {event.get('data')}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_stm_context(self, memory) -> str:
|
||||
"""Format short-term memory context for the prompt."""
|
||||
lines = []
|
||||
|
||||
if memory.stm.current_workflow:
|
||||
workflow = memory.stm.current_workflow
|
||||
lines.append(
|
||||
f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})"
|
||||
)
|
||||
if workflow.get("target"):
|
||||
lines.append(f" Target: {workflow.get('target')}")
|
||||
|
||||
if memory.stm.current_topic:
|
||||
lines.append(f"CURRENT TOPIC: {memory.stm.current_topic}")
|
||||
|
||||
if memory.stm.extracted_entities:
|
||||
lines.append("EXTRACTED ENTITIES:")
|
||||
for key, value in memory.stm.extracted_entities.items():
|
||||
lines.append(f" - {key}: {value}")
|
||||
|
||||
if memory.stm.language:
|
||||
lines.append(f"CONVERSATION LANGUAGE: {memory.stm.language}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_config_context(self, memory) -> str:
|
||||
"""Format configuration context."""
|
||||
lines = ["CURRENT CONFIGURATION:"]
|
||||
if memory.ltm.config:
|
||||
for key, value in memory.ltm.config.items():
|
||||
lines.append(f" - {key}: {value}")
|
||||
else:
|
||||
lines.append(" (no configuration set)")
|
||||
return "\n".join(lines)
|
||||
|
||||
def build_system_prompt(self) -> str:
|
||||
"""Build the complete system prompt."""
|
||||
# Get memory once for all context formatting
|
||||
memory = get_memory()
|
||||
|
||||
# Base instruction
|
||||
base = "You are a helpful AI assistant for managing a media library."
|
||||
|
||||
# Language instruction
|
||||
language_instruction = (
|
||||
"Your first task is to determine the user's language from their message "
|
||||
"and use the `set_language` tool if it's different from the current one. "
|
||||
"After that, proceed to help the user."
|
||||
)
|
||||
|
||||
# Available tools
|
||||
tools_desc = self._format_tools_description()
|
||||
tools_section = f"\nAVAILABLE TOOLS:\n{tools_desc}" if tools_desc else ""
|
||||
|
||||
# Configuration
|
||||
config_section = self._format_config_context(memory)
|
||||
if config_section:
|
||||
config_section = f"\n{config_section}"
|
||||
|
||||
# STM context
|
||||
stm_context = self._format_stm_context(memory)
|
||||
if stm_context:
|
||||
stm_context = f"\n{stm_context}"
|
||||
|
||||
# Episodic context
|
||||
episodic_context = self._format_episodic_context(memory)
|
||||
|
||||
# Important rules
|
||||
rules = """
|
||||
IMPORTANT RULES:
|
||||
- Use tools to accomplish tasks
|
||||
- When search results are available, reference them by index (e.g., "add_torrent_by_index")
|
||||
- Always confirm actions with the user before executing destructive operations
|
||||
- Provide clear, concise responses
|
||||
"""
|
||||
|
||||
# Examples
|
||||
examples = """
|
||||
EXAMPLES:
|
||||
- User: "Find Inception" → Use find_media_imdb_id, then find_torrent
|
||||
- User: "download the 3rd one" → Use add_torrent_by_index with index=3
|
||||
- User: "List my downloads" → Use list_folder with folder_type="download"
|
||||
"""
|
||||
|
||||
return f"""{base}
|
||||
|
||||
{language_instruction}
|
||||
{tools_section}
|
||||
{config_section}
|
||||
{stm_context}
|
||||
{episodic_context}
|
||||
{rules}
|
||||
{examples}
|
||||
"""
|
||||
@@ -0,0 +1,112 @@
|
||||
"""Tool registry - defines and registers all available tools for the agent."""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tool:
|
||||
"""Represents a tool that can be used by the agent."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
func: Callable[..., dict[str, Any]]
|
||||
parameters: dict[str, Any]
|
||||
|
||||
|
||||
def _create_tool_from_function(func: Callable) -> Tool:
|
||||
"""
|
||||
Create a Tool object from a function.
|
||||
|
||||
Args:
|
||||
func: Function to convert to a tool
|
||||
|
||||
Returns:
|
||||
Tool object with metadata extracted from function
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
doc = inspect.getdoc(func)
|
||||
|
||||
# Extract description from docstring (first line)
|
||||
description = doc.strip().split("\n")[0] if doc else func.__name__
|
||||
|
||||
# Build JSON schema from function signature
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
# Map Python types to JSON schema types
|
||||
param_type = "string" # default
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
if param.annotation is str:
|
||||
param_type = "string"
|
||||
elif param.annotation is int:
|
||||
param_type = "integer"
|
||||
elif param.annotation is float:
|
||||
param_type = "number"
|
||||
elif param.annotation is bool:
|
||||
param_type = "boolean"
|
||||
|
||||
properties[param_name] = {
|
||||
"type": param_type,
|
||||
"description": f"Parameter {param_name}",
|
||||
}
|
||||
|
||||
# Add to required if no default value
|
||||
if param.default == inspect.Parameter.empty:
|
||||
required.append(param_name)
|
||||
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
return Tool(
|
||||
name=func.__name__,
|
||||
description=description,
|
||||
func=func,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
|
||||
def make_tools() -> dict[str, Tool]:
|
||||
"""
|
||||
Create and register all available tools.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool names to Tool objects
|
||||
"""
|
||||
# Import tools here to avoid circular dependencies
|
||||
from .tools import api as api_tools # noqa: PLC0415
|
||||
from .tools import filesystem as fs_tools # noqa: PLC0415
|
||||
from .tools import language as lang_tools # noqa: PLC0415
|
||||
|
||||
# List of all tool functions
|
||||
tool_functions = [
|
||||
fs_tools.set_path_for_folder,
|
||||
fs_tools.list_folder,
|
||||
api_tools.find_media_imdb_id,
|
||||
api_tools.find_torrent,
|
||||
api_tools.add_torrent_by_index,
|
||||
api_tools.add_torrent_to_qbittorrent,
|
||||
api_tools.get_torrent_by_index,
|
||||
lang_tools.set_language,
|
||||
]
|
||||
|
||||
# Create Tool objects from functions
|
||||
tools = {}
|
||||
for func in tool_functions:
|
||||
tool = _create_tool_from_function(func)
|
||||
tools[tool.name] = tool
|
||||
|
||||
logger.info(f"Registered {len(tools)} tools: {list(tools.keys())}")
|
||||
return tools
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Tools module - filesystem and API tools for the agent."""
|
||||
|
||||
from .api import (
|
||||
add_torrent_by_index,
|
||||
add_torrent_to_qbittorrent,
|
||||
find_media_imdb_id,
|
||||
find_torrent,
|
||||
get_torrent_by_index,
|
||||
)
|
||||
from .filesystem import list_folder, set_path_for_folder
|
||||
from .language import set_language
|
||||
|
||||
__all__ = [
|
||||
"set_path_for_folder",
|
||||
"list_folder",
|
||||
"find_media_imdb_id",
|
||||
"find_torrent",
|
||||
"get_torrent_by_index",
|
||||
"add_torrent_to_qbittorrent",
|
||||
"add_torrent_by_index",
|
||||
"set_language",
|
||||
]
|
||||
@@ -0,0 +1,196 @@
|
||||
"""API tools for interacting with external services."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from application.movies import SearchMovieUseCase
|
||||
from application.torrents import AddTorrentUseCase, SearchTorrentsUseCase
|
||||
from infrastructure.api.knaben import knaben_client
|
||||
from infrastructure.api.qbittorrent import qbittorrent_client
|
||||
from infrastructure.api.tmdb import tmdb_client
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_media_imdb_id(media_title: str) -> dict[str, Any]:
|
||||
"""
|
||||
Find the IMDb ID for a given media title using TMDB API.
|
||||
|
||||
Args:
|
||||
media_title: Title of the media to search for.
|
||||
|
||||
Returns:
|
||||
Dict with IMDb ID and media info, or error details.
|
||||
"""
|
||||
use_case = SearchMovieUseCase(tmdb_client)
|
||||
response = use_case.execute(media_title)
|
||||
result = response.to_dict()
|
||||
|
||||
if result.get("status") == "ok":
|
||||
memory = get_memory()
|
||||
memory.stm.set_entity(
|
||||
"last_media_search",
|
||||
{
|
||||
"title": result.get("title"),
|
||||
"imdb_id": result.get("imdb_id"),
|
||||
"media_type": result.get("media_type"),
|
||||
"tmdb_id": result.get("tmdb_id"),
|
||||
},
|
||||
)
|
||||
memory.stm.set_topic("searching_media")
|
||||
logger.debug(f"Stored media search result in STM: {result.get('title')}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def find_torrent(media_title: str) -> dict[str, Any]:
|
||||
"""
|
||||
Find torrents for a given media title using Knaben API.
|
||||
|
||||
Results are stored in episodic memory so the user can reference them
|
||||
by index (e.g., "download the 3rd one").
|
||||
|
||||
Args:
|
||||
media_title: Title of the media to search for.
|
||||
|
||||
Returns:
|
||||
Dict with torrent list or error details.
|
||||
"""
|
||||
logger.info(f"Searching torrents for: {media_title}")
|
||||
|
||||
use_case = SearchTorrentsUseCase(knaben_client)
|
||||
response = use_case.execute(media_title, limit=10)
|
||||
result = response.to_dict()
|
||||
|
||||
if result.get("status") == "ok":
|
||||
memory = get_memory()
|
||||
torrents = result.get("torrents", [])
|
||||
memory.episodic.store_search_results(
|
||||
query=media_title, results=torrents, search_type="torrent"
|
||||
)
|
||||
memory.stm.set_topic("selecting_torrent")
|
||||
logger.info(f"Stored {len(torrents)} torrent results in episodic memory")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_torrent_by_index(index: int) -> dict[str, Any]:
|
||||
"""
|
||||
Get a torrent from the last search results by its index.
|
||||
|
||||
Allows the user to reference results by number after a search.
|
||||
|
||||
Args:
|
||||
index: 1-based index of the torrent in the search results.
|
||||
|
||||
Returns:
|
||||
Dict with torrent data or error if not found.
|
||||
"""
|
||||
logger.info(f"Getting torrent at index: {index}")
|
||||
|
||||
memory = get_memory()
|
||||
|
||||
if memory.episodic.last_search_results:
|
||||
results_count = len(memory.episodic.last_search_results.get("results", []))
|
||||
query = memory.episodic.last_search_results.get("query", "unknown")
|
||||
logger.debug(f"Episodic memory has {results_count} results from: {query}")
|
||||
else:
|
||||
logger.warning("No search results in episodic memory")
|
||||
|
||||
result = memory.episodic.get_result_by_index(index)
|
||||
|
||||
if result:
|
||||
logger.info(f"Found torrent at index {index}: {result.get('name', 'unknown')}")
|
||||
return {"status": "ok", "torrent": result}
|
||||
|
||||
logger.warning(f"No torrent found at index {index}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "not_found",
|
||||
"message": f"No torrent found at index {index}. Search for torrents first.",
|
||||
}
|
||||
|
||||
|
||||
def add_torrent_to_qbittorrent(magnet_link: str) -> dict[str, Any]:
|
||||
"""
|
||||
Add a torrent to qBittorrent using a magnet link.
|
||||
|
||||
Args:
|
||||
magnet_link: Magnet link of the torrent to add.
|
||||
|
||||
Returns:
|
||||
Dict with success status or error details.
|
||||
"""
|
||||
logger.info("Adding torrent to qBittorrent")
|
||||
|
||||
use_case = AddTorrentUseCase(qbittorrent_client)
|
||||
response = use_case.execute(magnet_link)
|
||||
result = response.to_dict()
|
||||
|
||||
if result.get("status") == "ok":
|
||||
memory = get_memory()
|
||||
last_search = memory.episodic.get_search_results()
|
||||
torrent_name = "Unknown"
|
||||
|
||||
if last_search:
|
||||
for t in last_search.get("results", []):
|
||||
if t.get("magnet") == magnet_link:
|
||||
torrent_name = t.get("name", "Unknown")
|
||||
break
|
||||
|
||||
memory.episodic.add_active_download(
|
||||
{
|
||||
"task_id": magnet_link[:20],
|
||||
"name": torrent_name,
|
||||
"magnet": magnet_link,
|
||||
"progress": 0,
|
||||
"status": "queued",
|
||||
}
|
||||
)
|
||||
|
||||
memory.stm.set_topic("downloading")
|
||||
memory.stm.end_workflow()
|
||||
logger.info(f"Added download to episodic memory: {torrent_name}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_torrent_by_index(index: int) -> dict[str, Any]:
|
||||
"""
|
||||
Add a torrent from the last search results by its index.
|
||||
|
||||
Combines get_torrent_by_index and add_torrent_to_qbittorrent.
|
||||
|
||||
Args:
|
||||
index: 1-based index of the torrent in the search results.
|
||||
|
||||
Returns:
|
||||
Dict with success status or error details.
|
||||
"""
|
||||
logger.info(f"Adding torrent by index: {index}")
|
||||
|
||||
torrent_result = get_torrent_by_index(index)
|
||||
|
||||
if torrent_result.get("status") != "ok":
|
||||
return torrent_result
|
||||
|
||||
torrent = torrent_result.get("torrent", {})
|
||||
magnet = torrent.get("magnet")
|
||||
|
||||
if not magnet:
|
||||
logger.error("Torrent has no magnet link")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "no_magnet",
|
||||
"message": "The selected torrent has no magnet link",
|
||||
}
|
||||
|
||||
logger.info(f"Adding torrent: {torrent.get('name', 'unknown')}")
|
||||
|
||||
result = add_torrent_to_qbittorrent(magnet)
|
||||
|
||||
if result.get("status") == "ok":
|
||||
result["torrent_name"] = torrent.get("name", "Unknown")
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Filesystem tools for folder management."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from application.filesystem import ListFolderUseCase, SetFolderPathUseCase
|
||||
from infrastructure.filesystem import FileManager
|
||||
|
||||
|
||||
def set_path_for_folder(folder_name: str, path_value: str) -> dict[str, Any]:
|
||||
"""
|
||||
Set a folder path in the configuration.
|
||||
|
||||
Args:
|
||||
folder_name: Name of folder to set (download, tvshow, movie, torrent).
|
||||
path_value: Absolute path to the folder.
|
||||
|
||||
Returns:
|
||||
Dict with status or error information.
|
||||
"""
|
||||
file_manager = FileManager()
|
||||
use_case = SetFolderPathUseCase(file_manager)
|
||||
response = use_case.execute(folder_name, path_value)
|
||||
return response.to_dict()
|
||||
|
||||
|
||||
def list_folder(folder_type: str, path: str = ".") -> dict[str, Any]:
|
||||
"""
|
||||
List contents of a configured folder.
|
||||
|
||||
Args:
|
||||
folder_type: Type of folder to list (download, tvshow, movie, torrent).
|
||||
path: Relative path within the folder (default: root).
|
||||
|
||||
Returns:
|
||||
Dict with folder contents or error information.
|
||||
"""
|
||||
file_manager = FileManager()
|
||||
use_case = ListFolderUseCase(file_manager)
|
||||
response = use_case.execute(folder_type, path)
|
||||
return response.to_dict()
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Language management tools for the agent."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_language(language: str) -> dict[str, Any]:
|
||||
"""
|
||||
Set the conversation language.
|
||||
|
||||
Args:
|
||||
language: Language code (e.g., 'en', 'fr', 'es', 'de')
|
||||
|
||||
Returns:
|
||||
Status dictionary
|
||||
"""
|
||||
try:
|
||||
memory = get_memory()
|
||||
memory.stm.set_language(language)
|
||||
memory.save()
|
||||
|
||||
logger.info(f"Language set to: {language}")
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"Language set to {language}",
|
||||
"language": language,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set language: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
Reference in New Issue
Block a user