Skip to content

Commit

Permalink
Test with Tapir (TuringLang#2289)
Browse files Browse the repository at this point in the history
* Test with Tapir

* Relax Tapir version bound

* Relax Tapir version bounds more

* Add test/test_utils/ad_utils.jl

* Change how Tapir is installed for tests

* Typo fix

* Turn Tapir's safe mode off

* Use standard AutoReverseDiff constructor

Co-authored-by: Hong Ge <[email protected]>

* Revert back to previous AutoReverseDiff constructor

* modify `setvarinfo`

* fix test error

* fix more error

* fix error

* fix error

* Exclude Tapir from AdvancedHMC tests

Co-authored-by: Will Tebbutt <[email protected]>

* Update ad_utils.jl (TuringLang#2313)

* Update test/test_utils/ad_utils.jl

* Move code around in ad_utils.jl

* Add a todo note

---------

Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: Will Tebbutt <[email protected]>
  • Loading branch information
5 people authored Sep 3, 2024
1 parent a26ce11 commit f92c93f
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 17 deletions.
14 changes: 10 additions & 4 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,17 @@ getstats(transition::AdvancedHMC.Transition) = transition.stat
getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params

getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) = getvarinfo(parent(f))
function getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper)
return getvarinfo(LogDensityProblemsAD.parent(f))
end

setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Accessors.@set f.varinfo = varinfo
function setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo)
return Accessors.@set f.= setvarinfo(f.ℓ, varinfo)
function setvarinfo(
f::LogDensityProblemsAD.ADGradientWrapper, varinfo, adtype::ADTypes.AbstractADType
)
return LogDensityProblemsAD.ADgradient(
adtype, setvarinfo(LogDensityProblemsAD.parent(f), varinfo)
)
end

"""
Expand Down Expand Up @@ -120,7 +126,7 @@ function AbstractMCMC.step(
varinfo = DynamicPPL.link(varinfo, model)
end
end
f = setvarinfo(f, varinfo)
f = setvarinfo(f, varinfo, alg.adtype)

# Then just call `AdvancedHMC.step` with the right arguments.
if initial_state === nothing
Expand Down
5 changes: 4 additions & 1 deletion test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module InferenceTests

using ..Models: gdemo_d, gdemo_default
using ..NumericalTests: check_gdemo, check_numerical
import ..ADUtils
using Distributions: Bernoulli, Beta, InverseGamma, Normal
using Distributions: sample
import DynamicPPL
Expand All @@ -14,7 +15,9 @@ import ReverseDiff
using Test: @test, @test_throws, @testset
using Turing

@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
ADUtils.install_tapir && import Tapir

@testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends
# Only test threading if 1.3+.
if VERSION > v"1.2"
@testset "threaded sampling" begin
Expand Down
11 changes: 9 additions & 2 deletions test/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module AbstractMCMCTests

import ..ADUtils
using AdvancedMH: AdvancedMH
using Distributions: sample
using Distributions.FillArrays: Zeros
Expand All @@ -15,14 +16,18 @@ using Test: @test, @test_throws, @testset
using Turing
using Turing.Inference: AdvancedHMC

ADUtils.install_tapir && import Tapir

function initialize_nuts(model::Turing.Model)
# Create a log-density function with an implementation of the
# gradient so we ensure that we're using the same AD backend as in Turing.
f = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(model))

# Link the varinfo.
f = Turing.Inference.setvarinfo(
f, DynamicPPL.link!!(Turing.Inference.getvarinfo(f), model)
f,
DynamicPPL.link!!(Turing.Inference.getvarinfo(f), model),
Turing.Inference.getADType(DynamicPPL.getcontext(LogDensityProblemsAD.parent(f))),
)

# Choose parameter dimensionality and initial parameter value
Expand Down Expand Up @@ -112,7 +117,9 @@ end

@testset "External samplers" begin
@testset "AdvancedHMC.jl" begin
# Try a few different AD backends.
# TODO(mhauru) The below tests fail with Tapir, see
# https://github.com/TuringLang/Turing.jl/pull/2289.
# Once that is fixed, this should say `for adtype in ADUtils.adbackends`.
@testset "adtype=$adtype" for adtype in [AutoForwardDiff(), AutoReverseDiff()]
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
# Need some functionality to initialize the sampler.
Expand Down
7 changes: 4 additions & 3 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module GibbsTests

using ..Models: MoGtest_default, gdemo, gdemo_default
using ..NumericalTests: check_MoGtest_default, check_gdemo, check_numerical
import ..ADUtils
using Distributions: InverseGamma, Normal
using Distributions: sample
using ForwardDiff: ForwardDiff
Expand All @@ -12,9 +13,9 @@ using Turing
using Turing: Inference
using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess

@testset "Testing gibbs.jl with $adbackend" for adbackend in (
AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)
)
ADUtils.install_tapir && import Tapir

@testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends
@testset "gibbs constructor" begin
N = 500
s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend))
Expand Down
7 changes: 4 additions & 3 deletions test/mcmc/gibbs_conditional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module GibbsConditionalTests

using ..Models: gdemo, gdemo_default
using ..NumericalTests: check_gdemo, check_numerical
import ..ADUtils
using Clustering: Clustering
using Distributions: Categorical, InverseGamma, Normal, sample
using ForwardDiff: ForwardDiff
Expand All @@ -14,9 +15,9 @@ using StatsFuns: StatsFuns
using Test: @test, @testset
using Turing

@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in (
AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)
)
ADUtils.install_tapir && import Tapir

@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in ADUtils.adbackends
Random.seed!(1000)
rng = StableRNG(123)

Expand Down
5 changes: 4 additions & 1 deletion test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ..Models: gdemo_default
using ..ADUtils: ADTypeCheckContext
#using ..Models: gdemo
using ..NumericalTests: check_gdemo, check_numerical
import ..ADUtils
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
import DynamicPPL
using DynamicPPL: Sampler
Expand All @@ -17,7 +18,9 @@ using StatsFuns: logistic
using Test: @test, @test_logs, @testset
using Turing

@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
ADUtils.install_tapir && import Tapir

@testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends
# Set a seed
rng = StableRNG(123)
@testset "constrained bounded" begin
Expand Down
7 changes: 5 additions & 2 deletions test/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module SGHMCTests

using ..Models: gdemo_default
using ..NumericalTests: check_gdemo
import ..ADUtils
using Distributions: sample
import ForwardDiff
using LinearAlgebra: dot
Expand All @@ -10,7 +11,9 @@ using StableRNGs: StableRNG
using Test: @test, @testset
using Turing

@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
ADUtils.install_tapir && import Tapir

@testset "Testing sghmc.jl with $adbackend" for adbackend in ADUtils.adbackends
@testset "sghmc constructor" begin
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend)
@test alg isa SGHMC
Expand All @@ -36,7 +39,7 @@ using Turing
end
end

@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
@testset "Testing sgld.jl with $adbackend" for adbackend in ADUtils.adbackends
@testset "sgld constructor" begin
alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend)
@test alg isa SGLD
Expand Down
26 changes: 25 additions & 1 deletion test/test_utils/ad_utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ADUtils

using ForwardDiff: ForwardDiff
using Pkg: Pkg
using Random: Random
using ReverseDiff: ReverseDiff
using Test: Test
Expand All @@ -9,7 +10,10 @@ using Turing: Turing
using Turing: DynamicPPL
using Zygote: Zygote

export ADTypeCheckContext
export ADTypeCheckContext, adbackends

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Stuff for checking that the right AD backend is being used.

"""Element types that are always valid for a VarInfo regardless of ADType."""
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
Expand Down Expand Up @@ -270,4 +274,24 @@ Test.@testset "ADTypeCheckContext" begin
end
end

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# List of AD backends to test.

"""
All the ADTypes on which we want to run the tests.
"""
adbackends = [
Turing.AutoForwardDiff(; chunksize=0), Turing.AutoReverseDiff(; compile=false)
]

# Tapir isn't supported for older Julia versions, hence the check.
install_tapir = isdefined(Turing, :AutoTapir)
if install_tapir
# TODO(mhauru) Is there a better way to install optional dependencies like this?
Pkg.add("Tapir")
using Tapir
push!(adbackends, Turing.AutoTapir(false))
push!(eltypes_by_adtype, Turing.AutoTapir => (Tapir.CoDual,))
end

end

0 comments on commit f92c93f

Please sign in to comment.