From cb9f1f94bec66767592783d27dca38b50ea833fb Mon Sep 17 00:00:00 2001 From: Aspen Smith Date: Thu, 26 Dec 2024 13:56:40 -0500 Subject: [PATCH] Simplify modality_is_const_for_axis Now that the types for jkind axes and mode axes have been unified, we can express the condition for "modality is const for axis" directly, without resorting to a huge ugly pattern match on the axis and the modality's atoms. --- typing/jkind_axis.ml | 66 +++----------------------------------------- typing/mode.ml | 56 +++++++++++++++++++++++++++++++++++-- typing/mode_intf.mli | 11 ++++++++ 3 files changed, 69 insertions(+), 64 deletions(-) diff --git a/typing/jkind_axis.ml b/typing/jkind_axis.ml index a9961f6955e..4874ceb8cfc 100644 --- a/typing/jkind_axis.ml +++ b/typing/jkind_axis.ml @@ -175,71 +175,13 @@ module Axis = struct | Nonmodal Externality -> true | Nonmodal Nullability -> false - (* CR aspsmith: This can get a lot simpler once we unify jkind axes with the axes in - Mode *) - let modality_is_const_for_axis (type a) (t : a t) modality = + let modality_is_const_for_axis (type a) (t : a t) + (modality : Mode.Modality.Value.Const.t) = match t with | Nonmodal Nullability | Nonmodal Externality -> false | Modal axis -> - let atoms = Mode.Modality.Value.Const.to_list modality in - List.exists - (fun (modality : Mode.Modality.t) -> - match axis, modality with - (* Constant modalities *) - | Comonadic Areality, Atom (Comonadic Areality, Meet_with Global) -> - true - | Comonadic Linearity, Atom (Comonadic Linearity, Meet_with Many) -> - true - | Monadic Uniqueness, Atom (Monadic Uniqueness, Join_with Aliased) -> - true - | ( Comonadic Portability, - Atom (Comonadic Portability, Meet_with Portable) ) -> - true - | Monadic Contention, Atom (Monadic Contention, Join_with Contended) - -> - true - (* Modalities which are actually identity *) - | Comonadic Areality, Atom (Comonadic Areality, Meet_with Local) - | Comonadic Linearity, Atom (Comonadic Linearity, Meet_with Once) - | Monadic Uniqueness, Atom (Monadic Uniqueness, Join_with Unique) - | ( Comonadic Portability, - Atom (Comonadic Portability, Meet_with Nonportable) ) - | Monadic Contention, Atom (Monadic Contention, Join_with Uncontended) - -> - false - (* Modalities which are neither constant nor identiy *) - | Comonadic Areality, Atom (Comonadic Areality, Meet_with Regional) - | Monadic Contention, Atom (Monadic Contention, Join_with Shared) -> - Misc.fatal_error - "Don't yet know how to interpret non-constant, non-identity \ - modalities" - (* Modalities which join or meet on an illegal axis *) - | _, Atom (Comonadic _, Join_with _) | _, Atom (Monadic _, Meet_with _) - -> - Misc.fatal_error "Illegal modality" - (* Mismatched axes *) - | Comonadic Areality, Atom (Monadic Uniqueness, _) - | Comonadic Areality, Atom (Monadic Contention, _) - | Comonadic Areality, Atom (Comonadic Linearity, _) - | Comonadic Areality, Atom (Comonadic Portability, _) - | Comonadic Linearity, Atom (Comonadic Areality, _) - | Comonadic Linearity, Atom (Monadic Uniqueness, _) - | Comonadic Portability, Atom (Monadic Uniqueness, _) - | Monadic Contention, Atom (Monadic Uniqueness, _) - | Comonadic Linearity, Atom (Comonadic Portability, _) - | Comonadic Linearity, Atom (Monadic Contention, _) - | Monadic Uniqueness, Atom (Comonadic Areality, _) - | Monadic Uniqueness, Atom (Comonadic Linearity, _) - | Monadic Uniqueness, Atom (Comonadic Portability, _) - | Monadic Contention, Atom (Comonadic Areality, _) - | Monadic Contention, Atom (Comonadic Linearity, _) - | Monadic Contention, Atom (Comonadic Portability, _) - | Monadic Uniqueness, Atom (Monadic Contention, _) - | Comonadic Portability, Atom (Comonadic Areality, _) - | Comonadic Portability, Atom (Comonadic Linearity, _) - | Comonadic Portability, Atom (Monadic Contention, _) -> - false) - atoms + let (P axis) = Mode.Const.Axis.alloc_as_value axis in + Mode.Modality.Value.Const.is_constant_for axis modality end module type Axed = sig diff --git a/typing/mode.ml b/typing/mode.ml index 2831536b9d6..6205e03aba2 100644 --- a/typing/mode.ml +++ b/typing/mode.ml @@ -1910,6 +1910,12 @@ module Value_with (Areality : Areality) = struct | Comonadic ax -> Comonadic.max_axis ax | Monadic ax -> Monadic.max_axis ax + let is_max : type m a d. (m, a, d) axis -> a -> bool = + fun ax m -> le_axis ax (max_axis ax) m + + let is_min : type m a d. (m, a, d) axis -> a -> bool = + fun ax m -> le_axis ax m (min_axis ax) + let split = split let merge = merge @@ -2168,6 +2174,19 @@ module Const = struct let areality = C.locality_as_regionality areality in { areality; linearity; portability; uniqueness; contention } + module Axis = struct + type 'd packed_value_axis = + | P : ('m, 'a, 'd) Value.axis -> 'd packed_value_axis + + let alloc_as_value : type m a d. (m, a, d) Alloc.axis -> d packed_value_axis + = function + | Comonadic Areality -> P (Comonadic Areality) + | Comonadic Linearity -> P (Comonadic Linearity) + | Comonadic Portability -> P (Comonadic Portability) + | Monadic Uniqueness -> P (Monadic Uniqueness) + | Monadic Contention -> P (Monadic Contention) + end + let locality_as_regionality = C.locality_as_regionality end @@ -2213,8 +2232,13 @@ module Modality = struct let is_id (Atom (ax, a)) = match a with - | Join_with c -> Value.Const.le_axis ax c (Value.Const.min_axis ax) - | Meet_with c -> Value.Const.le_axis ax (Value.Const.max_axis ax) c + | Join_with c -> Value.Const.is_min ax c + | Meet_with c -> Value.Const.is_max ax c + + let is_constant (Atom (ax, a)) = + match a with + | Join_with c -> Value.Const.is_max ax c + | Meet_with c -> Value.Const.is_min ax c let print ppf = function | Atom (ax, Join_with c) -> @@ -2273,6 +2297,9 @@ module Modality = struct (let ax : _ Axis.t = Contention in Atom (Monadic ax, Join_with (Axis.proj ax c))) ] + let proj ax = function + | Join_const c -> Atom (Monadic ax, Join_with (Axis.proj ax c)) + let print ppf = function | Join_const c -> Format.fprintf ppf "join_const(%a)" Mode.Const.print c end @@ -2415,6 +2442,9 @@ module Modality = struct (let ax : _ Axis.t = Portability in Atom (Comonadic ax, Meet_with (Axis.proj ax c))) ] + let proj ax = function + | Meet_const c -> Atom (Comonadic ax, Meet_with (Axis.proj ax c)) + let print ppf = function | Meet_const c -> Format.fprintf ppf "meet_const(%a)" Mode.Const.print c end @@ -2540,6 +2570,8 @@ module Modality = struct let id = { monadic = Monadic.id; comonadic = Comonadic.id } + let modality_is_id = is_id + let is_id { monadic; comonadic } = Monadic.is_id monadic && Comonadic.is_id comonadic @@ -2576,6 +2608,26 @@ module Modality = struct let to_list { monadic; comonadic } = Comonadic.to_list comonadic @ Monadic.to_list monadic + + let proj_monadic ax { monadic; _ } = Monadic.proj ax monadic + + let proj_comonadic ax { comonadic; _ } = Comonadic.proj ax comonadic + + let proj (type m a d) (ax : (m, a, d) Value.axis) t = + match ax with + | Monadic ax -> proj_monadic ax t + | Comonadic ax -> proj_comonadic ax t + + let is_constant_for (type m a d) (axis : (m, a, d) Value.axis) t = + let modality = proj axis t in + if is_constant modality + then true + else if modality_is_id modality + then false + else + Misc.fatal_error + "Don't yet know how to interpret non-constant, non-identity \ + modalities" end type t = (Monadic.t, Comonadic.t) monadic_comonadic diff --git a/typing/mode_intf.mli b/typing/mode_intf.mli index 781094749a8..8098ed8e1d8 100644 --- a/typing/mode_intf.mli +++ b/typing/mode_intf.mli @@ -286,6 +286,8 @@ module type S = sig | Contention : (monadic, Contention.Const.t) t val print : Format.formatter -> ('p, 'r) t -> unit + + val eq : ('p, 'r0) t -> ('p, 'r1) t -> ('r0, 'r1) Misc.eq option end module type Mode := sig @@ -440,6 +442,13 @@ module type S = sig module Const : sig val alloc_as_value : Alloc.Const.t -> Value.Const.t + module Axis : sig + type 'd packed_value_axis = + | P : ('m, 'a, 'd) Value.axis -> 'd packed_value_axis + + val alloc_as_value : ('m, 'a, 'd) Alloc.axis -> 'd packed_value_axis + end + val locality_as_regionality : Locality.Const.t -> Regionality.Const.t end @@ -518,6 +527,8 @@ module type S = sig output list exactly once. *) val to_list : t -> atom list + val is_constant_for : ('m, 'a, 'd) Value.axis -> t -> bool + (** [equate t0 t1] checks that [t0 = t1]. Definition: [t0 = t1] iff [t0 <= t1] and [t1 <= t0]. *) val equate : t -> t -> (unit, equate_error) Result.t