Skip to content

Commit

Permalink
A few small test-related fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanhogg committed Jan 2, 2025
1 parent c2dd586 commit e71967c
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 18 deletions.
24 changes: 14 additions & 10 deletions src/flitter/language/tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ cdef class Call(Expression):
if isinstance(value, Function):
func_expr = <Function>value
cdef bint literal_func = isinstance(function, Literal)
if literal_func and not (<Literal>function).value.objects:
if literal_func and (<Literal>function).value.length == 0:
return NoOp
cdef bint all_literal_args=True, all_dynamic_args=True
cdef Expression arg, sarg, expr
Expand Down Expand Up @@ -1056,15 +1056,19 @@ cdef class Call(Expression):
vector_args = [literal_arg.value for literal_arg in args]
kwargs = {binding.name: (<Literal>binding.expr).value for binding in keyword_args}
results = []
for func in (<Literal>function).value.objects:
if callable(func):
try:
assert not hasattr(func, 'context_func')
results.append(func(*vector_args, **kwargs))
except Exception as exc:
context.errors.add(f"Error calling {func.__name__}: {str(exc)}")
else:
context.errors.add(f"{func!r} is not callable")
if (<Literal>function).value.objects is not None:
for func in (<Literal>function).value.objects:
if callable(func):
try:
assert not hasattr(func, 'context_func')
results.append(func(*vector_args, **kwargs))
except Exception as exc:
context.errors.add(f"Error calling {func.__name__}: {str(exc)}")
else:
context.errors.add(f"{func!r} is not callable")
elif (<Literal>function).value.numbers != NULL:
for i in range((<Literal>function).value.length):
context.errors.add(f"{(<Literal>function).value.numbers[i]!r} is not callable")
return Literal(Vector._compose(results))
if isinstance(function, Literal) and len(args) == 1:
if (<Literal>function).value == static_builtins['ceil']:
Expand Down
5 changes: 1 addition & 4 deletions src/flitter/model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,7 @@ cdef class Vector:
return SymbolTable.get(self.numbers[0], f'{self.numbers[0]:.9g}')
elif n:
for i in range(n):
if self.numbers[i] == 0:
text += "0"
else:
text += SymbolTable.get(self.numbers[i], f'{self.numbers[i]:.9g}')
text += SymbolTable.get(self.numbers[i], f'{self.numbers[i]:.9g}')
return text

def __iter__(self):
Expand Down
10 changes: 7 additions & 3 deletions tests/test_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,10 +1057,14 @@ def test_static_failure(self):
self.assertSimplifiesTo(Call(Literal(functions.sqrtv), (), ()), Literal(null),
with_errors={'Error calling sqrtv: sqrtv() takes exactly 1 positional argument (0 given)'})

def test_non_callable_literal(self):
"""Calls to literals that are definitely not callable (i.e., empty or numeric) are replaced with null"""
def test_null_literal(self):
"""Calls to null literals are replaced with null"""
self.assertSimplifiesTo(Call(Literal(null), (Name('x'),), ()), Literal(null), dynamic={'x'})
self.assertSimplifiesTo(Call(Literal(5), (Name('x'),), ()), Literal(null), dynamic={'x'})

def test_non_callable_literals(self):
"""Literal calls to non-callables will evaluate to null with an error"""
self.assertSimplifiesTo(Call(Literal(5), (Literal(10),), ()), Literal(null), with_errors={'5.0 is not callable'})
self.assertSimplifiesTo(Call(Literal('Hello'), (Literal(10),), ()), Literal(null), with_errors={"'Hello' is not callable"})

def test_simple_named_inlining(self):
"""Calls to names that resolve to Function objects are inlined as let expressions"""
Expand Down
3 changes: 2 additions & 1 deletion tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,10 @@ def test_as_string(self):
(Vector([1 / 3]), "0.333333333"),
(Vector("Hello world!"), "Hello world!"),
(Vector(["Hello ", "world!"]), "Hello world!"),
(Vector(["testing", "testing", 1, 2.2, 3.0]), "testingtesting12.23"),
(Vector(["testing", "testing", 0, 1, 2.2, 3.0]), "testingtesting012.23"),
(Vector.symbol('foo'), "foo"),
(Vector.symbol('foo').concat(Vector.symbol('bar')), "foobar"),
(Vector.symbol('foo').concat(Vector([0, 1, 2.2, 3.0])), "foo012.23"),
(Vector(Node('foo', {'bar'}, {'baz': Vector(2)})), "foo"),
(Vector(self.test_as_string), "test_as_string"),
]
Expand Down

0 comments on commit e71967c

Please sign in to comment.