Skip to content

Commit

Permalink
accept expr in default value (#1820)
Browse files Browse the repository at this point in the history
* accept expr in default value

Signed-off-by: gabriel <[email protected]>

* add test

---------

Signed-off-by: gabriel <[email protected]>
Co-authored-by: gabriel <[email protected]>
  • Loading branch information
gab23r and gabriel authored Nov 1, 2024
1 parent 4f8bdbf commit ea4538d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pandera/backends/polars/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,10 @@ def set_default(self, check_obj: pl.LazyFrame, schema) -> pl.LazyFrame:
if hasattr(schema, "default") and schema.default is None:
return check_obj

default_value = pl.lit(schema.default, dtype=schema.dtype.type)
if isinstance(schema.default, pl.Expr):
default_value = schema.default
else:
default_value = pl.lit(schema.default, dtype=schema.dtype.type)
expr = pl.col(schema.selector)
if is_float_dtype(check_obj, schema.selector):
expr = expr.fill_nan(default_value)
Expand Down
20 changes: 20 additions & 0 deletions tests/polars/test_polars_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,24 @@ def test_set_default(data, dtype, default):
assert validated_data.select(pl.col("column").eq(default).any()).item()


def test_expr_as_default():
schema = pa.DataFrameSchema(
columns={
"a": pa.Column(int),
"b": pa.Column(float, default=1),
"c": pa.Column(str, default=pl.lit("foo")),
"d": pa.Column(int, nullable=True, default=pl.col("a")),
},
add_missing_columns=True,
coerce=True,
)
df = pl.LazyFrame({"a": [1, 2, 3]})
assert schema.validate(df).collect().to_dict(as_series=False) == {
"a": [1, 2, 3],
"b": [1.0, 1.0, 1.0],
"c": ["foo", "foo", "foo"],
"d": [1, 2, 3],
}


def test_column_schema_on_lazyframe_coerce(): ...

0 comments on commit ea4538d

Please sign in to comment.