Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First class variants in InterTypes #104

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions src/intertypes/julia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@
end
end

function extract_variant_names(variantexprs)
map(variantexprs) do vexpr
@match vexpr begin
tag::Symbol => tag

Check warning on line 131 in src/intertypes/julia.jl

View check run for this annotation

Codecov / codecov/patch

src/intertypes/julia.jl#L131

Added line #L131 was not covered by tests
Expr(:call, tag, fieldexprs...) => tag
_ => error("could not parse variant from $vexpr")
end
end
end

function parse_intertype_decl(e; mod::InterTypeModule)
@match e begin
Expr(:const, Expr(:(=), name::Symbol, type)) => Pair(name, Alias(parse_intertype(type; mod)))
Expand All @@ -139,8 +149,15 @@
Expr(:sum, name::Symbol, body) => begin
Base.remove_linenums!(body)
mod.declarations[name] = Alias(TypeRef(RefPath(:nothing)))
variant_names = extract_variant_names(body.args)
for vname in variant_names
mod.declarations[vname] = Alias(TypeRef(RefPath(:nothing)))
end
ret = Pair(name, SumType(parse_variants(body.args; mod)))
delete!(mod.declarations, name)
for vname in variant_names
delete!(mod.declarations, vname)
end
ret
end
Expr(:schema, head, body) => begin
Expand Down Expand Up @@ -412,30 +429,37 @@
Expr(:if, cond, body, expr)
end

function reader(name, decl::InterTypeDecl)
body = @match decl begin
Struct(fields) => variantreader(name, fields)
SumType(variants) => begin
tag = gensym(:tag)
ifs = makeifs(map(variants) do variant
(
:($tag == $(string(variant.tag))),
variantreader(variant.tag, variant.fields)
)
end)
quote
$tag = s[:_type]
$ifs
end
reader(name, decl::InterTypeDecl) = nothing

function reader(name, decl::Union{Struct, Variant})
body = variantreader(name, decl.fields)
quote
function $(GlobalRef(InterTypes, :read))(
format::$(JSONFormat), ::Type{$(name)}, s::$(JSON3.Object)
)
$body
end
_ => nothing
end
if !isnothing(body)
:(function $(GlobalRef(InterTypes, :read))(format::$(JSONFormat), ::Type{$(name)}, s::$(JSON3.Object))
$body
end)
else
nothing
end

function reader(name, decl::SumType)
tag = gensym(:tag)
variants = decl.variants
variantreaders = map(variants) do variant
reader(variant.tag, variant)
end
ifs = makeifs(map(variants) do variant
(
:($tag == $(string(variant.tag))),
:($(GlobalRef(InterTypes, :read))($(JSONFormat()), $(variant.tag), s))
)
end)
quote
$(variantreaders...)
function $(GlobalRef(InterTypes, :read))(format::$(JSONFormat), ::Type{$(name)}, s::$(JSON3.Object))
$tag = s[:_type]
$ifs
end
end
end

Expand Down
1 change: 1 addition & 0 deletions test/intertypes/InterTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ s = jsonwrite(t)
@test s isa String

@test jsonread(s, Term) == t
@test jsonread(s, Plus) == t

generate_module(simpleast, JSONTarget)

Expand Down
1 change: 1 addition & 0 deletions test/intertypes/simpleast.it
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ end
@sum Term begin
Constant(val::Const)
Plus(terms::Vector{Term})
Times(terms::Vector{Plus}) # to test referring to other variants
IfThenElse(ifcase::Term, thencase::Term, elsecase::Term)
end

Expand Down