Skip to content

Commit

Permalink
Make (almost) every constrains to parameters possible (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar authored Feb 10, 2024
1 parent b5834d2 commit dcadfa0
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "MixedStructTypes"
uuid = "3d69f371-6fa5-5add-b11c-3293622cad62"
version = "0.2.2"
version = "0.2.3"

[deps]
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expand Down
15 changes: 11 additions & 4 deletions src/CompactStructs.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@


struct Uninitialized end
const uninit = Uninitialized()

Expand Down Expand Up @@ -54,7 +55,10 @@ macro compact_structs(new_type, struct_defs)
field_type = is_mutable ? Expr(:const, :($(gensym_type)::Symbol)) : (:($(gensym_type)::Symbol))

expr_comp_types = [Expr(:struct, false, t, :(begin sdfnsdfsdfak() = 1 end)) for t in types_each]
expr_new_type = Expr(:struct, is_mutable, :($new_type <: $abstract_type),
type_name = new_type isa Symbol ? new_type : new_type.args[1]
uninit_val = :(MixedStructTypes.Uninitialized)
compact_t = MacroTools.postwalk(s -> s isa Expr && s.head == :(<:) ? make_union_uninit(s, type_name, uninit_val) : s, new_type)
expr_new_type = Expr(:struct, is_mutable, :($compact_t <: $abstract_type),
:(begin
$(all_fields_transf...)
$field_type
Expand All @@ -73,6 +77,7 @@ macro compact_structs(new_type, struct_defs)
if new_type_p === nothing
new_type_n, new_type_p = new_type, []
end
new_type_p = [t isa Expr && t.head == :(<:) ? t.args[1] : t for t in new_type_p]
f_params_args_with_T = [!any(p -> inexpr(x, p), new_type_p) ? (x isa Symbol ? x : x.args[1]) : x
for x in f_params_args_with_T]
struct_spec_n2_d = [d != "#328723329" ? Expr(:kw, n, d) : (:($n))
Expand All @@ -90,7 +95,9 @@ macro compact_structs(new_type, struct_defs)

@capture(struct_t, struct_t_n_{struct_t_p__})
struct_t_p === nothing && (struct_t_p = [])
new_type_p = [t in struct_t_p ? t : (:(MixedStructTypes.Uninitialized))
struct_t_p_no_sup = [p isa Expr && p.head == :(<:) ? p.args[1] : p for p in struct_t_p]
struct_t_arg = struct_t_p_no_sup != [] ? :($struct_t_n{$(struct_t_p_no_sup...)}) : struct_t
new_type_p = [t in struct_t_p_no_sup ? t : (:(MixedStructTypes.Uninitialized))
for t in new_type_p]

expr_function_kwargs = :()
Expand All @@ -115,7 +122,7 @@ macro compact_structs(new_type, struct_defs)
return $new_type_n{$(new_type_p...)}($(f_inside_args...))
end)
if !isempty(struct_t_p)
expr_function_args2 = :(function $(struct_t)($(f_params_args...)) where {$(struct_t_p...)}
expr_function_args2 = :(function $(struct_t_arg)($(f_params_args...)) where {$(struct_t_p...)}
return $new_type_n{$(new_type_p...)}($(f_inside_args2...))
end)
end
Expand All @@ -126,7 +133,7 @@ macro compact_structs(new_type, struct_defs)
end)
if !isempty(struct_t_p)
expr_function_kwargs2 = :(
function $(struct_t)($f_params_kwargs) where {$(struct_t_p...)}
function $(struct_t_arg)($f_params_kwargs) where {$(struct_t_p...)}
return $new_type_n{$(new_type_p...)}($(f_inside_args2...))
end)
end
Expand Down
22 changes: 15 additions & 7 deletions src/SumStructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ macro sum_structs(type, struct_defs)
t = d.args[2]
c = @capture(t, t_n_{t_p__})
c == false && ((t_n, t_p) = (t, []))
push!(variants_types, t)
t_p_no_sup = [p isa Expr && p.head == :(<:) ? p.args[1] : p for p in t_p]
push!(variants_types, t_p != [] ? :($t_n{$(t_p_no_sup...)}) : t_n)
h_t = gensym(t_n)
if t_p != []
h_t = :($h_t{$(t_p...)})
if t_p_no_sup != []
h_t = :($h_t{$(t_p_no_sup...)})
end
push!(hidden_struct_types, h_t)
d_new = MacroTools.postwalk(s -> s == t ? h_t : s, d)
for p in t_p
for p in t_p_no_sup
p_u = gensym(p)
d_new = MacroTools.postwalk(s -> s == p ? p_u : s, d_new)
end
Expand All @@ -54,7 +55,10 @@ macro sum_structs(type, struct_defs)

variants_defs = [:($t(ht::$ht)) for (t, ht) in zip(variants_types, hidden_struct_types)]

expr_sum_type = :(MixedStructTypes.SumTypes.@sum_type $type begin
type_name = type isa Symbol ? type : type.args[1]
uninit_val = :(MixedStructTypes.SumTypes.Uninit)
sum_t = MacroTools.postwalk(s -> s isa Expr && s.head == :(<:) ? make_union_uninit(s, type_name, uninit_val) : s, type)
expr_sum_type = :(MixedStructTypes.SumTypes.@sum_type $sum_t begin
$(variants_defs...)
end)
expr_sum_type = macroexpand(__module__, expr_sum_type)
Expand Down Expand Up @@ -215,6 +219,12 @@ function print_transform(x)
return x
end

function make_union_uninit(s, type_name, uninit_val)
s.args[1] == type_name && return s
s.args[2] = :(Union{$(s.args[2]), $uninit_val})
return s
end

function remove_redefinitions(e, t, vs, fs)

redef = [:($(Base).show), :(($Base).getproperty), :(($Base).propertynames)]
Expand All @@ -239,5 +249,3 @@ end

retrieve_type(::MixedStructTypes.SumTypes.Variant{T}) where T = T
retrieve_hidden_type(::MixedStructTypes.SumTypes.Variant{T,F,HT} where {T,F}) where HT = eltype(HT)


10 changes: 7 additions & 3 deletions test/compact_structs_macro_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
struct ST1 end
end

@compact_structs E{X,Y} begin
@kwdef mutable struct F{X}
@compact_structs E{X<:Real,Y<:Real} begin
@kwdef mutable struct F{X<:Int}
a::Tuple{X, X}
b::Tuple{Float64, Float64}
const c::Symbol
Expand All @@ -15,7 +15,7 @@ end
e::Bool
const c::Symbol
end
@kwdef mutable struct H{X,Y}
@kwdef mutable struct H{X,Y<:Real}
a::Tuple{X, X}
f::Y
g::Tuple{Complex, Complex}
Expand Down Expand Up @@ -90,6 +90,10 @@ end
g1 = G((1,1), 1, 1, :c)
g2 = G(; a = (1,1), d = 1, e = 1, c = :c)

@test_throws "" F((1.0,1.0), (1.0, 1.0), :s)
@test_throws "" G((1,1), im, (im, im), :d)
@test_throws "" G((im,im), 1, (im, im), :d)

@test f.a == (1,1)
@test f.b == (1.0, 1.0)
@test f.c == :s
Expand Down
7 changes: 5 additions & 2 deletions test/sum_structs_macro_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
struct ST2 end
end

@sum_structs A{X,Y} begin
@sum_structs A{X,Y<:Real} begin
@kwdef mutable struct B{X}
a::Tuple{X, X}
b::Tuple{Float64, Float64}
Expand Down Expand Up @@ -68,7 +68,10 @@ end
b = B((1,1), (1.0, 1.0), :s)
c1 = C((1,1), 1, 1, :c)
c2 = C(; a = (1,1), d = 1, e = 1, c = :c)

d = D((1,1), 1, (im, im), :d)

@test_throws "" D((1,1), im, (im, im), :d)

@test b.a == (1,1)
@test b.b == (1.0, 1.0)
@test b.c == :s
Expand Down

0 comments on commit dcadfa0

Please sign in to comment.