feat!: migrate to OpenAI native tool calls and fix circular deps (#fuck-gemini)
- Fix circular dependencies in agent/tools - Migrate from custom JSON to OpenAI tool calls format - Add async streaming (step_stream, complete_stream) - Simplify prompt system and remove token counting - Add 5 new API endpoints (/health, /v1/models, /api/memory/*) - Add 3 new tools (get_torrent_by_index, add_torrent_by_index, set_language) - Fix all 500 tests and add coverage config (80% threshold) - Add comprehensive docs (README, pytest guide) BREAKING: LLM interface changed, memory injection via get_memory()
This commit is contained in:
@@ -1 +1,25 @@
|
||||
"""Persistence layer - Data storage implementations."""
|
||||
|
||||
from .context import (
|
||||
get_memory,
|
||||
has_memory,
|
||||
init_memory,
|
||||
set_memory,
|
||||
)
|
||||
from .memory import (
|
||||
EpisodicMemory,
|
||||
LongTermMemory,
|
||||
Memory,
|
||||
ShortTermMemory,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Memory",
|
||||
"LongTermMemory",
|
||||
"ShortTermMemory",
|
||||
"EpisodicMemory",
|
||||
"init_memory",
|
||||
"set_memory",
|
||||
"get_memory",
|
||||
"has_memory",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Memory context using contextvars.
|
||||
|
||||
Provides thread-safe and async-safe access to the Memory instance
|
||||
without passing it explicitly through all function calls.
|
||||
|
||||
Usage:
|
||||
# At application startup
|
||||
from infrastructure.persistence import init_memory, get_memory
|
||||
|
||||
init_memory("memory_data")
|
||||
|
||||
# Anywhere in the code
|
||||
memory = get_memory()
|
||||
memory.ltm.set_config("key", "value")
|
||||
"""
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
from .memory import Memory
|
||||
|
||||
_memory_ctx: ContextVar[Memory | None] = ContextVar("memory", default=None)
|
||||
|
||||
|
||||
def init_memory(storage_dir: str = "memory_data") -> Memory:
|
||||
"""
|
||||
Initialize the memory and set it in the context.
|
||||
|
||||
Call this once at application startup.
|
||||
|
||||
Args:
|
||||
storage_dir: Directory for persistent storage.
|
||||
|
||||
Returns:
|
||||
The initialized Memory instance.
|
||||
"""
|
||||
memory = Memory(storage_dir=storage_dir)
|
||||
_memory_ctx.set(memory)
|
||||
return memory
|
||||
|
||||
|
||||
def set_memory(memory: Memory) -> None:
|
||||
"""
|
||||
Set an existing Memory instance in the context.
|
||||
|
||||
Useful for testing or when injecting a specific instance.
|
||||
|
||||
Args:
|
||||
memory: Memory instance to set.
|
||||
"""
|
||||
_memory_ctx.set(memory)
|
||||
|
||||
|
||||
def get_memory() -> Memory:
|
||||
"""
|
||||
Get the Memory instance from the context.
|
||||
|
||||
Returns:
|
||||
The Memory instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If memory has not been initialized.
|
||||
"""
|
||||
memory = _memory_ctx.get()
|
||||
if memory is None:
|
||||
raise RuntimeError(
|
||||
"Memory not initialized. Call init_memory() at application startup."
|
||||
)
|
||||
return memory
|
||||
|
||||
|
||||
def has_memory() -> bool:
|
||||
"""
|
||||
Check if memory has been initialized.
|
||||
|
||||
Returns:
|
||||
True if memory is available, False otherwise.
|
||||
"""
|
||||
return _memory_ctx.get() is not None
|
||||
@@ -1,7 +1,8 @@
|
||||
"""JSON-based repository implementations."""
|
||||
|
||||
from .movie_repository import JsonMovieRepository
|
||||
from .tvshow_repository import JsonTVShowRepository
|
||||
from .subtitle_repository import JsonSubtitleRepository
|
||||
from .tvshow_repository import JsonTVShowRepository
|
||||
|
||||
__all__ = [
|
||||
"JsonMovieRepository",
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
"""JSON-based movie repository implementation."""
|
||||
from typing import List, Optional, Dict, Any
|
||||
import logging
|
||||
|
||||
from domain.movies.repositories import MovieRepository
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from domain.movies.entities import Movie
|
||||
from domain.shared.value_objects import ImdbId
|
||||
from ..memory import Memory
|
||||
from domain.movies.repositories import MovieRepository
|
||||
from domain.movies.value_objects import MovieTitle, Quality, ReleaseYear
|
||||
from domain.shared.value_objects import FilePath, FileSize, ImdbId
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,103 +16,129 @@ logger = logging.getLogger(__name__)
|
||||
class JsonMovieRepository(MovieRepository):
|
||||
"""
|
||||
JSON-based implementation of MovieRepository.
|
||||
|
||||
Stores movies in the memory.json file.
|
||||
|
||||
Stores movies in the LTM library using the memory context.
|
||||
"""
|
||||
|
||||
def __init__(self, memory: Memory):
|
||||
"""
|
||||
Initialize repository.
|
||||
|
||||
Args:
|
||||
memory: Memory instance for persistence
|
||||
"""
|
||||
self.memory = memory
|
||||
|
||||
|
||||
def save(self, movie: Movie) -> None:
|
||||
"""Save a movie to the repository."""
|
||||
movies = self._load_all()
|
||||
|
||||
"""
|
||||
Save a movie to the repository.
|
||||
|
||||
Updates existing movie if IMDb ID matches.
|
||||
|
||||
Args:
|
||||
movie: Movie entity to save.
|
||||
"""
|
||||
memory = get_memory()
|
||||
movies = memory.ltm.library.get("movies", [])
|
||||
|
||||
# Remove existing movie with same IMDb ID
|
||||
movies = [m for m in movies if m.get('imdb_id') != str(movie.imdb_id)]
|
||||
|
||||
# Add new movie
|
||||
movies = [m for m in movies if m.get("imdb_id") != str(movie.imdb_id)]
|
||||
|
||||
movies.append(self._to_dict(movie))
|
||||
|
||||
# Save to memory
|
||||
self.memory.set('movies', movies)
|
||||
|
||||
memory.ltm.library["movies"] = movies
|
||||
memory.save()
|
||||
logger.debug(f"Saved movie: {movie.imdb_id}")
|
||||
|
||||
def find_by_imdb_id(self, imdb_id: ImdbId) -> Optional[Movie]:
|
||||
"""Find a movie by its IMDb ID."""
|
||||
movies = self._load_all()
|
||||
|
||||
|
||||
def find_by_imdb_id(self, imdb_id: ImdbId) -> Movie | None:
|
||||
"""
|
||||
Find a movie by its IMDb ID.
|
||||
|
||||
Args:
|
||||
imdb_id: IMDb ID to search for.
|
||||
|
||||
Returns:
|
||||
Movie if found, None otherwise.
|
||||
"""
|
||||
memory = get_memory()
|
||||
movies = memory.ltm.library.get("movies", [])
|
||||
|
||||
for movie_dict in movies:
|
||||
if movie_dict.get('imdb_id') == str(imdb_id):
|
||||
if movie_dict.get("imdb_id") == str(imdb_id):
|
||||
return self._from_dict(movie_dict)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
def find_all(self) -> List[Movie]:
|
||||
"""Get all movies in the repository."""
|
||||
movies_dict = self._load_all()
|
||||
|
||||
def find_all(self) -> list[Movie]:
|
||||
"""
|
||||
Get all movies in the repository.
|
||||
|
||||
Returns:
|
||||
List of all Movie entities.
|
||||
"""
|
||||
memory = get_memory()
|
||||
movies_dict = memory.ltm.library.get("movies", [])
|
||||
return [self._from_dict(m) for m in movies_dict]
|
||||
|
||||
|
||||
def delete(self, imdb_id: ImdbId) -> bool:
|
||||
"""Delete a movie from the repository."""
|
||||
movies = self._load_all()
|
||||
"""
|
||||
Delete a movie from the repository.
|
||||
|
||||
Args:
|
||||
imdb_id: IMDb ID of movie to delete.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
memory = get_memory()
|
||||
movies = memory.ltm.library.get("movies", [])
|
||||
initial_count = len(movies)
|
||||
|
||||
# Filter out the movie
|
||||
movies = [m for m in movies if m.get('imdb_id') != str(imdb_id)]
|
||||
|
||||
|
||||
movies = [m for m in movies if m.get("imdb_id") != str(imdb_id)]
|
||||
|
||||
if len(movies) < initial_count:
|
||||
self.memory.set('movies', movies)
|
||||
memory.ltm.library["movies"] = movies
|
||||
memory.save()
|
||||
logger.debug(f"Deleted movie: {imdb_id}")
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def exists(self, imdb_id: ImdbId) -> bool:
|
||||
"""Check if a movie exists in the repository."""
|
||||
"""
|
||||
Check if a movie exists in the repository.
|
||||
|
||||
Args:
|
||||
imdb_id: IMDb ID to check.
|
||||
|
||||
Returns:
|
||||
True if exists, False otherwise.
|
||||
"""
|
||||
return self.find_by_imdb_id(imdb_id) is not None
|
||||
|
||||
def _load_all(self) -> List[Dict[str, Any]]:
|
||||
"""Load all movies from memory."""
|
||||
return self.memory.get('movies', [])
|
||||
|
||||
def _to_dict(self, movie: Movie) -> Dict[str, Any]:
|
||||
|
||||
def _to_dict(self, movie: Movie) -> dict[str, Any]:
|
||||
"""Convert Movie entity to dict for storage."""
|
||||
return {
|
||||
'imdb_id': str(movie.imdb_id),
|
||||
'title': movie.title.value,
|
||||
'release_year': movie.release_year.value if movie.release_year else None,
|
||||
'quality': movie.quality.value,
|
||||
'file_path': str(movie.file_path) if movie.file_path else None,
|
||||
'file_size': movie.file_size.bytes if movie.file_size else None,
|
||||
'tmdb_id': movie.tmdb_id,
|
||||
'overview': movie.overview,
|
||||
'poster_path': movie.poster_path,
|
||||
'vote_average': movie.vote_average,
|
||||
'added_at': movie.added_at.isoformat(),
|
||||
"imdb_id": str(movie.imdb_id),
|
||||
"title": movie.title.value,
|
||||
"release_year": movie.release_year.value if movie.release_year else None,
|
||||
"quality": movie.quality.value,
|
||||
"file_path": str(movie.file_path) if movie.file_path else None,
|
||||
"file_size": movie.file_size.bytes if movie.file_size else None,
|
||||
"tmdb_id": movie.tmdb_id,
|
||||
"added_at": movie.added_at.isoformat(),
|
||||
}
|
||||
|
||||
def _from_dict(self, data: Dict[str, Any]) -> Movie:
|
||||
|
||||
def _from_dict(self, data: dict[str, Any]) -> Movie:
|
||||
"""Convert dict from storage to Movie entity."""
|
||||
from domain.movies.value_objects import MovieTitle, ReleaseYear, Quality
|
||||
from domain.shared.value_objects import FilePath, FileSize
|
||||
from datetime import datetime
|
||||
|
||||
# Parse quality string to enum
|
||||
quality_str = data.get("quality", "unknown")
|
||||
quality = Quality.from_string(quality_str)
|
||||
|
||||
return Movie(
|
||||
imdb_id=ImdbId(data['imdb_id']),
|
||||
title=MovieTitle(data['title']),
|
||||
release_year=ReleaseYear(data['release_year']) if data.get('release_year') else None,
|
||||
quality=Quality(data.get('quality', 'unknown')),
|
||||
file_path=FilePath(data['file_path']) if data.get('file_path') else None,
|
||||
file_size=FileSize(data['file_size']) if data.get('file_size') else None,
|
||||
tmdb_id=data.get('tmdb_id'),
|
||||
overview=data.get('overview'),
|
||||
poster_path=data.get('poster_path'),
|
||||
vote_average=data.get('vote_average'),
|
||||
added_at=datetime.fromisoformat(data['added_at']) if data.get('added_at') else datetime.now(),
|
||||
imdb_id=ImdbId(data["imdb_id"]),
|
||||
title=MovieTitle(data["title"]),
|
||||
release_year=(
|
||||
ReleaseYear(data["release_year"]) if data.get("release_year") else None
|
||||
),
|
||||
quality=quality,
|
||||
file_path=FilePath(data["file_path"]) if data.get("file_path") else None,
|
||||
file_size=FileSize(data["file_size"]) if data.get("file_size") else None,
|
||||
tmdb_id=data.get("tmdb_id"),
|
||||
added_at=(
|
||||
datetime.fromisoformat(data["added_at"])
|
||||
if data.get("added_at")
|
||||
else datetime.now()
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""JSON-based subtitle repository implementation."""
|
||||
from typing import List, Optional, Dict, Any
|
||||
import logging
|
||||
|
||||
from domain.subtitles.repositories import SubtitleRepository
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from domain.shared.value_objects import FilePath, ImdbId
|
||||
from domain.subtitles.entities import Subtitle
|
||||
from domain.subtitles.repositories import SubtitleRepository
|
||||
from domain.subtitles.value_objects import Language, SubtitleFormat, TimingOffset
|
||||
from domain.shared.value_objects import ImdbId, FilePath
|
||||
from ..memory import Memory
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,114 +15,130 @@ logger = logging.getLogger(__name__)
|
||||
class JsonSubtitleRepository(SubtitleRepository):
|
||||
"""
|
||||
JSON-based implementation of SubtitleRepository.
|
||||
|
||||
Stores subtitles in the memory.json file.
|
||||
|
||||
Stores subtitles in the LTM library using the memory context.
|
||||
"""
|
||||
|
||||
def __init__(self, memory: Memory):
|
||||
"""
|
||||
Initialize repository.
|
||||
|
||||
Args:
|
||||
memory: Memory instance for persistence
|
||||
"""
|
||||
self.memory = memory
|
||||
|
||||
|
||||
def save(self, subtitle: Subtitle) -> None:
|
||||
"""Save a subtitle to the repository."""
|
||||
subtitles = self._load_all()
|
||||
|
||||
# Add new subtitle (we allow multiple subtitles for same media)
|
||||
"""
|
||||
Save a subtitle to the repository.
|
||||
|
||||
Multiple subtitles can exist for the same media.
|
||||
|
||||
Args:
|
||||
subtitle: Subtitle entity to save.
|
||||
"""
|
||||
memory = get_memory()
|
||||
subtitles = memory.ltm.library.get("subtitles", [])
|
||||
|
||||
subtitles.append(self._to_dict(subtitle))
|
||||
|
||||
# Save to memory
|
||||
self.memory.set('subtitles', subtitles)
|
||||
|
||||
if "subtitles" not in memory.ltm.library:
|
||||
memory.ltm.library["subtitles"] = []
|
||||
memory.ltm.library["subtitles"] = subtitles
|
||||
memory.save()
|
||||
logger.debug(f"Saved subtitle for: {subtitle.media_imdb_id}")
|
||||
|
||||
|
||||
def find_by_media(
|
||||
self,
|
||||
media_imdb_id: ImdbId,
|
||||
language: Optional[Language] = None,
|
||||
season: Optional[int] = None,
|
||||
episode: Optional[int] = None
|
||||
) -> List[Subtitle]:
|
||||
"""Find subtitles for a media item."""
|
||||
subtitles = self._load_all()
|
||||
language: Language | None = None,
|
||||
season: int | None = None,
|
||||
episode: int | None = None,
|
||||
) -> list[Subtitle]:
|
||||
"""
|
||||
Find subtitles for a media item.
|
||||
|
||||
Args:
|
||||
media_imdb_id: IMDb ID of the media.
|
||||
language: Optional language filter.
|
||||
season: Optional season number filter.
|
||||
episode: Optional episode number filter.
|
||||
|
||||
Returns:
|
||||
List of matching Subtitle entities.
|
||||
"""
|
||||
memory = get_memory()
|
||||
subtitles = memory.ltm.library.get("subtitles", [])
|
||||
results = []
|
||||
|
||||
|
||||
for sub_dict in subtitles:
|
||||
# Filter by IMDb ID
|
||||
if sub_dict.get('media_imdb_id') != str(media_imdb_id):
|
||||
if sub_dict.get("media_imdb_id") != str(media_imdb_id):
|
||||
continue
|
||||
|
||||
# Filter by language if specified
|
||||
if language and sub_dict.get('language') != language.value:
|
||||
|
||||
if language and sub_dict.get("language") != language.value:
|
||||
continue
|
||||
|
||||
# Filter by season/episode if specified
|
||||
if season is not None and sub_dict.get('season_number') != season:
|
||||
|
||||
if season is not None and sub_dict.get("season_number") != season:
|
||||
continue
|
||||
if episode is not None and sub_dict.get('episode_number') != episode:
|
||||
|
||||
if episode is not None and sub_dict.get("episode_number") != episode:
|
||||
continue
|
||||
|
||||
|
||||
results.append(self._from_dict(sub_dict))
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def delete(self, subtitle: Subtitle) -> bool:
|
||||
"""Delete a subtitle from the repository."""
|
||||
subtitles = self._load_all()
|
||||
"""
|
||||
Delete a subtitle from the repository.
|
||||
|
||||
Matches by file path.
|
||||
|
||||
Args:
|
||||
subtitle: Subtitle entity to delete.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
memory = get_memory()
|
||||
subtitles = memory.ltm.library.get("subtitles", [])
|
||||
initial_count = len(subtitles)
|
||||
|
||||
# Filter out the subtitle (match by file path)
|
||||
|
||||
subtitles = [
|
||||
s for s in subtitles
|
||||
if s.get('file_path') != str(subtitle.file_path)
|
||||
s for s in subtitles if s.get("file_path") != str(subtitle.file_path)
|
||||
]
|
||||
|
||||
|
||||
if len(subtitles) < initial_count:
|
||||
self.memory.set('subtitles', subtitles)
|
||||
memory.ltm.library["subtitles"] = subtitles
|
||||
memory.save()
|
||||
logger.debug(f"Deleted subtitle: {subtitle.file_path}")
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
def _load_all(self) -> List[Dict[str, Any]]:
|
||||
"""Load all subtitles from memory."""
|
||||
return self.memory.get('subtitles', [])
|
||||
|
||||
def _to_dict(self, subtitle: Subtitle) -> Dict[str, Any]:
|
||||
|
||||
def _to_dict(self, subtitle: Subtitle) -> dict[str, Any]:
|
||||
"""Convert Subtitle entity to dict for storage."""
|
||||
return {
|
||||
'media_imdb_id': str(subtitle.media_imdb_id),
|
||||
'language': subtitle.language.value,
|
||||
'format': subtitle.format.value,
|
||||
'file_path': str(subtitle.file_path),
|
||||
'season_number': subtitle.season_number,
|
||||
'episode_number': subtitle.episode_number,
|
||||
'timing_offset': subtitle.timing_offset.milliseconds,
|
||||
'hearing_impaired': subtitle.hearing_impaired,
|
||||
'forced': subtitle.forced,
|
||||
'source': subtitle.source,
|
||||
'uploader': subtitle.uploader,
|
||||
'download_count': subtitle.download_count,
|
||||
'rating': subtitle.rating,
|
||||
"media_imdb_id": str(subtitle.media_imdb_id),
|
||||
"language": subtitle.language.value,
|
||||
"format": subtitle.format.value,
|
||||
"file_path": str(subtitle.file_path),
|
||||
"season_number": subtitle.season_number,
|
||||
"episode_number": subtitle.episode_number,
|
||||
"timing_offset": subtitle.timing_offset.milliseconds,
|
||||
"hearing_impaired": subtitle.hearing_impaired,
|
||||
"forced": subtitle.forced,
|
||||
"source": subtitle.source,
|
||||
"uploader": subtitle.uploader,
|
||||
"download_count": subtitle.download_count,
|
||||
"rating": subtitle.rating,
|
||||
}
|
||||
|
||||
def _from_dict(self, data: Dict[str, Any]) -> Subtitle:
|
||||
|
||||
def _from_dict(self, data: dict[str, Any]) -> Subtitle:
|
||||
"""Convert dict from storage to Subtitle entity."""
|
||||
return Subtitle(
|
||||
media_imdb_id=ImdbId(data['media_imdb_id']),
|
||||
language=Language.from_code(data['language']),
|
||||
format=SubtitleFormat.from_extension(data['format']),
|
||||
file_path=FilePath(data['file_path']),
|
||||
season_number=data.get('season_number'),
|
||||
episode_number=data.get('episode_number'),
|
||||
timing_offset=TimingOffset(data.get('timing_offset', 0)),
|
||||
hearing_impaired=data.get('hearing_impaired', False),
|
||||
forced=data.get('forced', False),
|
||||
source=data.get('source'),
|
||||
uploader=data.get('uploader'),
|
||||
download_count=data.get('download_count'),
|
||||
rating=data.get('rating'),
|
||||
media_imdb_id=ImdbId(data["media_imdb_id"]),
|
||||
language=Language.from_code(data["language"]),
|
||||
format=SubtitleFormat.from_extension(data["format"]),
|
||||
file_path=FilePath(data["file_path"]),
|
||||
season_number=data.get("season_number"),
|
||||
episode_number=data.get("episode_number"),
|
||||
timing_offset=TimingOffset(data.get("timing_offset", 0)),
|
||||
hearing_impaired=data.get("hearing_impaired", False),
|
||||
forced=data.get("forced", False),
|
||||
source=data.get("source"),
|
||||
uploader=data.get("uploader"),
|
||||
download_count=data.get("download_count"),
|
||||
rating=data.get("rating"),
|
||||
)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""JSON-based TV show repository implementation."""
|
||||
from typing import List, Optional, Dict, Any
|
||||
import logging
|
||||
|
||||
from domain.tv_shows.repositories import TVShowRepository
|
||||
from domain.tv_shows.entities import TVShow
|
||||
from domain.tv_shows.value_objects import ShowStatus
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from domain.shared.value_objects import ImdbId
|
||||
from ..memory import Memory
|
||||
from domain.tv_shows.entities import TVShow
|
||||
from domain.tv_shows.repositories import TVShowRepository
|
||||
from domain.tv_shows.value_objects import ShowStatus
|
||||
from infrastructure.persistence import get_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,99 +16,121 @@ logger = logging.getLogger(__name__)
|
||||
class JsonTVShowRepository(TVShowRepository):
|
||||
"""
|
||||
JSON-based implementation of TVShowRepository.
|
||||
|
||||
Stores TV shows in the memory.json file (compatible with existing tv_shows structure).
|
||||
|
||||
Stores TV shows in the LTM library using the memory context.
|
||||
"""
|
||||
|
||||
def __init__(self, memory: Memory):
|
||||
"""
|
||||
Initialize repository.
|
||||
|
||||
Args:
|
||||
memory: Memory instance for persistence
|
||||
"""
|
||||
self.memory = memory
|
||||
|
||||
|
||||
def save(self, show: TVShow) -> None:
|
||||
"""Save a TV show to the repository."""
|
||||
shows = self._load_all()
|
||||
|
||||
"""
|
||||
Save a TV show to the repository.
|
||||
|
||||
Updates existing show if IMDb ID matches.
|
||||
|
||||
Args:
|
||||
show: TVShow entity to save.
|
||||
"""
|
||||
memory = get_memory()
|
||||
shows = memory.ltm.library.get("tv_shows", [])
|
||||
|
||||
# Remove existing show with same IMDb ID
|
||||
shows = [s for s in shows if s.get('imdb_id') != str(show.imdb_id)]
|
||||
|
||||
# Add new show
|
||||
shows = [s for s in shows if s.get("imdb_id") != str(show.imdb_id)]
|
||||
|
||||
shows.append(self._to_dict(show))
|
||||
|
||||
# Save to memory
|
||||
self.memory.set('tv_shows', shows)
|
||||
|
||||
memory.ltm.library["tv_shows"] = shows
|
||||
memory.save()
|
||||
logger.debug(f"Saved TV show: {show.imdb_id}")
|
||||
|
||||
def find_by_imdb_id(self, imdb_id: ImdbId) -> Optional[TVShow]:
|
||||
"""Find a TV show by its IMDb ID."""
|
||||
shows = self._load_all()
|
||||
|
||||
|
||||
def find_by_imdb_id(self, imdb_id: ImdbId) -> TVShow | None:
|
||||
"""
|
||||
Find a TV show by its IMDb ID.
|
||||
|
||||
Args:
|
||||
imdb_id: IMDb ID to search for.
|
||||
|
||||
Returns:
|
||||
TVShow if found, None otherwise.
|
||||
"""
|
||||
memory = get_memory()
|
||||
shows = memory.ltm.library.get("tv_shows", [])
|
||||
|
||||
for show_dict in shows:
|
||||
if show_dict.get('imdb_id') == str(imdb_id):
|
||||
if show_dict.get("imdb_id") == str(imdb_id):
|
||||
return self._from_dict(show_dict)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
def find_all(self) -> List[TVShow]:
|
||||
"""Get all TV shows in the repository."""
|
||||
shows_dict = self._load_all()
|
||||
|
||||
def find_all(self) -> list[TVShow]:
|
||||
"""
|
||||
Get all TV shows in the repository.
|
||||
|
||||
Returns:
|
||||
List of all TVShow entities.
|
||||
"""
|
||||
memory = get_memory()
|
||||
shows_dict = memory.ltm.library.get("tv_shows", [])
|
||||
return [self._from_dict(s) for s in shows_dict]
|
||||
|
||||
|
||||
def delete(self, imdb_id: ImdbId) -> bool:
|
||||
"""Delete a TV show from the repository."""
|
||||
shows = self._load_all()
|
||||
"""
|
||||
Delete a TV show from the repository.
|
||||
|
||||
Args:
|
||||
imdb_id: IMDb ID of show to delete.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
memory = get_memory()
|
||||
shows = memory.ltm.library.get("tv_shows", [])
|
||||
initial_count = len(shows)
|
||||
|
||||
# Filter out the show
|
||||
shows = [s for s in shows if s.get('imdb_id') != str(imdb_id)]
|
||||
|
||||
|
||||
shows = [s for s in shows if s.get("imdb_id") != str(imdb_id)]
|
||||
|
||||
if len(shows) < initial_count:
|
||||
self.memory.set('tv_shows', shows)
|
||||
memory.ltm.library["tv_shows"] = shows
|
||||
memory.save()
|
||||
logger.debug(f"Deleted TV show: {imdb_id}")
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def exists(self, imdb_id: ImdbId) -> bool:
|
||||
"""Check if a TV show exists in the repository."""
|
||||
"""
|
||||
Check if a TV show exists in the repository.
|
||||
|
||||
Args:
|
||||
imdb_id: IMDb ID to check.
|
||||
|
||||
Returns:
|
||||
True if exists, False otherwise.
|
||||
"""
|
||||
return self.find_by_imdb_id(imdb_id) is not None
|
||||
|
||||
def _load_all(self) -> List[Dict[str, Any]]:
|
||||
"""Load all TV shows from memory."""
|
||||
return self.memory.get('tv_shows', [])
|
||||
|
||||
def _to_dict(self, show: TVShow) -> Dict[str, Any]:
|
||||
|
||||
def _to_dict(self, show: TVShow) -> dict[str, Any]:
|
||||
"""Convert TVShow entity to dict for storage."""
|
||||
return {
|
||||
'imdb_id': str(show.imdb_id),
|
||||
'title': show.title,
|
||||
'seasons_count': show.seasons_count,
|
||||
'status': show.status.value,
|
||||
'tmdb_id': show.tmdb_id,
|
||||
'overview': show.overview,
|
||||
'poster_path': show.poster_path,
|
||||
'first_air_date': show.first_air_date,
|
||||
'vote_average': show.vote_average,
|
||||
'added_at': show.added_at.isoformat(),
|
||||
"imdb_id": str(show.imdb_id),
|
||||
"title": show.title,
|
||||
"seasons_count": show.seasons_count,
|
||||
"status": show.status.value,
|
||||
"tmdb_id": show.tmdb_id,
|
||||
"first_air_date": show.first_air_date,
|
||||
"added_at": show.added_at.isoformat(),
|
||||
}
|
||||
|
||||
def _from_dict(self, data: Dict[str, Any]) -> TVShow:
|
||||
|
||||
def _from_dict(self, data: dict[str, Any]) -> TVShow:
|
||||
"""Convert dict from storage to TVShow entity."""
|
||||
from datetime import datetime
|
||||
|
||||
return TVShow(
|
||||
imdb_id=ImdbId(data['imdb_id']),
|
||||
title=data['title'],
|
||||
seasons_count=data['seasons_count'],
|
||||
status=ShowStatus.from_string(data['status']),
|
||||
tmdb_id=data.get('tmdb_id'),
|
||||
overview=data.get('overview'),
|
||||
poster_path=data.get('poster_path'),
|
||||
first_air_date=data.get('first_air_date'),
|
||||
vote_average=data.get('vote_average'),
|
||||
added_at=datetime.fromisoformat(data['added_at']) if data.get('added_at') else datetime.now(),
|
||||
imdb_id=ImdbId(data["imdb_id"]),
|
||||
title=data["title"],
|
||||
seasons_count=data["seasons_count"],
|
||||
status=ShowStatus.from_string(data["status"]),
|
||||
tmdb_id=data.get("tmdb_id"),
|
||||
first_air_date=data.get("first_air_date"),
|
||||
added_at=(
|
||||
datetime.fromisoformat(data["added_at"])
|
||||
if data.get("added_at")
|
||||
else datetime.now()
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,86 +1,571 @@
|
||||
"""Memory storage - Migrated from agent/memory.py"""
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
import json
|
||||
"""
|
||||
Memory - Unified management of 3 memory types.
|
||||
|
||||
from agent.config import settings
|
||||
from agent.parameters import validate_parameter, get_parameter_schema
|
||||
Architecture:
|
||||
- LTM (Long-Term Memory): Configuration, library, preferences - Persistent
|
||||
- STM (Short-Term Memory): Conversation, current workflow - Volatile
|
||||
- Episodic Memory: Search results, transient states - Very volatile
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LONG-TERM MEMORY (LTM) - Persistent
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongTermMemory:
|
||||
"""
|
||||
Long-term memory - Persistent and static.
|
||||
|
||||
Stores:
|
||||
- User configuration (folders, URLs)
|
||||
- Preferences (quality, languages)
|
||||
- Library (owned movies/TV shows)
|
||||
- Followed shows (watchlist)
|
||||
"""
|
||||
|
||||
# Folder and service configuration
|
||||
config: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# User preferences
|
||||
preferences: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"preferred_quality": "1080p",
|
||||
"preferred_languages": ["en", "fr"],
|
||||
"auto_organize": False,
|
||||
"naming_format": "{title}.{year}.{quality}",
|
||||
}
|
||||
)
|
||||
|
||||
# Library of owned media
|
||||
library: dict[str, list[dict]] = field(
|
||||
default_factory=lambda: {"movies": [], "tv_shows": []}
|
||||
)
|
||||
|
||||
# Followed shows (watchlist)
|
||||
following: list[dict] = field(default_factory=list)
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a configuration value."""
|
||||
return self.config.get(key, default)
|
||||
|
||||
def set_config(self, key: str, value: Any) -> None:
|
||||
"""Set a configuration value."""
|
||||
self.config[key] = value
|
||||
logger.debug(f"LTM: Set config {key}")
|
||||
|
||||
def has_config(self, key: str) -> bool:
|
||||
"""Check if a configuration exists."""
|
||||
return key in self.config and self.config[key] is not None
|
||||
|
||||
def add_to_library(self, media_type: str, media: dict) -> None:
|
||||
"""Add a media item to the library."""
|
||||
if media_type not in self.library:
|
||||
self.library[media_type] = []
|
||||
|
||||
# Avoid duplicates by imdb_id
|
||||
existing_ids = [m.get("imdb_id") for m in self.library[media_type]]
|
||||
if media.get("imdb_id") not in existing_ids:
|
||||
media["added_at"] = datetime.now().isoformat()
|
||||
self.library[media_type].append(media)
|
||||
logger.info(f"LTM: Added {media.get('title')} to {media_type}")
|
||||
|
||||
def get_library(self, media_type: str) -> list[dict]:
|
||||
"""Get the library for a media type."""
|
||||
return self.library.get(media_type, [])
|
||||
|
||||
def follow_show(self, show: dict) -> None:
|
||||
"""Add a show to the watchlist."""
|
||||
existing_ids = [s.get("imdb_id") for s in self.following]
|
||||
if show.get("imdb_id") not in existing_ids:
|
||||
show["followed_at"] = datetime.now().isoformat()
|
||||
self.following.append(show)
|
||||
logger.info(f"LTM: Now following {show.get('title')}")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"config": self.config,
|
||||
"preferences": self.preferences,
|
||||
"library": self.library,
|
||||
"following": self.following,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "LongTermMemory":
|
||||
"""Create an instance from a dictionary."""
|
||||
return cls(
|
||||
config=data.get("config", {}),
|
||||
preferences=data.get(
|
||||
"preferences",
|
||||
{
|
||||
"preferred_quality": "1080p",
|
||||
"preferred_languages": ["en", "fr"],
|
||||
"auto_organize": False,
|
||||
"naming_format": "{title}.{year}.{quality}",
|
||||
},
|
||||
),
|
||||
library=data.get("library", {"movies": [], "tv_shows": []}),
|
||||
following=data.get("following", []),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SHORT-TERM MEMORY (STM) - Conversation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShortTermMemory:
|
||||
"""
|
||||
Short-term memory - Volatile and conversational.
|
||||
|
||||
Stores:
|
||||
- Current conversation history
|
||||
- Current workflow (what we're doing)
|
||||
- Extracted entities from conversation
|
||||
- Current discussion topic
|
||||
"""
|
||||
|
||||
# Conversation message history
|
||||
conversation_history: list[dict[str, str]] = field(default_factory=list)
|
||||
|
||||
# Current workflow
|
||||
current_workflow: dict | None = None
|
||||
|
||||
# Extracted entities (title, year, requested quality, etc.)
|
||||
extracted_entities: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Current conversation topic
|
||||
current_topic: str | None = None
|
||||
|
||||
# History message limit
|
||||
max_history: int = 20
|
||||
|
||||
def add_message(self, role: str, content: str) -> None:
|
||||
"""Add a message to history."""
|
||||
self.conversation_history.append(
|
||||
{"role": role, "content": content, "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
# Keep only the last N messages
|
||||
if len(self.conversation_history) > self.max_history:
|
||||
self.conversation_history = self.conversation_history[-self.max_history :]
|
||||
logger.debug(f"STM: Added {role} message")
|
||||
|
||||
def get_recent_history(self, n: int = 10) -> list[dict]:
|
||||
"""Get the last N messages."""
|
||||
return self.conversation_history[-n:]
|
||||
|
||||
def start_workflow(self, workflow_type: str, target: dict) -> None:
|
||||
"""Start a new workflow."""
|
||||
self.current_workflow = {
|
||||
"type": workflow_type,
|
||||
"target": target,
|
||||
"stage": "started",
|
||||
"started_at": datetime.now().isoformat(),
|
||||
}
|
||||
logger.info(f"STM: Started workflow '{workflow_type}'")
|
||||
|
||||
def update_workflow_stage(self, stage: str) -> None:
|
||||
"""Update the workflow stage."""
|
||||
if self.current_workflow:
|
||||
self.current_workflow["stage"] = stage
|
||||
logger.debug(f"STM: Workflow stage -> {stage}")
|
||||
|
||||
def end_workflow(self) -> None:
|
||||
"""End the current workflow."""
|
||||
if self.current_workflow:
|
||||
logger.info(f"STM: Ended workflow '{self.current_workflow.get('type')}'")
|
||||
self.current_workflow = None
|
||||
|
||||
def set_entity(self, key: str, value: Any) -> None:
|
||||
"""Store an extracted entity."""
|
||||
self.extracted_entities[key] = value
|
||||
logger.debug(f"STM: Set entity {key}={value}")
|
||||
|
||||
def get_entity(self, key: str, default: Any = None) -> Any:
|
||||
"""Get an extracted entity."""
|
||||
return self.extracted_entities.get(key, default)
|
||||
|
||||
def clear_entities(self) -> None:
|
||||
"""Clear extracted entities."""
|
||||
self.extracted_entities = {}
|
||||
|
||||
def set_topic(self, topic: str) -> None:
|
||||
"""Set the current topic."""
|
||||
self.current_topic = topic
|
||||
logger.debug(f"STM: Topic -> {topic}")
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Reset short-term memory."""
|
||||
self.conversation_history = []
|
||||
self.current_workflow = None
|
||||
self.extracted_entities = {}
|
||||
self.current_topic = None
|
||||
logger.info("STM: Cleared")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"conversation_history": self.conversation_history,
|
||||
"current_workflow": self.current_workflow,
|
||||
"extracted_entities": self.extracted_entities,
|
||||
"current_topic": self.current_topic,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# EPISODIC MEMORY - Transient states
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodicMemory:
|
||||
"""
|
||||
Episodic/sensory memory - Temporary and event-driven.
|
||||
|
||||
Stores:
|
||||
- Last search results
|
||||
- Active downloads
|
||||
- Recent errors
|
||||
- Pending questions awaiting user response
|
||||
- Background events
|
||||
"""
|
||||
|
||||
# Last search results
|
||||
last_search_results: dict | None = None
|
||||
|
||||
# Active downloads
|
||||
active_downloads: list[dict] = field(default_factory=list)
|
||||
|
||||
# Recent errors
|
||||
recent_errors: list[dict] = field(default_factory=list)
|
||||
|
||||
# Pending question awaiting user response
|
||||
pending_question: dict | None = None
|
||||
|
||||
# Background events (download complete, new files, etc.)
|
||||
background_events: list[dict] = field(default_factory=list)
|
||||
|
||||
# Limits for errors/events kept
|
||||
max_errors: int = 5
|
||||
max_events: int = 10
|
||||
|
||||
def store_search_results(
|
||||
self, query: str, results: list[dict], search_type: str = "torrent"
|
||||
) -> None:
|
||||
"""
|
||||
Store search results with index.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
results: List of results
|
||||
search_type: Type of search (torrent, movie, tvshow)
|
||||
"""
|
||||
self.last_search_results = {
|
||||
"query": query,
|
||||
"type": search_type,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"results": [{"index": i + 1, **r} for i, r in enumerate(results)],
|
||||
}
|
||||
logger.info(f"Episodic: Stored {len(results)} search results for '{query}'")
|
||||
|
||||
def get_result_by_index(self, index: int) -> dict | None:
|
||||
"""
|
||||
Get a result by its number (1-indexed).
|
||||
|
||||
Args:
|
||||
index: Result number (1, 2, 3, ...)
|
||||
|
||||
Returns:
|
||||
The result or None if not found
|
||||
"""
|
||||
if not self.last_search_results:
|
||||
logger.warning("Episodic: No search results stored")
|
||||
return None
|
||||
|
||||
for result in self.last_search_results.get("results", []):
|
||||
if result.get("index") == index:
|
||||
return result
|
||||
|
||||
logger.warning(f"Episodic: Result #{index} not found")
|
||||
return None
|
||||
|
||||
def get_search_results(self) -> dict | None:
|
||||
"""Get the last search results."""
|
||||
return self.last_search_results
|
||||
|
||||
def clear_search_results(self) -> None:
|
||||
"""Clear search results."""
|
||||
self.last_search_results = None
|
||||
|
||||
def add_active_download(self, download: dict) -> None:
|
||||
"""Add an active download."""
|
||||
download["started_at"] = datetime.now().isoformat()
|
||||
self.active_downloads.append(download)
|
||||
logger.info(f"Episodic: Added download '{download.get('name')}'")
|
||||
|
||||
def update_download_progress(
|
||||
self, task_id: str, progress: int, status: str = "downloading"
|
||||
) -> None:
|
||||
"""Update download progress."""
|
||||
for dl in self.active_downloads:
|
||||
if dl.get("task_id") == task_id:
|
||||
dl["progress"] = progress
|
||||
dl["status"] = status
|
||||
dl["updated_at"] = datetime.now().isoformat()
|
||||
break
|
||||
|
||||
def complete_download(self, task_id: str, file_path: str) -> dict | None:
|
||||
"""Mark a download as complete and remove it."""
|
||||
for i, dl in enumerate(self.active_downloads):
|
||||
if dl.get("task_id") == task_id:
|
||||
completed = self.active_downloads.pop(i)
|
||||
completed["status"] = "completed"
|
||||
completed["file_path"] = file_path
|
||||
completed["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
# Add a background event
|
||||
self.add_background_event(
|
||||
"download_complete",
|
||||
{"name": completed.get("name"), "file_path": file_path},
|
||||
)
|
||||
|
||||
logger.info(f"Episodic: Download completed '{completed.get('name')}'")
|
||||
return completed
|
||||
return None
|
||||
|
||||
def get_active_downloads(self) -> list[dict]:
|
||||
"""Get active downloads."""
|
||||
return self.active_downloads
|
||||
|
||||
def add_error(
|
||||
self, action: str, error: str, context: dict | None = None
|
||||
) -> None:
|
||||
"""Record a recent error."""
|
||||
self.recent_errors.append(
|
||||
{
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"action": action,
|
||||
"error": error,
|
||||
"context": context or {},
|
||||
}
|
||||
)
|
||||
# Keep only the last N errors
|
||||
self.recent_errors = self.recent_errors[-self.max_errors :]
|
||||
logger.warning(f"Episodic: Error in '{action}': {error}")
|
||||
|
||||
def get_recent_errors(self) -> list[dict]:
|
||||
"""Get recent errors."""
|
||||
return self.recent_errors
|
||||
|
||||
def set_pending_question(
|
||||
self,
|
||||
question: str,
|
||||
options: list[dict],
|
||||
context: dict,
|
||||
question_type: str = "choice",
|
||||
) -> None:
|
||||
"""
|
||||
Record a question awaiting user response.
|
||||
|
||||
Args:
|
||||
question: The question asked
|
||||
options: List of possible options
|
||||
context: Question context
|
||||
question_type: Type of question (choice, confirmation, input)
|
||||
"""
|
||||
self.pending_question = {
|
||||
"type": question_type,
|
||||
"question": question,
|
||||
"options": options,
|
||||
"context": context,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
logger.info(f"Episodic: Pending question set ({question_type})")
|
||||
|
||||
def get_pending_question(self) -> dict | None:
|
||||
"""Get the pending question."""
|
||||
return self.pending_question
|
||||
|
||||
def resolve_pending_question(
|
||||
self, answer_index: int | None = None
|
||||
) -> dict | None:
|
||||
"""
|
||||
Resolve the pending question and return the chosen option.
|
||||
|
||||
Args:
|
||||
answer_index: Answer index (1-indexed) or None to cancel
|
||||
|
||||
Returns:
|
||||
The chosen option or None
|
||||
"""
|
||||
if not self.pending_question:
|
||||
return None
|
||||
|
||||
result = None
|
||||
if answer_index is not None and self.pending_question.get("options"):
|
||||
for opt in self.pending_question["options"]:
|
||||
if opt.get("index") == answer_index:
|
||||
result = opt
|
||||
break
|
||||
|
||||
self.pending_question = None
|
||||
logger.info("Episodic: Pending question resolved")
|
||||
return result
|
||||
|
||||
def add_background_event(self, event_type: str, data: dict) -> None:
|
||||
"""Add a background event."""
|
||||
self.background_events.append(
|
||||
{
|
||||
"type": event_type,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": data,
|
||||
"read": False,
|
||||
}
|
||||
)
|
||||
# Keep only the last N events
|
||||
self.background_events = self.background_events[-self.max_events :]
|
||||
logger.info(f"Episodic: Background event '{event_type}'")
|
||||
|
||||
def get_unread_events(self) -> list[dict]:
|
||||
"""Get unread events and mark them as read."""
|
||||
unread = [e for e in self.background_events if not e.get("read")]
|
||||
for e in self.background_events:
|
||||
e["read"] = True
|
||||
return unread
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Reset episodic memory."""
|
||||
self.last_search_results = None
|
||||
self.active_downloads = []
|
||||
self.recent_errors = []
|
||||
self.pending_question = None
|
||||
self.background_events = []
|
||||
logger.info("Episodic: Cleared")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"last_search_results": self.last_search_results,
|
||||
"active_downloads": self.active_downloads,
|
||||
"recent_errors": self.recent_errors,
|
||||
"pending_question": self.pending_question,
|
||||
"background_events": self.background_events,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MEMORY MANAGER - Unified manager
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class Memory:
|
||||
"""
|
||||
Generic memory storage for agent state.
|
||||
Unified manager for the 3 memory types.
|
||||
|
||||
Provides a simple key-value store that persists to JSON.
|
||||
Usage:
|
||||
memory = Memory("memory_data")
|
||||
memory.ltm.set_config("download_folder", "/path")
|
||||
memory.stm.add_message("user", "Hello")
|
||||
memory.episodic.store_search_results("query", results)
|
||||
memory.save()
|
||||
"""
|
||||
|
||||
def __init__(self, path: str = "memory.json"):
|
||||
self.file = Path(path)
|
||||
self.data: Dict[str, Any] = {}
|
||||
self.load()
|
||||
def __init__(self, storage_dir: str = "memory_data"):
|
||||
"""
|
||||
Initialize the memory.
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load memory from file or initialize with defaults."""
|
||||
if self.file.exists():
|
||||
Args:
|
||||
storage_dir: Directory for persistent storage
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.ltm_file = self.storage_dir / "ltm.json"
|
||||
|
||||
# Initialize the 3 memory types
|
||||
self.ltm = self._load_ltm()
|
||||
self.stm = ShortTermMemory()
|
||||
self.episodic = EpisodicMemory()
|
||||
|
||||
logger.info(f"Memory initialized (storage: {storage_dir})")
|
||||
|
||||
def _load_ltm(self) -> LongTermMemory:
|
||||
"""Load LTM from file."""
|
||||
if self.ltm_file.exists():
|
||||
try:
|
||||
self.data = json.loads(self.file.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
print(f"Warning: Could not load memory file: {e}")
|
||||
self.data = {
|
||||
"config": {},
|
||||
"tv_shows": [],
|
||||
"history": [],
|
||||
}
|
||||
else:
|
||||
self.data = {
|
||||
"config": {},
|
||||
"tv_shows": [],
|
||||
"history": [],
|
||||
}
|
||||
data = json.loads(self.ltm_file.read_text(encoding="utf-8"))
|
||||
logger.info("LTM loaded from file")
|
||||
return LongTermMemory.from_dict(data)
|
||||
except (OSError, json.JSONDecodeError) as e:
|
||||
logger.warning(f"Could not load LTM: {e}")
|
||||
return LongTermMemory()
|
||||
|
||||
def save(self) -> None:
|
||||
self.file.write_text(
|
||||
json.dumps(self.data, indent=2, ensure_ascii=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
"""Save LTM (the only persistent memory)."""
|
||||
try:
|
||||
self.ltm_file.write_text(
|
||||
json.dumps(self.ltm.to_dict(), indent=2, ensure_ascii=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.debug("LTM saved to file")
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to save LTM: {e}")
|
||||
raise
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a value from memory by key."""
|
||||
return self.data.get(key, default)
|
||||
def get_context_for_prompt(self) -> dict:
|
||||
"""
|
||||
Generate context to include in the system prompt.
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
Returns:
|
||||
Dictionary with relevant context from all 3 memories
|
||||
"""
|
||||
Set a value in memory and save.
|
||||
|
||||
Validates the value against the parameter schema if one exists.
|
||||
"""
|
||||
# Validate if schema exists
|
||||
is_valid, error_msg = validate_parameter(key, value)
|
||||
if not is_valid:
|
||||
print(f'Validation failed for {key}: {error_msg}')
|
||||
raise ValueError(f"Invalid value for {key}: {error_msg}")
|
||||
|
||||
print(f'Setting {key} in memory to: {value}')
|
||||
self.data[key] = value
|
||||
self.save()
|
||||
return {
|
||||
"config": self.ltm.config,
|
||||
"preferences": self.ltm.preferences,
|
||||
"current_workflow": self.stm.current_workflow,
|
||||
"current_topic": self.stm.current_topic,
|
||||
"extracted_entities": self.stm.extracted_entities,
|
||||
"last_search": {
|
||||
"query": (
|
||||
self.episodic.last_search_results.get("query")
|
||||
if self.episodic.last_search_results
|
||||
else None
|
||||
),
|
||||
"result_count": (
|
||||
len(self.episodic.last_search_results.get("results", []))
|
||||
if self.episodic.last_search_results
|
||||
else 0
|
||||
),
|
||||
},
|
||||
"active_downloads_count": len(self.episodic.active_downloads),
|
||||
"pending_question": self.episodic.pending_question is not None,
|
||||
"unread_events": len(
|
||||
[e for e in self.episodic.background_events if not e.get("read")]
|
||||
),
|
||||
}
|
||||
|
||||
def has(self, key: str) -> bool:
|
||||
"""Check if a key exists and has a non-None value."""
|
||||
return key in self.data and self.data[key] is not None
|
||||
|
||||
def append_history(self, role: str, content: str) -> None:
|
||||
"""
|
||||
Append a message to conversation history.
|
||||
|
||||
Args:
|
||||
role: Message role ('user' or 'assistant')
|
||||
content: Message content
|
||||
"""
|
||||
if "history" not in self.data:
|
||||
self.data["history"] = []
|
||||
|
||||
self.data["history"].append({
|
||||
"role": role,
|
||||
"content": content
|
||||
})
|
||||
self.save()
|
||||
def get_full_state(self) -> dict:
|
||||
"""Return the full state of all 3 memories (for debug)."""
|
||||
return {
|
||||
"ltm": self.ltm.to_dict(),
|
||||
"stm": self.stm.to_dict(),
|
||||
"episodic": self.episodic.to_dict(),
|
||||
}
|
||||
|
||||
def clear_session(self) -> None:
|
||||
"""Clear session memories (STM + Episodic)."""
|
||||
self.stm.clear()
|
||||
self.episodic.clear()
|
||||
logger.info("Session memories cleared")
|
||||
|
||||
Reference in New Issue
Block a user