Skip to content

Commit

Permalink
Move predict from Turing (#716)
Browse files Browse the repository at this point in the history
* move `predict` from Turing

* minor fixes

* Update test/ext/DynamicPPLMCMCChainsExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix test error by discard burn-in's

* add some comments

* fix test error

* Update test/ext/DynamicPPLMCMCChainsExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* refactor the code; add `predict` in Turing that takes array of varinfos

* Update model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* stop using `PredictiveSample` type

* use NamedTuple

* remove predict with varinfos function

* update implementation and tests; no longer using AdvancedHMC

* try fixing naming conflict

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Penelope Yong <[email protected]>
  • Loading branch information
3 people authored Dec 20, 2024
1 parent d0cfaaf commit 6657441
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ DynamicPPLZygoteRulesExt = ["ZygoteRules"]
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.8.4, 0.9"
AbstractPPL = "0.10.1"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
Expand Down
142 changes: 142 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,148 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
return keys(c.info.varname_to_symbol)
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
in `chain`, and return the resulting `Chains`.
The `model` passed to `predict` is often different from the one used to generate `chain`.
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
data points), while the model you pass to `predict` may mark these same variables as missing
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
simulate what new, unobserved data might look like, given your posterior beliefs.
For each parameter configuration in `chain`:
1. All random variables present in `chain` are fixed to their sampled values.
2. Any variables not included in `chain` are sampled from their prior distributions.
If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
predictive distribution.
# Examples
```jldoctest
using AbstractMCMC, Distributions, DynamicPPL, Random
@model function linear_reg(x, y, σ = 0.1)
β ~ Normal(0, 1)
for i in eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end
# Generate synthetic chain using known ground truth parameter
ground_truth_β = 2.0
# Create chain of samples from a normal distribution centered on ground truth
β_chain = MCMCChains.Chains(
rand(Normal(ground_truth_β, 0.002), 1000), [:β,]
)
# Generate predictions for two test points
xs_test = [10.1, 10.2]
m_train = linear_reg(xs_test, fill(missing, length(xs_test)))
predictions = DynamicPPL.AbstractPPL.predict(
Random.default_rng(), m_train, β_chain
)
ys_pred = vec(mean(Array(predictions); dims=1))
# Check if predictions match expected values within tolerance
(
isapprox(ys_pred[1], ground_truth_β * xs_test[1], atol = 0.01),
isapprox(ys_pred[2], ground_truth_β * xs_test[2], atol = 0.01)
)
# output
(true, true)
```
"""
function DynamicPPL.predict(
rng::DynamicPPL.Random.AbstractRNG,
model::DynamicPPL.Model,
chain::MCMCChains.Chains;
include_all=false,
)
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
varinfo = DynamicPPL.VarInfo(model)

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
model(rng, varinfo, DynamicPPL.SampleFromPrior())

vals = DynamicPPL.values_as_in_model(model, varinfo)
varname_vals = mapreduce(
collect,
vcat,
map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)),
)

return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo))
end

chain_result = reduce(
MCMCChains.chainscat,
[
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
chain_idx in 1:size(predictive_samples, 2)
],
)
parameter_names = if include_all
MCMCChains.names(chain_result, :parameters)
else
filter(
k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)),
names(chain_result, :parameters),
)
end
return chain_result[parameter_names]
end

function _predictive_samples_to_arrays(predictive_samples)
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()

sample_dicts = map(predictive_samples) do sample
varname_value_pairs = sample.varname_and_values
varnames = map(first, varname_value_pairs)
values = map(last, varname_value_pairs)
for varname in varnames
push!(variable_names_set, varname)
end

return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
end

variable_names = collect(variable_names_set)
variable_values = [
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
key in variable_names
]

return variable_names, variable_values
end

function _predictive_samples_to_chains(predictive_samples)
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples)
variable_names_symbols = map(Symbol, variable_names)

internal_parameters = [:lp]
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1)

parameter_names = [variable_names_symbols; internal_parameters]
parameter_values = hcat(variable_values, log_probabilities)
parameter_values = MCMCChains.concretize(parameter_values)

return MCMCChains.Chains(
parameter_values, parameter_names, (internals=internal_parameters,)
)
end

"""
returned(model::Model, chain::MCMCChains.Chains)
Expand Down
4 changes: 3 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using AbstractPPL
using Bijectors
using Compat
using Distributions
using OrderedCollections: OrderedDict
using OrderedCollections: OrderedCollections, OrderedDict

using AbstractMCMC: AbstractMCMC
using ADTypes: ADTypes
Expand Down Expand Up @@ -40,6 +40,8 @@ import Base:
keys,
haskey

import AbstractPPL: predict

# VarInfo
export AbstractVarInfo,
VarInfo,
Expand Down
20 changes: 20 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,26 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
end
end

"""
predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})
Generate samples from the posterior predictive distribution by evaluating `model` at each set
of parameter values provided in `chain`. The number of posterior predictive samples matches
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values
and the predicted values.
"""
function predict(
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo}
)
varinfo = DynamicPPL.VarInfo(model)
return map(chain) do params_varinfo
vi = deepcopy(varinfo)
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
model(rng, vi, SampleFromPrior())
return vi
end
end

"""
returned(model::Model, parameters::NamedTuple)
returned(model::Model, values, keys)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.8.4, 0.9"
AbstractPPL = "0.10.1"
Accessors = "0.1"
Bijectors = "0.15.1"
Combinatorics = "1"
Expand Down
8 changes: 4 additions & 4 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,27 +202,27 @@ end
s, m = retval.s, retval.m

# Keword approach.
model_fixed = fix(model; s=s)
model_fixed = DynamicPPL.fix(model; s=s)
@test model_fixed().s == s
@test model_fixed().m != m
# A fixed variable should not contribute at all to the logjoint.
# Assuming `condition` is correctly implemented, the following should hold.
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))

# Positional approach.
model_fixed = fix(model, (; s))
model_fixed = DynamicPPL.fix(model, (; s))
@test model_fixed().s == s
@test model_fixed().m != m
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))

# Pairs approach.
model_fixed = fix(model, @varname(s) => s)
model_fixed = DynamicPPL.fix(model, @varname(s) => s)
@test model_fixed().s == s
@test model_fixed().m != m
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))

# Dictionary approach.
model_fixed = fix(model, Dict(@varname(s) => s))
model_fixed = DynamicPPL.fix(model, Dict(@varname(s) => s))
@test model_fixed().s == s
@test model_fixed().m != m
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))
Expand Down
2 changes: 2 additions & 0 deletions test/ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
@test size(chain_generated) == (1000, 1)
@test mean(chain_generated) 0 atol = 0.1
end

# test for `predict` is in `test/model.jl`
105 changes: 105 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,4 +429,109 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
@test getlogp(varinfo_linked) getlogp(varinfo_linked_result)
end
end

@testset "predict" begin
@testset "with MCMCChains.Chains" begin
DynamicPPL.Random.seed!(100)

@model function linear_reg(x, y, σ=0.1)
β ~ Normal(0, 1)
for i in eachindex(y)
y[i] ~ Normal* x[i], σ)
end
end

@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal.* x, σ^2 * I)
end

ground_truth_β = 2
β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [])

xs_test = [10 + 0.1, 10 + 2 * 0.1]
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, β_chain)

ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01

# Ensure that `rng` is respected
rng = MersenneTwister(42)
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
predictions2 = DynamicPPL.predict(
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
)
@test all(Array(predictions1) .== Array(predictions2))

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain)
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))

@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01

# Multiple chains
multiple_β_chain = MCMCChains.Chains(
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
)
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
@test size(multiple_β_chain, 3) == size(predictions, 3)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
end

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred_vec = vec(
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
)
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
end
end

@testset "with AbstractVector{<:AbstractVarInfo}" begin
@model function linear_reg(x, y, σ=0.1)
β ~ Normal(1, 1)
for i in eachindex(y)
y[i] ~ Normal* x[i], σ)
end
end

ground_truth_β = 2.0
# the data will be ignored, as we are generating samples from the prior
xs_train = 1:0.1:10
ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train))
m_lin_reg = linear_reg(xs_train, ys_train)
chain = [evaluate!!(m_lin_reg)[2] for _ in 1:10000]

# chain is generated from the prior
@test mean([chain[i][@varname(β)] for i in eachindex(chain)]) 1.0 atol = 0.1

xs_test = [10 + 0.1, 10 + 2 * 0.1]
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain)

@test size(predicted_vis) == size(chain)
@test Set(keys(predicted_vis[1])) ==
Set([@varname(β), @varname(y[1]), @varname(y[2])])
# because β samples are from the prior, the std will be larger
@test mean([
predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis)
]) 1.0 * xs_test[1] rtol = 0.1
@test mean([
predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis)
]) 1.0 * xs_test[2] rtol = 0.1
end
end
end

0 comments on commit 6657441

Please sign in to comment.