diff --git a/.gitignore b/.gitignore index 559b95aa..4d4f66e4 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ wip /.benchmarkci /benchmark/*.json *.json -*.json.tmp \ No newline at end of file +*.json.tmp +*.pdf diff --git a/Project.toml b/Project.toml index 4da652ec..eefb3c31 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" @@ -22,6 +23,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" @@ -40,7 +42,7 @@ Aqua = "0.7" ArrayInterface = "7" BandedMatrices = "1" ConcreteStructs = "0.2" -DiffEqBase = "6.135" +DiffEqBase = "6.138" ForwardDiff = "0.10" LinearAlgebra = "1.9" LinearSolve = "2" @@ -56,6 +58,7 @@ SciMLBase = "2.5" Setfield = "1" SparseArrays = "1.9" SparseDiffTools = "2.9" +Tricks = "0.1" TruncatedStacktraces = "1" UnPack = "1" julia = "1.9" diff --git a/src/BoundaryValueDiffEq.jl b/src/BoundaryValueDiffEq.jl index 0113cc87..ef62dc20 100644 --- a/src/BoundaryValueDiffEq.jl +++ b/src/BoundaryValueDiffEq.jl @@ -5,7 +5,7 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat @recompile_invalidations begin using ADTypes, Adapt, BandedMatrices, DiffEqBase, ForwardDiff, LinearAlgebra, NonlinearSolve, PreallocationTools, Preferences, RecursiveArrayTools, Reexport, - SciMLBase, Setfield, SparseArrays, SparseDiffTools + SciMLBase, Setfield, SparseArrays, SparseDiffTools, Tricks import ADTypes: AbstractADType import ArrayInterface: matrix_colors, diff --git a/src/algorithms.jl b/src/algorithms.jl index c735088c..01226802 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -62,7 +62,8 @@ end """ MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(), - grid_coarsening = true, jac_alg = BVPJacobianAlgorithm()) + grid_coarsening = true, jac_alg = BVPJacobianAlgorithm(), + auto_static_nodes::Val = Val(false)) Multiple Shooting method, reduces BVP to an initial value problem and solves the IVP. Significantly more stable than Single Shooting. @@ -98,11 +99,18 @@ Significantly more stable than Single Shooting. of shooting points. For example, if `nshoots = 10` and `grid_coarsening = n -> n ÷ 2`, then the grid will be coarsened to `[5, 2]`. +## Experimental Features + + - `auto_static_nodes`: Automatically detect the timepoints used in the boundary condition + and use a faster version of the algorithm! This particular keyword argument should be + considered experimental and should be used with care! (Note that we ignore + `grid_coarsening` if this is set to `Val(true)`. We plan to support this in the future.) + !!! note For type-stability, the chunksizes for ForwardDiff ADTypes in `BVPJacobianAlgorithm` must be provided. """ -@concrete struct MultipleShooting{J <: BVPJacobianAlgorithm} +@concrete struct MultipleShooting{S, J <: BVPJacobianAlgorithm} ode_alg nlsolve jac_alg::J @@ -110,9 +118,9 @@ Significantly more stable than Single Shooting. grid_coarsening end -function concretize_jacobian_algorithm(alg::MultipleShooting, prob) +function concretize_jacobian_algorithm(alg::MultipleShooting{S}, prob) where {S} jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg) - return MultipleShooting(alg.ode_alg, alg.nlsolve, jac_alg, alg.nshoots, + return MultipleShooting{S}(alg.ode_alg, alg.nlsolve, jac_alg, alg.nshoots, alg.grid_coarsening) end @@ -121,17 +129,29 @@ function update_nshoots(alg::MultipleShooting, nshoots::Int) alg.grid_coarsening) end +function __without_static_nodes(ms::MultipleShooting{S}) where {S} + return MultipleShooting{false}(ms.ode_alg, ms.nlsolve, ms.jac_alg, ms.nshoots, + ms.grid_coarsening) +end + function MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(), - grid_coarsening = true, jac_alg = BVPJacobianAlgorithm()) - @assert grid_coarsening isa Bool || grid_coarsening isa Function || - grid_coarsening isa AbstractVector{<:Integer} || - grid_coarsening isa NTuple{N, <:Integer} where {N} + grid_coarsening = missing, jac_alg = BVPJacobianAlgorithm(), + auto_static_nodes::Val{S} = Val(false)) where {S} + @assert S isa Bool "`auto_static_nodes` must be either `Val(true)` or `Val(false)`." + if S + @assert grid_coarsening === missing||(grid_coarsening isa Bool && !grid_coarsening) "`auto_static_nodes` doesn't support grid_coarsening." + else + grid_coarsening === missing && (grid_coarsening = false) + @assert grid_coarsening isa Bool || grid_coarsening isa Function || + grid_coarsening isa AbstractVector{<:Integer} || + grid_coarsening isa NTuple{N, <:Integer} where {N} + end grid_coarsening isa Tuple && (grid_coarsening = Vector(grid_coarsening...)) if grid_coarsening isa AbstractVector sort!(grid_coarsening; rev = true) @assert all(grid_coarsening .> 0) && 1 ∉ grid_coarsening end - return MultipleShooting(ode_alg, nlsolve, jac_alg, nshoots, grid_coarsening) + return MultipleShooting{S}(ode_alg, nlsolve, jac_alg, nshoots, grid_coarsening) end for order in (2, 3, 4, 5, 6) diff --git a/src/solve/mirk.jl b/src/solve/mirk.jl index 45857e73..468484ca 100644 --- a/src/solve/mirk.jl +++ b/src/solve/mirk.jl @@ -35,7 +35,13 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, abstol = 1e-3, adaptive = true, kwargs...) @set! alg.jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg) iip = isinplace(prob) + _, T, M, n, X = __extract_problem_details(prob; dt, check_positive_dt = true) + # NOTE: Assumes the user provided initial guess is on a uniform mesh + mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1)) + + mesh_dt = diff(mesh) + chunksize = pickchunksize(M * (n + 1)) __alloc = x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg) @@ -43,10 +49,6 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, fᵢ_cache = __alloc(similar(X)) fᵢ₂_cache = vec(similar(X)) - # NOTE: Assumes the user provided initial guess is on a uniform mesh - mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1)) - mesh_dt = diff(mesh) - defect_threshold = T(0.1) # TODO: Allow user to specify these MxNsub = 3000 # TODO: Allow user to specify these @@ -100,7 +102,9 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, vecf, vecbc end - return MIRKCache{iip, T}(alg_order(alg), stage, M, size(X), f, bc, prob, + prob_ = !(prob.u0 isa AbstractArray) ? remake(prob; u0 = X) : prob + + return MIRKCache{iip, T}(alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type, prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete, k_interp, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, new_stages, resid₁_size, (; defect_threshold, MxNsub, abstol, dt, adaptive, kwargs...)) diff --git a/src/solve/multiple_shooting.jl b/src/solve/multiple_shooting.jl index 9326c3db..a7db9293 100644 --- a/src/solve/multiple_shooting.jl +++ b/src/solve/multiple_shooting.jl @@ -1,4 +1,86 @@ -function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), +function __solve(prob::BVProblem, _alg::MultipleShooting{true}; odesolve_kwargs = (;), + nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...) + # For TwoPointBVPs there is nothing to do. Forward to general multiple shooting + prob.problem_type isa TwoPointBVProblem && + return __solve_internal(prob, __without_static_nodes(_alg); odesolve_kwargs, + nlsolve_kwargs, ensemblealg, verbose, kwargs...) + + ig, T, N, Nig, u0 = __extract_problem_details(prob; dt = 0.1) + + if _unwrap_val(ig) && prob.u0 isa AbstractVector + if verbose + @warn "Static Nodes for Multiple-Shooting is not supported when Vector of \ + initial guesses are provided. Falling back to using the generic method!" + end + return __solve_internal(prob, __without_static_nodes(_alg); odesolve_kwargs, + nlsolve_kwargs, ensemblealg, verbose, kwargs...) + end + + has_initial_guess = _unwrap_val(ig) + + bcresid_prototype, resid_size = __get_bcresid_prototype(prob, u0) + iip, bc, u0, u0_size = isinplace(prob), prob.f.bc, deepcopy(u0), size(u0) + + # Extract the time-points used in BC + _prob = ODEProblem{iip}(prob.f, prob.u0, prob.tspan, prob.p) + _fake_ode_sol = __construct_fake_ode_solution(_prob, _alg.ode_alg) + if iip + bc(bcresid_prototype, _fake_ode_sol, prob.p, _fake_ode_sol.sol.t) + else + bc(_fake_ode_sol, prob.p, _fake_ode_sol.sol.t) + end + __finalize_nodes!(_fake_ode_sol) + + __alg = concretize_jacobian_algorithm(_alg, prob) + alg = if has_initial_guess && Nig != __alg.nshoots + verbose && + @warn "Initial guess length != `nshoots + 1`! Adapting to `nshoots = $(Nig)`" + update_nshoots(__alg, Nig) + else + __alg + end + nshoots = alg.nshoots + M = length(bcresid_prototype) + + internal_ode_kwargs = (; verbose, kwargs..., odesolve_kwargs..., save_end = true) + + function solve_internal_odes!(resid_nodes::T1, us::T2, p::T3, cur_nshoot::Int, + nodes::T4, odecache::C) where {T1, T2, T3, T4, C} + return __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoot, + odecache, nodes, u0_size, N, ensemblealg) + end + + ode_cache_loss_fn = __multiple_shooting_init_odecache(ensemblealg, prob, + alg.ode_alg, u0, nshoots; internal_ode_kwargs...) + + nodes = typeof(first(tspan))[] + u_at_nodes = __multiple_shooting_initialize!(nodes, prob, alg, ig, nshoots, + ode_cache_loss_fn; kwargs..., verbose, odesolve_kwargs..., + static_nodes = _fake_ode_sol.nodes) + + __solve_nlproblem!(prob.problem_type, alg, bcresid_prototype, u_at_nodes, nodes, + nshoots, M, N, prod(resid_size), solve_internal_odes!, bc, prob, prob.f, + u0_size, u0, ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; verbose, + kwargs..., nlsolve_kwargs...) + + if prob.problem_type isa TwoPointBVProblem + diffmode_shooting = __get_non_sparse_ad(alg.jac_alg.diffmode) + else + diffmode_shooting = __get_non_sparse_ad(alg.jac_alg.bc_diffmode) + end + shooting_alg = Shooting(alg.ode_alg, alg.nlsolve, + BVPJacobianAlgorithm(diffmode_shooting)) + + single_shooting_prob = remake(prob; u0 = reshape(@view(u_at_nodes[1:N]), u0_size)) + return __solve(single_shooting_prob, shooting_alg; odesolve_kwargs, nlsolve_kwargs, + verbose, kwargs...) +end + +function __solve(prob::BVProblem, _alg::MultipleShooting{false}; kwargs...) + return __solve_internal(prob, _alg; kwargs...) +end + +function __solve_internal(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...) @unpack f, tspan = prob @@ -49,8 +131,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), ode_cache_loss_fn; kwargs..., verbose, odesolve_kwargs...) else u_at_nodes = __multiple_shooting_initialize!(nodes, u_at_nodes, prob, alg, - cur_nshoot, all_nshoots[i - 1], ig, ode_cache_loss_fn; kwargs..., verbose, - odesolve_kwargs...) + cur_nshoot, all_nshoots[i - 1], ig, ode_cache_loss_fn, u0; kwargs..., + verbose, odesolve_kwargs...) end if prob.problem_type isa TwoPointBVProblem @@ -130,10 +212,71 @@ function __solve_nlproblem!(::TwoPointBVProblem, alg::MultipleShooting, bcresid_ return nothing end -function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_prototype, - u_at_nodes, nodes, cur_nshoot::Int, M::Int, N::Int, resid_len::Int, - solve_internal_odes!::S, bc::BC, prob, f::F, u0_size, u0, ode_cache_loss_fn, - ensemblealg, internal_ode_kwargs; kwargs...) where {BC, F, S} +function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting{true}, + bcresid_prototype, u_at_nodes, nodes, cur_nshoot::Int, M::Int, N::Int, + resid_len::Int, solve_internal_odes!::S, bc::BC, prob, f::F, u0_size, u0, + ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; kwargs...) where {BC, F, S} + if __any_sparse_ad(alg.jac_alg) + J_proto = __generate_sparse_jacobian_prototype(alg, prob.problem_type, + bcresid_prototype, u0, N, cur_nshoot) + end + resid_prototype = vcat(bcresid_prototype, similar(u_at_nodes, cur_nshoot * N)) + + __resid_nodes = resid_prototype[(end - cur_nshoot * N + 1):end] + resid_nodes = __maybe_allocate_diffcache(__resid_nodes, + pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.bc_diffmode) + + loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss!(du, u, p, cur_nshoot, + nodes, prob, solve_internal_odes!, resid_len, N, f, bc, u0_size, prob.tspan, + alg.ode_alg, u0, ode_cache_loss_fn) + + # ODE Part + sd_ode = alg.jac_alg.nonbc_diffmode isa AbstractSparseADType ? + __sparsity_detection_alg(J_proto) : NoSparsityDetection() + ode_jac_cache = sparse_jacobian_cache(alg.jac_alg.nonbc_diffmode, sd_ode, + nothing, similar(u_at_nodes, cur_nshoot * N), u_at_nodes) + ode_cache_ode_jac_fn = __multiple_shooting_init_jacobian_odecache(ensemblealg, prob, + ode_jac_cache, alg.jac_alg.nonbc_diffmode, alg.ode_alg, cur_nshoot, u0; + internal_ode_kwargs...) + + # BC Part + sd_bc = alg.jac_alg.bc_diffmode isa AbstractSparseADType ? + SymbolicsSparsityDetection() : NoSparsityDetection() + bc_jac_cache = sparse_jacobian_cache(alg.jac_alg.bc_diffmode, + sd_bc, nothing, similar(bcresid_prototype), u_at_nodes) + ode_cache_bc_jac_fn = __multiple_shooting_init_jacobian_odecache(ensemblealg, prob, + bc_jac_cache, alg.jac_alg.bc_diffmode, alg.ode_alg, cur_nshoot, u0; + internal_ode_kwargs...) + + jac_prototype = vcat(init_jacobian(bc_jac_cache), init_jacobian(ode_jac_cache)) + + # Define the functions now + ode_fn = (du, u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes, + ode_cache_ode_jac_fn) + bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc_static_node!(du, u, prob.p, + cur_nshoot, nodes, + prob, solve_internal_odes!, N, f, bc, u0_size, prob.tspan, alg.ode_alg, u0, + ode_cache_bc_jac_fn) + + jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian!(J, u, p, + similar(bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache, + ode_fn, bc_fn, alg, N, M) + + loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn, + jac_prototype) + + # NOTE: u_at_nodes is updated inplace + nlprob = (M != N ? NonlinearLeastSquaresProblem : NonlinearProblem)(loss_function!, + u_at_nodes, prob.p) + __solve(nlprob, alg.nlsolve; kwargs..., alias_u0 = true) + + return nothing +end + +function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting{false}, + bcresid_prototype, u_at_nodes, nodes, cur_nshoot::Int, M::Int, N::Int, + resid_len::Int, solve_internal_odes!::S, bc::BC, prob, f::F, u0_size, u0, + ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; kwargs...) where {BC, F, S} if __any_sparse_ad(alg.jac_alg) J_proto = __generate_sparse_jacobian_prototype(alg, prob.problem_type, bcresid_prototype, u0, N, cur_nshoot) @@ -333,6 +476,29 @@ end return nothing end +@views function __multiple_shooting_mpoint_loss_bc_static_node!(resid_bc, us, p, + cur_nshoots::Int, nodes, prob, solve_internal_odes!::S, N, f::F, bc::BC, u0_size, + tspan, ode_alg, u0, ode_cache) where {S, F, BC} + iip = isinplace(prob) + + # NOTE: We placed the nodes at the points `bc` is evaluated so we don't need to + # recompute the solution + _ts = nodes + _us = [reshape(us[((i - 1) * prod(u0_size) + 1):(i * prod(u0_size))], u0_size) + for i in eachindex(_ts)] + + odeprob = ODEProblem{iip}(f, u0, tspan, p) + total_solution = SciMLBase.build_solution(odeprob, ode_alg, _ts, _us) + + if iip + eval_bc_residual!(resid_bc, StandardBVProblem(), bc, total_solution, p) + else + resid_bc .= eval_bc_residual(StandardBVProblem(), bc, total_solution, p) + end + + return nothing +end + @views function __multiple_shooting_mpoint_loss!(resid, us, p, cur_nshoots::Int, nodes, prob, solve_internal_odes!::S, resid_len, N, f::F, bc::BC, u0_size, tspan, ode_alg, u0, ode_cache) where {S, F, BC} @@ -362,21 +528,35 @@ end resize!(nodes, nshoots + 1) nodes .= range(tspan[1], tspan[2]; length = nshoots + 1) - N = length(first(u0)) - u_at_nodes = similar(first(u0), (nshoots + 1) * N) - recursive_flatten!(u_at_nodes, u0) + # NOTE: We don't check `u0 isa Function` since `u0` in-principle can be a callable + # struct + u0_ = u0 isa AbstractArray ? u0 : [__initial_guess(u0, prob.p, t) for t in nodes] + + N = length(first(u0_)) + u_at_nodes = similar(first(u0_), (nshoots + 1) * N) + recursive_flatten!(u_at_nodes, u0_) return u_at_nodes end # No initial guess @views function __multiple_shooting_initialize!(nodes, prob, alg::MultipleShooting, - ::Val{false}, nshoots::Int, odecache_; verbose, kwargs...) + ::Val{false}, nshoots::Int, odecache_; verbose, static_nodes = nothing, kwargs...) @unpack f, u0, tspan, p = prob @unpack ode_alg = alg resize!(nodes, nshoots + 1) nodes .= range(tspan[1], tspan[2]; length = nshoots + 1) + + if static_nodes !== nothing + idx = 1 + for snode in static_nodes + sidx = searchsortedfirst(nodes[idx:end], snode) + nodes[idx + sidx - 1] = snode + idx = sidx + 1 + end + end + N = length(u0) # Ensures type stability in case the parameters are dual numbers @@ -401,7 +581,8 @@ end end else @warn "Initialization using odesolve failed. Initializing using 0s. It is \ - recommended to provide an `initial_guess` in this case." + recommended to provide an initial guess function via \ + `u0 = (p, t)` or `u0 = (t)` in this case." fill!(u_at_nodes, 0) end @@ -410,16 +591,16 @@ end # Grid coarsening @views function __multiple_shooting_initialize!(nodes, u_at_nodes_prev, prob, alg, - nshoots, old_nshoots, ig, odecache_; kwargs...) - @unpack f, u0, tspan, p = prob + nshoots, old_nshoots, ig, odecache_, u0; kwargs...) + @unpack f, tspan, p = prob prev_nodes = copy(nodes) odecache = odecache_ isa Vector ? first(odecache_) : odecache_ resize!(nodes, nshoots + 1) nodes .= range(tspan[1], tspan[2]; length = nshoots + 1) - N = _unwrap_val(ig) ? length(first(u0)) : length(u0) + N = length(u0) - u_at_nodes = similar(_unwrap_val(ig) ? first(u0) : u0, N + nshoots * N) + u_at_nodes = similar(u0, N + nshoots * N) u_at_nodes[1:N] .= u_at_nodes_prev[1:N] u_at_nodes[(end - N + 1):end] .= u_at_nodes_prev[(end - N + 1):end] diff --git a/src/utils.jl b/src/utils.jl index b302267b..9da06b62 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -139,17 +139,38 @@ function __extract_problem_details(prob, u0::AbstractArray; dt = 0.0, t₀, t₁ = prob.tspan return Val(false), eltype(u0), length(u0), Int(cld(t₁ - t₀, dt)), prob.u0 end -function __extract_problem_details(prob, ::F; kwargs...) where {F <: Function} - throw(ArgumentError("passing `u0` as a function is not supported yet. Curently we only \ - support AbstractArray or Vector of AbstractArrays as input! \ - Use the latter format for passing in initial guess!")) +function __extract_problem_details(prob, f::F; dt = 0.0, + check_positive_dt::Bool = false) where {F <: Function} + # Problem passes in a initial guess function + check_positive_dt && dt ≤ 0 && throw(ArgumentError("dt must be positive")) + u0 = __initial_guess(f, prob.p, prob.tspan[1]) + t₀, t₁ = prob.tspan + return Val(true), eltype(u0), length(u0), Int(cld(t₁ - t₀, dt)), u0 +end + +function __initial_guess(f::F, p::P, t::T) where {F, P, T} + if static_hasmethod(f, Tuple{P, T}) + return f(p, t) + elseif static_hasmethod(f, Tuple{T}) + return f(t) + else + throw(ArgumentError("`initial_guess` must be a function of the form `f(p, t)` or \ + `f(t)`")) + end end -__initial_state_from_prob(prob::BVProblem, mesh) = __initial_state_from_prob(prob.u0, mesh) -__initial_state_from_prob(u0::AbstractArray, mesh) = [copy(vec(u0)) for _ in mesh] -function __initial_state_from_prob(u0::AbstractVector{<:AbstractVector}, _) +function __initial_state_from_prob(prob::BVProblem, mesh) + return __initial_state_from_prob(prob, prob.u0, mesh) +end +function __initial_state_from_prob(::BVProblem, u0::AbstractArray, mesh) + return [copy(vec(u0)) for _ in mesh] +end +function __initial_state_from_prob(::BVProblem, u0::AbstractVector{<:AbstractVector}, _) return [copy(vec(u)) for u in u0] end +function __initial_state_from_prob(prob::BVProblem, f::F, mesh) where {F} + return [__initial_guess(f, prob.p, t) for t in mesh] +end function __get_bcresid_prototype(prob::BVProblem, u) return __get_bcresid_prototype(prob.problem_type, prob, u) @@ -228,3 +249,45 @@ function __restructure_sol(sol::Vector{<:AbstractArray}, u_size) end # TODO: Add dispatch for a ODESolution Type as well + +# Fake ODE Solution to capture calls to the solution object +@concrete struct __FakeODESolution2 + sol + nodes +end + +__FakeODESolutionXXX = __FakeODESolution2 + +function __construct_fake_ode_solution(prob::ODEProblem, alg) + nodes = Vector{promote_type(typeof(prob.tspan[1]), typeof(prob.tspan[2]))}() + return __FakeODESolutionXXX(SciMLBase.build_solution(prob, alg, + [prob.tspan[1], prob.tspan[2]], [prob.u0, prob.u0]), nodes) +end + +function __finalize_nodes!(sol::__FakeODESolutionXXX) + sort!(sol.nodes) + unique!(sol.nodes) + return sol +end + +function (s::__FakeODESolutionXXX)(t::T, args...; kwargs...) where {T <: Number} + push!(s.nodes, t) + return s.sol(t, args...; kwargs...) +end + +function (s::__FakeODESolutionXXX)(t::T, args...; kwargs...) where {T <: AbstractVector} + append!(s.nodes, t) + return s.sol(t, args...; kwargs...) +end + +function Base.getindex(::__FakeODESolutionXXX, args...) + throw(ArgumentError("`static_auto_nodes = Val(true)` doesn't support indexing into \ + the solution object. Please rewrite your code to call the \ + solution object with the time points you want to evaluate at \ + or use `static_auto_nodes = Val(false)`")) +end + +function Base.show(io::IO, sol::__FakeODESolutionXXX) + print(io, "ODESolution evaluated @ nodes: $(sol.nodes)") + return +end