Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add more tests in the agent #1572

Merged
merged 6 commits into from
Jan 31, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 107 additions & 2 deletions tests/unit_tests/agent/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import os
from typing import Optional
from unittest.mock import MagicMock, Mock, mock_open, patch
from unittest.mock import ANY, MagicMock, Mock, mock_open, patch

import pandas as pd
import pytest

from pandasai import DatasetLoader, VirtualDataFrame
from pandasai.agent.base import Agent
from pandasai.config import Config, ConfigManager
from pandasai.core.response.error import ErrorResponse
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import CodeExecutionError
from pandasai.exceptions import CodeExecutionError, InvalidLLMOutputType
from pandasai.llm.fake import FakeLLM


Expand Down Expand Up @@ -466,3 +467,107 @@ def test_execute_sql_query_error_no_dataframe(self, agent):

with pytest.raises(ValueError, match="No DataFrames available"):
agent._execute_sql_query(query)

def test_process_query(self, agent, config):
"""Test the _process_query method with successful execution"""
query = "What is the average age?"
output_type = "number"

# Mock the necessary methods
agent.generate_code = Mock(return_value="result = df['age'].mean()")
agent.execute_with_retries = Mock(return_value=30.5)
agent._state.config.enable_cache = True
agent._state.cache = Mock()

# Execute the query
result = agent._process_query(query, output_type)

# Verify the result
assert result == 30.5

# Verify method calls
agent.generate_code.assert_called_once()
agent.execute_with_retries.assert_called_once_with("result = df['age'].mean()")
agent._state.cache.set.assert_called_once()

def test_process_query_execution_error(self, agent, config):
"""Test the _process_query method with execution error"""
query = "What is the invalid operation?"

# Mock methods to simulate error
agent.generate_code = Mock(return_value="invalid_code")
agent.execute_with_retries = Mock(
side_effect=CodeExecutionError("Execution failed")
)
agent._handle_exception = Mock(return_value="Error handled")

# Execute the query
result = agent._process_query(query)

# Verify error handling
assert result == "Error handled"
agent._handle_exception.assert_called_once_with("invalid_code")

def test_regenerate_code_after_invalid_llm_output_error(self, agent):
"""Test code regeneration with InvalidLLMOutputType error"""
from pandasai.exceptions import InvalidLLMOutputType

code = "test code"
error = InvalidLLMOutputType("Invalid output type")

with patch(
"pandasai.agent.base.get_correct_output_type_error_prompt"
) as mock_prompt:
mock_prompt.return_value = "corrected prompt"
agent._code_generator.generate_code = MagicMock(return_value="new code")

result = agent._regenerate_code_after_error(code, error)

mock_prompt.assert_called_once_with(agent._state, code, ANY)
agent._code_generator.generate_code.assert_called_once_with(
"corrected prompt"
)
assert result == "new code"

def test_regenerate_code_after_other_error(self, agent):
"""Test code regeneration with non-InvalidLLMOutputType error"""
code = "test code"
error = ValueError("Some other error")

with patch(
"pandasai.agent.base.get_correct_error_prompt_for_sql"
) as mock_prompt:
mock_prompt.return_value = "sql error prompt"
agent._code_generator.generate_code = MagicMock(return_value="new code")

result = agent._regenerate_code_after_error(code, error)

mock_prompt.assert_called_once_with(agent._state, code, ANY)
agent._code_generator.generate_code.assert_called_once_with(
"sql error prompt"
)
assert result == "new code"

def test_handle_exception(self, agent):
"""Test that _handle_exception properly formats and logs exceptions"""
test_code = "print(1/0)" # Code that will raise a ZeroDivisionError

# Mock the logger to verify it's called
mock_logger = MagicMock()
agent._state.logger = mock_logger

# Create an actual exception to handle
try:
exec(test_code)
except:
# Call the method
result = agent._handle_exception(test_code)

# Verify the result is an ErrorResponse
assert isinstance(result, ErrorResponse)
assert result.last_code_executed == test_code
assert "ZeroDivisionError" in result.error

# Verify the error was logged
mock_logger.log.assert_called_once()
assert "Processing failed with error" in mock_logger.log.call_args[0][0]
Loading