Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DifferentiationInterface for the jacobian #258

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c47f611
Use DifferentiationInterface for the jacobian
ErikQQY Nov 21, 2024
ad04fc9
Fix incorrect dense choice
ErikQQY Nov 21, 2024
7c9ed85
Use SparseConnectivityTracer
ErikQQY Nov 21, 2024
309b809
FIRK and Ascher use SparseConnectivityTracer
ErikQQY Nov 21, 2024
cb29d91
Proper usage of SparseConnectivityTracer
ErikQQY Nov 22, 2024
e46534c
Fix incorrect diffmode usage
ErikQQY Nov 22, 2024
f712ebb
Done MIRKN
ErikQQY Nov 25, 2024
cd42ede
Done FIRK and single shooting
ErikQQY Nov 28, 2024
9cdaa8a
Merge branch 'master' into qqy/di
ErikQQY Nov 28, 2024
8524567
Remove SparseDiffTools everywhere
ErikQQY Nov 28, 2024
3af4f5e
Should reexport ADTypes
ErikQQY Nov 28, 2024
351b91e
Small tweaks
ErikQQY Nov 28, 2024
4dd08ce
Need to use OrdinaryDiffEq in test
ErikQQY Nov 28, 2024
907a106
Fix oop single shooting
ErikQQY Nov 29, 2024
286d635
And multiple shooting is done
ErikQQY Nov 29, 2024
24923f6
Dont forget multiple shooting for TwoPointBVProblem
ErikQQY Nov 29, 2024
c201ac8
Bump DI for shooting methods
ErikQQY Nov 30, 2024
1880155
Fix some CI complainings
ErikQQY Dec 5, 2024
b8b02ba
Merge branch 'master' of https://github.com/SciML/BoundaryValueDiffEq…
ErikQQY Dec 23, 2024
45cb9c6
Merge branch 'master' into qqy/di
ErikQQY Dec 23, 2024
0297d87
Using sparsity_pattern
ErikQQY Dec 24, 2024
162d14b
Fix some conflicts from merging
ErikQQY Dec 24, 2024
2f06695
Fix some incorrect utils in MIRKN
ErikQQY Dec 24, 2024
65a68e8
Fix some incorrect using in extension
ErikQQY Dec 25, 2024
e7670cc
Unify GreedyColoringAlgorithm usage
ErikQQY Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions lib/BoundaryValueDiffEqAscher/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BoundaryValueDiffEqAscher"
uuid = "7227322d-7511-4e07-9247-ad6ff830280e"
authors = ["Qingyu Qu <[email protected]>"]
version = "1.1.0"
version = "1.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -12,6 +12,7 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BoundaryValueDiffEqCore = "56b672f2-a5fe-4263-ab2d-da677488eb3a"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -25,7 +26,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[compat]
ADTypes = "1.9"
Expand All @@ -37,6 +38,7 @@ BoundaryValueDiffEqCore = "1.1.0"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.158.3"
DiffEqDevTools = "2.44"
DifferentiationInterface = "0.6.22"
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
FastClosures = "0.3.2"
ForwardDiff = "0.10.38"
Hwloc = "3"
Expand All @@ -55,7 +57,7 @@ Reexport = "1.2"
SciMLBase = "2.59.1"
Setfield = "1.1.1"
SparseArrays = "1.10"
SparseDiffTools = "2.23"
SparseMatrixColorings = "0.4.10"
StaticArrays = "1.9.8"
Test = "1.10"
julia = "1.10"
Expand Down
39 changes: 21 additions & 18 deletions lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
module BoundaryValueDiffEqAscher

using ADTypes
using ADTypes: ADTypes
using AlmostBlockDiagonals
using BoundaryValueDiffEqCore
using ConcreteStructs
using FastClosures
using ForwardDiff
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase
using DifferentiationInterface: DifferentiationInterface, Constant
using FastClosures: @closure
using ForwardDiff: ForwardDiff
using LinearAlgebra
using PreallocationTools
using RecursiveArrayTools
using Reexport
using SciMLBase
using Setfield
using PreallocationTools: PreallocationTools, DiffCache
using RecursiveArrayTools: VectorOfArray, recursivecopy
using Reexport: @reexport
using SciMLBase: SciMLBase, AbstractDiffEqInterpolation, StandardBVProblem, __solve,
_unwrap_val
using Setfield: @set!
using SparseMatrixColorings: SparseMatrixColorings, GreedyColoringAlgorithm, LargestFirst

import BoundaryValueDiffEqCore: BVPJacobianAlgorithm, __extract_problem_details,
concrete_jacobian_algorithm, __Fix3,
__concrete_nonlinearsolve_algorithm,
BoundaryValueDiffEqAlgorithm, __sparse_jacobian_cache,
__vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!,
__extract_mesh
const DI = DifferentiationInterface

import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
using BoundaryValueDiffEqCore: BVPJacobianAlgorithm, __extract_problem_details,
concrete_jacobian_algorithm, __Fix3,
__concrete_nonlinearsolve_algorithm,
BoundaryValueDiffEqAlgorithm, __sparse_jacobian_cache, __vec,
__vec_f, __vec_f!, __vec_bc, __vec_bc!, __extract_mesh,
get_dense_ad

@reexport using ADTypes, DiffEqBase, BoundaryValueDiffEqCore, SparseDiffTools, SciMLBase
@reexport using BoundaryValueDiffEqCore, SciMLBase

include("types.jl")
include("utils.jl")
Expand Down
45 changes: 32 additions & 13 deletions lib/BoundaryValueDiffEqAscher/src/ascher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,31 +315,50 @@ function __construct_nlproblem(cache::AscherCache{iip, T}) where {iip, T}
else
@closure (z, p) -> @views Φ(cache, z, pt)
end

lz = reduce(vcat, cache.z)
sd = alg.jac_alg.diffmode isa AutoSparse ? SymbolicsSparsityDetection() :
NoSparsityDetection()
ad = alg.jac_alg.diffmode
lossₚ = (iip ? __Fix3 : Base.Fix2)(loss, cache.p)
jac_cache = __sparse_jacobian_cache(Val(iip), ad, sd, lossₚ, lz, lz)
jac_prototype = init_jacobian(jac_cache)
resid_prototype = zero(lz)
diffmode = if alg.jac_alg.diffmode isa AutoSparse
AutoSparse(alg.jac_alg.diffmode;
coloring_algorithm = GreedyColoringAlgorithm(LargestFirst()))
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
else
alg.jac_alg.diffmode
end

jac_cache = if iip
DI.prepare_jacobian(
loss, resid_prototype, get_dense_ad(diffmode), lz, Constant(cache.p))
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
else
DI.prepare_jacobian(loss, get_dense_ad(diffmode), lz, Constant(cache.p))
end

jac_prototype = if iip
DI.jacobian(
loss, resid_prototype, jac_cache, get_dense_ad(diffmode), lz, Constant(cache.p))
else
DI.jacobian(loss, get_dense_ad(diffmode), lz, Constant(cache.p))
end

jac = if iip
@closure (J, u, p) -> __ascher_mpoint_jacobian!(J, u, ad, jac_cache, lossₚ, lz)
@closure (J, u, p) -> __ascher_mpoint_jacobian!(
J, u, get_dense_ad(diffmode), jac_cache, loss, lz, cache.p)
else
@closure (u, p) -> __ascher_mpoint_jacobian(jac_prototype, u, ad, jac_cache, lossₚ)
@closure (u, p) -> __ascher_mpoint_jacobian(
jac_prototype, u, get_dense_ad(diffmode), jac_cache, loss, cache.p)
end
resid_prototype = zero(lz)

_nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
nlprob::NonlinearProblem = NonlinearProblem(_nlf, lz, cache.p)
return nlprob
end

function __ascher_mpoint_jacobian!(J, x, diffmode, diffcache, loss, resid)
sparse_jacobian!(J, diffmode, diffcache, loss, resid, x)
function __ascher_mpoint_jacobian!(J, x, diffmode, diffcache, loss, resid, p)
DI.jacobian!(loss, resid, J, diffcache, diffmode, x, Constant(p))
return nothing
end
function __ascher_mpoint_jacobian(J, x, diffmode, diffcache, loss)
sparse_jacobian!(J, diffmode, diffcache, loss, x)
function __ascher_mpoint_jacobian(J, x, diffmode, diffcache, loss, p)
DI.jacobian!(loss, J, diffcache, diffmode, x, Constant(p))
return J
end

Expand Down
8 changes: 5 additions & 3 deletions lib/BoundaryValueDiffEqCore/src/BoundaryValueDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ import Logging
using NonlinearSolveFirstOrder: NonlinearSolvePolyAlgorithm
import LineSearch: BackTracking
import RecursiveArrayTools: VectorOfArray, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
import SciMLBase: AbstractBVProblem, AbstractDiffEqInterpolation, StandardBVProblem,
__solve, _unwrap_val

@reexport using ADTypes, NonlinearSolveFirstOrder, SparseDiffTools, SciMLBase
@reexport using NonlinearSolveFirstOrder, SparseDiffTools, SciMLBase
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved

include("types.jl")
include("utils.jl")
Expand All @@ -29,7 +30,8 @@ include("alg_utils.jl")
include("default_nlsolve.jl")
include("sparse_jacobians.jl")

function __solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; kwargs...)
function __solve(
prob::AbstractBVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end
Expand Down
4 changes: 4 additions & 0 deletions lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,7 @@ end
@inline (f::__Fix3{F})(a, b) where {F} = f.f(a, b, f.x)

# convert every vector of vector to AbstractVectorOfArray, especially if them come from get_tmp of PreallocationTools.jl

get_dense_ad(::Nothing) = nothing
get_dense_ad(ad) = ad
get_dense_ad(ad::AutoSparse) = ADTypes.dense_ad(ad)
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 5 additions & 3 deletions lib/BoundaryValueDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "BoundaryValueDiffEqFIRK"
uuid = "85d9eb09-370e-4000-bb32-543851f73618"
version = "1.2.0"
version = "1.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -10,6 +10,7 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BoundaryValueDiffEqCore = "56b672f2-a5fe-4263-ab2d-da677488eb3a"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -24,7 +25,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[compat]
ADTypes = "1.9"
Expand All @@ -36,6 +37,7 @@ BoundaryValueDiffEqCore = "1.1.0"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.158.3"
DiffEqDevTools = "2.44"
DifferentiationInterface = "0.6.22"
FastAlmostBandedMatrices = "0.1.4"
FastClosures = "0.3.2"
ForwardDiff = "0.10.38"
Expand All @@ -56,7 +58,7 @@ Reexport = "1.2"
SciMLBase = "2.59.1"
Setfield = "1.1.1"
SparseArrays = "1.10"
SparseDiffTools = "2.23"
SparseMatrixColorings = "0.4.10"
StaticArrays = "1.9.8"
Test = "1.10"
julia = "1.10"
Expand Down
82 changes: 46 additions & 36 deletions lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,54 @@
module BoundaryValueDiffEqFIRK

import PrecompileTools: @compile_workload, @setup_workload

using ADTypes, Adapt, ArrayInterface, BoundaryValueDiffEqCore, DiffEqBase, ForwardDiff,
LinearAlgebra, Preferences, RecursiveArrayTools, Reexport, SciMLBase, Setfield,
SparseDiffTools

using PrecompileTools: @compile_workload, @setup_workload
using SparseMatrixColorings: ColoringProblem, GreedyColoringAlgorithm,
ConstantColoringAlgorithm, row_colors, column_colors, coloring,
LargestFirst
using PreallocationTools: PreallocationTools, DiffCache

# Special Matrix Types
using BandedMatrices, FastAlmostBandedMatrices, SparseArrays

import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorithm,
recursive_flatten, recursive_flatten!, recursive_unflatten!,
__concrete_nonlinearsolve_algorithm, diff!,
__FastShortcutBVPCompatibleNonlinearPolyalg,
__FastShortcutBVPCompatibleNLLSPolyalg,
concrete_jacobian_algorithm, eval_bc_residual,
eval_bc_residual!, get_tmp, __maybe_matmul!,
__append_similar!, __extract_problem_details,
__initial_guess, __maybe_allocate_diffcache,
__get_bcresid_prototype, __similar, __vec, __vec_f,
__vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!,
__internal_nlsolve_problem, MaybeDiffCache, __extract_mesh,
__extract_u0, __has_initial_guess, __initial_guess_length,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __sparse_jacobian_cache,
__sparsity_detection_alg, _sparse_like, ColoredMatrix

import ADTypes: AbstractADType
import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scalar_indexing
import ConcreteStructs: @concrete
import DiffEqBase: solve
import FastClosures: @closure
import ForwardDiff: ForwardDiff, pickchunksize
import Logging
import RecursiveArrayTools: ArrayPartition, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val

@reexport using ADTypes, DiffEqBase, BoundaryValueDiffEqCore, SparseDiffTools, SciMLBase
using BandedMatrices: BandedMatrix, Ones
using SparseArrays: sparse
using LinearAlgebra
using FastAlmostBandedMatrices: AlmostBandedMatrix, fillpart, exclusive_bandpart,
finish_part_setindex!

using BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorithm,
recursive_flatten, recursive_flatten!, recursive_unflatten!,
__concrete_nonlinearsolve_algorithm, diff!,
__FastShortcutBVPCompatibleNonlinearPolyalg,
__FastShortcutBVPCompatibleNLLSPolyalg,
concrete_jacobian_algorithm, eval_bc_residual,
eval_bc_residual!, get_tmp, __maybe_matmul!,
__append_similar!, __extract_problem_details,
__initial_guess, __maybe_allocate_diffcache,
__get_bcresid_prototype, __similar, __vec, __vec_f, __vec_f!,
__vec_bc, __vec_bc!, recursive_flatten_twopoint!,
__internal_nlsolve_problem, MaybeDiffCache, __extract_mesh,
__extract_u0, __has_initial_guess, __initial_guess_length,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __sparse_jacobian_cache,
__sparsity_detection_alg, _sparse_like, ColoredMatrix,
get_dense_ad

using Adapt
using ADTypes: ADTypes
using ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scalar_indexing
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase
using DifferentiationInterface: DifferentiationInterface, Constant
using FastClosures: @closure
using ForwardDiff: ForwardDiff, pickchunksize
using Logging: Logging
using RecursiveArrayTools: AbstractVectorOfArray, VectorOfArray, recursivecopy
using SciMLBase: SciMLBase, AbstractDiffEqInterpolation, StandardBVProblem, __solve,
_unwrap_val
using Preferences: Preferences
using Reexport: @reexport
using Setfield: @set!
const DI = DifferentiationInterface

@reexport using BoundaryValueDiffEqCore, SciMLBase

include("types.jl")
include("utils.jl")
Expand Down
Loading
Loading