diff --git a/Project.toml b/Project.toml index 0315ac8..7f71e25 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "MixedStructTypes" uuid = "3d69f371-6fa5-5add-b11c-3293622cad62" -version = "0.2.3" +version = "0.2.4" [deps] ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" diff --git a/src/SumStructs.jl b/src/SumStructs.jl index ef8d0c3..c4823cc 100644 --- a/src/SumStructs.jl +++ b/src/SumStructs.jl @@ -25,10 +25,20 @@ macro sum_structs(type, struct_defs) variants_types = [] hidden_struct_types = [] + variants_params_unconstr = [[] for _ in 1:length(struct_defs)] + + fields_each, default_each = [], [] + for a_spec in struct_defs + a_comps = decompose_struct_no_base(a_spec) + push!(fields_each, a_comps[2][1]) + push!(default_each, a_comps[2][2]) + end + for (i, d) in enumerate(struct_defs) t = d.args[2] c = @capture(t, t_n_{t_p__}) c == false && ((t_n, t_p) = (t, [])) + append!(variants_params_unconstr[i], t_p) t_p_no_sup = [p isa Expr && p.head == :(<:) ? p.args[1] : p for p in t_p] push!(variants_types, t_p != [] ? :($t_n{$(t_p_no_sup...)}) : t_n) h_t = gensym(t_n) @@ -44,13 +54,6 @@ macro sum_structs(type, struct_defs) struct_defs[i] = d_new end - fields_each, default_each = [], [] - for a_spec in struct_defs - a_comps = decompose_struct_no_base(a_spec) - push!(fields_each, a_comps[2][1]) - push!(default_each, a_comps[2][2]) - end - struct_defs = [:($Base.@kwdef $d) for d in struct_defs] variants_defs = [:($t(ht::$ht)) for (t, ht) in zip(variants_types, hidden_struct_types)] @@ -139,22 +142,34 @@ macro sum_structs(type, struct_defs) expr_constructors = [] - for (fs, fd, t, h_t, is_kw) in zip(fields_each, default_each, variants_types, hidden_struct_types, is_kws) - f_d_n = retrieve_fields_names(fs, false) - f_d_n_t = retrieve_fields_names(fs, true) + for (fs, fd, t, h_t, t_p_u, is_kw) in zip(fields_each, default_each, variants_types, hidden_struct_types, variants_params_unconstr, is_kws) + f_params_args = retrieve_fields_names(fs, false) + f_params_args_with_T = retrieve_fields_names(fs, true) c = @capture(t, t_n_{t_p__}) a_spec_n_d = [d != "#328723329" ? Expr(:kw, n, d) : (:($n)) - for (n, d) in zip(f_d_n, fd)] + for (n, d) in zip(f_params_args, fd)] f_params_kwargs = Expr(:parameters, a_spec_n_d...) + @capture(type, new_type_n_{new_type_p__}) + if new_type_p === nothing + new_type_n, new_type_p = type, [] + end + new_type_p = [t isa Expr && t.head == :(<:) ? t.args[1] : t for t in new_type_p] + f_params_args_with_T = [!any(p -> inexpr(x, p), new_type_p) ? (x isa Symbol ? x : x.args[1]) : x + for x in f_params_args_with_T] + struct_spec_n2_d = [d != "#328723329" ? Expr(:kw, n, d) : (:($n)) + for (n, d) in zip(f_params_args_with_T, fd)] + f_params_kwargs_with_T = struct_spec_n2_d + f_params_kwargs_with_T = Expr(:parameters, f_params_kwargs_with_T...) + if t_p !== nothing - c1 = :(function $t($(f_d_n...)) where {$(t_p...)} - return $t($h_t($(f_d_n...))) + c1 = :(function $t($(f_params_args...)) where {$(t_p_u...)} + return $t($h_t($(f_params_args...))) end ) c4 = :() if is_kw - c4 = :(function $t($(f_params_kwargs)) where {$(t_p...)} - return $t($h_t($(f_d_n...))) + c4 = :(function $t($(f_params_kwargs)) where {$(t_p_u...)} + return $t($h_t($(f_params_args...))) end ) end @@ -162,14 +177,14 @@ macro sum_structs(type, struct_defs) c1 = :() c4 = :() end - c2 = :(function $(namify(t))($(f_d_n...)) - return $(namify(t))($(namify(h_t))($(f_d_n...))) + c2 = :(function $(namify(t))($(f_params_args_with_T...)) where {$(t_p_u...)} + return $(namify(t))($(namify(h_t))($(f_params_args...))) end ) c3 = :() if is_kw - c3 = :(function $(namify(t))($(f_params_kwargs)) - return $(namify(t))($(namify(h_t))($(f_d_n...))) + c3 = :(function $(namify(t))($(f_params_kwargs_with_T)) where {$(t_p_u...)} + return $(namify(t))($(namify(h_t))($(f_params_args...))) end ) end