Skip to content

Commit

Permalink
Merge pull request #214 from JuliaGPU/tb/adapt_collect
Browse files Browse the repository at this point in the history
Use Adapt.jl for generating collect methods.
  • Loading branch information
maleadt authored Oct 31, 2019
2 parents 54de102 + 42a4d3b commit 43746d7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
25 changes: 19 additions & 6 deletions src/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,36 @@ function to_cartesian(A, indices::Tuple)
CartesianIndices(start, stop)
end

## showing
## convert to CPU (keeping wrapper type)

Adapt.adapt_storage(::Type{<:Array}, xs::AbstractArray) = convert(Array, xs)
cpu(xs) = adapt(Array, xs)
convert_to_cpu(xs) = adapt(Array, xs)

## showing

for (W, ctor) in (:AT => (A,mut)->mut(A), Adapt.wrappers...)
@eval begin
# display
Base.print_array(io::IO, X::$W where {AT <: GPUArray}) = Base.print_array(io, $ctor(X, cpu))
Base.print_array(io::IO, X::$W where {AT <: GPUArray}) =
Base.print_array(io, $ctor(X, convert_to_cpu))

# show
Base._show_nonempty(io::IO, X::$W where {AT <: GPUArray}, prefix::String) =
Base._show_nonempty(io, $ctor(X, cpu), prefix)
Base._show_nonempty(io, $ctor(X, convert_to_cpu), prefix)
Base._show_empty(io::IO, X::$W where {AT <: GPUArray}) =
Base._show_empty(io, $ctor(X, cpu))
Base._show_empty(io, $ctor(X, convert_to_cpu))
Base.show_vector(io::IO, v::$W where {AT <: GPUArray}, args...) =
Base.show_vector(io, $ctor(v, cpu), args...)
Base.show_vector(io, $ctor(v, convert_to_cpu), args...)
end
end

## collect to CPU (discarding wrapper type)

collect_to_cpu(xs::AbstractArray) = collect(convert_to_cpu(xs))

for (W, ctor) in (:AT => (A,mut)->mut(A), Adapt.wrappers...)
@eval begin
Base.collect(X::$W where {AT <: GPUArray}) = collect_to_cpu(X)
end
end

Expand Down
9 changes: 1 addition & 8 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,11 @@ for i = 0:10

end

to_cpu(x) = x
to_cpu(x::GPUArray) = Array(x)
to_cpu(x::Broadcasted{ArrayStyle{AT}}) where {AT <: GPUArray} = to_cpu(Base.Broadcast.materialize(x))
to_cpu(x::LinearAlgebra.Transpose) = LinearAlgebra.Transpose(to_cpu(parent(x)))
to_cpu(x::LinearAlgebra.Adjoint) = LinearAlgebra.Adjoint(to_cpu(parent(x)))
to_cpu(x::SubArray) = SubArray(to_cpu(parent(x)), parentindices(x))

function acc_mapreduce(f, op, v0::OT, A::GPUSrcArray, rest::Tuple) where {OT}
blocksize = 80
threads = 256
if length(A) <= blocksize * threads
args = zip(to_cpu(A), to_cpu.(rest)...)
args = zip(convert_to_cpu(A), convert_to_cpu.(rest)...)
return mapreduce(x-> f(x...), op, args, init = v0)
end
out = similar(A, OT, (blocksize,))
Expand Down

0 comments on commit 43746d7

Please sign in to comment.