Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure alpha conversion always runs when a new variable is bound #414

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
02b5d9d
Ensure alpha conversion always runs when a new variable is bound
eb8680 Dec 25, 2020
6135a8b
add regression test
eb8680 Dec 28, 2020
cdf8ad0
further code organization change
eb8680 Dec 28, 2020
439d63e
dont strip name
eb8680 Dec 28, 2020
3fba6dd
simplify, remove hash() usage
eb8680 Dec 28, 2020
aa4c89c
simplify alpha_mangle
eb8680 Dec 28, 2020
b5da767
optimization to avoid some unnecessary deep recursion in Funsor._alph…
eb8680 Dec 28, 2020
73d2d50
optimization to avoid unnecessary recursion in Contraction._alpha_con…
eb8680 Dec 28, 2020
5cc64d5
nit
eb8680 Dec 28, 2020
7ba9869
Merge branch 'master' into repeat-alpha-conversion
eb8680 Feb 17, 2021
610037b
fix test
eb8680 Feb 17, 2021
e642186
Merge branch 'master' into repeat-alpha-conversion
eb8680 Feb 21, 2021
c42834d
try forcing reflect in _alpha_convert
eb8680 Feb 21, 2021
b4bed63
centralize reflect application in alpha conversion
eb8680 Feb 21, 2021
9c80304
centralize alpha_convert further
eb8680 Mar 16, 2021
ccfdec5
remove obsolete alpha_convert methods
eb8680 Mar 16, 2021
e982f17
Merge branch 'master' into repeat-alpha-conversion
eb8680 Mar 18, 2021
9b6a60a
delete more _alpha_convert methods
eb8680 Mar 18, 2021
fa79950
lint
eb8680 Mar 18, 2021
d2ca611
lint
eb8680 Mar 18, 2021
6456184
remove outdated comment
eb8680 Mar 18, 2021
c4e8010
Merge branch 'master' into repeat-alpha-conversion
eb8680 Apr 9, 2021
cd4fade
Merge branch 'master' into repeat-alpha-conversion
eb8680 Apr 13, 2021
01bfd39
avoid renaming twice
eb8680 Apr 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
Subs,
Unary,
Variable,
to_funsor,
)
from funsor.typing import Variadic
from funsor.util import broadcast_shape, get_backend, quote
Expand Down Expand Up @@ -213,15 +212,6 @@ def align(self, names):
) # raise NotImplementedError("TODO align all terms")
return result

def _alpha_convert(self, alpha_subs):
reduced_vars = frozenset(
to_funsor(alpha_subs.get(var.name, var), var.output)
for var in self.reduced_vars
)
alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()}
red_op, bin_op, _, terms = super()._alpha_convert(alpha_subs)
return red_op, bin_op, reduced_vars, terms


GaussianMixture = Contraction[
Union[ops.LogaddexpOp, NullOp],
Expand Down
18 changes: 0 additions & 18 deletions funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
Unary,
Variable,
_convert_reduced_vars,
substitute,
to_funsor,
)


Expand Down Expand Up @@ -65,22 +63,6 @@ def __init__(self, log_measure, integrand, reduced_vars):
self.integrand = integrand
self.reduced_vars = reduced_vars

def _alpha_convert(self, alpha_subs):
assert set(self.bound).issuperset(alpha_subs)
reduced_vars = frozenset(
Variable(alpha_subs.get(v.name, v.name), v.output)
for v in self.reduced_vars
)
alpha_subs = {
k: to_funsor(
v, self.integrand.inputs.get(k, self.log_measure.inputs.get(k))
)
for k, v in alpha_subs.items()
}
log_measure = substitute(self.log_measure, alpha_subs)
integrand = substitute(self.integrand, alpha_subs)
return log_measure, integrand, reduced_vars


@normalize.register(Integrate, Funsor, Funsor, frozenset)
def normalize_integrate(log_measure, integrand, reduced_vars):
Expand Down
68 changes: 32 additions & 36 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,18 @@ def stop(x):
return env[expr]


def _alpha_mangle(expr):
def _alpha_mangle(bound_vars):
"""
Rename bound variables in expr to avoid conflict with any free variables.

FIXME this does not avoid conflict with other bound variables.
Returns substitution dictionary with mangled names for consumption by Funsor._alpha_convert.
"""
alpha_subs = {
name: interpreter.gensym(name + "__BOUND")
for name in expr.bound
if "__BOUND" not in name
return {
name: interpreter.gensym(name.split("__BOUND_")[0] + "__BOUND_")
for name in bound_vars
}
if not alpha_subs:
return expr

ast_values = instrument.debug_logged(expr._alpha_convert)(alpha_subs)
return reflect.interpret(type(expr), *ast_values)

_SKIP_ALPHA = False


@reflect.set_callable
Expand All @@ -142,6 +138,29 @@ def reflect(cls, *args, **kwargs):
result = super(FunsorMeta, cls_specific).__call__(*args)
result._ast_values = args

# alpha-convert eagerly upon binding any variable.
global _SKIP_ALPHA
if result.bound and not _SKIP_ALPHA:
alpha_subs = _alpha_mangle(result.bound)
try:
# optimization: don't perform alpha-conversion again
# when renaming subexpressions of result
_SKIP_ALPHA = True
alpha_mangled_args = reflect(result._alpha_convert)(alpha_subs)
finally:
_SKIP_ALPHA = False

# TODO eliminate code duplication below
# this is currently necessary because .bound is computed in __init__().
result = super(FunsorMeta, cls_specific).__call__(*alpha_mangled_args)
result._ast_values = alpha_mangled_args

# we also make the old cons cache_key point to the new mangled value.
# this guarantees that alpha-conversion only runs once for this expression.
cls._cons_cache[cache_key] = result
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line ensures that _alpha_mangle does not break cons-hashing.


cache_key = reflect.make_hash_key(cls, *alpha_mangled_args)

if instrument.PROFILE:
size, depth, width = _get_ast_stats(result)
instrument.COUNTERS["ast_size"][size] += 1
Expand All @@ -150,9 +169,6 @@ def reflect(cls, *args, **kwargs):
instrument.COUNTERS["funsor"][classname] += 1
instrument.COUNTERS[classname][width] += 1

# alpha-convert eagerly upon binding any variable
result = _alpha_mangle(result)

cls._cons_cache[cache_key] = result
return result

Expand Down Expand Up @@ -326,9 +342,9 @@ def _alpha_convert(self, alpha_subs):
Rename bound variables while preserving all free variables.
"""
# Substitute all funsor values.
# Subclasses must handle string conversion.
assert set(alpha_subs).issubset(self.bound)
return tuple(substitute(v, alpha_subs) for v in self._ast_values)
alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()}
return substitute(self._ast_values, alpha_subs)

def __call__(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -1095,14 +1111,6 @@ def __str__(self):
str(self.arg), self.op.__name__, ", ".join(rvars)
)

def _alpha_convert(self, alpha_subs):
alpha_subs = {
k: to_funsor(v, self.arg.inputs[k]) for k, v in alpha_subs.items()
}
op, arg, reduced_vars = super()._alpha_convert(alpha_subs)
reduced_vars = frozenset(alpha_subs.get(var.name, var) for var in reduced_vars)
return op, arg, reduced_vars


def _reduce_unrelated_vars(op, arg, reduced_vars):
factor_vars = reduced_vars - arg.input_vars
Expand Down Expand Up @@ -1235,12 +1243,6 @@ def __init__(self, op, subs, source, reduced_vars):
self.source = source
self.reduced_vars = reduced_vars

def _alpha_convert(self, alpha_subs):
alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()}
op, subs, source, reduced_vars = super()._alpha_convert(alpha_subs)
reduced_vars = frozenset(alpha_subs.get(var.name, var) for var in reduced_vars)
return op, subs, source, reduced_vars

def eager_subs(self, subs):
subs = OrderedDict(subs)
new_subs = []
Expand Down Expand Up @@ -1720,12 +1722,6 @@ def __init__(self, var, expr):
self.var = var
self.expr = expr

def _alpha_convert(self, alpha_subs):
alpha_subs = {
k: to_funsor(v, self.var.inputs[k]) for k, v in alpha_subs.items()
}
return super()._alpha_convert(alpha_subs)


@eager.register(Binary, GetitemOp, Lambda, (Funsor, Align))
def eager_getitem_lambda(op, lhs, rhs):
Expand Down
8 changes: 8 additions & 0 deletions test/test_alpha_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,11 @@ def test_sample_independent():
actual = Independent(f, "x", "i", "x_i")
assert actual.sample("i")
assert actual.sample("j", {"i": 2})


def test_subs_already_bound():
with reflect:
x = Variable("x", Real)
y1 = (2 * x)(x=3)
y2 = y1.arg(4)
assert y1.bound != y2.bound