Unfucked gemini's mess
This commit is contained in:
+2
-2
@@ -1,6 +1,6 @@
|
||||
"""Agent module for media library management."""
|
||||
|
||||
from .agent import Agent, LLMClient
|
||||
from .agent import Agent
|
||||
from .config import settings
|
||||
|
||||
__all__ = ["Agent", "LLMClient", "settings"]
|
||||
__all__ = ["Agent", "settings"]
|
||||
|
||||
+152
-237
@@ -1,8 +1,7 @@
|
||||
"""Main agent for media library management."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
@@ -13,266 +12,182 @@ from .registry import Tool, make_tools
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClient(Protocol):
|
||||
"""Protocol defining the LLM client interface."""
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]]) -> str:
|
||||
"""Send messages to the LLM and get a response."""
|
||||
...
|
||||
|
||||
|
||||
class Agent:
|
||||
"""
|
||||
AI agent for media library management.
|
||||
|
||||
Orchestrates interactions between the LLM, memory, and tools
|
||||
to respond to user requests.
|
||||
|
||||
Attributes:
|
||||
llm: LLM client (DeepSeek or Ollama).
|
||||
tools: Available tools for the agent.
|
||||
prompt_builder: Builds system prompts with context.
|
||||
max_tool_iterations: Maximum tool calls per request.
|
||||
|
||||
Uses OpenAI-compatible tool calling API.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: LLMClient, max_tool_iterations: int = 5):
|
||||
def __init__(self, llm, max_tool_iterations: int = 5):
|
||||
"""
|
||||
Initialize the agent.
|
||||
|
||||
|
||||
Args:
|
||||
llm: LLM client compatible with the LLMClient protocol.
|
||||
max_tool_iterations: Maximum tool iterations (default: 5).
|
||||
llm: LLM client with complete() method
|
||||
max_tool_iterations: Maximum number of tool execution iterations
|
||||
"""
|
||||
self.llm = llm
|
||||
self.tools: dict[str, Tool] = make_tools()
|
||||
self.tools: Dict[str, Tool] = make_tools()
|
||||
self.prompt_builder = PromptBuilder(self.tools)
|
||||
self.max_tool_iterations = max_tool_iterations
|
||||
|
||||
def _parse_intent(self, text: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Parse an LLM response to detect a tool call.
|
||||
|
||||
Args:
|
||||
text: LLM response text.
|
||||
|
||||
Returns:
|
||||
Dict with intent if a tool call is detected, None otherwise.
|
||||
"""
|
||||
text = text.strip()
|
||||
|
||||
# Try direct JSON parse
|
||||
if text.startswith("{") and text.endswith("}"):
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if self._is_valid_intent(data):
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try to extract JSON from text
|
||||
try:
|
||||
start = text.find("{")
|
||||
end = text.rfind("}") + 1
|
||||
if start != -1 and end > start:
|
||||
json_str = text[start:end]
|
||||
data = json.loads(json_str)
|
||||
if self._is_valid_intent(data):
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _is_valid_intent(self, data: Any) -> bool:
|
||||
"""Check if parsed data is a valid tool intent."""
|
||||
if not isinstance(data, dict) or "action" not in data:
|
||||
return False
|
||||
action = data.get("action")
|
||||
return isinstance(action, dict) and isinstance(action.get("name"), str)
|
||||
|
||||
def _execute_action(self, intent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a tool action requested by the LLM.
|
||||
|
||||
Args:
|
||||
intent: Dict containing the action to execute.
|
||||
|
||||
Returns:
|
||||
Tool execution result.
|
||||
"""
|
||||
action = intent["action"]
|
||||
name: str = action["name"]
|
||||
args: dict[str, Any] = action.get("args", {}) or {}
|
||||
|
||||
tool = self.tools.get(name)
|
||||
if not tool:
|
||||
logger.warning(f"Unknown tool requested: {name}")
|
||||
return {
|
||||
"error": "unknown_tool",
|
||||
"tool": name,
|
||||
"available_tools": list(self.tools.keys()),
|
||||
}
|
||||
|
||||
try:
|
||||
result = tool.func(**args)
|
||||
|
||||
# Track errors in episodic memory
|
||||
if result.get("status") == "error" or result.get("error"):
|
||||
memory = get_memory()
|
||||
memory.episodic.add_error(
|
||||
action=name,
|
||||
error=result.get("error", result.get("message", "Unknown error")),
|
||||
context={"args": args, "result": result},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except TypeError as e:
|
||||
error_msg = f"Bad arguments for {name}: {e}"
|
||||
logger.error(error_msg)
|
||||
memory = get_memory()
|
||||
memory.episodic.add_error(
|
||||
action=name, error=error_msg, context={"args": args}
|
||||
)
|
||||
return {"error": "bad_args", "message": str(e)}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing {name}: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
memory = get_memory()
|
||||
memory.episodic.add_error(action=name, error=str(e), context={"args": args})
|
||||
return {"error": "execution_error", "message": str(e)}
|
||||
|
||||
def _check_unread_events(self) -> str:
|
||||
"""
|
||||
Check for unread background events and format them.
|
||||
|
||||
Returns:
|
||||
Formatted string of events, or empty string if none.
|
||||
"""
|
||||
memory = get_memory()
|
||||
events = memory.episodic.get_unread_events()
|
||||
|
||||
if not events:
|
||||
return ""
|
||||
|
||||
lines = ["Recent events:"]
|
||||
for event in events:
|
||||
event_type = event.get("type", "unknown")
|
||||
data = event.get("data", {})
|
||||
|
||||
if event_type == "download_complete":
|
||||
lines.append(f" - Download completed: {data.get('name')}")
|
||||
elif event_type == "new_files_detected":
|
||||
lines.append(f" - {data.get('count')} new files detected")
|
||||
else:
|
||||
lines.append(f" - {event_type}: {data}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def step(self, user_input: str) -> str:
|
||||
"""
|
||||
Execute one agent step with iterative tool execution.
|
||||
|
||||
Process:
|
||||
1. Check for unread events
|
||||
2. Build system prompt with memory context
|
||||
3. Query the LLM
|
||||
4. If tool call detected, execute and loop
|
||||
5. Return final text response
|
||||
|
||||
Execute one agent step with the user input.
|
||||
|
||||
This method:
|
||||
1. Adds user message to memory
|
||||
2. Builds prompt with history and context
|
||||
3. Calls LLM, executing tools as needed
|
||||
4. Returns final response
|
||||
|
||||
Args:
|
||||
user_input: User message.
|
||||
|
||||
user_input: User's message
|
||||
|
||||
Returns:
|
||||
Final response in natural text.
|
||||
Agent's final response
|
||||
"""
|
||||
logger.info("Starting agent step")
|
||||
logger.debug(f"User input: {user_input}")
|
||||
|
||||
memory = get_memory()
|
||||
|
||||
# Check for background events
|
||||
events_notification = self._check_unread_events()
|
||||
if events_notification:
|
||||
logger.info("Found unread background events")
|
||||
|
||||
# Build system prompt
|
||||
|
||||
# Add user message to history
|
||||
memory.stm.add_message("user", user_input)
|
||||
memory.save()
|
||||
|
||||
# Build initial messages
|
||||
system_prompt = self.prompt_builder.build_system_prompt()
|
||||
|
||||
# Initialize conversation
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
|
||||
|
||||
# Add conversation history
|
||||
history = memory.stm.get_recent_history(settings.max_history_messages)
|
||||
if history:
|
||||
for msg in history:
|
||||
messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
logger.debug(f"Added {len(history)} messages from history")
|
||||
|
||||
# Add events notification
|
||||
if events_notification:
|
||||
messages.append(
|
||||
{"role": "system", "content": f"[NOTIFICATION]\n{events_notification}"}
|
||||
)
|
||||
|
||||
# Add user input
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
messages.extend(history)
|
||||
|
||||
# Add unread events if any
|
||||
unread_events = memory.episodic.get_unread_events()
|
||||
if unread_events:
|
||||
events_text = "\n".join([
|
||||
f"- {e['type']}: {e['data']}"
|
||||
for e in unread_events
|
||||
])
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": f"Background events:\n{events_text}"
|
||||
})
|
||||
|
||||
# Get tools specification for OpenAI format
|
||||
tools_spec = self.prompt_builder.build_tools_spec()
|
||||
|
||||
# Tool execution loop
|
||||
iteration = 0
|
||||
while iteration < self.max_tool_iterations:
|
||||
logger.debug(f"Iteration {iteration + 1}/{self.max_tool_iterations}")
|
||||
|
||||
llm_response = self.llm.complete(messages)
|
||||
logger.debug(f"LLM response: {llm_response[:200]}...")
|
||||
|
||||
intent = self._parse_intent(llm_response)
|
||||
|
||||
if not intent:
|
||||
# Final text response
|
||||
logger.info("No tool intent, returning response")
|
||||
memory.stm.add_message("user", user_input)
|
||||
memory.stm.add_message("assistant", llm_response)
|
||||
for iteration in range(self.max_tool_iterations):
|
||||
# Call LLM with tools
|
||||
llm_result = self.llm.complete(messages, tools=tools_spec)
|
||||
|
||||
# Handle both tuple (response, usage) and dict response
|
||||
if isinstance(llm_result, tuple):
|
||||
response_message, usage = llm_result
|
||||
else:
|
||||
response_message = llm_result
|
||||
|
||||
# Check if there are tool calls
|
||||
tool_calls = response_message.get("tool_calls")
|
||||
|
||||
if not tool_calls:
|
||||
# No tool calls, this is the final response
|
||||
final_content = response_message.get("content", "")
|
||||
memory.stm.add_message("assistant", final_content)
|
||||
memory.save()
|
||||
return llm_response
|
||||
|
||||
# Execute tool
|
||||
tool_name = intent.get("action", {}).get("name", "unknown")
|
||||
logger.info(f"Executing tool: {tool_name}")
|
||||
tool_result = self._execute_action(intent)
|
||||
logger.debug(f"Tool result: {tool_result}")
|
||||
|
||||
# Add to conversation
|
||||
messages.append(
|
||||
{"role": "assistant", "content": json.dumps(intent, ensure_ascii=False)}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": json.dumps(
|
||||
{"tool_result": tool_result}, ensure_ascii=False
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
iteration += 1
|
||||
|
||||
# Max iterations reached
|
||||
logger.warning(f"Max iterations ({self.max_tool_iterations}) reached")
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please provide a final response based on the results.",
|
||||
}
|
||||
)
|
||||
|
||||
final_response = self.llm.complete(messages)
|
||||
|
||||
memory.stm.add_message("user", user_input)
|
||||
return final_content
|
||||
|
||||
# Add assistant message with tool calls to conversation
|
||||
messages.append(response_message)
|
||||
|
||||
# Execute each tool call
|
||||
for tool_call in tool_calls:
|
||||
tool_result = self._execute_tool_call(tool_call)
|
||||
|
||||
# Add tool result to messages
|
||||
messages.append({
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
"role": "tool",
|
||||
"name": tool_call.get("function", {}).get("name"),
|
||||
"content": json.dumps(tool_result, ensure_ascii=False),
|
||||
})
|
||||
|
||||
# Max iterations reached, force final response
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": "Please provide a final response to the user without using any more tools."
|
||||
})
|
||||
|
||||
llm_result = self.llm.complete(messages)
|
||||
if isinstance(llm_result, tuple):
|
||||
final_message, usage = llm_result
|
||||
else:
|
||||
final_message = llm_result
|
||||
|
||||
final_response = final_message.get("content", "I've completed the requested actions.")
|
||||
memory.stm.add_message("assistant", final_response)
|
||||
memory.save()
|
||||
|
||||
return final_response
|
||||
|
||||
def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a single tool call.
|
||||
|
||||
Args:
|
||||
tool_call: OpenAI-format tool call dict
|
||||
|
||||
Returns:
|
||||
Result dictionary
|
||||
"""
|
||||
function = tool_call.get("function", {})
|
||||
tool_name = function.get("name", "")
|
||||
|
||||
try:
|
||||
args_str = function.get("arguments", "{}")
|
||||
args = json.loads(args_str)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse tool arguments: {e}")
|
||||
return {
|
||||
"error": "bad_args",
|
||||
"message": f"Invalid JSON arguments: {e}"
|
||||
}
|
||||
|
||||
# Validate tool exists
|
||||
if tool_name not in self.tools:
|
||||
available = list(self.tools.keys())
|
||||
return {
|
||||
"error": "unknown_tool",
|
||||
"message": f"Tool '{tool_name}' not found",
|
||||
"available_tools": available
|
||||
}
|
||||
|
||||
tool = self.tools[tool_name]
|
||||
|
||||
# Execute tool
|
||||
try:
|
||||
result = tool.func(**args)
|
||||
return result
|
||||
except KeyboardInterrupt:
|
||||
# Don't catch KeyboardInterrupt - let it propagate
|
||||
raise
|
||||
except TypeError as e:
|
||||
# Bad arguments
|
||||
memory = get_memory()
|
||||
memory.episodic.add_error(tool_name, f"bad_args: {e}")
|
||||
return {
|
||||
"error": "bad_args",
|
||||
"message": str(e),
|
||||
"tool": tool_name
|
||||
}
|
||||
except Exception as e:
|
||||
# Other errors
|
||||
memory = get_memory()
|
||||
memory.episodic.add_error(tool_name, str(e))
|
||||
return {
|
||||
"error": "execution_failed",
|
||||
"message": str(e),
|
||||
"tool": tool_name
|
||||
}
|
||||
|
||||
+19
-14
@@ -51,15 +51,16 @@ class DeepSeekClient:
|
||||
|
||||
logger.info(f"DeepSeek client initialized with model: {self.model}")
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]]) -> str:
|
||||
def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a completion from the LLM.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content' keys
|
||||
tools: Optional list of tool specifications (OpenAI format)
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
OpenAI-compatible message dict with 'role', 'content', and optionally 'tool_calls'
|
||||
|
||||
Raises:
|
||||
LLMAPIError: If API request fails
|
||||
@@ -72,12 +73,14 @@ class DeepSeekClient:
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
raise ValueError(f"Each message must be a dict, got {type(msg)}")
|
||||
if "role" not in msg or "content" not in msg:
|
||||
raise ValueError(
|
||||
f"Each message must have 'role' and 'content' keys, got {msg.keys()}"
|
||||
)
|
||||
if msg["role"] not in ("system", "user", "assistant"):
|
||||
if "role" not in msg:
|
||||
raise ValueError(f"Message must have 'role' key, got {msg.keys()}")
|
||||
# Allow system, user, assistant, and tool roles
|
||||
if msg["role"] not in ("system", "user", "assistant", "tool"):
|
||||
raise ValueError(f"Invalid role: {msg['role']}")
|
||||
# Content is optional for tool messages (they may have tool_call_id instead)
|
||||
if msg["role"] != "tool" and "content" not in msg:
|
||||
raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}")
|
||||
|
||||
url = f"{self.base_url}/v1/chat/completions"
|
||||
headers = {
|
||||
@@ -89,9 +92,13 @@ class DeepSeekClient:
|
||||
"messages": messages,
|
||||
"temperature": settings.temperature,
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
try:
|
||||
logger.debug(f"Sending request to {url} with {len(messages)} messages")
|
||||
logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools")
|
||||
response = requests.post(
|
||||
url, headers=headers, json=payload, timeout=self.timeout
|
||||
)
|
||||
@@ -105,13 +112,11 @@ class DeepSeekClient:
|
||||
if "message" not in data["choices"][0]:
|
||||
raise LLMAPIError("Invalid API response: missing 'message' in choice")
|
||||
|
||||
if "content" not in data["choices"][0]["message"]:
|
||||
raise LLMAPIError("Invalid API response: missing 'content' in message")
|
||||
# Return the full message dict (OpenAI format)
|
||||
message = data["choices"][0]["message"]
|
||||
logger.debug(f"Received response: {message.get('content', '')[:100]}...")
|
||||
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
logger.debug(f"Received response with {len(content)} characters")
|
||||
|
||||
return content
|
||||
return message
|
||||
|
||||
except Timeout as e:
|
||||
logger.error(f"Request timeout after {self.timeout}s: {e}")
|
||||
|
||||
+19
-14
@@ -66,15 +66,16 @@ class OllamaClient:
|
||||
|
||||
logger.info(f"Ollama client initialized with model: {self.model}")
|
||||
|
||||
def complete(self, messages: list[dict[str, Any]]) -> str:
|
||||
def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a completion from the LLM.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content' keys
|
||||
tools: Optional list of tool specifications (OpenAI format)
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
OpenAI-compatible message dict with 'role', 'content', and optionally 'tool_calls'
|
||||
|
||||
Raises:
|
||||
LLMAPIError: If API request fails
|
||||
@@ -87,12 +88,14 @@ class OllamaClient:
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
raise ValueError(f"Each message must be a dict, got {type(msg)}")
|
||||
if "role" not in msg or "content" not in msg:
|
||||
raise ValueError(
|
||||
f"Each message must have 'role' and 'content' keys, got {msg.keys()}"
|
||||
)
|
||||
if msg["role"] not in ("system", "user", "assistant"):
|
||||
if "role" not in msg:
|
||||
raise ValueError(f"Message must have 'role' key, got {msg.keys()}")
|
||||
# Allow system, user, assistant, and tool roles
|
||||
if msg["role"] not in ("system", "user", "assistant", "tool"):
|
||||
raise ValueError(f"Invalid role: {msg['role']}")
|
||||
# Content is optional for tool messages (they may have tool_call_id instead)
|
||||
if msg["role"] != "tool" and "content" not in msg:
|
||||
raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}")
|
||||
|
||||
url = f"{self.base_url}/api/chat"
|
||||
payload = {
|
||||
@@ -103,9 +106,13 @@ class OllamaClient:
|
||||
"temperature": self.temperature,
|
||||
},
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
try:
|
||||
logger.debug(f"Sending request to {url} with {len(messages)} messages")
|
||||
logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools")
|
||||
response = requests.post(url, json=payload, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -114,13 +121,11 @@ class OllamaClient:
|
||||
if "message" not in data:
|
||||
raise LLMAPIError("Invalid API response: missing 'message'")
|
||||
|
||||
if "content" not in data["message"]:
|
||||
raise LLMAPIError("Invalid API response: missing 'content' in message")
|
||||
# Return the full message dict (OpenAI format)
|
||||
message = data["message"]
|
||||
logger.debug(f"Received response: {message.get('content', '')[:100]}...")
|
||||
|
||||
content = data["message"]["content"]
|
||||
logger.debug(f"Received response with {len(content)} characters")
|
||||
|
||||
return content
|
||||
return message
|
||||
|
||||
except Timeout as e:
|
||||
logger.error(f"Request timeout after {self.timeout}s: {e}")
|
||||
|
||||
+114
-109
@@ -1,31 +1,36 @@
|
||||
"""Prompt builder for the agent system."""
|
||||
|
||||
from typing import Dict, List, Any
|
||||
import json
|
||||
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
from .parameters import format_parameters_for_prompt, get_missing_required_parameters
|
||||
from .registry import Tool
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
"""Builds system prompts for the agent with memory context.
|
||||
"""Builds system prompts for the agent with memory context."""
|
||||
|
||||
Attributes:
|
||||
tools: Dictionary of available tools.
|
||||
"""
|
||||
|
||||
def __init__(self, tools: dict[str, Tool]):
|
||||
"""
|
||||
Initialize the prompt builder.
|
||||
|
||||
Args:
|
||||
tools: Dictionary mapping tool names to Tool instances.
|
||||
"""
|
||||
def __init__(self, tools: Dict[str, Tool]):
|
||||
self.tools = tools
|
||||
|
||||
def build_tools_spec(self) -> List[Dict[str, Any]]:
|
||||
"""Build the tool specification for the LLM API."""
|
||||
tool_specs = []
|
||||
for tool in self.tools.values():
|
||||
spec = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
},
|
||||
}
|
||||
tool_specs.append(spec)
|
||||
return tool_specs
|
||||
|
||||
def _format_tools_description(self) -> str:
|
||||
"""Format tools with their descriptions and parameters."""
|
||||
if not self.tools:
|
||||
return ""
|
||||
return "\n".join(
|
||||
f"- {tool.name}: {tool.description}\n"
|
||||
f" Parameters: {json.dumps(tool.parameters, ensure_ascii=False)}"
|
||||
@@ -37,134 +42,134 @@ class PromptBuilder:
|
||||
memory = get_memory()
|
||||
lines = []
|
||||
|
||||
# Last search results
|
||||
if memory.episodic.last_search_results:
|
||||
search = memory.episodic.last_search_results
|
||||
lines.append(f"LAST SEARCH: '{search.get('query')}'")
|
||||
results = search.get("results", [])
|
||||
if results:
|
||||
lines.append(f" {len(results)} results available:")
|
||||
for r in results[:5]:
|
||||
name = r.get("name", r.get("title", "Unknown"))
|
||||
lines.append(f" {r.get('index')}. {name}")
|
||||
if len(results) > 5:
|
||||
lines.append(f" ... and {len(results) - 5} more")
|
||||
results = memory.episodic.last_search_results
|
||||
result_list = results.get('results', [])
|
||||
lines.append(f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)")
|
||||
# Show first 5 results
|
||||
for i, result in enumerate(result_list[:5]):
|
||||
name = result.get('name', 'Unknown')
|
||||
lines.append(f" {i+1}. {name}")
|
||||
if len(result_list) > 5:
|
||||
lines.append(f" ... and {len(result_list) - 5} more")
|
||||
|
||||
# Pending question
|
||||
if memory.episodic.pending_question:
|
||||
q = memory.episodic.pending_question
|
||||
lines.append(f"\nPENDING QUESTION: {q.get('question')}")
|
||||
for opt in q.get("options", []):
|
||||
lines.append(f" {opt.get('index')}. {opt.get('label')}")
|
||||
question = memory.episodic.pending_question
|
||||
lines.append(f"\nPENDING QUESTION: {question.get('question')}")
|
||||
lines.append(f" Type: {question.get('type')}")
|
||||
if question.get('options'):
|
||||
lines.append(f" Options: {len(question.get('options'))}")
|
||||
|
||||
# Active downloads
|
||||
if memory.episodic.active_downloads:
|
||||
lines.append(f"\nACTIVE DOWNLOADS: {len(memory.episodic.active_downloads)}")
|
||||
for dl in memory.episodic.active_downloads[:3]:
|
||||
lines.append(f" - {dl.get('name')}: {dl.get('progress', 0)}%")
|
||||
lines.append(f" - {dl.get('name')}: {dl.get('progress', 0)}%")
|
||||
|
||||
# Recent errors
|
||||
if memory.episodic.recent_errors:
|
||||
last_error = memory.episodic.recent_errors[-1]
|
||||
lines.append(
|
||||
f"\nLAST ERROR: {last_error.get('error')} "
|
||||
f"(action: {last_error.get('action')})"
|
||||
)
|
||||
lines.append("\nRECENT ERRORS (up to 3):")
|
||||
for error in memory.episodic.recent_errors[-3:]:
|
||||
lines.append(f" - Action '{error.get('action')}' failed: {error.get('error')}")
|
||||
|
||||
# Unread events
|
||||
unread = [e for e in memory.episodic.background_events if not e.get("read")]
|
||||
unread = [e for e in memory.episodic.background_events if not e.get('read')]
|
||||
if unread:
|
||||
lines.append(f"\nUNREAD EVENTS: {len(unread)}")
|
||||
for e in unread[:3]:
|
||||
lines.append(f" - {e.get('type')}: {e.get('data', {})}")
|
||||
for event in unread[:3]:
|
||||
lines.append(f" - {event.get('type')}: {event.get('data')}")
|
||||
|
||||
return "\n".join(lines) if lines else ""
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_stm_context(self) -> str:
|
||||
"""Format short-term memory context for the prompt."""
|
||||
memory = get_memory()
|
||||
lines = []
|
||||
|
||||
# Current workflow
|
||||
if memory.stm.current_workflow:
|
||||
wf = memory.stm.current_workflow
|
||||
lines.append(f"CURRENT WORKFLOW: {wf.get('type')}")
|
||||
lines.append(f" Target: {wf.get('target', {}).get('title', 'Unknown')}")
|
||||
lines.append(f" Stage: {wf.get('stage')}")
|
||||
workflow = memory.stm.current_workflow
|
||||
lines.append(f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})")
|
||||
if workflow.get('target'):
|
||||
lines.append(f" Target: {workflow.get('target')}")
|
||||
|
||||
# Current topic
|
||||
if memory.stm.current_topic:
|
||||
lines.append(f"CURRENT TOPIC: {memory.stm.current_topic}")
|
||||
|
||||
# Extracted entities
|
||||
if memory.stm.extracted_entities:
|
||||
entities_json = json.dumps(
|
||||
memory.stm.extracted_entities, ensure_ascii=False
|
||||
)
|
||||
lines.append(f"EXTRACTED ENTITIES: {entities_json}")
|
||||
lines.append("EXTRACTED ENTITIES:")
|
||||
for key, value in memory.stm.extracted_entities.items():
|
||||
lines.append(f" - {key}: {value}")
|
||||
|
||||
if memory.stm.language:
|
||||
lines.append(f"CONVERSATION LANGUAGE: {memory.stm.language}")
|
||||
|
||||
return "\n".join(lines) if lines else ""
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_config_context(self) -> str:
|
||||
"""Format configuration context."""
|
||||
memory = get_memory()
|
||||
|
||||
lines = ["CURRENT CONFIGURATION:"]
|
||||
if memory.ltm.config:
|
||||
for key, value in memory.ltm.config.items():
|
||||
lines.append(f" - {key}: {value}")
|
||||
else:
|
||||
lines.append(" (no configuration set)")
|
||||
return "\n".join(lines)
|
||||
|
||||
def build_system_prompt(self) -> str:
|
||||
"""
|
||||
Build the system prompt with context from memory.
|
||||
|
||||
Returns:
|
||||
The complete system prompt string.
|
||||
"""
|
||||
"""Build the complete system prompt."""
|
||||
memory = get_memory()
|
||||
|
||||
# Base instruction
|
||||
base = "You are a helpful AI assistant for managing a media library."
|
||||
|
||||
# Language instruction
|
||||
language_instruction = (
|
||||
"Your first task is to determine the user's language from their message "
|
||||
"and use the `set_language` tool if it's different from the current one. "
|
||||
"After that, proceed to help the user."
|
||||
)
|
||||
|
||||
# Available tools
|
||||
tools_desc = self._format_tools_description()
|
||||
params_desc = format_parameters_for_prompt()
|
||||
tools_section = f"\nAVAILABLE TOOLS:\n{tools_desc}" if tools_desc else ""
|
||||
|
||||
# Check for missing required parameters
|
||||
missing_params = get_missing_required_parameters({"config": memory.ltm.config})
|
||||
missing_info = ""
|
||||
if missing_params:
|
||||
missing_info = "\n\nMISSING REQUIRED PARAMETERS:\n"
|
||||
for param in missing_params:
|
||||
missing_info += f"- {param.key}: {param.description}\n"
|
||||
missing_info += f" Why needed: {param.why_needed}\n"
|
||||
# Configuration
|
||||
config_section = self._format_config_context()
|
||||
if config_section:
|
||||
config_section = f"\n{config_section}"
|
||||
|
||||
# Build context sections
|
||||
episodic_context = self._format_episodic_context()
|
||||
# STM context
|
||||
stm_context = self._format_stm_context()
|
||||
if stm_context:
|
||||
stm_context = f"\n{stm_context}"
|
||||
|
||||
config_json = json.dumps(memory.ltm.config, indent=2, ensure_ascii=False)
|
||||
|
||||
return f"""You are an AI agent helping a user manage their local media library.
|
||||
|
||||
{params_desc}
|
||||
|
||||
CURRENT CONFIGURATION:
|
||||
{config_json}
|
||||
{missing_info}
|
||||
|
||||
{f"SESSION CONTEXT:{chr(10)}{stm_context}" if stm_context else ""}
|
||||
|
||||
{f"CURRENT STATE:{chr(10)}{episodic_context}" if episodic_context else ""}
|
||||
# Episodic context
|
||||
episodic_context = self._format_episodic_context()
|
||||
|
||||
# Important rules
|
||||
rules = """
|
||||
IMPORTANT RULES:
|
||||
1. When the user refers to a number (e.g., "the 3rd one", "download number 2"), \
|
||||
use `add_torrent_by_index` or `get_torrent_by_index` with that number.
|
||||
2. If a torrent search was performed, results are numbered. \
|
||||
The user can reference them by number.
|
||||
3. To use a tool, respond STRICTLY with this JSON format:
|
||||
{{ "thought": "explanation", "action": {{ "name": "tool_name", "args": {{ }} }} }}
|
||||
- No text before or after the JSON
|
||||
4. You can use MULTIPLE TOOLS IN SEQUENCE.
|
||||
5. When you have all the information needed, respond in NATURAL TEXT (not JSON).
|
||||
6. If a required parameter is missing, ask the user for it.
|
||||
7. Respond in the same language as the user.
|
||||
|
||||
EXAMPLES:
|
||||
- After a torrent search, if the user says "download the 3rd one":
|
||||
{{ "thought": "User wants torrent #3", "action": {{ "name": "add_torrent_by_index", \
|
||||
"args": {{ "index": 3 }} }} }}
|
||||
|
||||
- To search for torrents:
|
||||
{{ "thought": "Searching torrents", "action": {{ "name": "find_torrents", \
|
||||
"args": {{ "media_title": "Inception 1080p" }} }} }}
|
||||
|
||||
AVAILABLE TOOLS:
|
||||
{tools_desc}
|
||||
- Use tools to accomplish tasks
|
||||
- When search results are available, reference them by index (e.g., "add_torrent_by_index")
|
||||
- Always confirm actions with the user before executing destructive operations
|
||||
- Provide clear, concise responses
|
||||
"""
|
||||
|
||||
# Examples
|
||||
examples = """
|
||||
EXAMPLES:
|
||||
- User: "Find Inception" → Use find_media_imdb_id, then find_torrent
|
||||
- User: "download the 3rd one" → Use add_torrent_by_index with index=3
|
||||
- User: "List my downloads" → Use list_folder with folder_type="download"
|
||||
"""
|
||||
|
||||
return f"""{base}
|
||||
|
||||
{language_instruction}
|
||||
{tools_section}
|
||||
{config_section}
|
||||
{stm_context}
|
||||
{episodic_context}
|
||||
{rules}
|
||||
{examples}
|
||||
"""
|
||||
|
||||
+92
-164
@@ -1,181 +1,109 @@
|
||||
"""Tool registry - defines and registers all available tools for the agent."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .tools import api as api_tools
|
||||
from .tools import filesystem as fs_tools
|
||||
from typing import Callable, Any, Dict
|
||||
import logging
|
||||
import inspect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tool:
|
||||
"""Represents a tool that can be used by the agent.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the tool.
|
||||
description: Human-readable description for the LLM.
|
||||
func: The callable that implements the tool.
|
||||
parameters: JSON Schema describing the tool's parameters.
|
||||
"""
|
||||
|
||||
"""Represents a tool that can be used by the agent."""
|
||||
name: str
|
||||
description: str
|
||||
func: Callable[..., dict[str, Any]]
|
||||
parameters: dict[str, Any]
|
||||
func: Callable[..., Dict[str, Any]]
|
||||
parameters: Dict[str, Any]
|
||||
|
||||
|
||||
def make_tools() -> dict[str, Tool]:
|
||||
def _create_tool_from_function(func: Callable) -> Tool:
|
||||
"""
|
||||
Create a Tool object from a function.
|
||||
|
||||
Args:
|
||||
func: Function to convert to a tool
|
||||
|
||||
Returns:
|
||||
Tool object with metadata extracted from function
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
doc = inspect.getdoc(func)
|
||||
|
||||
# Extract description from docstring (first line)
|
||||
description = doc.strip().split('\n')[0] if doc else func.__name__
|
||||
|
||||
# Build JSON schema from function signature
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
# Map Python types to JSON schema types
|
||||
param_type = "string" # default
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
if param.annotation == str:
|
||||
param_type = "string"
|
||||
elif param.annotation == int:
|
||||
param_type = "integer"
|
||||
elif param.annotation == float:
|
||||
param_type = "number"
|
||||
elif param.annotation == bool:
|
||||
param_type = "boolean"
|
||||
|
||||
properties[param_name] = {
|
||||
"type": param_type,
|
||||
"description": f"Parameter {param_name}"
|
||||
}
|
||||
|
||||
# Add to required if no default value
|
||||
if param.default == inspect.Parameter.empty:
|
||||
required.append(param_name)
|
||||
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
return Tool(
|
||||
name=func.__name__,
|
||||
description=description,
|
||||
func=func,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
|
||||
def make_tools() -> Dict[str, Tool]:
|
||||
"""
|
||||
Create and register all available tools.
|
||||
|
||||
Tools access memory via get_memory() context.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool names to Tool instances.
|
||||
Dictionary mapping tool names to Tool objects
|
||||
"""
|
||||
tools = [
|
||||
# Filesystem tools
|
||||
Tool(
|
||||
name="set_path_for_folder",
|
||||
description=(
|
||||
"Sets a path in the configuration "
|
||||
"(download_folder, tvshow_folder, movie_folder, or torrent_folder)."
|
||||
),
|
||||
func=fs_tools.set_path_for_folder,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder_name": {
|
||||
"type": "string",
|
||||
"description": "Name of folder to set",
|
||||
"enum": ["download", "tvshow", "movie", "torrent"],
|
||||
},
|
||||
"path_value": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the folder",
|
||||
},
|
||||
},
|
||||
"required": ["folder_name", "path_value"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="list_folder",
|
||||
description="Lists the contents of a configured folder.",
|
||||
func=fs_tools.list_folder,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder_type": {
|
||||
"type": "string",
|
||||
"description": "Type of folder to list",
|
||||
"enum": ["download", "tvshow", "movie", "torrent"],
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path within the folder",
|
||||
"default": ".",
|
||||
},
|
||||
},
|
||||
"required": ["folder_type"],
|
||||
},
|
||||
),
|
||||
# Media search tools
|
||||
Tool(
|
||||
name="find_media_imdb_id",
|
||||
description=(
|
||||
"Finds the IMDb ID for a given media title using TMDB API. "
|
||||
"Use this to get information about a movie or TV show."
|
||||
),
|
||||
func=api_tools.find_media_imdb_id,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"media_title": {
|
||||
"type": "string",
|
||||
"description": "Title of the media to search for",
|
||||
},
|
||||
},
|
||||
"required": ["media_title"],
|
||||
},
|
||||
),
|
||||
# Torrent tools
|
||||
Tool(
|
||||
name="find_torrents",
|
||||
description=(
|
||||
"Finds torrents for a given media title. "
|
||||
"Results are numbered (1, 2, 3...) so the user can select by number."
|
||||
),
|
||||
func=api_tools.find_torrent,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"media_title": {
|
||||
"type": "string",
|
||||
"description": "Title to search for (include quality if specified)",
|
||||
},
|
||||
},
|
||||
"required": ["media_title"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="add_torrent_by_index",
|
||||
description=(
|
||||
"Adds a torrent from the previous search results by its number. "
|
||||
"Use when the user says 'download the 3rd one' or 'take number 2'."
|
||||
),
|
||||
func=api_tools.add_torrent_by_index,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index": {
|
||||
"type": "integer",
|
||||
"description": "Number of the torrent in search results (1, 2, 3...)",
|
||||
},
|
||||
},
|
||||
"required": ["index"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="add_torrent_to_qbittorrent",
|
||||
description=(
|
||||
"Adds a torrent to qBittorrent using a magnet link directly. "
|
||||
"Use add_torrent_by_index if user selected from search results."
|
||||
),
|
||||
func=api_tools.add_torrent_to_qbittorrent,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"magnet_link": {
|
||||
"type": "string",
|
||||
"description": "The magnet link of the torrent",
|
||||
},
|
||||
},
|
||||
"required": ["magnet_link"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_torrent_by_index",
|
||||
description=(
|
||||
"Gets details of a torrent from search results by its number, "
|
||||
"without downloading it."
|
||||
),
|
||||
func=api_tools.get_torrent_by_index,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index": {
|
||||
"type": "integer",
|
||||
"description": "Number of the torrent in search results (1, 2, 3...)",
|
||||
},
|
||||
},
|
||||
"required": ["index"],
|
||||
},
|
||||
),
|
||||
# Import tools here to avoid circular dependencies
|
||||
from .tools import filesystem as fs_tools
|
||||
from .tools import api as api_tools
|
||||
from .tools import language as lang_tools
|
||||
|
||||
# List of all tool functions
|
||||
tool_functions = [
|
||||
fs_tools.set_path_for_folder,
|
||||
fs_tools.list_folder,
|
||||
api_tools.find_media_imdb_id,
|
||||
api_tools.find_torrent,
|
||||
api_tools.add_torrent_by_index,
|
||||
api_tools.add_torrent_to_qbittorrent,
|
||||
api_tools.get_torrent_by_index,
|
||||
lang_tools.set_language,
|
||||
]
|
||||
|
||||
logger.info(f"Registered {len(tools)} tools")
|
||||
return {t.name: t for t in tools}
|
||||
|
||||
# Create Tool objects from functions
|
||||
tools = {}
|
||||
for func in tool_functions:
|
||||
tool = _create_tool_from_function(func)
|
||||
tools[tool.name] = tool
|
||||
|
||||
logger.info(f"Registered {len(tools)} tools: {list(tools.keys())}")
|
||||
return tools
|
||||
|
||||
@@ -8,6 +8,7 @@ from .api import (
|
||||
get_torrent_by_index,
|
||||
)
|
||||
from .filesystem import list_folder, set_path_for_folder
|
||||
from .language import set_language
|
||||
|
||||
__all__ = [
|
||||
"set_path_for_folder",
|
||||
@@ -17,4 +18,5 @@ __all__ = [
|
||||
"get_torrent_by_index",
|
||||
"add_torrent_to_qbittorrent",
|
||||
"add_torrent_by_index",
|
||||
"set_language",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Language management tools for the agent."""
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_language(language: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Set the conversation language.
|
||||
|
||||
Args:
|
||||
language: Language code (e.g., 'en', 'fr', 'es', 'de')
|
||||
|
||||
Returns:
|
||||
Status dictionary
|
||||
"""
|
||||
try:
|
||||
memory = get_memory()
|
||||
memory.stm.set_language(language)
|
||||
memory.save()
|
||||
|
||||
logger.info(f"Language set to: {language}")
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"Language set to {language}",
|
||||
"language": language
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set language: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
Reference in New Issue
Block a user