Skip to content

Commit

Permalink
Simplify modality_is_const_for_axis
Browse files Browse the repository at this point in the history
Now that the types for jkind axes and mode axes have been unified, we
can express the condition for "modality is const for axis" directly,
without resorting to a huge ugly pattern match on the axis and the
modality's atoms.
  • Loading branch information
glittershark committed Dec 28, 2024
1 parent 58a9931 commit e8ad948
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 64 deletions.
66 changes: 4 additions & 62 deletions typing/jkind_axis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -175,71 +175,13 @@ module Axis = struct
| Nonmodal Externality -> true
| Nonmodal Nullability -> false

(* CR aspsmith: This can get a lot simpler once we unify jkind axes with the axes in
Mode *)
let modality_is_const_for_axis (type a) (t : a t) modality =
let modality_is_const_for_axis (type a) (t : a t)
(modality : Mode.Modality.Value.Const.t) =
match t with
| Nonmodal Nullability | Nonmodal Externality -> false
| Modal axis ->
let atoms = Mode.Modality.Value.Const.to_list modality in
List.exists
(fun (modality : Mode.Modality.t) ->
match axis, modality with
(* Constant modalities *)
| Comonadic Areality, Atom (Comonadic Areality, Meet_with Global) ->
true
| Comonadic Linearity, Atom (Comonadic Linearity, Meet_with Many) ->
true
| Monadic Uniqueness, Atom (Monadic Uniqueness, Join_with Aliased) ->
true
| ( Comonadic Portability,
Atom (Comonadic Portability, Meet_with Portable) ) ->
true
| Monadic Contention, Atom (Monadic Contention, Join_with Contended)
->
true
(* Modalities which are actually identity *)
| Comonadic Areality, Atom (Comonadic Areality, Meet_with Local)
| Comonadic Linearity, Atom (Comonadic Linearity, Meet_with Once)
| Monadic Uniqueness, Atom (Monadic Uniqueness, Join_with Unique)
| ( Comonadic Portability,
Atom (Comonadic Portability, Meet_with Nonportable) )
| Monadic Contention, Atom (Monadic Contention, Join_with Uncontended)
->
false
(* Modalities which are neither constant nor identiy *)
| Comonadic Areality, Atom (Comonadic Areality, Meet_with Regional)
| Monadic Contention, Atom (Monadic Contention, Join_with Shared) ->
Misc.fatal_error
"Don't yet know how to interpret non-constant, non-identity \
modalities"
(* Modalities which join or meet on an illegal axis *)
| _, Atom (Comonadic _, Join_with _) | _, Atom (Monadic _, Meet_with _)
->
Misc.fatal_error "Illegal modality"
(* Mismatched axes *)
| Comonadic Areality, Atom (Monadic Uniqueness, _)
| Comonadic Areality, Atom (Monadic Contention, _)
| Comonadic Areality, Atom (Comonadic Linearity, _)
| Comonadic Areality, Atom (Comonadic Portability, _)
| Comonadic Linearity, Atom (Comonadic Areality, _)
| Comonadic Linearity, Atom (Monadic Uniqueness, _)
| Comonadic Portability, Atom (Monadic Uniqueness, _)
| Monadic Contention, Atom (Monadic Uniqueness, _)
| Comonadic Linearity, Atom (Comonadic Portability, _)
| Comonadic Linearity, Atom (Monadic Contention, _)
| Monadic Uniqueness, Atom (Comonadic Areality, _)
| Monadic Uniqueness, Atom (Comonadic Linearity, _)
| Monadic Uniqueness, Atom (Comonadic Portability, _)
| Monadic Contention, Atom (Comonadic Areality, _)
| Monadic Contention, Atom (Comonadic Linearity, _)
| Monadic Contention, Atom (Comonadic Portability, _)
| Monadic Uniqueness, Atom (Monadic Contention, _)
| Comonadic Portability, Atom (Comonadic Areality, _)
| Comonadic Portability, Atom (Comonadic Linearity, _)
| Comonadic Portability, Atom (Monadic Contention, _) ->
false)
atoms
let (P axis) = Mode.Const.Axis.alloc_as_value axis in
Mode.Modality.Value.Const.is_constant_for axis modality
end

module type Axed = sig
Expand Down
56 changes: 54 additions & 2 deletions typing/mode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1910,6 +1910,12 @@ module Value_with (Areality : Areality) = struct
| Comonadic ax -> Comonadic.max_axis ax
| Monadic ax -> Monadic.max_axis ax

let is_max : type m a d. (m, a, d) axis -> a -> bool =
fun ax m -> le_axis ax (max_axis ax) m

let is_min : type m a d. (m, a, d) axis -> a -> bool =
fun ax m -> le_axis ax m (min_axis ax)

let split = split

let merge = merge
Expand Down Expand Up @@ -2168,6 +2174,19 @@ module Const = struct
let areality = C.locality_as_regionality areality in
{ areality; linearity; portability; uniqueness; contention }

module Axis = struct
type 'd packed_value_axis =
| P : ('m, 'a, 'd) Value.axis -> 'd packed_value_axis

let alloc_as_value : type m a d. (m, a, d) Alloc.axis -> d packed_value_axis
= function
| Comonadic Areality -> P (Comonadic Areality)
| Comonadic Linearity -> P (Comonadic Linearity)
| Comonadic Portability -> P (Comonadic Portability)
| Monadic Uniqueness -> P (Monadic Uniqueness)
| Monadic Contention -> P (Monadic Contention)
end

let locality_as_regionality = C.locality_as_regionality
end

Expand Down Expand Up @@ -2213,8 +2232,13 @@ module Modality = struct

let is_id (Atom (ax, a)) =
match a with
| Join_with c -> Value.Const.le_axis ax c (Value.Const.min_axis ax)
| Meet_with c -> Value.Const.le_axis ax (Value.Const.max_axis ax) c
| Join_with c -> Value.Const.is_min ax c
| Meet_with c -> Value.Const.is_max ax c

let is_constant (Atom (ax, a)) =
match a with
| Join_with c -> Value.Const.is_max ax c
| Meet_with c -> Value.Const.is_min ax c

let print ppf = function
| Atom (ax, Join_with c) ->
Expand Down Expand Up @@ -2273,6 +2297,9 @@ module Modality = struct
(let ax : _ Axis.t = Contention in
Atom (Monadic ax, Join_with (Axis.proj ax c))) ]

let proj ax = function
| Join_const c -> Atom (Monadic ax, Join_with (Axis.proj ax c))

let print ppf = function
| Join_const c -> Format.fprintf ppf "join_const(%a)" Mode.Const.print c
end
Expand Down Expand Up @@ -2415,6 +2442,9 @@ module Modality = struct
(let ax : _ Axis.t = Portability in
Atom (Comonadic ax, Meet_with (Axis.proj ax c))) ]

let proj ax = function
| Meet_const c -> Atom (Comonadic ax, Meet_with (Axis.proj ax c))

let print ppf = function
| Meet_const c -> Format.fprintf ppf "meet_const(%a)" Mode.Const.print c
end
Expand Down Expand Up @@ -2540,6 +2570,8 @@ module Modality = struct

let id = { monadic = Monadic.id; comonadic = Comonadic.id }

let modality_is_id = is_id

let is_id { monadic; comonadic } =
Monadic.is_id monadic && Comonadic.is_id comonadic

Expand Down Expand Up @@ -2576,6 +2608,26 @@ module Modality = struct

let to_list { monadic; comonadic } =
Comonadic.to_list comonadic @ Monadic.to_list monadic

let proj_monadic ax { monadic; _ } = Monadic.proj ax monadic

let proj_comonadic ax { comonadic; _ } = Comonadic.proj ax comonadic

let proj (type m a d) (ax : (m, a, d) Value.axis) t =
match ax with
| Monadic ax -> proj_monadic ax t
| Comonadic ax -> proj_comonadic ax t

let is_constant_for (type m a d) (axis : (m, a, d) Value.axis) t =
let modality = proj axis t in
if is_constant modality
then true
else if modality_is_id modality
then false
else
Misc.fatal_error
"Don't yet know how to interpret non-constant, non-identity \
modalities"
end

type t = (Monadic.t, Comonadic.t) monadic_comonadic
Expand Down
13 changes: 13 additions & 0 deletions typing/mode_intf.mli
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ module type S = sig
| Contention : (monadic, Contention.Const.t) t

val print : Format.formatter -> ('p, 'r) t -> unit

val eq : ('p, 'r0) t -> ('p, 'r1) t -> ('r0, 'r1) Misc.eq option
end

module type Mode := sig
Expand Down Expand Up @@ -440,6 +442,13 @@ module type S = sig
module Const : sig
val alloc_as_value : Alloc.Const.t -> Value.Const.t

module Axis : sig
type 'd packed_value_axis =
| P : ('m, 'a, 'd) Value.axis -> 'd packed_value_axis

val alloc_as_value : ('m, 'a, 'd) Alloc.axis -> 'd packed_value_axis
end

val locality_as_regionality : Locality.Const.t -> Regionality.Const.t
end

Expand Down Expand Up @@ -518,6 +527,10 @@ module type S = sig
output list exactly once. *)
val to_list : t -> atom list

(** Test if the given modality is a constant modality along the given
axis. *)
val is_constant_for : ('m, 'a, 'd) Value.axis -> t -> bool

(** [equate t0 t1] checks that [t0 = t1].
Definition: [t0 = t1] iff [t0 <= t1] and [t1 <= t0]. *)
val equate : t -> t -> (unit, equate_error) Result.t
Expand Down

0 comments on commit e8ad948

Please sign in to comment.