feat(memory): Phase 1 — STM ToolResultsCache + ReleaseFocus + cache flag in YAML specs
Adds two STM components and a transparent cache hook in the agent loop so
read-only tools don't re-do work the agent already did in this session.
New STM components:
- ToolResultsCache — {tool_name: {key: result}}, session-scoped.
to_dict() exposes only the key inventory (not payloads) to keep the
prompt cheap.
- ReleaseFocus — current_release_path + working_set list, updated
automatically when a path-keyed inspector runs.
YAML spec layer:
- New optional 'cache: { key: <param_name> }' block in ToolSpec.
- Validated at load time: cache.key must be a declared parameter.
- Surfaced on Tool dataclass as cache_key: str | None.
Agent._execute_tool_call:
- Pre-exec cache lookup; hit short-circuits and adds _from_cache=true.
- Post-exec: stores successful results, updates release_focus for
path-keyed tools, refreshes episodic.last_search_results when
find_torrent's hit served the response (so get_torrent_by_index
keeps pointing at the right list).
Cacheable tools (5): analyze_release, probe_media, list_folder,
find_media_imdb_id, find_torrent.
This commit is contained in:
+65
-4
@@ -140,7 +140,7 @@ class Agent:
|
|||||||
memory.save()
|
memory.save()
|
||||||
return final_response
|
return final_response
|
||||||
|
|
||||||
def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]:
|
def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]: # noqa: PLR0911
|
||||||
"""
|
"""
|
||||||
Execute a single tool call.
|
Execute a single tool call.
|
||||||
|
|
||||||
@@ -183,25 +183,86 @@ class Agent:
|
|||||||
}
|
}
|
||||||
|
|
||||||
tool = self.tools[tool_name]
|
tool = self.tools[tool_name]
|
||||||
|
memory = get_memory()
|
||||||
|
|
||||||
|
# Cache lookup — for tools flagged cacheable, short-circuit on hit.
|
||||||
|
cache_key_value = self._cache_key_for(tool, args)
|
||||||
|
if cache_key_value is not None:
|
||||||
|
cached = memory.stm.tool_results.get(tool_name, cache_key_value)
|
||||||
|
if cached is not None:
|
||||||
|
logger.info(
|
||||||
|
f"Tool cache HIT: {tool_name}[{cache_key_value}]"
|
||||||
|
)
|
||||||
|
self._post_tool_side_effects(tool_name, args, cached, from_cache=True)
|
||||||
|
return {**cached, "_from_cache": True}
|
||||||
|
|
||||||
# Execute tool
|
# Execute tool
|
||||||
try:
|
try:
|
||||||
result = tool.func(**args)
|
result = tool.func(**args)
|
||||||
return result
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
# Don't catch KeyboardInterrupt - let it propagate
|
# Don't catch KeyboardInterrupt - let it propagate
|
||||||
raise
|
raise
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
# Bad arguments
|
# Bad arguments
|
||||||
memory = get_memory()
|
|
||||||
memory.episodic.add_error(tool_name, f"bad_args: {e}")
|
memory.episodic.add_error(tool_name, f"bad_args: {e}")
|
||||||
return {"error": "bad_args", "message": str(e), "tool": tool_name}
|
return {"error": "bad_args", "message": str(e), "tool": tool_name}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Other errors
|
# Other errors
|
||||||
memory = get_memory()
|
|
||||||
memory.episodic.add_error(tool_name, str(e))
|
memory.episodic.add_error(tool_name, str(e))
|
||||||
return {"error": "execution_failed", "message": str(e), "tool": tool_name}
|
return {"error": "execution_failed", "message": str(e), "tool": tool_name}
|
||||||
|
|
||||||
|
# Persist + side effects only on successful results.
|
||||||
|
if isinstance(result, dict) and result.get("status") == "ok":
|
||||||
|
if cache_key_value is not None:
|
||||||
|
memory.stm.tool_results.put(tool_name, cache_key_value, result)
|
||||||
|
self._post_tool_side_effects(tool_name, args, result, from_cache=False)
|
||||||
|
memory.save()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cache_key_for(tool: Tool, args: dict[str, Any]) -> str | None:
|
||||||
|
"""Return the cache key value for this call, or None if not cacheable."""
|
||||||
|
if tool.cache_key is None:
|
||||||
|
return None
|
||||||
|
value = args.get(tool.cache_key)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
def _post_tool_side_effects(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
args: dict[str, Any],
|
||||||
|
result: dict[str, Any],
|
||||||
|
*,
|
||||||
|
from_cache: bool,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Tool-agnostic side effects applied after a successful run or cache hit.
|
||||||
|
|
||||||
|
Today:
|
||||||
|
- Update release_focus when a path-keyed inspector runs.
|
||||||
|
- Refresh episodic.last_search_results on find_torrent cache hits so
|
||||||
|
get_torrent_by_index keeps pointing at the right list.
|
||||||
|
"""
|
||||||
|
memory = get_memory()
|
||||||
|
tool = self.tools.get(tool_name)
|
||||||
|
|
||||||
|
# Release focus: any path-keyed inspector updates current_release_path.
|
||||||
|
if tool is not None and tool.cache_key in {"source_path"}:
|
||||||
|
path = args.get(tool.cache_key)
|
||||||
|
if isinstance(path, str) and path:
|
||||||
|
memory.stm.release_focus.focus(path)
|
||||||
|
|
||||||
|
# Episodic refresh when find_torrent's cache short-circuits the call.
|
||||||
|
if from_cache and tool_name == "find_torrent":
|
||||||
|
torrents = result.get("torrents") or []
|
||||||
|
query = args.get("media_title") or ""
|
||||||
|
memory.episodic.store_search_results(
|
||||||
|
query=query, results=torrents, search_type="torrent"
|
||||||
|
)
|
||||||
|
|
||||||
async def step_streaming(
|
async def step_streaming(
|
||||||
self, user_input: str, completion_id: str, created_ts: int, model: str
|
self, user_input: str, completion_id: str, created_ts: int, model: str
|
||||||
) -> AsyncGenerator[dict[str, Any]]:
|
) -> AsyncGenerator[dict[str, Any]]:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class Tool:
|
|||||||
description: str
|
description: str
|
||||||
func: Callable[..., dict[str, Any]]
|
func: Callable[..., dict[str, Any]]
|
||||||
parameters: dict[str, Any]
|
parameters: dict[str, Any]
|
||||||
|
cache_key: str | None = None # Parameter name to use as STM cache key.
|
||||||
|
|
||||||
|
|
||||||
_PY_TYPE_TO_JSON = {
|
_PY_TYPE_TO_JSON = {
|
||||||
@@ -84,11 +85,14 @@ def _create_tool_from_function(func: Callable, spec: ToolSpec | None = None) ->
|
|||||||
"required": required,
|
"required": required,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cache_key = spec.cache.key if spec is not None and spec.cache is not None else None
|
||||||
|
|
||||||
return Tool(
|
return Tool(
|
||||||
name=func.__name__,
|
name=func.__name__,
|
||||||
description=description,
|
description=description,
|
||||||
func=func,
|
func=func,
|
||||||
parameters=parameters,
|
parameters=parameters,
|
||||||
|
cache_key=cache_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -61,6 +61,18 @@ class ReturnsSpec:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CacheSpec:
|
||||||
|
"""Marks a tool as cacheable in STM.tool_results, keyed by one of its parameters."""
|
||||||
|
|
||||||
|
key: str # Name of the parameter whose value is the cache key.
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> CacheSpec:
|
||||||
|
_require(data, "key", "cache")
|
||||||
|
return cls(key=str(data["key"]).strip())
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ToolSpec:
|
class ToolSpec:
|
||||||
"""Full semantic spec for one tool."""
|
"""Full semantic spec for one tool."""
|
||||||
@@ -73,6 +85,7 @@ class ToolSpec:
|
|||||||
next_steps: str | None
|
next_steps: str | None
|
||||||
parameters: dict[str, ParameterSpec] # name -> ParameterSpec
|
parameters: dict[str, ParameterSpec] # name -> ParameterSpec
|
||||||
returns: dict[str, ReturnsSpec] # status_key -> ReturnsSpec
|
returns: dict[str, ReturnsSpec] # status_key -> ReturnsSpec
|
||||||
|
cache: CacheSpec | None = None # If present, tool is cached.
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_yaml_path(cls, path: Path) -> ToolSpec:
|
def from_yaml_path(cls, path: Path) -> ToolSpec:
|
||||||
@@ -108,7 +121,12 @@ class ToolSpec:
|
|||||||
for rkey, rdata in returns_raw.items()
|
for rkey, rdata in returns_raw.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
return cls(
|
cache_raw = data.get("cache")
|
||||||
|
if cache_raw is not None and not isinstance(cache_raw, dict):
|
||||||
|
raise ToolSpecError("cache must be a mapping")
|
||||||
|
cache = CacheSpec.from_dict(cache_raw) if cache_raw else None
|
||||||
|
|
||||||
|
spec = cls(
|
||||||
name=str(data["name"]).strip(),
|
name=str(data["name"]).strip(),
|
||||||
summary=str(data["summary"]).strip(),
|
summary=str(data["summary"]).strip(),
|
||||||
description=str(data["description"]).strip(),
|
description=str(data["description"]).strip(),
|
||||||
@@ -117,7 +135,14 @@ class ToolSpec:
|
|||||||
next_steps=_strip_or_none(data.get("next_steps")),
|
next_steps=_strip_or_none(data.get("next_steps")),
|
||||||
parameters=parameters,
|
parameters=parameters,
|
||||||
returns=returns,
|
returns=returns,
|
||||||
|
cache=cache,
|
||||||
)
|
)
|
||||||
|
if cache is not None and cache.key not in parameters:
|
||||||
|
raise ToolSpecError(
|
||||||
|
f"cache.key '{cache.key}' is not a declared parameter "
|
||||||
|
f"(declared: {sorted(parameters)})"
|
||||||
|
)
|
||||||
|
return spec
|
||||||
|
|
||||||
def compile_description(self) -> str:
|
def compile_description(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ next_steps: |
|
|||||||
- media_type in (other, unknown) → ask the user what to do; do not
|
- media_type in (other, unknown) → ask the user what to do; do not
|
||||||
auto-route.
|
auto-route.
|
||||||
|
|
||||||
|
cache:
|
||||||
|
key: source_path
|
||||||
|
|
||||||
parameters:
|
parameters:
|
||||||
release_name:
|
release_name:
|
||||||
description: Raw release folder or file name as it appears on disk.
|
description: Raw release folder or file name as it appears on disk.
|
||||||
|
|||||||
@@ -27,6 +27,9 @@ next_steps: |
|
|||||||
- On status=error (not_found): show the error and ask the user for
|
- On status=error (not_found): show the error and ask the user for
|
||||||
a more precise title.
|
a more precise title.
|
||||||
|
|
||||||
|
cache:
|
||||||
|
key: media_title
|
||||||
|
|
||||||
parameters:
|
parameters:
|
||||||
media_title:
|
media_title:
|
||||||
description: Title to search for. Free-form — TMDB does the matching.
|
description: Title to search for. Free-form — TMDB does the matching.
|
||||||
|
|||||||
@@ -27,6 +27,9 @@ next_steps: |
|
|||||||
- Once chosen: call add_torrent_by_index(N) — that wraps
|
- Once chosen: call add_torrent_by_index(N) — that wraps
|
||||||
get_torrent_by_index + add_torrent_to_qbittorrent.
|
get_torrent_by_index + add_torrent_to_qbittorrent.
|
||||||
|
|
||||||
|
cache:
|
||||||
|
key: media_title
|
||||||
|
|
||||||
parameters:
|
parameters:
|
||||||
media_title:
|
media_title:
|
||||||
description: Title to search for on Knaben. Free-form.
|
description: Title to search for on Knaben. Free-form.
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ next_steps: |
|
|||||||
- After listing a library folder: use the result to disambiguate a
|
- After listing a library folder: use the result to disambiguate a
|
||||||
destination during resolve_*_destination.
|
destination during resolve_*_destination.
|
||||||
|
|
||||||
|
cache:
|
||||||
|
key: path
|
||||||
|
|
||||||
parameters:
|
parameters:
|
||||||
folder_type:
|
folder_type:
|
||||||
description: Logical folder key (download, torrent, movie, tv_show, ...).
|
description: Logical folder key (download, torrent, movie, tv_show, ...).
|
||||||
|
|||||||
@@ -27,6 +27,9 @@ next_steps: |
|
|||||||
"this is 7.1 DTS, want to keep it?"); rarely chained directly to
|
"this is 7.1 DTS, want to keep it?"); rarely chained directly to
|
||||||
another tool.
|
another tool.
|
||||||
|
|
||||||
|
cache:
|
||||||
|
key: source_path
|
||||||
|
|
||||||
parameters:
|
parameters:
|
||||||
source_path:
|
source_path:
|
||||||
description: Absolute path to the video file to probe.
|
description: Absolute path to the video file to probe.
|
||||||
|
|||||||
@@ -1,5 +1,13 @@
|
|||||||
from .conversation import Conversation
|
from .conversation import Conversation
|
||||||
from .entities import Entities
|
from .entities import Entities
|
||||||
|
from .release_focus import ReleaseFocus
|
||||||
|
from .tool_results import ToolResultsCache
|
||||||
from .workflow import Workflow
|
from .workflow import Workflow
|
||||||
|
|
||||||
__all__ = ["Conversation", "Workflow", "Entities"]
|
__all__ = [
|
||||||
|
"Conversation",
|
||||||
|
"Entities",
|
||||||
|
"ReleaseFocus",
|
||||||
|
"ToolResultsCache",
|
||||||
|
"Workflow",
|
||||||
|
]
|
||||||
|
|||||||
@@ -0,0 +1,68 @@
|
|||||||
|
"""ReleaseFocus — which release(s) the conversation is currently about.
|
||||||
|
|
||||||
|
Two slots:
|
||||||
|
- current_release_path: the single release the user is actively
|
||||||
|
working on. Updated automatically when an inspector tool
|
||||||
|
(analyze_release, probe_media, ...) is called on a path.
|
||||||
|
- working_set: every release path mentioned in this session, in
|
||||||
|
insertion order. Useful to answer "what were we looking at?".
|
||||||
|
|
||||||
|
Only paths are stored — full metadata lives in each release's
|
||||||
|
`.alfred/metadata.yaml`. Read it via a dedicated tool when needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReleaseFocus:
|
||||||
|
current_release_path: str | None = None
|
||||||
|
working_set: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def focus(self, path: str) -> None:
|
||||||
|
"""Set the current release and add it to the working set."""
|
||||||
|
self.current_release_path = path
|
||||||
|
if path not in self.working_set:
|
||||||
|
self.working_set.append(path)
|
||||||
|
logger.debug(f"ReleaseFocus: added '{path}' to working set")
|
||||||
|
logger.debug(f"ReleaseFocus: current -> '{path}'")
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self.current_release_path = None
|
||||||
|
self.working_set = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def describe(cls) -> dict:
|
||||||
|
return {
|
||||||
|
"name": "ReleaseFocus",
|
||||||
|
"tier": "stm",
|
||||||
|
"access": "read",
|
||||||
|
"description": (
|
||||||
|
"Tracks which release the conversation is centred on. "
|
||||||
|
"Updated automatically when an inspector tool runs on a "
|
||||||
|
"release path. Read 'current_release_path' to know what "
|
||||||
|
"the user is talking about without re-asking; read "
|
||||||
|
"'working_set' for the list of releases touched in this "
|
||||||
|
"session."
|
||||||
|
),
|
||||||
|
"fields": {
|
||||||
|
"current_release_path": (
|
||||||
|
"Absolute path of the release currently in focus, or "
|
||||||
|
"None if no release has been inspected yet."
|
||||||
|
),
|
||||||
|
"working_set": (
|
||||||
|
"Ordered list of every release path inspected during "
|
||||||
|
"the session. Refer to entries by their basename when "
|
||||||
|
"talking to the user."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"current_release_path": self.current_release_path,
|
||||||
|
"working_set": list(self.working_set),
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
"""ToolResultsCache — per-session cache of read-only tool results.
|
||||||
|
|
||||||
|
Tools flagged with `cache.key` in their YAML spec write their results
|
||||||
|
here automatically (via Agent._execute_tool_call). Lookup happens
|
||||||
|
before execution: a hit returns the cached value without re-calling
|
||||||
|
the underlying tool.
|
||||||
|
|
||||||
|
Lives for the duration of the session — cleared on memory.clear_session()
|
||||||
|
or on a fresh process start.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolResultsCache:
|
||||||
|
# {tool_name: {key: result}}
|
||||||
|
results: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def get(self, tool_name: str, key: str) -> dict | None:
|
||||||
|
bucket = self.results.get(tool_name)
|
||||||
|
if bucket is None:
|
||||||
|
return None
|
||||||
|
return bucket.get(key)
|
||||||
|
|
||||||
|
def put(self, tool_name: str, key: str, result: dict) -> None:
|
||||||
|
self.results.setdefault(tool_name, {})[key] = result
|
||||||
|
logger.debug(f"ToolResultsCache: stored {tool_name}[{key}]")
|
||||||
|
|
||||||
|
def keys_for(self, tool_name: str) -> list[str]:
|
||||||
|
return list(self.results.get(tool_name, {}).keys())
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self.results = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def describe(cls) -> dict:
|
||||||
|
return {
|
||||||
|
"name": "ToolResultsCache",
|
||||||
|
"tier": "stm",
|
||||||
|
"access": "read",
|
||||||
|
"description": (
|
||||||
|
"Per-session cache of read-only tool results, keyed by tool "
|
||||||
|
"name and a natural identifier from the tool's arguments "
|
||||||
|
"(usually a path or a title). Hits short-circuit the next "
|
||||||
|
"call to the same tool with the same key. Populated "
|
||||||
|
"automatically by the agent — do not write to it directly. "
|
||||||
|
"Consult it before suggesting a probe / analyze you may "
|
||||||
|
"have already done in this session."
|
||||||
|
),
|
||||||
|
"fields": {
|
||||||
|
"results": (
|
||||||
|
"Nested dict: {tool_name: {cache_key: result}}. "
|
||||||
|
"Read-only inventory; the agent should reference cached "
|
||||||
|
"values rather than re-calling the tool."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
# Surface only the index (tool + keys), not the payloads — payloads
|
||||||
|
# can be large and the prompt only needs to know what's available.
|
||||||
|
return {
|
||||||
|
tool: list(bucket.keys()) for tool, bucket in self.results.items()
|
||||||
|
}
|
||||||
@@ -3,7 +3,13 @@
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from .components import Conversation, Entities, Workflow
|
from .components import (
|
||||||
|
Conversation,
|
||||||
|
Entities,
|
||||||
|
ReleaseFocus,
|
||||||
|
ToolResultsCache,
|
||||||
|
Workflow,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -22,6 +28,8 @@ class ShortTermMemory:
|
|||||||
conversation: Conversation = field(default_factory=Conversation)
|
conversation: Conversation = field(default_factory=Conversation)
|
||||||
workflow: Workflow = field(default_factory=Workflow)
|
workflow: Workflow = field(default_factory=Workflow)
|
||||||
entities: Entities = field(default_factory=Entities)
|
entities: Entities = field(default_factory=Entities)
|
||||||
|
tool_results: ToolResultsCache = field(default_factory=ToolResultsCache)
|
||||||
|
release_focus: ReleaseFocus = field(default_factory=ReleaseFocus)
|
||||||
|
|
||||||
# Convenience proxies kept for backward compatibility with existing callers
|
# Convenience proxies kept for backward compatibility with existing callers
|
||||||
@property
|
@property
|
||||||
@@ -79,6 +87,8 @@ class ShortTermMemory:
|
|||||||
self.conversation.clear()
|
self.conversation.clear()
|
||||||
self.workflow.clear()
|
self.workflow.clear()
|
||||||
self.entities.clear()
|
self.entities.clear()
|
||||||
|
self.tool_results.clear()
|
||||||
|
self.release_focus.clear()
|
||||||
logger.info("STM: Cleared")
|
logger.info("STM: Cleared")
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
@@ -88,4 +98,6 @@ class ShortTermMemory:
|
|||||||
"extracted_entities": self.entities.data,
|
"extracted_entities": self.entities.data,
|
||||||
"current_topic": self.entities.topic,
|
"current_topic": self.entities.topic,
|
||||||
"language": self.conversation.language,
|
"language": self.conversation.language,
|
||||||
|
"tool_results": self.tool_results.to_dict(),
|
||||||
|
"release_focus": self.release_focus.to_dict(),
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user