diff --git a/funsor/cnf.py b/funsor/cnf.py index 39e82e1e..43532ebf 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -29,7 +29,6 @@ Subs, Unary, Variable, - to_funsor, ) from funsor.typing import Variadic from funsor.util import broadcast_shape, get_backend, quote @@ -213,15 +212,6 @@ def align(self, names): ) # raise NotImplementedError("TODO align all terms") return result - def _alpha_convert(self, alpha_subs): - reduced_vars = frozenset( - to_funsor(alpha_subs.get(var.name, var), var.output) - for var in self.reduced_vars - ) - alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - red_op, bin_op, _, terms = super()._alpha_convert(alpha_subs) - return red_op, bin_op, reduced_vars, terms - GaussianMixture = Contraction[ Union[ops.LogaddexpOp, NullOp], diff --git a/funsor/integrate.py b/funsor/integrate.py index f20879a0..b7dbf95c 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -18,8 +18,6 @@ Unary, Variable, _convert_reduced_vars, - substitute, - to_funsor, ) @@ -65,22 +63,6 @@ def __init__(self, log_measure, integrand, reduced_vars): self.integrand = integrand self.reduced_vars = reduced_vars - def _alpha_convert(self, alpha_subs): - assert set(self.bound).issuperset(alpha_subs) - reduced_vars = frozenset( - Variable(alpha_subs.get(v.name, v.name), v.output) - for v in self.reduced_vars - ) - alpha_subs = { - k: to_funsor( - v, self.integrand.inputs.get(k, self.log_measure.inputs.get(k)) - ) - for k, v in alpha_subs.items() - } - log_measure = substitute(self.log_measure, alpha_subs) - integrand = substitute(self.integrand, alpha_subs) - return log_measure, integrand, reduced_vars - @normalize.register(Integrate, Funsor, Funsor, frozenset) def normalize_integrate(log_measure, integrand, reduced_vars): diff --git a/funsor/terms.py b/funsor/terms.py index ae812e01..45a5fdf4 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -101,22 +101,18 @@ def stop(x): return env[expr] -def _alpha_mangle(expr): +def _alpha_mangle(bound_vars): """ Rename bound variables in expr to avoid conflict with any free variables. - - FIXME this does not avoid conflict with other bound variables. + Returns substitution dictionary with mangled names for consumption by Funsor._alpha_convert. """ - alpha_subs = { - name: interpreter.gensym(name + "__BOUND") - for name in expr.bound - if "__BOUND" not in name + return { + name: interpreter.gensym(name.split("__BOUND_")[0] + "__BOUND_") + for name in bound_vars } - if not alpha_subs: - return expr - ast_values = instrument.debug_logged(expr._alpha_convert)(alpha_subs) - return reflect.interpret(type(expr), *ast_values) + +_SKIP_ALPHA = False @reflect.set_callable @@ -142,6 +138,29 @@ def reflect(cls, *args, **kwargs): result = super(FunsorMeta, cls_specific).__call__(*args) result._ast_values = args + # alpha-convert eagerly upon binding any variable. + global _SKIP_ALPHA + if result.bound and not _SKIP_ALPHA: + alpha_subs = _alpha_mangle(result.bound) + try: + # optimization: don't perform alpha-conversion again + # when renaming subexpressions of result + _SKIP_ALPHA = True + alpha_mangled_args = reflect(result._alpha_convert)(alpha_subs) + finally: + _SKIP_ALPHA = False + + # TODO eliminate code duplication below + # this is currently necessary because .bound is computed in __init__(). + result = super(FunsorMeta, cls_specific).__call__(*alpha_mangled_args) + result._ast_values = alpha_mangled_args + + # we also make the old cons cache_key point to the new mangled value. + # this guarantees that alpha-conversion only runs once for this expression. + cls._cons_cache[cache_key] = result + + cache_key = reflect.make_hash_key(cls, *alpha_mangled_args) + if instrument.PROFILE: size, depth, width = _get_ast_stats(result) instrument.COUNTERS["ast_size"][size] += 1 @@ -150,9 +169,6 @@ def reflect(cls, *args, **kwargs): instrument.COUNTERS["funsor"][classname] += 1 instrument.COUNTERS[classname][width] += 1 - # alpha-convert eagerly upon binding any variable - result = _alpha_mangle(result) - cls._cons_cache[cache_key] = result return result @@ -326,9 +342,9 @@ def _alpha_convert(self, alpha_subs): Rename bound variables while preserving all free variables. """ # Substitute all funsor values. - # Subclasses must handle string conversion. assert set(alpha_subs).issubset(self.bound) - return tuple(substitute(v, alpha_subs) for v in self._ast_values) + alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} + return substitute(self._ast_values, alpha_subs) def __call__(self, *args, **kwargs): """ @@ -1095,14 +1111,6 @@ def __str__(self): str(self.arg), self.op.__name__, ", ".join(rvars) ) - def _alpha_convert(self, alpha_subs): - alpha_subs = { - k: to_funsor(v, self.arg.inputs[k]) for k, v in alpha_subs.items() - } - op, arg, reduced_vars = super()._alpha_convert(alpha_subs) - reduced_vars = frozenset(alpha_subs.get(var.name, var) for var in reduced_vars) - return op, arg, reduced_vars - def _reduce_unrelated_vars(op, arg, reduced_vars): factor_vars = reduced_vars - arg.input_vars @@ -1235,12 +1243,6 @@ def __init__(self, op, subs, source, reduced_vars): self.source = source self.reduced_vars = reduced_vars - def _alpha_convert(self, alpha_subs): - alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - op, subs, source, reduced_vars = super()._alpha_convert(alpha_subs) - reduced_vars = frozenset(alpha_subs.get(var.name, var) for var in reduced_vars) - return op, subs, source, reduced_vars - def eager_subs(self, subs): subs = OrderedDict(subs) new_subs = [] @@ -1720,12 +1722,6 @@ def __init__(self, var, expr): self.var = var self.expr = expr - def _alpha_convert(self, alpha_subs): - alpha_subs = { - k: to_funsor(v, self.var.inputs[k]) for k, v in alpha_subs.items() - } - return super()._alpha_convert(alpha_subs) - @eager.register(Binary, GetitemOp, Lambda, (Funsor, Align)) def eager_getitem_lambda(op, lhs, rhs): diff --git a/test/test_alpha_conversion.py b/test/test_alpha_conversion.py index b9c2819f..19cfecb1 100644 --- a/test/test_alpha_conversion.py +++ b/test/test_alpha_conversion.py @@ -125,3 +125,11 @@ def test_sample_independent(): actual = Independent(f, "x", "i", "x_i") assert actual.sample("i") assert actual.sample("j", {"i": 2}) + + +def test_subs_already_bound(): + with reflect: + x = Variable("x", Real) + y1 = (2 * x)(x=3) + y2 = y1.arg(4) + assert y1.bound != y2.bound