Skip to content

Commit

Permalink
implement if else function (#787)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilongin authored Jan 7, 2025
1 parent 6862726 commit ad44884
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 14 deletions.
3 changes: 2 additions & 1 deletion src/datachain/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
sum,
)
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
from .conditional import case, greatest, least
from .conditional import case, greatest, ifelse, least
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
from .random import rand
from .string import byte_hamming_distance
Expand All @@ -40,6 +40,7 @@
"euclidean_distance",
"first",
"greatest",
"ifelse",
"int_hash_64",
"least",
"length",
Expand Down
40 changes: 31 additions & 9 deletions src/datachain/func/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from .func import ColT, Func

CaseT = Union[int, float, complex, bool, str]


def greatest(*args: Union[ColT, float]) -> Func:
"""
Expand Down Expand Up @@ -85,9 +87,7 @@ def least(*args: Union[ColT, float]) -> Func:
)


def case(
*args: tuple[BinaryExpression, Union[int, float, complex, bool, str]], else_=None
) -> Func:
def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
"""
Returns the case function that produces case expression which has a list of
conditions and corresponding results. Results can only be python primitives
Expand All @@ -108,26 +108,48 @@ def case(
res=func.case((C("num") > 0, "P"), (C("num") < 0, "N"), else_="Z"),
)
```
Note:
- Result column will always be of the same type as the input columns.
"""
supported_types = [int, float, complex, str, bool]

type_ = type(else_) if else_ else None

if not args:
raise DataChainParamsError("Missing case statements")
raise DataChainParamsError("Missing statements")

for arg in args:
if type_ and not isinstance(arg[1], type_):
raise DataChainParamsError("Case statement values must be of the same type")
raise DataChainParamsError("Statement values must be of the same type")
type_ = type(arg[1])

if type_ not in supported_types:
raise DataChainParamsError(
f"Case supports only python literals ({supported_types}) for values"
f"Only python literals ({supported_types}) are supported for values"
)

kwargs = {"else_": else_}
return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_)


def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func:
"""
Returns the ifelse function that produces if expression which has a condition
and values for true and false outcome. Results can only be python primitives
like string, numbes or booleans. Result type is inferred from the values.
Args:
condition: BinaryExpression - condition which is evaluated
if_val: (str | int | float | complex | bool): value for true condition outcome
else_val: (str | int | float | complex | bool): value for false condition
outcome
Returns:
Func: A Func object that represents the ifelse function.
Example:
```py
dc.mutate(
res=func.ifelse(C("num") > 0, "P", "N"),
)
```
"""
return case((condition, if_val), else_=else_val)
25 changes: 21 additions & 4 deletions tests/unit/sql/test_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,38 @@ def test_case(warehouse, val, expected):
def test_case_missing_statements(warehouse):
with pytest.raises(DataChainParamsError) as exc_info:
select(func.case(*[], else_="D"))
assert str(exc_info.value) == "Missing case statements"
assert str(exc_info.value) == "Missing statements"


def test_case_not_same_result_types(warehouse):
val = 2
with pytest.raises(DataChainParamsError) as exc_info:
select(func.case(*[(val > 1, "A"), (2 < val < 4, 5)], else_="D"))
assert str(exc_info.value) == "Case statement values must be of the same type"
assert str(exc_info.value) == "Statement values must be of the same type"


def test_case_wrong_result_type(warehouse):
val = 2
with pytest.raises(DataChainParamsError) as exc_info:
select(func.case(*[(val > 1, ["a", "b"]), (2 < val < 4, [])], else_=[]))
assert str(exc_info.value) == (
"Case supports only python literals ([<class 'int'>, <class 'float'>, "
"<class 'complex'>, <class 'str'>, <class 'bool'>]) for values"
"Only python literals ([<class 'int'>, <class 'float'>, "
"<class 'complex'>, <class 'str'>, <class 'bool'>]) are supported for values"
)


@pytest.mark.parametrize(
"val,expected",
[
(1, "L"),
(2, "L"),
(3, "L"),
(4, "H"),
(5, "H"),
(100, "H"),
],
)
def test_ifelse(warehouse, val, expected):
query = select(func.ifelse(val <= 3, "L", "H"))
result = tuple(warehouse.db.execute(query))
assert result == ((expected,),)
18 changes: 18 additions & 0 deletions tests/unit/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
bit_hamming_distance,
byte_hamming_distance,
case,
ifelse,
int_hash_64,
literal,
)
Expand Down Expand Up @@ -660,3 +661,20 @@ def test_case_mutate(dc, val, else_, type_):
[val, else_, else_, else_, else_]
)
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"if_val,else_val,type_",
[
["A", "D", str],
[1, 2, int],
[1.5, 2.5, float],
[True, False, bool],
],
)
def test_ifelse_mutate(dc, if_val, else_val, type_):
res = dc.mutate(test=ifelse(C("num") < 2, if_val, else_val))
assert list(res.order_by("test").collect("test")) == sorted(
[if_val, else_val, else_val, else_val, else_val]
)
assert res.schema["test"] == type_

0 comments on commit ad44884

Please sign in to comment.