From c2ef523fa35d50e8bdebc10a65713cafa33bfc45 Mon Sep 17 00:00:00 2001 From: aterenin Date: Sun, 10 May 2020 19:12:02 +0100 Subject: [PATCH 1/2] Add inv for triangular matrices. --- src/host/linalg.jl | 9 +++++++++ test/testsuite/linalg.jl | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index ffb7a596..f3c7729f 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -200,3 +200,12 @@ function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArra end return dest end + + +## inv for Triangular +for TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + @eval function Base.inv(x::$TR{<:Any,<:AbstractGPUArray}) + out = typeof(parent(x))(I(size(x,1))) + $TR(LinearAlgebra.ldiv!(x,out)) + end +end \ No newline at end of file diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 880b28d6..c5a64d2d 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -46,6 +46,10 @@ function test_linalg(AT) end end + @testset "inv for triangular" for TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + @test compare(x -> inv(TR(x)), AT, rand(Float32, 32, 32)) + end + @testset "permutedims" begin @test compare(x -> permutedims(x, (2, 1)), AT, rand(Float32, 2, 3)) @test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6)) From a6c3e2f7da6f0c31e4ec979c33d4f7a4889a3800 Mon Sep 17 00:00:00 2001 From: aterenin Date: Tue, 12 May 2020 15:08:42 +0100 Subject: [PATCH 2/2] Add ldiv! and rdiv! for JLArray. --- src/reference.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/reference.jl b/src/reference.jl index 832f7c55..cba3c401 100644 --- a/src/reference.jl +++ b/src/reference.jl @@ -305,4 +305,14 @@ function GPUArrays.mapreducedim!(f, op, R::JLArray, As::AbstractArray...; init=n @allowscalar Base.reducedim!(op, R.data, map(f, As...)) end + +## LinearAlgebra + +using LinearAlgebra + +for TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + @eval LinearAlgebra.ldiv!(x::$TR{T,<:JLArray{T,2}}, y::JLArray{T,2}) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} = JLArray(LinearAlgebra.ldiv!($TR(parent(x).data),y.data)) + @eval LinearAlgebra.rdiv!(x::JLArray{T,2}, y::$TR{T,<:JLArray{T,2}}) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} = JLArray(LinearAlgebra.rdiv!(x.data,$TR(parent(y).data))) end + +end \ No newline at end of file