feat(memory): Phase 1 — STM ToolResultsCache + ReleaseFocus + cache flag in YAML specs

Adds two STM components and a transparent cache hook in the agent loop so
read-only tools don't re-do work the agent already did in this session.

New STM components:
  - ToolResultsCache  — {tool_name: {key: result}}, session-scoped.
    to_dict() exposes only the key inventory (not payloads) to keep the
    prompt cheap.
  - ReleaseFocus      — current_release_path + working_set list, updated
    automatically when a path-keyed inspector runs.

YAML spec layer:
  - New optional 'cache: { key: <param_name> }' block in ToolSpec.
  - Validated at load time: cache.key must be a declared parameter.
  - Surfaced on Tool dataclass as cache_key: str | None.

Agent._execute_tool_call:
  - Pre-exec cache lookup; hit short-circuits and adds _from_cache=true.
  - Post-exec: stores successful results, updates release_focus for
    path-keyed tools, refreshes episodic.last_search_results when
    find_torrent's hit served the response (so get_torrent_by_index
    keeps pointing at the right list).

Cacheable tools (5): analyze_release, probe_media, list_folder,
find_media_imdb_id, find_torrent.
This commit is contained in:
2026-05-15 10:44:14 +02:00
parent 2db3198ef2
commit 3c7c6695f2
12 changed files with 269 additions and 7 deletions
+65 -4
View File
@@ -140,7 +140,7 @@ class Agent:
memory.save() memory.save()
return final_response return final_response
def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]: def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]: # noqa: PLR0911
""" """
Execute a single tool call. Execute a single tool call.
@@ -183,25 +183,86 @@ class Agent:
} }
tool = self.tools[tool_name] tool = self.tools[tool_name]
memory = get_memory()
# Cache lookup — for tools flagged cacheable, short-circuit on hit.
cache_key_value = self._cache_key_for(tool, args)
if cache_key_value is not None:
cached = memory.stm.tool_results.get(tool_name, cache_key_value)
if cached is not None:
logger.info(
f"Tool cache HIT: {tool_name}[{cache_key_value}]"
)
self._post_tool_side_effects(tool_name, args, cached, from_cache=True)
return {**cached, "_from_cache": True}
# Execute tool # Execute tool
try: try:
result = tool.func(**args) result = tool.func(**args)
return result
except KeyboardInterrupt: except KeyboardInterrupt:
# Don't catch KeyboardInterrupt - let it propagate # Don't catch KeyboardInterrupt - let it propagate
raise raise
except TypeError as e: except TypeError as e:
# Bad arguments # Bad arguments
memory = get_memory()
memory.episodic.add_error(tool_name, f"bad_args: {e}") memory.episodic.add_error(tool_name, f"bad_args: {e}")
return {"error": "bad_args", "message": str(e), "tool": tool_name} return {"error": "bad_args", "message": str(e), "tool": tool_name}
except Exception as e: except Exception as e:
# Other errors # Other errors
memory = get_memory()
memory.episodic.add_error(tool_name, str(e)) memory.episodic.add_error(tool_name, str(e))
return {"error": "execution_failed", "message": str(e), "tool": tool_name} return {"error": "execution_failed", "message": str(e), "tool": tool_name}
# Persist + side effects only on successful results.
if isinstance(result, dict) and result.get("status") == "ok":
if cache_key_value is not None:
memory.stm.tool_results.put(tool_name, cache_key_value, result)
self._post_tool_side_effects(tool_name, args, result, from_cache=False)
memory.save()
return result
@staticmethod
def _cache_key_for(tool: Tool, args: dict[str, Any]) -> str | None:
"""Return the cache key value for this call, or None if not cacheable."""
if tool.cache_key is None:
return None
value = args.get(tool.cache_key)
if value is None:
return None
return str(value)
def _post_tool_side_effects(
self,
tool_name: str,
args: dict[str, Any],
result: dict[str, Any],
*,
from_cache: bool,
) -> None:
"""
Tool-agnostic side effects applied after a successful run or cache hit.
Today:
- Update release_focus when a path-keyed inspector runs.
- Refresh episodic.last_search_results on find_torrent cache hits so
get_torrent_by_index keeps pointing at the right list.
"""
memory = get_memory()
tool = self.tools.get(tool_name)
# Release focus: any path-keyed inspector updates current_release_path.
if tool is not None and tool.cache_key in {"source_path"}:
path = args.get(tool.cache_key)
if isinstance(path, str) and path:
memory.stm.release_focus.focus(path)
# Episodic refresh when find_torrent's cache short-circuits the call.
if from_cache and tool_name == "find_torrent":
torrents = result.get("torrents") or []
query = args.get("media_title") or ""
memory.episodic.store_search_results(
query=query, results=torrents, search_type="torrent"
)
async def step_streaming( async def step_streaming(
self, user_input: str, completion_id: str, created_ts: int, model: str self, user_input: str, completion_id: str, created_ts: int, model: str
) -> AsyncGenerator[dict[str, Any]]: ) -> AsyncGenerator[dict[str, Any]]:
+4
View File
@@ -20,6 +20,7 @@ class Tool:
description: str description: str
func: Callable[..., dict[str, Any]] func: Callable[..., dict[str, Any]]
parameters: dict[str, Any] parameters: dict[str, Any]
cache_key: str | None = None # Parameter name to use as STM cache key.
_PY_TYPE_TO_JSON = { _PY_TYPE_TO_JSON = {
@@ -84,11 +85,14 @@ def _create_tool_from_function(func: Callable, spec: ToolSpec | None = None) ->
"required": required, "required": required,
} }
cache_key = spec.cache.key if spec is not None and spec.cache is not None else None
return Tool( return Tool(
name=func.__name__, name=func.__name__,
description=description, description=description,
func=func, func=func,
parameters=parameters, parameters=parameters,
cache_key=cache_key,
) )
+26 -1
View File
@@ -61,6 +61,18 @@ class ReturnsSpec:
) )
@dataclass(frozen=True)
class CacheSpec:
"""Marks a tool as cacheable in STM.tool_results, keyed by one of its parameters."""
key: str # Name of the parameter whose value is the cache key.
@classmethod
def from_dict(cls, data: dict) -> CacheSpec:
_require(data, "key", "cache")
return cls(key=str(data["key"]).strip())
@dataclass(frozen=True) @dataclass(frozen=True)
class ToolSpec: class ToolSpec:
"""Full semantic spec for one tool.""" """Full semantic spec for one tool."""
@@ -73,6 +85,7 @@ class ToolSpec:
next_steps: str | None next_steps: str | None
parameters: dict[str, ParameterSpec] # name -> ParameterSpec parameters: dict[str, ParameterSpec] # name -> ParameterSpec
returns: dict[str, ReturnsSpec] # status_key -> ReturnsSpec returns: dict[str, ReturnsSpec] # status_key -> ReturnsSpec
cache: CacheSpec | None = None # If present, tool is cached.
@classmethod @classmethod
def from_yaml_path(cls, path: Path) -> ToolSpec: def from_yaml_path(cls, path: Path) -> ToolSpec:
@@ -108,7 +121,12 @@ class ToolSpec:
for rkey, rdata in returns_raw.items() for rkey, rdata in returns_raw.items()
} }
return cls( cache_raw = data.get("cache")
if cache_raw is not None and not isinstance(cache_raw, dict):
raise ToolSpecError("cache must be a mapping")
cache = CacheSpec.from_dict(cache_raw) if cache_raw else None
spec = cls(
name=str(data["name"]).strip(), name=str(data["name"]).strip(),
summary=str(data["summary"]).strip(), summary=str(data["summary"]).strip(),
description=str(data["description"]).strip(), description=str(data["description"]).strip(),
@@ -117,7 +135,14 @@ class ToolSpec:
next_steps=_strip_or_none(data.get("next_steps")), next_steps=_strip_or_none(data.get("next_steps")),
parameters=parameters, parameters=parameters,
returns=returns, returns=returns,
cache=cache,
) )
if cache is not None and cache.key not in parameters:
raise ToolSpecError(
f"cache.key '{cache.key}' is not a declared parameter "
f"(declared: {sorted(parameters)})"
)
return spec
def compile_description(self) -> str: def compile_description(self) -> str:
""" """
@@ -37,6 +37,9 @@ next_steps: |
- media_type in (other, unknown) → ask the user what to do; do not - media_type in (other, unknown) → ask the user what to do; do not
auto-route. auto-route.
cache:
key: source_path
parameters: parameters:
release_name: release_name:
description: Raw release folder or file name as it appears on disk. description: Raw release folder or file name as it appears on disk.
@@ -27,6 +27,9 @@ next_steps: |
- On status=error (not_found): show the error and ask the user for - On status=error (not_found): show the error and ask the user for
a more precise title. a more precise title.
cache:
key: media_title
parameters: parameters:
media_title: media_title:
description: Title to search for. Free-form — TMDB does the matching. description: Title to search for. Free-form — TMDB does the matching.
@@ -27,6 +27,9 @@ next_steps: |
- Once chosen: call add_torrent_by_index(N) — that wraps - Once chosen: call add_torrent_by_index(N) — that wraps
get_torrent_by_index + add_torrent_to_qbittorrent. get_torrent_by_index + add_torrent_to_qbittorrent.
cache:
key: media_title
parameters: parameters:
media_title: media_title:
description: Title to search for on Knaben. Free-form. description: Title to search for on Knaben. Free-form.
@@ -29,6 +29,9 @@ next_steps: |
- After listing a library folder: use the result to disambiguate a - After listing a library folder: use the result to disambiguate a
destination during resolve_*_destination. destination during resolve_*_destination.
cache:
key: path
parameters: parameters:
folder_type: folder_type:
description: Logical folder key (download, torrent, movie, tv_show, ...). description: Logical folder key (download, torrent, movie, tv_show, ...).
@@ -27,6 +27,9 @@ next_steps: |
"this is 7.1 DTS, want to keep it?"); rarely chained directly to "this is 7.1 DTS, want to keep it?"); rarely chained directly to
another tool. another tool.
cache:
key: source_path
parameters: parameters:
source_path: source_path:
description: Absolute path to the video file to probe. description: Absolute path to the video file to probe.
@@ -1,5 +1,13 @@
from .conversation import Conversation from .conversation import Conversation
from .entities import Entities from .entities import Entities
from .release_focus import ReleaseFocus
from .tool_results import ToolResultsCache
from .workflow import Workflow from .workflow import Workflow
__all__ = ["Conversation", "Workflow", "Entities"] __all__ = [
"Conversation",
"Entities",
"ReleaseFocus",
"ToolResultsCache",
"Workflow",
]
@@ -0,0 +1,68 @@
"""ReleaseFocus — which release(s) the conversation is currently about.
Two slots:
- current_release_path: the single release the user is actively
working on. Updated automatically when an inspector tool
(analyze_release, probe_media, ...) is called on a path.
- working_set: every release path mentioned in this session, in
insertion order. Useful to answer "what were we looking at?".
Only paths are stored — full metadata lives in each release's
`.alfred/metadata.yaml`. Read it via a dedicated tool when needed.
"""
import logging
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
@dataclass
class ReleaseFocus:
current_release_path: str | None = None
working_set: list[str] = field(default_factory=list)
def focus(self, path: str) -> None:
"""Set the current release and add it to the working set."""
self.current_release_path = path
if path not in self.working_set:
self.working_set.append(path)
logger.debug(f"ReleaseFocus: added '{path}' to working set")
logger.debug(f"ReleaseFocus: current -> '{path}'")
def clear(self) -> None:
self.current_release_path = None
self.working_set = []
@classmethod
def describe(cls) -> dict:
return {
"name": "ReleaseFocus",
"tier": "stm",
"access": "read",
"description": (
"Tracks which release the conversation is centred on. "
"Updated automatically when an inspector tool runs on a "
"release path. Read 'current_release_path' to know what "
"the user is talking about without re-asking; read "
"'working_set' for the list of releases touched in this "
"session."
),
"fields": {
"current_release_path": (
"Absolute path of the release currently in focus, or "
"None if no release has been inspected yet."
),
"working_set": (
"Ordered list of every release path inspected during "
"the session. Refer to entries by their basename when "
"talking to the user."
),
},
}
def to_dict(self) -> dict:
return {
"current_release_path": self.current_release_path,
"working_set": list(self.working_set),
}
@@ -0,0 +1,69 @@
"""ToolResultsCache — per-session cache of read-only tool results.
Tools flagged with `cache.key` in their YAML spec write their results
here automatically (via Agent._execute_tool_call). Lookup happens
before execution: a hit returns the cached value without re-calling
the underlying tool.
Lives for the duration of the session — cleared on memory.clear_session()
or on a fresh process start.
"""
import logging
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class ToolResultsCache:
# {tool_name: {key: result}}
results: dict[str, dict[str, Any]] = field(default_factory=dict)
def get(self, tool_name: str, key: str) -> dict | None:
bucket = self.results.get(tool_name)
if bucket is None:
return None
return bucket.get(key)
def put(self, tool_name: str, key: str, result: dict) -> None:
self.results.setdefault(tool_name, {})[key] = result
logger.debug(f"ToolResultsCache: stored {tool_name}[{key}]")
def keys_for(self, tool_name: str) -> list[str]:
return list(self.results.get(tool_name, {}).keys())
def clear(self) -> None:
self.results = {}
@classmethod
def describe(cls) -> dict:
return {
"name": "ToolResultsCache",
"tier": "stm",
"access": "read",
"description": (
"Per-session cache of read-only tool results, keyed by tool "
"name and a natural identifier from the tool's arguments "
"(usually a path or a title). Hits short-circuit the next "
"call to the same tool with the same key. Populated "
"automatically by the agent — do not write to it directly. "
"Consult it before suggesting a probe / analyze you may "
"have already done in this session."
),
"fields": {
"results": (
"Nested dict: {tool_name: {cache_key: result}}. "
"Read-only inventory; the agent should reference cached "
"values rather than re-calling the tool."
),
},
}
def to_dict(self) -> dict:
# Surface only the index (tool + keys), not the payloads — payloads
# can be large and the prompt only needs to know what's available.
return {
tool: list(bucket.keys()) for tool, bucket in self.results.items()
}
@@ -3,7 +3,13 @@
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from .components import Conversation, Entities, Workflow from .components import (
Conversation,
Entities,
ReleaseFocus,
ToolResultsCache,
Workflow,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,6 +28,8 @@ class ShortTermMemory:
conversation: Conversation = field(default_factory=Conversation) conversation: Conversation = field(default_factory=Conversation)
workflow: Workflow = field(default_factory=Workflow) workflow: Workflow = field(default_factory=Workflow)
entities: Entities = field(default_factory=Entities) entities: Entities = field(default_factory=Entities)
tool_results: ToolResultsCache = field(default_factory=ToolResultsCache)
release_focus: ReleaseFocus = field(default_factory=ReleaseFocus)
# Convenience proxies kept for backward compatibility with existing callers # Convenience proxies kept for backward compatibility with existing callers
@property @property
@@ -79,6 +87,8 @@ class ShortTermMemory:
self.conversation.clear() self.conversation.clear()
self.workflow.clear() self.workflow.clear()
self.entities.clear() self.entities.clear()
self.tool_results.clear()
self.release_focus.clear()
logger.info("STM: Cleared") logger.info("STM: Cleared")
def to_dict(self) -> dict: def to_dict(self) -> dict:
@@ -88,4 +98,6 @@ class ShortTermMemory:
"extracted_entities": self.entities.data, "extracted_entities": self.entities.data,
"current_topic": self.entities.topic, "current_topic": self.entities.topic,
"language": self.conversation.language, "language": self.conversation.language,
"tool_results": self.tool_results.to_dict(),
"release_focus": self.release_focus.to_dict(),
} }