From b5fbb0fa58fa0a5becf972965026d9a71c1b321c Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Mon, 1 Apr 2024 14:25:17 +0200 Subject: [PATCH 01/22] Delete Dagger.cleanup() Because it doesn't actually do anything now. --- src/compute.jl | 6 ------ src/sch/Sch.jl | 3 --- 2 files changed, 9 deletions(-) diff --git a/src/compute.jl b/src/compute.jl index f421eaccc..093b527f4 100644 --- a/src/compute.jl +++ b/src/compute.jl @@ -36,12 +36,6 @@ end Base.@deprecate gather(ctx, x) collect(ctx, x) Base.@deprecate gather(x) collect(x) -cleanup() = cleanup(Context(global_context())) -function cleanup(ctx::Context) - Sch.cleanup(ctx) - nothing -end - function get_type(s::String) local T for t in split(s, ".") diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 5c330841e..2553142a3 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -313,9 +313,6 @@ function populate_defaults(opts::ThunkOptions, Tf, Targs) ) end -function cleanup(ctx) -end - # Eager scheduling include("eager.jl") From 080d7bc207e5d88f87ea8ee126f27b78a94b46a4 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Tue, 9 Apr 2024 03:44:35 +0200 Subject: [PATCH 02/22] Use procs() when initializing EAGER_CONTEXT Using `myid()` with `workers()` meant that when the context was initialized with a single worker the processor list would be: `[OSProc(1), OSProc(1)]`. `procs()` will always include PID 1 and any other workers, which is what we want. --- src/sch/eager.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sch/eager.jl b/src/sch/eager.jl index 87a109788..259c1b6f8 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -6,7 +6,7 @@ const EAGER_STATE = Ref{Union{ComputeState,Nothing}}(nothing) function eager_context() if EAGER_CONTEXT[] === nothing - EAGER_CONTEXT[] = Context([myid(),workers()...]) + EAGER_CONTEXT[] = Context(procs()) end return EAGER_CONTEXT[] end From ebdfda43089428b1d02cbbe299de20814d8f2d4d Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 30 Nov 2023 20:42:50 -0700 Subject: [PATCH 03/22] Add metadata to EagerThunk --- Project.toml | 1 + src/dtask.jl | 14 +++++++++++++- src/submission.jl | 14 +++++++++++++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 8735298dc..45cf988c8 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" +Mmap = "a63ad114-7e13-5084-954f-fe012c677804" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" diff --git a/src/dtask.jl b/src/dtask.jl index 68f2d3c1b..98f74005a 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -39,6 +39,16 @@ end Options(;options...) = Options((;options...)) Options(options...) = Options((;options...)) +""" + DTaskMetadata + +Represents some useful metadata pertaining to a `DTask`: +- `return_type::Type` - The inferred return type of the task +""" +mutable struct DTaskMetadata + return_type::Type +end + """ DTask @@ -50,9 +60,11 @@ more details. mutable struct DTask uid::UInt future::ThunkFuture + metadata::DTaskMetadata finalizer_ref::DRef thunk_ref::DRef - DTask(uid, future, finalizer_ref) = new(uid, future, finalizer_ref) + + DTask(uid, future, metadata, finalizer_ref) = new(uid, future, metadata, finalizer_ref) end const EagerThunk = DTask diff --git a/src/submission.jl b/src/submission.jl index 7312e378d..bfb8cb8be 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -218,15 +218,27 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) return options end end + +function DTaskMetadata(spec::DTaskSpec) + f = chunktype(spec.f).instance + arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) + return_type = Base.promote_op(f, arg_types...) + return DTaskMetadata(return_type) +end + function eager_spawn(spec::DTaskSpec) # Generate new DTask uid = eager_next_id() future = ThunkFuture() + metadata = DTaskMetadata(spec) finalizer_ref = poolset(DTaskFinalizer(uid); device=MemPool.CPURAMDevice()) # Create unlaunched DTask - return DTask(uid, future, finalizer_ref) + return DTask(uid, future, metadata, finalizer_ref) end + +chunktype(t::DTask) = t.metadata.return_type + function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) # Assign a name, if specified eager_assign_name!(spec, task) From fcff9911e862a107b64d349d99b7e615746194ff Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 30 Nov 2023 20:44:14 -0700 Subject: [PATCH 04/22] Sch: Allow occupancy key to be Any --- src/sch/util.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/sch/util.jl b/src/sch/util.jl index e81703db5..cd006838b 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -406,12 +406,19 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig) else get(state.signature_alloc_cost, sig, UInt64(0)) end::UInt64 - est_occupancy = if occupancy !== nothing && haskey(occupancy, T) - # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` - Base.unsafe_trunc(UInt32, clamp(occupancy[T], 0, 1) * typemax(UInt32)) - else - typemax(UInt32) - end::UInt32 + est_occupancy::UInt32 = typemax(UInt32) + if occupancy !== nothing + occ = nothing + if haskey(occupancy, T) + occ = occupancy[T] + elseif haskey(occupancy, Any) + occ = occupancy[Any] + end + if occ !== nothing + # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` + est_occupancy = Base.unsafe_trunc(UInt32, clamp(occ, 0, 1) * typemax(UInt32)) + end + end #= FIXME: Estimate if cached data can be swapped to storage storage = storage_resource(p) real_alloc_util = state.worker_storage_pressure[gp][storage] From 50446760c41ab063e782481a09a4d8d84fb6a045 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 26 Nov 2024 12:29:32 -0600 Subject: [PATCH 05/22] Add a --verbose option to runtests.jl --- test/runtests.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index cfdab8177..00ed5862c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,6 +52,9 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ arg_type = Int default = additional_workers help = "How many additional workers to launch" + "-v", "--verbose" + action = :store_true + help = "Run the tests with debug logs from Dagger" end end @@ -81,6 +84,10 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ parsed_args["simulate"] && exit(0) additional_workers = parsed_args["procs"] + + if parsed_args["verbose"] + ENV["JULIA_DEBUG"] = "Dagger" + end else to_test = all_test_names @info "Running all tests" From ed2493c9b8f39f0b76f4d70b4fc72dd598b812cf Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 22 May 2024 12:48:53 -0500 Subject: [PATCH 06/22] task-tls: Refactor into DTaskTLS struct --- src/Dagger.jl | 2 ++ src/array/indexing.jl | 2 -- src/sch/Sch.jl | 2 +- src/sch/dynamic.jl | 2 +- src/task-tls.jl | 49 ++++++++++++++++++++++--------------------- 5 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index 8bc2c24a1..2fc2da84e 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -22,6 +22,8 @@ else import Base.ScopedValues: ScopedValue, with end +import TaskLocalValues: TaskLocalValue + if !isdefined(Base, :get_extension) import Requires: @require end diff --git a/src/array/indexing.jl b/src/array/indexing.jl index 82f44fbff..69725eb7a 100644 --- a/src/array/indexing.jl +++ b/src/array/indexing.jl @@ -1,5 +1,3 @@ -import TaskLocalValues: TaskLocalValue - ### getindex struct GetIndex{T,N} <: ArrayOp{T,N} diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 2553142a3..0360432ac 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -1202,7 +1202,7 @@ function proc_states(f::Base.Callable, uid::UInt64) end end proc_states(f::Base.Callable) = - proc_states(f, task_local_storage(:_dagger_sch_uid)::UInt64) + proc_states(f, Dagger.get_tls().sch_uid) task_tid_for_processor(::Processor) = nothing task_tid_for_processor(proc::Dagger.ThreadProc) = proc.tid diff --git a/src/sch/dynamic.jl b/src/sch/dynamic.jl index e02085ee6..5b917fdb5 100644 --- a/src/sch/dynamic.jl +++ b/src/sch/dynamic.jl @@ -17,7 +17,7 @@ struct SchedulerHandle end "Gets the scheduler handle for the currently-executing thunk." -sch_handle() = task_local_storage(:_dagger_sch_handle)::SchedulerHandle +sch_handle() = Dagger.get_tls().sch_handle::SchedulerHandle "Thrown when the scheduler halts before finishing processing the DAG." struct SchedulerHaltedException <: Exception end diff --git a/src/task-tls.jl b/src/task-tls.jl index ea188e004..90fdfedb3 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -1,41 +1,42 @@ # In-Thunk Helpers +struct DTaskTLS + processor::Processor + sch_uid::UInt + sch_handle::Any # FIXME: SchedulerHandle + task_spec::Vector{Any} # FIXME: TaskSpec +end + +const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) + """ - task_processor() + get_tls() -> DTaskTLS -Get the current processor executing the current Dagger task. +Gets all Dagger TLS variable as a `DTaskTLS`. """ -task_processor() = task_local_storage(:_dagger_processor)::Processor -@deprecate thunk_processor() task_processor() +get_tls() = DTASK_TLS[]::DTaskTLS """ - in_task() + set_tls!(tls) -Returns `true` if currently executing in a [`DTask`](@ref), else `false`. +Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedTuple`. """ -in_task() = haskey(task_local_storage(), :_dagger_sch_uid) -@deprecate in_thunk() in_task() +function set_tls!(tls) + DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec) +end """ - get_tls() + in_task() -> Bool -Gets all Dagger TLS variable as a `NamedTuple`. +Returns `true` if currently executing in a [`DTask`](@ref), else `false`. """ -get_tls() = ( - sch_uid=task_local_storage(:_dagger_sch_uid), - sch_handle=task_local_storage(:_dagger_sch_handle), - processor=task_processor(), - task_spec=task_local_storage(:_dagger_task_spec), -) +in_task() = DTASK_TLS[] !== nothing +@deprecate in_thunk() in_task() """ - set_tls!(tls) + task_processor() -> Processor -Sets all Dagger TLS variables from the `NamedTuple` `tls`. +Get the current processor executing the current [`DTask`](@ref). """ -function set_tls!(tls) - task_local_storage(:_dagger_sch_uid, tls.sch_uid) - task_local_storage(:_dagger_sch_handle, tls.sch_handle) - task_local_storage(:_dagger_processor, tls.processor) - task_local_storage(:_dagger_task_spec, tls.task_spec) -end +task_processor() = get_tls().processor +@deprecate thunk_processor() task_processor() From c07db4cfbbe7e8142f9756e44e549b3b6adbe0d7 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 13 Sep 2024 12:21:13 -0400 Subject: [PATCH 07/22] cancellation: Add cancel token support --- src/Dagger.jl | 4 ++-- src/cancellation.jl | 38 +++++++++++++++++++++++++++++++++++++- src/sch/Sch.jl | 14 ++++++++++++++ src/task-tls.jl | 21 ++++++++++++++++++++- src/threadproc.jl | 3 ++- 5 files changed, 75 insertions(+), 5 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index 2fc2da84e..8c5d6ae9d 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -57,16 +57,16 @@ include("processor.jl") include("threadproc.jl") include("context.jl") include("utils/processors.jl") +include("dtask.jl") +include("cancellation.jl") include("task-tls.jl") include("scopes.jl") include("utils/scopes.jl") -include("dtask.jl") include("queue.jl") include("thunk.jl") include("submission.jl") include("chunks.jl") include("memory-spaces.jl") -include("cancellation.jl") # Task scheduling include("compute.jl") diff --git a/src/cancellation.jl b/src/cancellation.jl index c982fd20c..dcb0f5add 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -1,3 +1,38 @@ +# DTask-level cancellation + +struct CancelToken + cancelled::Base.RefValue{Bool} + event::Base.Event +end +CancelToken() = CancelToken(Ref(false), Base.Event()) +function cancel!(token::CancelToken) + token.cancelled[] = true + notify(token.event) + return +end +is_cancelled(token::CancelToken) = token.cancelled[] +Base.wait(token::CancelToken) = wait(token.event) +# TODO: Enable this for safety +#Serialization.serialize(io::AbstractSerializer, ::CancelToken) = +# throw(ConcurrencyViolationError("Cannot serialize a CancelToken")) + +const DTASK_CANCEL_TOKEN = TaskLocalValue{Union{CancelToken,Nothing}}(()->nothing) + +function clone_cancel_token_remote(orig_token::CancelToken, wid::Integer) + remote_token = remotecall_fetch(wid) do + return poolset(CancelToken()) + end + errormonitor_tracked("remote cancel_token communicator", Threads.@spawn begin + wait(orig_token) + @dagdebug nothing :cancel "Cancelling remote token on worker $wid" + MemPool.access_ref(remote_token) do remote_token + cancel!(remote_token) + end + end) +end + +# Global-level cancellation + """ cancel!(task::DTask; force::Bool=false, halt_sch::Bool=false) @@ -80,11 +115,11 @@ function _cancel!(state, tid, force, halt_sch) Tf === typeof(Sch.eager_thunk) && continue istaskdone(task) && continue any_cancelled = true - @dagdebug tid :cancel "Cancelling running task ($Tf)" if force @dagdebug tid :cancel "Interrupting running task ($Tf)" Threads.@spawn Base.throwto(task, InterruptException()) else + @dagdebug tid :cancel "Cancelling running task ($Tf)" # Tell the processor to just drop this task task_occupancy = task_spec[4] time_util = task_spec[2] @@ -93,6 +128,7 @@ function _cancel!(state, tid, force, halt_sch) push!(istate.cancelled, tid) to_proc = istate.proc put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing))) + cancel!(istate.cancel_tokens[tid]) end end end diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 0360432ac..a5e8fb879 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -1183,6 +1183,7 @@ struct ProcessorInternalState proc_occupancy::Base.RefValue{UInt32} time_pressure::Base.RefValue{UInt64} cancelled::Set{Int} + cancel_tokens::Dict{Int,Dagger.CancelToken} done::Base.RefValue{Bool} end struct ProcessorState @@ -1332,7 +1333,14 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Execute the task and return its result t = @task begin + # Set up cancellation + cancel_token = Dagger.CancelToken() + Dagger.DTASK_CANCEL_TOKEN[] = cancel_token + lock(istate.queue) do _ + istate.cancel_tokens[thunk_id] = cancel_token + end was_cancelled = false + result = try do_task(to_proc, task) catch err @@ -1349,6 +1357,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Task was cancelled, so occupancy and pressure are # already reduced pop!(istate.cancelled, thunk_id) + delete!(istate.cancel_tokens, thunk_id) was_cancelled = true end end @@ -1366,6 +1375,9 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re else rethrow(err) end + finally + # Ensure that any spawned tasks get cleaned up + Dagger.cancel!(cancel_token) end end lock(istate.queue) do _ @@ -1415,6 +1427,7 @@ function do_tasks(to_proc, return_queue, tasks) Dict{Int,Vector{Any}}(), Ref(UInt32(0)), Ref(UInt64(0)), Set{Int}(), + Dict{Int,Dagger.CancelToken}(), Ref(false)) runner = start_processor_runner!(istate, uid, return_queue) @static if VERSION < v"1.9" @@ -1656,6 +1669,7 @@ function do_task(to_proc, task_desc) sch_handle, processor=to_proc, task_spec=task_desc, + cancel_token=Dagger.DTASK_CANCEL_TOKEN[], )) res = Dagger.with_options(propagated) do diff --git a/src/task-tls.jl b/src/task-tls.jl index 90fdfedb3..8a8b6c66d 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -5,6 +5,7 @@ struct DTaskTLS sch_uid::UInt sch_handle::Any # FIXME: SchedulerHandle task_spec::Vector{Any} # FIXME: TaskSpec + cancel_token::CancelToken end const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) @@ -22,7 +23,7 @@ get_tls() = DTASK_TLS[]::DTaskTLS Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedTuple`. """ function set_tls!(tls) - DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec) + DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token) end """ @@ -40,3 +41,21 @@ Get the current processor executing the current [`DTask`](@ref). """ task_processor() = get_tls().processor @deprecate thunk_processor() task_processor() + +""" + task_cancelled() -> Bool + +Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`. +""" +task_cancelled() = get_tls().cancel_token.cancelled[] + +""" + task_may_cancel!() + +Throws an `InterruptException` if the current [`DTask`](@ref) has been cancelled. +""" +function task_may_cancel!() + if task_cancelled() + throw(InterruptException()) + end +end diff --git a/src/threadproc.jl b/src/threadproc.jl index 09099889a..b75c90ca3 100644 --- a/src/threadproc.jl +++ b/src/threadproc.jl @@ -27,8 +27,9 @@ function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @n return result[] catch err if err isa InterruptException + # Direct interrupt hit us, propagate cancellation signal + # FIXME: We should tell the scheduler that the user hit Ctrl-C if !istaskdone(task) - # Propagate cancellation signal Threads.@spawn Base.throwto(task, InterruptException()) end end From 2af3f1076ebdfa5671c9ae4f2171635e0ecf0fe5 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 12:59:55 -0500 Subject: [PATCH 08/22] task-tls: Tweaks and fixes, task_id helper --- src/task-tls.jl | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/task-tls.jl b/src/task-tls.jl index 8a8b6c66d..f6889bbb1 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -1,6 +1,6 @@ # In-Thunk Helpers -struct DTaskTLS +mutable struct DTaskTLS processor::Processor sch_uid::UInt sch_handle::Any # FIXME: SchedulerHandle @@ -10,6 +10,8 @@ end const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) +Base.copy(tls::DTaskTLS) = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token) + """ get_tls() -> DTaskTLS @@ -32,7 +34,14 @@ end Returns `true` if currently executing in a [`DTask`](@ref), else `false`. """ in_task() = DTASK_TLS[] !== nothing -@deprecate in_thunk() in_task() +@deprecate(in_thunk(), in_task()) + +""" + task_id() -> Int + +Returns the ID of the current [`DTask`](@ref). +""" +task_id() = get_tls().sch_handle.thunk_id.id """ task_processor() -> Processor @@ -40,14 +49,14 @@ in_task() = DTASK_TLS[] !== nothing Get the current processor executing the current [`DTask`](@ref). """ task_processor() = get_tls().processor -@deprecate thunk_processor() task_processor() +@deprecate(thunk_processor(), task_processor()) """ task_cancelled() -> Bool Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`. """ -task_cancelled() = get_tls().cancel_token.cancelled[] +task_cancelled() = is_cancelled(get_tls().cancel_token) """ task_may_cancel!() @@ -59,3 +68,10 @@ function task_may_cancel!() throw(InterruptException()) end end + +""" + task_cancel!() + +Cancels the current [`DTask`](@ref). +""" +task_cancel!() = cancel!(get_tls().cancel_token) From 6da10fa6fbd4359d6dfbc6e11d53cbb63d096507 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 19:07:18 -0500 Subject: [PATCH 09/22] cancellation: Add graceful vs. forced --- src/cancellation.jl | 49 +++++++++++++++++++++++++++++---------------- src/task-tls.jl | 20 ++++++++++-------- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/src/cancellation.jl b/src/cancellation.jl index dcb0f5add..491b6756d 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -1,16 +1,29 @@ # DTask-level cancellation -struct CancelToken - cancelled::Base.RefValue{Bool} +mutable struct CancelToken + @atomic cancelled::Bool + @atomic graceful::Bool event::Base.Event end -CancelToken() = CancelToken(Ref(false), Base.Event()) -function cancel!(token::CancelToken) - token.cancelled[] = true +CancelToken() = CancelToken(false, false, Base.Event()) +function cancel!(token::CancelToken; graceful::Bool=true) + if !graceful + @atomic token.graceful = false + end + @atomic token.cancelled = true notify(token.event) return end -is_cancelled(token::CancelToken) = token.cancelled[] +function is_cancelled(token::CancelToken; must_force::Bool=false) + if token.cancelled[] + if must_force && token.graceful[] + # If we're only responding to forced cancellation, ignore graceful cancellations + return false + end + return true + end + return false +end Base.wait(token::CancelToken) = wait(token.event) # TODO: Enable this for safety #Serialization.serialize(io::AbstractSerializer, ::CancelToken) = @@ -34,13 +47,15 @@ end # Global-level cancellation """ - cancel!(task::DTask; force::Bool=false, halt_sch::Bool=false) + cancel!(task::DTask; force::Bool=false, graceful::Bool=true, halt_sch::Bool=false) Cancels `task` at any point in its lifecycle, causing the scheduler to abandon -it. If `force` is `true`, the task will be interrupted with an -`InterruptException` (not recommended, this is unsafe). If `halt_sch` is -`true`, the scheduler will be halted after the task is cancelled (it will -restart automatically upon the next `@spawn`/`spawn` call). +it. + +# Keyword arguments +- `force`: If `true`, the task will be interrupted with an `InterruptException` (not recommended, this is unsafe). +- `graceful`: If `true`, the task will be allowed to finish its current execution before being cancelled; otherwise, it will be cancelled as soon as possible. +- `halt_sch`: If `true`, the scheduler will be halted after the task is cancelled (it will restart automatically upon the next `@spawn`/`spawn` call). As an example, the following code will cancel task `t` before it finishes executing: @@ -56,24 +71,24 @@ tasks which are waiting to run. Using `cancel!` is generally a much safer alternative to Ctrl+C, as it cooperates with the scheduler and runtime and avoids unintended side effects. """ -function cancel!(task::DTask; force::Bool=false, halt_sch::Bool=false) +function cancel!(task::DTask; force::Bool=false, graceful::Bool=true, halt_sch::Bool=false) tid = lock(Dagger.Sch.EAGER_ID_MAP) do id_map id_map[task.uid] end - cancel!(tid; force, halt_sch) + cancel!(tid; force, graceful, halt_sch) end function cancel!(tid::Union{Int,Nothing}=nothing; - force::Bool=false, halt_sch::Bool=false) + force::Bool=false, graceful::Bool=true, halt_sch::Bool=false) remotecall_fetch(1, tid, force, halt_sch) do tid, force, halt_sch state = Sch.EAGER_STATE[] # Check that the scheduler isn't stopping or has already stopped if !isnothing(state) && !state.halt.set - @lock state.lock _cancel!(state, tid, force, halt_sch) + @lock state.lock _cancel!(state, tid, force, graceful, halt_sch) end end end -function _cancel!(state, tid, force, halt_sch) +function _cancel!(state, tid, force, graceful, halt_sch) @assert islocked(state.lock) # Get the scheduler uid @@ -128,7 +143,7 @@ function _cancel!(state, tid, force, halt_sch) push!(istate.cancelled, tid) to_proc = istate.proc put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing))) - cancel!(istate.cancel_tokens[tid]) + cancel!(istate.cancel_tokens[tid]; graceful) end end end diff --git a/src/task-tls.jl b/src/task-tls.jl index f6889bbb1..5c7d0375b 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -52,26 +52,30 @@ task_processor() = get_tls().processor @deprecate(thunk_processor(), task_processor()) """ - task_cancelled() -> Bool + task_cancelled(; must_force::Bool=false) -> Bool Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`. +If `must_force=true`, then only return `true` if the cancellation was forced. """ -task_cancelled() = is_cancelled(get_tls().cancel_token) +task_cancelled(; must_force::Bool=false) = + is_cancelled(get_tls().cancel_token; must_force) """ - task_may_cancel!() + task_may_cancel!(; must_force::Bool=false) Throws an `InterruptException` if the current [`DTask`](@ref) has been cancelled. +If `must_force=true`, then only throw if the cancellation was forced. """ -function task_may_cancel!() - if task_cancelled() +function task_may_cancel!(;must_force::Bool=false) + if task_cancelled(;must_force) throw(InterruptException()) end end """ - task_cancel!() + task_cancel!(; graceful::Bool=true) -Cancels the current [`DTask`](@ref). +Cancels the current [`DTask`](@ref). If `graceful=true`, then the task will be +cancelled gracefully, otherwise it will be forced. """ -task_cancel!() = cancel!(get_tls().cancel_token) +task_cancel!(; graceful::Bool=true) = cancel!(get_tls().cancel_token; graceful) From 2515c36dce3d61cace45e0d123d95fdedc1d7a93 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 19:07:56 -0500 Subject: [PATCH 10/22] cancellation: Wrap InterruptException in DTaskFailedException --- src/cancellation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cancellation.jl b/src/cancellation.jl index 491b6756d..63993a0e0 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -98,7 +98,7 @@ function _cancel!(state, tid, force, graceful, halt_sch) for task in state.ready tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling ready task" - state.cache[task] = InterruptException() + state.cache[task] = DTaskFailedException(task, task, InterruptException()) state.errored[task] = true Sch.set_failed!(state, task) end @@ -108,7 +108,7 @@ function _cancel!(state, tid, force, graceful, halt_sch) for task in keys(state.waiting) tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling waiting task" - state.cache[task] = InterruptException() + state.cache[task] = DTaskFailedException(task, task, InterruptException()) state.errored[task] = true Sch.set_failed!(state, task) end From 2b70bf3029574dc4ac5b98965a9ffa9b19e1678e Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 14 Sep 2024 11:54:34 -0400 Subject: [PATCH 11/22] Sch: Add unwrap_nested_exception for DTaskFailedException --- src/sch/util.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sch/util.jl b/src/sch/util.jl index cd006838b..eb5a285b4 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -29,6 +29,8 @@ unwrap_nested_exception(err::CapturedException) = unwrap_nested_exception(err.ex) unwrap_nested_exception(err::RemoteException) = unwrap_nested_exception(err.captured) +unwrap_nested_exception(err::DTaskFailedException) = + unwrap_nested_exception(err.ex) unwrap_nested_exception(err) = err "Gets a `NamedTuple` of options propagated by `thunk`." From 3638866714f3de52f338ff0fe57707e1127e093f Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 12:22:52 -0500 Subject: [PATCH 12/22] Add task_id for DTask --- src/sch/eager.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sch/eager.jl b/src/sch/eager.jl index 259c1b6f8..f3aca2ca0 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -134,3 +134,6 @@ function _find_thunk(e::Dagger.DTask) unwrap_weak_checked(EAGER_STATE[].thunk_dict[tid]) end end +Dagger.task_id(t::Dagger.DTask) = lock(EAGER_ID_MAP) do id_map + id_map[t.uid] +end From 2d7560326096e525de55a3f98c9fdcf004b714b0 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 13:05:49 -0500 Subject: [PATCH 13/22] dagdebug: Always yield to avoid heisenbugs --- src/utils/dagdebug.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 9a9d24167..8b6d3530f 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -31,6 +31,10 @@ macro dagdebug(thunk, category, msg, args...) $debug_ex_noid end end + + # Always yield to reduce differing behavior for debug vs. non-debug + # TODO: Remove this eventually + yield() end end) end From 1c806e99beb9145ab6dfb1d6a0bba8d8d81d6bf2 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 13:06:59 -0500 Subject: [PATCH 14/22] tests: Add offline mode --- test/runtests.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 00ed5862c..baa21c3a4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,7 +35,10 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) using Pkg Pkg.activate(@__DIR__) - Pkg.instantiate() + try + Pkg.instantiate() + catch + end using ArgParse s = ArgParseSettings(description = "Dagger Testsuite") @@ -55,6 +58,9 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ "-v", "--verbose" action = :store_true help = "Run the tests with debug logs from Dagger" + "-O", "--offline" + action = :store_true + help = "Set Pkg into offline mode" end end @@ -88,6 +94,11 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ if parsed_args["verbose"] ENV["JULIA_DEBUG"] = "Dagger" end + + if parsed_args["offline"] + Pkg.UPDATED_REGISTRY_THIS_SESSION[] = true + Pkg.offline(true) + end else to_test = all_test_names @info "Running all tests" From 5233385ed685c35649a1e57bd6c25f9e29682a7c Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 09:00:23 -0500 Subject: [PATCH 15/22] dagdebug: Add JULIA_DAGGER_DEBUG config variable --- src/Dagger.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/Dagger.jl b/src/Dagger.jl index 8c5d6ae9d..4623b5405 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -171,6 +171,20 @@ function __init__() ThreadProc(myid(), tid) end end + + # Set up @dagdebug categories, if specified + try + if haskey(ENV, "JULIA_DAGGER_DEBUG") + empty!(DAGDEBUG_CATEGORIES) + for category in split(ENV["JULIA_DAGGER_DEBUG"], ",") + if category != "" + push!(DAGDEBUG_CATEGORIES, Symbol(category)) + end + end + end + catch err + @warn "Error parsing JULIA_DAGGER_DEBUG" exception=err + end end end # module From 2d20762289e84018523ff1169d3cdd5018eb504c Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 19:08:31 -0500 Subject: [PATCH 16/22] options: Add internal helper to strip all options --- src/options.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/options.jl b/src/options.jl index 1c1e3ff29..00196dd59 100644 --- a/src/options.jl +++ b/src/options.jl @@ -20,6 +20,12 @@ function with_options(f, options::NamedTuple) end with_options(f; options...) = with_options(f, NamedTuple(options)) +function _without_options(f) + with(options_context => NamedTuple()) do + f() + end +end + """ get_options(key::Symbol, default) -> Any get_options(key::Symbol) -> Any From 32b2ccd0fd212125582465bba1c3fea131802eb9 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 21 Nov 2024 07:09:17 -0500 Subject: [PATCH 17/22] tests: Test DTaskFailedException inner type --- test/mutation.jl | 2 +- test/processors.jl | 4 ++-- test/scheduler.jl | 12 ++++++------ test/scopes.jl | 11 ++++++----- test/thunk.jl | 30 ++++++++++++++---------------- test/util.jl | 15 ++++++++++----- 6 files changed, 39 insertions(+), 35 deletions(-) diff --git a/test/mutation.jl b/test/mutation.jl index b6ac7143b..fa2f62bcf 100644 --- a/test/mutation.jl +++ b/test/mutation.jl @@ -48,7 +48,7 @@ end x = Dagger.@mutable worker=w Ref{Int}() @test fetch(Dagger.@spawn mutable_update!(x)) == w wo_scope = Dagger.ProcessScope(wo) - @test_throws_unwrap Dagger.DTaskFailedException fetch(Dagger.@spawn scope=wo_scope mutable_update!(x)) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) fetch(Dagger.@spawn scope=wo_scope mutable_update!(x)) end end # @testset "@mutable" diff --git a/test/processors.jl b/test/processors.jl index e97a1d239..6e56876dd 100644 --- a/test/processors.jl +++ b/test/processors.jl @@ -37,9 +37,9 @@ end end @testset "Processor exhaustion" begin opts = ThunkOptions(proclist=[OptOutProc]) - @test_throws_unwrap Dagger.DTaskFailedException ex isa Dagger.Sch.SchedulingException ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) opts = ThunkOptions(proclist=(proc)->false) - @test_throws_unwrap Dagger.DTaskFailedException ex isa Dagger.Sch.SchedulingException ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) opts = ThunkOptions(proclist=nothing) @test collect(delayed(sum; options=opts)([1,2,3])) == 6 end diff --git a/test/scheduler.jl b/test/scheduler.jl index b9fe01872..b12ad3e1e 100644 --- a/test/scheduler.jl +++ b/test/scheduler.jl @@ -182,7 +182,7 @@ end @testset "allow errors" begin opts = ThunkOptions(;allow_errors=true) a = delayed(error; options=opts)("Test") - @test_throws_unwrap Dagger.DTaskFailedException collect(a) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) collect(a) end end @@ -396,7 +396,7 @@ end ([Dagger.tochunk(MyStruct(1)), Dagger.tochunk(1)], sizeof(MyStruct)+sizeof(Int)), ] for arg in args - if arg isa Chunk + if arg isa Dagger.Chunk aff = Dagger.affinity(arg) @test aff[1] == OSProc(1) @test aff[2] == MemPool.approx_size(MemPool.poolget(arg.handle)) @@ -477,7 +477,7 @@ end @test res == 2 @testset "self as input" begin a = delayed(dynamic_add_thunk_self_dominated)(1) - @test_throws_unwrap Dagger.Sch.DynamicThunkException reason="Cannot fetch result of dominated thunk" collect(Context(), a) + @test_throws_unwrap (RemoteException, Dagger.Sch.DynamicThunkException) reason="Cannot fetch result of dominated thunk" collect(Context(), a) end end @testset "Fetch/Wait" begin @@ -487,11 +487,11 @@ end end @testset "self" begin a = delayed(dynamic_fetch_self)(1) - @test_throws_unwrap Dagger.Sch.DynamicThunkException reason="Cannot fetch own result" collect(Context(), a) + @test_throws_unwrap (RemoteException, Dagger.Sch.DynamicThunkException) reason="Cannot fetch own result" collect(Context(), a) end @testset "dominated" begin a = delayed(identity)(delayed(dynamic_fetch_dominated)(1)) - @test_throws_unwrap Dagger.Sch.DynamicThunkException reason="Cannot fetch result of dominated thunk" collect(Context(), a) + @test_throws_unwrap (RemoteException, Dagger.Sch.DynamicThunkException) reason="Cannot fetch result of dominated thunk" collect(Context(), a) end end end @@ -540,7 +540,7 @@ end t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) sleep(100) start_time = time_ns() Dagger.cancel!(t) - @test_throws_unwrap Dagger.DTaskFailedException fetch(t) + @test_throws_unwrap (Dagger.DTaskFailedException, InterruptException) fetch(t) t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) yield() fetch(t) finish_time = time_ns() diff --git a/test/scopes.jl b/test/scopes.jl index 5f82a71a0..065e5158f 100644 --- a/test/scopes.jl +++ b/test/scopes.jl @@ -1,3 +1,4 @@ +#@everywhere ENV["JULIA_DEBUG"] = "Dagger" @testset "Chunk Scopes" begin wid1, wid2 = addprocs(2, exeflags=["-t 2"]) @everywhere [wid1,wid2] using Dagger @@ -56,7 +57,7 @@ # Different nodes for (ch1, ch2) in [(ns1_ch, ns2_ch), (ns2_ch, ns1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Process Scope" begin @@ -75,7 +76,7 @@ # Different process for (ch1, ch2) in [(ps1_ch, ps2_ch), (ps2_ch, ps1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end # Same process and node @@ -83,7 +84,7 @@ # Different process and node for (ch1, ch2) in [(ps1_ch, ns2_ch), (ns2_ch, ps1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Exact Scope" begin @@ -104,14 +105,14 @@ # Different process, different processor for (ch1, ch2) in [(es1_ch, es2_ch), (es2_ch, es1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end # Same process, different processor es1_2 = ExactScope(Dagger.ThreadProc(wid1, 2)) es1_2_ch = Dagger.tochunk(nothing, OSProc(), es1_2) for (ch1, ch2) in [(es1_ch, es1_2_ch), (es1_2_ch, es1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Union Scope" begin diff --git a/test/thunk.jl b/test/thunk.jl index e6fb7e86b..73879545b 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -69,7 +69,7 @@ end A = rand(4, 4) @test fetch(@spawn sum(A; dims=1)) ≈ sum(A; dims=1) - @test_throws_unwrap Dagger.DTaskFailedException fetch(@spawn sum(A; fakearg=2)) + @test_throws_unwrap (Dagger.DTaskFailedException, MethodError) fetch(@spawn sum(A; fakearg=2)) @test fetch(@spawn reduce(+, A; dims=1, init=2.0)) ≈ reduce(+, A; dims=1, init=2.0) @@ -194,7 +194,7 @@ end a = @spawn error("Test") wait(a) @test isready(a) - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(a) b = @spawn 1+2 @test fetch(b) == 3 end @@ -207,8 +207,7 @@ end catch err err end - ex = Dagger.Sch.unwrap_nested_exception(ex) - ex_str = sprint(io->Base.showerror(io,ex)) + ex_str = sprint(io->Base.showerror(io, ex)) @test occursin(r"^DTaskFailedException:", ex_str) @test occursin("Test", ex_str) @test !occursin("Root Task", ex_str) @@ -218,36 +217,35 @@ end catch err err end - ex = Dagger.Sch.unwrap_nested_exception(ex) - ex_str = sprint(io->Base.showerror(io,ex)) + ex_str = sprint(io->Base.showerror(io, ex)) @test occursin("Test", ex_str) @test occursin("Root Task", ex_str) end @testset "single dependent" begin a = @spawn error("Test") b = @spawn a+2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(a) end @testset "multi dependent" begin a = @spawn error("Test") b = @spawn a+2 c = @spawn a*2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(b) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(c) end @testset "dependent chain" begin a = @spawn error("Test") - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(a) b = @spawn a+1 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(b) c = @spawn b+2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(c) end @testset "single input" begin a = @spawn 1+1 b = @spawn (a->error("Test"))(a) @test fetch(a) == 2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(b) end @testset "multi input" begin a = @spawn 1+1 @@ -255,7 +253,7 @@ end c = @spawn ((a,b)->error("Test"))(a,b) @test fetch(a) == 2 @test fetch(b) == 4 - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(c) end @testset "diamond" begin a = @spawn 1+1 @@ -265,7 +263,7 @@ end @test fetch(a) == 2 @test fetch(b) == 3 @test fetch(c) == 4 - @test_throws_unwrap Dagger.DTaskFailedException fetch(d) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(d) end end @testset "remote spawn" begin @@ -283,7 +281,7 @@ end t1 = Dagger.@spawn 1+"fail" Dagger.@spawn t1+1 end - @test_throws_unwrap Dagger.DTaskFailedException fetch(t2) + @test_throws_unwrap (Dagger.DTaskFailedException, MethodError) fetch(t2) end @testset "undefined function" begin # Issues #254, #255 diff --git a/test/util.jl b/test/util.jl index f01b3d95d..1131a9ebe 100644 --- a/test/util.jl +++ b/test/util.jl @@ -14,7 +14,7 @@ end replace_obj!(ex::Symbol, obj) = Expr(:(.), obj, QuoteNode(ex)) replace_obj!(ex, obj) = ex function _test_throws_unwrap(terr, ex; to_match=[]) - @gensym rerr + @gensym oerr rerr match_expr = Expr(:block) for m in to_match if m.head == :(=) @@ -35,12 +35,17 @@ function _test_throws_unwrap(terr, ex; to_match=[]) end end quote - $rerr = try - $(esc(ex)) + $oerr, $rerr = try + nothing, $(esc(ex)) catch err - Dagger.Sch.unwrap_nested_exception(err) + (err, Dagger.Sch.unwrap_nested_exception(err)) + end + if $terr isa Tuple + @test $oerr isa $terr[1] + @test $rerr isa $terr[2] + else + @test $rerr isa $terr end - @test $rerr isa $terr $match_expr end end From a5b663ec6e9d0658f3d81a7f5e8edd9801330f49 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 25 Nov 2024 12:18:50 -0600 Subject: [PATCH 18/22] Sch: Skip not-yet-inited workers --- src/sch/Sch.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index a5e8fb879..84be60ebc 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -684,6 +684,9 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) safepoint(state) @assert length(procs) > 0 + # Remove processors that aren't yet initialized + procs = filter(p -> haskey(state.worker_chans, Dagger.root_worker_id(p)), procs) + populate_processor_cache_list!(state, procs) # Schedule tasks From 577e17900eae2be57aa219adcda0e95ba46aeaff Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sat, 16 Nov 2024 21:13:29 +0100 Subject: [PATCH 19/22] Bump MemPool compat --- .buildkite/pipeline.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index db4d09e16..6a5658705 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -6,6 +6,7 @@ os: linux arch: x86_64 command: "julia --project -e 'using Pkg; Pkg.develop(;path=\"lib/TimespanLogging\")'" + .bench: &bench if: build.message =~ /\[run benchmarks\]/ agents: @@ -14,6 +15,7 @@ os: linux arch: x86_64 num_cpus: 16 + steps: - label: Julia 1.9 timeout_in_minutes: 90 From a765cbe803b0c7654da10f2db930b5f9518e10db Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 12 Sep 2023 10:56:47 -0500 Subject: [PATCH 20/22] Add streaming API Co-authored-by: JamesWrigley Co-authored-by: davidizzle --- Project.toml | 1 - docs/make.jl | 1 + docs/src/index.md | 34 ++ docs/src/streaming.md | 106 +++++++ src/Dagger.jl | 6 +- src/sch/Sch.jl | 10 +- src/sch/eager.jl | 7 + src/sch/util.jl | 2 + src/stream-buffers.jl | 64 ++++ src/stream-transfer.jl | 71 +++++ src/stream.jl | 682 +++++++++++++++++++++++++++++++++++++++++ src/submission.jl | 2 +- src/utils/dagdebug.jl | 3 +- test/runtests.jl | 2 +- test/streaming.jl | 380 +++++++++++++++++++++++ 15 files changed, 1362 insertions(+), 9 deletions(-) create mode 100644 docs/src/streaming.md create mode 100644 src/stream-buffers.jl create mode 100644 src/stream-transfer.jl create mode 100644 src/stream.jl create mode 100644 test/streaming.jl diff --git a/Project.toml b/Project.toml index 45cf988c8..8735298dc 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" -Mmap = "a63ad114-7e13-5084-954f-fe012c677804" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" diff --git a/docs/make.jl b/docs/make.jl index c21c03f2d..8f1f97f5c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -22,6 +22,7 @@ makedocs(; "Task Spawning" => "task-spawning.md", "Data Management" => "data-management.md", "Distributed Arrays" => "darray.md", + "Streaming Tasks" => "streaming.md", "Scopes" => "scopes.md", "Processors" => "processors.md", "Task Queues" => "task-queues.md", diff --git a/docs/src/index.md b/docs/src/index.md index 27eb28dd3..87b4ea174 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -394,3 +394,37 @@ Dagger.@spawn copyto!(C, X) In contrast to the previous example, here, the tasks are executed without argument annotations. As a result, there is a possibility of the `copyto!` task being executed before the `sort!` task, leading to unexpected results in the output array `C`. +## Quickstart: Streaming + +Dagger.jl provides a streaming API that allows you to process data in a streaming fashion, where data is processed as it becomes available, rather than waiting for the entire dataset to be loaded into memory. + +For more details: [Streaming](@ref) + +### Syntax + +The `Dagger.spawn_streaming()` function is used to create a streaming region, +where tasks are executed continuously, processing data as it becomes available: + +```julia +# Open a file to write to on this worker +f = Dagger.@mutable open("output.txt", "w") +t = Dagger.spawn_streaming() do + # Generate random numbers continuously + val = Dagger.@spawn rand() + # Write each random number to a file + Dagger.@spawn (f, val) -> begin + if val < 0.01 + # Finish streaming when the random number is less than 0.01 + Dagger.finish_stream() + end + println(f, val) + end +end +# Wait for all values to be generated and written +wait(t) +``` + +The above example demonstrates a streaming region that generates random numbers +continuously and writes each random number to a file. The streaming region is +terminated when a random number less than 0.01 is generated, which is done by +calling `Dagger.finish_stream()` (this exits the current streaming task). diff --git a/docs/src/streaming.md b/docs/src/streaming.md new file mode 100644 index 000000000..25060e1b2 --- /dev/null +++ b/docs/src/streaming.md @@ -0,0 +1,106 @@ +# Streaming + +Dagger tasks have a limited lifetime - they are created, execute, finish, and +are eventually destroyed when they're no longer needed. Thus, if one wants +to run the same kind of computations over and over, one might re-create a +similar set of tasks for each unit of data that needs processing. + +This might be fine for computations which take a long time to run (thus +dwarfing the cost of task creation, which is quite small), or when working with +a limited set of data, but this approach is not great for doing lots of small +computations on a large (or endless) amount of data. For example, processing +image frames from a webcam, reacting to messages from a message bus, reading +samples from a software radio, etc. All of these tasks are better suited to a +"streaming" model of data processing, where data is simply piped into a +continuously-running task (or DAG of tasks) forever, or until the data runs +out. + +Thankfully, if you have a problem which is best modeled as a streaming system +of tasks, Dagger has you covered! Building on its support for +[Task Queues](@ref), Dagger provides a means to convert an entire DAG of +tasks into a streaming DAG, where data flows into and out of each task +asynchronously, using the `spawn_streaming` function: + +```julia +Dagger.spawn_streaming() do # enters a streaming region + vals = Dagger.@spawn rand() + print_vals = Dagger.@spawn println(vals) +end # exits the streaming region, and starts the DAG running +``` + +In the above example, `vals` is a Dagger task which has been transformed to run +in a streaming manner - instead of just calling `rand()` once and returning its +result, it will re-run `rand()` endlessly, continuously producing new random +values. In typical Dagger style, `print_vals` is a Dagger task which depends on +`vals`, but in streaming form - it will continuously `println` the random +values produced from `vals`. Both tasks will run forever, and will run +efficiently, only doing the work necessary to generate, transfer, and consume +values. + +As the comments point out, `spawn_streaming` creates a streaming region, during +which `vals` and `print_vals` are created and configured. Both tasks are halted +until `spawn_streaming` returns, allowing large DAGs to be built all at once, +without any task losing a single value. If desired, streaming regions can be +connected, although some values might be lost while tasks are being connected: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.@spawn rand() +end + +# Some values might be generated by `vals` but thrown away +# before `print_vals` is fully setup and connected to it + +print_vals = Dagger.spawn_streaming() do + Dagger.@spawn println(vals) +end +``` + +More complicated streaming DAGs can be easily constructed, without doing +anything different. For example, we can generate multiple streams of random +numbers, write them all to their own files, and print the combined results: + +```julia +Dagger.spawn_streaming() do + all_vals = [Dagger.spawn(rand) for i in 1:4] + all_vals_written = map(1:4) do i + Dagger.spawn(all_vals[i]) do val + open("results_$i.txt"; write=true, create=true, append=true) do io + println(io, repr(val)) + end + return val + end + end + Dagger.spawn(all_vals_written...) do all_vals_written... + vals_sum = sum(all_vals_written) + println(vals_sum) + end +end +``` + +If you want to stop the streaming DAG and tear it all down, you can call +`Dagger.cancel!.(all_vals)` and `Dagger.cancel!.(all_vals_written)` to +terminate each streaming task. In the future, a more convenient way to tear +down a full DAG will be added; for now, each task must be cancelled individually. + +Alternatively, tasks can stop themselves from the inside with +`finish_stream`, optionally returning a value that can be `fetch`'d. Let's +do this when our randomly-drawn number falls within some arbitrary range: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.spawn() do + x = rand() + if x < 0.001 + # That's good enough, let's be done + return Dagger.finish_stream("Finished!") + end + return x + end +end +fetch(vals) +``` + +In this example, the call to `fetch` will hang (while random numbers continue +to be drawn), until a drawn number is less than 0.001; at that point, `fetch` +will return with `"Finished!"`, and the task `vals` will have terminated. diff --git a/src/Dagger.jl b/src/Dagger.jl index 4623b5405..fd6395a4b 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -21,7 +21,6 @@ if !isdefined(Base, :ScopedValues) else import Base.ScopedValues: ScopedValue, with end - import TaskLocalValues: TaskLocalValue if !isdefined(Base, :get_extension) @@ -78,6 +77,11 @@ include("sch/Sch.jl"); using .Sch # Data dependency task queue include("datadeps.jl") +# Streaming +include("stream.jl") +include("stream-buffers.jl") +include("stream-transfer.jl") + # Array computations include("array/darray.jl") include("array/alloc.jl") diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 84be60ebc..b894f4526 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -259,9 +259,11 @@ end Combine `SchedulerOptions` and `ThunkOptions` into a new `ThunkOptions`. """ function Base.merge(sopts::SchedulerOptions, topts::ThunkOptions) - single = topts.single !== nothing ? topts.single : sopts.single - allow_errors = topts.allow_errors !== nothing ? topts.allow_errors : sopts.allow_errors - proclist = topts.proclist !== nothing ? topts.proclist : sopts.proclist + select_option = (sopt, topt) -> isnothing(topt) ? sopt : topt + + single = select_option(sopts.single, topts.single) + allow_errors = select_option(sopts.allow_errors, topts.allow_errors) + proclist = select_option(sopts.proclist, topts.proclist) ThunkOptions(single, proclist, topts.time_util, @@ -1376,7 +1378,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re if unwrap_nested_exception(err) isa InvalidStateException || !isopen(return_queue) @dagdebug thunk_id :execute "Return queue is closed, failing to put result" chan=return_queue exception=(err, catch_backtrace()) else - rethrow(err) + rethrow() end finally # Ensure that any spawned tasks get cleaned up diff --git a/src/sch/eager.jl b/src/sch/eager.jl index f3aca2ca0..aea0abbf6 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -124,6 +124,13 @@ function eager_cleanup(state, uid) # N.B. cache and errored expire automatically delete!(state.thunk_dict, tid) end + remotecall_wait(1, uid) do uid + lock(Dagger.EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + delete!(global_streams, uid) + end + end + end end function _find_thunk(e::Dagger.DTask) diff --git a/src/sch/util.jl b/src/sch/util.jl index eb5a285b4..2e090b26c 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -31,6 +31,8 @@ unwrap_nested_exception(err::RemoteException) = unwrap_nested_exception(err.captured) unwrap_nested_exception(err::DTaskFailedException) = unwrap_nested_exception(err.ex) +unwrap_nested_exception(err::TaskFailedException) = + unwrap_nested_exception(err.t.exception) unwrap_nested_exception(err) = err "Gets a `NamedTuple` of options propagated by `thunk`." diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl new file mode 100644 index 000000000..9770933f6 --- /dev/null +++ b/src/stream-buffers.jl @@ -0,0 +1,64 @@ +"A process-local ring buffer." +mutable struct ProcessRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + @atomic open::Bool + function ProcessRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer, true) + end +end +Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0 +isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer) +capacity(rb::ProcessRingBuffer) = length(rb.buffer) +Base.length(rb::ProcessRingBuffer) = @atomic rb.count +Base.isopen(rb::ProcessRingBuffer) = @atomic rb.open +function Base.close(rb::ProcessRingBuffer) + @atomic rb.open = false +end +function Base.put!(rb::ProcessRingBuffer{T}, x) where T + while isfull(rb) + yield() + if !isopen(rb) + throw(InvalidStateException("ProcessRingBuffer is closed", :closed)) + end + task_may_cancel!(; must_force=true) + end + to_write_idx = mod1(rb.write_idx, length(rb.buffer)) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::ProcessRingBuffer) + while isempty(rb) + yield() + if !isopen(rb) && isempty(rb) + throw(InvalidStateException("ProcessRingBuffer is closed", :closed)) + end + if task_cancelled() && isempty(rb) + # We respect a graceful cancellation only if the buffer is empty. + # Otherwise, we may have values to continue communicating. + task_may_cancel!() + end + task_may_cancel!(; must_force=true) + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end + +""" +`take!()` all the elements from a buffer and put them in a `Vector`. +""" +function collect!(rb::ProcessRingBuffer{T}) where T + output = Vector{T}(undef, rb.count) + for i in 1:rb.count + output[i] = take!(rb) + end + + return output +end diff --git a/src/stream-transfer.jl b/src/stream-transfer.jl new file mode 100644 index 000000000..96e61fb9c --- /dev/null +++ b/src/stream-transfer.jl @@ -0,0 +1,71 @@ +struct RemoteChannelFetcher + chan::RemoteChannel + RemoteChannelFetcher() = new(RemoteChannel()) +end +const _THEIR_TID = TaskLocalValue{Int}(()->0) +function stream_push_values!(fetcher::RemoteChannelFetcher, T, our_store::StreamStore, their_stream::Stream, buffer) + our_tid = STREAM_THUNK_ID[] + our_uid = our_store.uid + their_uid = their_stream.uid + if _THEIR_TID[] == 0 + _THEIR_TID[] = remotecall_fetch(1) do + lock(Sch.EAGER_ID_MAP) do id_map + id_map[their_uid] + end + end + end + their_tid = _THEIR_TID[] + @dagdebug our_tid :stream_push "taking output value: $our_tid -> $their_tid" + value = try + take!(buffer) + catch + close(fetcher.chan) + rethrow() + end + @lock our_store.lock notify(our_store.lock) + @dagdebug our_tid :stream_push "pushing output value: $our_tid -> $their_tid" + try + put!(fetcher.chan, value) + catch err + if err isa InvalidStateException && !isopen(fetcher.chan) + @dagdebug our_tid :stream_push "channel closed: $our_tid -> $their_tid" + throw(InterruptException()) + end + # N.B. We don't close the buffer to allow for eventual reconnection + rethrow() + end + @dagdebug our_tid :stream_push "finished pushing output value: $our_tid -> $their_tid" +end +function stream_pull_values!(fetcher::RemoteChannelFetcher, T, our_store::StreamStore, their_stream::Stream, buffer) + our_tid = STREAM_THUNK_ID[] + our_uid = our_store.uid + their_uid = their_stream.uid + if _THEIR_TID[] == 0 + _THEIR_TID[] = remotecall_fetch(1) do + lock(Sch.EAGER_ID_MAP) do id_map + id_map[their_uid] + end + end + end + their_tid = _THEIR_TID[] + @dagdebug our_tid :stream_pull "pulling input value: $their_tid -> $our_tid" + value = try + take!(fetcher.chan) + catch err + if err isa InvalidStateException && !isopen(fetcher.chan) + @dagdebug our_tid :stream_pull "channel closed: $their_tid -> $our_tid" + throw(InterruptException()) + end + # N.B. We don't close the buffer to allow for eventual reconnection + rethrow() + end + @dagdebug our_tid :stream_pull "putting input value: $their_tid -> $our_tid" + try + put!(buffer, value) + catch + close(fetcher.chan) + rethrow() + end + @lock our_store.lock notify(our_store.lock) + @dagdebug our_tid :stream_pull "finished putting input value: $their_tid -> $our_tid" +end diff --git a/src/stream.jl b/src/stream.jl new file mode 100644 index 000000000..81becd5ac --- /dev/null +++ b/src/stream.jl @@ -0,0 +1,682 @@ +mutable struct StreamStore{T,B} + uid::UInt + waiters::Vector{Int} + input_streams::Dict{UInt,Any} # FIXME: Concrete type + output_streams::Dict{UInt,Any} # FIXME: Concrete type + input_buffers::Dict{UInt,B} + output_buffers::Dict{UInt,B} + input_buffer_amount::Int + output_buffer_amount::Int + input_fetchers::Dict{UInt,Any} + output_fetchers::Dict{UInt,Any} + open::Bool + migrating::Bool + lock::Threads.Condition + StreamStore{T,B}(uid::UInt, input_buffer_amount::Integer, output_buffer_amount::Integer) where {T,B} = + new{T,B}(uid, zeros(Int, 0), + Dict{UInt,Any}(), Dict{UInt,Any}(), + Dict{UInt,B}(), Dict{UInt,B}(), + input_buffer_amount, output_buffer_amount, + Dict{UInt,Any}(), Dict{UInt,Any}(), + true, false, Threads.Condition()) +end + +function tid_to_uid(thunk_id) + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end +end + +function Base.put!(store::StreamStore{T,B}, value) where {T,B} + thunk_id = STREAM_THUNK_ID[] + @lock store.lock begin + if !isopen(store) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + @dagdebug thunk_id :stream "adding $value ($(length(store.output_streams)) outputs)" + for output_uid in keys(store.output_streams) + if !haskey(store.output_buffers, output_uid) + initialize_output_stream!(store, output_uid) + end + buffer = store.output_buffers[output_uid] + while isfull(buffer) + if !isopen(store) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + @dagdebug thunk_id :stream "buffer full ($(length(buffer)) values), waiting" + wait(store.lock) + if !isfull(buffer) + @dagdebug thunk_id :stream "buffer has space ($(length(buffer)) values), continuing" + end + task_may_cancel!() + end + put!(buffer, value) + end + notify(store.lock) + end +end + +function Base.take!(store::StreamStore, id::UInt) + thunk_id = STREAM_THUNK_ID[] + @lock store.lock begin + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + error("Must first check isempty(store, id) before taking from a stream") + end + buffer = store.output_buffers[id] + while isempty(buffer) && isopen(store, id) + @dagdebug thunk_id :stream "no elements, not taking" + wait(store.lock) + task_may_cancel!() + end + @dagdebug thunk_id :stream "wait finished" + if !isopen(store, id) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + unlock(store.lock) + value = try + take!(buffer) + finally + lock(store.lock) + end + @dagdebug thunk_id :stream "value accepted" + notify(store.lock) + return value + end +end + +""" +Returns whether the store is actively open. Only check this when deciding if +new values can be pushed. +""" +Base.isopen(store::StreamStore) = store.open + +""" +Returns whether the store is actively open, or if closing, still has remaining +messages for `id`. Only check this when deciding if existing values can be +taken. +""" +function Base.isopen(store::StreamStore, id::UInt) + @lock store.lock begin + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + return store.open + end + if !isempty(store.output_buffers[id]) + return true + end + return store.open + end +end + +function Base.close(store::StreamStore) + @lock store.lock begin + store.open || return + store.open = false + for buffer in values(store.input_buffers) + close(buffer) + end + for buffer in values(store.output_buffers) + close(buffer) + end + notify(store.lock) + end +end + +# FIXME: Just pass Stream directly, rather than its uid +function add_waiters!(store::StreamStore{T,B}, waiters::Vector{Pair{UInt,Any}}) where {T,B} + our_uid = store.uid + @lock store.lock begin + for (output_uid, output_fetcher) in waiters + store.output_streams[output_uid] = task_to_stream(output_uid) + push!(store.waiters, output_uid) + store.output_fetchers[output_uid] = output_fetcher + end + notify(store.lock) + end +end + +function remove_waiters!(store::StreamStore, waiters::Vector{UInt}) + @lock store.lock begin + for w in waiters + delete!(store.output_buffers, w) + idx = findfirst(wo->wo==w, store.waiters) + deleteat!(store.waiters, idx) + delete!(store.input_streams, w) + end + notify(store.lock) + end +end + +mutable struct Stream{T,B} + uid::UInt + store::Union{StreamStore{T,B},Nothing} + store_ref::Chunk + function Stream{T,B}(uid::UInt, input_buffer_amount::Integer, output_buffer_amount::Integer) where {T,B} + # Creates a new output stream + store = StreamStore{T,B}(uid, input_buffer_amount, output_buffer_amount) + store_ref = tochunk(store) + return new{T,B}(uid, store, store_ref) + end + function Stream(stream::Stream{T,B}) where {T,B} + # References an existing output stream + return new{T,B}(stream.uid, nothing, stream.store_ref) + end +end + +struct StreamingValue{B} + buffer::B +end +Base.take!(sv::StreamingValue) = take!(sv.buffer) + +function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::Stream{IT,IB}) where {IT,OT,IB,OB} + input_uid = input_stream.uid + our_uid = our_store.uid + local buffer, input_fetcher + @lock our_store.lock begin + if haskey(our_store.input_buffers, input_uid) + return StreamingValue(our_store.input_buffers[input_uid]) + end + + buffer = initialize_stream_buffer(OB, IT, our_store.input_buffer_amount) + # FIXME: Also pass a RemoteChannel to track remote closure + our_store.input_buffers[input_uid] = buffer + input_fetcher = our_store.input_fetchers[input_uid] + end + thunk_id = STREAM_THUNK_ID[] + tls = get_tls() + Sch.errormonitor_tracked("streaming input: $input_uid -> $our_uid", Threads.@spawn begin + set_tls!(tls) + STREAM_THUNK_ID[] = thunk_id + try + while isopen(our_store) + stream_pull_values!(input_fetcher, IT, our_store, input_stream, buffer) + end + catch err + unwrapped_err = Sch.unwrap_nested_exception(err) + if unwrapped_err isa InterruptException || (unwrapped_err isa InvalidStateException && !isopen(buffer)) + return + else + rethrow() + end + finally + @dagdebug STREAM_THUNK_ID[] :stream "input stream closed" + end + end) + return StreamingValue(buffer) +end +initialize_input_stream!(our_store::StreamStore, arg) = arg +function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt) where {T,B} + @assert islocked(our_store.lock) + @dagdebug STREAM_THUNK_ID[] :stream "initializing output stream $output_uid" + buffer = initialize_stream_buffer(B, T, our_store.output_buffer_amount) + our_store.output_buffers[output_uid] = buffer + our_uid = our_store.uid + output_stream = our_store.output_streams[output_uid] + output_fetcher = our_store.output_fetchers[output_uid] + thunk_id = STREAM_THUNK_ID[] + tls = get_tls() + Sch.errormonitor_tracked("streaming output: $our_uid -> $output_uid", Threads.@spawn begin + set_tls!(tls) + STREAM_THUNK_ID[] = thunk_id + try + while true + if !isopen(our_store) && isempty(buffer) + # Only exit if the buffer is empty; otherwise, we need to + # continue draining it + break + end + stream_push_values!(output_fetcher, T, our_store, output_stream, buffer) + end + catch err + unwrapped_err = Sch.unwrap_nested_exception(err) + if unwrapped_err isa InterruptException || (unwrapped_err isa InvalidStateException && !isopen(buffer)) + return + else + rethrow() + end + finally + @dagdebug thunk_id :stream "output stream closed" + end + end) +end + +Base.put!(stream::Stream, @nospecialize(value)) = put!(stream.store, value) + +function Base.isopen(stream::Stream, id::UInt)::Bool + return MemPool.access_ref(stream.store_ref.handle, id) do store, id + return isopen(store::StreamStore, id) + end +end + +function Base.close(stream::Stream) + MemPool.access_ref(stream.store_ref.handle) do store + close(store::StreamStore) + return + end + return +end + +function add_waiters!(stream::Stream, waiters::Vector{Pair{UInt,Any}}) + MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters + add_waiters!(store::StreamStore, waiters) + return + end + return +end + +function remove_waiters!(stream::Stream, waiters::Vector{UInt}) + MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters + remove_waiters!(store::StreamStore, waiters) + return + end + return +end + +struct StreamingFunction{F, S} + f::F + stream::S + max_evals::Int + + StreamingFunction(f::F, stream::S, max_evals) where {F, S} = + new{F, S}(f, stream, max_evals) +end + +function migrate_stream!(stream::Stream, w::Integer=myid()) + # Perform migration of the StreamStore + # MemPool will block access to the new ref until the migration completes + # FIXME: Do this ownership check with MemPool.access_ref, + # in case stream was already migrated + if stream.store_ref.handle.owner != w + thunk_id = STREAM_THUNK_ID[] + @dagdebug thunk_id :stream "Beginning migration... ($(length(stream.store.input_streams)) -> $(length(stream.store.output_streams)))" + + # TODO: Wire up listener to ferry cancel_token notifications to remote + # worker once migrations occur during runtime + tls = get_tls() + @assert w == myid() "Only pull-based migration is currently supported" + #remote_cancel_token = clone_cancel_token_remote(get_tls().cancel_token, worker_id) + + new_store_ref = MemPool.migrate!(stream.store_ref.handle, w; + pre_migration=store->begin + # Lock store to prevent any further modifications + # N.B. Serialization automatically unlocks the migrated copy + lock((store::StreamStore).lock) + + # Return the serializeable unsent inputs/outputs. We can't send the + # buffers themselves because they may be mmap'ed or something. + unsent_inputs = Dict(uid => collect!(buffer) for (uid, buffer) in store.input_buffers) + unsent_outputs = Dict(uid => collect!(buffer) for (uid, buffer) in store.output_buffers) + empty!(store.input_buffers) + empty!(store.output_buffers) + return (unsent_inputs, unsent_outputs) + end, + dest_post_migration=(store, unsent)->begin + # Initialize the StreamStore on the destination with the unsent inputs/outputs. + STREAM_THUNK_ID[] = thunk_id + @assert !in_task() + set_tls!(tls) + #get_tls().cancel_token = MemPool.access_ref(identity, remote_cancel_token; local_only=true) + unsent_inputs, unsent_outputs = unsent + for (input_uid, inputs) in unsent_inputs + input_stream = store.input_streams[input_uid] + initialize_input_stream!(store, input_stream) + for item in inputs + put!(store.input_buffers[input_uid], item) + end + end + for (output_uid, outputs) in unsent_outputs + initialize_output_stream!(store, output_uid) + for item in outputs + put!(store.output_buffers[output_uid], item) + end + end + + # Reset the state of this new store + store.open = true + store.migrating = false + end, + post_migration=store->begin + # Indicate that this store has migrated + store.migrating = true + store.open = false + + # Unlock the store + unlock((store::StreamStore).lock) + end) + if w == myid() + stream.store_ref.handle = new_store_ref # FIXME: It's not valid to mutate the Chunk handle, but we want to update this to enable fast location queries + stream.store = MemPool.access_ref(identity, new_store_ref; local_only=true) + end + + @dagdebug thunk_id :stream "Migration complete ($(length(stream.store.input_streams)) -> $(length(stream.store.output_streams)))" + end +end + +struct StreamingTaskQueue <: AbstractTaskQueue + tasks::Vector{Pair{DTaskSpec,DTask}} + self_streams::Dict{UInt,Any} + StreamingTaskQueue() = new(Pair{DTaskSpec,DTask}[], + Dict{UInt,Any}()) +end + +function enqueue!(queue::StreamingTaskQueue, spec::Pair{DTaskSpec,DTask}) + push!(queue.tasks, spec) + initialize_streaming!(queue.self_streams, spec...) +end + +function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) + append!(queue.tasks, specs) + for (spec, task) in specs + initialize_streaming!(queue.self_streams, spec, task) + end +end + +function initialize_streaming!(self_streams, spec, task) + @assert !isa(spec.f, StreamingFunction) "Task is already in streaming form" + + # Calculate the return type of the called function + T_old = Base.uniontypes(task.metadata.return_type) + T_old = map(t->(t !== Union{} && t <: FinishStream) ? first(t.parameters) : t, T_old) + # N.B. We treat non-dominating error paths as unreachable + T_old = filter(t->t !== Union{}, T_old) + T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any + + # Get input buffer configuration + input_buffer_amount = get(spec.options, :stream_input_buffer_amount, 1) + if input_buffer_amount <= 0 + throw(ArgumentError("Input buffering is required; please specify a `stream_input_buffer_amount` greater than 0")) + end + + # Get output buffer configuration + output_buffer_amount = get(spec.options, :stream_output_buffer_amount, 1) + if output_buffer_amount <= 0 + throw(ArgumentError("Output buffering is required; please specify a `stream_output_buffer_amount` greater than 0")) + end + + # Create the Stream + buffer_type = get(spec.options, :stream_buffer_type, ProcessRingBuffer) + stream = Stream{T,buffer_type}(task.uid, input_buffer_amount, output_buffer_amount) + self_streams[task.uid] = stream + + # Get max evaluation count + max_evals = get(spec.options, :stream_max_evals, -1) + if max_evals == 0 + throw(ArgumentError("stream_max_evals cannot be 0")) + end + + # Wrap the function in a StreamingFunction + spec.f = StreamingFunction(spec.f, stream, max_evals) + + # Mark the task as non-blocking + spec.options = merge(spec.options, (;occupancy=Dict(Any=>0))) + + # Register Stream globally + remotecall_wait(1, task.uid, stream) do uid, stream + lock(EAGER_THUNK_STREAMS) do global_streams + global_streams[uid] = stream + end + end +end + +function spawn_streaming(f::Base.Callable) + queue = StreamingTaskQueue() + result = with_options(f; task_queue=queue) + if length(queue.tasks) > 0 + finalize_streaming!(queue.tasks, queue.self_streams) + enqueue!(queue.tasks) + end + return result +end + +struct FinishStream{T,R} + value::Union{Some{T},Nothing} + result::R +end + +finish_stream(value::T; result::R=nothing) where {T,R} = FinishStream{T,R}(Some{T}(value), result) + +finish_stream(; result::R=nothing) where R = FinishStream{Union{},R}(nothing, result) + +const STREAM_THUNK_ID = TaskLocalValue{Int}(()->0) + +chunktype(sf::StreamingFunction{F}) where F = F + +struct StreamMigrating end + +function (sf::StreamingFunction)(args...; kwargs...) + thunk_id = Sch.sch_handle().thunk_id.id + STREAM_THUNK_ID[] = thunk_id + + # Migrate our output stream store to this worker + if sf.stream isa Stream + remote_cancel_token = migrate_stream!(sf.stream) + end + + @label start + @dagdebug thunk_id :stream "Starting StreamingFunction" + worker_id = sf.stream.store_ref.handle.owner # FIXME: Not valid to access the owner directly + result = if worker_id == myid() + _run_streamingfunction(nothing, nothing, sf, args...; kwargs...) + else + tls = get_tls() + remotecall_fetch(_run_streamingfunction, worker_id, tls, remote_cancel_token, sf, args...; kwargs...) + end + if result === StreamMigrating() + @goto start + end + return result +end + +function _run_streamingfunction(tls, cancel_token, sf, args...; kwargs...) + @nospecialize sf args kwargs + + store = sf.stream.store = MemPool.access_ref(identity, sf.stream.store_ref.handle; local_only=true) + @assert isopen(store) + + if tls !== nothing + # Setup TLS on this new task + tls.cancel_token = MemPool.access_ref(identity, cancel_token; local_only=true) + set_tls!(tls) + end + + thunk_id = Sch.sch_handle().thunk_id.id + STREAM_THUNK_ID[] = thunk_id + + # FIXME: Remove when scheduler is distributed + uid = remotecall_fetch(1, thunk_id) do thunk_id + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end + end + + try + # TODO: This kwarg song-and-dance is required to ensure that we don't + # allocate boxes within `stream!`, when possible + kwarg_names = map(name->Val{name}(), map(first, (kwargs...,))) + kwarg_values = map(last, (kwargs...,)) + args = map(arg->initialize_input_stream!(store, arg), args) + kwarg_values = map(kwarg->initialize_input_stream!(store, kwarg), kwarg_values) + return stream!(sf, uid, (args...,), kwarg_names, kwarg_values) + finally + if !sf.stream.store.migrating + # Remove ourself as a waiter for upstream Streams + streams = Set{Stream}() + for (idx, arg) in enumerate(args) + if arg isa Stream + push!(streams, arg) + end + end + for (idx, (pos, arg)) in enumerate(kwargs) + if arg isa Stream + push!(streams, arg) + end + end + for stream in streams + @dagdebug thunk_id :stream "dropping waiter" + remove_waiters!(stream, uid) + @dagdebug thunk_id :stream "dropped waiter" + end + + # Ensure downstream tasks also terminate + close(sf.stream) + @dagdebug thunk_id :stream "closed stream store" + end + end +end + +# N.B We specialize to minimize/eliminate allocations +function stream!(sf::StreamingFunction, uid, + args::Tuple, kwarg_names::Tuple, kwarg_values::Tuple) + f = move(task_processor(), sf.f) + counter = 0 + + while true + # Yield to other (streaming) tasks + yield() + + # Exit streaming on cancellation + task_may_cancel!() + + # Exit streaming on migration + if sf.stream.store.migrating + error("FIXME: max_evals should be retained") + @dagdebug STREAM_THUNK_ID[] :stream "returning for migration" + return StreamMigrating() + end + + # Get values from Stream args/kwargs + stream_args = _stream_take_values!(args) + stream_kwarg_values = _stream_take_values!(kwarg_values) + stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values) + + if length(stream_args) > 0 || length(stream_kwarg_values) > 0 + # Notify tasks that input buffers may have space + @lock sf.stream.store.lock notify(sf.stream.store.lock) + end + + # Run a single cycle of f + counter += 1 + @dagdebug STREAM_THUNK_ID[] :stream "executing $f (eval $counter)" + stream_result = f(stream_args...; stream_kwargs...) + + # Exit streaming on graceful request + if stream_result isa FinishStream + if stream_result.value !== nothing + value = something(stream_result.value) + put!(sf.stream, value) + end + @dagdebug STREAM_THUNK_ID[] :stream "voluntarily returning" + return stream_result.result + end + + # Put the result into the output stream + put!(sf.stream, stream_result) + + # Exit streaming on eval limit + if sf.max_evals > 0 && counter >= sf.max_evals + @dagdebug STREAM_THUNK_ID[] :stream "max evals reached (eval $counter)" + return + end + end +end + +function _stream_take_values!(args) + return ntuple(length(args)) do idx + arg = args[idx] + if arg isa StreamingValue + return take!(arg) + else + return arg + end + end +end + +@inline @generated function _stream_namedtuple(kwarg_names::Tuple, + stream_kwarg_values::Tuple) + name_ex = Expr(:tuple, map(name->QuoteNode(name.parameters[1]), kwarg_names.parameters)...) + NT = :(NamedTuple{$name_ex,$stream_kwarg_values}) + return :($NT(stream_kwarg_values)) +end + +# Default for buffers, can be customized +initialize_stream_buffer(B, T, buffer_amount) = B{T}(buffer_amount) + +const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Any}()) +function task_to_stream(uid::UInt) + if myid() != 1 + return remotecall_fetch(task_to_stream, 1, uid) + end + lock(EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + return global_streams[uid] + end + return + end +end + +function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) + stream_waiter_changes = Dict{UInt,Vector{Pair{UInt,Any}}}() + + for (spec, task) in tasks + @assert haskey(self_streams, task.uid) + our_stream = self_streams[task.uid] + + # Adapt args to accept Stream output of other streaming tasks + for (idx, (pos, arg)) in enumerate(spec.args) + if arg isa DTask + # Check if this is a streaming task + if haskey(self_streams, arg.uid) + other_stream = self_streams[arg.uid] + else + other_stream = task_to_stream(arg.uid) + end + + if other_stream !== nothing + # Generate Stream handle for input + # FIXME: Be configurable + input_fetcher = RemoteChannelFetcher() + other_stream_handle = Stream(other_stream) + spec.args[idx] = pos => other_stream_handle + our_stream.store.input_streams[arg.uid] = other_stream_handle + our_stream.store.input_fetchers[arg.uid] = input_fetcher + + # Add this task as a waiter for the associated output Stream + changes = get!(stream_waiter_changes, arg.uid) do + Pair{UInt,Any}[] + end + push!(changes, task.uid => input_fetcher) + end + end + end + + # Filter out all streaming options + to_filter = (:stream_buffer_type, + :stream_input_buffer_amount, :stream_output_buffer_amount, + :stream_max_evals) + spec.options = NamedTuple(filter(opt -> !(opt[1] in to_filter), + Base.pairs(spec.options))) + if haskey(spec.options, :propagates) + propagates = filter(opt -> !(opt in to_filter), + spec.options.propagates) + spec.options = merge(spec.options, (;propagates)) + end + end + + # Notify Streams of any new waiters + for (uid, waiters) in stream_waiter_changes + stream = task_to_stream(uid) + add_waiters!(stream, waiters) + end +end diff --git a/src/submission.jl b/src/submission.jl index bfb8cb8be..f23539271 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -220,7 +220,7 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) end function DTaskMetadata(spec::DTaskSpec) - f = chunktype(spec.f).instance + f = spec.f isa StreamingFunction ? spec.f.f : spec.f arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) return_type = Base.promote_op(f, arg_types...) return DTaskMetadata(return_type) diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 8b6d3530f..6a71e5c52 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -2,7 +2,8 @@ function istask end function task_id end const DAGDEBUG_CATEGORIES = Symbol[:global, :submit, :schedule, :scope, - :take, :execute, :move, :processor, :cancel] + :take, :execute, :move, :processor, :cancel, + :stream] macro dagdebug(thunk, category, msg, args...) cat_sym = category.value @gensym id diff --git a/test/runtests.jl b/test/runtests.jl index baa21c3a4..79ba890d7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ tests = [ ("Mutation", "mutation.jl"), ("Task Queues", "task-queues.jl"), ("Datadeps", "datadeps.jl"), + ("Streaming", "streaming.jl"), ("Domain Utilities", "domain.jl"), ("Array - Allocation", "array/allocation.jl"), ("Array - Indexing", "array/indexing.jl"), @@ -104,7 +105,6 @@ else @info "Running all tests" end - using Distributed if additional_workers > 0 # We put this inside a branch because addprocs() takes a minimum of 1s to diff --git a/test/streaming.jl b/test/streaming.jl new file mode 100644 index 000000000..9eb01312c --- /dev/null +++ b/test/streaming.jl @@ -0,0 +1,380 @@ +const ACCUMULATOR = Dict{Int,Vector{Real}}() +@everywhere function accumulator(x=0) + tid = Dagger.task_id() + remotecall_wait(1, tid, x) do tid, x + acc = get!(Vector{Real}, ACCUMULATOR, tid) + push!(acc, x) + end + return +end +@everywhere accumulator(xs...) = accumulator(sum(xs)) +@everywhere accumulator(::Nothing) = accumulator(0) + +function catch_interrupt(f) + try + f() + catch err + if err isa Dagger.DTaskFailedException && err.ex isa InterruptException + return + elseif err isa Dagger.Sch.SchedulingException + return + end + rethrow() + end +end + +function merge_testset!(inner::Test.DefaultTestSet) + outer = Test.get_testset() + append!(outer.results, inner.results) + outer.n_passed += inner.n_passed +end + +function test_finishes(f, message::String; timeout=10, ignore_timeout=false, max_evals=10) + t = @eval Threads.@spawn begin + tset = nothing + try + @testset $message begin + try + @testset $message begin + Dagger.with_options(;stream_max_evals=$max_evals) do + catch_interrupt($f) + end + end + finally + tset = Test.get_testset() + end + end + catch + end + return tset + end + + timed_out = timedwait(()->istaskdone(t), timeout) == :timed_out + if timed_out + if !ignore_timeout + @warn "Testing task timed out: $message" + end + Dagger.cancel!(;halt_sch=true, graceful=false) + @everywhere GC.gc() + fetch(Dagger.@spawn 1+1) + end + + tset = fetch(t)::Test.DefaultTestSet + merge_testset!(tset) + return !timed_out +end + +all_scopes = [Dagger.ExactScope(proc) for proc in Dagger.all_processors()] +for idx in 1:5 + if idx == 1 + scopes = [Dagger.scope(worker = 1, thread = 1)] + scope_str = "Worker 1" + elseif idx == 2 && nprocs() > 1 + scopes = [Dagger.scope(worker = 2, thread = 1)] + scope_str = "Worker 2" + else + scopes = all_scopes + scope_str = "All Workers" + end + + @testset "Single Task Control Flow ($scope_str)" begin + @test !test_finishes("Single task running forever"; max_evals=1_000_000, ignore_timeout=true) do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) () -> begin + y = rand() + sleep(1) + return y + end + end + @test_throws_unwrap InterruptException fetch(x) + end + + @test test_finishes("Single task without result") do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + end + @test fetch(x) === nothing + end + + @test test_finishes("Single task with result"; max_evals=1_000_000) do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) () -> begin + x = rand() + if x < 0.1 + return Dagger.finish_stream(x; result=123) + end + return x + end + end + @test fetch(x) == 123 + end + end + + @testset "Non-Streaming Inputs ($scope_str)" begin + @test test_finishes("() -> A") do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(0), values[A_tid]) + end + @test test_finishes("42 -> A") do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator(42) + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(42), values[A_tid]) + end + @test test_finishes("(42, 43) -> A") do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator(42, 43) + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(42 + 43), values[A_tid]) + end + end + + @testset "Non-Streaming Outputs ($scope_str)" begin + @test test_finishes("x -> A") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + end + Dagger._without_options() do + A = Dagger.@spawn accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[A_tid]) + end + + @test test_finishes("x -> (A, B)") do + local x, A, B + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + end + Dagger._without_options() do + A = Dagger.@spawn accumulator(x) + B = Dagger.@spawn accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + @test fetch(B) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[A_tid]) + B_tid = Dagger.task_id(B) + @test length(values[B_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[B_tid]) + end + end + + @testset "Multiple Tasks ($scope_str)" begin + @test test_finishes("x -> A") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 1, values[A_tid]) + end + + @test test_finishes("(x, A)") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(1.0) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> v == 1, values[A_tid]) + end + + @test test_finishes("x -> y -> A") do + local x, y, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x+1 + A = Dagger.@spawn scope=rand(scopes) accumulator(y) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 1 <= v <= 2, values[A_tid]) + end + + @test test_finishes("x -> (y, A)") do + local x, y, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x+1 + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 1, values[A_tid]) + end + + @test test_finishes("(x, y) -> A") do + local x, y, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(x, y) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + end + + @test test_finishes("(x, y) -> z -> A") do + local x, y, z, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + z = Dagger.@spawn scope=rand(scopes) x + y + A = Dagger.@spawn scope=rand(scopes) accumulator(z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + end + + @test test_finishes("x -> (y, z) -> A") do + local x, y, z, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x + 1 + z = Dagger.@spawn scope=rand(scopes) x + 2 + A = Dagger.@spawn scope=rand(scopes) accumulator(y, z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 3 <= v <= 5, values[A_tid]) + end + + @test test_finishes("(x, y) -> z -> (A, B)") do + local x, y, z, A, B + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + z = Dagger.@spawn scope=rand(scopes) x + y + A = Dagger.@spawn scope=rand(scopes) accumulator(z) + B = Dagger.@spawn scope=rand(scopes) accumulator(z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + @test fetch(B) === nothing + + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + B_tid = Dagger.task_id(B) + @test length(values[B_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[B_tid]) + end + + for T in (Float64, Int32, BigFloat) + @test test_finishes("Stream eltype $T") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand(T) + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> v isa T, values[A_tid]) + end + end + end + + @testset "Max Evals ($scope_str)" begin + @test test_finishes("max_evals=0"; max_evals=0) do + @test_throws ArgumentError Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + end + @test test_finishes("max_evals=1"; max_evals=1) do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + end + @test test_finishes("max_evals=100"; max_evals=100) do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 100 + end + end + + # FIXME: Varying buffer amounts + + #= TODO: Zero-allocation test + # First execution of a streaming task will almost guaranteed allocate (compiling, setup, etc.) + # BUT, second and later executions could possibly not allocate any further ("steady-state") + # We want to be able to validate that the steady-state execution for certain tasks is non-allocating + =# +end From cbe64f8c53390179e79837d0773534670afa46df Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 4 Dec 2024 11:25:19 -0600 Subject: [PATCH 21/22] DTask: Add waitany and waitall helpers --- src/dtask.jl | 26 +++++++++++ src/utils/tasks.jl | 112 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+) diff --git a/src/dtask.jl b/src/dtask.jl index 98f74005a..b597db5fa 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -85,6 +85,32 @@ function Base.fetch(t::DTask; raw=false) end return fetch(t.future; raw) end +function waitany(tasks::Vector{DTask}) + if isempty(tasks) + return + end + cond = Threads.Condition() + for task in tasks + Sch.errormonitor_tracked("waitany listener", Threads.@spawn begin + wait(task) + @lock cond notify(cond) + end) + end + @lock cond wait(cond) + return +end +function waitall(tasks::Vector{DTask}) + if isempty(tasks) + return + end + @sync for task in tasks + Threads.@spawn begin + wait(task) + @lock cond notify(cond) + end + end + return +end function Base.show(io::IO, t::DTask) status = if istaskstarted(t) isready(t) ? "finished" : "running" diff --git a/src/utils/tasks.jl b/src/utils/tasks.jl index c2796cf21..ddd8da2ee 100644 --- a/src/utils/tasks.jl +++ b/src/utils/tasks.jl @@ -18,3 +18,115 @@ function set_task_tid!(task::Task, tid::Integer) end @assert Threads.threadid(task) == tid "jl_set_task_tid failed!" end + +if isdefined(Base, :waitany) +import Base: waitany, waitall +else +# Vendored from Base +# License is MIT +waitany(tasks; throw=true) = _wait_multiple(tasks, throw) +waitall(tasks; failfast=true, throw=true) = _wait_multiple(tasks, throw, true, failfast) +function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false) + tasks = Task[] + + for t in waiting_tasks + t isa Task || error("Expected an iterator of `Task` object") + push!(tasks, t) + end + + if (all && !failfast) || length(tasks) <= 1 + exception = false + # Force everything to finish synchronously for the case of waitall + # with failfast=false + for t in tasks + _wait(t) + exception |= istaskfailed(t) + end + if exception && throwexc + exceptions = [TaskFailedException(t) for t in tasks if istaskfailed(t)] + throw(CompositeException(exceptions)) + else + return tasks, Task[] + end + end + + exception = false + nremaining::Int = length(tasks) + done_mask = falses(nremaining) + for (i, t) in enumerate(tasks) + if istaskdone(t) + done_mask[i] = true + exception |= istaskfailed(t) + nremaining -= 1 + else + done_mask[i] = false + end + end + + if nremaining == 0 + return tasks, Task[] + elseif any(done_mask) && (!all || (failfast && exception)) + if throwexc && (!all || failfast) && exception + exceptions = [TaskFailedException(t) for t in tasks[done_mask] if istaskfailed(t)] + throw(CompositeException(exceptions)) + else + return tasks[done_mask], tasks[.~done_mask] + end + end + + chan = Channel{Int}(Inf) + sentinel = current_task() + waiter_tasks = fill(sentinel, length(tasks)) + + for (i, done) in enumerate(done_mask) + done && continue + t = tasks[i] + if istaskdone(t) + done_mask[i] = true + exception |= istaskfailed(t) + nremaining -= 1 + exception && failfast && break + else + waiter = @task put!(chan, i) + waiter.sticky = false + _wait2(t, waiter) + waiter_tasks[i] = waiter + end + end + + while nremaining > 0 + i = take!(chan) + t = tasks[i] + waiter_tasks[i] = sentinel + done_mask[i] = true + exception |= istaskfailed(t) + nremaining -= 1 + + # stop early if requested, unless there is something immediately + # ready to consume from the channel (using a race-y check) + if (!all || (failfast && exception)) && !isready(chan) + break + end + end + + close(chan) + + if nremaining == 0 + return tasks, Task[] + else + remaining_mask = .~done_mask + for i in findall(remaining_mask) + waiter = waiter_tasks[i] + donenotify = tasks[i].donenotify::ThreadSynchronizer + @lock donenotify Base.list_deletefirst!(donenotify.waitq, waiter) + end + done_tasks = tasks[done_mask] + if throwexc && exception + exceptions = [TaskFailedException(t) for t in done_tasks if istaskfailed(t)] + throw(CompositeException(exceptions)) + else + return done_tasks, tasks[remaining_mask] + end + end +end +end From 3c5c389fd29228b6386b766a48a7598277765829 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 4 Dec 2024 11:25:39 -0600 Subject: [PATCH 22/22] streaming: Add DAG teardown option --- docs/src/index.md | 3 +- docs/src/streaming.md | 5 ++- src/stream.jl | 27 +++++++++++++++- test/streaming.jl | 75 +++++++++++++++++++++++++++++++------------ 4 files changed, 85 insertions(+), 25 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 87b4ea174..152b95cc5 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -427,4 +427,5 @@ wait(t) The above example demonstrates a streaming region that generates random numbers continuously and writes each random number to a file. The streaming region is terminated when a random number less than 0.01 is generated, which is done by -calling `Dagger.finish_stream()` (this exits the current streaming task). +calling `Dagger.finish_stream()` (this terminates the current task, and will +also terminate all streaming tasks launched by `spawn_streaming`). diff --git a/docs/src/streaming.md b/docs/src/streaming.md index 25060e1b2..41c111e82 100644 --- a/docs/src/streaming.md +++ b/docs/src/streaming.md @@ -79,9 +79,8 @@ end ``` If you want to stop the streaming DAG and tear it all down, you can call -`Dagger.cancel!.(all_vals)` and `Dagger.cancel!.(all_vals_written)` to -terminate each streaming task. In the future, a more convenient way to tear -down a full DAG will be added; for now, each task must be cancelled individually. +`Dagger.cancel!(all_vals[1])` (or with any other task in the streaming DAG) to +terminate all streaming tasks. Alternatively, tasks can stop themselves from the inside with `finish_stream`, optionally returning a value that can be `fetch`'d. Let's diff --git a/src/stream.jl b/src/stream.jl index 81becd5ac..07a3dae95 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -426,12 +426,37 @@ function initialize_streaming!(self_streams, spec, task) end end -function spawn_streaming(f::Base.Callable) +""" +Starts a streaming region, within which all tasks run continuously and +concurrently. Any `DTask` argument that is itself a streaming task will be +treated as a streaming input/output. The streaming region will automatically +handle the buffering and synchronization of these tasks' values. + +# Keyword Arguments +- `teardown::Bool=true`: If `true`, the streaming region will automatically + cancel all tasks if any task fails or is cancelled. Otherwise, a failing task + will not cancel the other tasks, which will continue running. +""" +function spawn_streaming(f::Base.Callable; teardown::Bool=true) queue = StreamingTaskQueue() result = with_options(f; task_queue=queue) if length(queue.tasks) > 0 finalize_streaming!(queue.tasks, queue.self_streams) enqueue!(queue.tasks) + + if teardown + # Start teardown monitor + dtasks = map(last, queue.tasks)::Vector{DTask} + Sch.errormonitor_tracked("streaming teardown", Threads.@spawn begin + # Wait for any task to finish + waitany(dtasks) + + # Cancel all tasks + for task in dtasks + cancel!(task; graceful=false) + end + end) + end end return result end diff --git a/test/streaming.jl b/test/streaming.jl index 9eb01312c..c3bf0e406 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -80,7 +80,7 @@ for idx in 1:5 @testset "Single Task Control Flow ($scope_str)" begin @test !test_finishes("Single task running forever"; max_evals=1_000_000, ignore_timeout=true) do local x - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) () -> begin y = rand() sleep(1) @@ -92,7 +92,7 @@ for idx in 1:5 @test test_finishes("Single task without result") do local x - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() end @test fetch(x) === nothing @@ -100,7 +100,7 @@ for idx in 1:5 @test test_finishes("Single task with result"; max_evals=1_000_000) do local x - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) () -> begin x = rand() if x < 0.1 @@ -116,7 +116,7 @@ for idx in 1:5 @testset "Non-Streaming Inputs ($scope_str)" begin @test test_finishes("() -> A") do local A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do A = Dagger.@spawn scope=rand(scopes) accumulator() end @test fetch(A) === nothing @@ -127,7 +127,7 @@ for idx in 1:5 end @test test_finishes("42 -> A") do local A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do A = Dagger.@spawn scope=rand(scopes) accumulator(42) end @test fetch(A) === nothing @@ -138,7 +138,7 @@ for idx in 1:5 end @test test_finishes("(42, 43) -> A") do local A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do A = Dagger.@spawn scope=rand(scopes) accumulator(42, 43) end @test fetch(A) === nothing @@ -152,7 +152,7 @@ for idx in 1:5 @testset "Non-Streaming Outputs ($scope_str)" begin @test test_finishes("x -> A") do local x, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() end Dagger._without_options() do @@ -168,7 +168,7 @@ for idx in 1:5 @test test_finishes("x -> (A, B)") do local x, A, B - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() end Dagger._without_options() do @@ -188,10 +188,45 @@ for idx in 1:5 end end + @testset "Teardown" begin + @test test_finishes("teardown=true"; max_evals=1_000_000, ignore_timeout=true) do + local x, y + Dagger.spawn_streaming(;teardown=true) do + x = Dagger.@spawn scope=rand(scopes) () -> begin + sleep(0.1) + return rand() + end + y = Dagger.with_options(;stream_max_evals=10) do + Dagger.@spawn scope=rand(scopes) identity(x) + end + end + @test fetch(y) === nothing + sleep(1) # Wait for teardown + @test istaskdone(x) + fetch(x) + end + @test !test_finishes("teardown=false"; max_evals=1_000_000, ignore_timeout=true) do + local x, y + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) () -> begin + sleep(0.1) + return rand() + end + y = Dagger.with_options(;stream_max_evals=10) do + Dagger.@spawn scope=rand(scopes) identity(x) + end + end + @test fetch(y) === nothing + sleep(1) # Wait to ensure `x` task is still running + @test !istaskdone(x) + @test_throws_unwrap InterruptException fetch(x) + end + end + @testset "Multiple Tasks ($scope_str)" begin @test test_finishes("x -> A") do local x, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() A = Dagger.@spawn scope=rand(scopes) accumulator(x) end @@ -205,7 +240,7 @@ for idx in 1:5 @test test_finishes("(x, A)") do local x, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() A = Dagger.@spawn scope=rand(scopes) accumulator(1.0) end @@ -219,7 +254,7 @@ for idx in 1:5 @test test_finishes("x -> y -> A") do local x, y, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() y = Dagger.@spawn scope=rand(scopes) x+1 A = Dagger.@spawn scope=rand(scopes) accumulator(y) @@ -235,7 +270,7 @@ for idx in 1:5 @test test_finishes("x -> (y, A)") do local x, y, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() y = Dagger.@spawn scope=rand(scopes) x+1 A = Dagger.@spawn scope=rand(scopes) accumulator(x) @@ -251,7 +286,7 @@ for idx in 1:5 @test test_finishes("(x, y) -> A") do local x, y, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() y = Dagger.@spawn scope=rand(scopes) rand() A = Dagger.@spawn scope=rand(scopes) accumulator(x, y) @@ -267,7 +302,7 @@ for idx in 1:5 @test test_finishes("(x, y) -> z -> A") do local x, y, z, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() y = Dagger.@spawn scope=rand(scopes) rand() z = Dagger.@spawn scope=rand(scopes) x + y @@ -285,7 +320,7 @@ for idx in 1:5 @test test_finishes("x -> (y, z) -> A") do local x, y, z, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() y = Dagger.@spawn scope=rand(scopes) x + 1 z = Dagger.@spawn scope=rand(scopes) x + 2 @@ -303,7 +338,7 @@ for idx in 1:5 @test test_finishes("(x, y) -> z -> (A, B)") do local x, y, z, A, B - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() y = Dagger.@spawn scope=rand(scopes) rand() z = Dagger.@spawn scope=rand(scopes) x + y @@ -328,7 +363,7 @@ for idx in 1:5 for T in (Float64, Int32, BigFloat) @test test_finishes("Stream eltype $T") do local x, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand(T) A = Dagger.@spawn scope=rand(scopes) accumulator(x) end @@ -344,13 +379,13 @@ for idx in 1:5 @testset "Max Evals ($scope_str)" begin @test test_finishes("max_evals=0"; max_evals=0) do - @test_throws ArgumentError Dagger.spawn_streaming() do + @test_throws ArgumentError Dagger.spawn_streaming(;teardown=false) do A = Dagger.@spawn scope=rand(scopes) accumulator() end end @test test_finishes("max_evals=1"; max_evals=1) do local A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do A = Dagger.@spawn scope=rand(scopes) accumulator() end @test fetch(A) === nothing @@ -360,7 +395,7 @@ for idx in 1:5 end @test test_finishes("max_evals=100"; max_evals=100) do local A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do A = Dagger.@spawn scope=rand(scopes) accumulator() end @test fetch(A) === nothing