Skip to content

Commit

Permalink
support for kwdef macro in sum_struct_type macro
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Jan 21, 2024
1 parent fe4eade commit 9b4b899
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 10 deletions.
45 changes: 40 additions & 5 deletions src/SumStructTypes.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@

macro sum_struct_type(type, struct_defs)
macro sum_struct_type(type, struct_defs = nothing)

if struct_defs === nothing
is_kwdef = true
type, struct_defs = type.args[end-1:end]
else
is_kwdef = false
end

struct_defs = [x for x in struct_defs.args if !(x isa LineNumberNode)]

Expand All @@ -25,6 +32,15 @@ macro sum_struct_type(type, struct_defs)
struct_defs[i] = d_new
end

fields_each, default_each = [], []
for a_spec in struct_defs
a_comps = decompose_struct_no_base(a_spec)
push!(fields_each, a_comps[2][1])
push!(default_each, a_comps[2][2])
end

struct_defs = [:(@kwdef $d) for d in struct_defs]

variants_defs = [:($t(ht::$ht)) for (t, ht) in zip(variants_types, hidden_struct_types)]

expr_sum_type = :(SumTypes.@sum_type $type begin
Expand Down Expand Up @@ -94,25 +110,44 @@ macro sum_struct_type(type, struct_defs)

expr_constructors = []

for (d, t, h_t) in zip(struct_defs, variants_types, hidden_struct_types)
f_d = [x for x in d.args[3].args if !(x isa LineNumberNode)]
f_d_n = retrieve_fields_names(f_d, false)
f_d_n_t = retrieve_fields_names(f_d, true)
for (fs, fd, t, h_t) in zip(fields_each, default_each, variants_types, hidden_struct_types)
f_d_n = retrieve_fields_names(fs, false)
f_d_n_t = retrieve_fields_names(fs, true)
c = @capture(t, t_n_{t_p__})
a_spec_n_d = [d != "#328723329" ? Expr(:kw, n, d) : (:($n))
for (n, d) in zip(f_d_n, fd)]
f_params_kwargs = Expr(:parameters, a_spec_n_d...)
if t_p !== nothing
c1 = :(function $t($(f_d_n...)) where {$(t_p...)}
return $t($(namify(h_t))($(f_d_n...)))
end
)
c4 = :()
if is_kwdef
c4 = :(function $t($(f_params_kwargs)) where {$(t_p...)}
return $t($(namify(h_t))($(f_d_n...)))
end
)
end
else
c1 = :()
c4 = :()
end
c2 = :(function $(namify(t))($(f_d_n...))
return $(namify(t))($(namify(h_t))($(f_d_n...)))
end
)
c3 = :()
if is_kwdef
c3 = :(function $(namify(t))($(f_params_kwargs))
return $(namify(t))($(namify(h_t))($(f_d_n...)))
end
)
end
push!(expr_constructors, c1)
push!(expr_constructors, c2)
push!(expr_constructors, c3)
push!(expr_constructors, c4)
end

expr = quote
Expand Down
16 changes: 11 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ using MixedStructTypes
end
end

@sum_struct_type Animal{T,N,J} begin
@sum_struct_type @kwdef Animal{T,N,J} begin
mutable struct Wolf{T,N}
energy::T
energy::T = 0.5
ground_speed::N
const fur_color::Symbol
end
mutable struct Hawk{T,N,J}
energy::T
energy::T = 0.1
ground_speed::N
flight_speed::J
end
Expand Down Expand Up @@ -59,16 +59,22 @@ end
@test kindof(b) == :B

hawk_1 = Hawk(1.0, 2.0, 3)
hawk_2 = Hawk(; ground_speed = 2.3, flight_speed = 2)
wolf_1 = Wolf(2.0, 3.0, :black)
wolf_2 = Wolf(; ground_speed = 2.0, fur_color = :white)

@test hawk_1.energy == 1.0
@test hawk_2.energy == 0.1
@test wolf_1.energy == 2.0
@test wolf_2.energy == 0.5
@test hawk_1.flight_speed == 3
@test hawk_2.flight_speed == 2
@test wolf_1.fur_color == :black
@test wolf_2.fur_color == :white
@test_throws "" hawk_1.fur_color
@test_throws "" wolf_1.flight_speed
@test kindof(hawk_1) == :Hawk
@test kindof(wolf_1) == :Wolf
@test kindof(hawk_1) == kindof(hawk_2) == :Hawk
@test kindof(wolf_1) == kindof(wolf_2) == :Wolf

b = SimpleA(1)
c = SimpleB(2)
Expand Down

0 comments on commit 9b4b899

Please sign in to comment.