Updated folder structure (for Docker)

This commit is contained in:
2025-12-09 05:35:59 +01:00
parent 6940c76e58
commit ec7d2d623f
108 changed files with 0 additions and 0 deletions
+6
View File
@@ -0,0 +1,6 @@
"""Agent module for media library management."""
from .agent import Agent
from .config import settings
__all__ = ["Agent", "settings"]
+371
View File
@@ -0,0 +1,371 @@
"""Main agent for media library management."""
import json
import logging
from collections.abc import AsyncGenerator
from typing import Any
from infrastructure.persistence import get_memory
from .config import settings
from .prompts import PromptBuilder
from .registry import Tool, make_tools
logger = logging.getLogger(__name__)
class Agent:
"""
AI agent for media library management.
Uses OpenAI-compatible tool calling API.
"""
def __init__(self, llm, max_tool_iterations: int = 5):
"""
Initialize the agent.
Args:
llm: LLM client with complete() method
max_tool_iterations: Maximum number of tool execution iterations
"""
self.llm = llm
self.tools: dict[str, Tool] = make_tools()
self.prompt_builder = PromptBuilder(self.tools)
self.max_tool_iterations = max_tool_iterations
def step(self, user_input: str) -> str:
"""
Execute one agent step with the user input.
This method:
1. Adds user message to memory
2. Builds prompt with history and context
3. Calls LLM, executing tools as needed
4. Returns final response
Args:
user_input: User's message
Returns:
Agent's final response
"""
memory = get_memory()
# Add user message to history
memory.stm.add_message("user", user_input)
memory.save()
# Build initial messages
system_prompt = self.prompt_builder.build_system_prompt()
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
# Add conversation history
history = memory.stm.get_recent_history(settings.max_history_messages)
messages.extend(history)
# Add unread events if any
unread_events = memory.episodic.get_unread_events()
if unread_events:
events_text = "\n".join(
[f"- {e['type']}: {e['data']}" for e in unread_events]
)
messages.append(
{"role": "system", "content": f"Background events:\n{events_text}"}
)
# Get tools specification for OpenAI format
tools_spec = self.prompt_builder.build_tools_spec()
# Tool execution loop
for _iteration in range(self.max_tool_iterations):
# Call LLM with tools
llm_result = self.llm.complete(messages, tools=tools_spec)
# Handle both tuple (response, usage) and dict response
if isinstance(llm_result, tuple):
response_message, usage = llm_result
else:
response_message = llm_result
# Check if there are tool calls
tool_calls = response_message.get("tool_calls")
if not tool_calls:
# No tool calls, this is the final response
final_content = response_message.get("content", "")
memory.stm.add_message("assistant", final_content)
memory.save()
return final_content
# Add assistant message with tool calls to conversation
messages.append(response_message)
# Execute each tool call
for tool_call in tool_calls:
tool_result = self._execute_tool_call(tool_call)
# Add tool result to messages
messages.append(
{
"tool_call_id": tool_call.get("id"),
"role": "tool",
"name": tool_call.get("function", {}).get("name"),
"content": json.dumps(tool_result, ensure_ascii=False),
}
)
# Max iterations reached, force final response
messages.append(
{
"role": "system",
"content": "Please provide a final response to the user without using any more tools.",
}
)
llm_result = self.llm.complete(messages)
if isinstance(llm_result, tuple):
final_message, usage = llm_result
else:
final_message = llm_result
final_response = final_message.get(
"content", "I've completed the requested actions."
)
memory.stm.add_message("assistant", final_response)
memory.save()
return final_response
def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]:
"""
Execute a single tool call.
Args:
tool_call: OpenAI-format tool call dict
Returns:
Result dictionary
"""
function = tool_call.get("function", {})
tool_name = function.get("name", "")
try:
args_str = function.get("arguments", "{}")
args = json.loads(args_str)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse tool arguments: {e}")
return {"error": "bad_args", "message": f"Invalid JSON arguments: {e}"}
# Validate tool exists
if tool_name not in self.tools:
available = list(self.tools.keys())
return {
"error": "unknown_tool",
"message": f"Tool '{tool_name}' not found",
"available_tools": available,
}
tool = self.tools[tool_name]
# Execute tool
try:
result = tool.func(**args)
return result
except KeyboardInterrupt:
# Don't catch KeyboardInterrupt - let it propagate
raise
except TypeError as e:
# Bad arguments
memory = get_memory()
memory.episodic.add_error(tool_name, f"bad_args: {e}")
return {"error": "bad_args", "message": str(e), "tool": tool_name}
except Exception as e:
# Other errors
memory = get_memory()
memory.episodic.add_error(tool_name, str(e))
return {"error": "execution_failed", "message": str(e), "tool": tool_name}
async def step_streaming(
self, user_input: str, completion_id: str, created_ts: int, model: str
) -> AsyncGenerator[dict[str, Any], None]:
"""
Execute agent step with streaming support for LibreChat.
Yields SSE chunks for tool calls and final response.
Args:
user_input: User's message
completion_id: Completion ID for the response
created_ts: Timestamp for the response
model: Model name
Yields:
SSE chunks in OpenAI format
"""
memory = get_memory()
# Add user message to history
memory.stm.add_message("user", user_input)
memory.save()
# Build initial messages
system_prompt = self.prompt_builder.build_system_prompt()
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
# Add conversation history
history = memory.stm.get_recent_history(settings.max_history_messages)
messages.extend(history)
# Add unread events if any
unread_events = memory.episodic.get_unread_events()
if unread_events:
events_text = "\n".join(
[f"- {e['type']}: {e['data']}" for e in unread_events]
)
messages.append(
{"role": "system", "content": f"Background events:\n{events_text}"}
)
# Get tools specification for OpenAI format
tools_spec = self.prompt_builder.build_tools_spec()
# Tool execution loop
for _iteration in range(self.max_tool_iterations):
# Call LLM with tools
llm_result = self.llm.complete(messages, tools=tools_spec)
# Handle both tuple (response, usage) and dict response
if isinstance(llm_result, tuple):
response_message, usage = llm_result
else:
response_message = llm_result
# Check if there are tool calls
tool_calls = response_message.get("tool_calls")
if not tool_calls:
# No tool calls, this is the final response
final_content = response_message.get("content", "")
memory.stm.add_message("assistant", final_content)
memory.save()
# Stream the final response
yield {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": final_content},
"finish_reason": "stop",
}
],
}
return
# Stream tool calls
for tool_call in tool_calls:
function = tool_call.get("function", {})
tool_name = function.get("name", "")
tool_args = function.get("arguments", "{}")
# Yield chunk indicating tool call
yield {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [
{
"index": 0,
"id": tool_call.get("id"),
"type": "function",
"function": {
"name": tool_name,
"arguments": tool_args,
},
}
]
},
"finish_reason": None,
}
],
}
# Add assistant message with tool calls to conversation
messages.append(response_message)
# Execute each tool call and stream results
for tool_call in tool_calls:
tool_result = self._execute_tool_call(tool_call)
function = tool_call.get("function", {})
tool_name = function.get("name", "")
# Add tool result to messages
messages.append(
{
"tool_call_id": tool_call.get("id"),
"role": "tool",
"name": tool_name,
"content": json.dumps(tool_result, ensure_ascii=False),
}
)
# Stream tool result as content
result_text = (
f"\n🔧 {tool_name}: {json.dumps(tool_result, ensure_ascii=False)}\n"
)
yield {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": result_text},
"finish_reason": None,
}
],
}
# Max iterations reached, force final response
messages.append(
{
"role": "system",
"content": "Please provide a final response to the user without using any more tools.",
}
)
llm_result = self.llm.complete(messages)
if isinstance(llm_result, tuple):
final_message, usage = llm_result
else:
final_message = llm_result
final_response = final_message.get(
"content", "I've completed the requested actions."
)
memory.stm.add_message("assistant", final_response)
memory.save()
# Stream final response
yield {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": final_response},
"finish_reason": "stop",
}
],
}
+115
View File
@@ -0,0 +1,115 @@
"""Configuration management with validation."""
import os
from dataclasses import dataclass, field
from pathlib import Path
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
class ConfigurationError(Exception):
"""Raised when configuration is invalid."""
pass
@dataclass
class Settings:
"""Application settings loaded from environment variables."""
# LLM Configuration
deepseek_api_key: str = field(
default_factory=lambda: os.getenv("DEEPSEEK_API_KEY", "")
)
deepseek_base_url: str = field(
default_factory=lambda: os.getenv(
"DEEPSEEK_BASE_URL", "https://api.deepseek.com"
)
)
model: str = field(
default_factory=lambda: os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
)
temperature: float = field(
default_factory=lambda: float(os.getenv("TEMPERATURE", "0.2"))
)
# TMDB Configuration
tmdb_api_key: str = field(default_factory=lambda: os.getenv("TMDB_API_KEY", ""))
tmdb_base_url: str = field(
default_factory=lambda: os.getenv(
"TMDB_BASE_URL", "https://api.themoviedb.org/3"
)
)
# Storage Configuration
memory_file: str = field(
default_factory=lambda: os.getenv("MEMORY_FILE", "memory.json")
)
# Security Configuration
max_tool_iterations: int = field(
default_factory=lambda: int(os.getenv("MAX_TOOL_ITERATIONS", "5"))
)
request_timeout: int = field(
default_factory=lambda: int(os.getenv("REQUEST_TIMEOUT", "30"))
)
# Memory Configuration
max_history_messages: int = field(
default_factory=lambda: int(os.getenv("MAX_HISTORY_MESSAGES", "10"))
)
def __post_init__(self):
"""Validate settings after initialization."""
self._validate()
def _validate(self) -> None:
"""Validate configuration values."""
# Validate temperature
if not 0.0 <= self.temperature <= 2.0:
raise ConfigurationError(
f"Temperature must be between 0.0 and 2.0, got {self.temperature}"
)
# Validate max_tool_iterations
if self.max_tool_iterations < 1 or self.max_tool_iterations > 20:
raise ConfigurationError(
f"max_tool_iterations must be between 1 and 20, got {self.max_tool_iterations}"
)
# Validate request_timeout
if self.request_timeout < 1 or self.request_timeout > 300:
raise ConfigurationError(
f"request_timeout must be between 1 and 300 seconds, got {self.request_timeout}"
)
# Validate URLs
if not self.deepseek_base_url.startswith(("http://", "https://")):
raise ConfigurationError(
f"Invalid deepseek_base_url: {self.deepseek_base_url}"
)
if not self.tmdb_base_url.startswith(("http://", "https://")):
raise ConfigurationError(f"Invalid tmdb_base_url: {self.tmdb_base_url}")
# Validate memory file path
memory_path = Path(self.memory_file)
if memory_path.exists() and not memory_path.is_file():
raise ConfigurationError(
f"memory_file exists but is not a file: {self.memory_file}"
)
def is_deepseek_configured(self) -> bool:
"""Check if DeepSeek API is properly configured."""
return bool(self.deepseek_api_key and self.deepseek_base_url)
def is_tmdb_configured(self) -> bool:
"""Check if TMDB API is properly configured."""
return bool(self.tmdb_api_key and self.tmdb_base_url)
# Global settings instance
settings = Settings()
+13
View File
@@ -0,0 +1,13 @@
"""LLM clients module."""
from .deepseek import DeepSeekClient
from .exceptions import LLMAPIError, LLMConfigurationError, LLMError
from .ollama import OllamaClient
__all__ = [
"DeepSeekClient",
"OllamaClient",
"LLMError",
"LLMAPIError",
"LLMConfigurationError",
]
+148
View File
@@ -0,0 +1,148 @@
"""DeepSeek LLM client with robust error handling."""
import logging
from typing import Any
import requests
from requests.exceptions import HTTPError, RequestException, Timeout
from ..config import settings
from .exceptions import LLMAPIError, LLMConfigurationError
logger = logging.getLogger(__name__)
class DeepSeekClient:
"""Client for interacting with DeepSeek API."""
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
model: str | None = None,
timeout: int | None = None,
):
"""
Initialize DeepSeek client.
Args:
api_key: API key for authentication (defaults to settings)
base_url: Base URL for API (defaults to settings)
model: Model name to use (defaults to settings)
timeout: Request timeout in seconds (defaults to settings)
Raises:
LLMConfigurationError: If API key is missing
"""
self.api_key = api_key or settings.deepseek_api_key
self.base_url = base_url or settings.deepseek_base_url
self.model = model or settings.model
self.timeout = timeout or settings.request_timeout
if not self.api_key:
raise LLMConfigurationError(
"DeepSeek API key is required. Set DEEPSEEK_API_KEY environment variable."
)
if not self.base_url:
raise LLMConfigurationError(
"DeepSeek base URL is required. Set DEEPSEEK_BASE_URL environment variable."
)
logger.info(f"DeepSeek client initialized with model: {self.model}")
def complete( # noqa: PLR0912
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
"""
Generate a completion from the LLM.
Args:
messages: List of message dicts with 'role' and 'content' keys
tools: Optional list of tool specifications (OpenAI format)
Returns:
OpenAI-compatible message dict with 'role', 'content', and optionally 'tool_calls'
Raises:
LLMAPIError: If API request fails
ValueError: If messages format is invalid
"""
# Validate messages format
if not messages:
raise ValueError("Messages list cannot be empty")
for msg in messages:
if not isinstance(msg, dict):
raise ValueError(f"Each message must be a dict, got {type(msg)}")
if "role" not in msg:
raise ValueError(f"Message must have 'role' key, got {msg.keys()}")
# Allow system, user, assistant, and tool roles
if msg["role"] not in ("system", "user", "assistant", "tool"):
raise ValueError(f"Invalid role: {msg['role']}")
# Content is optional for tool messages (they may have tool_call_id instead)
if msg["role"] != "tool" and "content" not in msg:
raise ValueError(
f"Non-tool message must have 'content' key, got {msg.keys()}"
)
url = f"{self.base_url}/v1/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"messages": messages,
"temperature": settings.temperature,
}
# Add tools if provided
if tools:
payload["tools"] = tools
try:
logger.debug(
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
)
response = requests.post(
url, headers=headers, json=payload, timeout=self.timeout
)
response.raise_for_status()
data = response.json()
# Validate response structure
if "choices" not in data or not data["choices"]:
raise LLMAPIError("Invalid API response: missing 'choices'")
if "message" not in data["choices"][0]:
raise LLMAPIError("Invalid API response: missing 'message' in choice")
# Return the full message dict (OpenAI format)
message = data["choices"][0]["message"]
logger.debug(f"Received response: {message.get('content', '')[:100]}...")
return message
except Timeout as e:
logger.error(f"Request timeout after {self.timeout}s: {e}")
raise LLMAPIError(f"Request timeout after {self.timeout} seconds") from e
except HTTPError as e:
logger.error(f"HTTP error from DeepSeek API: {e}")
if e.response is not None:
try:
error_data = e.response.json()
error_msg = error_data.get("error", {}).get("message", str(e))
except Exception:
error_msg = str(e)
raise LLMAPIError(f"DeepSeek API error: {error_msg}") from e
raise LLMAPIError(f"HTTP error: {e}") from e
except RequestException as e:
logger.error(f"Request failed: {e}")
raise LLMAPIError(f"Failed to connect to DeepSeek API: {e}") from e
except (KeyError, IndexError, TypeError) as e:
logger.error(f"Failed to parse API response: {e}")
raise LLMAPIError(f"Invalid API response format: {e}") from e
+19
View File
@@ -0,0 +1,19 @@
"""LLM-related exceptions."""
class LLMError(Exception):
"""Base exception for LLM-related errors."""
pass
class LLMConfigurationError(LLMError):
"""Raised when LLM is not properly configured."""
pass
class LLMAPIError(LLMError):
"""Raised when LLM API returns an error."""
pass
+193
View File
@@ -0,0 +1,193 @@
"""Ollama LLM client with robust error handling."""
import logging
import os
from typing import Any
import requests
from requests.exceptions import HTTPError, RequestException, Timeout
from ..config import settings
from .exceptions import LLMAPIError, LLMConfigurationError
logger = logging.getLogger(__name__)
class OllamaClient:
"""
Client for interacting with Ollama API.
Ollama runs locally and provides an OpenAI-compatible API.
Example:
>>> client = OllamaClient(model="llama3.2")
>>> messages = [{"role": "user", "content": "Hello!"}]
>>> response = client.complete(messages)
>>> print(response)
"""
def __init__(
self,
base_url: str | None = None,
model: str | None = None,
timeout: int | None = None,
temperature: float | None = None,
):
"""
Initialize Ollama client.
Args:
base_url: Ollama API base URL (defaults to http://localhost:11434)
model: Model name to use (e.g., "llama3.2", "mistral", "codellama")
timeout: Request timeout in seconds (defaults to settings)
temperature: Temperature for generation (defaults to settings)
Raises:
LLMConfigurationError: If configuration is invalid
"""
self.base_url = base_url or os.getenv(
"OLLAMA_BASE_URL", "http://localhost:11434"
)
self.model = model or os.getenv("OLLAMA_MODEL", "llama3.2")
self.timeout = timeout or settings.request_timeout
self.temperature = (
temperature if temperature is not None else settings.temperature
)
if not self.base_url:
raise LLMConfigurationError(
"Ollama base URL is required. Set OLLAMA_BASE_URL environment variable."
)
if not self.model:
raise LLMConfigurationError(
"Ollama model is required. Set OLLAMA_MODEL environment variable."
)
logger.info(f"Ollama client initialized with model: {self.model}")
def complete( # noqa: PLR0912
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
"""
Generate a completion from the LLM.
Args:
messages: List of message dicts with 'role' and 'content' keys
tools: Optional list of tool specifications (OpenAI format)
Returns:
OpenAI-compatible message dict with 'role', 'content', and optionally 'tool_calls'
Raises:
LLMAPIError: If API request fails
ValueError: If messages format is invalid
"""
# Validate messages format
if not messages:
raise ValueError("Messages list cannot be empty")
for msg in messages:
if not isinstance(msg, dict):
raise ValueError(f"Each message must be a dict, got {type(msg)}")
if "role" not in msg:
raise ValueError(f"Message must have 'role' key, got {msg.keys()}")
# Allow system, user, assistant, and tool roles
if msg["role"] not in ("system", "user", "assistant", "tool"):
raise ValueError(f"Invalid role: {msg['role']}")
# Content is optional for tool messages (they may have tool_call_id instead)
if msg["role"] != "tool" and "content" not in msg:
raise ValueError(
f"Non-tool message must have 'content' key, got {msg.keys()}"
)
url = f"{self.base_url}/api/chat"
payload = {
"model": self.model,
"messages": messages,
"stream": False,
"options": {
"temperature": self.temperature,
},
}
# Add tools if provided
if tools:
payload["tools"] = tools
try:
logger.debug(
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
)
response = requests.post(url, json=payload, timeout=self.timeout)
response.raise_for_status()
data = response.json()
# Validate response structure
if "message" not in data:
raise LLMAPIError("Invalid API response: missing 'message'")
# Return the full message dict (OpenAI format)
message = data["message"]
logger.debug(f"Received response: {message.get('content', '')[:100]}...")
return message
except Timeout as e:
logger.error(f"Request timeout after {self.timeout}s: {e}")
raise LLMAPIError(f"Request timeout after {self.timeout} seconds") from e
except HTTPError as e:
logger.error(f"HTTP error from Ollama API: {e}")
if e.response is not None:
try:
error_data = e.response.json()
error_msg = error_data.get("error", str(e))
except Exception:
error_msg = str(e)
raise LLMAPIError(f"Ollama API error: {error_msg}") from e
raise LLMAPIError(f"HTTP error: {e}") from e
except RequestException as e:
logger.error(f"Request failed: {e}")
raise LLMAPIError(f"Failed to connect to Ollama API: {e}") from e
except (KeyError, IndexError, TypeError) as e:
logger.error(f"Failed to parse API response: {e}")
raise LLMAPIError(f"Invalid API response format: {e}") from e
def list_models(self) -> list[str]:
"""
List available models in Ollama.
Returns:
List of model names
"""
url = f"{self.base_url}/api/tags"
try:
response = requests.get(url, timeout=self.timeout)
response.raise_for_status()
data = response.json()
models = [model["name"] for model in data.get("models", [])]
logger.info(f"Found {len(models)} models: {models}")
return models
except Exception as e:
logger.error(f"Failed to list models: {e}")
return []
def is_available(self) -> bool:
"""
Check if Ollama is running and accessible.
Returns:
True if Ollama is available, False otherwise
"""
try:
url = f"{self.base_url}/api/tags"
response = requests.get(url, timeout=5)
return response.status_code == 200
except Exception:
return False
+101
View File
@@ -0,0 +1,101 @@
# agent/parameters.py
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
@dataclass
class ParameterSchema:
"""Describes a required parameter for the agent."""
key: str
description: str
why_needed: str # Explanation for the AI
type: str # "string", "number", "object", etc.
validator: Callable[[Any], bool] | None = None
default: Any = None
required: bool = True
# Define all required parameters
REQUIRED_PARAMETERS = [
ParameterSchema(
key="config",
description="Configuration object containing all folder paths",
why_needed=(
"This contains the paths to all important folders:\n"
"- download_folder: Where downloaded files arrive before being organized\n"
"- tvshow_folder: Where TV show files are organized and stored\n"
"- movie_folder: Where movie files are organized and stored\n"
"- torrent_folder: Where torrent structures are saved for the torrent client"
),
type="object",
validator=lambda x: isinstance(x, dict),
required=True,
default={},
),
ParameterSchema(
key="tv_shows",
description="List of TV shows the user is following",
why_needed=(
"This tracks which TV shows you're following. "
"Each show includes: IMDB ID, title, number of seasons, and status (ongoing or ended)."
),
type="array",
validator=lambda x: isinstance(x, list),
required=False,
default=[],
),
]
def get_parameter_schema(key: str) -> ParameterSchema | None:
"""Get schema for a specific parameter."""
for param in REQUIRED_PARAMETERS:
if param.key == key:
return param
return None
def get_missing_required_parameters(memory_data: dict) -> list[ParameterSchema]:
"""Get list of required parameters that are missing or None."""
missing = []
for param in REQUIRED_PARAMETERS:
if param.required:
value = memory_data.get(param.key)
if value is None:
missing.append(param)
return missing
def format_parameters_for_prompt() -> str:
"""Format parameter descriptions for the AI system prompt."""
lines = ["REQUIRED PARAMETERS:"]
for param in REQUIRED_PARAMETERS:
status = "REQUIRED" if param.required else "OPTIONAL"
lines.append(f"\n- {param.key} ({status}):")
lines.append(f" Description: {param.description}")
lines.append(f" Why needed: {param.why_needed}")
lines.append(f" Type: {param.type}")
return "\n".join(lines)
def validate_parameter(key: str, value: Any) -> tuple[bool, str | None]:
"""
Validate a parameter value against its schema.
Returns:
(is_valid, error_message)
"""
schema = get_parameter_schema(key)
if not schema:
return True, None # Unknown parameters are allowed
if schema.validator:
try:
if not schema.validator(value):
return False, f"Validation failed for {key}"
except Exception as e:
return False, f"Validation error for {key}: {str(e)}"
return True, None
+180
View File
@@ -0,0 +1,180 @@
"""Prompt builder for the agent system."""
import json
from typing import Any
from infrastructure.persistence import get_memory
from .registry import Tool
class PromptBuilder:
"""Builds system prompts for the agent with memory context."""
def __init__(self, tools: dict[str, Tool]):
self.tools = tools
def build_tools_spec(self) -> list[dict[str, Any]]:
"""Build the tool specification for the LLM API."""
tool_specs = []
for tool in self.tools.values():
spec = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
}
tool_specs.append(spec)
return tool_specs
def _format_tools_description(self) -> str:
"""Format tools with their descriptions and parameters."""
if not self.tools:
return ""
return "\n".join(
f"- {tool.name}: {tool.description}\n"
f" Parameters: {json.dumps(tool.parameters, ensure_ascii=False)}"
for tool in self.tools.values()
)
def _format_episodic_context(self, memory) -> str:
"""Format episodic memory context for the prompt."""
lines = []
if memory.episodic.last_search_results:
results = memory.episodic.last_search_results
result_list = results.get("results", [])
lines.append(
f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)"
)
# Show first 5 results
for i, result in enumerate(result_list[:5]):
name = result.get("name", "Unknown")
lines.append(f" {i+1}. {name}")
if len(result_list) > 5:
lines.append(f" ... and {len(result_list) - 5} more")
if memory.episodic.pending_question:
question = memory.episodic.pending_question
lines.append(f"\nPENDING QUESTION: {question.get('question')}")
lines.append(f" Type: {question.get('type')}")
if question.get("options"):
lines.append(f" Options: {len(question.get('options'))}")
if memory.episodic.active_downloads:
lines.append(f"\nACTIVE DOWNLOADS: {len(memory.episodic.active_downloads)}")
for dl in memory.episodic.active_downloads[:3]:
lines.append(f" - {dl.get('name')}: {dl.get('progress', 0)}%")
if memory.episodic.recent_errors:
lines.append("\nRECENT ERRORS (up to 3):")
for error in memory.episodic.recent_errors[-3:]:
lines.append(
f" - Action '{error.get('action')}' failed: {error.get('error')}"
)
# Unread events
unread = [e for e in memory.episodic.background_events if not e.get("read")]
if unread:
lines.append(f"\nUNREAD EVENTS: {len(unread)}")
for event in unread[:3]:
lines.append(f" - {event.get('type')}: {event.get('data')}")
return "\n".join(lines)
def _format_stm_context(self, memory) -> str:
"""Format short-term memory context for the prompt."""
lines = []
if memory.stm.current_workflow:
workflow = memory.stm.current_workflow
lines.append(
f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})"
)
if workflow.get("target"):
lines.append(f" Target: {workflow.get('target')}")
if memory.stm.current_topic:
lines.append(f"CURRENT TOPIC: {memory.stm.current_topic}")
if memory.stm.extracted_entities:
lines.append("EXTRACTED ENTITIES:")
for key, value in memory.stm.extracted_entities.items():
lines.append(f" - {key}: {value}")
if memory.stm.language:
lines.append(f"CONVERSATION LANGUAGE: {memory.stm.language}")
return "\n".join(lines)
def _format_config_context(self, memory) -> str:
"""Format configuration context."""
lines = ["CURRENT CONFIGURATION:"]
if memory.ltm.config:
for key, value in memory.ltm.config.items():
lines.append(f" - {key}: {value}")
else:
lines.append(" (no configuration set)")
return "\n".join(lines)
def build_system_prompt(self) -> str:
"""Build the complete system prompt."""
# Get memory once for all context formatting
memory = get_memory()
# Base instruction
base = "You are a helpful AI assistant for managing a media library."
# Language instruction
language_instruction = (
"Your first task is to determine the user's language from their message "
"and use the `set_language` tool if it's different from the current one. "
"After that, proceed to help the user."
)
# Available tools
tools_desc = self._format_tools_description()
tools_section = f"\nAVAILABLE TOOLS:\n{tools_desc}" if tools_desc else ""
# Configuration
config_section = self._format_config_context(memory)
if config_section:
config_section = f"\n{config_section}"
# STM context
stm_context = self._format_stm_context(memory)
if stm_context:
stm_context = f"\n{stm_context}"
# Episodic context
episodic_context = self._format_episodic_context(memory)
# Important rules
rules = """
IMPORTANT RULES:
- Use tools to accomplish tasks
- When search results are available, reference them by index (e.g., "add_torrent_by_index")
- Always confirm actions with the user before executing destructive operations
- Provide clear, concise responses
"""
# Examples
examples = """
EXAMPLES:
- User: "Find Inception" → Use find_media_imdb_id, then find_torrent
- User: "download the 3rd one" → Use add_torrent_by_index with index=3
- User: "List my downloads" → Use list_folder with folder_type="download"
"""
return f"""{base}
{language_instruction}
{tools_section}
{config_section}
{stm_context}
{episodic_context}
{rules}
{examples}
"""
+112
View File
@@ -0,0 +1,112 @@
"""Tool registry - defines and registers all available tools for the agent."""
import inspect
import logging
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class Tool:
"""Represents a tool that can be used by the agent."""
name: str
description: str
func: Callable[..., dict[str, Any]]
parameters: dict[str, Any]
def _create_tool_from_function(func: Callable) -> Tool:
"""
Create a Tool object from a function.
Args:
func: Function to convert to a tool
Returns:
Tool object with metadata extracted from function
"""
sig = inspect.signature(func)
doc = inspect.getdoc(func)
# Extract description from docstring (first line)
description = doc.strip().split("\n")[0] if doc else func.__name__
# Build JSON schema from function signature
properties = {}
required = []
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
# Map Python types to JSON schema types
param_type = "string" # default
if param.annotation != inspect.Parameter.empty:
if param.annotation is str:
param_type = "string"
elif param.annotation is int:
param_type = "integer"
elif param.annotation is float:
param_type = "number"
elif param.annotation is bool:
param_type = "boolean"
properties[param_name] = {
"type": param_type,
"description": f"Parameter {param_name}",
}
# Add to required if no default value
if param.default == inspect.Parameter.empty:
required.append(param_name)
parameters = {
"type": "object",
"properties": properties,
"required": required,
}
return Tool(
name=func.__name__,
description=description,
func=func,
parameters=parameters,
)
def make_tools() -> dict[str, Tool]:
"""
Create and register all available tools.
Returns:
Dictionary mapping tool names to Tool objects
"""
# Import tools here to avoid circular dependencies
from .tools import api as api_tools # noqa: PLC0415
from .tools import filesystem as fs_tools # noqa: PLC0415
from .tools import language as lang_tools # noqa: PLC0415
# List of all tool functions
tool_functions = [
fs_tools.set_path_for_folder,
fs_tools.list_folder,
api_tools.find_media_imdb_id,
api_tools.find_torrent,
api_tools.add_torrent_by_index,
api_tools.add_torrent_to_qbittorrent,
api_tools.get_torrent_by_index,
lang_tools.set_language,
]
# Create Tool objects from functions
tools = {}
for func in tool_functions:
tool = _create_tool_from_function(func)
tools[tool.name] = tool
logger.info(f"Registered {len(tools)} tools: {list(tools.keys())}")
return tools
+22
View File
@@ -0,0 +1,22 @@
"""Tools module - filesystem and API tools for the agent."""
from .api import (
add_torrent_by_index,
add_torrent_to_qbittorrent,
find_media_imdb_id,
find_torrent,
get_torrent_by_index,
)
from .filesystem import list_folder, set_path_for_folder
from .language import set_language
__all__ = [
"set_path_for_folder",
"list_folder",
"find_media_imdb_id",
"find_torrent",
"get_torrent_by_index",
"add_torrent_to_qbittorrent",
"add_torrent_by_index",
"set_language",
]
+196
View File
@@ -0,0 +1,196 @@
"""API tools for interacting with external services."""
import logging
from typing import Any
from application.movies import SearchMovieUseCase
from application.torrents import AddTorrentUseCase, SearchTorrentsUseCase
from infrastructure.api.knaben import knaben_client
from infrastructure.api.qbittorrent import qbittorrent_client
from infrastructure.api.tmdb import tmdb_client
from infrastructure.persistence import get_memory
logger = logging.getLogger(__name__)
def find_media_imdb_id(media_title: str) -> dict[str, Any]:
"""
Find the IMDb ID for a given media title using TMDB API.
Args:
media_title: Title of the media to search for.
Returns:
Dict with IMDb ID and media info, or error details.
"""
use_case = SearchMovieUseCase(tmdb_client)
response = use_case.execute(media_title)
result = response.to_dict()
if result.get("status") == "ok":
memory = get_memory()
memory.stm.set_entity(
"last_media_search",
{
"title": result.get("title"),
"imdb_id": result.get("imdb_id"),
"media_type": result.get("media_type"),
"tmdb_id": result.get("tmdb_id"),
},
)
memory.stm.set_topic("searching_media")
logger.debug(f"Stored media search result in STM: {result.get('title')}")
return result
def find_torrent(media_title: str) -> dict[str, Any]:
"""
Find torrents for a given media title using Knaben API.
Results are stored in episodic memory so the user can reference them
by index (e.g., "download the 3rd one").
Args:
media_title: Title of the media to search for.
Returns:
Dict with torrent list or error details.
"""
logger.info(f"Searching torrents for: {media_title}")
use_case = SearchTorrentsUseCase(knaben_client)
response = use_case.execute(media_title, limit=10)
result = response.to_dict()
if result.get("status") == "ok":
memory = get_memory()
torrents = result.get("torrents", [])
memory.episodic.store_search_results(
query=media_title, results=torrents, search_type="torrent"
)
memory.stm.set_topic("selecting_torrent")
logger.info(f"Stored {len(torrents)} torrent results in episodic memory")
return result
def get_torrent_by_index(index: int) -> dict[str, Any]:
"""
Get a torrent from the last search results by its index.
Allows the user to reference results by number after a search.
Args:
index: 1-based index of the torrent in the search results.
Returns:
Dict with torrent data or error if not found.
"""
logger.info(f"Getting torrent at index: {index}")
memory = get_memory()
if memory.episodic.last_search_results:
results_count = len(memory.episodic.last_search_results.get("results", []))
query = memory.episodic.last_search_results.get("query", "unknown")
logger.debug(f"Episodic memory has {results_count} results from: {query}")
else:
logger.warning("No search results in episodic memory")
result = memory.episodic.get_result_by_index(index)
if result:
logger.info(f"Found torrent at index {index}: {result.get('name', 'unknown')}")
return {"status": "ok", "torrent": result}
logger.warning(f"No torrent found at index {index}")
return {
"status": "error",
"error": "not_found",
"message": f"No torrent found at index {index}. Search for torrents first.",
}
def add_torrent_to_qbittorrent(magnet_link: str) -> dict[str, Any]:
"""
Add a torrent to qBittorrent using a magnet link.
Args:
magnet_link: Magnet link of the torrent to add.
Returns:
Dict with success status or error details.
"""
logger.info("Adding torrent to qBittorrent")
use_case = AddTorrentUseCase(qbittorrent_client)
response = use_case.execute(magnet_link)
result = response.to_dict()
if result.get("status") == "ok":
memory = get_memory()
last_search = memory.episodic.get_search_results()
torrent_name = "Unknown"
if last_search:
for t in last_search.get("results", []):
if t.get("magnet") == magnet_link:
torrent_name = t.get("name", "Unknown")
break
memory.episodic.add_active_download(
{
"task_id": magnet_link[:20],
"name": torrent_name,
"magnet": magnet_link,
"progress": 0,
"status": "queued",
}
)
memory.stm.set_topic("downloading")
memory.stm.end_workflow()
logger.info(f"Added download to episodic memory: {torrent_name}")
return result
def add_torrent_by_index(index: int) -> dict[str, Any]:
"""
Add a torrent from the last search results by its index.
Combines get_torrent_by_index and add_torrent_to_qbittorrent.
Args:
index: 1-based index of the torrent in the search results.
Returns:
Dict with success status or error details.
"""
logger.info(f"Adding torrent by index: {index}")
torrent_result = get_torrent_by_index(index)
if torrent_result.get("status") != "ok":
return torrent_result
torrent = torrent_result.get("torrent", {})
magnet = torrent.get("magnet")
if not magnet:
logger.error("Torrent has no magnet link")
return {
"status": "error",
"error": "no_magnet",
"message": "The selected torrent has no magnet link",
}
logger.info(f"Adding torrent: {torrent.get('name', 'unknown')}")
result = add_torrent_to_qbittorrent(magnet)
if result.get("status") == "ok":
result["torrent_name"] = torrent.get("name", "Unknown")
return result
+40
View File
@@ -0,0 +1,40 @@
"""Filesystem tools for folder management."""
from typing import Any
from application.filesystem import ListFolderUseCase, SetFolderPathUseCase
from infrastructure.filesystem import FileManager
def set_path_for_folder(folder_name: str, path_value: str) -> dict[str, Any]:
"""
Set a folder path in the configuration.
Args:
folder_name: Name of folder to set (download, tvshow, movie, torrent).
path_value: Absolute path to the folder.
Returns:
Dict with status or error information.
"""
file_manager = FileManager()
use_case = SetFolderPathUseCase(file_manager)
response = use_case.execute(folder_name, path_value)
return response.to_dict()
def list_folder(folder_type: str, path: str = ".") -> dict[str, Any]:
"""
List contents of a configured folder.
Args:
folder_type: Type of folder to list (download, tvshow, movie, torrent).
path: Relative path within the folder (default: root).
Returns:
Dict with folder contents or error information.
"""
file_manager = FileManager()
use_case = ListFolderUseCase(file_manager)
response = use_case.execute(folder_type, path)
return response.to_dict()
+35
View File
@@ -0,0 +1,35 @@
"""Language management tools for the agent."""
import logging
from typing import Any
from infrastructure.persistence import get_memory
logger = logging.getLogger(__name__)
def set_language(language: str) -> dict[str, Any]:
"""
Set the conversation language.
Args:
language: Language code (e.g., 'en', 'fr', 'es', 'de')
Returns:
Status dictionary
"""
try:
memory = get_memory()
memory.stm.set_language(language)
memory.save()
logger.info(f"Language set to: {language}")
return {
"status": "ok",
"message": f"Language set to {language}",
"language": language,
}
except Exception as e:
logger.error(f"Failed to set language: {e}")
return {"status": "error", "error": str(e)}