diff --git a/Project.toml b/Project.toml index 6ba8586..0315ac8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "MixedStructTypes" uuid = "3d69f371-6fa5-5add-b11c-3293622cad62" -version = "0.2.2" +version = "0.2.3" [deps] ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" diff --git a/src/CompactStructs.jl b/src/CompactStructs.jl index 03fc524..2eccc49 100644 --- a/src/CompactStructs.jl +++ b/src/CompactStructs.jl @@ -1,4 +1,5 @@ + struct Uninitialized end const uninit = Uninitialized() @@ -54,7 +55,10 @@ macro compact_structs(new_type, struct_defs) field_type = is_mutable ? Expr(:const, :($(gensym_type)::Symbol)) : (:($(gensym_type)::Symbol)) expr_comp_types = [Expr(:struct, false, t, :(begin sdfnsdfsdfak() = 1 end)) for t in types_each] - expr_new_type = Expr(:struct, is_mutable, :($new_type <: $abstract_type), + type_name = new_type isa Symbol ? new_type : new_type.args[1] + uninit_val = :(MixedStructTypes.Uninitialized) + compact_t = MacroTools.postwalk(s -> s isa Expr && s.head == :(<:) ? make_union_uninit(s, type_name, uninit_val) : s, new_type) + expr_new_type = Expr(:struct, is_mutable, :($compact_t <: $abstract_type), :(begin $(all_fields_transf...) $field_type @@ -73,6 +77,7 @@ macro compact_structs(new_type, struct_defs) if new_type_p === nothing new_type_n, new_type_p = new_type, [] end + new_type_p = [t isa Expr && t.head == :(<:) ? t.args[1] : t for t in new_type_p] f_params_args_with_T = [!any(p -> inexpr(x, p), new_type_p) ? (x isa Symbol ? x : x.args[1]) : x for x in f_params_args_with_T] struct_spec_n2_d = [d != "#328723329" ? Expr(:kw, n, d) : (:($n)) @@ -90,7 +95,9 @@ macro compact_structs(new_type, struct_defs) @capture(struct_t, struct_t_n_{struct_t_p__}) struct_t_p === nothing && (struct_t_p = []) - new_type_p = [t in struct_t_p ? t : (:(MixedStructTypes.Uninitialized)) + struct_t_p_no_sup = [p isa Expr && p.head == :(<:) ? p.args[1] : p for p in struct_t_p] + struct_t_arg = struct_t_p_no_sup != [] ? :($struct_t_n{$(struct_t_p_no_sup...)}) : struct_t + new_type_p = [t in struct_t_p_no_sup ? t : (:(MixedStructTypes.Uninitialized)) for t in new_type_p] expr_function_kwargs = :() @@ -115,7 +122,7 @@ macro compact_structs(new_type, struct_defs) return $new_type_n{$(new_type_p...)}($(f_inside_args...)) end) if !isempty(struct_t_p) - expr_function_args2 = :(function $(struct_t)($(f_params_args...)) where {$(struct_t_p...)} + expr_function_args2 = :(function $(struct_t_arg)($(f_params_args...)) where {$(struct_t_p...)} return $new_type_n{$(new_type_p...)}($(f_inside_args2...)) end) end @@ -126,7 +133,7 @@ macro compact_structs(new_type, struct_defs) end) if !isempty(struct_t_p) expr_function_kwargs2 = :( - function $(struct_t)($f_params_kwargs) where {$(struct_t_p...)} + function $(struct_t_arg)($f_params_kwargs) where {$(struct_t_p...)} return $new_type_n{$(new_type_p...)}($(f_inside_args2...)) end) end diff --git a/src/SumStructs.jl b/src/SumStructs.jl index 6a75a0e..ef8d0c3 100644 --- a/src/SumStructs.jl +++ b/src/SumStructs.jl @@ -29,14 +29,15 @@ macro sum_structs(type, struct_defs) t = d.args[2] c = @capture(t, t_n_{t_p__}) c == false && ((t_n, t_p) = (t, [])) - push!(variants_types, t) + t_p_no_sup = [p isa Expr && p.head == :(<:) ? p.args[1] : p for p in t_p] + push!(variants_types, t_p != [] ? :($t_n{$(t_p_no_sup...)}) : t_n) h_t = gensym(t_n) - if t_p != [] - h_t = :($h_t{$(t_p...)}) + if t_p_no_sup != [] + h_t = :($h_t{$(t_p_no_sup...)}) end push!(hidden_struct_types, h_t) d_new = MacroTools.postwalk(s -> s == t ? h_t : s, d) - for p in t_p + for p in t_p_no_sup p_u = gensym(p) d_new = MacroTools.postwalk(s -> s == p ? p_u : s, d_new) end @@ -54,7 +55,10 @@ macro sum_structs(type, struct_defs) variants_defs = [:($t(ht::$ht)) for (t, ht) in zip(variants_types, hidden_struct_types)] - expr_sum_type = :(MixedStructTypes.SumTypes.@sum_type $type begin + type_name = type isa Symbol ? type : type.args[1] + uninit_val = :(MixedStructTypes.SumTypes.Uninit) + sum_t = MacroTools.postwalk(s -> s isa Expr && s.head == :(<:) ? make_union_uninit(s, type_name, uninit_val) : s, type) + expr_sum_type = :(MixedStructTypes.SumTypes.@sum_type $sum_t begin $(variants_defs...) end) expr_sum_type = macroexpand(__module__, expr_sum_type) @@ -215,6 +219,12 @@ function print_transform(x) return x end +function make_union_uninit(s, type_name, uninit_val) + s.args[1] == type_name && return s + s.args[2] = :(Union{$(s.args[2]), $uninit_val}) + return s +end + function remove_redefinitions(e, t, vs, fs) redef = [:($(Base).show), :(($Base).getproperty), :(($Base).propertynames)] @@ -239,5 +249,3 @@ end retrieve_type(::MixedStructTypes.SumTypes.Variant{T}) where T = T retrieve_hidden_type(::MixedStructTypes.SumTypes.Variant{T,F,HT} where {T,F}) where HT = eltype(HT) - - diff --git a/test/compact_structs_macro_tests.jl b/test/compact_structs_macro_tests.jl index e5a1737..5318678 100644 --- a/test/compact_structs_macro_tests.jl +++ b/test/compact_structs_macro_tests.jl @@ -3,8 +3,8 @@ struct ST1 end end -@compact_structs E{X,Y} begin - @kwdef mutable struct F{X} +@compact_structs E{X<:Real,Y<:Real} begin + @kwdef mutable struct F{X<:Int} a::Tuple{X, X} b::Tuple{Float64, Float64} const c::Symbol @@ -15,7 +15,7 @@ end e::Bool const c::Symbol end - @kwdef mutable struct H{X,Y} + @kwdef mutable struct H{X,Y<:Real} a::Tuple{X, X} f::Y g::Tuple{Complex, Complex} @@ -90,6 +90,10 @@ end g1 = G((1,1), 1, 1, :c) g2 = G(; a = (1,1), d = 1, e = 1, c = :c) + @test_throws "" F((1.0,1.0), (1.0, 1.0), :s) + @test_throws "" G((1,1), im, (im, im), :d) + @test_throws "" G((im,im), 1, (im, im), :d) + @test f.a == (1,1) @test f.b == (1.0, 1.0) @test f.c == :s diff --git a/test/sum_structs_macro_tests.jl b/test/sum_structs_macro_tests.jl index ed7ce97..2844a4f 100644 --- a/test/sum_structs_macro_tests.jl +++ b/test/sum_structs_macro_tests.jl @@ -3,7 +3,7 @@ struct ST2 end end -@sum_structs A{X,Y} begin +@sum_structs A{X,Y<:Real} begin @kwdef mutable struct B{X} a::Tuple{X, X} b::Tuple{Float64, Float64} @@ -68,7 +68,10 @@ end b = B((1,1), (1.0, 1.0), :s) c1 = C((1,1), 1, 1, :c) c2 = C(; a = (1,1), d = 1, e = 1, c = :c) - + d = D((1,1), 1, (im, im), :d) + + @test_throws "" D((1,1), im, (im, im), :d) + @test b.a == (1,1) @test b.b == (1.0, 1.0) @test b.c == :s