Files
alfred/alfred/agent/agent.py
T

512 lines
18 KiB
Python

"""Main agent for media library management."""
import json
import logging
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Any
from alfred.infrastructure.metadata import MetadataStore
from alfred.infrastructure.persistence import get_memory
from alfred.settings import settings
from .prompt import PromptBuilder
from .registry import Tool, make_tools
from .workflows import WorkflowLoader
logger = logging.getLogger(__name__)
class Agent:
"""
AI agent for media library management.
Uses OpenAI-compatible tool calling API.
"""
def __init__(self, settings, llm, max_tool_iterations: int = 5):
"""
Initialize the agent.
Args:
settings: Application settings instance
llm: LLM client with complete() method
max_tool_iterations: Maximum number of tool execution iterations
"""
self.settings = settings
self.llm = llm
self.tools: dict[str, Tool] = make_tools(settings)
self.workflow_loader = WorkflowLoader()
self.prompt_builder = PromptBuilder(self.tools, self.workflow_loader)
self.max_tool_iterations = max_tool_iterations
def step(self, user_input: str) -> str:
"""
Execute one agent step with the user input.
This method:
1. Adds user message to memory
2. Builds prompt with history and context
3. Calls LLM, executing tools as needed
4. Returns final response
Args:
user_input: User's message
Returns:
Agent's final response
"""
memory = get_memory()
# Add user message to history
memory.stm.add_message("user", user_input)
memory.save()
# Build initial messages
system_prompt = self.prompt_builder.build_system_prompt()
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
# Add conversation history
history = memory.stm.get_recent_history(settings.max_history_messages)
messages.extend(history)
# Add unread events if any
unread_events = memory.episodic.get_unread_events()
if unread_events:
events_text = "\n".join(
[f"- {e['type']}: {e['data']}" for e in unread_events]
)
messages.append(
{"role": "system", "content": f"Background events:\n{events_text}"}
)
# Get tools specification for OpenAI format
tools_spec = self.prompt_builder.build_tools_spec()
# Tool execution loop
for _iteration in range(self.settings.max_tool_iterations):
# Call LLM with tools
llm_result = self.llm.complete(messages, tools=tools_spec)
# Handle both tuple (response, usage) and dict response
if isinstance(llm_result, tuple):
response_message, usage = llm_result
else:
response_message = llm_result
# Check if there are tool calls
tool_calls = response_message.get("tool_calls")
if not tool_calls:
# No tool calls, this is the final response
final_content = response_message.get("content", "")
memory.stm.add_message("assistant", final_content)
memory.save()
return final_content
# Add assistant message with tool calls to conversation
messages.append(response_message)
# Execute each tool call
for tool_call in tool_calls:
tool_result = self._execute_tool_call(tool_call)
# Add tool result to messages
messages.append(
{
"tool_call_id": tool_call.get("id"),
"role": "tool",
"name": tool_call.get("function", {}).get("name"),
"content": json.dumps(tool_result, ensure_ascii=False),
}
)
# Max iterations reached, force final response
messages.append(
{
"role": "system",
"content": "Please provide a final response to the user without using any more tools.",
}
)
llm_result = self.llm.complete(messages)
if isinstance(llm_result, tuple):
final_message, usage = llm_result
else:
final_message = llm_result
final_response = final_message.get(
"content", "I've completed the requested actions."
)
memory.stm.add_message("assistant", final_response)
memory.save()
return final_response
def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]: # noqa: PLR0911
"""
Execute a single tool call.
Args:
tool_call: OpenAI-format tool call dict
Returns:
Result dictionary
"""
function = tool_call.get("function", {})
tool_name = function.get("name", "")
try:
args_str = function.get("arguments", "{}")
args = json.loads(args_str)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse tool arguments: {e}")
return {"error": "bad_args", "message": f"Invalid JSON arguments: {e}"}
# Validate tool exists
if tool_name not in self.tools:
available = list(self.tools.keys())
return {
"error": "unknown_tool",
"message": f"Tool '{tool_name}' not found",
"available_tools": available,
}
# Defensive: reject calls to tools that are not currently in scope.
visible = set(self.prompt_builder.visible_tool_names())
if tool_name not in visible:
return {
"error": "tool_out_of_scope",
"message": (
f"Tool '{tool_name}' is not available in the current "
"workflow scope. Call end_workflow first or start the "
"appropriate workflow."
),
"available_tools": sorted(visible),
}
tool = self.tools[tool_name]
memory = get_memory()
# Cache lookup — for tools flagged cacheable, short-circuit on hit.
cache_key_value = self._cache_key_for(tool, args)
if cache_key_value is not None:
cached = memory.stm.tool_results.get(tool_name, cache_key_value)
if cached is not None:
logger.info(f"Tool cache HIT: {tool_name}[{cache_key_value}]")
self._post_tool_side_effects(tool_name, args, cached, from_cache=True)
return {**cached, "_from_cache": True}
# Execute tool
try:
result = tool.func(**args)
except KeyboardInterrupt:
# Don't catch KeyboardInterrupt - let it propagate
raise
except TypeError as e:
# Bad arguments
memory.episodic.add_error(tool_name, f"bad_args: {e}")
return {"error": "bad_args", "message": str(e), "tool": tool_name}
except Exception as e:
# Other errors
memory.episodic.add_error(tool_name, str(e))
return {"error": "execution_failed", "message": str(e), "tool": tool_name}
# Persist + side effects only on successful results.
if isinstance(result, dict) and result.get("status") == "ok":
if cache_key_value is not None:
memory.stm.tool_results.put(tool_name, cache_key_value, result)
self._post_tool_side_effects(tool_name, args, result, from_cache=False)
memory.save()
return result
@staticmethod
def _cache_key_for(tool: Tool, args: dict[str, Any]) -> str | None:
"""Return the cache key value for this call, or None if not cacheable."""
if tool.cache_key is None:
return None
value = args.get(tool.cache_key)
if value is None:
return None
return str(value)
def _post_tool_side_effects(
self,
tool_name: str,
args: dict[str, Any],
result: dict[str, Any],
*,
from_cache: bool,
) -> None:
"""
Tool-agnostic side effects applied after a successful run or cache hit.
Today:
- Update release_focus when a path-keyed inspector runs.
- Persist inspector results into the release's `.alfred/metadata.yaml`.
- Refresh episodic.last_search_results on find_torrent cache hits so
get_torrent_by_index keeps pointing at the right list.
"""
memory = get_memory()
tool = self.tools.get(tool_name)
# Release focus: any path-keyed inspector updates current_release_path.
if tool is not None and tool.cache_key in {"source_path"}:
path = args.get(tool.cache_key)
if isinstance(path, str) and path:
memory.stm.release_focus.focus(path)
# Persist inspector results to .alfred/metadata.yaml (skip on cache
# hit — the file is already up to date from the original run).
if not from_cache:
self._maybe_update_alfred(tool_name, args, result)
# Episodic refresh when find_torrent's cache short-circuits the call.
if from_cache and tool_name == "find_torrent":
torrents = result.get("torrents") or []
query = args.get("media_title") or ""
memory.episodic.store_search_results(
query=query, results=torrents, search_type="torrent"
)
def _maybe_update_alfred(
self,
tool_name: str,
args: dict[str, Any],
result: dict[str, Any],
) -> None:
"""
Persist a successful inspector result into the release's
`.alfred/metadata.yaml`. No-op when the release root can't be resolved.
"""
if tool_name not in {"analyze_release", "probe_media", "find_media_imdb_id"}:
return
release_root = self._resolve_release_root(tool_name, args)
if release_root is None:
return
try:
store = MetadataStore(release_root)
if tool_name == "analyze_release":
store.update_parse(result)
elif tool_name == "probe_media":
store.update_probe(result)
elif tool_name == "find_media_imdb_id":
store.update_tmdb(result)
except Exception as e:
logger.warning(
f"Failed to update .alfred for {tool_name} at {release_root}: {e}"
)
@staticmethod
def _resolve_release_root(
tool_name: str,
args: dict[str, Any],
) -> Path | None:
"""
Figure out which release folder owns this call.
- analyze_release / probe_media: derived from source_path
(folder kept as-is, file walked up to its parent).
- find_media_imdb_id: follow the current release focus in STM.
"""
if tool_name in {"analyze_release", "probe_media"}:
raw = args.get("source_path")
if not isinstance(raw, str) or not raw:
return None
path = Path(raw)
return path if path.is_dir() else path.parent
# find_media_imdb_id has no path arg — rely on release focus.
focus = get_memory().stm.release_focus.current_release_path
if not focus:
return None
path = Path(focus)
return path if path.is_dir() else path.parent
async def step_streaming(
self, user_input: str, completion_id: str, created_ts: int, model: str
) -> AsyncGenerator[dict[str, Any]]:
"""
Execute agent step with streaming support for LibreChat.
Yields SSE chunks for tool calls and final response.
Args:
user_input: User's message
completion_id: Completion ID for the response
created_ts: Timestamp for the response
model: Model name
Yields:
SSE chunks in OpenAI format
"""
memory = get_memory()
# Add user message to history
memory.stm.add_message("user", user_input)
memory.save()
# Build initial messages
system_prompt = self.prompt_builder.build_system_prompt()
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
# Add conversation history
history = memory.stm.get_recent_history(settings.max_history_messages)
messages.extend(history)
# Add unread events if any
unread_events = memory.episodic.get_unread_events()
if unread_events:
events_text = "\n".join(
[f"- {e['type']}: {e['data']}" for e in unread_events]
)
messages.append(
{"role": "system", "content": f"Background events:\n{events_text}"}
)
# Get tools specification for OpenAI format
tools_spec = self.prompt_builder.build_tools_spec()
# Tool execution loop
for _iteration in range(self.settings.max_tool_iterations):
# Call LLM with tools
llm_result = self.llm.complete(messages, tools=tools_spec)
# Handle both tuple (response, usage) and dict response
if isinstance(llm_result, tuple):
response_message, usage = llm_result
else:
response_message = llm_result
# Check if there are tool calls
tool_calls = response_message.get("tool_calls")
if not tool_calls:
# No tool calls, this is the final response
final_content = response_message.get("content", "")
memory.stm.add_message("assistant", final_content)
memory.save()
# Stream the final response
yield {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": final_content},
"finish_reason": "stop",
}
],
}
return
# Stream tool calls
for tool_call in tool_calls:
function = tool_call.get("function", {})
tool_name = function.get("name", "")
tool_args = function.get("arguments", "{}")
# Yield chunk indicating tool call
yield {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [
{
"index": 0,
"id": tool_call.get("id"),
"type": "function",
"function": {
"name": tool_name,
"arguments": tool_args,
},
}
]
},
"finish_reason": None,
}
],
}
# Add assistant message with tool calls to conversation
messages.append(response_message)
# Execute each tool call and stream results
for tool_call in tool_calls:
tool_result = self._execute_tool_call(tool_call)
function = tool_call.get("function", {})
tool_name = function.get("name", "")
# Add tool result to messages
messages.append(
{
"tool_call_id": tool_call.get("id"),
"role": "tool",
"name": tool_name,
"content": json.dumps(tool_result, ensure_ascii=False),
}
)
# Stream tool result as content
result_text = (
f"\n🔧 {tool_name}: {json.dumps(tool_result, ensure_ascii=False)}\n"
)
yield {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": result_text},
"finish_reason": None,
}
],
}
# Max iterations reached, force final response
messages.append(
{
"role": "system",
"content": "Please provide a final response to the user without using any more tools.",
}
)
llm_result = self.llm.complete(messages)
if isinstance(llm_result, tuple):
final_message, usage = llm_result
else:
final_message = llm_result
final_response = final_message.get(
"content", "I've completed the requested actions."
)
memory.stm.add_message("assistant", final_response)
memory.save()
# Stream final response
yield {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": final_response},
"finish_reason": "stop",
}
],
}