Skip to content

Commit

Permalink
Use @pattern instead of @dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Jun 5, 2024
1 parent 9f30500 commit 3e9dcf9
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 65 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ while the second uses [SumTypes.jl](https://github.com/MasonProtter/SumTypes.jl)
which is more memory efficient and allows to mix mutable and immutable structs.

Even if there is only a unique type defined by this macro, you can access a symbol containing the conceptual
type of an instance with the function `kindof` and use the `@dispatch` macro to define functions which
type of an instance with the function `kindof` and use the `@pattern` macro to define functions which
can operate differently on each kind.

## Construct mixed structs
Expand Down Expand Up @@ -96,7 +96,7 @@ There are currently two ways to define function on the types created
with this package:

- Use manual branching;
- Use the `@dispatch` macro.
- Use the `@pattern` macro.

For example, let's say we want to create a sum function where different values are added
depending on the kind of each element in a vector:
Expand Down Expand Up @@ -131,7 +131,7 @@ julia> value_D() = 3;

julia> value_E() = 4;

julia> function sum2(v) # with @dispatch macro
julia> function sum2(v) # with @pattern macro
s = 0
for x in v
s += value(x)
Expand All @@ -140,13 +140,13 @@ julia> function sum2(v) # with @dispatch macro
end
sum2 (generic function with 1 method)

julia> @dispatch value(::B) = 1;
julia> @pattern value(::B) = 1;

julia> @dispatch value(::C) = 2;
julia> @pattern value(::C) = 2;

julia> @dispatch value(::D) = 3;
julia> @pattern value(::D) = 3;

julia> @dispatch value(::E) = 4;
julia> @pattern value(::E) = 4;

julia> sum1(v)
2499517
Expand All @@ -155,12 +155,12 @@ julia> sum2(v)
2499517
```

As you can see the version using the `@dispatch` macro is much less verbose and more intuitive. In some more
As you can see the version using the `@pattern` macro is much less verbose and more intuitive. In some more
advanced cases the verbosity of the first approach could be even stronger.

Since the macro essentially reconstruct the branching version described above, to ensure that everything will
work correctly when using it, do not define functions operating on the main type of a mixed struct without
using the `@dispatch` macro.
using the `@pattern` macro.

Consult the [API page](https://juliadynamics.github.io/DynamicSumTypes.jl/stable/) for more information on
the available functionalities.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

```@docs
@sum_structs
@dispatch
@pattern
kindof
allkinds
kindconstructor
Expand Down
4 changes: 2 additions & 2 deletions src/DynamicSumTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using MacroTools
using SumTypes

export @sum_structs
export @dispatch
export @pattern
export kindof
export allkinds
export kindconstructor
Expand Down Expand Up @@ -73,7 +73,7 @@ function kindconstructor end

include("SumStructsSpeed.jl")
include("SumStructsMem.jl")
include("Dispatch.jl")
include("Pattern.jl")
include("precompile.jl")

end
44 changes: 22 additions & 22 deletions src/Dispatch.jl → src/Pattern.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

"""
@dispatch(function_definition)
@pattern(function_definition)
This macro allows to dispatch on types created by [`@sum_structs`](@ref).
This macro allows to pattern on types created by [`@sum_structs`](@ref).
Notice that this only works when the kinds in the macro are not wrapped
by any type containing them.
Expand All @@ -15,14 +15,14 @@ julia> @sum_structs AB begin
struct B y::Int end
end
julia> @dispatch f(::A) = 1;
julia> @pattern f(::A) = 1;
julia> @dispatch f(::B) = 2;
julia> @pattern f(::B) = 2;
julia> @dispatch f(::Vector{AB}) = 3; # this works
julia> @pattern f(::Vector{AB}) = 3; # this works
julia> @dispatch f(::Vector{B}) = 3; # this doesn't work
ERROR: LoadError: It is not possible to dispatch on a variant wrapped in another type
julia> @pattern f(::Vector{B}) = 3; # this doesn't work
ERROR: LoadError: It is not possible to pattern on a variant wrapped in another type
...
julia> f(A(0))
Expand All @@ -35,10 +35,10 @@ julia> f([A(0), B(0)])
3
```
"""
macro dispatch(f_def)
macro pattern(f_def)
vtc = __variants_types_cache__[__module__]
vtwpc = __variants_types_with_params_cache__[__module__]
f_sub, f_super_dict, f_cache = _dispatch(f_def, vtc, vtwpc)
f_sub, f_super_dict, f_cache = _pattern(f_def, vtc, vtwpc)

if f_super_dict == nothing
return Expr(:toplevel, esc(f_sub))
Expand All @@ -52,8 +52,8 @@ macro dispatch(f_def)
end

if is_first
expr_m = :(module Methods_Dispatch_Module_219428042303
const __dispatch_cache__ = Dict{Any, Any}()
expr_m = :(module Methods_Pattern_Module_219428042303
const __pattern_cache__ = Dict{Any, Any}()
function __init__()
define_all()
end
Expand All @@ -73,15 +73,15 @@ macro dispatch(f_def)
expr_d = :(DynamicSumTypes.define_f_super($(__module__), $(QuoteNode(f_super_dict)), $(QuoteNode(f_cache))))
expr_fire = quote
if isinteractive() && (@__MODULE__) == Main
Methods_Dispatch_Module_219428042303.define_all()
Methods_Pattern_Module_219428042303.define_all()
$(f_super_dict[:name])
end
end

return Expr(:toplevel, esc(f_sub), esc(expr_m), esc(expr_d), esc(expr_fire))
end

function _dispatch(f_def, vtc, vtwpc)
function _pattern(f_def, vtc, vtwpc)
macros = []
while f_def.head == :macrocall
f_def_comps = rmlines(f_def.args)
Expand All @@ -103,12 +103,12 @@ function _dispatch(f_def, vtc, vtwpc)

for k in keys(vtc)
if any(a -> inexpr(a[2], k) && !(a[1] in idxs_mvtc), enumerate(f_args_t))
error("It is not possible to dispatch on a variant wrapped in another type")
error("It is not possible to pattern on a variant wrapped in another type")
end
end

if !isempty(idxs_mctc) && !isempty(idxs_mvtc)
error("Dispatching on overall types and variants at the same time is not supported")
error("Using `@pattern` with signatures containing sum types and variants at the same time is not supported")
end

if !any(a -> a in keys(vtc) || a in values(vtc), f_args_n)
Expand Down Expand Up @@ -241,7 +241,7 @@ function _dispatch(f_def, vtc, vtwpc)
f_super_dict[:macros] = macros
f_super_dict[:condition] = new_cond
f_super_dict[:subcall] = :(return $(f_sub_dict[:name])($(g_args_names...)))
f_sub_name_default = Symbol(f_comps[:name], :_sub_, collect(Iterators.flatten(all_types_args1))..., :_, length(f_args))
f_sub_name_default = Symbol(Symbol("##"), f_comps[:name], :_, collect(Iterators.flatten(all_types_args1))...)
f_super_dict[:subcall_default] = :(return $(f_sub_name_default)($(g_args_names...)))

return f_sub, f_super_dict, f_cache
Expand All @@ -255,7 +255,7 @@ end

function define_f_sub(whereparams, f_comps, all_types_args0, f_args)
f_sub_dict = Dict{Symbol, Any}()
f_sub_name = Symbol(f_comps[:name], :_sub_, collect(Iterators.flatten(all_types_args0))..., :_, length(f_args))
f_sub_name = Symbol(Symbol("##"), f_comps[:name], :_, collect(Iterators.flatten(all_types_args0))...)
f_sub_dict[:name] = f_sub_name
f_sub_dict[:args] = f_args
f_sub_dict[:kwargs] = :kwargs in keys(f_comps) ? f_comps[:kwargs] : []
Expand All @@ -268,7 +268,7 @@ function inspect_sig end

function define_f_super(mod, f_super_dict, f_cache)
f_name = f_super_dict[:name]
cache = mod.Methods_Dispatch_Module_219428042303.__dispatch_cache__
cache = mod.Methods_Pattern_Module_219428042303.__pattern_cache__
if !(f_name in keys(cache))
cache[f_name] = Dict{Any, Any}(f_cache => [f_super_dict])
else
Expand All @@ -291,7 +291,7 @@ function define_f_super(mod, f_super_dict, f_cache)
end

function generate_defs(mod)
cache = mod.Methods_Dispatch_Module_219428042303.__dispatch_cache__
cache = mod.Methods_Pattern_Module_219428042303.__pattern_cache__
return generate_defs(mod, cache)
end

Expand All @@ -302,9 +302,9 @@ function generate_defs(mod, cache)
new_d = Dict{Symbol, Any}()
new_d[:args] = ds[end][:args]
new_d[:name] = ds[end][:name]
!allequal(d[:whereparams] for d in ds) && error("Parameters in where {...} should be the same for all @dispatch methods with same signature")
!allequal(d[:whereparams] for d in ds) && error("Parameters in where {...} should be the same for all @pattern methods with same signature")
new_d[:whereparams] = ds[end][:whereparams]
!allequal(d[:kwargs] for d in ds) && error("Keyword arguments should be the same for all @dispatch methods with same signature")
!allequal(d[:kwargs] for d in ds) && error("Keyword arguments should be the same for all @pattern methods with same signature")
new_d[:kwargs] = ds[end][:kwargs]
default = findfirst(d -> d[:condition] == nothing, ds)
subcall_default = nothing
Expand All @@ -327,7 +327,7 @@ function generate_defs(mod, cache)
end
new_d[:body] = quote $body end
new_df = mod.DynamicSumTypes.ExprTools.combinedef(new_d)
!allequal(d[:macros] for d in ds) && error("Applied macros should be the same for all @dispatch methods with same signature")
!allequal(d[:macros] for d in ds) && error("Applied macros should be the same for all @pattern methods with same signature")
for m in ds[end][:macros]
new_df = Expr(:macrocall, m, :(), new_df)
end
Expand Down
8 changes: 4 additions & 4 deletions src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ using PrecompileTools

_compact_structs(type, struct_defs, vtc, vtwpc)
_sum_structs(type, struct_defs, vtc, vtwpc)
_dispatch(f0, vtc, vtwpc)
_dispatch(f1, vtc, vtwpc)
_dispatch(f2, vtc, vtwpc)
f_sub, f_super_dict, f_cache = _dispatch(f3, vtc, vtwpc)
_pattern(f0, vtc, vtwpc)
_pattern(f1, vtc, vtwpc)
_pattern(f2, vtc, vtwpc)
f_sub, f_super_dict, f_cache = _pattern(f3, vtc, vtwpc)
cache = Dict{Any, Any}()
cache[:f] = Dict{Any, Any}(f_cache => [f_super_dict])
generate_defs(parentmodule(@__MODULE__), cache)
Expand Down
52 changes: 26 additions & 26 deletions test/dispatch_macro_tests.jl → test/pattern_macro_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,41 +34,41 @@ end
end
end

@dispatch g(x::X, q, a::X) = -10
@dispatch g(x::B1, q, a::A1) = -1
@dispatch g(x::B1, q::Int, a::A1) = 0
@dispatch g(x::B1, q::Int, b::B1) = 1
@dispatch g(x::B1, q::Int, c::C1) = 2
@dispatch g(a::A1, q::Int, c::B1) = 3
@pattern g(x::X, q, a::X) = -10
@pattern g(x::B1, q, a::A1) = -1
@pattern g(x::B1, q::Int, a::A1) = 0
@pattern g(x::B1, q::Int, b::B1) = 1
@pattern g(x::B1, q::Int, c::C1) = 2
@pattern g(a::A1, q::Int, c::B1) = 3

@dispatch g(a::A1, q::Int, c::B1{Int}; s = 1) = 10 + s
@dispatch g(a::A1, q::Int, c::C1{Int}; s = 1) = 11 + s
@dispatch g(a::X, q::Int, c::X{DynamicSumTypes.Uninitialized, Int}; s = 1) = 12 + s
@pattern g(a::A1, q::Int, c::B1{Int}; s = 1) = 10 + s
@pattern g(a::A1, q::Int, c::C1{Int}; s = 1) = 11 + s
@pattern g(a::X, q::Int, c::X{DynamicSumTypes.Uninitialized, Int}; s = 1) = 12 + s

@dispatch g(x::X, q::Vararg{Int, 2}) = 1000
@dispatch g(x::A1, q::Vararg{Int, 2}) = 1001
@pattern g(x::X, q::Vararg{Int, 2}) = 1000
@pattern g(x::A1, q::Vararg{Int, 2}) = 1001

@dispatch g(x::X, q::Vararg{Any, N}) where N = 2000
@dispatch g(x::A1, q::Vararg{Any, N}) where N = 2001
@pattern g(x::X, q::Vararg{Any, N}) where N = 2000
@pattern g(x::A1, q::Vararg{Any, N}) where N = 2001

@dispatch g(a::E1, b::Int, c::D1) = 0
@dispatch g(a::E1, b::Int, c::E1) = 1
@dispatch g(a::E1, b::Int, c::F1) = 2
@dispatch g(a::D1, b::Int, c::E1) = 3
@dispatch g(a::E1, b::Int, c::F1) = 4
@pattern g(a::E1, b::Int, c::D1) = 0
@pattern g(a::E1, b::Int, c::E1) = 1
@pattern g(a::E1, b::Int, c::F1) = 2
@pattern g(a::D1, b::Int, c::E1) = 3
@pattern g(a::E1, b::Int, c::F1) = 4

@dispatch g(a::B1, b::Int, c::Vector{<:X}) = c
@pattern g(a::B1, b::Int, c::Vector{<:X}) = c

@dispatch g(a::H1{Int}, b::G1{Int}, c::I1{Int}) = a.a + c.b
@dispatch g(a::G1{Int}, b::G1{Int}, c::I1{Int}) = c.b
@dispatch g(a::H1{Float64}, b::G1{Float64}, c::I1{Float64}) = a.a
@dispatch g(a::X, q::Int, c::X{Int}; s = 1) = 12 + s
@pattern g(a::H1{Int}, b::G1{Int}, c::I1{Int}) = a.a + c.b
@pattern g(a::G1{Int}, b::G1{Int}, c::I1{Int}) = c.b
@pattern g(a::H1{Float64}, b::G1{Float64}, c::I1{Float64}) = a.a
@pattern g(a::X, q::Int, c::X{Int}; s = 1) = 12 + s

@dispatch t(::A1) = 100
@pattern t(::A1) = 100

Methods_Dispatch_Module_219428042303.define_all()
Methods_Pattern_Module_219428042303.define_all()

@testset "@dispatch" begin
@testset "@pattern" begin

a, b1, b2, c = A1(), B1(0.0, 0.0), B1(1.0, 1.0), C1(1.0)

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ using DynamicSumTypes
include("package_sanity_tests.jl")
include("sum_structs_memory_macro_tests.jl")
include("sum_structs_speed_macro_tests.jl")
include("dispatch_macro_tests.jl")
include("pattern_macro_tests.jl")
end

0 comments on commit 3e9dcf9

Please sign in to comment.