Skip to content

Commit

Permalink
test: add more tests in the agent
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Jan 31, 2025
1 parent 4ca228f commit e128221
Showing 1 changed file with 107 additions and 2 deletions.
109 changes: 107 additions & 2 deletions tests/unit_tests/agent/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import os
from typing import Optional
from unittest.mock import MagicMock, Mock, mock_open, patch
from unittest.mock import ANY, MagicMock, Mock, mock_open, patch

import pandas as pd
import pytest

from pandasai import DatasetLoader, VirtualDataFrame
from pandasai.agent.base import Agent
from pandasai.config import Config, ConfigManager
from pandasai.core.response.error import ErrorResponse
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import CodeExecutionError
from pandasai.exceptions import CodeExecutionError, InvalidLLMOutputType
from pandasai.llm.fake import FakeLLM


Expand Down Expand Up @@ -466,3 +467,107 @@ def test_execute_sql_query_error_no_dataframe(self, agent):

with pytest.raises(ValueError, match="No DataFrames available"):
agent._execute_sql_query(query)

def test_process_query(self, agent, config):
"""Test the _process_query method with successful execution"""
query = "What is the average age?"
output_type = "number"

# Mock the necessary methods
agent.generate_code = Mock(return_value="result = df['age'].mean()")
agent.execute_with_retries = Mock(return_value=30.5)
agent._state.config.enable_cache = True
agent._state.cache = Mock()

# Execute the query
result = agent._process_query(query, output_type)

# Verify the result
assert result == 30.5

# Verify method calls
agent.generate_code.assert_called_once()
agent.execute_with_retries.assert_called_once_with("result = df['age'].mean()")
agent._state.cache.set.assert_called_once()

def test_process_query_execution_error(self, agent, config):
"""Test the _process_query method with execution error"""
query = "What is the invalid operation?"

# Mock methods to simulate error
agent.generate_code = Mock(return_value="invalid_code")
agent.execute_with_retries = Mock(
side_effect=CodeExecutionError("Execution failed")
)
agent._handle_exception = Mock(return_value="Error handled")

# Execute the query
result = agent._process_query(query)

# Verify error handling
assert result == "Error handled"
agent._handle_exception.assert_called_once_with("invalid_code")

def test_regenerate_code_after_invalid_llm_output_error(self, agent):
"""Test code regeneration with InvalidLLMOutputType error"""
from pandasai.exceptions import InvalidLLMOutputType

code = "test code"
error = InvalidLLMOutputType("Invalid output type")

with patch(
"pandasai.agent.base.get_correct_output_type_error_prompt"
) as mock_prompt:
mock_prompt.return_value = "corrected prompt"
agent._code_generator.generate_code = MagicMock(return_value="new code")

result = agent._regenerate_code_after_error(code, error)

mock_prompt.assert_called_once_with(agent._state, code, ANY)
agent._code_generator.generate_code.assert_called_once_with(
"corrected prompt"
)
assert result == "new code"

def test_regenerate_code_after_other_error(self, agent):
"""Test code regeneration with non-InvalidLLMOutputType error"""
code = "test code"
error = ValueError("Some other error")

with patch(
"pandasai.agent.base.get_correct_error_prompt_for_sql"
) as mock_prompt:
mock_prompt.return_value = "sql error prompt"
agent._code_generator.generate_code = MagicMock(return_value="new code")

result = agent._regenerate_code_after_error(code, error)

mock_prompt.assert_called_once_with(agent._state, code, ANY)
agent._code_generator.generate_code.assert_called_once_with(
"sql error prompt"
)
assert result == "new code"

def test_handle_exception(self, agent):
"""Test that _handle_exception properly formats and logs exceptions"""
test_code = "print(1/0)" # Code that will raise a ZeroDivisionError

# Mock the logger to verify it's called
mock_logger = MagicMock()
agent._state.logger = mock_logger

# Create an actual exception to handle
try:
exec(test_code)
except:
# Call the method
result = agent._handle_exception(test_code)

# Verify the result is an ErrorResponse
assert isinstance(result, ErrorResponse)
assert result.last_code_executed == test_code
assert "ZeroDivisionError" in result.error

# Verify the error was logged
mock_logger.log.assert_called_once()
assert "Processing failed with error" in mock_logger.log.call_args[0][0]

0 comments on commit e128221

Please sign in to comment.