Skip to content

Commit

Permalink
parser: Support setindex
Browse files Browse the repository at this point in the history
  • Loading branch information
jpsamaroo committed Jan 6, 2025
1 parent 62f8307 commit 3c57330
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,12 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=())
body = nothing
arg1 = nothing
arg2 = nothing
value = nothing
if recur && @capture(ex, f_(allargs__)) ||
@capture(ex, f_(allargs__) do cargs_ body_ end) ||
@capture(ex, allargs__->body_) ||
@capture(ex, arg1_[allargs__]) ||
@capture(ex, arg1_[allargs__] = value_) ||
@capture(ex, arg1_.arg2_) ||
@capture(ex, (;allargs__)) ||
@capture(ex, bf_.(allargs__))
Expand All @@ -387,8 +389,13 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=())
# Getproperty (A.B)
f = Base.getproperty
allargs = Any[arg1, QuoteNode(arg2)]
elseif value !== nothing
# setindex! (A[2,3] = 4)
f = _setindex!_return_value
pushfirst!(allargs, value)
pushfirst!(allargs, arg1)
else
# Indexing (A[2,3])
# getindex (A[2,3])
f = Base.getindex
pushfirst!(allargs, arg1)
end
Expand Down Expand Up @@ -444,6 +451,11 @@ _par(mod, ex; kwargs...) = throw(ArgumentError("Invalid Dagger task expression:
_par_inner(mod, ex; kwargs...) = ex
_par_inner(mod, ex::Expr; kwargs...) = _par(mod, ex; kwargs...)

function _setindex!_return_value(A, value, idxs...)
setindex!(A, value, idxs...)
return value
end

"""
Dagger.spawn(f, args...; kwargs...) -> DTask
Expand Down
19 changes: 19 additions & 0 deletions test/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,25 @@ end
@test t isa Dagger.DTask
@test fetch(t) == 42
end
@testset "setindex!" begin
A = Dagger.@mutable rand(4, 4)

t = @spawn A[1, 2] = 3.0
@test t isa Dagger.DTask
@test fetch(t) == 3.0
@test fetch(@spawn A[1, 2]) == 3.0

t = @spawn A[2] = 4.0
@test t isa Dagger.DTask
@test fetch(t) == 4.0
@test fetch(@spawn A[2]) == 4.0

R = Dagger.@mutable Ref(42)
t = @spawn R[] = 43
@test t isa Dagger.DTask
@test fetch(t) == 43
@test fetch(@spawn R[]) == 43
end
@testset "NamedTuple" begin
t = @spawn (;a=1, b=2)
@test t isa Dagger.DTask
Expand Down

0 comments on commit 3c57330

Please sign in to comment.