diff --git a/alfred/agent/agent.py b/alfred/agent/agent.py index 31a4252..cfa9ebe 100644 --- a/alfred/agent/agent.py +++ b/alfred/agent/agent.py @@ -140,7 +140,7 @@ class Agent: memory.save() return final_response - def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]: + def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]: # noqa: PLR0911 """ Execute a single tool call. @@ -183,25 +183,86 @@ class Agent: } tool = self.tools[tool_name] + memory = get_memory() + + # Cache lookup — for tools flagged cacheable, short-circuit on hit. + cache_key_value = self._cache_key_for(tool, args) + if cache_key_value is not None: + cached = memory.stm.tool_results.get(tool_name, cache_key_value) + if cached is not None: + logger.info( + f"Tool cache HIT: {tool_name}[{cache_key_value}]" + ) + self._post_tool_side_effects(tool_name, args, cached, from_cache=True) + return {**cached, "_from_cache": True} # Execute tool try: result = tool.func(**args) - return result except KeyboardInterrupt: # Don't catch KeyboardInterrupt - let it propagate raise except TypeError as e: # Bad arguments - memory = get_memory() memory.episodic.add_error(tool_name, f"bad_args: {e}") return {"error": "bad_args", "message": str(e), "tool": tool_name} except Exception as e: # Other errors - memory = get_memory() memory.episodic.add_error(tool_name, str(e)) return {"error": "execution_failed", "message": str(e), "tool": tool_name} + # Persist + side effects only on successful results. + if isinstance(result, dict) and result.get("status") == "ok": + if cache_key_value is not None: + memory.stm.tool_results.put(tool_name, cache_key_value, result) + self._post_tool_side_effects(tool_name, args, result, from_cache=False) + memory.save() + + return result + + @staticmethod + def _cache_key_for(tool: Tool, args: dict[str, Any]) -> str | None: + """Return the cache key value for this call, or None if not cacheable.""" + if tool.cache_key is None: + return None + value = args.get(tool.cache_key) + if value is None: + return None + return str(value) + + def _post_tool_side_effects( + self, + tool_name: str, + args: dict[str, Any], + result: dict[str, Any], + *, + from_cache: bool, + ) -> None: + """ + Tool-agnostic side effects applied after a successful run or cache hit. + + Today: + - Update release_focus when a path-keyed inspector runs. + - Refresh episodic.last_search_results on find_torrent cache hits so + get_torrent_by_index keeps pointing at the right list. + """ + memory = get_memory() + tool = self.tools.get(tool_name) + + # Release focus: any path-keyed inspector updates current_release_path. + if tool is not None and tool.cache_key in {"source_path"}: + path = args.get(tool.cache_key) + if isinstance(path, str) and path: + memory.stm.release_focus.focus(path) + + # Episodic refresh when find_torrent's cache short-circuits the call. + if from_cache and tool_name == "find_torrent": + torrents = result.get("torrents") or [] + query = args.get("media_title") or "" + memory.episodic.store_search_results( + query=query, results=torrents, search_type="torrent" + ) + async def step_streaming( self, user_input: str, completion_id: str, created_ts: int, model: str ) -> AsyncGenerator[dict[str, Any]]: diff --git a/alfred/agent/registry.py b/alfred/agent/registry.py index e45480b..d36b0ee 100644 --- a/alfred/agent/registry.py +++ b/alfred/agent/registry.py @@ -20,6 +20,7 @@ class Tool: description: str func: Callable[..., dict[str, Any]] parameters: dict[str, Any] + cache_key: str | None = None # Parameter name to use as STM cache key. _PY_TYPE_TO_JSON = { @@ -84,11 +85,14 @@ def _create_tool_from_function(func: Callable, spec: ToolSpec | None = None) -> "required": required, } + cache_key = spec.cache.key if spec is not None and spec.cache is not None else None + return Tool( name=func.__name__, description=description, func=func, parameters=parameters, + cache_key=cache_key, ) diff --git a/alfred/agent/tools/spec.py b/alfred/agent/tools/spec.py index 8fef324..4c6e9b8 100644 --- a/alfred/agent/tools/spec.py +++ b/alfred/agent/tools/spec.py @@ -61,6 +61,18 @@ class ReturnsSpec: ) +@dataclass(frozen=True) +class CacheSpec: + """Marks a tool as cacheable in STM.tool_results, keyed by one of its parameters.""" + + key: str # Name of the parameter whose value is the cache key. + + @classmethod + def from_dict(cls, data: dict) -> CacheSpec: + _require(data, "key", "cache") + return cls(key=str(data["key"]).strip()) + + @dataclass(frozen=True) class ToolSpec: """Full semantic spec for one tool.""" @@ -73,6 +85,7 @@ class ToolSpec: next_steps: str | None parameters: dict[str, ParameterSpec] # name -> ParameterSpec returns: dict[str, ReturnsSpec] # status_key -> ReturnsSpec + cache: CacheSpec | None = None # If present, tool is cached. @classmethod def from_yaml_path(cls, path: Path) -> ToolSpec: @@ -108,7 +121,12 @@ class ToolSpec: for rkey, rdata in returns_raw.items() } - return cls( + cache_raw = data.get("cache") + if cache_raw is not None and not isinstance(cache_raw, dict): + raise ToolSpecError("cache must be a mapping") + cache = CacheSpec.from_dict(cache_raw) if cache_raw else None + + spec = cls( name=str(data["name"]).strip(), summary=str(data["summary"]).strip(), description=str(data["description"]).strip(), @@ -117,7 +135,14 @@ class ToolSpec: next_steps=_strip_or_none(data.get("next_steps")), parameters=parameters, returns=returns, + cache=cache, ) + if cache is not None and cache.key not in parameters: + raise ToolSpecError( + f"cache.key '{cache.key}' is not a declared parameter " + f"(declared: {sorted(parameters)})" + ) + return spec def compile_description(self) -> str: """ diff --git a/alfred/agent/tools/specs/analyze_release.yaml b/alfred/agent/tools/specs/analyze_release.yaml index 352e061..c701fc6 100644 --- a/alfred/agent/tools/specs/analyze_release.yaml +++ b/alfred/agent/tools/specs/analyze_release.yaml @@ -37,6 +37,9 @@ next_steps: | - media_type in (other, unknown) → ask the user what to do; do not auto-route. +cache: + key: source_path + parameters: release_name: description: Raw release folder or file name as it appears on disk. diff --git a/alfred/agent/tools/specs/find_media_imdb_id.yaml b/alfred/agent/tools/specs/find_media_imdb_id.yaml index 6d2e207..899f765 100644 --- a/alfred/agent/tools/specs/find_media_imdb_id.yaml +++ b/alfred/agent/tools/specs/find_media_imdb_id.yaml @@ -27,6 +27,9 @@ next_steps: | - On status=error (not_found): show the error and ask the user for a more precise title. +cache: + key: media_title + parameters: media_title: description: Title to search for. Free-form — TMDB does the matching. diff --git a/alfred/agent/tools/specs/find_torrent.yaml b/alfred/agent/tools/specs/find_torrent.yaml index a568939..a14cd9c 100644 --- a/alfred/agent/tools/specs/find_torrent.yaml +++ b/alfred/agent/tools/specs/find_torrent.yaml @@ -27,6 +27,9 @@ next_steps: | - Once chosen: call add_torrent_by_index(N) — that wraps get_torrent_by_index + add_torrent_to_qbittorrent. +cache: + key: media_title + parameters: media_title: description: Title to search for on Knaben. Free-form. diff --git a/alfred/agent/tools/specs/list_folder.yaml b/alfred/agent/tools/specs/list_folder.yaml index 67600a3..445fcae 100644 --- a/alfred/agent/tools/specs/list_folder.yaml +++ b/alfred/agent/tools/specs/list_folder.yaml @@ -29,6 +29,9 @@ next_steps: | - After listing a library folder: use the result to disambiguate a destination during resolve_*_destination. +cache: + key: path + parameters: folder_type: description: Logical folder key (download, torrent, movie, tv_show, ...). diff --git a/alfred/agent/tools/specs/probe_media.yaml b/alfred/agent/tools/specs/probe_media.yaml index a0c7a74..8721d61 100644 --- a/alfred/agent/tools/specs/probe_media.yaml +++ b/alfred/agent/tools/specs/probe_media.yaml @@ -27,6 +27,9 @@ next_steps: | "this is 7.1 DTS, want to keep it?"); rarely chained directly to another tool. +cache: + key: source_path + parameters: source_path: description: Absolute path to the video file to probe. diff --git a/alfred/infrastructure/persistence/memory/stm/components/__init__.py b/alfred/infrastructure/persistence/memory/stm/components/__init__.py index 22cddcc..5b0e737 100644 --- a/alfred/infrastructure/persistence/memory/stm/components/__init__.py +++ b/alfred/infrastructure/persistence/memory/stm/components/__init__.py @@ -1,5 +1,13 @@ from .conversation import Conversation from .entities import Entities +from .release_focus import ReleaseFocus +from .tool_results import ToolResultsCache from .workflow import Workflow -__all__ = ["Conversation", "Workflow", "Entities"] +__all__ = [ + "Conversation", + "Entities", + "ReleaseFocus", + "ToolResultsCache", + "Workflow", +] diff --git a/alfred/infrastructure/persistence/memory/stm/components/release_focus.py b/alfred/infrastructure/persistence/memory/stm/components/release_focus.py new file mode 100644 index 0000000..3a3a4c5 --- /dev/null +++ b/alfred/infrastructure/persistence/memory/stm/components/release_focus.py @@ -0,0 +1,68 @@ +"""ReleaseFocus — which release(s) the conversation is currently about. + +Two slots: + - current_release_path: the single release the user is actively + working on. Updated automatically when an inspector tool + (analyze_release, probe_media, ...) is called on a path. + - working_set: every release path mentioned in this session, in + insertion order. Useful to answer "what were we looking at?". + +Only paths are stored — full metadata lives in each release's +`.alfred/metadata.yaml`. Read it via a dedicated tool when needed. +""" + +import logging +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + + +@dataclass +class ReleaseFocus: + current_release_path: str | None = None + working_set: list[str] = field(default_factory=list) + + def focus(self, path: str) -> None: + """Set the current release and add it to the working set.""" + self.current_release_path = path + if path not in self.working_set: + self.working_set.append(path) + logger.debug(f"ReleaseFocus: added '{path}' to working set") + logger.debug(f"ReleaseFocus: current -> '{path}'") + + def clear(self) -> None: + self.current_release_path = None + self.working_set = [] + + @classmethod + def describe(cls) -> dict: + return { + "name": "ReleaseFocus", + "tier": "stm", + "access": "read", + "description": ( + "Tracks which release the conversation is centred on. " + "Updated automatically when an inspector tool runs on a " + "release path. Read 'current_release_path' to know what " + "the user is talking about without re-asking; read " + "'working_set' for the list of releases touched in this " + "session." + ), + "fields": { + "current_release_path": ( + "Absolute path of the release currently in focus, or " + "None if no release has been inspected yet." + ), + "working_set": ( + "Ordered list of every release path inspected during " + "the session. Refer to entries by their basename when " + "talking to the user." + ), + }, + } + + def to_dict(self) -> dict: + return { + "current_release_path": self.current_release_path, + "working_set": list(self.working_set), + } diff --git a/alfred/infrastructure/persistence/memory/stm/components/tool_results.py b/alfred/infrastructure/persistence/memory/stm/components/tool_results.py new file mode 100644 index 0000000..62eb51f --- /dev/null +++ b/alfred/infrastructure/persistence/memory/stm/components/tool_results.py @@ -0,0 +1,69 @@ +"""ToolResultsCache — per-session cache of read-only tool results. + +Tools flagged with `cache.key` in their YAML spec write their results +here automatically (via Agent._execute_tool_call). Lookup happens +before execution: a hit returns the cached value without re-calling +the underlying tool. + +Lives for the duration of the session — cleared on memory.clear_session() +or on a fresh process start. +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolResultsCache: + # {tool_name: {key: result}} + results: dict[str, dict[str, Any]] = field(default_factory=dict) + + def get(self, tool_name: str, key: str) -> dict | None: + bucket = self.results.get(tool_name) + if bucket is None: + return None + return bucket.get(key) + + def put(self, tool_name: str, key: str, result: dict) -> None: + self.results.setdefault(tool_name, {})[key] = result + logger.debug(f"ToolResultsCache: stored {tool_name}[{key}]") + + def keys_for(self, tool_name: str) -> list[str]: + return list(self.results.get(tool_name, {}).keys()) + + def clear(self) -> None: + self.results = {} + + @classmethod + def describe(cls) -> dict: + return { + "name": "ToolResultsCache", + "tier": "stm", + "access": "read", + "description": ( + "Per-session cache of read-only tool results, keyed by tool " + "name and a natural identifier from the tool's arguments " + "(usually a path or a title). Hits short-circuit the next " + "call to the same tool with the same key. Populated " + "automatically by the agent — do not write to it directly. " + "Consult it before suggesting a probe / analyze you may " + "have already done in this session." + ), + "fields": { + "results": ( + "Nested dict: {tool_name: {cache_key: result}}. " + "Read-only inventory; the agent should reference cached " + "values rather than re-calling the tool." + ), + }, + } + + def to_dict(self) -> dict: + # Surface only the index (tool + keys), not the payloads — payloads + # can be large and the prompt only needs to know what's available. + return { + tool: list(bucket.keys()) for tool, bucket in self.results.items() + } diff --git a/alfred/infrastructure/persistence/memory/stm/stm.py b/alfred/infrastructure/persistence/memory/stm/stm.py index 5338228..a25bcf2 100644 --- a/alfred/infrastructure/persistence/memory/stm/stm.py +++ b/alfred/infrastructure/persistence/memory/stm/stm.py @@ -3,7 +3,13 @@ import logging from dataclasses import dataclass, field -from .components import Conversation, Entities, Workflow +from .components import ( + Conversation, + Entities, + ReleaseFocus, + ToolResultsCache, + Workflow, +) logger = logging.getLogger(__name__) @@ -22,6 +28,8 @@ class ShortTermMemory: conversation: Conversation = field(default_factory=Conversation) workflow: Workflow = field(default_factory=Workflow) entities: Entities = field(default_factory=Entities) + tool_results: ToolResultsCache = field(default_factory=ToolResultsCache) + release_focus: ReleaseFocus = field(default_factory=ReleaseFocus) # Convenience proxies kept for backward compatibility with existing callers @property @@ -79,6 +87,8 @@ class ShortTermMemory: self.conversation.clear() self.workflow.clear() self.entities.clear() + self.tool_results.clear() + self.release_focus.clear() logger.info("STM: Cleared") def to_dict(self) -> dict: @@ -88,4 +98,6 @@ class ShortTermMemory: "extracted_entities": self.entities.data, "current_topic": self.entities.topic, "language": self.conversation.language, + "tool_results": self.tool_results.to_dict(), + "release_focus": self.release_focus.to_dict(), }