Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(Views): multiple joins #1611

Merged
merged 3 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pandasai/query_builders/view_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,23 @@ def _get_table_expression(self) -> str:

query = select(*columns).from_(first_query)

# Group relations by target dataset to combine multiple join conditions
join_conditions = {}
for relation in relations:
to_datasets = relation.to.split(".")[0]
if to_datasets not in join_conditions:
join_conditions[to_datasets] = []
join_conditions[to_datasets].append(
f"{sanitize_view_column_name(relation.from_)} = {sanitize_view_column_name(relation.to)}"
)

# Create joins with combined conditions
for to_datasets, conditions in join_conditions.items():
loader = self.schema_dependencies_dict[to_datasets]
subquery = self._get_sub_query_from_loader(loader)
query = query.join(
subquery,
on=f"{sanitize_view_column_name(relation.from_)} = {sanitize_view_column_name(relation.to)}",
on=" AND ".join(conditions),
append=True,
)
alias = normalize_identifiers(self.schema.name).sql()
Expand Down
123 changes: 123 additions & 0 deletions tests/unit_tests/query_builders/test_view_query_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from unittest.mock import MagicMock

import pytest

from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.data_loader.sql_loader import SQLDatasetLoader
from pandasai.query_builders.sql_query_builder import SqlQueryBuilder
from pandasai.query_builders.view_query_builder import ViewQueryBuilder


Expand All @@ -8,6 +13,29 @@ class TestViewQueryBuilder:
def view_query_builder(self, mysql_view_schema, mysql_view_dependencies_dict):
return ViewQueryBuilder(mysql_view_schema, mysql_view_dependencies_dict)

def _create_mock_loader(self, table_name):
"""Helper method to create a mock loader for a table."""
schema = SemanticLayerSchema(
**{
"name": table_name,
"source": {
"type": "mysql",
"connection": {
"host": "localhost",
"port": 3306,
"database": "test_db",
"user": "test_user",
"password": "test_password",
},
"table": table_name,
},
}
)
mock_loader = MagicMock(spec=SQLDatasetLoader)
mock_loader.schema = schema
mock_loader.query_builder = SqlQueryBuilder(schema=schema)
return mock_loader

def test__init__(self, mysql_view_schema, mysql_view_dependencies_dict):
query_builder = ViewQueryBuilder(
mysql_view_schema, mysql_view_dependencies_dict
Expand Down Expand Up @@ -212,6 +240,101 @@ def test_table_name_comment_injection(self, view_query_builder):
) AS users"""
)

def test_multiple_joins_same_table(self):
"""Test joining the same table multiple times with different conditions."""
schema_dict = {
"name": "health_combined",
"columns": [
{"name": "diabetes.age"},
{"name": "diabetes.bloodpressure"},
{"name": "heart.age"},
{"name": "heart.restingbp"},
],
"relations": [
{"from": "diabetes.age", "to": "heart.age"},
{"from": "diabetes.bloodpressure", "to": "heart.restingbp"},
],
"view": "true",
}
schema = SemanticLayerSchema(**schema_dict)
dependencies = {
"diabetes": self._create_mock_loader("diabetes"),
"heart": self._create_mock_loader("heart"),
}
query_builder = ViewQueryBuilder(schema, dependencies)

assert (
query_builder._get_table_expression()
== """(
SELECT
diabetes.age AS diabetes_age,
diabetes.bloodpressure AS diabetes_bloodpressure,
heart.age AS heart_age,
heart.restingbp AS heart_restingbp
FROM (
SELECT
*
FROM diabetes
) AS diabetes
JOIN (
SELECT
*
FROM heart
) AS heart
ON diabetes.age = heart.age AND diabetes.bloodpressure = heart.restingbp
) AS health_combined"""
)

def test_three_table_join(self, mysql_view_dependencies_dict):
"""Test joining three different tables."""
schema_dict = {
"name": "patient_records",
"columns": [
{"name": "patients.id"},
{"name": "diabetes.glucose"},
{"name": "heart.cholesterol"},
],
"relations": [
{"from": "patients.id", "to": "diabetes.patient_id"},
{"from": "patients.id", "to": "heart.patient_id"},
],
"view": "true",
}
schema = SemanticLayerSchema(**schema_dict)
dependencies = {
"patients": self._create_mock_loader("patients"),
"diabetes": self._create_mock_loader("diabetes"),
"heart": self._create_mock_loader("heart"),
}
query_builder = ViewQueryBuilder(schema, dependencies)

assert (
query_builder._get_table_expression()
== """(
SELECT
patients.id AS patients_id,
diabetes.glucose AS diabetes_glucose,
heart.cholesterol AS heart_cholesterol
FROM (
SELECT
*
FROM patients
) AS patients
JOIN (
SELECT
*
FROM diabetes
) AS diabetes
ON patients.id = diabetes.patient_id
JOIN (
SELECT
*
FROM heart
) AS heart
ON patients.id = heart.patient_id
) AS patient_records"""
)

def test_column_name_comment_injection(self, view_query_builder):
view_query_builder.schema.columns[0].name = "column --"
query = view_query_builder.build_query()
Expand Down
Loading