Formatting

This commit is contained in:
2025-12-07 03:33:51 +01:00
parent a923a760ef
commit 4eae1d6d58
24 changed files with 1003 additions and 833 deletions
+32 -40
View File
@@ -1,7 +1,8 @@
"""Main agent for media library management.""" """Main agent for media library management."""
import json import json
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any
from infrastructure.persistence import get_memory from infrastructure.persistence import get_memory
@@ -28,7 +29,7 @@ class Agent:
max_tool_iterations: Maximum number of tool execution iterations max_tool_iterations: Maximum number of tool execution iterations
""" """
self.llm = llm self.llm = llm
self.tools: Dict[str, Tool] = make_tools() self.tools: dict[str, Tool] = make_tools()
self.prompt_builder = PromptBuilder(self.tools) self.prompt_builder = PromptBuilder(self.tools)
self.max_tool_iterations = max_tool_iterations self.max_tool_iterations = max_tool_iterations
@@ -56,9 +57,7 @@ class Agent:
# Build initial messages # Build initial messages
system_prompt = self.prompt_builder.build_system_prompt() system_prompt = self.prompt_builder.build_system_prompt()
messages: List[Dict[str, Any]] = [ messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
{"role": "system", "content": system_prompt}
]
# Add conversation history # Add conversation history
history = memory.stm.get_recent_history(settings.max_history_messages) history = memory.stm.get_recent_history(settings.max_history_messages)
@@ -67,14 +66,12 @@ class Agent:
# Add unread events if any # Add unread events if any
unread_events = memory.episodic.get_unread_events() unread_events = memory.episodic.get_unread_events()
if unread_events: if unread_events:
events_text = "\n".join([ events_text = "\n".join(
f"- {e['type']}: {e['data']}" [f"- {e['type']}: {e['data']}" for e in unread_events]
for e in unread_events )
]) messages.append(
messages.append({ {"role": "system", "content": f"Background events:\n{events_text}"}
"role": "system", )
"content": f"Background events:\n{events_text}"
})
# Get tools specification for OpenAI format # Get tools specification for OpenAI format
tools_spec = self.prompt_builder.build_tools_spec() tools_spec = self.prompt_builder.build_tools_spec()
@@ -108,18 +105,22 @@ class Agent:
tool_result = self._execute_tool_call(tool_call) tool_result = self._execute_tool_call(tool_call)
# Add tool result to messages # Add tool result to messages
messages.append({ messages.append(
"tool_call_id": tool_call.get("id"), {
"role": "tool", "tool_call_id": tool_call.get("id"),
"name": tool_call.get("function", {}).get("name"), "role": "tool",
"content": json.dumps(tool_result, ensure_ascii=False), "name": tool_call.get("function", {}).get("name"),
}) "content": json.dumps(tool_result, ensure_ascii=False),
}
)
# Max iterations reached, force final response # Max iterations reached, force final response
messages.append({ messages.append(
"role": "system", {
"content": "Please provide a final response to the user without using any more tools." "role": "system",
}) "content": "Please provide a final response to the user without using any more tools.",
}
)
llm_result = self.llm.complete(messages) llm_result = self.llm.complete(messages)
if isinstance(llm_result, tuple): if isinstance(llm_result, tuple):
@@ -127,12 +128,14 @@ class Agent:
else: else:
final_message = llm_result final_message = llm_result
final_response = final_message.get("content", "I've completed the requested actions.") final_response = final_message.get(
"content", "I've completed the requested actions."
)
memory.stm.add_message("assistant", final_response) memory.stm.add_message("assistant", final_response)
memory.save() memory.save()
return final_response return final_response
def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Dict[str, Any]: def _execute_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]:
""" """
Execute a single tool call. Execute a single tool call.
@@ -150,10 +153,7 @@ class Agent:
args = json.loads(args_str) args = json.loads(args_str)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"Failed to parse tool arguments: {e}") logger.error(f"Failed to parse tool arguments: {e}")
return { return {"error": "bad_args", "message": f"Invalid JSON arguments: {e}"}
"error": "bad_args",
"message": f"Invalid JSON arguments: {e}"
}
# Validate tool exists # Validate tool exists
if tool_name not in self.tools: if tool_name not in self.tools:
@@ -161,7 +161,7 @@ class Agent:
return { return {
"error": "unknown_tool", "error": "unknown_tool",
"message": f"Tool '{tool_name}' not found", "message": f"Tool '{tool_name}' not found",
"available_tools": available "available_tools": available,
} }
tool = self.tools[tool_name] tool = self.tools[tool_name]
@@ -177,17 +177,9 @@ class Agent:
# Bad arguments # Bad arguments
memory = get_memory() memory = get_memory()
memory.episodic.add_error(tool_name, f"bad_args: {e}") memory.episodic.add_error(tool_name, f"bad_args: {e}")
return { return {"error": "bad_args", "message": str(e), "tool": tool_name}
"error": "bad_args",
"message": str(e),
"tool": tool_name
}
except Exception as e: except Exception as e:
# Other errors # Other errors
memory = get_memory() memory = get_memory()
memory.episodic.add_error(tool_name, str(e)) memory.episodic.add_error(tool_name, str(e))
return { return {"error": "execution_failed", "message": str(e), "tool": tool_name}
"error": "execution_failed",
"message": str(e),
"tool": tool_name
}
+9 -3
View File
@@ -51,7 +51,9 @@ class DeepSeekClient:
logger.info(f"DeepSeek client initialized with model: {self.model}") logger.info(f"DeepSeek client initialized with model: {self.model}")
def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]: def complete(
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
""" """
Generate a completion from the LLM. Generate a completion from the LLM.
@@ -80,7 +82,9 @@ class DeepSeekClient:
raise ValueError(f"Invalid role: {msg['role']}") raise ValueError(f"Invalid role: {msg['role']}")
# Content is optional for tool messages (they may have tool_call_id instead) # Content is optional for tool messages (they may have tool_call_id instead)
if msg["role"] != "tool" and "content" not in msg: if msg["role"] != "tool" and "content" not in msg:
raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}") raise ValueError(
f"Non-tool message must have 'content' key, got {msg.keys()}"
)
url = f"{self.base_url}/v1/chat/completions" url = f"{self.base_url}/v1/chat/completions"
headers = { headers = {
@@ -98,7 +102,9 @@ class DeepSeekClient:
payload["tools"] = tools payload["tools"] = tools
try: try:
logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools") logger.debug(
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
)
response = requests.post( response = requests.post(
url, headers=headers, json=payload, timeout=self.timeout url, headers=headers, json=payload, timeout=self.timeout
) )
+9 -3
View File
@@ -66,7 +66,9 @@ class OllamaClient:
logger.info(f"Ollama client initialized with model: {self.model}") logger.info(f"Ollama client initialized with model: {self.model}")
def complete(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None) -> dict[str, Any]: def complete(
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None
) -> dict[str, Any]:
""" """
Generate a completion from the LLM. Generate a completion from the LLM.
@@ -95,7 +97,9 @@ class OllamaClient:
raise ValueError(f"Invalid role: {msg['role']}") raise ValueError(f"Invalid role: {msg['role']}")
# Content is optional for tool messages (they may have tool_call_id instead) # Content is optional for tool messages (they may have tool_call_id instead)
if msg["role"] != "tool" and "content" not in msg: if msg["role"] != "tool" and "content" not in msg:
raise ValueError(f"Non-tool message must have 'content' key, got {msg.keys()}") raise ValueError(
f"Non-tool message must have 'content' key, got {msg.keys()}"
)
url = f"{self.base_url}/api/chat" url = f"{self.base_url}/api/chat"
payload = { payload = {
@@ -112,7 +116,9 @@ class OllamaClient:
payload["tools"] = tools payload["tools"] = tools
try: try:
logger.debug(f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools") logger.debug(
f"Sending request to {url} with {len(messages)} messages and {len(tools) if tools else 0} tools"
)
response = requests.post(url, json=payload, timeout=self.timeout) response = requests.post(url, json=payload, timeout=self.timeout)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
+20 -12
View File
@@ -1,18 +1,20 @@
"""Prompt builder for the agent system.""" """Prompt builder for the agent system."""
from typing import Dict, List, Any
import json import json
from typing import Any
from infrastructure.persistence import get_memory
from .registry import Tool from .registry import Tool
from infrastructure.persistence import get_memory
class PromptBuilder: class PromptBuilder:
"""Builds system prompts for the agent with memory context.""" """Builds system prompts for the agent with memory context."""
def __init__(self, tools: Dict[str, Tool]): def __init__(self, tools: dict[str, Tool]):
self.tools = tools self.tools = tools
def build_tools_spec(self) -> List[Dict[str, Any]]: def build_tools_spec(self) -> list[dict[str, Any]]:
"""Build the tool specification for the LLM API.""" """Build the tool specification for the LLM API."""
tool_specs = [] tool_specs = []
for tool in self.tools.values(): for tool in self.tools.values():
@@ -44,11 +46,13 @@ class PromptBuilder:
if memory.episodic.last_search_results: if memory.episodic.last_search_results:
results = memory.episodic.last_search_results results = memory.episodic.last_search_results
result_list = results.get('results', []) result_list = results.get("results", [])
lines.append(f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)") lines.append(
f"\nLAST SEARCH: '{results.get('query')}' ({len(result_list)} results)"
)
# Show first 5 results # Show first 5 results
for i, result in enumerate(result_list[:5]): for i, result in enumerate(result_list[:5]):
name = result.get('name', 'Unknown') name = result.get("name", "Unknown")
lines.append(f" {i+1}. {name}") lines.append(f" {i+1}. {name}")
if len(result_list) > 5: if len(result_list) > 5:
lines.append(f" ... and {len(result_list) - 5} more") lines.append(f" ... and {len(result_list) - 5} more")
@@ -57,7 +61,7 @@ class PromptBuilder:
question = memory.episodic.pending_question question = memory.episodic.pending_question
lines.append(f"\nPENDING QUESTION: {question.get('question')}") lines.append(f"\nPENDING QUESTION: {question.get('question')}")
lines.append(f" Type: {question.get('type')}") lines.append(f" Type: {question.get('type')}")
if question.get('options'): if question.get("options"):
lines.append(f" Options: {len(question.get('options'))}") lines.append(f" Options: {len(question.get('options'))}")
if memory.episodic.active_downloads: if memory.episodic.active_downloads:
@@ -68,10 +72,12 @@ class PromptBuilder:
if memory.episodic.recent_errors: if memory.episodic.recent_errors:
lines.append("\nRECENT ERRORS (up to 3):") lines.append("\nRECENT ERRORS (up to 3):")
for error in memory.episodic.recent_errors[-3:]: for error in memory.episodic.recent_errors[-3:]:
lines.append(f" - Action '{error.get('action')}' failed: {error.get('error')}") lines.append(
f" - Action '{error.get('action')}' failed: {error.get('error')}"
)
# Unread events # Unread events
unread = [e for e in memory.episodic.background_events if not e.get('read')] unread = [e for e in memory.episodic.background_events if not e.get("read")]
if unread: if unread:
lines.append(f"\nUNREAD EVENTS: {len(unread)}") lines.append(f"\nUNREAD EVENTS: {len(unread)}")
for event in unread[:3]: for event in unread[:3]:
@@ -86,8 +92,10 @@ class PromptBuilder:
if memory.stm.current_workflow: if memory.stm.current_workflow:
workflow = memory.stm.current_workflow workflow = memory.stm.current_workflow
lines.append(f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})") lines.append(
if workflow.get('target'): f"CURRENT WORKFLOW: {workflow.get('type')} (stage: {workflow.get('stage')})"
)
if workflow.get("target"):
lines.append(f" Target: {workflow.get('target')}") lines.append(f" Target: {workflow.get('target')}")
if memory.stm.current_topic: if memory.stm.current_topic:
+12 -9
View File
@@ -1,8 +1,10 @@
"""Tool registry - defines and registers all available tools for the agent.""" """Tool registry - defines and registers all available tools for the agent."""
from dataclasses import dataclass
from typing import Callable, Any, Dict
import logging
import inspect import inspect
import logging
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -10,10 +12,11 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class Tool: class Tool:
"""Represents a tool that can be used by the agent.""" """Represents a tool that can be used by the agent."""
name: str name: str
description: str description: str
func: Callable[..., Dict[str, Any]] func: Callable[..., dict[str, Any]]
parameters: Dict[str, Any] parameters: dict[str, Any]
def _create_tool_from_function(func: Callable) -> Tool: def _create_tool_from_function(func: Callable) -> Tool:
@@ -30,7 +33,7 @@ def _create_tool_from_function(func: Callable) -> Tool:
doc = inspect.getdoc(func) doc = inspect.getdoc(func)
# Extract description from docstring (first line) # Extract description from docstring (first line)
description = doc.strip().split('\n')[0] if doc else func.__name__ description = doc.strip().split("\n")[0] if doc else func.__name__
# Build JSON schema from function signature # Build JSON schema from function signature
properties = {} properties = {}
@@ -54,7 +57,7 @@ def _create_tool_from_function(func: Callable) -> Tool:
properties[param_name] = { properties[param_name] = {
"type": param_type, "type": param_type,
"description": f"Parameter {param_name}" "description": f"Parameter {param_name}",
} }
# Add to required if no default value # Add to required if no default value
@@ -75,7 +78,7 @@ def _create_tool_from_function(func: Callable) -> Tool:
) )
def make_tools() -> Dict[str, Tool]: def make_tools() -> dict[str, Tool]:
""" """
Create and register all available tools. Create and register all available tools.
@@ -83,8 +86,8 @@ def make_tools() -> Dict[str, Tool]:
Dictionary mapping tool names to Tool objects Dictionary mapping tool names to Tool objects
""" """
# Import tools here to avoid circular dependencies # Import tools here to avoid circular dependencies
from .tools import filesystem as fs_tools
from .tools import api as api_tools from .tools import api as api_tools
from .tools import filesystem as fs_tools
from .tools import language as lang_tools from .tools import language as lang_tools
# List of all tool functions # List of all tool functions
+5 -7
View File
@@ -1,13 +1,14 @@
"""Language management tools for the agent.""" """Language management tools for the agent."""
import logging import logging
from typing import Dict, Any from typing import Any
from infrastructure.persistence import get_memory from infrastructure.persistence import get_memory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def set_language(language: str) -> Dict[str, Any]: def set_language(language: str) -> dict[str, Any]:
""" """
Set the conversation language. Set the conversation language.
@@ -27,11 +28,8 @@ def set_language(language: str) -> Dict[str, Any]:
return { return {
"status": "ok", "status": "ok",
"message": f"Language set to {language}", "message": f"Language set to {language}",
"language": language "language": language,
} }
except Exception as e: except Exception as e:
logger.error(f"Failed to set language: {e}") logger.error(f"Failed to set language: {e}")
return { return {"status": "error", "error": str(e)}
"status": "error",
"error": str(e)
}
+2 -6
View File
@@ -359,9 +359,7 @@ class EpisodicMemory:
"""Get active downloads.""" """Get active downloads."""
return self.active_downloads return self.active_downloads
def add_error( def add_error(self, action: str, error: str, context: dict | None = None) -> None:
self, action: str, error: str, context: dict | None = None
) -> None:
"""Record a recent error.""" """Record a recent error."""
self.recent_errors.append( self.recent_errors.append(
{ {
@@ -408,9 +406,7 @@ class EpisodicMemory:
"""Get the pending question.""" """Get the pending question."""
return self.pending_question return self.pending_question
def resolve_pending_question( def resolve_pending_question(self, answer_index: int | None = None) -> dict | None:
self, answer_index: int | None = None
) -> dict | None:
""" """
Resolve the pending question and return the chosen option. Resolve the pending question and return the chosen option.
+1 -1
View File
@@ -110,4 +110,4 @@ select = [
"PL", "PL",
"UP", "UP",
] ]
ignore = ["W503", "PLR0913", "PLR2004"] ignore = ["PLR0913", "PLR2004"]
+28 -33
View File
@@ -1,16 +1,13 @@
"""Pytest configuration and shared fixtures.""" """Pytest configuration and shared fixtures."""
import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import Mock, MagicMock
from infrastructure.persistence import Memory, init_memory, set_memory, get_memory import shutil
from infrastructure.persistence.memory import ( import tempfile
LongTermMemory, from pathlib import Path
ShortTermMemory, from unittest.mock import MagicMock, Mock
EpisodicMemory,
) import pytest
from infrastructure.persistence import Memory, set_memory
@pytest.fixture @pytest.fixture
@@ -122,12 +119,11 @@ def memory_with_library(memory):
def mock_llm(): def mock_llm():
"""Create a mock LLM client that returns OpenAI-compatible format.""" """Create a mock LLM client that returns OpenAI-compatible format."""
llm = Mock() llm = Mock()
# Return OpenAI-style message dict without tool calls # Return OpenAI-style message dict without tool calls
def complete_func(messages, tools=None): def complete_func(messages, tools=None):
return { return {"role": "assistant", "content": "I found what you're looking for!"}
"role": "assistant",
"content": "I found what you're looking for!"
}
llm.complete = Mock(side_effect=complete_func) llm.complete = Mock(side_effect=complete_func)
return llm return llm
@@ -139,7 +135,7 @@ def mock_llm_with_tool_call():
# First call returns a tool call, second returns final response # First call returns a tool call, second returns final response
def complete_side_effect(messages, tools=None): def complete_side_effect(messages, tools=None):
if not hasattr(complete_side_effect, 'call_count'): if not hasattr(complete_side_effect, "call_count"):
complete_side_effect.call_count = 0 complete_side_effect.call_count = 0
complete_side_effect.call_count += 1 complete_side_effect.call_count += 1
@@ -148,21 +144,20 @@ def mock_llm_with_tool_call():
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_123", {
"type": "function", "id": "call_123",
"function": { "type": "function",
"name": "find_torrent", "function": {
"arguments": '{"media_title": "Inception"}' "name": "find_torrent",
"arguments": '{"media_title": "Inception"}',
},
} }
}] ],
} }
else: else:
# Second call: return final response # Second call: return final response
return { return {"role": "assistant", "content": "I found 3 torrents for Inception!"}
"role": "assistant",
"content": "I found 3 torrents for Inception!"
}
llm.complete = Mock(side_effect=complete_side_effect) llm.complete = Mock(side_effect=complete_side_effect)
return llm return llm
@@ -254,10 +249,10 @@ def mock_deepseek():
# Your test code here # Your test code here
""" """
import sys import sys
from unittest.mock import Mock, MagicMock from unittest.mock import Mock
# Save the original module if it exists # Save the original module if it exists
original_module = sys.modules.get('agent.llm.deepseek') original_module = sys.modules.get("agent.llm.deepseek")
# Create a mock module for deepseek # Create a mock module for deepseek
mock_deepseek_module = MagicMock() mock_deepseek_module = MagicMock()
@@ -269,15 +264,15 @@ def mock_deepseek():
mock_deepseek_module.DeepSeekClient = MockDeepSeekClient mock_deepseek_module.DeepSeekClient = MockDeepSeekClient
# Inject the mock # Inject the mock
sys.modules['agent.llm.deepseek'] = mock_deepseek_module sys.modules["agent.llm.deepseek"] = mock_deepseek_module
yield mock_deepseek_module yield mock_deepseek_module
# Restore the original module # Restore the original module
if original_module is not None: if original_module is not None:
sys.modules['agent.llm.deepseek'] = original_module sys.modules["agent.llm.deepseek"] = original_module
elif 'agent.llm.deepseek' in sys.modules: elif "agent.llm.deepseek" in sys.modules:
del sys.modules['agent.llm.deepseek'] del sys.modules["agent.llm.deepseek"]
@pytest.fixture @pytest.fixture
+37 -43
View File
@@ -1,6 +1,6 @@
"""Tests for the Agent.""" """Tests for the Agent."""
from unittest.mock import Mock, patch from unittest.mock import Mock
from agent.agent import Agent from agent.agent import Agent
from infrastructure.persistence import get_memory from infrastructure.persistence import get_memory
@@ -55,8 +55,8 @@ class TestExecuteToolCall:
"id": "call_123", "id": "call_123",
"function": { "function": {
"name": "list_folder", "name": "list_folder",
"arguments": '{"folder_type": "download"}' "arguments": '{"folder_type": "download"}',
} },
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -68,10 +68,7 @@ class TestExecuteToolCall:
tool_call = { tool_call = {
"id": "call_123", "id": "call_123",
"function": { "function": {"name": "unknown_tool", "arguments": "{}"},
"name": "unknown_tool",
"arguments": '{}'
}
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -84,10 +81,7 @@ class TestExecuteToolCall:
tool_call = { tool_call = {
"id": "call_123", "id": "call_123",
"function": { "function": {"name": "set_path_for_folder", "arguments": "{}"},
"name": "set_path_for_folder",
"arguments": '{}'
}
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -102,8 +96,8 @@ class TestExecuteToolCall:
"id": "call_123", "id": "call_123",
"function": { "function": {
"name": "set_path_for_folder", "name": "set_path_for_folder",
"arguments": '{"folder_name": 123}' # Wrong type "arguments": '{"folder_name": 123}', # Wrong type
} },
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -116,10 +110,7 @@ class TestExecuteToolCall:
tool_call = { tool_call = {
"id": "call_123", "id": "call_123",
"function": { "function": {"name": "list_folder", "arguments": "{invalid json}"},
"name": "list_folder",
"arguments": '{invalid json}'
}
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -163,8 +154,8 @@ class TestStep:
# CRITICAL: Verify tools were passed to LLM # CRITICAL: Verify tools were passed to LLM
first_call_args = mock_llm_with_tool_call.complete.call_args_list[0] first_call_args = mock_llm_with_tool_call.complete.call_args_list[0]
assert first_call_args[1]['tools'] is not None, "Tools not passed to LLM!" assert first_call_args[1]["tools"] is not None, "Tools not passed to LLM!"
assert len(first_call_args[1]['tools']) > 0, "Tools list is empty!" assert len(first_call_args[1]["tools"]) > 0, "Tools list is empty!"
def test_step_max_iterations(self, memory, mock_llm): def test_step_max_iterations(self, memory, mock_llm):
"""Should stop after max iterations.""" """Should stop after max iterations."""
@@ -180,19 +171,18 @@ class TestStep:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": f"call_{call_count[0]}", {
"function": { "id": f"call_{call_count[0]}",
"name": "list_folder", "function": {
"arguments": '{"folder_type": "download"}' "name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
} }
}] ],
} }
else: else:
return { return {"role": "assistant", "content": "I couldn't complete the task."}
"role": "assistant",
"content": "I couldn't complete the task."
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3) agent = Agent(llm=mock_llm, max_tool_iterations=3)
@@ -251,34 +241,38 @@ class TestAgentIntegration:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_1", {
"function": { "id": "call_1",
"name": "list_folder", "function": {
"arguments": '{"folder_type": "download"}' "name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
} }
}] ],
} }
elif call_count[0] == 2: elif call_count[0] == 2:
# CRITICAL: Verify tool result was sent back # CRITICAL: Verify tool result was sent back
tool_messages = [m for m in messages if m.get('role') == 'tool'] tool_messages = [m for m in messages if m.get("role") == "tool"]
assert len(tool_messages) > 0, "Tool result not sent back to LLM!" assert len(tool_messages) > 0, "Tool result not sent back to LLM!"
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_2", {
"function": { "id": "call_2",
"name": "list_folder", "function": {
"arguments": '{"folder_type": "movie"}' "name": "list_folder",
"arguments": '{"folder_type": "movie"}',
},
} }
}] ],
} }
else: else:
return { return {
"role": "assistant", "role": "assistant",
"content": "I listed both folders for you." "content": "I listed both folders for you.",
} }
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
+69 -84
View File
@@ -1,7 +1,9 @@
"""Edge case tests for the Agent.""" """Edge case tests for the Agent."""
import pytest
from unittest.mock import Mock from unittest.mock import Mock
import pytest
from agent.agent import Agent from agent.agent import Agent
from infrastructure.persistence import get_memory from infrastructure.persistence import get_memory
@@ -15,19 +17,14 @@ class TestExecuteToolCallEdgeCases:
# Mock a tool that returns None # Mock a tool that returns None
from agent.registry import Tool from agent.registry import Tool
agent.tools["test_tool"] = Tool( agent.tools["test_tool"] = Tool(
name="test_tool", name="test_tool", description="Test", func=lambda: None, parameters={}
description="Test",
func=lambda: None,
parameters={}
) )
tool_call = { tool_call = {
"id": "call_123", "id": "call_123",
"function": { "function": {"name": "test_tool", "arguments": "{}"},
"name": "test_tool",
"arguments": '{}'
}
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -38,22 +35,17 @@ class TestExecuteToolCallEdgeCases:
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
from agent.registry import Tool from agent.registry import Tool
def raise_interrupt(): def raise_interrupt():
raise KeyboardInterrupt() raise KeyboardInterrupt()
agent.tools["test_tool"] = Tool( agent.tools["test_tool"] = Tool(
name="test_tool", name="test_tool", description="Test", func=raise_interrupt, parameters={}
description="Test",
func=raise_interrupt,
parameters={}
) )
tool_call = { tool_call = {
"id": "call_123", "id": "call_123",
"function": { "function": {"name": "test_tool", "arguments": "{}"},
"name": "test_tool",
"arguments": '{}'
}
} }
with pytest.raises(KeyboardInterrupt): with pytest.raises(KeyboardInterrupt):
@@ -68,8 +60,8 @@ class TestExecuteToolCallEdgeCases:
"id": "call_123", "id": "call_123",
"function": { "function": {
"name": "list_folder", "name": "list_folder",
"arguments": '{"folder_type": "download", "extra_arg": "ignored"}' "arguments": '{"folder_type": "download", "extra_arg": "ignored"}',
} },
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -84,8 +76,8 @@ class TestExecuteToolCallEdgeCases:
"id": "call_123", "id": "call_123",
"function": { "function": {
"name": "get_torrent_by_index", "name": "get_torrent_by_index",
"arguments": '{"index": "not an int"}' "arguments": '{"index": "not an int"}',
} },
} }
result = agent._execute_tool_call(tool_call) result = agent._execute_tool_call(tool_call)
@@ -115,11 +107,9 @@ class TestStepEdgeCases:
def test_step_with_unicode_input(self, memory, mock_llm): def test_step_with_unicode_input(self, memory, mock_llm):
"""Should handle unicode input.""" """Should handle unicode input."""
def mock_complete(messages, tools=None): def mock_complete(messages, tools=None):
return { return {"role": "assistant", "content": "日本語の応答"}
"role": "assistant",
"content": "日本語の応答"
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
@@ -130,11 +120,9 @@ class TestStepEdgeCases:
def test_step_llm_returns_empty(self, memory, mock_llm): def test_step_llm_returns_empty(self, memory, mock_llm):
"""Should handle LLM returning empty string.""" """Should handle LLM returning empty string."""
def mock_complete(messages, tools=None): def mock_complete(messages, tools=None):
return { return {"role": "assistant", "content": ""}
"role": "assistant",
"content": ""
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
@@ -161,18 +149,17 @@ class TestStepEdgeCases:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": f"call_{call_count[0]}", {
"function": { "id": f"call_{call_count[0]}",
"name": "list_folder", "function": {
"arguments": '{"folder_type": "download"}' "name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
} }
}] ],
} }
return { return {"role": "assistant", "content": "Done looping"}
"role": "assistant",
"content": "Done looping"
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3) agent = Agent(llm=mock_llm, max_tool_iterations=3)
@@ -212,11 +199,13 @@ class TestStepEdgeCases:
def test_step_with_active_downloads(self, memory, mock_llm): def test_step_with_active_downloads(self, memory, mock_llm):
"""Should include active downloads in context.""" """Should include active downloads in context."""
memory.episodic.add_active_download({ memory.episodic.add_active_download(
"task_id": "123", {
"name": "Movie.mkv", "task_id": "123",
"progress": 50, "name": "Movie.mkv",
}) "progress": 50,
}
)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
response = agent.step("Hello") response = agent.step("Hello")
@@ -264,18 +253,17 @@ class TestAgentConcurrencyEdgeCases:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_1", {
"function": { "id": "call_1",
"name": "set_path_for_folder", "function": {
"arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}' "name": "set_path_for_folder",
"arguments": f'{{"folder_name": "movie", "path_value": "{str(real_folder["movies"])}"}}',
},
} }
}] ],
} }
return { return {"role": "assistant", "content": "Path set successfully."}
"role": "assistant",
"content": "Path set successfully."
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
@@ -299,18 +287,17 @@ class TestAgentErrorRecovery:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_1", {
"function": { "id": "call_1",
"name": "list_folder", "function": {
"arguments": '{"folder_type": "download"}' "name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
} }
}] ],
} }
return { return {"role": "assistant", "content": "The folder is not configured."}
"role": "assistant",
"content": "The folder is not configured."
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
@@ -329,18 +316,17 @@ class TestAgentErrorRecovery:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_1", {
"function": { "id": "call_1",
"name": "set_path_for_folder", "function": {
"arguments": '{}' # Missing required args "name": "set_path_for_folder",
"arguments": "{}", # Missing required args
},
} }
}] ],
} }
return { return {"role": "assistant", "content": "Error occurred."}
"role": "assistant",
"content": "Error occurred."
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm) agent = Agent(llm=mock_llm)
@@ -360,18 +346,17 @@ class TestAgentErrorRecovery:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": f"call_{call_count[0]}", {
"function": { "id": f"call_{call_count[0]}",
"name": "set_path_for_folder", "function": {
"arguments": '{}' # Missing required args - will error "name": "set_path_for_folder",
"arguments": "{}", # Missing required args - will error
},
} }
}] ],
} }
return { return {"role": "assistant", "content": "All attempts failed."}
"role": "assistant",
"content": "All attempts failed."
}
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
agent = Agent(llm=mock_llm, max_tool_iterations=3) agent = Agent(llm=mock_llm, max_tool_iterations=3)
+63 -31
View File
@@ -1,6 +1,7 @@
"""Tests for FastAPI endpoints.""" """Tests for FastAPI endpoints."""
import pytest
from unittest.mock import Mock, patch, MagicMock from unittest.mock import patch
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@@ -10,6 +11,7 @@ class TestHealthEndpoint:
def test_health_check(self, memory): def test_health_check(self, memory):
"""Should return healthy status.""" """Should return healthy status."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/health") response = client.get("/health")
@@ -24,6 +26,7 @@ class TestModelsEndpoint:
def test_list_models(self, memory): def test_list_models(self, memory):
"""Should return model list.""" """Should return model list."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/v1/models") response = client.get("/v1/models")
@@ -41,6 +44,7 @@ class TestMemoryEndpoints:
def test_get_memory_state(self, memory): def test_get_memory_state(self, memory):
"""Should return full memory state.""" """Should return full memory state."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/memory/state") response = client.get("/memory/state")
@@ -54,6 +58,7 @@ class TestMemoryEndpoints:
def test_get_search_results_empty(self, memory): def test_get_search_results_empty(self, memory):
"""Should return empty when no search results.""" """Should return empty when no search results."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/memory/episodic/search-results") response = client.get("/memory/episodic/search-results")
@@ -65,6 +70,7 @@ class TestMemoryEndpoints:
def test_get_search_results_with_data(self, memory_with_search_results): def test_get_search_results_with_data(self, memory_with_search_results):
"""Should return search results when available.""" """Should return search results when available."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/memory/episodic/search-results") response = client.get("/memory/episodic/search-results")
@@ -78,6 +84,7 @@ class TestMemoryEndpoints:
def test_clear_session(self, memory_with_search_results): def test_clear_session(self, memory_with_search_results):
"""Should clear session memories.""" """Should clear session memories."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/memory/clear-session") response = client.post("/memory/clear-session")
@@ -96,14 +103,18 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_success(self, memory): def test_chat_completion_success(self, memory):
"""Should return chat completion.""" """Should return chat completion."""
from app import app from app import app
# Patch the agent's step method directly # Patch the agent's step method directly
with patch("app.agent.step", return_value="Hello! How can I help?"): with patch("app.agent.step", return_value="Hello! How can I help?"):
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "Hello"}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": "Hello"}],
},
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -113,12 +124,16 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_no_user_message(self, memory): def test_chat_completion_no_user_message(self, memory):
"""Should return error if no user message.""" """Should return error if no user message."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "system", "content": "You are helpful"}], json={
}) "model": "agent-media",
"messages": [{"role": "system", "content": "You are helpful"}],
},
)
assert response.status_code == 422 assert response.status_code == 422
detail = response.json()["detail"] detail = response.json()["detail"]
@@ -132,18 +147,23 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_empty_messages(self, memory): def test_chat_completion_empty_messages(self, memory):
"""Should return error for empty messages.""" """Should return error for empty messages."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [], json={
}) "model": "agent-media",
"messages": [],
},
)
assert response.status_code == 422 assert response.status_code == 422
def test_chat_completion_invalid_json(self, memory): def test_chat_completion_invalid_json(self, memory):
"""Should return error for invalid JSON.""" """Should return error for invalid JSON."""
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post( response = client.post(
@@ -157,14 +177,18 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_streaming(self, memory): def test_chat_completion_streaming(self, memory):
"""Should support streaming mode.""" """Should support streaming mode."""
from app import app from app import app
with patch("app.agent.step", return_value="Streaming response"): with patch("app.agent.step", return_value="Streaming response"):
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "Hello"}], json={
"stream": True, "model": "agent-media",
}) "messages": [{"role": "user", "content": "Hello"}],
"stream": True,
},
)
assert response.status_code == 200 assert response.status_code == 200
assert "text/event-stream" in response.headers["content-type"] assert "text/event-stream" in response.headers["content-type"]
@@ -172,17 +196,21 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_extracts_last_user_message(self, memory): def test_chat_completion_extracts_last_user_message(self, memory):
"""Should use last user message.""" """Should use last user message."""
from app import app from app import app
with patch("app.agent.step", return_value="Response") as mock_step: with patch("app.agent.step", return_value="Response") as mock_step:
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [ json={
{"role": "user", "content": "First message"}, "model": "agent-media",
{"role": "assistant", "content": "Response"}, "messages": [
{"role": "user", "content": "Second message"}, {"role": "user", "content": "First message"},
], {"role": "assistant", "content": "Response"},
}) {"role": "user", "content": "Second message"},
],
},
)
assert response.status_code == 200 assert response.status_code == 200
# Verify the agent received the last user message # Verify the agent received the last user message
@@ -191,13 +219,17 @@ class TestChatCompletionsEndpoint:
def test_chat_completion_response_format(self, memory): def test_chat_completion_response_format(self, memory):
"""Should return OpenAI-compatible format.""" """Should return OpenAI-compatible format."""
from app import app from app import app
with patch("app.agent.step", return_value="Test response"): with patch("app.agent.step", return_value="Test response"):
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "Test"}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": "Test"}],
},
)
data = response.json() data = response.json()
assert "id" in data assert "id" in data
+169 -120
View File
@@ -1,7 +1,7 @@
"""Edge case tests for FastAPI endpoints.""" """Edge case tests for FastAPI endpoints."""
import pytest
import json from unittest.mock import Mock, patch
from unittest.mock import Mock, patch, MagicMock
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@@ -10,43 +10,46 @@ class TestChatCompletionsEdgeCases:
def test_very_long_message(self, memory): def test_very_long_message(self, memory):
"""Should handle very long user message.""" """Should handle very long user message."""
from app import app, agent from app import agent, app
# Patch the agent's LLM directly # Patch the agent's LLM directly
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
"role": "assistant",
"content": "Response"
}
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
long_message = "x" * 100000 long_message = "x" * 100000
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": long_message}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": long_message}],
},
)
assert response.status_code == 200 assert response.status_code == 200
def test_unicode_message(self, memory): def test_unicode_message(self, memory):
"""Should handle unicode in message.""" """Should handle unicode in message."""
from app import app, agent from app import agent, app
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {
"role": "assistant", "role": "assistant",
"content": "日本語の応答" "content": "日本語の応答",
} }
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "日本語のメッセージ 🎬"}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": "日本語のメッセージ 🎬"}],
},
)
assert response.status_code == 200 assert response.status_code == 200
content = response.json()["choices"][0]["message"]["content"] content = response.json()["choices"][0]["message"]["content"]
@@ -54,22 +57,22 @@ class TestChatCompletionsEdgeCases:
def test_special_characters_in_message(self, memory): def test_special_characters_in_message(self, memory):
"""Should handle special characters.""" """Should handle special characters."""
from app import app, agent from app import agent, app
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
"role": "assistant",
"content": "Response"
}
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
special_message = 'Test with "quotes" and \\backslash and \n newline' special_message = 'Test with "quotes" and \\backslash and \n newline'
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": special_message}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": special_message}],
},
)
assert response.status_code == 200 assert response.status_code == 200
@@ -81,12 +84,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": ""}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": ""}],
},
)
# Empty content should be rejected # Empty content should be rejected
assert response.status_code == 422 assert response.status_code == 422
@@ -98,12 +105,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": None}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": None}],
},
)
assert response.status_code == 422 assert response.status_code == 422
@@ -114,12 +125,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user"}], # No content json={
}) "model": "agent-media",
"messages": [{"role": "user"}], # No content
},
)
# May accept or reject depending on validation # May accept or reject depending on validation
assert response.status_code in [200, 400, 422] assert response.status_code in [200, 400, 422]
@@ -131,12 +146,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"content": "Hello"}], # No role json={
}) "model": "agent-media",
"messages": [{"content": "Hello"}], # No role
},
)
# Should reject or accept depending on validation # Should reject or accept depending on validation
assert response.status_code in [200, 400, 422] assert response.status_code in [200, 400, 422]
@@ -149,25 +168,26 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "invalid_role", "content": "Hello"}], json={
}) "model": "agent-media",
"messages": [{"role": "invalid_role", "content": "Hello"}],
},
)
# Should reject or ignore invalid role # Should reject or ignore invalid role
assert response.status_code in [200, 400, 422] assert response.status_code in [200, 400, 422]
def test_many_messages(self, memory): def test_many_messages(self, memory):
"""Should handle many messages in conversation.""" """Should handle many messages in conversation."""
from app import app, agent from app import agent, app
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
"role": "assistant",
"content": "Response"
}
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
@@ -178,10 +198,13 @@ class TestChatCompletionsEdgeCases:
messages.append({"role": "assistant", "content": f"Response {i}"}) messages.append({"role": "assistant", "content": f"Response {i}"})
messages.append({"role": "user", "content": "Final message"}) messages.append({"role": "user", "content": "Final message"})
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": messages, json={
}) "model": "agent-media",
"messages": messages,
},
)
assert response.status_code == 200 assert response.status_code == 200
@@ -192,15 +215,19 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [ json={
{"role": "system", "content": "You are helpful"}, "model": "agent-media",
{"role": "system", "content": "Be concise"}, "messages": [
], {"role": "system", "content": "You are helpful"},
}) {"role": "system", "content": "Be concise"},
],
},
)
assert response.status_code == 422 assert response.status_code == 422
@@ -211,14 +238,18 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [ json={
{"role": "assistant", "content": "Hello"}, "model": "agent-media",
], "messages": [
}) {"role": "assistant", "content": "Hello"},
],
},
)
assert response.status_code == 422 assert response.status_code == 422
@@ -229,12 +260,16 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": "not an array", json={
}) "model": "agent-media",
"messages": "not an array",
},
)
assert response.status_code == 422 assert response.status_code == 422
# Pydantic validation error # Pydantic validation error
@@ -246,66 +281,70 @@ class TestChatCompletionsEdgeCases:
mock_llm_class.return_value = mock_llm mock_llm_class.return_value = mock_llm
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": ["not an object", 123, None], json={
}) "model": "agent-media",
"messages": ["not an object", 123, None],
},
)
assert response.status_code == 422 assert response.status_code == 422
# Pydantic validation error # Pydantic validation error
def test_extra_fields_in_request(self, memory): def test_extra_fields_in_request(self, memory):
"""Should ignore extra fields in request.""" """Should ignore extra fields in request."""
from app import app, agent from app import agent, app
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
"role": "assistant",
"content": "Response"
}
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "Hello"}], json={
"extra_field": "should be ignored", "model": "agent-media",
"temperature": 0.7, "messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100, "extra_field": "should be ignored",
}) "temperature": 0.7,
"max_tokens": 100,
},
)
assert response.status_code == 200 assert response.status_code == 200
def test_streaming_with_tool_call(self, memory, real_folder): def test_streaming_with_tool_call(self, memory, real_folder):
"""Should handle streaming with tool execution.""" """Should handle streaming with tool execution."""
from app import app, agent from app import agent, app
from infrastructure.persistence import get_memory from infrastructure.persistence import get_memory
mem = get_memory() mem = get_memory()
mem.ltm.set_config("download_folder", str(real_folder["downloads"])) mem.ltm.set_config("download_folder", str(real_folder["downloads"]))
call_count = [0] call_count = [0]
def mock_complete(messages, tools=None): def mock_complete(messages, tools=None):
call_count[0] += 1 call_count[0] += 1
if call_count[0] == 1: if call_count[0] == 1:
return { return {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{ "tool_calls": [
"id": "call_1", {
"function": { "id": "call_1",
"name": "list_folder", "function": {
"arguments": '{"folder_type": "download"}' "name": "list_folder",
"arguments": '{"folder_type": "download"}',
},
} }
}] ],
} }
return { return {"role": "assistant", "content": "Listed the folder."}
"role": "assistant",
"content": "Listed the folder."
}
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete = Mock(side_effect=mock_complete) mock_llm.complete = Mock(side_effect=mock_complete)
@@ -313,51 +352,57 @@ class TestChatCompletionsEdgeCases:
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "List downloads"}], json={
"stream": True, "model": "agent-media",
}) "messages": [{"role": "user", "content": "List downloads"}],
"stream": True,
},
)
assert response.status_code == 200 assert response.status_code == 200
def test_concurrent_requests_simulation(self, memory): def test_concurrent_requests_simulation(self, memory):
"""Should handle rapid sequential requests.""" """Should handle rapid sequential requests."""
from app import app, agent from app import agent, app
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {"role": "assistant", "content": "Response"}
"role": "assistant",
"content": "Response"
}
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
for i in range(10): for i in range(10):
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": f"Request {i}"}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": f"Request {i}"}],
},
)
assert response.status_code == 200 assert response.status_code == 200
def test_llm_returns_json_in_response(self, memory): def test_llm_returns_json_in_response(self, memory):
"""Should handle LLM returning JSON in text response.""" """Should handle LLM returning JSON in text response."""
from app import app, agent from app import agent, app
mock_llm = Mock() mock_llm = Mock()
mock_llm.complete.return_value = { mock_llm.complete.return_value = {
"role": "assistant", "role": "assistant",
"content": '{"result": "some data", "count": 5}' "content": '{"result": "some data", "count": 5}',
} }
agent.llm = mock_llm agent.llm = mock_llm
client = TestClient(app) client = TestClient(app)
response = client.post("/v1/chat/completions", json={ response = client.post(
"model": "agent-media", "/v1/chat/completions",
"messages": [{"role": "user", "content": "Give me JSON"}], json={
}) "model": "agent-media",
"messages": [{"role": "user", "content": "Give me JSON"}],
},
)
assert response.status_code == 200 assert response.status_code == 200
content = response.json()["choices"][0]["message"]["content"] content = response.json()["choices"][0]["message"]["content"]
@@ -425,6 +470,7 @@ class TestMemoryEndpointsEdgeCases:
with patch("app.DeepSeekClient") as mock_llm: with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock() mock_llm.return_value = Mock()
from app import app from app import app
client = TestClient(app) client = TestClient(app)
# Clear multiple times # Clear multiple times
@@ -459,6 +505,7 @@ class TestHealthEndpointEdgeCases:
with patch("app.DeepSeekClient") as mock_llm: with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock() mock_llm.return_value = Mock()
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/health") response = client.get("/health")
@@ -471,6 +518,7 @@ class TestHealthEndpointEdgeCases:
with patch("app.DeepSeekClient") as mock_llm: with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock() mock_llm.return_value = Mock()
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/health?extra=param&another=value") response = client.get("/health?extra=param&another=value")
@@ -486,6 +534,7 @@ class TestModelsEndpointEdgeCases:
with patch("app.DeepSeekClient") as mock_llm: with patch("app.DeepSeekClient") as mock_llm:
mock_llm.return_value = Mock() mock_llm.return_value = Mock()
from app import app from app import app
client = TestClient(app) client = TestClient(app)
response = client.get("/v1/models") response = client.get("/v1/models")
+7 -13
View File
@@ -1,9 +1,9 @@
"""Critical tests for configuration validation.""" """Critical tests for configuration validation."""
import pytest
import os
from agent.config import Settings, ConfigurationError import pytest
from agent.config import ConfigurationError, Settings
class TestConfigValidation: class TestConfigValidation:
@@ -86,8 +86,7 @@ class TestConfigChecks:
def test_is_deepseek_configured_with_key(self): def test_is_deepseek_configured_with_key(self):
"""Verify is_deepseek_configured returns True with API key.""" """Verify is_deepseek_configured returns True with API key."""
settings = Settings( settings = Settings(
deepseek_api_key="test-key", deepseek_api_key="test-key", deepseek_base_url="https://api.test.com"
deepseek_base_url="https://api.test.com"
) )
assert settings.is_deepseek_configured() is True assert settings.is_deepseek_configured() is True
@@ -95,8 +94,7 @@ class TestConfigChecks:
def test_is_deepseek_configured_without_key(self): def test_is_deepseek_configured_without_key(self):
"""Verify is_deepseek_configured returns False without API key.""" """Verify is_deepseek_configured returns False without API key."""
settings = Settings( settings = Settings(
deepseek_api_key="", deepseek_api_key="", deepseek_base_url="https://api.test.com"
deepseek_base_url="https://api.test.com"
) )
assert settings.is_deepseek_configured() is False assert settings.is_deepseek_configured() is False
@@ -110,18 +108,14 @@ class TestConfigChecks:
def test_is_tmdb_configured_with_key(self): def test_is_tmdb_configured_with_key(self):
"""Verify is_tmdb_configured returns True with API key.""" """Verify is_tmdb_configured returns True with API key."""
settings = Settings( settings = Settings(
tmdb_api_key="test-key", tmdb_api_key="test-key", tmdb_base_url="https://api.test.com"
tmdb_base_url="https://api.test.com"
) )
assert settings.is_tmdb_configured() is True assert settings.is_tmdb_configured() is True
def test_is_tmdb_configured_without_key(self): def test_is_tmdb_configured_without_key(self):
"""Verify is_tmdb_configured returns False without API key.""" """Verify is_tmdb_configured returns False without API key."""
settings = Settings( settings = Settings(tmdb_api_key="", tmdb_base_url="https://api.test.com")
tmdb_api_key="",
tmdb_base_url="https://api.test.com"
)
assert settings.is_tmdb_configured() is False assert settings.is_tmdb_configured() is False
+21 -11
View File
@@ -1,12 +1,14 @@
"""Edge case tests for configuration and parameters.""" """Edge case tests for configuration and parameters."""
import pytest
import os import os
from unittest.mock import patch from unittest.mock import patch
from agent.config import Settings, ConfigurationError import pytest
from agent.config import ConfigurationError, Settings
from agent.parameters import ( from agent.parameters import (
ParameterSchema,
REQUIRED_PARAMETERS, REQUIRED_PARAMETERS,
ParameterSchema,
format_parameters_for_prompt, format_parameters_for_prompt,
get_missing_required_parameters, get_missing_required_parameters,
) )
@@ -110,19 +112,27 @@ class TestSettingsEdgeCases:
def test_http_url_accepted(self): def test_http_url_accepted(self):
"""Should accept http:// URLs.""" """Should accept http:// URLs."""
with patch.dict(os.environ, { with patch.dict(
"DEEPSEEK_BASE_URL": "http://localhost:8080", os.environ,
"TMDB_BASE_URL": "http://localhost:3000", {
}, clear=True): "DEEPSEEK_BASE_URL": "http://localhost:8080",
"TMDB_BASE_URL": "http://localhost:3000",
},
clear=True,
):
settings = Settings() settings = Settings()
assert settings.deepseek_base_url == "http://localhost:8080" assert settings.deepseek_base_url == "http://localhost:8080"
def test_https_url_accepted(self): def test_https_url_accepted(self):
"""Should accept https:// URLs.""" """Should accept https:// URLs."""
with patch.dict(os.environ, { with patch.dict(
"DEEPSEEK_BASE_URL": "https://api.example.com", os.environ,
"TMDB_BASE_URL": "https://api.example.com", {
}, clear=True): "DEEPSEEK_BASE_URL": "https://api.example.com",
"TMDB_BASE_URL": "https://api.example.com",
},
clear=True,
):
settings = Settings() settings = Settings()
assert settings.deepseek_base_url == "https://api.example.com" assert settings.deepseek_base_url == "https://api.example.com"
+19 -12
View File
@@ -1,18 +1,17 @@
"""Tests for the Memory system.""" """Tests for the Memory system."""
import pytest
import json
from datetime import datetime from datetime import datetime
from pathlib import Path
import pytest
from infrastructure.persistence import ( from infrastructure.persistence import (
Memory,
LongTermMemory,
ShortTermMemory,
EpisodicMemory, EpisodicMemory,
init_memory, LongTermMemory,
Memory,
ShortTermMemory,
get_memory, get_memory,
set_memory,
has_memory, has_memory,
init_memory,
) )
from infrastructure.persistence.context import _memory_ctx from infrastructure.persistence.context import _memory_ctx
@@ -23,11 +22,12 @@ def is_iso_format(s: str) -> bool:
return False return False
try: try:
# Attempt to parse the string as an ISO 8601 timestamp # Attempt to parse the string as an ISO 8601 timestamp
datetime.fromisoformat(s.replace('Z', '+00:00')) datetime.fromisoformat(s.replace("Z", "+00:00"))
return True return True
except (ValueError, TypeError): except (ValueError, TypeError):
return False return False
class TestLongTermMemory: class TestLongTermMemory:
"""Tests for LongTermMemory.""" """Tests for LongTermMemory."""
@@ -116,12 +116,18 @@ class TestLongTermMemory:
assert data["config"]["key"] == "value" assert data["config"]["key"] == "value"
def test_from_dict(self): def test_from_dict(self):
data = {"config": {"download_folder": "/downloads"}, "preferences": {"preferred_quality": "4K"}, "library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]}, "following": []} data = {
"config": {"download_folder": "/downloads"},
"preferences": {"preferred_quality": "4K"},
"library": {"movies": [{"imdb_id": "tt1", "title": "Test"}]},
"following": [],
}
ltm = LongTermMemory.from_dict(data) ltm = LongTermMemory.from_dict(data)
assert ltm.get_config("download_folder") == "/downloads" assert ltm.get_config("download_folder") == "/downloads"
assert ltm.preferences["preferred_quality"] == "4K" assert ltm.preferences["preferred_quality"] == "4K"
assert len(ltm.library["movies"]) == 1 assert len(ltm.library["movies"]) == 1
class TestShortTermMemory: class TestShortTermMemory:
"""Tests for ShortTermMemory.""" """Tests for ShortTermMemory."""
@@ -162,6 +168,7 @@ class TestShortTermMemory:
assert stm.conversation_history == [] assert stm.conversation_history == []
assert stm.language == "en" assert stm.language == "en"
class TestEpisodicMemory: class TestEpisodicMemory:
"""Tests for EpisodicMemory.""" """Tests for EpisodicMemory."""
@@ -192,6 +199,7 @@ class TestEpisodicMemory:
assert result is not None assert result is not None
assert result["name"] == "Result 2" assert result["name"] == "Result 2"
class TestMemory: class TestMemory:
"""Tests for the Memory manager.""" """Tests for the Memory manager."""
@@ -217,11 +225,10 @@ class TestMemory:
assert memory.stm.conversation_history == [] assert memory.stm.conversation_history == []
assert memory.episodic.recent_errors == [] assert memory.episodic.recent_errors == []
class TestMemoryContext: class TestMemoryContext:
"""Tests for memory context functions.""" """Tests for memory context functions."""
def test_get_memory_not_initialized(self): def test_get_memory_not_initialized(self):
_memory_ctx.set(None) _memory_ctx.set(None)
with pytest.raises(RuntimeError, match="Memory not initialized"): with pytest.raises(RuntimeError, match="Memory not initialized"):
+7 -9
View File
@@ -1,18 +1,17 @@
"""Edge case tests for the Memory system.""" """Edge case tests for the Memory system."""
import pytest
import json import json
import os import os
from pathlib import Path
from datetime import datetime import pytest
from unittest.mock import patch, mock_open
from infrastructure.persistence import ( from infrastructure.persistence import (
Memory,
LongTermMemory,
ShortTermMemory,
EpisodicMemory, EpisodicMemory,
init_memory, LongTermMemory,
Memory,
ShortTermMemory,
get_memory, get_memory,
init_memory,
set_memory, set_memory,
) )
from infrastructure.persistence.context import _memory_ctx from infrastructure.persistence.context import _memory_ctx
@@ -529,7 +528,6 @@ class TestMemoryContextEdgeCases:
def test_context_isolation(self, temp_dir): def test_context_isolation(self, temp_dir):
"""Context should be isolated per context.""" """Context should be isolated per context."""
import asyncio
from contextvars import copy_context from contextvars import copy_context
_memory_ctx.set(None) _memory_ctx.set(None)
+14 -16
View File
@@ -1,10 +1,8 @@
"""Critical tests for prompt builder - Tests that would have caught bugs.""" """Critical tests for prompt builder - Tests that would have caught bugs."""
import pytest
from agent.registry import make_tools
from agent.prompts import PromptBuilder from agent.prompts import PromptBuilder
from infrastructure.persistence import get_memory from agent.registry import make_tools
class TestPromptBuilderToolsInjection: class TestPromptBuilderToolsInjection:
@@ -18,7 +16,9 @@ class TestPromptBuilderToolsInjection:
# Verify each tool is mentioned # Verify each tool is mentioned
for tool_name in tools.keys(): for tool_name in tools.keys():
assert tool_name in prompt, f"Tool {tool_name} not mentioned in system prompt" assert (
tool_name in prompt
), f"Tool {tool_name} not mentioned in system prompt"
def test_tools_spec_contains_all_registered_tools(self): def test_tools_spec_contains_all_registered_tools(self):
"""CRITICAL: Verify build_tools_spec() returns all tools.""" """CRITICAL: Verify build_tools_spec() returns all tools."""
@@ -26,7 +26,7 @@ class TestPromptBuilderToolsInjection:
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
specs = builder.build_tools_spec() specs = builder.build_tools_spec()
spec_names = {spec['function']['name'] for spec in specs} spec_names = {spec["function"]["name"] for spec in specs}
tool_names = set(tools.keys()) tool_names = set(tools.keys())
assert spec_names == tool_names, f"Missing tools: {tool_names - spec_names}" assert spec_names == tool_names, f"Missing tools: {tool_names - spec_names}"
@@ -46,12 +46,12 @@ class TestPromptBuilderToolsInjection:
specs = builder.build_tools_spec() specs = builder.build_tools_spec()
for spec in specs: for spec in specs:
assert 'type' in spec assert "type" in spec
assert spec['type'] == 'function' assert spec["type"] == "function"
assert 'function' in spec assert "function" in spec
assert 'name' in spec['function'] assert "name" in spec["function"]
assert 'description' in spec['function'] assert "description" in spec["function"]
assert 'parameters' in spec['function'] assert "parameters" in spec["function"]
class TestPromptBuilderMemoryContext: class TestPromptBuilderMemoryContext:
@@ -92,11 +92,9 @@ class TestPromptBuilderMemoryContext:
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
memory.episodic.add_active_download({ memory.episodic.add_active_download(
"task_id": "123", {"task_id": "123", "name": "Test Movie", "progress": 50}
"name": "Test Movie", )
"progress": 50
})
prompt = builder.build_system_prompt() prompt = builder.build_system_prompt()
+25 -19
View File
@@ -1,10 +1,8 @@
"""Edge case tests for PromptBuilder.""" """Edge case tests for PromptBuilder."""
import pytest
import json
from agent.prompts import PromptBuilder from agent.prompts import PromptBuilder
from agent.registry import make_tools from agent.registry import make_tools
from infrastructure.persistence import get_memory
class TestPromptBuilderEdgeCases: class TestPromptBuilderEdgeCases:
@@ -93,11 +91,13 @@ class TestPromptBuilderEdgeCases:
def test_prompt_with_many_active_downloads(self, memory): def test_prompt_with_many_active_downloads(self, memory):
"""Should limit displayed active downloads.""" """Should limit displayed active downloads."""
for i in range(20): for i in range(20):
memory.episodic.add_active_download({ memory.episodic.add_active_download(
"task_id": str(i), {
"name": f"Download {i}", "task_id": str(i),
"progress": i * 5, "name": f"Download {i}",
}) "progress": i * 5,
}
)
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
@@ -136,12 +136,15 @@ class TestPromptBuilderEdgeCases:
def test_prompt_with_complex_workflow(self, memory): def test_prompt_with_complex_workflow(self, memory):
"""Should handle complex workflow state.""" """Should handle complex workflow state."""
memory.stm.start_workflow("download", { memory.stm.start_workflow(
"title": "Test Movie", "download",
"year": 2024, {
"quality": "1080p", "title": "Test Movie",
"nested": {"deep": {"value": "test"}}, "year": 2024,
}) "quality": "1080p",
"nested": {"deep": {"value": "test"}},
},
)
memory.stm.update_workflow_stage("searching_torrents") memory.stm.update_workflow_stage("searching_torrents")
tools = make_tools() tools = make_tools()
@@ -313,11 +316,14 @@ class TestFormatEpisodicContextEdgeCases:
def test_format_with_search_results_none_names(self, memory): def test_format_with_search_results_none_names(self, memory):
"""Should handle results with None names.""" """Should handle results with None names."""
memory.episodic.store_search_results("test", [ memory.episodic.store_search_results(
{"name": None}, "test",
{"title": None}, [
{}, {"name": None},
]) {"title": None},
{},
],
)
tools = make_tools() tools = make_tools()
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
+44 -32
View File
@@ -1,10 +1,11 @@
"""Critical tests for tool registry - Tests that would have caught bugs.""" """Critical tests for tool registry - Tests that would have caught bugs."""
import pytest
import inspect import inspect
from agent.registry import make_tools, _create_tool_from_function, Tool import pytest
from agent.prompts import PromptBuilder from agent.prompts import PromptBuilder
from agent.registry import Tool, _create_tool_from_function, make_tools
class TestToolSpecFormat: class TestToolSpecFormat:
@@ -22,22 +23,25 @@ class TestToolSpecFormat:
for spec in specs: for spec in specs:
# OpenAI format requires these fields # OpenAI format requires these fields
assert spec['type'] == 'function', f"Tool type must be 'function', got {spec.get('type')}" assert (
assert 'function' in spec, "Tool spec missing 'function' key" spec["type"] == "function"
), f"Tool type must be 'function', got {spec.get('type')}"
assert "function" in spec, "Tool spec missing 'function' key"
func = spec['function'] func = spec["function"]
assert 'name' in func, "Function missing 'name'" assert "name" in func, "Function missing 'name'"
assert 'description' in func, "Function missing 'description'" assert "description" in func, "Function missing 'description'"
assert 'parameters' in func, "Function missing 'parameters'" assert "parameters" in func, "Function missing 'parameters'"
params = func['parameters'] params = func["parameters"]
assert params['type'] == 'object', "Parameters type must be 'object'" assert params["type"] == "object", "Parameters type must be 'object'"
assert 'properties' in params, "Parameters missing 'properties'" assert "properties" in params, "Parameters missing 'properties'"
assert 'required' in params, "Parameters missing 'required'" assert "required" in params, "Parameters missing 'required'"
assert isinstance(params['required'], list), "Required must be a list" assert isinstance(params["required"], list), "Required must be a list"
def test_tool_parameters_match_function_signature(self): def test_tool_parameters_match_function_signature(self):
"""CRITICAL: Verify generated parameters match function signature.""" """CRITICAL: Verify generated parameters match function signature."""
def test_func(name: str, age: int, active: bool = True): def test_func(name: str, age: int, active: bool = True):
"""Test function with typed parameters.""" """Test function with typed parameters."""
return {"status": "ok"} return {"status": "ok"}
@@ -45,14 +49,16 @@ class TestToolSpecFormat:
tool = _create_tool_from_function(test_func) tool = _create_tool_from_function(test_func)
# Verify types are correctly mapped # Verify types are correctly mapped
assert tool.parameters['properties']['name']['type'] == 'string' assert tool.parameters["properties"]["name"]["type"] == "string"
assert tool.parameters['properties']['age']['type'] == 'integer' assert tool.parameters["properties"]["age"]["type"] == "integer"
assert tool.parameters['properties']['active']['type'] == 'boolean' assert tool.parameters["properties"]["active"]["type"] == "boolean"
# Verify required vs optional # Verify required vs optional
assert 'name' in tool.parameters['required'], "name should be required" assert "name" in tool.parameters["required"], "name should be required"
assert 'age' in tool.parameters['required'], "age should be required" assert "age" in tool.parameters["required"], "age should be required"
assert 'active' not in tool.parameters['required'], "active has default, should not be required" assert (
"active" not in tool.parameters["required"]
), "active has default, should not be required"
def test_all_registered_tools_are_callable(self): def test_all_registered_tools_are_callable(self):
"""CRITICAL: Verify all registered tools are actually callable.""" """CRITICAL: Verify all registered tools are actually callable."""
@@ -76,7 +82,7 @@ class TestToolSpecFormat:
builder = PromptBuilder(tools) builder = PromptBuilder(tools)
specs = builder.build_tools_spec() specs = builder.build_tools_spec()
spec_names = {spec['function']['name'] for spec in specs} spec_names = {spec["function"]["name"] for spec in specs}
tool_names = set(tools.keys()) tool_names = set(tools.keys())
missing = tool_names - spec_names missing = tool_names - spec_names
@@ -88,6 +94,7 @@ class TestToolSpecFormat:
def test_tool_description_extracted_from_docstring(self): def test_tool_description_extracted_from_docstring(self):
"""Verify tool description is extracted from function docstring.""" """Verify tool description is extracted from function docstring."""
def test_func(param: str): def test_func(param: str):
"""This is the description. """This is the description.
@@ -102,6 +109,7 @@ class TestToolSpecFormat:
def test_tool_without_docstring_uses_function_name(self): def test_tool_without_docstring_uses_function_name(self):
"""Verify tool without docstring uses function name as description.""" """Verify tool without docstring uses function name as description."""
def test_func_no_doc(param: str): def test_func_no_doc(param: str):
return {} return {}
@@ -116,23 +124,25 @@ class TestToolSpecFormat:
specs = builder.build_tools_spec() specs = builder.build_tools_spec()
for spec in specs: for spec in specs:
params = spec['function']['parameters'] params = spec["function"]["parameters"]
properties = params.get('properties', {}) properties = params.get("properties", {})
for param_name, param_spec in properties.items(): for param_name, param_spec in properties.items():
assert 'description' in param_spec, \ assert (
f"Parameter {param_name} in {spec['function']['name']} missing description" "description" in param_spec
), f"Parameter {param_name} in {spec['function']['name']} missing description"
def test_required_parameters_are_marked_correctly(self): def test_required_parameters_are_marked_correctly(self):
"""Verify required parameters are correctly identified.""" """Verify required parameters are correctly identified."""
def func_with_optional(required: str, optional: int = 5): def func_with_optional(required: str, optional: int = 5):
return {} return {}
tool = _create_tool_from_function(func_with_optional) tool = _create_tool_from_function(func_with_optional)
assert 'required' in tool.parameters['required'] assert "required" in tool.parameters["required"]
assert 'optional' not in tool.parameters['required'] assert "optional" not in tool.parameters["required"]
assert len(tool.parameters['required']) == 1 assert len(tool.parameters["required"]) == 1
class TestToolRegistry: class TestToolRegistry:
@@ -195,6 +205,7 @@ class TestToolDataclass:
def test_tool_creation(self): def test_tool_creation(self):
"""Verify Tool can be created with all fields.""" """Verify Tool can be created with all fields."""
def dummy_func(): def dummy_func():
return {} return {}
@@ -202,7 +213,7 @@ class TestToolDataclass:
name="test_tool", name="test_tool",
description="Test description", description="Test description",
func=dummy_func, func=dummy_func,
parameters={"type": "object", "properties": {}, "required": []} parameters={"type": "object", "properties": {}, "required": []},
) )
assert tool.name == "test_tool" assert tool.name == "test_tool"
@@ -212,12 +223,13 @@ class TestToolDataclass:
def test_tool_parameters_structure(self): def test_tool_parameters_structure(self):
"""Verify Tool parameters have correct structure.""" """Verify Tool parameters have correct structure."""
def dummy_func(arg: str): def dummy_func(arg: str):
return {} return {}
tool = _create_tool_from_function(dummy_func) tool = _create_tool_from_function(dummy_func)
assert 'type' in tool.parameters assert "type" in tool.parameters
assert 'properties' in tool.parameters assert "properties" in tool.parameters
assert 'required' in tool.parameters assert "required" in tool.parameters
assert tool.parameters['type'] == 'object' assert tool.parameters["type"] == "object"
+8 -3
View File
@@ -1,6 +1,7 @@
"""Edge case tests for tool registry.""" """Edge case tests for tool registry."""
import pytest import pytest
from unittest.mock import Mock
from agent.registry import Tool, make_tools from agent.registry import Tool, make_tools
@@ -182,7 +183,9 @@ class TestMakeToolsEdgeCases:
params = tool.parameters params = tool.parameters
if "required" in params and "properties" in params: if "required" in params and "properties" in params:
for req in params["required"]: for req in params["required"]:
assert req in params["properties"], f"Required param {req} not in properties for {tool.name}" assert (
req in params["properties"]
), f"Required param {req} not in properties for {tool.name}"
def test_make_tools_descriptions_not_empty(self, memory): def test_make_tools_descriptions_not_empty(self, memory):
"""Should have non-empty descriptions.""" """Should have non-empty descriptions."""
@@ -233,7 +236,9 @@ class TestMakeToolsEdgeCases:
if "properties" in tool.parameters: if "properties" in tool.parameters:
for prop_name, prop_schema in tool.parameters["properties"].items(): for prop_name, prop_schema in tool.parameters["properties"].items():
if "type" in prop_schema: if "type" in prop_schema:
assert prop_schema["type"] in valid_types, f"Invalid type for {tool.name}.{prop_name}" assert (
prop_schema["type"] in valid_types
), f"Invalid type for {tool.name}.{prop_name}"
def test_make_tools_enum_values(self, memory): def test_make_tools_enum_values(self, memory):
"""Should have valid enum values.""" """Should have valid enum values."""
+49 -40
View File
@@ -1,19 +1,18 @@
"""Tests for JSON repositories.""" """Tests for JSON repositories."""
import pytest
from datetime import datetime
from infrastructure.persistence.json import (
JsonMovieRepository,
JsonTVShowRepository,
JsonSubtitleRepository,
)
from domain.movies.entities import Movie from domain.movies.entities import Movie
from domain.movies.value_objects import MovieTitle, ReleaseYear, Quality from domain.movies.value_objects import MovieTitle, Quality, ReleaseYear
from domain.tv_shows.entities import TVShow from domain.shared.value_objects import FilePath, FileSize, ImdbId
from domain.tv_shows.value_objects import ShowStatus
from domain.subtitles.entities import Subtitle from domain.subtitles.entities import Subtitle
from domain.subtitles.value_objects import Language, SubtitleFormat, TimingOffset from domain.subtitles.value_objects import Language, SubtitleFormat, TimingOffset
from domain.shared.value_objects import ImdbId, FilePath, FileSize from domain.tv_shows.entities import TVShow
from domain.tv_shows.value_objects import ShowStatus
from infrastructure.persistence.json import (
JsonMovieRepository,
JsonSubtitleRepository,
JsonTVShowRepository,
)
class TestJsonMovieRepository: class TestJsonMovieRepository:
@@ -224,7 +223,9 @@ class TestJsonTVShowRepository:
"""Should preserve show status.""" """Should preserve show status."""
repo = JsonTVShowRepository() repo = JsonTVShowRepository()
for i, status in enumerate([ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]): for i, status in enumerate(
[ShowStatus.ONGOING, ShowStatus.ENDED, ShowStatus.UNKNOWN]
):
show = TVShow( show = TVShow(
imdb_id=ImdbId(f"tt{i+1000000:07d}"), imdb_id=ImdbId(f"tt{i+1000000:07d}"),
title=f"Show {status.value}", title=f"Show {status.value}",
@@ -294,18 +295,22 @@ class TestJsonSubtitleRepository:
def test_find_by_media_with_language_filter(self, memory): def test_find_by_media_with_language_filter(self, memory):
"""Should filter by language.""" """Should filter by language."""
repo = JsonSubtitleRepository() repo = JsonSubtitleRepository()
repo.save(Subtitle( repo.save(
media_imdb_id=ImdbId("tt1375666"), Subtitle(
language=Language.ENGLISH, media_imdb_id=ImdbId("tt1375666"),
format=SubtitleFormat.SRT, language=Language.ENGLISH,
file_path=FilePath("/subs/en.srt"), format=SubtitleFormat.SRT,
)) file_path=FilePath("/subs/en.srt"),
repo.save(Subtitle( )
media_imdb_id=ImdbId("tt1375666"), )
language=Language.FRENCH, repo.save(
format=SubtitleFormat.SRT, Subtitle(
file_path=FilePath("/subs/fr.srt"), media_imdb_id=ImdbId("tt1375666"),
)) language=Language.FRENCH,
format=SubtitleFormat.SRT,
file_path=FilePath("/subs/fr.srt"),
)
)
results = repo.find_by_media(ImdbId("tt1375666"), language=Language.FRENCH) results = repo.find_by_media(ImdbId("tt1375666"), language=Language.FRENCH)
@@ -315,22 +320,26 @@ class TestJsonSubtitleRepository:
def test_find_by_media_with_episode_filter(self, memory): def test_find_by_media_with_episode_filter(self, memory):
"""Should filter by season/episode.""" """Should filter by season/episode."""
repo = JsonSubtitleRepository() repo = JsonSubtitleRepository()
repo.save(Subtitle( repo.save(
media_imdb_id=ImdbId("tt0944947"), Subtitle(
language=Language.ENGLISH, media_imdb_id=ImdbId("tt0944947"),
format=SubtitleFormat.SRT, language=Language.ENGLISH,
file_path=FilePath("/subs/s01e01.srt"), format=SubtitleFormat.SRT,
season_number=1, file_path=FilePath("/subs/s01e01.srt"),
episode_number=1, season_number=1,
)) episode_number=1,
repo.save(Subtitle( )
media_imdb_id=ImdbId("tt0944947"), )
language=Language.ENGLISH, repo.save(
format=SubtitleFormat.SRT, Subtitle(
file_path=FilePath("/subs/s01e02.srt"), media_imdb_id=ImdbId("tt0944947"),
season_number=1, language=Language.ENGLISH,
episode_number=2, format=SubtitleFormat.SRT,
)) file_path=FilePath("/subs/s01e02.srt"),
season_number=1,
episode_number=2,
)
)
results = repo.find_by_media( results = repo.find_by_media(
ImdbId("tt0944947"), ImdbId("tt0944947"),
+113 -68
View File
@@ -21,21 +21,27 @@ def create_mock_response(status_code, json_data=None, text=None):
class TestFindMediaImdbId: class TestFindMediaImdbId:
"""Tests for find_media_imdb_id tool.""" """Tests for find_media_imdb_id tool."""
@patch('infrastructure.api.tmdb.client.requests.get') @patch("infrastructure.api.tmdb.client.requests.get")
def test_success(self, mock_get, memory): def test_success(self, mock_get, memory):
"""Should return movie info on success.""" """Should return movie info on success."""
# Mock HTTP responses # Mock HTTP responses
def mock_get_side_effect(url, **kwargs): def mock_get_side_effect(url, **kwargs):
if "search" in url: if "search" in url:
return create_mock_response(200, json_data={ return create_mock_response(
"results": [{ 200,
"id": 27205, json_data={
"title": "Inception", "results": [
"release_date": "2010-07-16", {
"overview": "A thief...", "id": 27205,
"media_type": "movie" "title": "Inception",
}] "release_date": "2010-07-16",
}) "overview": "A thief...",
"media_type": "movie",
}
]
},
)
elif "external_ids" in url: elif "external_ids" in url:
return create_mock_response(200, json_data={"imdb_id": "tt1375666"}) return create_mock_response(200, json_data={"imdb_id": "tt1375666"})
@@ -50,19 +56,25 @@ class TestFindMediaImdbId:
# Verify HTTP calls # Verify HTTP calls
assert mock_get.call_count == 2 assert mock_get.call_count == 2
@patch('infrastructure.api.tmdb.client.requests.get') @patch("infrastructure.api.tmdb.client.requests.get")
def test_stores_in_stm(self, mock_get, memory): def test_stores_in_stm(self, mock_get, memory):
"""Should store result in STM on success.""" """Should store result in STM on success."""
def mock_get_side_effect(url, **kwargs): def mock_get_side_effect(url, **kwargs):
if "search" in url: if "search" in url:
return create_mock_response(200, json_data={ return create_mock_response(
"results": [{ 200,
"id": 27205, json_data={
"title": "Inception", "results": [
"release_date": "2010-07-16", {
"media_type": "movie" "id": 27205,
}] "title": "Inception",
}) "release_date": "2010-07-16",
"media_type": "movie",
}
]
},
)
elif "external_ids" in url: elif "external_ids" in url:
return create_mock_response(200, json_data={"imdb_id": "tt1375666"}) return create_mock_response(200, json_data={"imdb_id": "tt1375666"})
@@ -76,7 +88,7 @@ class TestFindMediaImdbId:
assert entity["title"] == "Inception" assert entity["title"] == "Inception"
assert mem.stm.current_topic == "searching_media" assert mem.stm.current_topic == "searching_media"
@patch('infrastructure.api.tmdb.client.requests.get') @patch("infrastructure.api.tmdb.client.requests.get")
def test_not_found(self, mock_get, memory): def test_not_found(self, mock_get, memory):
"""Should return error when not found.""" """Should return error when not found."""
mock_get.return_value = create_mock_response(200, json_data={"results": []}) mock_get.return_value = create_mock_response(200, json_data={"results": []})
@@ -86,7 +98,7 @@ class TestFindMediaImdbId:
assert result["status"] == "error" assert result["status"] == "error"
assert result["error"] == "not_found" assert result["error"] == "not_found"
@patch('infrastructure.api.tmdb.client.requests.get') @patch("infrastructure.api.tmdb.client.requests.get")
def test_does_not_store_on_error(self, mock_get, memory): def test_does_not_store_on_error(self, mock_get, memory):
"""Should not store in STM on error.""" """Should not store in STM on error."""
mock_get.return_value = create_mock_response(200, json_data={"results": []}) mock_get.return_value = create_mock_response(200, json_data={"results": []})
@@ -100,27 +112,30 @@ class TestFindMediaImdbId:
class TestFindTorrent: class TestFindTorrent:
"""Tests for find_torrent tool.""" """Tests for find_torrent tool."""
@patch('infrastructure.api.knaben.client.requests.post') @patch("infrastructure.api.knaben.client.requests.post")
def test_success(self, mock_post, memory): def test_success(self, mock_post, memory):
"""Should return torrents on success.""" """Should return torrents on success."""
mock_post.return_value = create_mock_response(200, json_data={ mock_post.return_value = create_mock_response(
"hits": [ 200,
{ json_data={
"title": "Torrent 1", "hits": [
"seeders": 100, {
"leechers": 10, "title": "Torrent 1",
"magnetUrl": "magnet:?xt=...", "seeders": 100,
"size": "2.5 GB" "leechers": 10,
}, "magnetUrl": "magnet:?xt=...",
{ "size": "2.5 GB",
"title": "Torrent 2", },
"seeders": 50, {
"leechers": 5, "title": "Torrent 2",
"magnetUrl": "magnet:?xt=...", "seeders": 50,
"size": "1.8 GB" "leechers": 5,
} "magnetUrl": "magnet:?xt=...",
] "size": "1.8 GB",
}) },
]
},
)
result = api_tools.find_torrent("Inception 1080p") result = api_tools.find_torrent("Inception 1080p")
@@ -128,21 +143,26 @@ class TestFindTorrent:
assert len(result["torrents"]) == 2 assert len(result["torrents"]) == 2
# Verify HTTP payload # Verify HTTP payload
payload = mock_post.call_args[1]['json'] payload = mock_post.call_args[1]["json"]
assert payload['query'] == "Inception 1080p" assert payload["query"] == "Inception 1080p"
@patch('infrastructure.api.knaben.client.requests.post') @patch("infrastructure.api.knaben.client.requests.post")
def test_stores_in_episodic(self, mock_post, memory): def test_stores_in_episodic(self, mock_post, memory):
"""Should store results in episodic memory.""" """Should store results in episodic memory."""
mock_post.return_value = create_mock_response(200, json_data={ mock_post.return_value = create_mock_response(
"hits": [{ 200,
"title": "Torrent 1", json_data={
"seeders": 100, "hits": [
"leechers": 10, {
"magnetUrl": "magnet:?xt=...", "title": "Torrent 1",
"size": "2.5 GB" "seeders": 100,
}] "leechers": 10,
}) "magnetUrl": "magnet:?xt=...",
"size": "2.5 GB",
}
]
},
)
api_tools.find_torrent("Inception") api_tools.find_torrent("Inception")
@@ -151,16 +171,37 @@ class TestFindTorrent:
assert mem.episodic.last_search_results["query"] == "Inception" assert mem.episodic.last_search_results["query"] == "Inception"
assert mem.stm.current_topic == "selecting_torrent" assert mem.stm.current_topic == "selecting_torrent"
@patch('infrastructure.api.knaben.client.requests.post') @patch("infrastructure.api.knaben.client.requests.post")
def test_results_have_indexes(self, mock_post, memory): def test_results_have_indexes(self, mock_post, memory):
"""Should add indexes to results.""" """Should add indexes to results."""
mock_post.return_value = create_mock_response(200, json_data={ mock_post.return_value = create_mock_response(
"hits": [ 200,
{"title": "Torrent 1", "seeders": 100, "leechers": 10, "magnetUrl": "magnet:?xt=1", "size": "1GB"}, json_data={
{"title": "Torrent 2", "seeders": 50, "leechers": 5, "magnetUrl": "magnet:?xt=2", "size": "2GB"}, "hits": [
{"title": "Torrent 3", "seeders": 25, "leechers": 2, "magnetUrl": "magnet:?xt=3", "size": "3GB"} {
] "title": "Torrent 1",
}) "seeders": 100,
"leechers": 10,
"magnetUrl": "magnet:?xt=1",
"size": "1GB",
},
{
"title": "Torrent 2",
"seeders": 50,
"leechers": 5,
"magnetUrl": "magnet:?xt=2",
"size": "2GB",
},
{
"title": "Torrent 3",
"seeders": 25,
"leechers": 2,
"magnetUrl": "magnet:?xt=3",
"size": "3GB",
},
]
},
)
api_tools.find_torrent("Test") api_tools.find_torrent("Test")
@@ -170,7 +211,7 @@ class TestFindTorrent:
assert results[1]["index"] == 2 assert results[1]["index"] == 2
assert results[2]["index"] == 3 assert results[2]["index"] == 3
@patch('infrastructure.api.knaben.client.requests.post') @patch("infrastructure.api.knaben.client.requests.post")
def test_not_found(self, mock_post, memory): def test_not_found(self, mock_post, memory):
"""Should return error when no torrents found.""" """Should return error when no torrents found."""
mock_post.return_value = create_mock_response(200, json_data={"hits": []}) mock_post.return_value = create_mock_response(200, json_data={"hits": []})
@@ -245,7 +286,7 @@ class TestAddTorrentToQbittorrent:
This is acceptable mocking because we're testing the TOOL logic, not the client. This is acceptable mocking because we're testing the TOOL logic, not the client.
""" """
@patch('agent.tools.api.qbittorrent_client') @patch("agent.tools.api.qbittorrent_client")
def test_success(self, mock_client, memory): def test_success(self, mock_client, memory):
"""Should add torrent successfully and update memory.""" """Should add torrent successfully and update memory."""
mock_client.add_torrent.return_value = True mock_client.add_torrent.return_value = True
@@ -257,7 +298,7 @@ class TestAddTorrentToQbittorrent:
# Verify client was called correctly # Verify client was called correctly
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123") mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123")
@patch('agent.tools.api.qbittorrent_client') @patch("agent.tools.api.qbittorrent_client")
def test_adds_to_active_downloads(self, mock_client, memory_with_search_results): def test_adds_to_active_downloads(self, mock_client, memory_with_search_results):
"""Should add to active downloads on success.""" """Should add to active downloads on success."""
mock_client.add_torrent.return_value = True mock_client.add_torrent.return_value = True
@@ -267,9 +308,12 @@ class TestAddTorrentToQbittorrent:
# Test memory update logic # Test memory update logic
mem = get_memory() mem = get_memory()
assert len(mem.episodic.active_downloads) == 1 assert len(mem.episodic.active_downloads) == 1
assert mem.episodic.active_downloads[0]["name"] == "Inception.2010.1080p.BluRay.x264" assert (
mem.episodic.active_downloads[0]["name"]
== "Inception.2010.1080p.BluRay.x264"
)
@patch('agent.tools.api.qbittorrent_client') @patch("agent.tools.api.qbittorrent_client")
def test_sets_topic_and_ends_workflow(self, mock_client, memory): def test_sets_topic_and_ends_workflow(self, mock_client, memory):
"""Should set topic and end workflow.""" """Should set topic and end workflow."""
mock_client.add_torrent.return_value = True mock_client.add_torrent.return_value = True
@@ -282,10 +326,11 @@ class TestAddTorrentToQbittorrent:
assert mem.stm.current_topic == "downloading" assert mem.stm.current_topic == "downloading"
assert mem.stm.current_workflow is None assert mem.stm.current_workflow is None
@patch('agent.tools.api.qbittorrent_client') @patch("agent.tools.api.qbittorrent_client")
def test_error_handling(self, mock_client, memory): def test_error_handling(self, mock_client, memory):
"""Should handle client errors correctly.""" """Should handle client errors correctly."""
from infrastructure.api.qbittorrent.exceptions import QBittorrentAPIError from infrastructure.api.qbittorrent.exceptions import QBittorrentAPIError
mock_client.add_torrent.side_effect = QBittorrentAPIError("Connection failed") mock_client.add_torrent.side_effect = QBittorrentAPIError("Connection failed")
result = api_tools.add_torrent_to_qbittorrent("magnet:?xt=...") result = api_tools.add_torrent_to_qbittorrent("magnet:?xt=...")
@@ -304,7 +349,7 @@ class TestAddTorrentByIndex:
- Error handling for edge cases - Error handling for edge cases
""" """
@patch('agent.tools.api.qbittorrent_client') @patch("agent.tools.api.qbittorrent_client")
def test_success(self, mock_client, memory_with_search_results): def test_success(self, mock_client, memory_with_search_results):
"""Should get torrent by index and add it.""" """Should get torrent by index and add it."""
mock_client.add_torrent.return_value = True mock_client.add_torrent.return_value = True
@@ -317,7 +362,7 @@ class TestAddTorrentByIndex:
# Verify correct magnet was extracted and used # Verify correct magnet was extracted and used
mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123") mock_client.add_torrent.assert_called_once_with("magnet:?xt=urn:btih:abc123")
@patch('agent.tools.api.qbittorrent_client') @patch("agent.tools.api.qbittorrent_client")
def test_uses_correct_magnet(self, mock_client, memory_with_search_results): def test_uses_correct_magnet(self, mock_client, memory_with_search_results):
"""Should extract correct magnet from index.""" """Should extract correct magnet from index."""
mock_client.add_torrent.return_value = True mock_client.add_torrent.return_value = True
+30 -8
View File
@@ -1,7 +1,8 @@
"""Edge case tests for tools.""" """Edge case tests for tools."""
from unittest.mock import Mock, patch
import pytest import pytest
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
from agent.tools import api as api_tools from agent.tools import api as api_tools
from agent.tools import filesystem as fs_tools from agent.tools import filesystem as fs_tools
@@ -15,7 +16,10 @@ class TestFindTorrentEdgeCases:
def test_empty_query(self, mock_use_case_class, memory): def test_empty_query(self, mock_use_case_class, memory):
"""Should handle empty query.""" """Should handle empty query."""
mock_response = Mock() mock_response = Mock()
mock_response.to_dict.return_value = {"status": "error", "error": "invalid_query"} mock_response.to_dict.return_value = {
"status": "error",
"error": "invalid_query",
}
mock_use_case = Mock() mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case mock_use_case_class.return_value = mock_use_case
@@ -28,7 +32,11 @@ class TestFindTorrentEdgeCases:
def test_very_long_query(self, mock_use_case_class, memory): def test_very_long_query(self, mock_use_case_class, memory):
"""Should handle very long query.""" """Should handle very long query."""
mock_response = Mock() mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0} mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock() mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case mock_use_case_class.return_value = mock_use_case
@@ -43,7 +51,11 @@ class TestFindTorrentEdgeCases:
def test_special_characters_in_query(self, mock_use_case_class, memory): def test_special_characters_in_query(self, mock_use_case_class, memory):
"""Should handle special characters in query.""" """Should handle special characters in query."""
mock_response = Mock() mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0} mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock() mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case mock_use_case_class.return_value = mock_use_case
@@ -57,7 +69,11 @@ class TestFindTorrentEdgeCases:
def test_unicode_query(self, mock_use_case_class, memory): def test_unicode_query(self, mock_use_case_class, memory):
"""Should handle unicode in query.""" """Should handle unicode in query."""
mock_response = Mock() mock_response = Mock()
mock_response.to_dict.return_value = {"status": "ok", "torrents": [], "count": 0} mock_response.to_dict.return_value = {
"status": "ok",
"torrents": [],
"count": 0,
}
mock_use_case = Mock() mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case mock_use_case_class.return_value = mock_use_case
@@ -161,7 +177,10 @@ class TestAddTorrentEdgeCases:
def test_empty_magnet_link(self, mock_use_case_class, memory): def test_empty_magnet_link(self, mock_use_case_class, memory):
"""Should handle empty magnet link.""" """Should handle empty magnet link."""
mock_response = Mock() mock_response = Mock()
mock_response.to_dict.return_value = {"status": "error", "error": "empty_magnet"} mock_response.to_dict.return_value = {
"status": "error",
"error": "empty_magnet",
}
mock_use_case = Mock() mock_use_case = Mock()
mock_use_case.execute.return_value = mock_response mock_use_case.execute.return_value = mock_response
mock_use_case_class.return_value = mock_use_case mock_use_case_class.return_value = mock_use_case
@@ -326,7 +345,10 @@ class TestFilesystemEdgeCases:
for attempt in attempts: for attempt in attempts:
result = fs_tools.list_folder("download", attempt) result = fs_tools.list_folder("download", attempt)
# Should either be forbidden or not found # Should either be forbidden or not found
assert result.get("error") in ["forbidden", "not_found", None] or result.get("status") == "ok" assert (
result.get("error") in ["forbidden", "not_found", None]
or result.get("status") == "ok"
)
def test_path_with_null_byte(self, memory, real_folder): def test_path_with_null_byte(self, memory, real_folder):
"""Should block null byte injection.""" """Should block null byte injection."""