Skip to content

Commit

Permalink
fix: fixing datasets naming convention
Browse files Browse the repository at this point in the history
  • Loading branch information
scaliseraoul committed Feb 18, 2025
1 parent ef8388f commit 614a9ca
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 24 deletions.
13 changes: 9 additions & 4 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
Source,
)
from pandasai.exceptions import DatasetNotFound, InvalidConfigError, PandaAIApiKeyError
from pandasai.helpers.path import find_project_root, get_validated_dataset_path
from pandasai.helpers.path import (
find_project_root,
get_validated_dataset_path,
transform_dash_to_underscore,
)
from pandasai.helpers.session import get_pandaai_session
from pandasai.query_builders import SqlQueryBuilder
from pandasai.sandbox.sandbox import Sandbox
Expand Down Expand Up @@ -119,6 +123,7 @@ def create(
raise ValueError("df must be a PandaAI DataFrame")

org_name, dataset_name = get_validated_dataset_path(path)
underscore_dataset_name = transform_dash_to_underscore(dataset_name)
dataset_directory = str(os.path.join(org_name, dataset_name))

schema_path = os.path.join(dataset_directory, "schema.yaml")
Expand All @@ -140,7 +145,7 @@ def create(

if df is not None:
schema = df.schema
schema.name = dataset_name
schema.name = underscore_dataset_name
if (
parsed_columns
): # if no columns are passed it automatically parse the columns from the df
Expand All @@ -153,15 +158,15 @@ def create(
elif view:
_relation = [Relation(**relation) for relation in relations or ()]
schema: SemanticLayerSchema = SemanticLayerSchema(
name=dataset_name,
name=underscore_dataset_name,
relations=_relation,
view=True,
columns=parsed_columns,
group_by=group_by,
)
elif source.get("table"):
schema: SemanticLayerSchema = SemanticLayerSchema(
name=dataset_name,
name=underscore_dataset_name,
source=Source(**source),
columns=parsed_columns,
group_by=group_by,
Expand Down
7 changes: 5 additions & 2 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import MethodNotImplementedError
from pandasai.helpers.path import get_validated_dataset_path
from pandasai.helpers.path import (
get_validated_dataset_path,
transform_underscore_to_dash,
)
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name

from .. import ConfigManager
Expand Down Expand Up @@ -60,6 +63,7 @@ def create_loader_from_path(cls, dataset_path: str) -> "DatasetLoader":
"""
Factory method to create the appropriate loader based on the dataset type.
"""
dataset_path = transform_underscore_to_dash(dataset_path)
schema = cls._read_schema_file(dataset_path)
return DatasetLoader.create_loader_from_schema(schema, dataset_path)

Expand All @@ -74,7 +78,6 @@ def _read_schema_file(dataset_path: str) -> SemanticLayerSchema:

schema_file = file_manager.load(schema_path)
raw_schema = yaml.safe_load(schema_file)
raw_schema["name"] = sanitize_sql_table_name(raw_schema["name"])
return SemanticLayerSchema(**raw_schema)

def load(self) -> DataFrame:
Expand Down
20 changes: 16 additions & 4 deletions pandasai/data_loader/semantic_layer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
VALID_COLUMN_TYPES,
VALID_TRANSFORMATION_TYPES,
)
from pandasai.helpers.path import (
validate_underscore_name_format,
)


class SQLConnectionConfig(BaseModel):
Expand Down Expand Up @@ -299,10 +302,17 @@ class SemanticLayerSchema(BaseModel):

@model_validator(mode="after")
def validate_schema(self) -> "SemanticLayerSchema":
self._validate_name()
self._validate_group_by_columns()
self._validate_columns_relations()
return self

def _validate_name(self) -> None:
if not self.name or not validate_underscore_name_format(self.name):
raise ValueError(
"Dataset name must be lowercase and use underscores instead of spaces. E.g. 'dataset_name'."
)

def _validate_group_by_columns(self) -> None:
if not self.group_by or not self.columns:
return
Expand All @@ -321,7 +331,7 @@ def _validate_group_by_columns(self) -> None:
)

def _validate_columns_relations(self):
column_re_check = r"^[a-zA-Z0-9-]+\.[a-zA-Z0-9_]+$"
column_re_check = r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$"
is_view_column_name = partial(re.match, column_re_check)

# unpack columns info
Expand Down Expand Up @@ -360,15 +370,15 @@ def _validate_columns_relations(self):
is_view_column_name(column_name) for column_name in _column_names
):
raise ValueError(
"All columns in a view must be in the format '[dataset].[column]'."
"All columns in a view must be in the format '[dataset_name].[column_name]' accepting only letters, numbers, and underscores."
)

if not all(
is_view_column_name(column_name)
for column_name in _column_names_in_relations
):
raise ValueError(
"All params 'from' and 'to' in the relations must be in the format '[dataset].[column]'."
"All params 'from' and 'to' in the relations must be in the format '[dataset_name].[column_name]' accepting only letters, numbers, and underscores."
)

uncovered_tables = _tables_names_in_columns - _tables_names_in_relations
Expand All @@ -378,7 +388,9 @@ def _validate_columns_relations(self):
)

elif any(is_view_column_name(column_name) for column_name in _column_names):
raise ValueError("All columns in a table must be in the format '[column]'.")
raise ValueError(
"All columns in a table must be in the format '[column_name]' accepting only letters, numbers, and underscores."
)
return self

def to_dict(self) -> dict[str, Any]:
Expand Down
6 changes: 5 additions & 1 deletion pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pandas._typing import Axes, Dtype

import pandasai as pai
from pandasai import get_validated_dataset_path
from pandasai.config import Config, ConfigManager
from pandasai.core.response import BaseResponse
from pandasai.data_loader.semantic_layer_schema import (
Expand Down Expand Up @@ -148,12 +149,15 @@ def push(self):
"Please save the dataset before pushing to the remote server."
)

SemanticLayerSchema.model_validate(self.schema)
org_name, dataset_name = get_validated_dataset_path(self.path)

api_key = os.environ.get("PANDABI_API_KEY", None)

request_session = get_pandaai_session()

params = {
"path": self.path,
"path": f"{org_name}/{dataset_name}",
"description": self.schema.description,
"name": self.schema.name,
}
Expand Down
15 changes: 15 additions & 0 deletions pandasai/helpers/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ def validate_name_format(value):
return bool(re.match(r"^[a-z0-9]+(?:-[a-z0-9]+)*$", value))


def validate_underscore_name_format(value):
"""
Validate name format to be 'my_organization'
"""
return bool(re.match(r"^[a-z0-9]+(?:_[a-z0-9]+)*$", value))


def transform_dash_to_underscore(value: str) -> str:
return value.replace("-", "_")


def transform_underscore_to_dash(value: str) -> str:
return value.replace("_", "-")


def get_validated_dataset_path(path: str):
# Validate path format
path_parts = path.split("/")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def sample_dataframes():
@pytest.fixture
def raw_sample_schema():
return {
"name": "Users",
"name": "users",
"update_frequency": "weekly",
"columns": [
{
Expand Down
21 changes: 12 additions & 9 deletions tests/unit_tests/data_loader/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ def test_local_loader_properties(self, sample_schema):
loader = LocalDatasetLoader(sample_schema, "test/test")
assert isinstance(loader.query_builder, LocalQueryBuilder)

def test_load_schema_mysql_invalid_name(self, mysql_schema):
mysql_schema.name = "invalid-name"

with patch("os.path.exists", return_value=True), patch(
"builtins.open", mock_open(read_data=str(mysql_schema.to_yaml()))
):
with pytest.raises(
ValueError,
match="Dataset name must be lowercase and use underscores instead of spaces.",
):
DatasetLoader._read_schema_file("test/users")

def test_load_from_local_source_invalid_source_type(self, sample_schema):
sample_schema.source.type = "mysql"
loader = LocalDatasetLoader(sample_schema, "test/test")
Expand All @@ -53,15 +65,6 @@ def test_load_schema_mysql(self, mysql_schema):
schema = DatasetLoader._read_schema_file("test/users")
assert schema == mysql_schema

def test_load_schema_mysql_sanitized_name(self, mysql_schema):
mysql_schema.name = "non-sanitized-name"

with patch("os.path.exists", return_value=True), patch(
"builtins.open", mock_open(read_data=str(mysql_schema.to_yaml()))
):
schema = DatasetLoader._read_schema_file("test/users")
assert schema.name == "non_sanitized_name"

def test_load_schema_file_not_found(self):
with patch("os.path.exists", return_value=False):
with pytest.raises(FileNotFoundError):
Expand Down
8 changes: 7 additions & 1 deletion tests/unit_tests/dataframe/test_semantic_layer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestSemanticLayerSchema:
def test_valid_schema(self, raw_sample_schema):
schema = SemanticLayerSchema(**raw_sample_schema)

assert schema.name == "Users"
assert schema.name == "users"
assert schema.update_frequency == "weekly"
assert len(schema.columns) == 3
assert schema.order_by == ["created_at DESC"]
Expand All @@ -39,6 +39,12 @@ def test_valid_raw_mysql_view_schema(self, raw_mysql_view_schema):
assert len(schema.columns) == 3
assert schema.view == True

def test_invalid_name(self, raw_sample_schema):
raw_sample_schema["name"] = "invalid-name"

with pytest.raises(ValidationError):
SemanticLayerSchema(**raw_sample_schema)

def test_missing_source_path(self, raw_sample_schema):
raw_sample_schema["source"].pop("path")

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/query_builders/test_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestQueryBuilder:
@pytest.fixture
def mysql_schema(self):
raw_schema = {
"name": "Users",
"name": "users",
"update_frequency": "weekly",
"columns": [
{
Expand Down
9 changes: 8 additions & 1 deletion tests/unit_tests/test_pandasai_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ def test_load_without_api_credentials(
== 'The dataset "test/dataset" does not exist in your local datasets directory. In addition, no API Key has been provided. Set an API key with valid permits if you want to fetch the dataset from the remote server.'
)

def test_load_invalid_name(self):
with pytest.raises(
ValueError,
match="Organization name must be lowercase and use hyphens instead of spaces",
):
pandasai.load("test_test/data_set")

def test_clear_cache(self):
with patch("pandasai.core.cache.Cache.clear") as mock_clear:
pandasai.clear_cache()
Expand Down Expand Up @@ -415,7 +422,7 @@ def test_create_valid_dataset_with_description(
from pandasai.data_loader.semantic_layer_schema import Source

schema = SemanticLayerSchema(
name="test-dataset",
name="test_dataset",
description="test_description",
source=Source(type="parquet", path="data.parquet"),
)
Expand Down

0 comments on commit 614a9ca

Please sign in to comment.