"""Tests for the Memory system.""" import json import pytest from infrastructure.persistence import ( EpisodicMemory, LongTermMemory, Memory, ShortTermMemory, get_memory, has_memory, init_memory, set_memory, ) from infrastructure.persistence.context import _memory_ctx class TestLongTermMemory: """Tests for LongTermMemory.""" def test_default_values(self): """LTM should have sensible defaults.""" ltm = LongTermMemory() assert ltm.config == {} assert ltm.preferences["preferred_quality"] == "1080p" assert "en" in ltm.preferences["preferred_languages"] assert ltm.library == {"movies": [], "tv_shows": []} assert ltm.following == [] def test_set_and_get_config(self): """Should set and retrieve config values.""" ltm = LongTermMemory() ltm.set_config("download_folder", "/path/to/downloads") assert ltm.get_config("download_folder") == "/path/to/downloads" def test_get_config_default(self): """Should return default for missing config.""" ltm = LongTermMemory() assert ltm.get_config("nonexistent") is None assert ltm.get_config("nonexistent", "default") == "default" def test_has_config(self): """Should check if config exists.""" ltm = LongTermMemory() assert not ltm.has_config("download_folder") ltm.set_config("download_folder", "/path") assert ltm.has_config("download_folder") def test_has_config_none_value(self): """Should return False for None values.""" ltm = LongTermMemory() ltm.config["key"] = None assert not ltm.has_config("key") def test_add_to_library(self): """Should add media to library.""" ltm = LongTermMemory() movie = {"imdb_id": "tt1375666", "title": "Inception"} ltm.add_to_library("movies", movie) assert len(ltm.library["movies"]) == 1 assert ltm.library["movies"][0]["title"] == "Inception" assert "added_at" in ltm.library["movies"][0] def test_add_to_library_no_duplicates(self): """Should not add duplicate media.""" ltm = LongTermMemory() movie = {"imdb_id": "tt1375666", "title": "Inception"} ltm.add_to_library("movies", movie) ltm.add_to_library("movies", movie) assert len(ltm.library["movies"]) == 1 def test_add_to_library_new_type(self): """Should create new media type if not exists.""" ltm = LongTermMemory() subtitle = {"imdb_id": "tt1375666", "language": "en"} ltm.add_to_library("subtitles", subtitle) assert "subtitles" in ltm.library assert len(ltm.library["subtitles"]) == 1 def test_get_library(self): """Should get library for media type.""" ltm = LongTermMemory() ltm.add_to_library("movies", {"imdb_id": "tt1", "title": "Movie 1"}) ltm.add_to_library("movies", {"imdb_id": "tt2", "title": "Movie 2"}) movies = ltm.get_library("movies") assert len(movies) == 2 def test_get_library_empty(self): """Should return empty list for unknown type.""" ltm = LongTermMemory() assert ltm.get_library("unknown") == [] def test_follow_show(self): """Should add show to following list.""" ltm = LongTermMemory() show = {"imdb_id": "tt0944947", "title": "Game of Thrones"} ltm.follow_show(show) assert len(ltm.following) == 1 assert ltm.following[0]["title"] == "Game of Thrones" assert "followed_at" in ltm.following[0] def test_follow_show_no_duplicates(self): """Should not follow same show twice.""" ltm = LongTermMemory() show = {"imdb_id": "tt0944947", "title": "Game of Thrones"} ltm.follow_show(show) ltm.follow_show(show) assert len(ltm.following) == 1 def test_to_dict(self): """Should serialize to dict.""" ltm = LongTermMemory() ltm.set_config("key", "value") data = ltm.to_dict() assert "config" in data assert "preferences" in data assert "library" in data assert "following" in data assert data["config"]["key"] == "value" def test_from_dict(self): """Should deserialize from dict.""" data = { "config": {"download_folder": "/downloads"}, "preferences": {"preferred_quality": "4K"}, "library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]}, "following": [], } ltm = LongTermMemory.from_dict(data) assert ltm.get_config("download_folder") == "/downloads" assert ltm.preferences["preferred_quality"] == "4K" assert len(ltm.library["movies"]) == 1 def test_from_dict_missing_keys(self): """Should handle missing keys with defaults.""" ltm = LongTermMemory.from_dict({}) assert ltm.config == {} assert ltm.preferences["preferred_quality"] == "1080p" class TestShortTermMemory: """Tests for ShortTermMemory.""" def test_default_values(self): """STM should start empty.""" stm = ShortTermMemory() assert stm.conversation_history == [] assert stm.current_workflow is None assert stm.extracted_entities == {} assert stm.current_topic is None def test_add_message(self): """Should add message to history.""" stm = ShortTermMemory() stm.add_message("user", "Hello") assert len(stm.conversation_history) == 1 assert stm.conversation_history[0]["role"] == "user" assert stm.conversation_history[0]["content"] == "Hello" assert "timestamp" in stm.conversation_history[0] def test_add_message_max_history(self): """Should limit history to max_history.""" stm = ShortTermMemory() stm.max_history = 5 for i in range(10): stm.add_message("user", f"Message {i}") assert len(stm.conversation_history) == 5 assert stm.conversation_history[0]["content"] == "Message 5" def test_get_recent_history(self): """Should get last N messages.""" stm = ShortTermMemory() for i in range(10): stm.add_message("user", f"Message {i}") recent = stm.get_recent_history(3) assert len(recent) == 3 assert recent[0]["content"] == "Message 7" def test_get_recent_history_less_than_n(self): """Should return all if less than N messages.""" stm = ShortTermMemory() stm.add_message("user", "Hello") stm.add_message("assistant", "Hi") recent = stm.get_recent_history(10) assert len(recent) == 2 def test_start_workflow(self): """Should start a workflow.""" stm = ShortTermMemory() stm.start_workflow("download", {"title": "Inception"}) assert stm.current_workflow is not None assert stm.current_workflow["type"] == "download" assert stm.current_workflow["target"]["title"] == "Inception" assert stm.current_workflow["stage"] == "started" def test_update_workflow_stage(self): """Should update workflow stage.""" stm = ShortTermMemory() stm.start_workflow("download", {"title": "Inception"}) stm.update_workflow_stage("searching") assert stm.current_workflow["stage"] == "searching" def test_update_workflow_stage_no_workflow(self): """Should do nothing if no workflow.""" stm = ShortTermMemory() stm.update_workflow_stage("searching") # Should not raise assert stm.current_workflow is None def test_end_workflow(self): """Should end workflow.""" stm = ShortTermMemory() stm.start_workflow("download", {"title": "Inception"}) stm.end_workflow() assert stm.current_workflow is None def test_set_and_get_entity(self): """Should set and get entities.""" stm = ShortTermMemory() stm.set_entity("movie_title", "Inception") stm.set_entity("year", 2010) assert stm.get_entity("movie_title") == "Inception" assert stm.get_entity("year") == 2010 def test_get_entity_default(self): """Should return default for missing entity.""" stm = ShortTermMemory() assert stm.get_entity("nonexistent") is None assert stm.get_entity("nonexistent", "default") == "default" def test_clear_entities(self): """Should clear all entities.""" stm = ShortTermMemory() stm.set_entity("key1", "value1") stm.set_entity("key2", "value2") stm.clear_entities() assert stm.extracted_entities == {} def test_set_topic(self): """Should set current topic.""" stm = ShortTermMemory() stm.set_topic("searching_movie") assert stm.current_topic == "searching_movie" def test_clear(self): """Should clear all STM data.""" stm = ShortTermMemory() stm.add_message("user", "Hello") stm.start_workflow("download", {}) stm.set_entity("key", "value") stm.set_topic("topic") stm.clear() assert stm.conversation_history == [] assert stm.current_workflow is None assert stm.extracted_entities == {} assert stm.current_topic is None def test_to_dict(self): """Should serialize to dict.""" stm = ShortTermMemory() stm.add_message("user", "Hello") stm.set_topic("test") data = stm.to_dict() assert "conversation_history" in data assert "current_workflow" in data assert "extracted_entities" in data assert "current_topic" in data class TestEpisodicMemory: """Tests for EpisodicMemory.""" def test_default_values(self): """Episodic should start empty.""" episodic = EpisodicMemory() assert episodic.last_search_results is None assert episodic.active_downloads == [] assert episodic.recent_errors == [] assert episodic.pending_question is None assert episodic.background_events == [] def test_store_search_results(self): """Should store search results with indexes.""" episodic = EpisodicMemory() results = [ {"name": "Result 1", "seeders": 100}, {"name": "Result 2", "seeders": 50}, ] episodic.store_search_results("test query", results) assert episodic.last_search_results is not None assert episodic.last_search_results["query"] == "test query" assert len(episodic.last_search_results["results"]) == 2 assert episodic.last_search_results["results"][0]["index"] == 1 assert episodic.last_search_results["results"][1]["index"] == 2 def test_get_result_by_index(self): """Should get result by 1-based index.""" episodic = EpisodicMemory() results = [ {"name": "Result 1"}, {"name": "Result 2"}, {"name": "Result 3"}, ] episodic.store_search_results("query", results) result = episodic.get_result_by_index(2) assert result is not None assert result["name"] == "Result 2" def test_get_result_by_index_not_found(self): """Should return None for invalid index.""" episodic = EpisodicMemory() results = [{"name": "Result 1"}] episodic.store_search_results("query", results) assert episodic.get_result_by_index(5) is None assert episodic.get_result_by_index(0) is None assert episodic.get_result_by_index(-1) is None def test_get_result_by_index_no_results(self): """Should return None if no search results.""" episodic = EpisodicMemory() assert episodic.get_result_by_index(1) is None def test_clear_search_results(self): """Should clear search results.""" episodic = EpisodicMemory() episodic.store_search_results("query", [{"name": "Result"}]) episodic.clear_search_results() assert episodic.last_search_results is None def test_add_active_download(self): """Should add download with timestamp.""" episodic = EpisodicMemory() episodic.add_active_download( { "task_id": "123", "name": "Test Movie", "magnet": "magnet:?xt=...", } ) assert len(episodic.active_downloads) == 1 assert episodic.active_downloads[0]["name"] == "Test Movie" assert "started_at" in episodic.active_downloads[0] def test_update_download_progress(self): """Should update download progress.""" episodic = EpisodicMemory() episodic.add_active_download({"task_id": "123", "name": "Test"}) episodic.update_download_progress("123", 50, "downloading") assert episodic.active_downloads[0]["progress"] == 50 assert episodic.active_downloads[0]["status"] == "downloading" def test_update_download_progress_not_found(self): """Should do nothing for unknown task_id.""" episodic = EpisodicMemory() episodic.add_active_download({"task_id": "123", "name": "Test"}) episodic.update_download_progress("999", 50) # Should not raise assert episodic.active_downloads[0].get("progress") is None def test_complete_download(self): """Should complete download and add event.""" episodic = EpisodicMemory() episodic.add_active_download({"task_id": "123", "name": "Test Movie"}) completed = episodic.complete_download("123", "/path/to/file.mkv") assert len(episodic.active_downloads) == 0 assert completed["status"] == "completed" assert completed["file_path"] == "/path/to/file.mkv" assert len(episodic.background_events) == 1 assert episodic.background_events[0]["type"] == "download_complete" def test_complete_download_not_found(self): """Should return None for unknown task_id.""" episodic = EpisodicMemory() result = episodic.complete_download("999", "/path") assert result is None def test_add_error(self): """Should add error with timestamp.""" episodic = EpisodicMemory() episodic.add_error("find_torrent", "API timeout", {"query": "test"}) assert len(episodic.recent_errors) == 1 assert episodic.recent_errors[0]["action"] == "find_torrent" assert episodic.recent_errors[0]["error"] == "API timeout" def test_add_error_max_limit(self): """Should limit errors to max_errors.""" episodic = EpisodicMemory() episodic.max_errors = 3 for i in range(5): episodic.add_error("action", f"Error {i}") assert len(episodic.recent_errors) == 3 assert episodic.recent_errors[0]["error"] == "Error 2" def test_set_pending_question(self): """Should set pending question.""" episodic = EpisodicMemory() options = [ {"index": 1, "label": "Option 1"}, {"index": 2, "label": "Option 2"}, ] episodic.set_pending_question( "Which one?", options, {"context": "test"}, "choice", ) assert episodic.pending_question is not None assert episodic.pending_question["question"] == "Which one?" assert len(episodic.pending_question["options"]) == 2 def test_resolve_pending_question(self): """Should resolve question and return chosen option.""" episodic = EpisodicMemory() options = [ {"index": 1, "label": "Option 1"}, {"index": 2, "label": "Option 2"}, ] episodic.set_pending_question("Which?", options, {}) result = episodic.resolve_pending_question(2) assert result["label"] == "Option 2" assert episodic.pending_question is None def test_resolve_pending_question_cancel(self): """Should cancel question if no index.""" episodic = EpisodicMemory() episodic.set_pending_question("Which?", [], {}) result = episodic.resolve_pending_question(None) assert result is None assert episodic.pending_question is None def test_add_background_event(self): """Should add background event.""" episodic = EpisodicMemory() episodic.add_background_event("download_complete", {"name": "Movie"}) assert len(episodic.background_events) == 1 assert episodic.background_events[0]["type"] == "download_complete" assert episodic.background_events[0]["read"] is False def test_add_background_event_max_limit(self): """Should limit events to max_events.""" episodic = EpisodicMemory() episodic.max_events = 3 for i in range(5): episodic.add_background_event("event", {"i": i}) assert len(episodic.background_events) == 3 def test_get_unread_events(self): """Should get unread events and mark as read.""" episodic = EpisodicMemory() episodic.add_background_event("event1", {}) episodic.add_background_event("event2", {}) unread = episodic.get_unread_events() assert len(unread) == 2 assert all(e["read"] for e in episodic.background_events) def test_get_unread_events_already_read(self): """Should not return already read events.""" episodic = EpisodicMemory() episodic.add_background_event("event1", {}) episodic.get_unread_events() # Mark as read episodic.add_background_event("event2", {}) unread = episodic.get_unread_events() assert len(unread) == 1 assert unread[0]["type"] == "event2" def test_clear(self): """Should clear all episodic data.""" episodic = EpisodicMemory() episodic.store_search_results("query", [{}]) episodic.add_active_download({"task_id": "1", "name": "Test"}) episodic.add_error("action", "error") episodic.set_pending_question("?", [], {}) episodic.add_background_event("event", {}) episodic.clear() assert episodic.last_search_results is None assert episodic.active_downloads == [] assert episodic.recent_errors == [] assert episodic.pending_question is None assert episodic.background_events == [] class TestMemory: """Tests for the Memory manager.""" def test_init_creates_directories(self, temp_dir): """Should create storage directory.""" storage = temp_dir / "memory_data" memory = Memory(storage_dir=str(storage)) assert storage.exists() def test_init_loads_existing_ltm(self, temp_dir): """Should load existing LTM from file.""" ltm_file = temp_dir / "ltm.json" ltm_file.write_text( json.dumps( { "config": {"download_folder": "/downloads"}, "preferences": {"preferred_quality": "4K"}, "library": {"movies": []}, "following": [], } ) ) memory = Memory(storage_dir=str(temp_dir)) assert memory.ltm.get_config("download_folder") == "/downloads" assert memory.ltm.preferences["preferred_quality"] == "4K" def test_init_handles_corrupted_ltm(self, temp_dir): """Should handle corrupted LTM file.""" ltm_file = temp_dir / "ltm.json" ltm_file.write_text("not valid json {{{") memory = Memory(storage_dir=str(temp_dir)) assert memory.ltm.config == {} # Default values def test_save(self, temp_dir): """Should save LTM to file.""" memory = Memory(storage_dir=str(temp_dir)) memory.ltm.set_config("test_key", "test_value") memory.save() ltm_file = temp_dir / "ltm.json" assert ltm_file.exists() data = json.loads(ltm_file.read_text()) assert data["config"]["test_key"] == "test_value" def test_get_context_for_prompt(self, memory_with_search_results): """Should generate context for prompt.""" context = memory_with_search_results.get_context_for_prompt() assert "config" in context assert "preferences" in context assert context["last_search"]["query"] == "Inception 1080p" assert context["last_search"]["result_count"] == 3 def test_get_full_state(self, memory): """Should return full state of all memories.""" state = memory.get_full_state() assert "ltm" in state assert "stm" in state assert "episodic" in state def test_clear_session(self, memory_with_search_results): """Should clear STM and Episodic but keep LTM.""" memory_with_search_results.ltm.set_config("key", "value") memory_with_search_results.stm.add_message("user", "Hello") memory_with_search_results.clear_session() assert memory_with_search_results.ltm.get_config("key") == "value" assert memory_with_search_results.stm.conversation_history == [] assert memory_with_search_results.episodic.last_search_results is None class TestMemoryContext: """Tests for memory context functions.""" def test_init_memory(self, temp_dir): """Should initialize and set memory in context.""" _memory_ctx.set(None) # Reset context memory = init_memory(str(temp_dir)) assert memory is not None assert has_memory() assert get_memory() is memory def test_set_memory(self, temp_dir): """Should set existing memory in context.""" _memory_ctx.set(None) memory = Memory(storage_dir=str(temp_dir)) set_memory(memory) assert get_memory() is memory def test_get_memory_not_initialized(self): """Should raise if memory not initialized.""" _memory_ctx.set(None) with pytest.raises(RuntimeError, match="Memory not initialized"): get_memory() def test_has_memory(self, temp_dir): """Should check if memory is initialized.""" _memory_ctx.set(None) assert not has_memory() init_memory(str(temp_dir)) assert has_memory()