Skip to content

Commit

Permalink
Merge pull request #261 from ErikQQY/qqy/interpsol_for_all
Browse files Browse the repository at this point in the history
Use solution object in all solvers
  • Loading branch information
ChrisRackauckas authored Dec 21, 2024
2 parents fec7b7e + 25f6c2b commit f06cbf7
Show file tree
Hide file tree
Showing 28 changed files with 351 additions and 198 deletions.
18 changes: 2 additions & 16 deletions benchmark/simple_pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ function bc_pendulum!(residual, u, p, t)
return nothing
end

function bc_pendulum_mirk!(residual, u, p, t)
residual[1] = u[:, end ÷ 2][1] + π / 2
residual[2] = u[:, end][1] - π / 2
return nothing
end

function simple_pendulum(u, p, t)
g, L, θ, dθ = 9.81, 1.0, u[1], u[2]
return [dθ, -(g / L) * sin(θ)]
Expand All @@ -34,16 +28,8 @@ function bc_pendulum(u, p, t)
return [u((t0 + t1) / 2)[1] + π / 2, u(t1)[1] - π / 2]
end

function bc_pendulum_mirk(u, p, t)
return [u[:, end ÷ 2][1] + π / 2, u[:, end][1] - π / 2]
end

const prob_oop = BVProblem{false}(simple_pendulum, bc_pendulum, [π / 2, π / 2], tspan)
const prob_iip = BVProblem{true}(simple_pendulum!, bc_pendulum!, [π / 2, π / 2], tspan)
const prob_oop_mirk = BVProblem{false}(
simple_pendulum, bc_pendulum_mirk, [π / 2, π / 2], tspan)
const prob_iip_mirk = BVProblem{true}(
simple_pendulum!, bc_pendulum_mirk!, [π / 2, π / 2], tspan)

end

Expand Down Expand Up @@ -77,7 +63,7 @@ function create_simple_pendulum_benchmark()
for alg in (MIRK2, MIRK3, MIRK4, MIRK5, MIRK6)
if @isdefined(alg)
iip_suite["$alg()"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_iip_mirk, $alg(), dt = 0.05)
$SimplePendulumBenchmark.prob_iip, $alg(), dt = 0.05)
end
end

Expand All @@ -102,7 +88,7 @@ function create_simple_pendulum_benchmark()
for alg in (MIRK2, MIRK3, MIRK4, MIRK5, MIRK6)
if @isdefined(alg)
oop_suite["$alg()"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_oop_mirk, $alg(), dt = 0.05)
$SimplePendulumBenchmark.prob_oop, $alg(), dt = 0.05)
end
end

Expand Down
13 changes: 10 additions & 3 deletions lib/BoundaryValueDiffEqCore/src/misc_utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# Intermidiate solution evaluation
@concrete struct EvalSol{iip}
@concrete struct EvalSol{C}
u
t
alg
k_discrete
cache::C
end

Base.size(e::EvalSol) = (size(e.u[1])..., length(e.u))
Base.size(e::EvalSol, i) = size(e)[i]

Base.axes(e::EvalSol) = Base.OneTo.(size(e))
Base.axes(e::EvalSol, d::Int) = Base.OneTo.(size(e)[d])

Base.getindex(e::EvalSol, args...) = Base.getindex(VectorOfArray(e.u), args...)
8 changes: 5 additions & 3 deletions lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ end
__vec_f(u, p, t, f, u_size) = vec(f(reshape(u, u_size), p, t))

function __vec_bc!(resid, sol, p, t, bc!, resid_size, u_size)
bc!(reshape(resid, resid_size), __restructure_sol(sol, u_size), p, t)
bc!(reshape(resid, resid_size), sol, p, t)
return nothing
end

Expand All @@ -232,17 +232,19 @@ function __vec_bc!(resid, sol, p, bc!, resid_size, u_size)
return nothing
end

__vec_bc(sol, p, t, bc, u_size) = vec(bc(__restructure_sol(sol, u_size), p, t))
__vec_bc(sol, p, t, bc, u_size) = vec(bc(sol, p, t))
__vec_bc(sol, p, bc, u_size) = vec(bc(reshape(sol, u_size), p))

@inline __get_non_sparse_ad(ad::AbstractADType) = ad
@inline __get_non_sparse_ad(ad::AutoSparse) = ADTypes.dense_ad(ad)

# Restructure Solution
function __restructure_sol(sol::AbstractVectorOfArray, u_size)
(size(first(sol)) == u_size) && return sol
return VectorOfArray(map(Base.Fix2(reshape, u_size), sol))
end
function __restructure_sol(sol::Vector{<:AbstractArray}, u_size)
function __restructure_sol(sol::AbstractArray{<:AbstractArray}, u_size)
(size(first(sol)) == u_size) && return sol
return map(Base.Fix2(reshape, u_size), sol)
end

Expand Down
25 changes: 13 additions & 12 deletions lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit
recursive_flatten, recursive_flatten!, recursive_unflatten!,
__concrete_nonlinearsolve_algorithm, diff!,
__FastShortcutBVPCompatibleNonlinearPolyalg,
__FastShortcutBVPCompatibleNLLSPolyalg,
__FastShortcutBVPCompatibleNLLSPolyalg, EvalSol,
concrete_jacobian_algorithm, eval_bc_residual,
eval_bc_residual!, get_tmp, __maybe_matmul!,
__append_similar!, __extract_problem_details,
__initial_guess, __maybe_allocate_diffcache,
__get_bcresid_prototype, __similar, __vec, __vec_f,
__vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!,
__internal_nlsolve_problem, MaybeDiffCache, __extract_mesh,
__extract_u0, __has_initial_guess, __initial_guess_length,
__restructure_sol, __get_bcresid_prototype, __similar,
__vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!,
recursive_flatten_twopoint!, __internal_nlsolve_problem,
MaybeDiffCache, __extract_mesh, __extract_u0,
__has_initial_guess, __initial_guess_length,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __sparse_jacobian_cache,
__sparsity_detection_alg, _sparse_like, ColoredMatrix
Expand All @@ -33,7 +34,7 @@ import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scal
import ConcreteStructs: @concrete
import DiffEqBase: solve
import FastClosures: @closure
import ForwardDiff: ForwardDiff, pickchunksize
import ForwardDiff: ForwardDiff, pickchunksize, Dual
import Logging
import RecursiveArrayTools: ArrayPartition, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
Expand All @@ -60,11 +61,11 @@ include("sparse_jacobians.jl")
f1 = (u, p, t) -> [u[2], 0]

function bc1!(residual, u, p, t)
residual[1] = u[:, 1][1] - 5
residual[2] = u[:, end][1]
residual[1] = u(0.0)[1] - 5
residual[2] = u(5.0)[1]
end

bc1 = (u, p, t) -> [u[:, 1][1] - 5, u[:, end][1]]
bc1 = (u, p, t) -> [u(0.0)[1] - 5, u(5.0)[1]]

bc1_a! = (residual, ua, p) -> (residual[1] = ua[1] - 5)
bc1_b! = (residual, ub, p) -> (residual[1] = ub[1])
Expand Down Expand Up @@ -143,14 +144,14 @@ include("sparse_jacobians.jl")
f1_nlls = (u, p, t) -> [u[2], -u[1]]

bc1_nlls! = (resid, sol, p, t) -> begin
solₜ₁ = sol[:, 1]
solₜ₂ = sol[:, end]
solₜ₁ = sol(0.0)
solₜ₂ = sol(100.0)
resid[1] = solₜ₁[1]
resid[2] = solₜ₂[1] - 1
resid[3] = solₜ₂[2] + 1.729109
return nothing
end
bc1_nlls = (sol, p, t) -> [sol[:, 1][1], sol[:, end][1] - 1, sol[:, end][2] + 1.729109]
bc1_nlls = (sol, p, t) -> [sol(0.0)[1], sol(100.0)[1] - 1, sol(100.0)[2] + 1.729109]

bc1_nlls_a! = (resid, ua, p) -> (resid[1] = ua[1])
bc1_nlls_b! = (resid, ub, p) -> (resid[1] = ub[1] - 1;
Expand Down
6 changes: 3 additions & 3 deletions lib/BoundaryValueDiffEqFIRK/src/adaptivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,10 @@ function apply_q_prime(τ, h, coeffs)
return sum(i * coeffs[i] ** h)^(i - 1) for i in axes(coeffs, 1))
end

function eval_q(y_i, τ, h, A, K)
function eval_q(y_i::AbstractArray{T}, τ, h, A, K) where {T}
M = size(K, 1)
q = zeros(M)
q′ = zeros(M)
q = zeros(T, M)
q′ = zeros(T, M)
for i in 1:M
ki = @view K[i, :]
coeffs = get_q_coeffs(A, ki, h)
Expand Down
40 changes: 22 additions & 18 deletions lib/BoundaryValueDiffEqFIRK/src/firk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ end

function __perform_firk_iteration(cache::Union{FIRKCacheExpand, FIRKCacheNested}, abstol,
adaptive::Bool; nlsolve_kwargs = (;), kwargs...)
nlprob = __construct_nlproblem(cache, vec(cache.y₀))
nlprob = __construct_nlproblem(cache, vec(cache.y₀), copy(cache.y₀))
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
sol_nlprob = __solve(
nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs..., alias_u0 = true)
Expand Down Expand Up @@ -402,9 +402,11 @@ end

# Constructing the Nonlinear Problem
function __construct_nlproblem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpand{iip}},
y::AbstractVector) where {iip}
y::AbstractVector, y₀::AbstractVectorOfArray) where {iip}
pt = cache.problem_type

eval_sol = EvalSol(__restructure_sol(y₀.u, cache.in_size), cache.mesh, cache)

loss_bc = if iip
@closure (du, u, p) -> __firk_loss_bc!(
du, u, p, pt, cache.bc, cache.y, cache.mesh, cache)
Expand All @@ -422,9 +424,10 @@ function __construct_nlproblem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpan

loss = if iip
@closure (du, u, p) -> __firk_loss!(
du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache)
du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache, eval_sol)
else
@closure (u, p) -> __firk_loss(u, p, cache.y, pt, cache.bc, cache.mesh, cache)
@closure (u, p) -> __firk_loss(
u, p, cache.y, pt, cache.bc, cache.mesh, cache, eval_sol)
end

return __construct_nlproblem(cache, y, loss_bc, loss_collocation, loss, pt)
Expand Down Expand Up @@ -658,19 +661,19 @@ function __construct_nlproblem(
return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p)
end

@views function __firk_loss!(
resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh, cache) where {BC}
@views function __firk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC,
residual, mesh, cache, eval_sol) where {BC}
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
soly_ = VectorOfArray(y_)
eval_bc_residual!(resids[1], pt, bc!, soly_, p, mesh)
Φ!(resids[2:end], cache, y_, u, p)
eval_sol.u[1:end] .= y_
eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh)
recursive_flatten!(resid, resids)
return nothing
end

@views function __firk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2},
residual, mesh, cache) where {BC1, BC2}
residual, mesh, cache, _) where {BC1, BC2}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
resids = [get_tmp(r, u) for r in residual]
Expand All @@ -682,16 +685,17 @@ end
return nothing
end

@views function __firk_loss(u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache) where {BC}
@views function __firk_loss(
u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, eval_sol) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
resid_bc = eval_bc_residual(pt, bc, soly_, p, mesh)
eval_sol.u[1:end] .= y_
resid_bc = eval_bc_residual(pt, bc, eval_sol, p, mesh)
resid_co = Φ(cache, y_, u, p)
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
end

@views function __firk_loss(
u, p, y, pt::TwoPointBVProblem, bc::Tuple{BC1, BC2}, mesh, cache) where {BC1, BC2}
@views function __firk_loss(u, p, y, pt::TwoPointBVProblem, bc::Tuple{BC1, BC2},
mesh, cache, _) where {BC1, BC2}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
resid_bca, resid_bcb = eval_bc_residual(pt, bc, soly_, p, mesh)
Expand All @@ -702,16 +706,16 @@ end
@views function __firk_loss_bc!(resid, u, p, pt, bc!::BC, y, mesh,
cache::Union{FIRKCacheNested, FIRKCacheExpand}) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
eval_bc_residual!(resid, pt, bc!, soly_, p, mesh)
eval_sol = EvalSol(__restructure_sol(y_, cache.in_size), mesh, cache)
eval_bc_residual!(resid, pt, bc!, eval_sol, p, mesh)
return nothing
end

@views function __firk_loss_bc(u, p, pt, bc!::BC, y, mesh,
cache::Union{FIRKCacheNested, FIRKCacheExpand}) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
return eval_bc_residual(pt, bc!, soly_, p, mesh)
eval_sol = EvalSol(__restructure_sol(y_, cache.in_size), mesh, cache)
return eval_bc_residual(pt, bc!, eval_sol, p, mesh)
end

@views function __firk_loss_collocation!(resid, u, p, y, mesh, residual, cache)
Expand Down
94 changes: 94 additions & 0 deletions lib/BoundaryValueDiffEqFIRK/src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,97 @@ end
cache.mesh, u, cache)
@inline __build_interpolation(cache::FIRKCacheNested, u::AbstractVector) = FIRKNestedInterpolation(
cache.mesh, u, cache)

# Intermidiate solution for evaluating boundry conditions
# basically simplified version of the interpolation for FIRK
# Expanded FIRK
function (s::EvalSol{C})(tval::Number) where {C <: FIRKCacheExpand}
(; t, u, cache) = s
(; f, alg, ITU, p) = cache
(; q_coeff) = ITU
stage = alg_stage(alg)
# Quick handle for the case where tval is at the boundary
(tval == t[1]) && return first(u)
(tval == t[end]) && return last(u)
K = __similar(first(u), length(first(u)), stage)
j = interval(t, tval)
ctr_y = (j - 1) * (stage + 1) + 1

yᵢ = u[ctr_y]
yᵢ₊₁ = u[ctr_y + stage + 1]

if SciMLBase.isinplace(cache.prob)
dyᵢ = similar(yᵢ)
dyᵢ₊₁ = similar(yᵢ₊₁)

f(dyᵢ, yᵢ, p, t[j])
f(dyᵢ₊₁, yᵢ₊₁, p, t[j + 1])
else
dyᵢ = f(yᵢ, p, t[j])
dyᵢ₊₁ = f(yᵢ₊₁, p, t[j + 1])
end

# Load interpolation residual
for jj in 1:stage
K[:, jj] = u[ctr_y + jj]
end
h = t[j + 1] - t[j]
τ = tval - t[j]

z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)

z = similar(yᵢ)

S_interpolate!(z, τ, S_coeffs)
return z
end

nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

# Nested FIRK
function (s::EvalSol{C})(tval::Number) where {C <: FIRKCacheNested}
(; t, u, cache) = s
(; f, nest_prob, nest_tol, alg, mesh_dt, p, ITU) = cache
(; q_coeff) = ITU
stage = alg_stage(alg)
# Quick handle for the case where tval is at the boundary
(tval == t[1]) && return first(u)
(tval == t[end]) && return last(u)
j = interval(t, tval)
h = t[j + 1] - t[j]
τ = tval - t[j]

nest_nlsolve_alg = __concrete_nonlinearsolve_algorithm(nest_prob, alg.nlsolve)
nestprob_p = zeros(cache.M + 2)

yᵢ = u[j]
yᵢ₊₁ = u[j + 1]

if SciMLBase.isinplace(cache.prob)
dyᵢ = similar(yᵢ)
dyᵢ₊₁ = similar(yᵢ₊₁)

f(dyᵢ, yᵢ, p, t[j])
f(dyᵢ₊₁, yᵢ₊₁, p, t[j + 1])
else
dyᵢ = f(yᵢ, p, t[j])
dyᵢ₊₁ = f(yᵢ₊₁, p, t[j + 1])
end

nestprob_p[1] = t[j]
nestprob_p[2] = mesh_dt[j]
nestprob_p[3:end] .= nodual_value(yᵢ)

_nestprob = remake(nest_prob, p = nestprob_p)
nestsol = __solve(_nestprob, nest_nlsolve_alg; abstol = nest_tol)
K = nestsol.u

z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
z = similar(yᵢ)
S_interpolate!(z, τ, S_coeffs)
return z
end
4 changes: 2 additions & 2 deletions lib/BoundaryValueDiffEqFIRK/test/expanded/ensemble_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
end

function bc!(residual, u, p, t)
residual[1] = u[:, 1][1] - 1.0
residual[2] = u[:, end][1]
residual[1] = u(0.0)[1] - 1.0
residual[2] = u(1.0)[1]
end

prob_func(prob, i, repeat) = remake(prob, p = [rand()])
Expand Down
Loading

0 comments on commit f06cbf7

Please sign in to comment.