Skip to content

Commit

Permalink
add docs and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-roehrich committed Oct 27, 2023
1 parent 1768937 commit 22bead5
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 22 deletions.
6 changes: 6 additions & 0 deletions docs/src/sparse/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ Various products:
*(::SRow{T}, ::SMat{T}) where {T}
```

```@docs
dot(x::SRow{T}, A::SMat{T}, y::SRow{T}) where T
dot(x::MatrixElem{T}, A::SMat{T}, y::MatrixElem{T}) where T
dot(x::AbstractVector{T}, A::SMat{T}, y::AbstractVector{T}) where T
```

Other:
```@docs
sparse(::SMat)
Expand Down
65 changes: 43 additions & 22 deletions src/Sparse/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -729,59 +729,80 @@ end
#
################################################################################

@doc raw"""
dot(x::SRow{T}, A::SMat{T}, y::SRow{T}) where T
dot(x::MatrixElem{T}, A::SMat{T}, y::MatrixElem{T}) where T
dot(x::AbstractVector{T}, A::SMat{T}, y::AbstractVector{T}) where T
Return the generalized dot product `dot(x, A*y)`.
"""
function dot(x::SRow{T}, A::SMat{T}, y::SRow{T}) where T
v = T(0)

px = 1
py = 1
for i in 1:length(A.rows), j in 1:length(A[i].pos)
while px <= length(x.pos) && x.pos[px] < A[i].pos[j]
for i in 1:length(A.rows)
while px <= length(x.pos) && x.pos[px] < i
px += 1
end
if px > length(x.pos)
break
elseif x.pos[px] > i
continue
end

while py <= length(y.pos) && y.pos[py] < A[i].pos[j]
py += 1
end
if py > length(y.pos)
break
end
s = T(0)
py = 1
for j in 1:length(A[i].pos)
while py <= length(y.pos) && y.pos[py] < A[i].pos[j]
py += 1
end
if py > length(y.pos)
break
elseif y.pos[py] > A[i].pos[j]
continue
end

if x.pos[px] == A[i].pos[j] == y.pos[py]
v += x.values[px] * A[i].values[j] * y.values[py]
s += A[i].values[j] * y.values[py]
end

v += x.values[px] * s
end

return v
end

function dot(x::AbstractVector{T}, A::SMat{T}, y::AbstractVector{T}) where T
@assert length(x) == length(y)
@req length(x) == length(A.rows) == length(y) "incompatible matrix dimensions"

v = T(0)
for i in 1:length(A.rows), j in 1:length(A[i].pos)
if A[i].pos[j] > length(x) || A[i].pos[j] > length(y)
error("incompatible matrix dimensions")
for i in 1:length(A.rows)
s = T(0)
for j in 1:length(A[i].pos)
if A[i].pos[j] > length(y)
error("incompatible matrix dimensions")
end
s += A[i].values[j] * y[A[i].pos[j]]
end
v += x[A[i].pos[j]] * A[i].values[j] * y[A[i].pos[j]]
v += x[i] * s
end

return v
end

# support dot product for vector matrices
function dot(x::MatrixElem{T}, A::SMat{T}, y::MatrixElem{T}) where T
@assert length(x) == length(y)
@req length(x) == length(A.rows) == length(y) "incompatible matrix dimensions"
len = length(x)

v = T(0)
for i in 1:length(A.rows), j in 1:length(A[i].pos)
if A[i].pos[j] > len || A[i].pos[j] > len
error("incompatible matrix dimensions")
for i in 1:length(A.rows)
s = T(0)
for j in 1:length(A[i].pos)
if A[i].pos[j] > len
error("incompatible matrix dimensions")
end
s += A[i].values[j] * y[A[i].pos[j]]
end
v += x[A[i].pos[j]] * A[i].values[j] * y[A[i].pos[j]] # this will throw if x or y is not a vector
v += x[i] * s
end

return v
Expand Down
28 changes: 28 additions & 0 deletions test/Sparse/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,34 @@ using Hecke.SparseArrays
E = @inferred D * R(3)
@test E == sparse_matrix(R, [3 0 0; 0 0 3; 0 0 0])

# Dot product

D = sparse_matrix(ZZ, [4 -2; -2 2])
E = sparse_matrix(ZZ, [3 0 0; 0 0 3; 0 3 0])

@test dot(sparse_row(ZZ, [1], [1]), D, sparse_row(ZZ, [1], [1])) == 4
@test dot(sparse_row(ZZ, [1, 2], [1, 1]), D, sparse_row(ZZ, [1, 2], [1, 2])) == 2
@test dot(sparse_row(ZZ, [1], [1]), D, sparse_row(ZZ, [2], [1])) == -2

@test dot(sparse_row(ZZ, [1, 4], [1, 2]), D, sparse_row(ZZ, [2], [1])) == -2
@test dot(sparse_row(ZZ, [1, 4], [1, 2]), E, sparse_row(ZZ, [2], [1])) == 0
@test dot(sparse_row(ZZ, [1, 3], [1, 2]), E, sparse_row(ZZ, [2], [1])) == 6

@test dot(ZZRingElem[1, 0], D, ZZRingElem[1, 0]) == 4
@test dot(ZZRingElem[1, 1], D, ZZRingElem[1, 2]) == 2
@test dot(ZZRingElem[1, 0], D, ZZRingElem[0, 1]) == -2

@test dot(ZZ[1 0], D, ZZ[1 0]) == 4
@test dot(ZZ[1; 1], D, ZZ[1 2]) == 2
@test dot(ZZ[1 0], D, ZZ[0; 1]) == -2

@test_throws ArgumentError dot(ZZRingElem[1], D, ZZRingElem[0, 1])
@test_throws ArgumentError dot(ZZRingElem[1, 0, 0], D, ZZRingElem[0, 1])
@test_throws ArgumentError dot(ZZRingElem[1, 0], D, ZZRingElem[0, 1, 0])

@test_throws ArgumentError dot(ZZ[1 0 0], D, ZZ[1 0])
@test_throws ArgumentError dot(ZZ[1 0; 0 0], D, ZZ[1 0 0 0])

# Submatrix

D = sparse_matrix(FlintZZ, [1 5 3; 0 0 0; 0 1 0])
Expand Down

0 comments on commit 22bead5

Please sign in to comment.