From 02b5d9dfbb927e8861ee4befcc0ae2e40b6125d6 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 24 Dec 2020 21:19:51 -0500 Subject: [PATCH 01/19] Ensure alpha conversion always runs when a new variable is bound --- funsor/terms.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index 34501c58d..ef62d5345 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -41,19 +41,19 @@ def subs_interpreter(cls, *args): return interpreter.reinterpret(expr) -def _alpha_mangle(expr): +def _alpha_mangle(expr, identifier): """ Rename bound variables in expr to avoid conflict with any free variables. - - FIXME this does not avoid conflict with other bound variables. + Returns expr._ast_values with mangled names. """ - alpha_subs = {name: interpreter.gensym(name + "__BOUND") - for name in expr.bound if "__BOUND" not in name} + # how we know which variables to mangle: we assume variable names include + # constant-size information about the original binding context, + # in the form of a cons-hash key of the binding context. + alpha_subs = {name: name.split("__BOUND")[0] + "__BOUND_" + identifier + for name in expr.bound if identifier not in name} if not alpha_subs: - return expr - - ast_values = expr._alpha_convert(alpha_subs) - return reflect(type(expr), *ast_values) + return expr._ast_values + return expr._alpha_convert(alpha_subs) # return mangled _ast_values def reflect(cls, *args, **kwargs): @@ -81,8 +81,15 @@ def reflect(cls, *args, **kwargs): result = super(FunsorMeta, cls_specific).__call__(*args) result._ast_values = args - # alpha-convert eagerly upon binding any variable - result = _alpha_mangle(result) + # alpha-convert eagerly upon binding any variable. + # the identifier we use to reconcile alpha-conversion and cons-hashing + # is the string literal of hash() of the type and cons-hashing key: + alpha_mangled_args = _alpha_mangle(result, str(hash((cls_specific,) + cache_key))) + # TODO eliminate code duplication here... + result = super(FunsorMeta, cls_specific).__call__(*alpha_mangled_args) + result._ast_values = alpha_mangled_args + cache_key = tuple(id(arg) if type(arg).__name__ == "DeviceArray" or not isinstance(arg, Hashable) + else arg for arg in alpha_mangled_args) cls._cons_cache[cache_key] = result return result From 6135a8bd711f11535d3096515180a4da15cd2543 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 27 Dec 2020 23:18:15 -0500 Subject: [PATCH 02/19] add regression test --- test/test_alpha_conversion.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/test_alpha_conversion.py b/test/test_alpha_conversion.py index 3a1be13b2..a22938b32 100644 --- a/test/test_alpha_conversion.py +++ b/test/test_alpha_conversion.py @@ -115,3 +115,11 @@ def test_sample_independent(): actual = Independent(f, 'x', 'i', 'x_i') assert actual.sample('i') assert actual.sample('j', {'i': 2}) + + +def test_subs_already_bound(): + with interpretation(reflect): + x = Variable('x', Real) + y1 = (2 * x)(x=3) + y2 = y1.arg(4) + assert y1.bound != y2.bound From cdf8ad0faeafc96aa6adbdfce2571ba067fdd0a0 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 28 Dec 2020 00:31:53 -0500 Subject: [PATCH 03/19] further code organization change --- funsor/terms.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index ef62d5345..683d15b94 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -44,16 +44,13 @@ def subs_interpreter(cls, *args): def _alpha_mangle(expr, identifier): """ Rename bound variables in expr to avoid conflict with any free variables. - Returns expr._ast_values with mangled names. + Returns substitution dictionary with mangled names for consumption by Funsor._alpha_convert. """ # how we know which variables to mangle: we assume variable names include # constant-size information about the original binding context, # in the form of a cons-hash key of the binding context. - alpha_subs = {name: name.split("__BOUND")[0] + "__BOUND_" + identifier - for name in expr.bound if identifier not in name} - if not alpha_subs: - return expr._ast_values - return expr._alpha_convert(alpha_subs) # return mangled _ast_values + return {name: name.split("__BOUND")[0] + "__BOUND_" + identifier + for name in expr.bound if identifier not in name} def reflect(cls, *args, **kwargs): @@ -84,13 +81,16 @@ def reflect(cls, *args, **kwargs): # alpha-convert eagerly upon binding any variable. # the identifier we use to reconcile alpha-conversion and cons-hashing # is the string literal of hash() of the type and cons-hashing key: - alpha_mangled_args = _alpha_mangle(result, str(hash((cls_specific,) + cache_key))) - # TODO eliminate code duplication here... - result = super(FunsorMeta, cls_specific).__call__(*alpha_mangled_args) - result._ast_values = alpha_mangled_args - cache_key = tuple(id(arg) if type(arg).__name__ == "DeviceArray" or not isinstance(arg, Hashable) - else arg for arg in alpha_mangled_args) - + alpha_subs = _alpha_mangle(result, str(hash((cls_specific,) + cache_key))) + if alpha_subs: + # TODO eliminate code duplication here... + alpha_mangled_args = result._alpha_convert(alpha_subs) + result = super(FunsorMeta, cls_specific).__call__(*alpha_mangled_args) + result._ast_values = alpha_mangled_args + cache_key = tuple(id(arg) if type(arg).__name__ == "DeviceArray" or not isinstance(arg, Hashable) + else arg for arg in alpha_mangled_args) + + result._cache_key = cache_key cls._cons_cache[cache_key] = result return result From 439d63e249c20c1482b4651b1f8533d24ccb841d Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 28 Dec 2020 01:03:48 -0500 Subject: [PATCH 04/19] dont strip name --- funsor/terms.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index 683d15b94..5237dd7b6 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -49,8 +49,8 @@ def _alpha_mangle(expr, identifier): # how we know which variables to mangle: we assume variable names include # constant-size information about the original binding context, # in the form of a cons-hash key of the binding context. - return {name: name.split("__BOUND")[0] + "__BOUND_" + identifier - for name in expr.bound if identifier not in name} + return {name: name + "__BOUND_" + identifier + for name in expr.bound if not name.endswith(identifier)} def reflect(cls, *args, **kwargs): @@ -87,10 +87,15 @@ def reflect(cls, *args, **kwargs): alpha_mangled_args = result._alpha_convert(alpha_subs) result = super(FunsorMeta, cls_specific).__call__(*alpha_mangled_args) result._ast_values = alpha_mangled_args + + # we also make the old cons cache_key point to the new mangled value. + # XXX this matches the previous behavior of reflect, but may be a hack + # necessitated by ambiguous behavior of cons-hashing for funsor.Tensor + cls._cons_cache[cache_key] = result + cache_key = tuple(id(arg) if type(arg).__name__ == "DeviceArray" or not isinstance(arg, Hashable) else arg for arg in alpha_mangled_args) - result._cache_key = cache_key cls._cons_cache[cache_key] = result return result From 3fba6dddc748ed4ea129dfe44a4afb07cf9ae591 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 28 Dec 2020 09:31:42 -0500 Subject: [PATCH 05/19] simplify, remove hash() usage --- funsor/terms.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index 5237dd7b6..3b82dfa63 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -41,16 +41,12 @@ def subs_interpreter(cls, *args): return interpreter.reinterpret(expr) -def _alpha_mangle(expr, identifier): +def _alpha_mangle(expr): """ Rename bound variables in expr to avoid conflict with any free variables. Returns substitution dictionary with mangled names for consumption by Funsor._alpha_convert. """ - # how we know which variables to mangle: we assume variable names include - # constant-size information about the original binding context, - # in the form of a cons-hash key of the binding context. - return {name: name + "__BOUND_" + identifier - for name in expr.bound if not name.endswith(identifier)} + return {name: interpreter.gensym(name.split("__BOUND_")[0] + "__BOUND_") for name in expr.bound} def reflect(cls, *args, **kwargs): @@ -81,16 +77,17 @@ def reflect(cls, *args, **kwargs): # alpha-convert eagerly upon binding any variable. # the identifier we use to reconcile alpha-conversion and cons-hashing # is the string literal of hash() of the type and cons-hashing key: - alpha_subs = _alpha_mangle(result, str(hash((cls_specific,) + cache_key))) - if alpha_subs: - # TODO eliminate code duplication here... + if result.bound: + alpha_subs = _alpha_mangle(result) alpha_mangled_args = result._alpha_convert(alpha_subs) + + # TODO eliminate code duplication below + # this is currently necessary because .bound is computed in __init__(). result = super(FunsorMeta, cls_specific).__call__(*alpha_mangled_args) result._ast_values = alpha_mangled_args # we also make the old cons cache_key point to the new mangled value. - # XXX this matches the previous behavior of reflect, but may be a hack - # necessitated by ambiguous behavior of cons-hashing for funsor.Tensor + # this guarantees that alpha-conversion only runs once for this expression. cls._cons_cache[cache_key] = result cache_key = tuple(id(arg) if type(arg).__name__ == "DeviceArray" or not isinstance(arg, Hashable) From aa4c89cd104786b0d1da6a25f1ce86a6c85d6946 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 28 Dec 2020 09:36:11 -0500 Subject: [PATCH 06/19] simplify alpha_mangle --- funsor/terms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index 3b82dfa63..9097354c9 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -41,12 +41,12 @@ def subs_interpreter(cls, *args): return interpreter.reinterpret(expr) -def _alpha_mangle(expr): +def _alpha_mangle(bound_vars): """ Rename bound variables in expr to avoid conflict with any free variables. Returns substitution dictionary with mangled names for consumption by Funsor._alpha_convert. """ - return {name: interpreter.gensym(name.split("__BOUND_")[0] + "__BOUND_") for name in expr.bound} + return {name: interpreter.gensym(name.split("__BOUND_")[0] + "__BOUND_") for name in bound_vars} def reflect(cls, *args, **kwargs): @@ -78,7 +78,7 @@ def reflect(cls, *args, **kwargs): # the identifier we use to reconcile alpha-conversion and cons-hashing # is the string literal of hash() of the type and cons-hashing key: if result.bound: - alpha_subs = _alpha_mangle(result) + alpha_subs = _alpha_mangle(result.bound) alpha_mangled_args = result._alpha_convert(alpha_subs) # TODO eliminate code duplication below From b5da767ae8a72bcc653912634fc52f5085436cba Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 28 Dec 2020 09:56:15 -0500 Subject: [PATCH 07/19] optimization to avoid some unnecessary deep recursion in Funsor._alpha_convert --- funsor/terms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/funsor/terms.py b/funsor/terms.py index 9097354c9..c4386704b 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -403,7 +403,8 @@ def _alpha_convert(self, alpha_subs): # Substitute all funsor values. # Subclasses must handle string conversion. assert self.bound.issuperset(alpha_subs) - return tuple(substitute(v, alpha_subs) for v in self._ast_values) + return tuple(substitute(v, alpha_subs) if not isinstance(v, Funsor) or self.bound.intersection(v.inputs) else v + for v in self._ast_values) def __call__(self, *args, **kwargs): """ From 73d2d5013539f85f630f7c30aa76963a3ab71219 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 28 Dec 2020 10:39:38 -0500 Subject: [PATCH 08/19] optimization to avoid unnecessary recursion in Contraction._alpha_convert --- funsor/cnf.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index e472534ec..c5c5ef3fc 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -31,6 +31,7 @@ eager, normalize, reflect, + substitute, to_funsor ) from funsor.util import broadcast_shape, get_backend, quote @@ -158,8 +159,10 @@ def _alpha_convert(self, alpha_subs): for term in self.terms: bound_types.update({k: term.inputs[k] for k in self.bound.intersection(term.inputs)}) alpha_subs = {k: to_funsor(v, bound_types[k]) for k, v in alpha_subs.items()} - red_op, bin_op, _, terms = super()._alpha_convert(alpha_subs) - return red_op, bin_op, reduced_vars, terms + terms = tuple(term if isinstance(term, Funsor) and not self.bound.intersection(term.inputs) + else substitute(term, alpha_subs) + for term in self.terms) + return self.red_op, self.bin_op, reduced_vars, terms GaussianMixture = Contraction[Union[ops.LogAddExpOp, NullOp], ops.AddOp, frozenset, From 5cc64d5e19e0a974d52037ceacbfb624081a1791 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 28 Dec 2020 10:40:56 -0500 Subject: [PATCH 09/19] nit --- funsor/terms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/funsor/terms.py b/funsor/terms.py index c4386704b..c73c58d2f 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -403,7 +403,8 @@ def _alpha_convert(self, alpha_subs): # Substitute all funsor values. # Subclasses must handle string conversion. assert self.bound.issuperset(alpha_subs) - return tuple(substitute(v, alpha_subs) if not isinstance(v, Funsor) or self.bound.intersection(v.inputs) else v + return tuple(v if isinstance(v, Funsor) and not self.bound.intersection(v.inputs) + else substitute(v, alpha_subs) for v in self._ast_values) def __call__(self, *args, **kwargs): From 610037be7eee143954fadab1d7ad1883cb9dff89 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 16 Feb 2021 20:56:00 -0500 Subject: [PATCH 10/19] fix test --- funsor/cnf.py | 3 ++- test/test_alpha_conversion.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index 71bf7c40a..37426278e 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -222,7 +222,8 @@ def _alpha_convert(self, alpha_subs): alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} terms = tuple( term - if isinstance(term, Funsor) and not self.bound.intersection(term.inputs) + if isinstance(term, Funsor) + and not set(self.bound).intersection(term.inputs) else substitute(term, alpha_subs) for term in self.terms ) diff --git a/test/test_alpha_conversion.py b/test/test_alpha_conversion.py index 3b401a78c..19cfecb1d 100644 --- a/test/test_alpha_conversion.py +++ b/test/test_alpha_conversion.py @@ -128,7 +128,7 @@ def test_sample_independent(): def test_subs_already_bound(): - with interpretation(reflect): + with reflect: x = Variable("x", Real) y1 = (2 * x)(x=3) y2 = y1.arg(4) From c42834d7a8001b2bd89d09e42d44d0cf60d00832 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 21 Feb 2021 01:16:23 -0500 Subject: [PATCH 11/19] try forcing reflect in _alpha_convert --- funsor/cnf.py | 2 +- funsor/terms.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index 7a4653e89..cc264758e 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -224,7 +224,7 @@ def _alpha_convert(self, alpha_subs): term if isinstance(term, Funsor) and not set(self.bound).intersection(term.inputs) - else substitute(term, alpha_subs) + else reflect(substitute)(term, alpha_subs) for term in self.terms ) return self.red_op, self.bin_op, reduced_vars, terms diff --git a/funsor/terms.py b/funsor/terms.py index d2b6a11d9..c820d8b21 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -303,7 +303,7 @@ def _alpha_convert(self, alpha_subs): return tuple( v if isinstance(v, Funsor) and not set(self.bound).intersection(v.inputs) - else substitute(v, alpha_subs) + else reflect(substitute)(v, alpha_subs) for v in self._ast_values ) From b4bed6353a9f1473e5b2bcfb6c15e7de46733035 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 21 Feb 2021 01:47:54 -0500 Subject: [PATCH 12/19] centralize reflect application in alpha conversion --- funsor/cnf.py | 2 +- funsor/terms.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index cc264758e..7a4653e89 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -224,7 +224,7 @@ def _alpha_convert(self, alpha_subs): term if isinstance(term, Funsor) and not set(self.bound).intersection(term.inputs) - else reflect(substitute)(term, alpha_subs) + else substitute(term, alpha_subs) for term in self.terms ) return self.red_op, self.bin_op, reduced_vars, terms diff --git a/funsor/terms.py b/funsor/terms.py index c820d8b21..4464c5380 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -108,7 +108,7 @@ def reflect(cls, *args, **kwargs): # is the string literal of hash() of the type and cons-hashing key: if result.bound: alpha_subs = _alpha_mangle(result.bound) - alpha_mangled_args = result._alpha_convert(alpha_subs) + alpha_mangled_args = reflect(result._alpha_convert)(alpha_subs) # TODO eliminate code duplication below # this is currently necessary because .bound is computed in __init__(). @@ -303,7 +303,7 @@ def _alpha_convert(self, alpha_subs): return tuple( v if isinstance(v, Funsor) and not set(self.bound).intersection(v.inputs) - else reflect(substitute)(v, alpha_subs) + else substitute(v, alpha_subs) for v in self._ast_values ) From 9c80304932e26e773e36de7e63e5b3b37cb9ef25 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 16 Mar 2021 17:08:36 -0400 Subject: [PATCH 13/19] centralize alpha_convert further --- funsor/cnf.py | 9 +-------- funsor/integrate.py | 4 +--- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index 7a4653e89..e863bb694 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -29,7 +29,6 @@ Subs, Unary, Variable, - substitute, to_funsor, ) from funsor.typing import Variadic @@ -220,13 +219,7 @@ def _alpha_convert(self, alpha_subs): for var in self.reduced_vars ) alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - terms = tuple( - term - if isinstance(term, Funsor) - and not set(self.bound).intersection(term.inputs) - else substitute(term, alpha_subs) - for term in self.terms - ) + terms = super()._alpha_convert(alpha_subs)[-1] return self.red_op, self.bin_op, reduced_vars, terms diff --git a/funsor/integrate.py b/funsor/integrate.py index b75b6f50e..62b9235d0 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -18,7 +18,6 @@ Unary, Variable, _convert_reduced_vars, - substitute, to_funsor, ) @@ -77,8 +76,7 @@ def _alpha_convert(self, alpha_subs): ) for k, v in alpha_subs.items() } - log_measure = substitute(self.log_measure, alpha_subs) - integrand = substitute(self.integrand, alpha_subs) + log_measure, integrand, _ = super()._alpha_convert(alpha_subs) return log_measure, integrand, reduced_vars From ccfdec535e893d551e80b89e818f7c4ed392a61f Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 16 Mar 2021 17:22:11 -0400 Subject: [PATCH 14/19] remove obsolete alpha_convert methods --- funsor/cnf.py | 10 ---------- funsor/integrate.py | 16 ---------------- funsor/terms.py | 16 +--------------- 3 files changed, 1 insertion(+), 41 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index e863bb694..88f13790e 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -29,7 +29,6 @@ Subs, Unary, Variable, - to_funsor, ) from funsor.typing import Variadic from funsor.util import broadcast_shape, get_backend, quote @@ -213,15 +212,6 @@ def align(self, names): ) # raise NotImplementedError("TODO align all terms") return result - def _alpha_convert(self, alpha_subs): - reduced_vars = frozenset( - to_funsor(alpha_subs.get(var.name, var), var.output) - for var in self.reduced_vars - ) - alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - terms = super()._alpha_convert(alpha_subs)[-1] - return self.red_op, self.bin_op, reduced_vars, terms - GaussianMixture = Contraction[ Union[ops.LogaddexpOp, NullOp], diff --git a/funsor/integrate.py b/funsor/integrate.py index 62b9235d0..a1c28b9c6 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -18,7 +18,6 @@ Unary, Variable, _convert_reduced_vars, - to_funsor, ) @@ -64,21 +63,6 @@ def __init__(self, log_measure, integrand, reduced_vars): self.integrand = integrand self.reduced_vars = reduced_vars - def _alpha_convert(self, alpha_subs): - assert set(self.bound).issuperset(alpha_subs) - reduced_vars = frozenset( - Variable(alpha_subs.get(v.name, v.name), v.output) - for v in self.reduced_vars - ) - alpha_subs = { - k: to_funsor( - v, self.integrand.inputs.get(k, self.log_measure.inputs.get(k)) - ) - for k, v in alpha_subs.items() - } - log_measure, integrand, _ = super()._alpha_convert(alpha_subs) - return log_measure, integrand, reduced_vars - @normalize.register(Integrate, Funsor, Funsor, frozenset) def normalize_integrate(log_measure, integrand, reduced_vars): diff --git a/funsor/terms.py b/funsor/terms.py index 4464c5380..c35faa821 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -298,8 +298,8 @@ def _alpha_convert(self, alpha_subs): Rename bound variables while preserving all free variables. """ # Substitute all funsor values. - # Subclasses must handle string conversion. assert set(alpha_subs).issubset(self.bound) + alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} return tuple( v if isinstance(v, Funsor) and not set(self.bound).intersection(v.inputs) @@ -1027,14 +1027,6 @@ def __str__(self): str(self.arg), self.op.__name__, ", ".join(rvars) ) - def _alpha_convert(self, alpha_subs): - alpha_subs = { - k: to_funsor(v, self.arg.inputs[k]) for k, v in alpha_subs.items() - } - op, arg, reduced_vars = super()._alpha_convert(alpha_subs) - reduced_vars = frozenset(alpha_subs.get(var.name, var) for var in reduced_vars) - return op, arg, reduced_vars - def _reduce_unrelated_vars(op, arg, reduced_vars): factor_vars = reduced_vars - arg.input_vars @@ -1496,12 +1488,6 @@ def __init__(self, var, expr): self.var = var self.expr = expr - def _alpha_convert(self, alpha_subs): - alpha_subs = { - k: to_funsor(v, self.var.inputs[k]) for k, v in alpha_subs.items() - } - return super()._alpha_convert(alpha_subs) - @eager.register(Binary, GetitemOp, Lambda, (Funsor, Align)) def eager_getitem_lambda(op, lhs, rhs): From 9b6a60ad96f8793b4e88a372bcea929f0d5b37e1 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 18 Mar 2021 12:01:37 -0400 Subject: [PATCH 15/19] delete more _alpha_convert methods --- funsor/factory.py | 4 ---- funsor/terms.py | 6 ------ 2 files changed, 10 deletions(-) diff --git a/funsor/factory.py b/funsor/factory.py index 651cf76d8..71e37bcff 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -246,10 +246,6 @@ def __init__(self, **kwargs): for name, arg in zip(self._ast_fields, args): setattr(self, name, arg) - def _alpha_convert(self, alpha_subs): - alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - return Funsor._alpha_convert(self, alpha_subs) - ResultMeta.__name__ = f"{fn.__name__}Meta" Result = ResultMeta( fn.__name__, (Funsor,), {"__init__": __init__, "_alpha_convert": _alpha_convert} diff --git a/funsor/terms.py b/funsor/terms.py index 13d3388c6..6fb778905 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1194,12 +1194,6 @@ def __init__(self, op, subs, source, reduced_vars): self.source = source self.reduced_vars = reduced_vars - def _alpha_convert(self, alpha_subs): - alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - op, subs, source, reduced_vars = super()._alpha_convert(alpha_subs) - reduced_vars = frozenset(alpha_subs.get(var.name, var) for var in reduced_vars) - return op, subs, source, reduced_vars - class Approximate(Funsor): """ From fa7995090723fa248660fc569363902884ec2d0b Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 18 Mar 2021 12:28:48 -0400 Subject: [PATCH 16/19] lint --- funsor/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/factory.py b/funsor/factory.py index 71e37bcff..47198a47a 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -248,7 +248,7 @@ def __init__(self, **kwargs): ResultMeta.__name__ = f"{fn.__name__}Meta" Result = ResultMeta( - fn.__name__, (Funsor,), {"__init__": __init__, "_alpha_convert": _alpha_convert} + fn.__name__, (Funsor,), {"__init__": __init__} ) pattern = (Result,) + tuple( _hint_to_pattern(input_types[k]) for k in Result._ast_fields From d2ca611fc3117777da72158ddf246b1d636f495c Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 18 Mar 2021 12:29:29 -0400 Subject: [PATCH 17/19] lint --- funsor/factory.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/funsor/factory.py b/funsor/factory.py index 47198a47a..3d05f63f4 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -247,9 +247,7 @@ def __init__(self, **kwargs): setattr(self, name, arg) ResultMeta.__name__ = f"{fn.__name__}Meta" - Result = ResultMeta( - fn.__name__, (Funsor,), {"__init__": __init__} - ) + Result = ResultMeta(fn.__name__, (Funsor,), {"__init__": __init__}) pattern = (Result,) + tuple( _hint_to_pattern(input_types[k]) for k in Result._ast_fields ) From 645618436cb4da76a7b79369b7f0f2d7c6dd783b Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 18 Mar 2021 14:15:24 -0400 Subject: [PATCH 18/19] remove outdated comment --- funsor/terms.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index 6fb778905..a26cb69f2 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -105,8 +105,6 @@ def reflect(cls, *args, **kwargs): result._ast_values = args # alpha-convert eagerly upon binding any variable. - # the identifier we use to reconcile alpha-conversion and cons-hashing - # is the string literal of hash() of the type and cons-hashing key: if result.bound: alpha_subs = _alpha_mangle(result.bound) alpha_mangled_args = reflect(result._alpha_convert)(alpha_subs) From 01bfd391900f1c95fc979483d918b9e1cbfcf967 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 13 Apr 2021 02:05:20 -0400 Subject: [PATCH 19/19] avoid renaming twice --- funsor/terms.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index 6bd7bb334..45a5fdf4a 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -112,6 +112,9 @@ def _alpha_mangle(bound_vars): } +_SKIP_ALPHA = False + + @reflect.set_callable def reflect(cls, *args, **kwargs): """ @@ -136,9 +139,16 @@ def reflect(cls, *args, **kwargs): result._ast_values = args # alpha-convert eagerly upon binding any variable. - if result.bound: + global _SKIP_ALPHA + if result.bound and not _SKIP_ALPHA: alpha_subs = _alpha_mangle(result.bound) - alpha_mangled_args = reflect(result._alpha_convert)(alpha_subs) + try: + # optimization: don't perform alpha-conversion again + # when renaming subexpressions of result + _SKIP_ALPHA = True + alpha_mangled_args = reflect(result._alpha_convert)(alpha_subs) + finally: + _SKIP_ALPHA = False # TODO eliminate code duplication below # this is currently necessary because .bound is computed in __init__(). @@ -334,12 +344,7 @@ def _alpha_convert(self, alpha_subs): # Substitute all funsor values. assert set(alpha_subs).issubset(self.bound) alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - return tuple( - v - if isinstance(v, Funsor) and not set(self.bound).intersection(v.inputs) - else substitute(v, alpha_subs) - for v in self._ast_values - ) + return substitute(self._ast_values, alpha_subs) def __call__(self, *args, **kwargs): """