Skip to content

Commit

Permalink
Merge pull request #274 from aterenin/master
Browse files Browse the repository at this point in the history
Add inv for Triangular types and test.
  • Loading branch information
maleadt authored May 13, 2020
2 parents 451427e + a6c3e2f commit b142b8e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,12 @@ function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArra
end
return dest
end


## inv for Triangular
for TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
@eval function Base.inv(x::$TR{<:Any,<:AbstractGPUArray})
out = typeof(parent(x))(I(size(x,1)))
$TR(LinearAlgebra.ldiv!(x,out))
end
end
10 changes: 10 additions & 0 deletions src/reference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,4 +306,14 @@ function GPUArrays.mapreducedim!(f, op, R::JLArray, A::Union{AbstractArray,Broad
@allowscalar Base.reducedim!(op, R.data, map(f, A))
end


## LinearAlgebra

using LinearAlgebra

for TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
@eval LinearAlgebra.ldiv!(x::$TR{T,<:JLArray{T,2}}, y::JLArray{T,2}) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} = JLArray(LinearAlgebra.ldiv!($TR(parent(x).data),y.data))
@eval LinearAlgebra.rdiv!(x::JLArray{T,2}, y::$TR{T,<:JLArray{T,2}}) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} = JLArray(LinearAlgebra.rdiv!(x.data,$TR(parent(y).data)))
end

end
4 changes: 4 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ function test_linalg(AT)
end
end

@testset "inv for triangular" for TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
@test compare(x -> inv(TR(x)), AT, rand(Float32, 32, 32))
end

@testset "permutedims" begin
@test compare(x -> permutedims(x, (2, 1)), AT, rand(Float32, 2, 3))
@test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6))
Expand Down

0 comments on commit b142b8e

Please sign in to comment.