diff --git a/.gitignore b/.gitignore index 61b7e47..87d5612 100644 --- a/.gitignore +++ b/.gitignore @@ -1,31 +1,7 @@ .ipynb_checkpoints/ -*.annot -*.cmo -*.cma -*.cmi -*.a -*.o -*.cmx -*.cmxs -*.cmxa - -# ocamlbuild working directory _build/ -# ocamlbuild targets -*.byte -*.native - -# oasis generated files -setup.data -setup.log - -# Merlin configuring file for Vim and Emacs -.merlin - -*.install - config-*sh config.ini _coverage/ diff --git a/.ocamlformat b/.ocamlformat index e0be211..111a92a 100644 --- a/.ocamlformat +++ b/.ocamlformat @@ -1,3 +1,3 @@ margin = 85 break-cases=fit -profile=conventional \ No newline at end of file +profile=conventional diff --git a/LICENSE b/LICENSE index 753f10b..faaf993 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright 2018 John K. Feser +Copyright 2022 John K. Feser Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: diff --git a/bin/dune b/bin/dune index e162e4c..2cab6b8 100644 --- a/bin/dune +++ b/bin/dune @@ -61,6 +61,29 @@ (pps ppx_sexp_conv ppx_let)) (modules combine)) -; Local Variables: -; mode: tuareg-dune -; End: +(executable + (name opt) + (public_name opt.exe) + (libraries core core_unix.command_unix castor castor_opt logs logs.fmt fmt + fmt.tty) + (preprocess + (pps ppx_sexp_conv ppx_let ppx_sexp_conv ppx_compare ppx_hash)) + (modules opt)) + +(executable + (name xform) + (public_name xform.exe) + (libraries core core_unix.command_unix castor castor_opt logs logs.fmt fmt + fmt.tty) + (preprocess + (pps ppx_sexp_conv ppx_let ppx_sexp_conv ppx_compare ppx_hash)) + (modules xform)) + +(executable + (name sql) + (public_name sql.exe) + (libraries core core_unix.command_unix castor castor_opt logs logs.fmt fmt + fmt.tty) + (preprocess + (pps ppx_sexp_conv ppx_let ppx_sexp_conv ppx_compare ppx_hash)) + (modules sql)) diff --git a/bin/explore.ml b/bin/explore.ml new file mode 100644 index 0000000..5942604 --- /dev/null +++ b/bin/explore.ml @@ -0,0 +1,117 @@ +open Core +open Castor + +module Node = struct + module T = struct + type t = Abslayout.t * int [@@deriving compare, hash, sexp_of] + end + + include T + include Comparator.Make (T) +end + +module Edge = struct + module T = struct + type t = Ok of Node.t * Node.t * string | Err of Node.t * string + [@@deriving compare, hash, sexp_of] + end + + include T + include Comparator.Make (T) +end + +let choose ls = List.nth_exn ls (Random.int (List.length ls)) +let choose_set ls = Set.nth ls (Random.int (Set.length ls)) + +let main ~params ~db ch = + let params = + List.map params ~f:(fun (n, t) -> Name.create ~type_:t n) + |> Set.of_list (module Name.Compare_no_type) + in + let module Config = struct + let conn = Db.create db + let params = params + let check_transforms = true + end in + let module A = Abslayout_db.Make (Config) in + let module T = Transform.Make (Config) (A) () in + let query_str = In_channel.input_all ch in + let query = Abslayout.of_string_exn query_str |> A.resolve ~params in + let explore ?(max_nodes = 10000) query = + let tfs = + List.filter_map T.transforms ~f:(fun (_, tf) -> + try Some (tf []) with _ -> None) + in + let edges = ref (Set.empty (module Edge)) in + let nodes = ref (Set.singleton (module Node) (query, 0)) in + let rec loop () = + if Set.length !nodes > max_nodes then () + else + match choose_set !nodes with + | Some ((query, _) as n) -> + let tf = choose tfs in + (try + match T.run tf query with + | [] -> () + | ls -> + let n' = (choose ls, Set.length !nodes) in + edges := Set.add !edges (Ok (n, n', tf.name)); + nodes := Set.add !nodes n' + with _ -> edges := Set.add !edges (Err (n, tf.name))); + loop () + | None -> () + in + loop (); + printf "digraph {"; + Set.to_sequence !edges + |> Sequence.iter ~f:(function + | Edge.Err ((_, i), name) -> + printf "%d -> err [label=\"%s\"];\n" i name + | Ok ((_, i1), (_, i2), name) -> + printf "%d -> %d [label=\"%s\"];\n" i1 i2 name); + printf "}" + in + explore query + +let reporter ppf = + let report _ level ~over k msgf = + let k _ = + over (); + k () + in + let with_time h _ k ppf fmt = + let time = Core.Time.now () in + Format.kfprintf k ppf + ("%a [%s] @[" ^^ fmt ^^ "@]@.") + Logs.pp_header (level, h) (Core.Time.to_string time) + in + msgf @@ fun ?header ?tags fmt -> with_time header tags k ppf fmt + in + { Logs.report } + +let () = + Logs.set_reporter (reporter Format.err_formatter); + let open Command in + let open Let_syntax in + Logs.info (fun m -> + m "%s" (Sys.argv |> Array.to_list |> String.concat ~sep:" ")); + basic ~summary:"Compile a query." + (let%map_open verbose = + flag "verbose" ~aliases:[ "v" ] no_arg ~doc:"increase verbosity" + and quiet = flag "quiet" ~aliases:[ "q" ] no_arg ~doc:"decrease verbosity" + and db = + flag "db" (required string) ~doc:"CONNINFO the database to connect to" + and params = + flag "param" ~aliases:[ "p" ] (listed Util.param) + ~doc:"NAME:TYPE query parameters" + and ch = + anon (maybe_with_default In_channel.stdin ("query" %: Util.channel)) + in + fun () -> + if verbose then Logs.set_level (Some Logs.Debug) + else if quiet then Logs.set_level (Some Logs.Error) + else Logs.set_level (Some Logs.Info); + Logs.info (fun m -> + m "%s" (Sys.argv |> Array.to_list |> String.concat ~sep:" ")); + main ~params ~db ch) + |> run diff --git a/bin/opt.ml b/bin/opt.ml new file mode 100644 index 0000000..dcd6d0f --- /dev/null +++ b/bin/opt.ml @@ -0,0 +1,170 @@ +open! Core +open Castor +open Collections +open Castor_opt +open Abslayout_load +module A = Abslayout + +let dump fn r = + Out_channel.with_file fn ~f:(fun ch -> + Fmt.pf (Format.formatter_of_out_channel ch) "%a" Abslayout.pp r) + +(** Run a command and return its output on stdout, logging it if it fails. *) +let command_out cmd = + let open Or_error.Let_syntax in + let ch = Core_unix.open_process_in cmd in + let out = In_channel.input_all ch in + let%map () = + Core_unix.Exit_or_signal.or_error (Core_unix.close_process_in ch) + |> Or_error.tag ~tag:cmd + in + out + +let system_exn cmd = + match Core_unix.system cmd with + | Ok () -> () + | Error (`Exit_non_zero code) -> + failwith @@ sprintf "Command '%s' exited with code %d" cmd code + | Error (`Signal signal) -> + failwith + @@ sprintf "Command '%s' terminated by signal %s" cmd + (Signal.to_string signal) + +let opt conn cost_conn params cost_timeout state query = + let module Config = struct + let conn = conn + let cost_conn = cost_conn + let params = params + let cost_timeout = cost_timeout + let random = state + end in + let module T = Transform.Make (Config) in + match Transform.optimize (module Config) query with + | First opt_query -> + if is_ok @@ T.is_serializable opt_query then Some opt_query + else ( + Logs.warn (fun m -> m "Not serializable:@ %a" A.pp opt_query); + None) + | Second failed_subquery -> + Logs.warn (fun m -> + m "Optimization failed for subquery:@ %a" A.pp failed_subquery); + None + +let eval dir params query = + let open Result.Let_syntax in + Logs.info (fun m -> m "Evaluating:@ %a" A.pp query); + + (* Set up the output directory. *) + system_exn @@ sprintf "rm -rf %s" dir; + system_exn @@ sprintf "mkdir -p %s" dir; + let query_fn = sprintf "%s/query.txt" dir in + dump query_fn query; + + (* Try to build the query. *) + let%bind () = + let compile_cmd = + let params = + List.map params ~f:(fun (n, t, _) -> + Fmt.str "-p %s:%a" n Prim_type.pp t) + |> String.concat ~sep:" " + in + sprintf + "$CASTOR_ROOT/../_build/default/castor/bin/compile.exe -o %s %s %s > \ + %s/compile.log 2>&1" + dir params query_fn dir + in + let%map out = command_out compile_cmd in + Logs.info (fun m -> m "Compile output: %s" out) + in + + (* Try to run the query. *) + let%map run_time = + let run_cmd = + let params = + List.map params ~f:(fun (_, _, v) -> sprintf "'%s'" @@ Value.to_param v) + |> String.concat ~sep:" " + in + sprintf "%s/scanner.exe -t 1 %s/data.bin %s" dir dir params + in + let%map out = command_out run_cmd in + let time, _ = String.lsplit2_exn ~on:' ' out in + String.rstrip ~drop:Char.is_alpha time |> Float.of_string + in + + run_time + +let trial_dir = sprintf "%s-trial" + +let copy_out out_file out_dir query = + dump out_file query; + system_exn @@ sprintf "rm -rf %s" out_dir; + system_exn @@ sprintf "mv -f %s %s" (trial_dir out_dir) out_dir + +let main ~params ~cost_timeout ~timeout ~out_dir ~out_file ch = + Random.init 0; + + let conn = Db.create (Sys.getenv_exn "CASTOR_OPT_DB") in + let cost_conn = conn in + let params_set = + List.map params ~f:(fun (n, t, _) -> Name.create ~type_:t n) + |> Set.of_list (module Name) + in + let query = + load_string_exn ~params:params_set conn @@ In_channel.input_all ch + in + + let best_cost = ref Float.infinity in + let cost state = + Fresh.reset Global.fresh; + match opt conn cost_conn params_set cost_timeout state query with + | Some query' -> ( + match eval (trial_dir out_dir) params query' with + | Ok cost -> + if Float.(cost < !best_cost) then ( + copy_out out_file out_dir query'; + best_cost := cost); + cost + | Error err -> + Logs.warn (fun m -> m "Evaluation failed: %a" Error.pp err); + Float.infinity) + | None -> Float.infinity + in + + let cost = Memo.of_comparable (module Mcmc.Random_choice.C) cost in + let max_time = Option.map ~f:Time.Span.of_sec timeout in + + try Mcmc.run ?max_time cost |> ignore + with Resolve.Resolve_error r -> Fmt.epr "%a@." (Resolve.pp_err Fmt.nop) r + +let spec = + let open Command.Let_syntax in + [%map_open + let () = Log.param + and () = Ops.param + and () = Db.param + and () = Type_cost.param + and () = Join_opt.param + and () = Groupby_tactics.param + and () = Type.param + and () = Simplify_tactic.param + and cost_timeout = + flag "cost-timeout" (optional float) + ~doc:"SEC time to run cost estimation" + and timeout = + flag "timeout" (optional float) ~doc:"SEC time to run optimizer" + and params = + flag "param" ~aliases:[ "p" ] + (listed Util.param_and_value) + ~doc:"NAME:TYPE query parameters" + and out_dir = + flag "out-dir" (required string) ~aliases:[ "o" ] + ~doc:"DIR output directory" + and out_file = + flag "out-file" (required string) ~aliases:[ "f" ] + ~doc:"FILE output directory" + and ch = + anon (maybe_with_default In_channel.stdin ("query" %: Util.channel)) + in + fun () -> main ~params ~cost_timeout ~timeout ~out_dir ~out_file ch] + +let () = Command.basic spec ~summary:"Optimize a query." |> Command_unix.run diff --git a/bin/sql.ml b/bin/sql.ml new file mode 100644 index 0000000..bedc07d --- /dev/null +++ b/bin/sql.ml @@ -0,0 +1,86 @@ +open! Core +open Castor +open Collections +open Ast +open Abslayout_load +module A = Abslayout + +let main ~params:all_params ~simplify ~project ~unnest ~sql ~cse ch = + Logs.set_level (Some Debug); + Logs.Src.set_level Log.src (Some Debug); + Format.set_margin 120; + + let params = + List.map all_params ~f:(fun (n, t, _) -> Name.create ~type_:t n) + |> Set.of_list (module Name) + in + let module Config = struct + let conn = Db.create (Sys.getenv_exn "CASTOR_DB") + let params = params + end in + let module S = Simplify_tactic.Make (Config) in + let module O = Ops.Make (Config) in + let simplify q = + if simplify then + let q = + Cardinality.annotate ~dedup:true q + |> Join_elim.remove_joins |> Unnest.hoist_meta |> strip_meta + in + Option.value_exn (O.apply S.simplify Path.root q) + else q + in + + let query_str = In_channel.input_all ch in + let query = load_string_exn ~params Config.conn query_str in + let query = simplify query in + let query = + let q = + if unnest then Unnest.unnest ~params query + else + (Cardinality.annotate query + :> < cardinality_matters : bool ; why_card_matters : string > + Ast.annot) + in + Cardinality.extend ~dedup:true q + |> Join_elim.remove_joins |> Unnest.hoist_meta |> strip_meta + in + let query = simplify query in + let query = if project then Project.project ~params query else query in + let query = simplify query in + let query = + if cse then `Sql (Cse.extract_common query |> Cse.to_sql) else `Query query + in + let query = + if sql then + match query with + | `Sql _ -> query + | `Query q -> `Sql (Sql.of_ralgebra q |> Sql.to_string) + else query + in + match query with + | `Sql x -> Sql.format x |> print_endline + | `Query x -> Format.printf "%a@." Abslayout.pp x + +let spec = + let open Command.Let_syntax in + [%map_open + let () = Log.param + and sql = flag "sql" no_arg ~doc:"dump sql" + and simplify = + flag "simplify" ~aliases:[ "s" ] no_arg ~doc:"simplify the query" + and unnest = + flag "unnest" ~aliases:[ "u" ] no_arg ~doc:"unnest before simplifying" + and project = + flag "project" ~aliases:[ "r" ] no_arg ~doc:"project the query" + and cse = flag "cse" ~aliases:[ "c" ] no_arg ~doc:"apply cse" + and params = + flag "param" ~aliases:[ "p" ] + (listed Util.param_and_value) + ~doc:"NAME:TYPE query parameters" + and ch = + anon (maybe_with_default In_channel.stdin ("query" %: Util.channel)) + in + fun () -> main ~params ~simplify ~project ~unnest ~sql ~cse ch] + +let () = + Command.basic spec ~summary:"Optimize a query for sql." |> Command_unix.run diff --git a/bin/xform.ml b/bin/xform.ml new file mode 100644 index 0000000..3f00357 --- /dev/null +++ b/bin/xform.ml @@ -0,0 +1,672 @@ +open! Core +open Castor +open Collections +open Castor_opt +open Abslayout_load +module A = Abslayout +open Match + +module type CONFIG = sig + val conn : Db.t + val params : Set.M(Name).t + val cost_conn : Db.t +end + +module Xforms (C : CONFIG) = struct + module O = Ops.Make (C) + module Simplify = Simplify_tactic.Make (C) + module Filter = Filter_tactics.Make (C) + module Groupby = Groupby_tactics.Make (C) + module Orderby = Orderby_tactics.Make (C) + module Simple = Simple_tactics.Make (C) + module Join_elim = Join_elim_tactics.Make (C) + module Select = Select_tactics.Make (C) +end + +let config conn params = + (module struct + let conn = conn + let cost_conn = conn + let params = params + end : CONFIG) + +let main ~name ~params ~ch = + let conn = Db.create (Sys.getenv_exn "CASTOR_DB") in + let params = + List.map params ~f:(fun (n, t, _) -> Name.create ~type_:t n) + |> Set.of_list (module Name) + in + let query = load_string_exn ~params conn @@ In_channel.input_all ch in + + let (module C) = config conn params in + let open Xforms (C) in + let open O in + let open Simplify in + let open Filter in + let open Groupby in + let open Orderby in + let open Select in + let open Simple in + let open Join_elim in + (* Recursively optimize subqueries. *) + let apply_to_subqueries tf = + let subquery_visitor tf = + object (self : 'a) + inherit [_] V.map + + method visit_subquery r = + Option.value_exn ~message:"Transforming subquery failed." + (apply tf Path.root r) + + method! visit_Exists () r = Exists (self#visit_subquery r) + method! visit_First () r = First (self#visit_subquery r) + end + in + + let f r = Some ((subquery_visitor tf)#visit_t () r) in + of_func f ~name:"apply-to-subqueries" + in + + let apply_to_filter_subquery tf = + let open Option.Let_syntax in + let subquery_visitor tf = + object (self : 'a) + inherit [_] V.map + + method visit_subquery r = + let module C = struct + let conn = conn + let cost_conn = conn + let params = Set.union params (Free.free r) + end in + Option.value_exn ~message:"Transforming subquery failed." + (apply (tf (module C : CONFIG)) Path.root r) + + method! visit_Exists () r = Exists (self#visit_subquery r) + method! visit_First () r = First (self#visit_subquery r) + end + in + + let f r = + let%bind p, r' = to_filter r in + return @@ A.filter ((subquery_visitor tf)#visit_pred () p) r' + in + of_func f ~name:"apply-to-filter-subquery" + in + + let xform_1 = + seq_many + [ + at_ elim_groupby Path.(all >>? is_groupby >>| shallowest); + at_ push_orderby Path.(all >>? is_orderby >>| shallowest); + Branching.at_ elim_cmp_filter Path.(all >>? is_filter >>| deepest) + |> Branching.lower Seq.hd; + at_ push_filter Path.(all >>? is_filter >>| shallowest); + at_ push_filter Path.(all >>? is_filter >>| shallowest); + project; + simplify; + at_ push_select Path.(all >>? is_select >>| shallowest); + at_ row_store Path.(all >>? is_filter >>| shallowest); + project; + simplify; + ] + in + + let xform_2 = + seq_many + [ + id; + at_ hoist_filter + Path.(all >>? is_join >>? has_child is_filter >>| deepest); + at_ hoist_filter + Path.(all >>? is_join >>? has_child is_filter >>| deepest); + at_ hoist_filter + Path.(all >>? is_join >>? has_child is_filter >>| deepest); + at_ hoist_filter + Path.(all >>? is_join >>? has_child is_filter >>| deepest); + at_ hoist_filter + Path.(all >>? is_join >>? has_child is_filter >>| deepest); + at_ split_filter Path.(all >>? is_filter >>| shallowest); + at_ hoist_filter Path.(all >>? is_orderby >>| shallowest); + at_ elim_eq_filter Path.(all >>? is_filter >>| shallowest); + at_ push_orderby Path.(all >>? is_orderby >>| shallowest); + simplify; + at_ push_filter Path.(all >>? is_filter >>| shallowest); + at_ + (split_out + (Path.( + all >>? is_relation + >>? matches (function + | Relation r -> String.(r.r_name = "supplier") + | _ -> false) + >>| deepest) + >>= parent) + "s1_suppkey") + Path.(all >>? is_filter >>| shallowest); + fix project; + at_ split_filter Path.(all >>? is_filter >>| shallowest); + at_ hoist_filter Path.(all >>? is_filter >>| shallowest); + at_ split_filter Path.(all >>? is_filter >>| shallowest); + at_ split_filter Path.(all >>? is_filter >>| shallowest); + at_ + (precompute_filter "p_type" + @@ List.map ~f:(sprintf "\"%s\"") + [ "TIN"; "COPPER"; "NICKEL"; "BRASS"; "STEEL" ]) + (Path.(all >>? is_param_filter >>| shallowest) >>= child' 0); + at_ row_store (Path.(all >>? is_param_filter >>| deepest) >>= child' 0); + at_ row_store + (Path.(all >>? is_hash_idx >>| deepest) >>= child' 1 >>= child' 0); + project; + simplify; + ] + in + + let xform_3 = + seq_many + [ + fix + (at_ hoist_filter + (Path.(all >>? is_filter >>| shallowest) >>= parent)); + at_ elim_groupby Path.(all >>? is_groupby >>| shallowest); + fix + (at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent)); + at_ split_filter Path.(all >>? is_filter >>| shallowest); + Branching.at_ + (partition_eq "customer.c_mktsegment") + Path.(all >>? is_collection >>| shallowest) + |> Branching.lower Seq.hd; + fix + (at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent)); + at_ split_filter Path.(all >>? is_filter >>| shallowest); + at_ split_filter Path.(all >>? is_filter >>| shallowest); + at_ split_filter Path.(all >>? is_filter >>| shallowest); + at_ row_store + Infix.( + Path.(all >>? is_filter >>? not is_param_filter >>| shallowest)); + project; + simplify; + ] + in + + let xform_4 = + seq_many + [ + at_ elim_groupby Path.(all >>? is_groupby >>| shallowest); + at_ push_orderby Path.(all >>? is_orderby >>| shallowest); + at_ push_filter Path.(all >>? is_filter >>| shallowest); + Branching.at_ elim_cmp_filter Path.(all >>? is_filter >>| shallowest) + |> Branching.lower Seq.hd; + fix @@ at_ push_filter Path.(all >>? is_filter >>| shallowest); + simplify; + at_ push_select Path.(all >>? is_select >>| shallowest); + at_ row_store Path.(all >>? is_filter >>| shallowest); + project; + simplify; + ] + in + + let swap_filter p = seq_many [ at_ hoist_filter p; at_ split_filter p ] in + + let xform_5 = + seq_many + [ + fix + @@ at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + at_ elim_groupby_approx Path.(all >>? is_groupby >>| shallowest); + Branching.at_ + (partition_eq "region.r_name") + Path.(all >>? is_collection >>| shallowest) + |> Branching.lower Seq.hd; + swap_filter Path.(all >>? is_filter >>| shallowest); + Branching.at_ elim_cmp_filter + Path.(all >>? is_param_filter >>| shallowest) + |> Branching.lower Seq.hd; + simplify; + at_ push_select Path.(all >>? is_select >>? is_run_time >>| deepest); + at_ row_store Path.(all >>? is_filter >>? is_run_time >>| shallowest); + project; + simplify; + ] + in + + let push_no_param_filter = + fix + @@ at_ push_filter + Infix.(Path.(all >>? is_filter >>? not is_param_filter >>| shallowest)) + in + + let xform_6 = + seq_many + [ + fix @@ at_ split_filter Path.(all >>? is_filter >>| deepest); + at_ push_filter Path.(all >>? is_filter >>| shallowest); + at_ push_filter Path.(all >>? is_filter >>| shallowest >>= child' 0); + at_ hoist_filter Path.(all >>? is_filter >>| shallowest); + at_ split_filter Path.(all >>? is_filter >>| shallowest); + at_ hoist_filter Path.(all >>? is_filter >>| shallowest >>= child' 0); + at_ split_filter Path.(all >>? is_filter >>| shallowest >>= child' 0); + Branching.at_ elim_cmp_filter Path.(all >>? is_filter >>| deepest) + |> Branching.lower Seq.hd; + Branching.at_ elim_cmp_filter + Path.(all >>? is_filter >>| shallowest >>= child' 0) + |> Branching.lower Seq.hd; + push_no_param_filter; + at_ row_store Path.(all >>? is_filter >>| deepest); + fix project; + simplify; + ] + in + + let xform_7 = + seq_many + [ + at_ + (partition_domain "param0" "nation.n_name") + Path.(all >>? is_orderby >>| shallowest); + at_ + (partition_domain "param1" "nation.n_name") + Path.(all >>? is_collection >>| shallowest); + at_ row_store Path.(all >>? is_orderby >>| shallowest); + project; + simplify; + ] + in + + let xform_8 = + seq_many + [ + Branching.at_ + (partition_eq "region.r_name") + Path.(all >>? is_orderby >>| shallowest) + |> Branching.lower Seq.hd; + fix + @@ at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + at_ elim_groupby Path.(all >>? is_groupby >>| shallowest); + at_ push_orderby Path.(all >>? is_orderby >>| shallowest); + push_no_param_filter; + at_ hoist_join_param_filter Path.(all >>? is_join >>| shallowest); + at_ row_store Path.(all >>? is_join >>| shallowest); + project; + simplify; + ] + in + + let xform_9 = + seq_many + [ + fix + @@ at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + at_ elim_groupby Path.(all >>? is_groupby >>| shallowest); + at_ push_orderby Path.(all >>? is_orderby >>| shallowest); + at_ hoist_filter_extend + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + at_ split_filter Path.(all >>? is_param_filter >>| shallowest); + at_ + (precompute_filter_bv + @@ List.map ~f:(sprintf "\"%s\"") + [ + "black"; + "blue"; + "brown"; + "green"; + "grey"; + "navy"; + "orange"; + "pink"; + "purple"; + "red"; + "white"; + "yellow"; + ]) + Path.(all >>? is_param_filter >>| shallowest); + at_ row_store + (Path.(all >>? is_param_filter >>| shallowest) >>= child' 0); + project; + simplify; + ] + in + + let xform_10 = + seq_many + [ + at_ elim_groupby_flat Path.(all >>? is_groupby >>| shallowest); + fix + @@ at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + at_ split_filter Path.(all >>? is_param_filter >>| shallowest); + at_ split_filter Path.(all >>? is_param_filter >>| shallowest); + at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + at_ split_filter Path.(all >>? is_param_filter >>| shallowest); + at_ row_store + (Path.(all >>? is_param_filter >>| shallowest) >>= child' 0); + project; + simplify; + ] + in + + let xform_11 = + seq_many + [ + at_ + (partition_domain "param1" "nation.n_name") + Path.(all >>? is_filter >>| shallowest); + at_ elim_groupby Path.(all >>? is_groupby >>| shallowest); + at_ elim_subquery Path.(all >>? is_filter >>| shallowest); + at_ row_store (Path.(all >>? is_list >>| shallowest) >>= child' 1); + at_ hoist_param (Path.(all >>? is_depjoin >>| shallowest) >>= child' 0); + at_ row_store + (Path.(all >>? is_depjoin >>| shallowest) >>= child' 0 >>= child' 0); + project; + simplify; + ] + in + + let xform_12 = + seq_many + [ + at_ elim_groupby Path.(all >>? is_groupby >>| shallowest); + push_orderby; + at_ split_filter Path.(all >>? is_param_filter >>| shallowest); + fix + @@ at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + at_ split_filter Path.(all >>? is_param_filter >>| shallowest); + fix + @@ at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + at_ hoist_filter_agg + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + fix + @@ at_ hoist_filter + (Path.(all >>? is_param_filter >>| deepest) >>= parent); + at_ split_filter Path.(all >>? is_param_filter >>| deepest); + at_ split_filter Path.(all >>? is_param_filter >>| deepest); + at_ split_filter Path.(all >>? is_param_filter >>| deepest); + at_ hoist_filter (Path.(all >>? is_param_filter >>| deepest) >>= parent); + at_ split_filter Path.(all >>? is_param_filter >>| deepest); + at_ hoist_filter (Path.(all >>? is_param_filter >>| deepest) >>= parent); + at_ split_filter Path.(all >>? is_param_filter >>| deepest); + Branching.at_ elim_cmp_filter Path.(all >>? is_param_filter >>| deepest) + |> Branching.lower Seq.hd; + simplify; + at_ push_select Path.(all >>? is_select >>? is_run_time >>| shallowest); + at_ row_store + Infix.( + Path.(all >>? is_filter >>? not is_param_filter >>| shallowest)); + project; + simplify; + ] + in + + let xform_14 = + seq_many + [ + at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + Branching.at_ elim_cmp_filter Path.(all >>? is_filter >>| shallowest) + |> Branching.lower Seq.hd; + simplify; + push_select; + at_ row_store Path.(all >>? is_filter >>| shallowest); + project; + simplify; + ] + in + + let xform_15 = + seq_many + [ + at_ + (partition_domain "param1" "lineitem.l_shipdate") + Path.(all >>? is_orderby >>| shallowest); + at_ row_store Path.(all >>? is_orderby >>| shallowest); + at_ hoist_join_filter Path.(all >>? is_join >>| shallowest); + at_ elim_subquery_join Path.(all >>? is_filter >>| shallowest); + project; + simplify; + ] + in + + let xform_16 = + seq_many + [ + fix + @@ at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + at_ split_filter Path.(all >>? is_filter >>| shallowest); + at_ hoist_filter_extend + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + Branching.at_ elim_groupby_partial + Path.(all >>? is_groupby >>| shallowest) + |> Branching.lower (fun s -> Seq.nth s 2); + at_ row_store Path.(all >>? is_groupby >>| shallowest); + project; + simplify; + ] + in + + let xform_17 = + seq_many + [ + at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + at_ elim_eq_filter Path.(all >>? is_param_filter >>| shallowest); + apply_to_subqueries + (seq_many + [ + at_ + (partition_domain "p1_partkey" "part.p_partkey") + Path.(all >>? is_select >>| shallowest); + at_ row_store + Path.(all >>? is_select >>? is_run_time >>| deepest); + ]); + at_ row_store Path.(all >>? is_filter >>| deepest); + simplify; + project; + project; + ] + in + + let xform_18 = + seq_many + [ + fix + @@ at_ hoist_filter (Path.(all >>? is_filter >>| shallowest) >>= parent); + apply_to_subqueries + (seq_many + [ + split_filter; + at_ hoist_filter + (Path.(all >>? is_param_filter >>| shallowest) >>= parent); + split_filter; + at_ + (partition_domain "o1_orderkey" "lineitem.l_orderkey") + (Path.(all >>? is_param_filter >>| shallowest) >>= child' 0); + at_ row_store + (Path.(all >>? is_collection >>| shallowest) >>= child' 1); + simplify; + ]); + project; + at_ row_store Path.(all >>? is_orderby >>| shallowest); + project; + simplify; + ] + in + + let xform_19 = + seq_many + [ + at_ hoist_join_param_filter Path.(all >>? is_join >>| shallowest); + at_ elim_disjunct Path.(all >>? is_filter >>| shallowest); + fix + @@ at_ split_filter_params + (Path.(all >>? above is_param_filter >>? is_join >>| deepest) + >>= parent); + fix + @@ at_ row_store + Infix.( + Path.( + all >>? is_filter >>? not is_param_filter >>? is_run_time + >>| shallowest)); + project; + try_ simplify; + ] + in + + let xform_20 = + seq_many + [ + partition_domain "param2" "nation.n_name"; + at_ hoist_filter Path.(all >>? is_join >>| shallowest); + at_ hoist_filter Path.(all >>? is_orderby >>| shallowest); + at_ split_filter Path.(all >>? is_filter >>| shallowest); + at_ hoist_filter Path.(all >>? is_filter >>| shallowest); + at_ split_filter Path.(all >>? is_filter >>| shallowest); + at_ + ( apply_to_filter_subquery @@ fun (module C : CONFIG) -> + seq_many + [ + at_ + ( apply_to_filter_subquery @@ fun _ -> + seq_many + [ + Branching.at_ + (partition_eq "part.p_partkey") + Path.(all >>? is_filter >>| shallowest) + |> Branching.lower Seq.hd; + swap_filter Path.(all >>? is_filter >>| shallowest); + first row_store Path.(all >>? is_run_time); + ] ) + Path.(all >>? is_filter >>| shallowest); + at_ + ( apply_to_filter_subquery @@ fun (module C : CONFIG) -> + let open Filter_tactics.Make (C) in + let open Simplify_tactic.Make (C) in + let open Select_tactics.Make (C) in + seq_many + [ + at_ elim_eq_filter Path.(all >>? is_filter >>| deepest); + at_ push_filter + Path.(all >>? is_param_filter >>| shallowest); + Branching.at_ elim_cmp_filter + Path.(all >>? is_param_filter >>| shallowest) + |> Branching.lower Seq.hd; + simplify; + at_ push_select Path.(all >>? is_select >>| shallowest); + simplify; + at_ push_select + Path.(all >>? is_select >>? is_run_time >>| deepest); + project; + first row_store Path.(all >>? is_run_time); + project; + ] ) + (Path.(all >>? is_filter >>| shallowest) >>= child' 0); + (let open Filter_tactics.Make (C) in + at_ elim_eq_filter Path.(all >>? is_filter >>| deepest)); + at_ row_store + Path.(all >>? is_run_time >>? is_filter >>| deepest); + ] ) + Path.(all >>? is_filter >>| shallowest); + first row_store Path.(all >>? is_run_time); + project; + ] + in + let xform_21 = + seq_many + [ + partition_domain "param0" "nation.n_name"; + first row_store Path.(all >>? is_run_time); + simplify; + project; + ] + in + let xform_22 = + seq_many + [ + elim_groupby; + for_all push_select Path.(all >>? is_select); + for_all cse_filter Path.(all >>? is_filter); + for_all push_select Path.(all >>? is_select); + at_ push_filter Path.(all >>? is_filter >>| shallowest); + at_ push_filter Path.(all >>? is_filter >>| shallowest); + swap_filter Path.(all >>? is_filter >>| shallowest); + swap_filter (Path.(all >>? is_filter >>| shallowest) >>= child' 0); + at_ + ( apply_to_filter_subquery @@ fun (module C : CONFIG) -> + let open Ops.Make (C) in + let open Filter_tactics.Make (C) in + seq_many + [ + for_all cse_filter Path.(all >>? is_filter); + first row_store Path.all; + simplify; + project; + ] ) + (Path.(all >>? is_filter >>| shallowest) >>= child' 0); + first row_store Path.(all >>? is_run_time); + elim_subquery; + simplify; + project; + ] + in + + let xform = + match name with + | "1" -> xform_1 + | "2" -> xform_2 + | "3-no" -> xform_3 + | "4" -> xform_4 + | "5-no" -> xform_5 + | "6" -> xform_6 + | "7" -> xform_7 + | "8" -> xform_8 + | "9" -> xform_9 + | "10-no" -> xform_10 + | "11-no" -> xform_11 + | "12" -> xform_12 + | "14" -> xform_14 + | "15" -> xform_15 + | "16-no" -> xform_16 + | "17" -> xform_17 + | "18" -> xform_18 + | "19" -> xform_19 + | "20" -> xform_20 + | "21-no" -> xform_21 + | "22-no" -> xform_22 + | _ -> failwith "unknown query name" + in + let query' = apply xform Path.root query in + Option.iter query' ~f:(A.pp Fmt.stdout) + +let spec = + let open Command.Let_syntax in + [%map_open + let () = Log.param + and () = Ops.param + and () = Db.param + and () = Type_cost.param + and () = Join_opt.param + and () = Groupby_tactics.param + and () = Type.param + and () = Simplify_tactic.param + and params = + flag "param" ~aliases:[ "p" ] + (listed Util.param_and_value) + ~doc:"NAME:TYPE query parameters" + and name = flag "name" (required string) ~doc:"query name" + and ch = + anon (maybe_with_default In_channel.stdin ("query" %: Util.channel)) + in + fun () -> main ~name ~params ~ch] + +let () = + Command.basic spec ~summary:"Apply transformations to a query." + |> Command_unix.run diff --git a/dune b/dune index f687cd1..868dc10 100644 --- a/dune +++ b/dune @@ -5,7 +5,3 @@ (glob_files etc/*.h))) (dirs :standard \ bench) - -; Local Variables: -; mode: tuareg-dune -; End: diff --git a/dune-project b/dune-project index 71429ac..1d5db61 100644 --- a/dune-project +++ b/dune-project @@ -2,6 +2,7 @@ (name castor) (source (github jfeser/castor)) + (license MIT) (authors "Jack Feser ") (maintainers "Jack Feser ") @@ -54,8 +55,8 @@ ounit (expect_test_helpers_core (>= v0.12)))) - (using menhir 2.0) +(using fmt 1.1) (generate_opam_files true) ; Local Variables: diff --git a/lib/dedup_tactics.ml b/lib/dedup_tactics.ml new file mode 100644 index 0000000..001249b --- /dev/null +++ b/lib/dedup_tactics.ml @@ -0,0 +1,56 @@ +open Ast +open Abslayout +module P = Pred.Infix + +module Config = struct + module type S = sig + include Ops.Config.S + include Tactics_util.Config.S + end +end + +module Make (C : Config.S) = struct + module O = Ops.Make (C) + module Tactics_util = Tactics_util.Make (C) + + let to_dedup r = match r.node with Dedup r -> Some r | _ -> None + + let elim_dedup r = + let open Option.Let_syntax in + let%bind r = to_dedup r in + match Cardinality.estimate r with + | Abs_int.Interval (_, h) when h <= 1 -> Some r + | _ -> None + + let elim_dedup = O.of_func elim_dedup ~name:"elim-dedup" + + let lhs_visible lhs rhs = + Set.is_subset + (Set.of_list (module Name) (Schema.schema lhs)) + ~of_:(Set.of_list (module Name) (Schema.schema rhs)) + + let push_dedup r = + let open Option.Let_syntax in + let%bind r = to_dedup r in + match r.node with + | Filter (p, r') -> Some (filter p (dedup r')) + | Dedup r' -> Some (dedup r') + | AScalar _ | AEmpty -> Some r + | AHashIdx h -> return @@ hash_idx' { h with hi_values = dedup h.hi_values } + | AList ({ l_keys = lhs; l_values = rhs; _ } as l) when lhs_visible lhs rhs + -> + return @@ list' { l with l_keys = dedup lhs; l_values = dedup rhs } + | DepJoin ({ d_lhs = lhs; d_rhs = rhs; _ } as d) when lhs_visible lhs rhs -> + return @@ dep_join' { d with d_lhs = dedup lhs; d_rhs = dedup rhs } + | AOrderedIdx o -> + (* TODO: This transform isn't correct unless the right hand sides are + non-overlapping. *) + return + @@ ordered_idx' + { o with oi_keys = dedup o.oi_keys; oi_values = dedup o.oi_values } + | ATuple (ts, Cross) -> Some (tuple (List.map ts ~f:dedup) Cross) + | Select _ -> None + | _ -> None + + let push_dedup = O.of_func push_dedup ~name:"push-dedup" +end diff --git a/lib/explain.ml b/lib/explain.ml new file mode 100644 index 0000000..8dcdddb --- /dev/null +++ b/lib/explain.ml @@ -0,0 +1,28 @@ +open Yojson.Basic +open Postgresql + +type t = { nrows : int; cost : float } + +let explain (conn : Db.t) query = + let open Result.Let_syntax in + let r : result = + (Db.conn conn)#exec (sprintf "explain (format json) %s" query) + in + let%bind json_str = + match r#status with + | Single_tuple | Tuples_ok -> Ok (r#getvalue 0 0) + | _ -> + let status = result_status r#status in + Result.fail + Error.( + create "Postgres error." (status, r#error, query) + [%sexp_of: string * string * string]) + in + let json = from_string json_str in + try + let plan = Util.to_list json |> List.hd_exn |> Util.member "Plan" in + let nrows = Util.member "Plan Rows" plan |> Util.to_int in + let cost = Util.member "Total Cost" plan |> Util.to_number in + Ok { nrows; cost } + with Util.Type_error _ as e -> + Result.fail Error.(of_exn e |> tag ~tag:json_str) diff --git a/lib/filter_tactics.ml b/lib/filter_tactics.ml new file mode 100644 index 0000000..485c970 --- /dev/null +++ b/lib/filter_tactics.ml @@ -0,0 +1,1300 @@ +open Ast +open Collections +module A = Abslayout +module P = Pred.Infix +module V = Visitors +open Match + +(** Enable partitioning when a parameter is used in a range predicate. *) +let enable_partition_cmp = ref false + +module Config = struct + module type My_S = sig + val params : Set.M(Name).t + end + + module type S = sig + include Ops.Config.S + include Simplify_tactic.Config.S + include Tactics_util.Config.S + include My_S + end +end + +module Make (C : Config.S) = struct + open Ops.Make (C) + open Simplify_tactic.Make (C) + module Tactics_util = Tactics_util.Make (C) + module My_C : Config.My_S = C + open My_C + + let fresh_name = Fresh.name Global.fresh + let schema_set r = Schema.schema r |> Set.of_list (module Name) + + (** Split predicates that sit under a binder into the parts that depend on + bound variables and the parts that don't. *) + let split_bound binder p = + List.partition_tf (Pred.conjuncts p) ~f:(fun p' -> + overlaps (Free.pred_free p') (schema_set binder)) + + (** Check that a predicate is supported by a relation (it does not depend on + anything in the context that it did not previously depend on.) *) + let invariant_support orig_bound new_bound pred = + let supported = Set.inter (Free.pred_free pred) orig_bound in + Set.is_subset supported ~of_:new_bound + + let filter_many ps r = + if List.is_empty ps then r else A.filter (Pred.conjoin ps) r + + let hoist_filter r = + let open Option.Let_syntax in + match r.node with + | OrderBy { key; rel } -> + let%map p, r = to_filter rel in + A.filter p (A.order_by key r) + | GroupBy (ps, key, r) -> + let%bind p, r = to_filter r in + if invariant_support (schema_set r) (schema_set (A.group_by ps key r)) p + then Some (A.filter p (A.group_by ps key r)) + else None + | Filter (p', r) -> + let%map p, r = to_filter r in + A.filter (Binop (And, p, p')) r + | Select (ps, r) -> ( + let%bind p, r = to_filter r in + match A.select_kind ps with + | `Scalar -> + if Tactics_util.select_contains (Free.pred_free p) ps r then + Some (A.filter p (A.select ps r)) + else None + | `Agg -> None) + | Join { pred; r1; r2 } -> ( + match (to_filter r1, to_filter r2) with + | Some (p1, r1), Some (p2, r2) -> + Some (filter_many [ p1; p2 ] (A.join pred r1 r2)) + | None, Some (p, r2) -> Some (A.filter p (A.join pred r1 r2)) + | Some (p, r1), None -> Some (A.filter p (A.join pred r1 r2)) + | None, None -> None) + | Dedup r -> + let%map p, r = to_filter r in + A.filter p (A.dedup r) + | AList l -> + let%map p, r = to_filter l.l_values in + A.filter (Pred.unscoped l.l_scope p) (A.list' { l with l_values = r }) + | AHashIdx ({ hi_keys = rk; hi_values = rv; _ } as h) -> + let%map p, r = to_filter rv in + let below, above = split_bound rk p in + let above = List.map above ~f:(Pred.unscoped h.hi_scope) in + filter_many above + @@ A.hash_idx' { h with hi_values = filter_many below r } + | AOrderedIdx o -> + let%map p, r = to_filter o.oi_values in + let below, above = split_bound o.oi_keys p in + let above = List.map above ~f:(Pred.unscoped o.oi_scope) in + filter_many above + @@ A.ordered_idx' { o with oi_values = filter_many below r } + | DepJoin _ | Relation _ | AEmpty | AScalar _ | ATuple _ | Range _ -> None + + let hoist_filter = of_func hoist_filter ~name:"hoist-filter" + + let hoist_filter_agg r = + let open Option.Let_syntax in + match r.node with + | Select (ps, r) -> ( + let%bind p, r = to_filter r in + match A.select_kind ps with + | `Scalar -> None + | `Agg -> + if + Tactics_util.select_contains + (Set.diff (Free.pred_free p) params) + ps r + then Some (A.filter p (A.select ps r)) + else None) + | _ -> None + + let hoist_filter_agg = of_func hoist_filter_agg ~name:"hoist-filter-agg" + + let hoist_filter_extend r = + let open Option.Let_syntax in + match r.node with + | Select (ps, r) -> ( + let%bind p, r = to_filter r in + match A.select_kind ps with + | `Scalar -> + let ext = + Free.pred_free p |> Set.to_list + |> List.filter ~f:(fun f -> + not + (Tactics_util.select_contains + (Set.singleton (module Name) f) + ps r)) + |> List.map ~f:(fun n -> Name n) + in + Some (A.filter p @@ A.select (ps @ ext) r) + | `Agg -> None) + | GroupBy (ps, key, r) -> + let%bind p, r = to_filter r in + let ext = + let key_preds = List.map ~f:(fun n -> Name n) key in + Free.pred_free p |> Set.to_list + |> List.filter ~f:(fun f -> + (not (Set.mem params f)) + && (not + (Tactics_util.select_contains + (Set.singleton (module Name) f) + key_preds r)) + && List.mem key ~equal:[%compare.equal: Name.t] f) + |> List.map ~f:(fun n -> Name n) + in + Some (A.filter p @@ A.group_by (ps @ ext) key r) + | _ -> None + + let hoist_filter_extend = + of_func hoist_filter_extend ~name:"hoist-filter-extend" + + let split_filter r = + match r.node with + | Filter (Binop (And, p, p'), r) -> Some (A.filter p (A.filter p' r)) + | _ -> None + + let split_filter = of_func split_filter ~name:"split-filter" + + let split_filter_params r = + match r.node with + | Filter (p, r) -> + let has_params, no_params = + Pred.conjuncts p + |> List.partition_tf ~f:(fun p -> + Set.inter (Free.pred_free p) params |> Set.is_empty |> not) + in + if List.is_empty has_params || List.is_empty no_params then None + else + Some + (A.filter (Pred.conjoin has_params) + @@ A.filter (Pred.conjoin no_params) r) + | _ -> None + + let split_filter_params = + of_func split_filter_params ~name:"split-filter-params" + + let rec first_ok = function + | Ok x :: _ -> Some x + | _ :: xs -> first_ok xs + | [] -> None + + let qualify rn p = + let visitor = + object + inherit [_] V.endo + method! visit_Name () _ n = Name (Name.copy ~scope:(Some rn) n) + end + in + visitor#visit_pred () p + + let gen_ordered_idx ?lb ?ub p rk rv = + let k = fresh_name "k%d" in + let n = fresh_name "x%d" in + A.ordered_idx + (A.dedup (A.select [ P.as_ p n ] rk)) + k + (A.filter (Binop (Eq, Name (Name.create ~scope:k n), p)) rv) + [ (lb, ub) ] + + (** A predicate `p` is a candidate lookup key into a partitioning of `r` if it + does not depend on any of the fields in `r`. + + TODO: In practice we also want it to have a parameter in it. Is this correct? *) + let is_candidate_key p r = + let pfree = Free.pred_free p in + (not (overlaps (schema_set r) pfree)) && overlaps params pfree + + (** A predicate is a candidate to be matched if all its free variables are + bound by the relation that it is above. *) + let is_candidate_match p r = + Set.is_subset (Free.pred_free p) ~of_:(schema_set r) + + let elim_cmp_filter r = + match r.node with + | Filter (p, r') -> ( + let orig_schema = Schema.schema r' in + + (* Select the comparisons which have a parameter on exactly one side and + partition by the unparameterized side of the comparison. *) + let cmps, rest = + Pred.conjuncts p + |> List.partition_map ~f:(function + | (Binop (Gt, p1, p2) | Binop (Lt, p2, p1)) as p -> + if is_candidate_key p1 r' && is_candidate_match p2 r' then + First (p2, (`Lt, p1)) + else if is_candidate_key p2 r' && is_candidate_match p1 r' + then First (p1, (`Gt, p2)) + else Second p + | (Binop (Ge, p1, p2) | Binop (Le, p2, p1)) as p -> + if is_candidate_key p1 r' && is_candidate_match p2 r' then + First (p2, (`Le, p1)) + else if is_candidate_key p2 r' && is_candidate_match p1 r' + then First (p1, (`Ge, p2)) + else Second p + | p -> Second p) + in + let cmps, rest' = + Map.of_alist_multi (module Pred) cmps + |> Map.to_alist + |> List.map ~f:(fun (key, bounds) -> + let lb, rest = + let open_lb = + List.filter_map bounds ~f:(fun (f, p) -> + match f with `Gt -> Some p | _ -> None) + in + let closed_lb = + List.filter_map bounds ~f:(fun (f, p) -> + match f with `Ge -> Some p | _ -> None) + in + match + (List.length open_lb = 0, List.length closed_lb = 0) + with + | true, true -> (None, []) + | _, true -> + ( Option.map (List.reduce ~f:Pred.max_of open_lb) + ~f:(fun max -> (max, `Open)), + [] ) + | _ -> + ( Option.map + (List.reduce ~f:Pred.max_of (open_lb @ closed_lb)) + ~f:(fun max -> (max, `Closed)), + List.map open_lb ~f:(fun v -> P.(key > v)) ) + in + let ub, rest' = + let open_ub = + List.filter_map bounds ~f:(fun (f, p) -> + match f with `Lt -> Some p | _ -> None) + in + let closed_ub = + List.filter_map bounds ~f:(fun (f, p) -> + match f with `Le -> Some p | _ -> None) + in + match + (List.length open_ub = 0, List.length closed_ub = 0) + with + | true, true -> (None, []) + | _, true -> + ( Option.map (List.reduce ~f:Pred.min_of open_ub) + ~f:(fun p -> (p, `Open)), + [] ) + | _ -> + ( Option.map + (List.reduce ~f:Pred.min_of (open_ub @ closed_ub)) + ~f:(fun p -> (p, `Closed)), + List.map open_ub ~f:(fun v -> P.(key > v)) ) + in + ((key, (lb, ub)), rest @ rest')) + |> List.unzip + in + let rest = rest @ List.concat rest' in + let key, cmps = List.unzip cmps in + let x = + let open Or_error.Let_syntax in + if List.is_empty key then + Or_error.error_string "No candidate keys found." + else + let%map all_keys = Tactics_util.all_values key r' in + let scope = fresh_name "s%d" in + let keys_schema = Schema.schema all_keys in + A.select (Schema.to_select_list orig_schema) + @@ A.ordered_idx all_keys scope + (filter_many + (List.map key ~f:(fun p -> + P.(p = Pred.scoped keys_schema scope p))) + r') + cmps + in + match x with + | Ok r -> Seq.singleton (filter_many rest r) + | Error err -> + Logs.warn (fun m -> m "Elim-cmp: %a" Error.pp err); + Seq.empty) + | _ -> Seq.empty + + let elim_cmp_filter = + Branching.(local elim_cmp_filter ~name:"elim-cmp-filter") + + (** Eliminate a filter with one parameter and one attribute. *) + let elim_simple_filter r = + let open Option.Let_syntax in + let%bind p, r = to_filter r in + let%bind r = if Is_serializable.is_static ~params r then Some r else None in + let p = (p :> Pred.t) and r = (r :> Ast.t) in + let names = Pred.names p |> Set.to_list in + + let is_param = Set.mem params in + let is_field n = + Db.relation_has_field C.conn (Name.name n) |> Option.is_some + in + let%bind _param, attr = + match names with + | [ n1; n2 ] when is_param n1 && is_field n2 -> return (n1, n2) + | [ n2; n1 ] when is_param n1 && is_field n2 -> return (n1, n2) + | _ -> None + in + let scope = Fresh.name Global.fresh "s%d" in + let sattr = Name.scoped scope attr in + let select_list = + Schema.schema r + |> List.filter ~f:(fun n -> Name.O.(n <> attr)) + |> Schema.to_select_list + in + return + @@ A.list (A.dedup @@ A.select [ Name attr ] r) scope + @@ A.tuple + [ + A.filter p @@ A.scalar (Name sattr); + A.select select_list @@ A.filter P.(Name attr = Name sattr) r; + ] + Cross + + let elim_simple_filter = + of_func_pre ~pre:Is_serializable.annotate_stage elim_simple_filter + ~name:"elim-simple-filter" + + let push_filter_cross_tuple stage p rs = + let ps = Pred.conjuncts p in + (* Find the earliest placement for each predicate. *) + let preds = Array.create ~len:(List.length rs) [] in + let rec place_all ps i = + if i >= List.length rs then ps + else + let bnd = + List.nth_exn rs i |> Schema.schema |> Set.of_list (module Name) + in + let pl, up = + List.partition_tf ps ~f:(Tactics_util.is_supported stage bnd) + in + preds.(i) <- pl; + place_all up (i + 1) + in + let rest = place_all ps 0 in + let rs = List.mapi rs ~f:(fun i -> filter_many preds.(i)) in + filter_many rest (A.tuple rs Cross) + + let push_filter_list stage p l = + let rk_bnd = Set.of_list (module Name) (Schema.schema l.l_keys) in + let pushed_key, pushed_val = + Pred.conjuncts p + |> List.partition_map ~f:(fun p -> + if Tactics_util.is_supported stage rk_bnd p then First p + else Second p) + in + A.list + (filter_many pushed_key l.l_keys) + l.l_scope + (filter_many pushed_val l.l_values) + + let push_filter_select stage p ps r = + match A.select_kind ps with + | `Scalar -> + let ctx = + List.filter_map ps ~f:(fun p -> + Option.map (Pred.to_name p) ~f:(fun n -> (n, Pred.remove_as p))) + |> Map.of_alist_exn (module Name) + in + let p' = Pred.subst ctx p in + A.select ps (A.filter p' r) + | `Agg -> + let scalar_ctx = + List.filter_map ps ~f:(fun p -> + if Poly.(Pred.kind p = `Scalar) then + Option.map (Pred.to_name p) ~f:(fun n -> (n, Pred.remove_as p)) + else None) + |> Map.of_alist_exn (module Name) + in + let names = Map.keys scalar_ctx |> Set.of_list (module Name) in + let pushed, unpushed = + Pred.conjuncts p + |> List.partition_map ~f:(fun p -> + if Tactics_util.is_supported stage names p then + First (Pred.subst scalar_ctx p) + else Second p) + in + filter_many unpushed @@ A.select ps @@ filter_many pushed r + + let push_filter r = + let open Option.Let_syntax in + let stage = r.meta#stage in + let r = strip_meta r in + let%bind p, r = to_filter r in + match r.node with + | Filter (p', r') -> Some (A.filter (Binop (And, p, p')) r') + | Dedup r' -> Some (A.dedup (A.filter p r')) + | Select (ps, r') -> Some (push_filter_select stage p ps r') + | ATuple (rs, Concat) -> Some (A.tuple (List.map rs ~f:(A.filter p)) Concat) + | ATuple (rs, Cross) -> Some (push_filter_cross_tuple stage p rs) + (* Lists are a special case because their keys are bound at compile time and + are not available at runtime. *) + | AList l -> Some (push_filter_list stage p l) + | _ -> + let%map rk, scope, rv, mk = + match r.node with + | DepJoin { d_lhs = rk; d_rhs = rv; d_alias } -> + Some (rk, d_alias, rv, A.dep_join) + | AList l -> Some (l.l_keys, l.l_scope, l.l_values, A.list) + | AHashIdx h -> + Some + ( h.hi_keys, + h.hi_scope, + h.hi_values, + fun rk s rv -> + A.hash_idx' + { h with hi_keys = rk; hi_scope = s; hi_values = rv } ) + | AOrderedIdx o -> + Some + ( o.oi_keys, + o.oi_scope, + o.oi_values, + fun rk s rv -> + A.ordered_idx' + { o with oi_keys = rk; oi_scope = s; oi_values = rv } ) + | _ -> None + in + let rk_bnd = Set.of_list (module Name) (Schema.schema rk) in + let pushed_key, pushed_val = + Pred.conjuncts p + |> List.partition_map ~f:(fun p -> + if Tactics_util.is_supported stage rk_bnd p then First p + else Second p) + in + let pushed_val = + List.map pushed_val ~f:(Pred.scoped (Set.to_list rk_bnd) scope) + in + mk (filter_many pushed_key rk) scope (filter_many pushed_val rv) + + let push_filter = + (* NOTE: Simplify is necessary to make push-filter safe under fixpoints. *) + seq' + (of_func_cond ~name:"push-filter" + ~pre:(fun r -> Some (Resolve.resolve_exn ~params r)) + push_filter + ~post:(fun r -> Resolve.resolve ~params r |> Result.ok)) + simplify + + let elim_eq_filter_src = + let src = Logs.Src.create "elim-eq-filter" in + Logs.Src.set_level src (Some Warning); + src + + let contains_not p = + let visitor = + object (self) + inherit [_] V.reduce + inherit [_] Util.disj_monoid + + method! visit_Unop () op p = + match op with Not -> true | _ -> self#visit_pred () p + end + in + visitor#visit_pred () p + + let is_eq_subtree p = + let visitor = + object (self) + inherit [_] V.reduce + inherit [_] Util.conj_monoid + + method! visit_Binop () op p1 p2 = + match op with + | And | Or -> self#visit_pred () p1 && self#visit_pred () p2 + | Eq -> true + | _ -> false + + method! visit_Unop () op p = + match op with Not -> false | _ -> self#visit_pred () p + end + in + visitor#visit_pred () p + + (** Domain computations for predicates containing conjunctions, disjunctions + and equalities. *) + module EqDomain = struct + type domain = + | And of domain * domain + | Or of domain * domain + | Domain of Ast.t + [@@deriving compare] + + type t = domain Map.M(Pred).t + + let intersect d1 d2 = + Map.merge d1 d2 ~f:(fun ~key:_ v -> + let ret = + match v with + | `Both (d1, d2) -> + if [%compare.equal: domain] d1 d2 then d1 else And (d1, d2) + | `Left d | `Right d -> d + in + Some ret) + + let union d1 d2 = + Map.merge d1 d2 ~f:(fun ~key:_ v -> + let ret = + match v with + | `Both (d1, d2) -> + if [%compare.equal: domain] d1 d2 then d1 else Or (d1, d2) + | `Left d | `Right d -> d + in + Some ret) + + let rec of_pred r = + let open Or_error.Let_syntax in + function + | Binop (And, p1, p2) -> + let%bind ds1 = of_pred r p1 in + let%map ds2 = of_pred r p2 in + intersect ds1 ds2 + | Binop (Or, p1, p2) -> + let%bind ds1 = of_pred r p1 in + let%map ds2 = of_pred r p2 in + union ds1 ds2 + | Binop (Eq, p1, p2) -> ( + match + (Tactics_util.all_values [ p1 ] r, Tactics_util.all_values [ p2 ] r) + with + | _, Ok vs2 when is_candidate_key p1 r && is_candidate_match p2 r -> + Ok (Map.singleton (module Pred) p1 (Domain vs2)) + | Ok vs1, _ when is_candidate_key p2 r && is_candidate_match p1 r -> + Ok (Map.singleton (module Pred) p2 (Domain vs1)) + | _, Ok _ | Ok _, _ -> + Or_error.error "No candidate keys." (p1, p2) + [%sexp_of: Pred.t * Pred.t] + | Error e1, Error e2 -> Error (Error.of_list [ e1; e2 ])) + | p -> + Or_error.error "Not part of an equality predicate." p + [%sexp_of: Pred.t] + + let to_ralgebra d = + let schema r = List.hd_exn (Schema.schema r) in + let rec extract = function + | And (d1, d2) -> + let e1 = extract d1 and e2 = extract d2 in + let n1 = schema e1 and n2 = schema e2 in + A.dedup @@ A.select [ Name n1 ] + @@ A.join (Binop (Eq, Name n1, Name n2)) e1 e2 + | Or (d1, d2) -> + let e1 = extract d1 and e2 = extract d2 in + let n1 = schema e1 and n2 = schema e2 and n = fresh_name "x%d" in + A.dedup + @@ A.tuple + [ + A.select [ P.as_ (Name n1) n ] e1; + A.select [ P.as_ (Name n2) n ] e2; + ] + Concat + | Domain d -> + let n = schema d and n' = fresh_name "x%d" in + A.select [ P.as_ (Name n) n' ] d + in + Map.map d ~f:extract + end + + let elim_eq_filter_limit = 3 + + let elim_eq_check_limit n = + if n > elim_eq_filter_limit then ( + Logs.info ~src:elim_eq_filter_src (fun m -> + m "Would need to join too many relations (%d > %d)" n + elim_eq_filter_limit); + None) + else Some () + + let elim_eq_filter r = + let open Option.Let_syntax in + let%bind p, r = to_filter r in + let orig_schema = Schema.schema r in + + (* Extract equalities from the filter predicate. *) + let eqs, rest = + Pred.to_nnf p |> Pred.conjuncts + |> List.partition_map ~f:(fun p -> + match EqDomain.of_pred r p with + | Ok d -> First (p, d) + | Error e -> + Logs.info ~src:elim_eq_filter_src (fun m -> m "%a" Error.pp e); + Second p) + in + + let inner, eqs = List.unzip eqs in + let eqs = List.reduce ~f:EqDomain.intersect eqs + and inner = Pred.conjoin inner in + match eqs with + | None -> + Logs.info ~src:elim_eq_filter_src (fun m -> m "Found no equalities."); + None + | Some eqs -> + let eqs = EqDomain.to_ralgebra eqs in + let key, rels = Map.to_alist eqs |> List.unzip in + + let%map () = elim_eq_check_limit (List.length rels) in + + let r_keys = A.dedup (A.tuple rels Cross) in + let scope = fresh_name "s%d" in + let inner_filter_pred = + let ctx = + Map.map eqs ~f:(fun r -> + Schema.schema r |> List.hd_exn |> Name.scoped scope |> P.name) + in + Pred.subst_tree ctx inner + and select_list = Schema.to_select_list orig_schema in + A.select select_list @@ filter_many rest + @@ A.hash_idx r_keys scope (A.filter inner_filter_pred r) key + + let elim_eq_filter = + seq' (of_func elim_eq_filter ~name:"elim-eq-filter") (try_ filter_const) + + let elim_disjunct r = + let open Option.Let_syntax in + let%bind p, r = to_filter r in + let clauses = Pred.disjuncts p in + if List.length clauses > 1 then + let%bind all_disjoint = + Tactics_util.all_disjoint + (List.map ~f:(Pred.to_static ~params) clauses) + r + |> Or_error.ok + in + if all_disjoint && List.length clauses > 1 then + Some (A.tuple (List.map clauses ~f:(fun p -> A.filter p r)) Concat) + else None + else None + + let elim_disjunct = of_func elim_disjunct ~name:"elim-disjunct" + + let eq_bound n ps = + List.find_map + ~f:(function + | Binop (Eq, Name n', p) when Name.O.(n' = n) -> Some p + | Binop (Eq, p, Name n') when Name.O.(n' = n) -> Some p + | _ -> None) + ps + + let to_lower_bound n ps = + let cmp = + if !enable_partition_cmp then + List.find_map + ~f:(function + | Binop (Lt, Name n', p) when Name.O.(n' = n) -> Some p + | Binop (Le, Name n', p) when Name.O.(n' = n) -> Some p + | Binop (Gt, p, Name n') when Name.O.(n' = n) -> Some p + | Binop (Ge, p, Name n') when Name.O.(n' = n) -> Some p + | Binop (Lt, Binop (Add, Name n', p'), p) when Name.O.(n' = n) -> + Some (Binop (Sub, p, p')) + | Binop (Le, Binop (Add, Name n', p'), p) when Name.O.(n' = n) -> + Some (Binop (Sub, p, p')) + | Binop (Gt, p, Binop (Add, Name n', p')) when Name.O.(n' = n) -> + Some (Binop (Sub, p, p')) + | Binop (Ge, p, Binop (Add, Name n', p')) when Name.O.(n' = n) -> + Some (Binop (Sub, p, p')) + | _ -> None) + ps + else None + in + Option.first_some (eq_bound n ps) cmp + + let to_upper_bound n ps = + let cmp = + if !enable_partition_cmp then + List.find_map + ~f:(function + | Binop (Lt, p, Name n') when Name.O.(n' = n) -> Some p + | Binop (Le, p, Name n') when Name.O.(n' = n) -> Some p + | Binop (Gt, Name n', p) when Name.O.(n' = n) -> Some p + | Binop (Ge, Name n', p) when Name.O.(n' = n) -> Some p + | Binop (Lt, p, Binop (Add, Name n', p')) when Name.O.(n' = n) -> + Some (Binop (Add, p, p')) + | Binop (Le, p, Binop (Add, Name n', p')) when Name.O.(n' = n) -> + Some (Binop (Add, p, p')) + | Binop (Gt, Binop (Add, Name n', p'), p) when Name.O.(n' = n) -> + Some (Binop (Add, p, p')) + | Binop (Ge, Binop (Add, Name n', p'), p) when Name.O.(n' = n) -> + Some (Binop (Add, p, p')) + | _ -> None) + ps + else None + in + Option.first_some (eq_bound n ps) cmp + + let to_range n ps = (to_lower_bound n ps, to_upper_bound n ps) + + let relevant_conjuncts r n = + let visitor = + object + inherit [_] V.reduce as super + inherit [_] Util.list_monoid + + method! visit_Filter () (p, r) = + super#visit_Filter () (p, r) + @ (Pred.conjuncts p + |> List.filter ~f:(fun p -> Set.mem (Pred.names p) n)) + end + in + visitor#visit_t () r + + let subst_no_subquery ctx = + let v = + object + inherit [_] V.endo + + method! visit_Name _ this v = + match Map.find ctx v with Some x -> x | None -> this + + method! visit_Exists _ this _ = this + method! visit_First _ this _ = this + end + in + v#visit_t () + + let partition_with_bounds field aliases lo hi r n = + let open Option.Let_syntax in + let key_name = fresh_name "k%d" in + let%bind keys = + match Pred.to_type field with + | IntT _ | DateT _ -> + let%map vals = Tactics_util.all_values [ field ] r |> Or_error.ok in + let vals = + let val_name = List.hd_exn (Schema.schema vals) in + let select_list = + let alias_binds = + List.filter_map aliases ~f:Fun.id + |> List.map ~f:(fun n -> P.as_ (Name val_name) @@ Name.name n) + in + P.name val_name :: alias_binds + in + A.select select_list vals + and scope = fresh_name "k%d" in + + let open P in + A.dep_join + (A.select [ as_ (Min lo) "lo"; as_ (Max hi) "hi" ] vals) + scope + @@ A.select [ as_ (name (Name.create "range")) key_name ] + @@ A.range + (name (Name.create ~scope "lo")) + (name (Name.create ~scope "hi")) + | StringT _ -> + let%map keys = Tactics_util.all_values [ field ] r |> Or_error.ok in + let select_list = + [ P.(as_ (name (List.hd_exn (Schema.schema keys))) key_name) ] + in + A.select select_list keys + | _ -> None + in + let scope = fresh_name "s%d" in + let r' = + let ctx = + Map.singleton + (module Name) + n + (P.name @@ Name.scoped scope @@ Name.create key_name) + in + subst_no_subquery ctx r + in + if Set.mem (A.names r') n then None + else + return + @@ A.select Schema.(schema r' |> to_select_list) + @@ A.hash_idx keys scope r' [ Name n ] + + let exists_correlated_subquery r n = + let zero = false and plus = ( || ) in + let rec annot in_subquery r = + V.Reduce.annot zero plus (query in_subquery) meta r + and meta _ = zero + and pred in_subquery = function + | Name n' -> in_subquery && [%compare.equal: Name.t] n n' + | Exists r | First r -> annot true r + | p -> V.Reduce.pred zero plus (annot in_subquery) (pred in_subquery) p + and query in_subquery q = + V.Reduce.query zero plus (annot in_subquery) (pred in_subquery) q + in + annot false r + + (** Try to partition a layout on values of an attribute. *) + let partition_on r n = + let open Option.Let_syntax in + let preds = relevant_conjuncts r n in + let key_range = to_range n preds in + let%bind fields = + List.map preds ~f:(fun p -> Set.remove (Pred.names p) n) + |> List.reduce ~f:Set.union |> Option.map ~f:Set.to_list + in + let fields = + let m = Tactics_util.alias_map r in + List.map fields ~f:(fun n -> + (* If n is an alias for a field in a base relation, find the name of + that field. *) + match Map.find m n with + | Some n' -> (n', Some n) + | None -> (Name n, None)) + |> Map.of_alist_multi (module Pred) + |> Map.to_alist + in + + let fail msg = + Logs.debug (fun m -> + m "Partition: %s %a" msg (Fmt.Dump.list Pred.pp) + (List.map ~f:Tuple.T2.get1 fields)) + in + let%bind f, aliases, l, h = + match (fields, key_range) with + | [ (f, aliases) ], (Some l, Some h) -> return (f, aliases, l, h) + | _, (None, _ | _, None) -> + fail "Could not find bounds for fields"; + None + | _ -> + fail "Found too many fields"; + None + in + (* Don't try to partition if there's a subquery that refers to the partition attribute. *) + if + Pred.names f |> Set.to_sequence + |> Sequence.exists ~f:(exists_correlated_subquery r) + then None + else partition_with_bounds f aliases l h r n + + let partition _ r = + Set.to_sequence params |> Seq.filter_map ~f:(partition_on r) + + let partition = Branching.global partition ~name:"partition" + let db_relation n = A.relation (Db.relation C.conn n) + + let partition_eq n = + let eq_preds r = + let visitor = + object + inherit [_] V.reduce + inherit [_] Util.list_monoid + + method! visit_Binop ps op arg1 arg2 = + if [%compare.equal: Pred.Binop.t] op Eq then (arg1, arg2) :: ps + else ps + end + in + visitor#visit_t [] r + in + + let replace_pred r p1 p2 = + let visitor = + object + inherit [_] V.endo as super + + method! visit_pred () p = + let p = super#visit_pred () p in + if [%compare.equal: Pred.t] p p1 then p2 else p + end + in + visitor#visit_t () r + in + + let open A in + let name = A.name_of_string_exn n in + let fresh_name = + Caml.Format.sprintf "%s_%s" (Name.to_var name) + (Fresh.name Global.fresh "%d") + in + let rel = Name.rel_exn name in + Branching.local ~name:"partition-eq" (fun r -> + let eqs = eq_preds r in + (* Any predicate that is compared for equality with the partition + field is a candidate for the hash table key. *) + let keys = + List.filter_map eqs ~f:(fun (p1, p2) -> + match (p1, p2) with + | Name n, _ when String.(Name.name n = Name.name name) -> Some p2 + | _, Name n when String.(Name.name n = Name.name name) -> Some p1 + | _ -> None) + in + let lhs = + dedup + (select + [ As_pred (Name (Name.copy ~scope:None name), fresh_name) ] + (db_relation rel)) + in + let scope = Fresh.name Global.fresh "s%d" in + let pred = + Binop + ( Eq, + Name (Name.copy ~scope:None name), + Name (Name.create fresh_name) ) + |> Pred.scoped (Schema.schema lhs) scope + in + let filtered_rel = filter pred (db_relation rel) in + List.map keys ~f:(fun k -> + (* The predicate that we chose as the key can be replaced by + `fresh_name`. *) + A.select Schema.(schema r |> to_select_list) + @@ hash_idx lhs scope + (replace_pred + (Tactics_util.replace_rel rel filtered_rel r) + k + (Name (Name.create ~scope fresh_name))) + [ k ]) + |> Seq.of_list) + + let partition_domain n_subst n_domain = + let open A in + let n_subst, n_domain = + (name_of_string_exn n_subst, name_of_string_exn n_domain) + in + let rel = Name.rel_exn n_domain in + let lhs = + dedup (select [ Name (Name.unscoped n_domain) ] (db_relation rel)) + in + let scope = Fresh.name Global.fresh "s%d" in + of_func ~name:"partition-domain" @@ fun r -> + let inner_select = + let lhs_schema = Schema.schema lhs in + Schema.schema r + |> List.filter ~f:(fun n -> + not (List.mem lhs_schema ~equal:[%compare.equal: Name.t] n)) + |> Schema.to_select_list + in + Option.return + @@ A.select Schema.(schema r |> to_select_list) + @@ hash_idx lhs scope + (select inner_select + (subst + (Map.singleton + (module Name) + n_subst + (Name (Name.scoped scope n_domain))) + r)) + [ Name n_subst ] + + (** Hoist subqueries out of the filter predicate and make them available by a + depjoin. Allows expensive subqueries to be computed once instead of many + times. *) + let elim_subquery p r = + let open Option.Let_syntax in + let stage = Is_serializable.stage ~params r in + let can_hoist r = + Free.free r + |> Set.for_all ~f:(fun n -> + Set.mem params n + || match stage n with `Compile -> true | _ -> false) + in + let scope = fresh_name "s%d" in + let visitor = + object + inherit Tactics_util.extract_subquery_visitor + method can_hoist = can_hoist + + method fresh_name () = + Name.create ~scope @@ Fresh.name Global.fresh "q%d" + end + in + let rhs, subqueries = visitor#visit_t () @@ Path.get_exn p r in + let subqueries = + List.map subqueries ~f:(fun (n, p) -> + A.select [ As_pred (p, Name.name n) ] + @@ A.scalar (As_pred (Int 0, "dummy"))) + in + let%map lhs = + match subqueries with + | [] -> None + | [ x ] -> Some x + | _ -> Some (A.tuple subqueries Cross) + in + Path.set_exn p r (A.dep_join lhs scope rhs) + + let elim_subquery = global elim_subquery "elim-subquery" + + (** Hoist subqueries out of the filter predicate and make them available by a + join. *) + let elim_subquery_join p r = + let open Option.Let_syntax in + let stage = Is_serializable.stage ~params r in + let can_hoist r = + Free.free r + |> Set.for_all ~f:(fun n -> + Set.mem params n + || match stage n with `Compile -> true | _ -> false) + in + let visitor = + object + inherit Tactics_util.extract_subquery_visitor + method can_hoist = can_hoist + method fresh_name () = Name.create @@ Fresh.name Global.fresh "q%d" + end + in + + let%bind pred, r' = to_filter @@ Path.get_exn p r in + let pred', subqueries = visitor#visit_pred () pred in + let subqueries = + List.filter_map subqueries ~f:(function + | n, First r -> ( + match Schema.schema r with + | [ n' ] -> return @@ A.select [ As_pred (Name n', Name.name n) ] r + | _ -> None) + | _ -> None) + in + let%map rhs = + match subqueries with + | [] -> None + | [ x ] -> Some x + | _ -> Some (A.tuple subqueries Cross) + in + Path.set_exn p r (A.join pred' r' rhs) + + let elim_subquery_join = global elim_subquery_join "elim-subquery-join" + + let elim_correlated_first_subquery r = + let open Option.Let_syntax in + let visitor = + object + inherit Tactics_util.extract_subquery_visitor + val mutable is_first = true + + method can_hoist _ = + if is_first then ( + is_first <- false; + true) + else false + + method fresh_name () = Name.create @@ Fresh.name Global.fresh "q%d" + end + in + let%bind p, r' = to_filter r in + let p', subqueries = visitor#visit_pred () p in + let%bind subquery_name, subquery = + match subqueries with + | [ s ] -> return s + | [] -> None + | _ -> failwith "expected one subquery" + in + let scope = Fresh.name Global.fresh "s%d" in + let schema = Schema.schema r' in + let ctx = + List.map schema ~f:(fun n -> (n, Name (Name.scoped scope n))) + |> Map.of_alist_exn (module Name) + in + let schema = Schema.scoped scope schema in + match subquery with + | Exists _ -> None + | First r -> + return @@ A.dep_join r' scope + @@ A.select (Schema.to_select_list schema) + @@ A.subst ctx @@ A.filter p' + @@ A.select + [ + P.as_ + (P.name @@ List.hd_exn @@ Schema.schema r) + (Name.name subquery_name); + ] + r + | _ -> failwith "not a subquery" + + let elim_correlated_first_subquery = + of_func elim_correlated_first_subquery + ~name:"elim-correlated-first-subquery" + + let elim_correlated_exists_subquery r = + let open Option.Let_syntax in + let%bind p, r' = to_filter r in + let p, subqueries = + Pred.conjuncts p + |> List.partition_map ~f:(function Exists r -> Second r | p -> First p) + in + let%bind p, subquery = + match subqueries with + | [] -> None + | s :: ss -> Some (Pred.conjoin (p @ List.map ss ~f:P.exists), s) + in + let scope = Fresh.name Global.fresh "s%d" in + let schema = Schema.schema r' in + let ctx = + List.map schema ~f:(fun n -> (n, Name (Name.scoped scope n))) + |> Map.of_alist_exn (module Name) + in + let schema = Schema.scoped scope schema in + let unscoped_schema = Schema.unscoped schema in + return @@ A.dep_join r' scope @@ A.dedup @@ A.filter p + @@ A.group_by (Schema.to_select_list unscoped_schema) unscoped_schema + @@ A.select (Schema.to_select_list schema) + @@ A.subst ctx subquery + + let elim_correlated_exists_subquery = + of_func elim_correlated_exists_subquery + ~name:"elim-correlated-exists-subquery" + + let elim_all_correlated_subqueries = + for_all + (first_success + [ elim_correlated_first_subquery; elim_correlated_exists_subquery ]) + Path.(all >>? is_filter >>? is_run_time) + + let unnest = + global + (fun _ r -> Some (Unnest.unnest ~params r |> Ast.strip_meta)) + "unnest" + + let simplify_filter r = + let open Option.Let_syntax in + match r.node with + | Filter (p, r) -> return @@ A.filter (Pred.simplify p) r + | Join { pred; r1; r2 } -> return @@ A.join (Pred.simplify pred) r1 r2 + | _ -> None + + let simplify_filter = of_func ~name:"simplify-filter" simplify_filter + + let precompute_filter_bv args = + let open A in + let values = List.map args ~f:Pred.of_string_exn in + let exception Failed of Error.t in + let run_exn r = + match r.node with + | Filter (p, r') -> + let schema = Schema.schema r' in + let free_vars = + Set.diff (Free.pred_free p) (Set.of_list (module Name) schema) + |> Set.to_list + in + let free_var = + match free_vars with + | [ v ] -> v + | _ -> + let err = + Error.of_string + "Unexpected number of free variables in predicate." + in + raise (Failed err) + in + let witness_name = Fresh.name Global.fresh "wit%d_" in + let witnesses = + List.mapi values ~f:(fun i v -> + As_pred + ( Pred.subst (Map.singleton (module Name) free_var v) p, + sprintf "%s_%d" witness_name i )) + in + let filter_pred = + List.foldi values ~init:p ~f:(fun i else_ v -> + If + ( Binop (Eq, Name free_var, v), + Name (Name.create (sprintf "%s_%d" witness_name i)), + else_ )) + in + let select_list = witnesses @ List.map schema ~f:(fun n -> Name n) in + Some (filter filter_pred (select select_list r')) + | _ -> None + in + of_func ~name:"precompute-filter-bv" @@ fun r -> + try run_exn r with Failed _ -> None + + (** Given a restricted parameter range, precompute a filter that depends on a + single table field. If the parameter is outside the range, then run the + original filter. Otherwise, check the precomputed evidence. *) + let precompute_filter field values = + let open A in + let field, values = + (A.name_of_string_exn field, List.map values ~f:Pred.of_string_exn) + in + let exception Failed of Error.t in + let run_exn r = + match r.node with + | Filter (p, r') -> + let schema = Schema.schema r' in + let free_vars = + Set.diff (Free.pred_free p) (Set.of_list (module Name) schema) + |> Set.to_list + in + let free_var = + match free_vars with + | [ v ] -> v + | _ -> + let err = + Error.of_string + "Unexpected number of free variables in predicate." + in + raise (Failed err) + in + let encoder = + List.foldi values ~init:(Int 0) ~f:(fun i else_ v -> + let witness = + Pred.subst (Map.singleton (module Name) free_var v) p + in + If (witness, Int (i + 1), else_)) + in + let decoder = + List.foldi values ~init:(Int 0) ~f:(fun i else_ v -> + If (Binop (Eq, Name free_var, v), Int (i + 1), else_)) + in + let fresh_name = Fresh.name Global.fresh "p%d_" ^ Name.name field in + let select_list = + As_pred (encoder, fresh_name) + :: List.map schema ~f:(fun n -> Name n) + in + Option.return + @@ filter + (If + ( Binop (Eq, decoder, Int 0), + p, + Binop (Eq, decoder, Name (Name.create fresh_name)) )) + @@ select select_list r' + | _ -> None + in + let f r = try run_exn r with Failed _ -> None in + of_func ~name:"precompute-filter" f + + let cse_filter r = + let open Option.Let_syntax in + let%bind p, r = to_filter r in + let p', binds = Pred.cse p in + if List.is_empty binds then None + else + return @@ A.filter p' + @@ A.select + (List.map binds ~f:(fun (n, p) -> P.as_ p (Name.name n)) + @ Schema.(schema r |> to_select_list)) + @@ r + + let cse_filter = of_func ~name:"cse-filter" cse_filter + + (* let precompute_filter n = + * let exception Failed of Error.t in + * let run_exn r = + * M.annotate_schema r; + * match r.node with + * | Filter (p, r') -> + * let schema = Meta.(find_exn r' schema) in + * let free_vars = + * Set.diff (pred_free p) (Set.of_list (module Name) schema) + * |> Set.to_list + * in + * let free_var = + * match free_vars with + * | [ v ] -> v + * | _ -> + * let err = + * Error.of_string + * "Unexpected number of free variables in predicate." + * in + * raise (Failed err) + * in + * let witness_name = Fresh.name fresh "wit%d_" in + * let witnesses = + * List.mapi values ~f:(fun i v -> + * As_pred + * ( subst_pred (Map.singleton (module Name) free_var v) p, + * sprintf "%s_%d" witness_name i )) + * in + * let filter_pred = + * List.foldi values ~init:p ~f:(fun i else_ v -> + * If + * ( Binop (Eq, Name free_var, v), + * Name (Name.create (sprintf "%s_%d" witness_name i)), + * else_ )) + * in + * let select_list = witnesses @ List.map schema ~f:(fun n -> Name n) in + * Some (filter filter_pred (select select_list r')) + * | _ -> None + * in + * let f r = try run_exn r with Failed _ -> None in + * of_func f ~name:"precompute-filter" *) +end diff --git a/lib/filter_tactics_test.ml b/lib/filter_tactics_test.ml new file mode 100644 index 0000000..f6ef282 --- /dev/null +++ b/lib/filter_tactics_test.ml @@ -0,0 +1,578 @@ +open Abslayout +open Abslayout_load +open Castor_test.Test_util + +module Test_db = struct + module C = struct + let params = + Set.of_list + (module Name) + [ + Name.create ~type_:Prim_type.int_t "param"; + Name.create ~type_:Prim_type.string_t "param1"; + ] + + let conn = Lazy.force test_db_conn + let cost_conn = Lazy.force test_db_conn + end + + module C_tpch = struct + let conn = Lazy.force tpch_conn + let cost_conn = Lazy.force tpch_conn + end + + open Filter_tactics.Make (C) + open Ops.Make (C) + + let load_string ?params s = load_string_exn ?params C.conn s + + let%expect_test "push-filter-comptime" = + let r = + load_string + "alist(r as r1, filter(r1.f = f, alist(r as r2, ascalar(r2.f))))" + in + Option.iter + (apply + (at_ push_filter Path.(all >>? is_filter >>| shallowest)) + Path.root r) + ~f:(Format.printf "%a\n" pp); + [%expect + {| alist(r as r1, alist(filter((r1.f = f), r) as r2, ascalar(r2.f))) |}] + + let%expect_test "push-filter-runtime" = + let r = + load_string + "depjoin(r as r1, filter(r1.f = f, alist(r as r2, ascalar(r2.f))))" + in + Option.iter + (apply + (at_ push_filter Path.(all >>? is_filter >>| shallowest)) + Path.root r) + ~f:(Format.printf "%a\n" pp); + [%expect + {| depjoin(r as r1, alist(r as r2, filter((r1.f = f), ascalar(r2.f)))) |}] + + let%expect_test "push-filter-support" = + let r = + load_string ~params:C.params + "filter(f > param, ahashidx(select([f], r) as k, ascalar(0 as x), 0))" + in + Option.iter + (apply + (at_ push_filter Path.(all >>? is_filter >>| shallowest)) + Path.root r) + ~f:(Format.printf "%a\n" pp); + [%expect + {| + ahashidx(select([f], r) as k, filter((k.f > param), ascalar(0 as x)), 0) |}] + + let%expect_test "push-filter-support" = + let r = + load_string + {| +alist(filter((0 = g), + depjoin(ascalar(0 as f) as k, + select([k.f, g], ascalar(0 as g)))) as k1, ascalar(0 as x)) + +|} + in + Option.iter + (apply + (at_ push_filter Path.(all >>? is_filter >>| shallowest)) + Path.root r) + ~f:(Format.printf "%a\n" pp); + [%expect + {| + alist(depjoin(ascalar(0 as f) as k, + filter((0 = g), select([k.f, g], ascalar(0 as g)))) as k1, + ascalar(0 as x)) |}] + + let%expect_test "push-filter-select" = + let r = + load_string "filter(test > 0, select([x as test], ascalar(0 as x)))" + in + Option.iter (apply push_filter Path.root r) ~f:(Format.printf "%a\n" pp); + [%expect {| select([x as test], filter((x > 0), ascalar(0 as x))) |}] + + let%expect_test "push-filter-select" = + let r = + load_string + "filter(a = b, select([(x - 1) as a, (x + 1) as b], ascalar(0 as x)))" + in + Option.iter (apply push_filter Path.root r) ~f:(Format.printf "%a\n" pp); + [%expect + {| + select([(x - 1) as a, (x + 1) as b], + filter(((x - 1) = (x + 1)), ascalar(0 as x))) |}] + + let%expect_test "elim-eq-filter" = + let r = + load_string ~params:C.params "depjoin(r as k, filter(k.f = param, r))" + in + Option.iter + (apply + (at_ elim_eq_filter Path.(all >>? is_filter >>| shallowest)) + Path.root r) + ~f:(Format.printf "%a\n" pp); + [%expect {| |}] + + let%expect_test "partition" = + let r = load_string ~params:C.params "filter(f = param, r)" in + Sequence.iter + (Branching.apply partition Path.root r) + ~f:(Format.printf "%a\n" pp); + [%expect + {| + select([f, g], + ahashidx(depjoin(select([min(f) as lo, max(f) as hi], + select([f], select([f], dedup(select([f], r))))) as k1, + select([range as k0], range(k1.lo, k1.hi))) as s0, + filter((f = s0.k0), r), + param)) + |}] + + let%expect_test "elim-eq-filter" = + let r = + load_string ~params:C.params + "filter(fresh = param, select([f as fresh], r))" + in + Option.iter (apply elim_eq_filter Path.root r) ~f:(Format.printf "%a\n" pp); + [%expect + {| + select([fresh], + ahashidx(dedup( + atuple([select([fresh as x0], + dedup(select([fresh], select([f as fresh], r))))], + cross)) as s0, + filter((fresh = s0.x0), select([f as fresh], r)), + param)) |}] + + let%expect_test "elim-eq-filter-approx" = + let r = + load_string ~params:C.params + "filter(fresh = param, select([f as fresh], filter(g = param, r)))" + in + Option.iter (apply elim_eq_filter Path.root r) ~f:(Format.printf "%a\n" pp); + [%expect + {| + select([fresh], + ahashidx(dedup( + atuple([select([fresh as x0], + select([f as fresh], dedup(select([f], r))))], + cross)) as s0, + filter((fresh = s0.x0), select([f as fresh], filter((g = param), r))), + param)) |}] + + let%expect_test "elim-eq-filter" = + let r = + load_string ~params:C.params + "filter((fresh = param) && true, select([f as fresh, g], r))" + in + Option.iter (apply elim_eq_filter Path.root r) ~f:(Format.printf "%a\n" pp); + [%expect + {| + select([fresh, g], + filter(true, + ahashidx(dedup( + atuple([select([fresh as x0], + dedup(select([fresh], select([f as fresh, g], r))))], + cross)) as s0, + filter((fresh = s0.x0), select([f as fresh, g], r)), + param))) |}] + + let%expect_test "elim-eq-filter" = + let r = + load_string ~params:C.params + "filter((fresh1 = param && fresh2 = (param +1)) || (fresh2 = param && \ + fresh1 = (param +1)), select([f as fresh1, g as fresh2], r))" + in + Option.iter (apply elim_eq_filter Path.root r) ~f:(Format.printf "%a\n" pp); + [%expect + {| + select([fresh1, fresh2], + ahashidx(dedup( + atuple([dedup( + atuple([select([x0 as x2], + select([fresh1 as x0], + dedup( + select([fresh1], + select([f as fresh1, g as fresh2], r))))), + select([x1 as x2], + select([fresh2 as x1], + dedup( + select([fresh2], + select([f as fresh1, g as fresh2], r)))))], + concat)), + dedup( + atuple([select([x3 as x5], + select([fresh2 as x3], + dedup( + select([fresh2], + select([f as fresh1, g as fresh2], r))))), + select([x4 as x5], + select([fresh1 as x4], + dedup( + select([fresh1], + select([f as fresh1, g as fresh2], r)))))], + concat))], + cross)) as s0, + filter((((fresh1 = s0.x2) && (fresh2 = s0.x5)) || + ((fresh2 = s0.x2) && (fresh1 = s0.x5))), + select([f as fresh1, g as fresh2], r)), + (param, (param + 1)))) |}] +end + +module Tpch = struct + module C = struct + let params = Set.empty (module Name) + let conn = Lazy.force tpch_conn + let cost_conn = Lazy.force tpch_conn + end + + open Filter_tactics.Make (C) + open Simplify_tactic.Make (C) + open Ops.Make (C) + + let load_string ?params s = load_string_exn ?params C.conn s + + let%expect_test "push-filter-tuple" = + let r = + Abslayout_load.load_string_exn (Lazy.force tpch_conn) + {| +filter((strpos(p_name, "") > 0), + atuple([ascalar(0 as x), + alist(select([s_suppkey], + select([s_suppkey, s_nationkey], supplier)) as s5, + alist(select([l_partkey, l_suppkey, l_quantity, + l_extendedprice, l_discount, o_orderdate, + p_name], + filter(((s5.s_suppkey = l_suppkey)), + select([l_partkey, l_suppkey, l_quantity, + l_extendedprice, l_discount, + o_orderdate, p_name], + join((p_partkey = l_partkey), + join((o_orderkey = l_orderkey), + lineitem, + orders), + part)))) as s4, + atuple([ascalar(s4.l_quantity), + ascalar(s4.l_extendedprice), + ascalar(s4.l_discount), + ascalar(s4.o_orderdate), ascalar(s4.p_name), + alist(select([ps_supplycost], + filter(((ps_partkey = s4.l_partkey) + && + ((ps_suppkey = s4.l_suppkey) + && + (ps_suppkey = s4.l_suppkey))), + partsupp)) as s3, + ascalar(s3.ps_supplycost))], + cross)))], + cross)) +|} + in + Option.iter (apply push_filter Path.root r) ~f:(Format.printf "%a\n" pp); + [%expect + {| + atuple([ascalar(0 as x), + filter((strpos(p_name, "") > 0), + alist(select([s_suppkey], supplier) as s5, + alist(select([l_partkey, l_suppkey, l_quantity, l_extendedprice, + l_discount, o_orderdate, p_name], + filter((s5.s_suppkey = l_suppkey), + select([l_partkey, l_suppkey, l_quantity, + l_extendedprice, l_discount, o_orderdate, + p_name], + join((p_partkey = l_partkey), + join((o_orderkey = l_orderkey), lineitem, orders), + part)))) as s4, + atuple([ascalar(s4.l_quantity), ascalar(s4.l_extendedprice), + ascalar(s4.l_discount), ascalar(s4.o_orderdate), + ascalar(s4.p_name), + alist(select([ps_supplycost], + filter(((ps_partkey = s4.l_partkey) && + ((ps_suppkey = s4.l_suppkey) && + (ps_suppkey = s4.l_suppkey))), + partsupp)) as s3, + ascalar(s3.ps_supplycost))], + cross))))], + cross) |}] + + let with_log src f = + Log.setup_log Error; + Logs.Src.set_level src (Some Debug); + Exn.protect ~f ~finally:(fun () -> Logs.Src.set_level src None) + + let%expect_test "elim-subquery" = + let r = + Abslayout_load.load_string_exn (Lazy.force tpch_conn) + {| +filter((0 = + (select([max(total_revenue_i) as tot], + alist(dedup(select([l_suppkey], lineitem)) as k1, + select([sum(agg2) as total_revenue_i], + aorderedidx(dedup(select([l_shipdate], lineitem)) as s41, + filter((count2 > 0), + select([count() as count2, + sum((l_extendedprice * (1 - l_discount))) as agg2], + alist(select([l_extendedprice, l_discount], + filter(((l_suppkey = k1.l_suppkey) && + (l_shipdate = s41.l_shipdate)), + lineitem)) as s42, + atuple([ascalar(s42.l_extendedprice), + ascalar(s42.l_discount)], + cross)))), + >= date("0000-01-01"), < (date("0000-01-01") + month(3)))))))), + ascalar(0 as x)) +|} + in + Option.iter (apply elim_subquery Path.root r) ~f:(Format.printf "%a\n" pp); + [%expect + {| + depjoin(select([(select([max(total_revenue_i) as tot], + alist(dedup(select([l_suppkey], lineitem)) as k1, + select([sum(agg2) as total_revenue_i], + aorderedidx(dedup(select([l_shipdate], lineitem)) as s41, + filter((count2 > 0), + select([count() as count2, + sum((l_extendedprice * (1 - l_discount))) as agg2], + alist(select([l_extendedprice, l_discount], + filter(((l_suppkey = k1.l_suppkey) && + (l_shipdate = s41.l_shipdate)), + lineitem)) as s42, + atuple([ascalar(s42.l_extendedprice), + ascalar(s42.l_discount)], + cross)))), + >= date("0000-01-01"), < (date("0000-01-01") + + month(3))))))) as q0], + ascalar(0 as dummy)) as s0, + filter((0 = s0.q0), ascalar(0 as x))) + |}] + + let%expect_test "hoist-filter" = + let r = + Abslayout_load.load_string_exn (Lazy.force tpch_conn) + {| +aorderedidx(select([l_shipdate, o_orderdate], + join(true, dedup(select([l_shipdate], lineitem)), dedup(select([o_orderdate], orders)))) as s68, + filter((c_mktsegment = ""), + alist(select([c_custkey, c_mktsegment], customer) as s65, + atuple([ascalar(s65.c_mktsegment), + alist(select([l_orderkey, l_discount, l_extendedprice, o_shippriority, o_orderdate], + join(true, + lineitem, + orders)) as s64, + atuple([ascalar(s64.l_orderkey), ascalar(s64.l_discount), + ascalar(s64.l_extendedprice), ascalar(s64.o_shippriority), + ascalar(s64.o_orderdate)], + cross))], + cross))), + > date("0000-01-01"), , , < date("0000-01-01")) +|} + in + Option.iter (apply hoist_filter Path.root r) ~f:(Fmt.pr "%a" pp); + [%expect + {| + filter((c_mktsegment = ""), + aorderedidx(select([l_shipdate, o_orderdate], + join(true, + dedup(select([l_shipdate], lineitem)), + dedup(select([o_orderdate], orders)))) as s68, + alist(select([c_custkey, c_mktsegment], customer) as s65, + atuple([ascalar(s65.c_mktsegment), + alist(select([l_orderkey, l_discount, l_extendedprice, + o_shippriority, o_orderdate], + join(true, lineitem, orders)) as s64, + atuple([ascalar(s64.l_orderkey), ascalar(s64.l_discount), + ascalar(s64.l_extendedprice), + ascalar(s64.o_shippriority), ascalar(s64.o_orderdate)], + cross))], + cross)), + > date("0000-01-01"), , , < date("0000-01-01"))) |}] + + let%expect_test "partition" = + let param1 = Name.create ~type_:Prim_type.string_t "param1" in + let params = Set.of_list (module Name) [ param1 ] in + let r = + load_string_exn (Lazy.force tpch_conn) ~params + {| +orderby([nation, o_year desc], + groupby([nation, o_year, sum(amount) as sum_profit], + [nation, o_year], + select([n_name as nation, to_year(o_orderdate) as o_year, + ((l_extendedprice * (1 - l_discount)) - (ps_supplycost * l_quantity)) as amount], + join((s_suppkey = l_suppkey), + join(((ps_suppkey = l_suppkey) && (ps_partkey = l_partkey)), + join((p_partkey = l_partkey), + filter((strpos(p_name, param1) > 0), part), + join((o_orderkey = l_orderkey), orders, lineitem)), + partsupp), + join((s_nationkey = n_nationkey), supplier, nation))))) +|} + in + Option.iter (partition_on r param1) ~f:(Fmt.pr "%a" pp) + + let%expect_test "elim-correlated-subquery" = + let r = + load_string_exn (Lazy.force tpch_conn) + {| +filter((ps_availqty > + (select([(0.5 * sum(l_quantity)) as tot], + filter(((l_partkey = ps_partkey) && + ((l_suppkey = ps_suppkey))), + lineitem)))), + partsupp) +|} + in + apply elim_correlated_first_subquery Path.root r + |> Fmt.pr "%a@." Fmt.(option pp); + [%expect + {| + depjoin(partsupp as s0, + select([s0.ps_partkey, s0.ps_suppkey, s0.ps_availqty, s0.ps_supplycost, + s0.ps_comment], + filter((s0.ps_availqty > q0), + select([tot as q0], + select([(0.5 * sum(l_quantity)) as tot], + filter(((l_partkey = s0.ps_partkey) && (l_suppkey = s0.ps_suppkey)), + lineitem)))))) |}] + + let%expect_test "elim-correlated-subquery" = + let r = + load_string_exn (Lazy.force tpch_conn) + {| +filter(exists(filter((ps_partkey = p_partkey), + filter((strpos(p_name, "test") = 1), part))), + partsupp) +|} + in + apply elim_correlated_exists_subquery Path.root r + |> Fmt.pr "%a@." Fmt.(option pp); + [%expect + {| + depjoin(partsupp as s0, + dedup( + filter(true, + groupby([ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment], + [ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment], + select([s0.ps_partkey, s0.ps_suppkey, s0.ps_availqty, + s0.ps_supplycost, s0.ps_comment], + filter((s0.ps_partkey = p_partkey), + filter((strpos(p_name, "test") = 1), part))))))) |}] + + let%expect_test "partition" = + let s_suppkey = Name.create "s_suppkey" in + let r = + load_string_exn + ~params:(Set.singleton (module Name) s_suppkey) + (Lazy.force tpch_conn) + {| +filter((s_suppkey = ps_suppkey), + filter(exists(filter((ps_partkey = p_partkey), filter((strpos(p_name, "test") = 1), part))), + filter((ps_availqty > + (select([(0.5 * sum(l_quantity)) as tot], + filter(((l_partkey = ps_partkey) && ((l_suppkey = ps_suppkey))), + lineitem)))), + partsupp))) +|} + in + partition_on r s_suppkey |> Option.iter ~f:(Fmt.pr "%a@." pp) + + let%expect_test "elim-correlated-subquery" = + let r = + load_string_exn (Lazy.force tpch_conn) + {| +dedup(select([s_name, s_address], + orderby([s_name], + join((s_nationkey = n_nationkey), + filter((n_name = "test"), nation), + filter(exists(filter((s_suppkey = ps_suppkey), + filter(exists(filter((ps_partkey = p_partkey), + filter((strpos(p_name, "test") = 1), + part))), + filter((ps_availqty > + (select([(0.5 * sum(l_quantity)) as tot], + filter(((l_partkey = ps_partkey) && + ((l_suppkey = ps_suppkey))), + lineitem)))), + partsupp)))), + supplier))))) +|} + in + apply + (seq_many [ elim_all_correlated_subqueries; unnest_and_simplify ]) + Path.root r + |> Fmt.pr "%a@." Fmt.(option pp); + [%expect + {| + dedup( + select([s_name, s_address], + orderby([s_name], + join((s_nationkey = n_nationkey), + filter((n_name = "test"), nation), + select([s_suppkey, s_name, s_address, s_nationkey, s_phone, + s_acctbal, s_comment], + groupby([s1_s_acctbal, s1_s_address, s1_s_comment, s1_s_name, + s1_s_nationkey, s1_s_phone, s1_s_suppkey, s_suppkey, + s_name, s_address, s_nationkey, s_phone, s_acctbal, + s_comment], + [s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, + s_comment, s1_s_acctbal, s1_s_address, s1_s_comment, s1_s_name, + s1_s_nationkey, s1_s_phone, s1_s_suppkey], + select([s1_s_acctbal, s1_s_address, s1_s_comment, s1_s_name, + s1_s_nationkey, s1_s_phone, s1_s_suppkey, + s1_s_suppkey as s_suppkey, s1_s_name as s_name, + s1_s_address as s_address, s1_s_nationkey as s_nationkey, + s1_s_phone as s_phone, s1_s_acctbal as s_acctbal, + s1_s_comment as s_comment], + join(((s1_s_suppkey = ps_suppkey) && true), + dedup( + select([s_acctbal as s1_s_acctbal, + s_address as s1_s_address, + s_comment as s1_s_comment, s_name as s1_s_name, + s_nationkey as s1_s_nationkey, + s_phone as s1_s_phone, s_suppkey as s1_s_suppkey], + supplier)), + select([ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, + ps_comment], + groupby([s3_ps_availqty, s3_ps_comment, s3_ps_partkey, + s3_ps_suppkey, s3_ps_supplycost, ps_partkey, + ps_suppkey, ps_availqty, ps_supplycost, ps_comment], + [ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, + ps_comment, s3_ps_availqty, s3_ps_comment, + s3_ps_partkey, s3_ps_suppkey, s3_ps_supplycost], + select([s3_ps_availqty, s3_ps_comment, s3_ps_partkey, + s3_ps_suppkey, s3_ps_supplycost, + s3_ps_partkey as ps_partkey, + s3_ps_suppkey as ps_suppkey, + s3_ps_availqty as ps_availqty, + s3_ps_supplycost as ps_supplycost, + s3_ps_comment as ps_comment], + join(((s3_ps_partkey = p_partkey) && true), + dedup( + select([s4_ps_availqty as s3_ps_availqty, + s4_ps_comment as s3_ps_comment, + s4_ps_partkey as s3_ps_partkey, + s4_ps_suppkey as s3_ps_suppkey, + s4_ps_supplycost as s3_ps_supplycost], + filter((s4_ps_availqty > q2), + groupby([s4_ps_availqty, s4_ps_comment, + s4_ps_partkey, s4_ps_suppkey, + s4_ps_supplycost, + (0.5 * sum(l_quantity)) as q2], + [s4_ps_availqty, s4_ps_comment, + s4_ps_partkey, s4_ps_suppkey, + s4_ps_supplycost], + join((((l_partkey = s4_ps_partkey) && + (l_suppkey = s4_ps_suppkey)) && true), + dedup( + select([ps_availqty as s4_ps_availqty, + ps_comment as s4_ps_comment, + ps_partkey as s4_ps_partkey, + ps_suppkey as s4_ps_suppkey, + ps_supplycost as s4_ps_supplycost], + partsupp)), + lineitem))))), + filter((strpos(p_name, "test") = 1), part))))))))))))) +|}] +end diff --git a/lib/groupby_tactics.ml b/lib/groupby_tactics.ml new file mode 100644 index 0000000..16666e2 --- /dev/null +++ b/lib/groupby_tactics.ml @@ -0,0 +1,129 @@ +open Visitors +open Collections +open Match +module A = Abslayout +module P = Pred.Infix +include (val Log.make ~level:(Some Warning) "castor-opt.groupby-tactics") + +module Config = struct + module type S = sig + include Ops.Config.S + include Tactics_util.Config.S + end +end + +module Make (C : Config.S) = struct + open Ops.Make (C) + open Tactics_util.Make (C) + + (** Eliminate a group by operator without representing duplicate key values. *) + let elim_groupby_flat r = + match r.node with + | GroupBy (ps, key, r) -> ( + let key_name = Fresh.name Global.fresh "k%d" in + let key_preds = List.map key ~f:P.name in + let filter_pred = + List.map key ~f:(fun n -> + Pred.Infix.(name n = name (Name.copy n ~scope:(Some key_name)))) + |> Pred.conjoin + in + let keys = A.dedup (A.select key_preds r) in + (* Try to remove any remaining parameters from the keys relation. *) + match over_approx C.params keys with + | Ok keys -> + let scalars, rest = + let schema = Schema.schema r in + List.partition_tf ps ~f:(fun p -> + match Pred.to_name p with + | Some n -> List.mem schema n ~equal:[%compare.equal: Name.t] + | None -> false) + in + let scalars = + List.map scalars ~f:(fun p -> + A.scalar @@ Pred.scoped key key_name p) + in + Option.return @@ A.list keys key_name + @@ A.tuple + (scalars @ [ A.select rest (A.filter filter_pred r) ]) + Cross + | Error err -> + info (fun m -> m "elim-groupby: %a" Error.pp err); + None) + (* Otherwise, if some keys are computed, fail. *) + | _ -> None + + let elim_groupby_flat = of_func elim_groupby_flat ~name:"elim-groupby-flat" + + let elim_groupby_approx r = + let open Option.Let_syntax in + let%bind ps, key, r = to_groupby r in + let key_name = Fresh.name Global.fresh "k%d" in + let key_preds = List.map key ~f:P.name in + let filter_pred = + List.map key ~f:(fun n -> + Pred.Infix.(name n = name (Name.copy n ~scope:(Some key_name)))) + |> Pred.conjoin + in + (* Try to remove any remaining parameters from the keys relation. *) + let%bind keys = + match all_values_approx key_preds r with + | Ok keys -> return @@ A.dedup keys + | Error err -> + (* Otherwise, if some keys are computed, fail. *) + info (fun m -> m "elim-groupby-approx: %a" Error.pp err); + None + in + return @@ A.list keys key_name (A.select ps (A.filter filter_pred r)) + + let elim_groupby_approx = + of_func elim_groupby_approx ~name:"elim-groupby-approx" + + let elim_groupby = elim_groupby_approx + let db_relation n = A.relation (Db.relation C.conn n) + + let elim_groupby_partial = + let open A in + Branching.local ~name:"elim-groupby-partial" (function + | { node = GroupBy (ps, key, r); _ } -> + Seq.of_list key + |> Seq.map ~f:(fun k -> + let key_name = Fresh.name Global.fresh "k%d" in + let scope = Fresh.name Global.fresh "s%d" in + let rel = + (Option.value_exn + (Db.relation_has_field C.conn (Name.name k))) + .r_name + in + let lhs = + dedup + (select [ As_pred (Name k, key_name) ] (db_relation rel)) + in + let new_key = + List.filter key ~f:(fun k' -> Name.O.(k <> k')) + in + let new_ps = + List.filter ps ~f:(fun p -> + not ([%compare.equal: Pred.t] p (Name k))) + |> List.map ~f:(Pred.scoped (Schema.schema lhs) scope) + in + let filter_pred = + Binop (Eq, Name k, Name (Name.create key_name)) + |> Pred.scoped (Schema.schema lhs) scope + in + let new_r = + replace_rel rel (filter filter_pred (db_relation rel)) r + in + let new_group_by = + if List.is_empty new_key then select new_ps new_r + else group_by new_ps new_key new_r + in + let key_scalar = + let p = + Pred.scoped (Schema.schema lhs) scope + (As_pred (Name (Name.create key_name), Name.name k)) + in + scalar p + in + list lhs scope (tuple [ key_scalar; new_group_by ] Cross)) + | _ -> Seq.empty) +end diff --git a/lib/groupby_tactics_test.ml b/lib/groupby_tactics_test.ml new file mode 100644 index 0000000..395f1b8 --- /dev/null +++ b/lib/groupby_tactics_test.ml @@ -0,0 +1,259 @@ +open Abslayout +open Abslayout_load + +module C = struct + let conn = Db.create "postgresql:///tpch_1k" + let cost_conn = conn + let verbose = false + let validate = false + + let params = + let open Prim_type in + Set.of_list + (module Name) + [ + Name.create ~type_:string_t "param1"; + Name.create ~type_:string_t "param2"; + Name.create ~type_:string_t "param3"; + ] + + let param_ctx = Map.empty (module Name) + let fresh = Fresh.create () + let simplify = None +end + +open C +open Groupby_tactics.Make (C) +open Ops.Make (C) +open Tactics_util.Make (C) + +let with_logs src f = + Logs.(set_reporter (format_reporter ())); + Logs.Src.set_level src (Some Debug); + let ret = f () in + Logs.Src.set_level src (Some Error); + Logs.(set_reporter nop_reporter); + ret + +let%expect_test "" = + let r = + load_string_exn ~params conn + {| +groupby([o_year, + (sum((if (nation_name = param1) then volume else 0.0)) / + sum(volume)) as mkt_share], + [o_year], + select([to_year(o_orderdate) as o_year, + (l_extendedprice * (1 - l_discount)) as volume, + n2_name as nation_name], + join((p_partkey = l_partkey), + join((s_suppkey = l_suppkey), + join((l_orderkey = o_orderkey), + join((o_custkey = c_custkey), + join((c_nationkey = n1_nationkey), + join((n1_regionkey = r_regionkey), + select([n_regionkey as n1_regionkey, n_nationkey as n1_nationkey], + nation), + filter((r_name = param2), region)), + customer), + filter(((o_orderdate >= date("1995-01-01")) && + (o_orderdate <= date("1996-12-31"))), + orders)), + lineitem), + join((s_nationkey = n2_nationkey), + select([n_nationkey as n2_nationkey, n_name as n2_name], + nation), + supplier)), + filter((p_type = param3), part)))) +|} + in + ( with_logs Groupby_tactics.src @@ fun () -> + with_logs Join_elim.src @@ fun () -> + apply elim_groupby Path.root r |> Option.iter ~f:(Fmt.pr "%a@." pp) ); + [%expect + {| + alist(dedup( + select([to_year(o_orderdate) as o_year], + dedup(select([o_orderdate], orders)))) as k0, + select([o_year, + (sum((if (nation_name = param1) then volume else 0.0)) / + sum(volume)) as mkt_share], + filter((o_year = k0.o_year), + select([to_year(o_orderdate) as o_year, + (l_extendedprice * (1 - l_discount)) as volume, + n2_name as nation_name], + join((p_partkey = l_partkey), + join((s_suppkey = l_suppkey), + join((l_orderkey = o_orderkey), + join((o_custkey = c_custkey), + join((c_nationkey = n1_nationkey), + join((n1_regionkey = r_regionkey), + select([n_regionkey as n1_regionkey, + n_nationkey as n1_nationkey], + nation), + filter((r_name = param2), region)), + customer), + filter(((o_orderdate >= date("1995-01-01")) && + (o_orderdate <= date("1996-12-31"))), + orders)), + lineitem), + join((s_nationkey = n2_nationkey), + select([n_nationkey as n2_nationkey, n_name as n2_name], + nation), + supplier)), + filter((p_type = param3), part)))))) |}] + +let%expect_test "" = + let params = + let open Prim_type in + Set.of_list + (module Name) + [ + Name.create ~type_:string_t "param0"; Name.create ~type_:date_t "param1"; + ] + in + let module C = struct + let conn = Db.create "postgresql:///tpch_1k" + let cost_conn = conn + let params = params + end in + let open Groupby_tactics.Make (C) in + let open Ops.Make (C) in + let r = + load_string_exn ~params conn + {| +groupby([l_orderkey, sum((l_extendedprice * (1 - l_discount))) as revenue, o_orderdate, o_shippriority], + [l_orderkey, o_orderdate, o_shippriority], + join((c_custkey = o_custkey), + join((l_orderkey = o_orderkey), + filter((o_orderdate < param1), + alist(orders as s1, + atuple([ascalar(s1.o_orderkey), ascalar(s1.o_custkey), ascalar(s1.o_orderstatus), + ascalar(s1.o_totalprice), ascalar(s1.o_orderdate), + ascalar(s1.o_orderpriority), ascalar(s1.o_clerk), + ascalar(s1.o_shippriority), ascalar(s1.o_comment)], + cross))), + filter((l_shipdate > param1), + alist(lineitem as s2, + atuple([ascalar(s2.l_orderkey), ascalar(s2.l_partkey), ascalar(s2.l_suppkey), + ascalar(s2.l_linenumber), ascalar(s2.l_quantity), + ascalar(s2.l_extendedprice), ascalar(s2.l_discount), + ascalar(s2.l_tax), ascalar(s2.l_returnflag), ascalar(s2.l_linestatus), + ascalar(s2.l_shipdate), ascalar(s2.l_commitdate), + ascalar(s2.l_receiptdate), ascalar(s2.l_shipinstruct), + ascalar(s2.l_shipmode), ascalar(s2.l_comment)], + cross)))), + filter((c_mktsegment = param0), + alist(customer as s0, + atuple([ascalar(s0.c_custkey), ascalar(s0.c_name), ascalar(s0.c_address), + ascalar(s0.c_nationkey), ascalar(s0.c_phone), ascalar(s0.c_acctbal), + ascalar(s0.c_mktsegment), ascalar(s0.c_comment)], + cross))))) +|} + in + apply elim_groupby Path.root r |> Option.iter ~f:(Fmt.pr "%a@." pp); + [%expect + {| + alist(dedup( + dedup( + select([l_orderkey, o_orderdate, o_shippriority], + join((c_custkey = o_custkey), + join((l_orderkey = o_orderkey), + alist(orders as s13, + atuple([ascalar(s13.o_orderkey), ascalar(s13.o_custkey), + ascalar(s13.o_orderstatus), + ascalar(s13.o_totalprice), + ascalar(s13.o_orderdate), + ascalar(s13.o_orderpriority), ascalar(s13.o_clerk), + ascalar(s13.o_shippriority), ascalar(s13.o_comment)], + cross)), + alist(lineitem as s14, + atuple([ascalar(s14.l_orderkey), ascalar(s14.l_partkey), + ascalar(s14.l_suppkey), ascalar(s14.l_linenumber), + ascalar(s14.l_quantity), + ascalar(s14.l_extendedprice), + ascalar(s14.l_discount), ascalar(s14.l_tax), + ascalar(s14.l_returnflag), + ascalar(s14.l_linestatus), ascalar(s14.l_shipdate), + ascalar(s14.l_commitdate), + ascalar(s14.l_receiptdate), + ascalar(s14.l_shipinstruct), + ascalar(s14.l_shipmode), ascalar(s14.l_comment)], + cross))), + alist(customer as s12, + atuple([ascalar(s12.c_custkey), ascalar(s12.c_name), + ascalar(s12.c_address), ascalar(s12.c_nationkey), + ascalar(s12.c_phone), ascalar(s12.c_acctbal), + ascalar(s12.c_mktsegment), ascalar(s12.c_comment)], + cross)))))) as k1, + select([l_orderkey, sum((l_extendedprice * (1 - l_discount))) as revenue, + o_orderdate, o_shippriority], + filter(((l_orderkey = k1.l_orderkey) && + ((o_orderdate = k1.o_orderdate) && + (o_shippriority = k1.o_shippriority))), + join((c_custkey = o_custkey), + join((l_orderkey = o_orderkey), + filter((o_orderdate < param1), + alist(orders as s1, + atuple([ascalar(s1.o_orderkey), ascalar(s1.o_custkey), + ascalar(s1.o_orderstatus), ascalar(s1.o_totalprice), + ascalar(s1.o_orderdate), ascalar(s1.o_orderpriority), + ascalar(s1.o_clerk), ascalar(s1.o_shippriority), + ascalar(s1.o_comment)], + cross))), + filter((l_shipdate > param1), + alist(lineitem as s2, + atuple([ascalar(s2.l_orderkey), ascalar(s2.l_partkey), + ascalar(s2.l_suppkey), ascalar(s2.l_linenumber), + ascalar(s2.l_quantity), ascalar(s2.l_extendedprice), + ascalar(s2.l_discount), ascalar(s2.l_tax), + ascalar(s2.l_returnflag), ascalar(s2.l_linestatus), + ascalar(s2.l_shipdate), ascalar(s2.l_commitdate), + ascalar(s2.l_receiptdate), ascalar(s2.l_shipinstruct), + ascalar(s2.l_shipmode), ascalar(s2.l_comment)], + cross)))), + filter((c_mktsegment = param0), + alist(customer as s0, + atuple([ascalar(s0.c_custkey), ascalar(s0.c_name), + ascalar(s0.c_address), ascalar(s0.c_nationkey), + ascalar(s0.c_phone), ascalar(s0.c_acctbal), + ascalar(s0.c_mktsegment), ascalar(s0.c_comment)], + cross))))))) |}] + +let%expect_test "" = + let params = + Set.of_list (module Name) + @@ List.init 7 ~f:(fun i -> Name.create @@ sprintf "param%d" i) + in + let r = + load_string_exn ~params conn + {| +select([substring(c1_phone, 0, 2) as cntrycode], + filter(((substring(c1_phone, 0, 2) = param0) || + ((substring(c1_phone, 0, 2) = param1) || + ((substring(c1_phone, 0, 2) = param2) || + ((substring(c1_phone, 0, 2) = param3) || + ((substring(c1_phone, 0, 2) = param4) || + ((substring(c1_phone, 0, 2) = param5) || + (substring(c1_phone, 0, 2) = param6))))))), + filter((c1_acctbal > + (select([avg(c_acctbal) as avgbal], + filter(((c_acctbal > 0.0) && + ((substring(c_phone, 0, 2) = param0) || + ((substring(c_phone, 0, 2) = param1) || + ((substring(c_phone, 0, 2) = param2) || + ((substring(c_phone, 0, 2) = param3) || + ((substring(c_phone, 0, 2) = param4) || + ((substring(c_phone, 0, 2) = param5) || + (substring(c_phone, 0, 2) = param6)))))))), + customer)))), + filter(not(exists(filter((o_custkey = c1_custkey), orders))), + select([c_phone as c1_phone, c_acctbal as c1_acctbal, c_custkey as c1_custkey], customer))))) +|} + in + let ps = match r.node with Select (ps, _) -> ps | _ -> assert false in + all_values_approx ps r |> Result.iter ~f:(Fmt.pr "%a@." pp); + [%expect + {| + select([substring(c_phone, 0, 2) as cntrycode], + dedup(select([c_phone], customer))) |}] diff --git a/lib/join_elim_tactics.ml b/lib/join_elim_tactics.ml new file mode 100644 index 0000000..9361972 --- /dev/null +++ b/lib/join_elim_tactics.ml @@ -0,0 +1,181 @@ +open! Base +open Ast +open Schema +open Collections +module P = Pred.Infix +module A = Abslayout +open Match + +module Config = struct + module type S = sig + include Ops.Config.S + include Tactics_util.Config.S + include Simplify_tactic.Config.S + end +end + +module Make (C : Config.S) = struct + open C + open Ops.Make (C) + open Simplify_tactic.Make (C) + module Tactics_util = Tactics_util.Make (C) + + let elim_join_nest r = + let open Option.Let_syntax in + let%bind pred, r1, r2 = to_join r in + let scope = Fresh.name Global.fresh "s%d" in + let pred = Pred.scoped (schema r1) scope pred in + let lhs = + let scalars = + schema r1 |> Schema.scoped scope |> Schema.to_select_list + |> List.map ~f:A.scalar + in + A.tuple scalars Cross + and rhs = A.filter pred r2 in + return @@ A.list r1 scope @@ A.tuple [ lhs; rhs ] Cross + + let elim_join_nest = of_func elim_join_nest ~name:"elim-join-nest" + + let elim_join_hash r = + let open Option.Let_syntax in + let%bind pred, r1, r2 = to_join r in + match pred with + | Binop (Eq, kl, kr) -> + let join_scope = Fresh.name Global.fresh "s%d" + and hash_scope = Fresh.name Global.fresh "s%d" + and r1_schema = schema r1 + and r2_schema = schema r2 in + let layout = + let slist = + let r1_schema = Schema.scoped join_scope r1_schema in + r1_schema @ schema r2 |> List.map ~f:P.name + in + A.dep_join r1 join_scope @@ A.select slist + @@ A.hash_idx + (A.dedup @@ A.select [ kr ] r2) + hash_scope + (A.filter + (Binop (Eq, Pred.scoped r2_schema hash_scope kr, kr)) + r2) + [ Pred.scoped r1_schema join_scope kl ] + in + Some layout + | _ -> None + + let elim_join_hash = of_func elim_join_hash ~name:"elim-join-hash" + + let elim_join_filter r = + let open Option.Let_syntax in + let%map pred, r1, r2 = to_join r in + A.filter pred (A.join (Bool true) r1 r2) + + let elim_join_filter = of_func elim_join_filter ~name:"elim-join-filter" + + let hoist_join_param_filter r = + let open Option.Let_syntax in + let%bind pred, r1, r2 = to_join r in + let has_params p = + not (Set.is_empty @@ Set.inter (Free.pred_free p) C.params) + in + let hoist, keep = Pred.conjuncts pred |> List.partition_tf ~f:has_params in + if List.is_empty hoist then None + else + return + @@ A.filter (Pred.conjoin hoist) + @@ A.join (Pred.conjoin keep) r1 r2 + + let hoist_join_param_filter = + of_func hoist_join_param_filter ~name:"hoist-join-param-filter" + + let hoist_join_filter r = + let open Option.Let_syntax in + let%bind pred, r1, r2 = to_join r in + return @@ A.filter pred @@ A.join (Bool true) r1 r2 + + let hoist_join_filter = of_func hoist_join_filter ~name:"hoist-join-filter" + + let push_join_filter r = + let open Option.Let_syntax in + let stage = r.meta#stage in + let r = strip_meta r in + let%bind p, r, r' = to_join r in + let s = Schema.schema r |> Set.of_list (module Name) + and s' = Schema.schema r' |> Set.of_list (module Name) in + let left, right, above = + Pred.conjuncts p + |> List.partition3_map ~f:(fun p -> + if Tactics_util.is_supported stage s p then `Fst p + else if Tactics_util.is_supported stage s' p then `Snd p + else `Trd p) + in + if List.is_empty left && List.is_empty right then None + else + let r = if List.is_empty left then r else A.filter (Pred.conjoin left) r + and r' = + if List.is_empty right then r' else A.filter (Pred.conjoin right) r' + in + return @@ A.join (Pred.conjoin above) r r' + + let push_join_filter = + seq' + (of_func_cond ~name:"push-join-filter" + ~pre:(fun r -> Some (Resolve.resolve_exn ~params r)) + push_join_filter + ~post:(fun r -> Resolve.resolve ~params r |> Result.ok)) + simplify + + let split_out path pk = + let open A in + let pk = A.name_of_string_exn pk in + let eq = [%compare.equal: Name.t] in + let fresh_name = Fresh.name Global.fresh in + + of_func ~name:"split-out" @@ fun r -> + let open Option.Let_syntax in + let%bind path = path r in + let rel = Path.get_exn path r in + let rel_schema = Schema.schema rel in + + let schema = schema r in + if List.mem schema pk ~equal:eq then + let scope = fresh_name "s%d" in + let alias = Name.scoped scope @@ Name.fresh "x%d" in + (* Partition the schema of the original layout between the two new layouts. *) + let fst_sel_list, snd_sel_list = + List.partition_tf schema ~f:(fun n -> + eq n pk + || not + (List.mem rel_schema n ~equal:(fun n n' -> + String.(Name.(name n = name n'))))) + in + let scope2 = fresh_name "s%d" in + Option.return + @@ dep_join + (select (Schema.to_select_list fst_sel_list) r) + scope2 + (hash_idx + (dedup (select [ As_pred (Name pk, Name.name alias) ] rel)) + scope + (select + (List.map fst_sel_list ~f:(fun n -> + Name (Name.scoped scope2 n)) + @ Schema.to_select_list snd_sel_list) + (filter (Binop (Eq, Name pk, Name alias)) rel)) + [ Name (Name.scoped scope2 pk) ]) + else None + + (* let hoist_join_left r = + * let open Option.Let_syntax in + * let%bind p1, r1, r2 = to_join r in + * let%bind p2, r3, r4 = to_join r1 in + * return A.join p2 (A.join p1) + * let has_params p = + * not (Set.is_empty @@ Set.inter (Free.pred_free p) C.params) + * in + * let hoist, keep = Pred.conjuncts pred |> List.partition_tf ~f:has_params in + * if List.is_empty hoist then None + * else + * return + * @@ A.filter (Pred.conjoin hoist) + * @@ A.join (Pred.conjoin keep) r1 r2 *) +end diff --git a/lib/join_opt.ml b/lib/join_opt.ml new file mode 100644 index 0000000..97a834b --- /dev/null +++ b/lib/join_opt.ml @@ -0,0 +1,625 @@ +open Graph +open Printf +open Collections +open Ast +module A = Abslayout +module P = Pred.Infix +module V = Visitors +include (val Log.make ~level:(Some Warning) "castor-opt.join-opt") + +let max_nest_joins = ref 1 + +let param = + let open Command.Let_syntax in + [%map_open + let () = param + and nest_joins = + flag "max-nest-joins" (optional int) + ~doc:"N maximum number of nested layout joins" + in + Option.iter nest_joins ~f:(fun x -> max_nest_joins := x)] + +let filter p r = { node = Filter (p, r); meta = r.meta } +let join pred r1 r2 = { node = Join { pred; r1; r2 }; meta = r1.meta } + +module Join_graph = struct + module Vertex = struct + module T = struct + type t = + (< stage : Name.t -> [ `Compile | `Run | `No_scope ] + ; resolved : Resolve.resolved > + [@ignore] [@opaque]) + annot + [@@deriving compare, hash, sexp] + + let equal = [%compare.equal: t] + end + + include T + include Comparator.Make (T) + end + + module Edge = struct + module T = struct + type t = + (< stage : Name.t -> [ `Compile | `Run | `No_scope ] + ; resolved : Resolve.resolved > + [@ignore] [@opaque]) + annot + pred + [@@deriving compare, sexp] + end + + include T + include Comparator.Make (T) + + let default = Bool true + end + + module G = Persistent.Graph.ConcreteLabeled (Vertex) (Edge) + include G + include Oper.P (G) + include Oper.Choose (G) + module Dfs = Traverse.Dfs (G) + + let source_relation leaves n = + List.find_map leaves ~f:(fun (r, s) -> if Set.mem s n then Some r else None) + |> Result.of_option + ~error: + (Error.create "No source found for name." + (n, List.map leaves ~f:(fun (_, ns) -> ns)) + [%sexp_of: Name.t * Set.M(Name).t list]) + + let to_string g = sprintf "graph (|V|=%d) (|E|=%d)" (nb_vertex g) (nb_edges g) + + let sexp_of_t g = + fold_edges_e (fun e l -> e :: l) g [] + |> [%sexp_of: (Vertex.t * Edge.t * Vertex.t) list] + + let compare g1 g2 = Sexp.compare ([%sexp_of: t] g1) ([%sexp_of: t] g2) + + let add_or_update_edge g ((v1, l, v2) as e) = + try + let _, l', _ = find_edge g v1 v2 in + add_edge_e g (v1, Binop (And, l, l'), v2) + with Caml.Not_found -> add_edge_e g e + + let vertices g = fold_vertex (fun v l -> v :: l) g [] + + let partition g vs = + let g1, g2 = + let f v (lhs, rhs) = + let in_set = Set.mem vs v in + let lhs = if in_set then remove_vertex lhs v else lhs + and rhs = if in_set then rhs else remove_vertex rhs v in + (lhs, rhs) + in + fold_vertex f g (g, g) + in + let es = + let f ((v1, _, v2) as e) es = + let v1_in = Set.mem vs v1 and v2_in = Set.mem vs v2 in + if (v1_in && not v2_in) || ((not v1_in) && v2_in) then e :: es else es + in + fold_edges_e f g [] + in + (g1, g2, es) + + let is_connected g = + let n = nb_vertex g in + let n = Dfs.fold_component (fun _ i -> i - 1) n g (choose_vertex g) in + n = 0 + + let contract join g = + (* if the edge is to be removed (property = true): + * make a union of the two union-sets of start and end node; + * put this set in the map for all nodes in this set *) + let f edge m = + let s_src, j_src = Map.find_exn m (E.src edge) in + let s_dst, j_dst = Map.find_exn m (E.dst edge) in + let s = Set.union s_src s_dst in + let j = join ~label:(G.E.label edge) j_src j_dst in + Set.fold s ~init:m ~f:(fun m vertex -> Map.set m ~key:vertex ~data:(s, j)) + in + (* initialize map with singleton-sets for every node (of itself) *) + let m = + G.fold_vertex + (fun vertex m -> + Map.set m ~key:vertex + ~data:(Set.singleton (module Vertex) vertex, vertex)) + g + (Map.empty (module Vertex)) + in + G.fold_edges_e f g m |> Map.data |> List.hd_exn |> Tuple.T2.get2 + + let to_ralgebra graph = + if nb_vertex graph = 1 then choose_vertex graph + else contract (fun ~label:p j1 j2 -> join p j1 j2) graph + + (** Collect the leaves of the join tree rooted at r. *) + let rec to_leaves r = + match r.node with + | Join { r1; r2; _ } -> Set.union (to_leaves r1) (to_leaves r2) + | _ -> Set.singleton (module Vertex) r + + type graph_filters = { + graph : t; + leaf_filters : Set.M(Edge).t Map.M(Vertex).t; + top_filters : Edge.t list; + } + + (** Convert a join tree to a join graph. + + Returns a graph where the nodes are the queries at the leaves of the join + tree and the edges are join predicates. *) + let rec to_graph leaves r = + let union_filters f1 f2 = + let merger ~key:_ = function + | `Left x | `Right x -> Some x + | `Both (x, y) -> Some (Set.union x y) + in + Map.merge ~f:merger f1 f2 + in + match r.node with + | Join { r1; r2; pred = p } -> + let x1 = to_graph leaves r1 and x2 = to_graph leaves r2 in + let graph = union x1#graph x2#graph + and leaf_filters = union_filters x1#leaf_filters x2#leaf_filters + and top_filters = x1#top_filters @ x2#top_filters in + let x = + object + method graph = graph + method leaf_filters = leaf_filters + method top_filters = top_filters + end + in + + (* Collect the set of relations that this join depends on. *) + List.fold_left (Pred.conjuncts p) ~init:x ~f:(fun acc p -> + let pred_rels = + Pred.names p |> Set.to_list + |> List.map ~f:(source_relation leaves) + |> Or_error.all + in + match pred_rels with + | Ok [ r ] -> + let leaf_filters = + Map.update acc#leaf_filters r ~f:(function + | Some fs -> Set.add fs p + | None -> Set.singleton (module Edge) p) + in + object + method graph = acc#graph + method leaf_filters = leaf_filters + method top_filters = acc#top_filters + end + | Ok [ r1; r2 ] -> + let graph = add_or_update_edge acc#graph (r1, p, r2) in + object + method graph = graph + method leaf_filters = acc#leaf_filters + method top_filters = acc#top_filters + end + | _ -> + object + method graph = acc#graph + method leaf_filters = acc#leaf_filters + method top_filters = p :: acc#top_filters + end) + | _ -> + object + method graph = empty + method leaf_filters = Map.empty (module Vertex) + method top_filters = [] + end + + let of_abslayout r = + debug (fun m -> m "Planning join for %a." A.pp r); + let leaves = + to_leaves r |> Set.to_list + |> List.map ~f:(fun r -> + (r, Schema.schema r |> Set.of_list (module Name))) + in + let x = to_graph leaves r in + (* Put the filters back onto the leaves of the graph. *) + let graph = + map_vertex + (fun r -> + match Map.find x#leaf_filters r with + | Some preds -> filter (Set.to_list preds |> Pred.conjoin) r + | None -> r) + x#graph + in + object + method graph = graph + method top_filters = x#top_filters + end + + let partition_fold ~init ~f graph = + let vertices = vertices graph |> Array.of_list in + let n = Array.length vertices in + let rec loop acc k = + if k >= n then acc + else + let acc = + Combinat.combinations (List.init n ~f:Fun.id) ~k + |> Iter.fold ~init:acc ~f:(fun acc vs -> + let g1, g2, es = + partition graph + (List.init k ~f:(fun i -> vertices.(vs.(i))) + |> Set.of_list (module Vertex)) + in + if is_connected g1 && is_connected g2 then f acc (g1, g2, es) + else acc) + in + loop acc (k + 1) + in + loop init 1 +end + +module Pareto_set = struct + type 'a t = (float array * 'a) list + + let empty = [] + let singleton c v = [ (c, v) ] + + let dominates x y = + assert (Array.length x = Array.length y); + let n = Array.length x in + let rec loop i le lt = + if i = n then le && lt + else loop (i + 1) Float.(le && x.(i) <= y.(i)) Float.(lt || x.(i) < y.(i)) + in + loop 0 true false + + let rec add s c v = + match s with + | [] -> [ (c, v) ] + | (c', v') :: s' -> + if Array.equal Float.( = ) c c' || dominates c' c then s + else if dominates c c' then add s' c v + else (c', v') :: add s' c v + + let min_elt f s = + List.map s ~f:(fun (c, x) -> (f c, x)) + |> List.min_elt ~compare:(fun (c1, _) (c2, _) -> Float.compare c1 c2) + |> Option.map ~f:(fun (_, x) -> x) + + let of_list l = List.fold_left l ~init:[] ~f:(fun s (c, v) -> add s c v) + let length = List.length + let union_all ss = List.concat ss |> of_list +end + +module Config = struct + module type My_S = sig + val cost_conn : Db.t + val params : Set.M(Name).t + val random : Mcmc.Random_choice.t + end + + module type S = sig + include Ops.Config.S + include Simple_tactics.Config.S + include My_S + end +end + +module Make (Config : Config.S) = struct + open Config + open Ops.Make (Config) + module G = Join_graph + + type t = + | Flat of G.Vertex.t + | Id of G.Vertex.t + | Hash of { lkey : G.Edge.t; lhs : t; rkey : G.Edge.t; rhs : t } + | Nest of { lhs : t; rhs : t; pred : G.Edge.t } + [@@deriving sexp_of] + + let rec num_nest = function + | Flat _ | Id _ -> 0 + | Hash { lhs; rhs; _ } -> num_nest lhs + num_nest rhs + | Nest { lhs; rhs; _ } -> 1 + num_nest lhs + num_nest rhs + + let rec to_ralgebra = function + | Flat r | Id r -> r + | Nest { lhs; rhs; pred } -> join pred (to_ralgebra lhs) (to_ralgebra rhs) + | Hash { lkey; rkey; lhs; rhs } -> + join (Binop (Eq, lkey, rkey)) (to_ralgebra lhs) (to_ralgebra rhs) + + module Cost = struct + let read = function + | Prim_type.(IntT _ | DateT _ | FixedT _ | StringT _) -> 4.0 + | BoolT _ -> 1.0 + | _ -> failwith "Unexpected type." + + let hash = function + | Prim_type.(IntT _ | DateT _ | FixedT _ | BoolT _) -> 40.0 + | StringT _ -> 100.0 + | _ -> failwith "Unexpected type." + + let size = function + | Prim_type.(IntT _ | DateT _ | FixedT _) -> 4.0 + | StringT _ -> 25.0 + | BoolT _ -> 1.0 + | _ -> failwith "Unexpected type." + + (* TODO: Not all lists have 16B headers *) + let list_size = 16.0 + end + + let ntuples r = + let r = to_ralgebra r in + (Explain.explain cost_conn (Sql.of_ralgebra r |> Sql.to_string) + |> Or_error.ok_exn) + .nrows |> Float.of_int + + let estimate_ntuples_parted parts join = + let r = to_ralgebra join in + let s = Schema.schema r |> Set.of_list (module Name) in + + (* Remove parameters from the join nest where possible. *) + let static_r = + let rec annot r = V.Map.annot query r + and query q = V.Map.query annot pred q + and pred p = Pred.to_static ~params p in + annot @@ strip_meta r + in + + (* Generate a group-by using the partition fields. *) + let parts = Set.filter parts ~f:(Set.mem s) in + let parted_r = + let c = P.name (Name.create "c") in + A.( + select + [ + As_pred (Min c, "min"); + As_pred (Max c, "max"); + As_pred (Avg c, "avg"); + ] + @@ group_by [ P.as_ Count "c" ] (Set.to_list parts) static_r) + |> Simplify_tactic.simplify cost_conn + in + let sql = Sql.of_ralgebra parted_r |> Sql.to_string + and schema = Prim_type.[ int_t; int_t; fixed_t ] in + + try + let tups = Db.exec_exn cost_conn schema sql in + + match tups with + | [ Int min; Int max; Fixed avg ] :: _ -> + (min, max, Fixed_point.to_float avg) + | [ Null; Null; Null ] :: _ -> (0, 0, 0.0) + | _ -> + err (fun m -> m "Unexpected tuples: %s" sql); + failwith "Unexpected tuples." + with _ -> (Int.max_value, Int.max_value, Float.max_value) + + let to_parts rhs pred = + let rhs_schema = Schema.schema rhs |> Set.of_list (module Name) in + Pred.names pred |> Set.filter ~f:(Set.mem rhs_schema) + + let rec scan_cost parts r = + let sum = List.sum (module Float) in + match r with + | Flat _ | Id _ -> + let _, _, nt = estimate_ntuples_parted parts r in + sum (Schema_types.types (to_ralgebra r)) ~f:Cost.read *. nt + | Nest { lhs; rhs; pred } -> + let _, _, lhs_nt = estimate_ntuples_parted parts lhs in + let rhs_per_partition_cost = + scan_cost (Set.union (to_parts (to_ralgebra rhs) pred) parts) rhs + in + scan_cost parts lhs +. (lhs_nt *. rhs_per_partition_cost) + | Hash { lkey; lhs; rhs; rkey } -> + let _, _, nt_lhs = estimate_ntuples_parted parts lhs in + let rhs_per_partition_cost = + let pred = Pred.Infix.(lkey = rkey) in + scan_cost (Set.union (to_parts (to_ralgebra rhs) pred) parts) rhs + in + scan_cost parts lhs + +. (nt_lhs *. (Cost.hash (Pred.to_type lkey) +. rhs_per_partition_cost)) + + let rec is_static_join = function + | Id r -> Is_serializable.is_static ~params r + | Flat _ -> true + | Hash { lhs; rhs; _ } | Nest { lhs; rhs; _ } -> + is_static_join lhs && is_static_join rhs + + let leaf_flat r = + let open Option.Let_syntax in + if Is_serializable.is_static ~params r then return @@ Flat r else None + + let leaf_id r = + let open Option.Let_syntax in + return @@ Id r + + let enum_flat_join opt parts pred s1 s2 = + let select_flat s = + List.filter_map (opt parts s) ~f:(fun (_, j) -> + if is_static_join j then Some (to_ralgebra j) else None) + in + List.cartesian_product (select_flat s1) (select_flat s2) + |> List.map ~f:(fun (r1, r2) -> Flat (join pred r1 r2)) + + let enum_hash_join opt parts pred s1 s2 = + let open List.Let_syntax in + let lhs = G.to_ralgebra s1 and rhs = G.to_ralgebra s2 in + let lhs_schema = Schema.schema lhs and rhs_schema = Schema.schema rhs in + (* Figure out which partition a key comes from. *) + let key_side k = + let rs = Pred.names k in + let all_in s = + Set.for_all rs ~f:(List.mem ~equal:[%compare.equal: Name.t] s) + in + if all_in lhs_schema then return (`Lhs (s1, k)) + else if all_in rhs_schema then return (`Rhs (s2, k)) + else ( + debug (fun m -> m "Unknown key %a" Pred.pp k); + []) + in + let%bind k1, k2 = + match pred with + | Binop (Eq, k1, k2) -> return (k1, k2) + | _ -> + debug (fun m -> m "Adding hash join failed."); + [] + in + let%bind s1 = key_side k1 and s2 = key_side k2 in + match (s1, s2) with + | `Lhs (s1, k1), `Rhs (s2, k2) | `Rhs (s2, k2), `Lhs (s1, k1) -> + let rhs_parts = Set.union (to_parts rhs pred) parts in + let lhs_set = opt parts s1 |> List.map ~f:(fun (_, j) -> j) + and rhs_set = + opt rhs_parts s2 + |> List.map ~f:(fun (_, j) -> j) + |> List.filter ~f:is_static_join + in + List.cartesian_product lhs_set rhs_set + |> List.map ~f:(fun (lhs, rhs) -> + Hash { lkey = k1; rkey = k2; lhs; rhs }) + | _ -> + debug (fun m -> + m "Keys come from same partition %a %a" Pred.pp k1 Pred.pp k2); + [] + + let enum_nest_join opt parts pred s1 s2 = + let lhs_parts = Set.union (to_parts (G.to_ralgebra s1) pred) parts + and rhs_parts = Set.union (to_parts (G.to_ralgebra s2) pred) parts in + let lhs_set = + opt lhs_parts s1 + |> List.map ~f:(fun (_, j) -> j) + |> List.filter ~f:is_static_join + and rhs_set = opt rhs_parts s2 |> List.map ~f:(fun (_, j) -> j) in + List.cartesian_product lhs_set rhs_set + |> List.map ~f:(fun (j1, j2) -> Nest { lhs = j1; rhs = j2; pred }) + + let opt_nonrec opt parts s = + info (fun m -> m "Choosing join for space %s." (G.to_string s)); + + let filter_nest_joins = + List.filter ~f:(fun j -> num_nest j <= !max_nest_joins) + in + let add_cost = List.map ~f:(fun j -> ([| scan_cost parts j |], j)) in + + let joins = + if G.nb_vertex s = 1 then + (* Select strategy for the leaves of the join tree. *) + let r = G.choose_vertex s in + [ leaf_flat r; leaf_id r ] + |> List.filter_map ~f:Fun.id |> filter_nest_joins |> add_cost + |> Pareto_set.of_list + else + G.partition_fold s ~init:Pareto_set.empty ~f:(fun cs (s1, s2, es) -> + let pred = Pred.conjoin (List.map es ~f:(fun (_, p, _) -> p)) in + let r = join pred (G.to_ralgebra s1) (G.to_ralgebra s2) in + + let open Mcmc.Random_choice in + let flat_joins = + if rand random "flat-join" (strip_meta r) then + enum_flat_join opt parts pred s1 s2 + else [] + and hash_joins = + if rand random "hash-join" (strip_meta r) then + enum_hash_join opt parts pred s1 s2 + else [] + and nest_joins = + if rand random "nest-join" (strip_meta r) then + enum_nest_join opt parts pred s1 s2 + else [] + in + let all_joins = + flat_joins @ hash_joins @ nest_joins + |> filter_nest_joins |> add_cost + in + + Pareto_set.(union_all [ cs; of_list all_joins ])) + in + info (fun m -> m "Found %d pareto-optimal joins." (Pareto_set.length joins)); + joins + + let opt = + let module Key = struct + type t = Set.M(Name).t * Set.M(G.Vertex).t + [@@deriving compare, hash, sexp_of] + + let create p graph = + let vertices = + G.fold_vertex + (fun v vs -> Set.add vs v) + graph + (Set.empty (module G.Vertex)) + in + (p, vertices) + end in + let tbl = Hashtbl.create (module Key) in + let rec opt p s = + let key = Key.create p s in + match Hashtbl.find tbl key with + | Some v -> v + | None -> + let v = opt_nonrec opt p s in + Hashtbl.add_exn tbl ~key ~data:v; + v + in + opt + + let opt r = + let s = G.of_abslayout r in + let joins = opt (Set.empty (module Name)) s#graph in + object + method joins = joins + method top_filters = s#top_filters + end + + let reshape top_filters j _ = + Some + (A.filter (Pred.conjoin (top_filters :> Pred.t list)) + @@ (to_ralgebra j :> Ast.t)) + + let rec emit_joins = + let open Join_elim_tactics.Make (Config) in + let open Simple_tactics.Make (Config) in + function + | Flat _ -> row_store + | Id _ -> id + | Hash { lhs; rhs; _ } -> + seq_many + [ + at_ (emit_joins lhs) (child 0); + at_ (emit_joins rhs) (child 1); + elim_join_hash; + ] + | Nest { lhs; rhs; _ } -> + seq_many + [ + at_ (emit_joins lhs) (child 0); + at_ (emit_joins rhs) (child 1); + elim_join_nest; + ] + + let transform = + let open Option.Let_syntax in + let f p r = + let%bind r = Resolve.resolve ~params r |> Result.ok in + let r = + Is_serializable.annotate_stage r + |> Visitors.map_meta (fun meta -> + object + method stage = meta#stage + method resolved = meta#meta#resolved + end) + in + let x = opt (Castor.Path.get_exn p r) in + info (fun m -> m "Found %d join options." (Pareto_set.length x#joins)); + let%bind j = Pareto_set.min_elt (fun a -> a.(0)) x#joins in + info (fun m -> m "Chose %a." Sexp.pp_hum ([%sexp_of: t] j)); + let tf = + seq + (local (reshape x#top_filters j) "reshape") + (at_ (emit_joins j) (child 0)) + in + apply (traced tf) p (strip_meta r) + in + global f "join-opt" +end diff --git a/lib/join_opt_test.ml b/lib/join_opt_test.ml new file mode 100644 index 0000000..26982fb --- /dev/null +++ b/lib/join_opt_test.ml @@ -0,0 +1,423 @@ +open Castor_test +open Collections +module A = Abslayout + +let () = Logs.Src.set_level Join_opt.src (Some Warning) + +module Config = struct + let cost_conn = Db.create "postgresql:///tpch_1k" + let conn = cost_conn + let validate = false + let param_ctx = Map.empty (module Name) + let params = Set.empty (module Name) + let verbose = false + let simplify = None + let random = Mcmc.Random_choice.create () +end + +open Ops.Make (Config) +open Join_opt.Make (Config) + +module C = + (val Constructors.Annot.with_default + object + method stage : Name.t -> [ `Compile | `Run | `No_scope ] = + assert false + + method resolved : Resolve.resolved = assert false + end) + +let type_ = Prim_type.IntT { nullable = false } +let c_custkey = Name.create ~type_ "c_custkey" +let c_nationkey = Name.create ~type_ "c_nationkey" +let n_nationkey = Name.create ~type_ "n_nationkey" +let o_custkey = Name.create ~type_ "o_custkey" +let orders = Db.relation Config.cost_conn "orders" +let customer = Db.relation Config.cost_conn "customer" +let nation = Db.relation Config.cost_conn "nation" + +let%expect_test "parted-cost" = + estimate_ntuples_parted (Set.empty (module Name)) (Flat (C.relation orders)) + |> [%sexp_of: int * int * float] |> print_s; + [%expect {| (1000 1000 1000) |}] + +let%expect_test "parted-cost" = + estimate_ntuples_parted + (Set.singleton (module Name) o_custkey) + (Flat (C.relation orders)) + |> [%sexp_of: int * int * float] |> print_s; + [%expect {| (1 2 1.0060362173038229) |}] + +let%expect_test "parted-cost" = + estimate_ntuples_parted + (Set.singleton (module Name) c_custkey) + (Flat (C.relation customer)) + |> [%sexp_of: int * int * float] |> print_s; + [%expect {| (1 1 1) |}] + +let estimate_cost p r = [| scan_cost p r |] + +let%expect_test "cost" = + estimate_cost + (Set.empty (module Name)) + (Flat + C.( + join + (Binop (Eq, Name c_custkey, Name o_custkey)) + (relation orders) (relation customer))) + |> [%sexp_of: float array] |> print_s; + [%expect {| (68000) |}] + +let%expect_test "cost" = + estimate_cost + (Set.empty (module Name)) + C.( + Nest + { + pred = Binop (Eq, Name c_custkey, Name o_custkey); + lhs = Flat (relation customer); + rhs = Flat (relation orders); + }) + |> [%sexp_of: float array] |> print_s; + [%expect {| (67808) |}] + +let%expect_test "cost" = + estimate_cost + (Set.empty (module Name)) + C.( + Flat + (join + (Binop (Eq, Name c_nationkey, Name n_nationkey)) + (relation nation) (relation customer))) + |> [%sexp_of: float array] |> print_s; + estimate_cost + (Set.empty (module Name)) + C.( + Nest + { + pred = Binop (Eq, Name c_nationkey, Name n_nationkey); + lhs = Flat (relation nation); + rhs = Flat (relation customer); + }) + |> [%sexp_of: float array] |> print_s; + estimate_cost + (Set.empty (module Name)) + C.( + Hash + { + lkey = Name c_nationkey; + rkey = Name n_nationkey; + lhs = Flat (relation nation); + rhs = Flat (relation customer); + }) + |> [%sexp_of: float array] |> print_s; + [%expect {| + (47712) + (32208) + (33208) |}] + +let%expect_test "to-from-ralgebra" = + let r = + C.( + join + (Binop (Eq, Name c_nationkey, Name n_nationkey)) + (relation nation) (relation customer)) + in + let s = G.of_abslayout r in + G.to_ralgebra s#graph |> Format.printf "%a" Abslayout.pp; + [%expect {| + join((c_nationkey = n_nationkey), nation, customer) |}] + +let%expect_test "to-from-ralgebra" = + let r = + C.( + join + (Binop (Eq, Name c_custkey, Name o_custkey)) + (relation orders) + (join + (Binop (Eq, Name c_nationkey, Name n_nationkey)) + (relation nation) (relation customer))) + in + let s = G.of_abslayout r in + G.to_ralgebra s#graph |> Format.printf "%a" Abslayout.pp; + [%expect + {| + join((c_custkey = o_custkey), + orders, + join((c_nationkey = n_nationkey), nation, customer)) |}] + +let%expect_test "part-fold" = + let r = + C.( + join + (Binop (Eq, Name c_custkey, Name o_custkey)) + (relation orders) + (join + (Binop (Eq, Name c_nationkey, Name n_nationkey)) + (relation nation) (relation customer))) + in + let s = G.of_abslayout r in + G.partition_fold s#graph ~init:() ~f:(fun () (s1, s2, _) -> + Format.printf "%a@.%a@.---\n" Abslayout.pp (G.to_ralgebra s1) Abslayout.pp + (G.to_ralgebra s2)); + [%expect + {| + join((c_nationkey = n_nationkey), nation, customer) + orders + --- + join((c_custkey = o_custkey), orders, customer) + nation + --- + nation + join((c_custkey = o_custkey), orders, customer) + --- + orders + join((c_nationkey = n_nationkey), nation, customer) + --- |}] + +let opt_test r = (opt r)#joins |> [%sexp_of: (float array * t) list] |> print_s + +let%expect_test "join-opt" = + opt_test + @@ C.join + (Binop (Eq, Name c_nationkey, Name n_nationkey)) + (C.relation nation) (C.relation customer); + [%expect + {| + (((32208) + (Nest + (lhs + (Flat + ((node + (Relation + ((r_name nation) + (r_schema + (((((name n_nationkey) (meta )) (IntT)) + (((name n_name) (meta )) (StringT (padded))) + (((name n_regionkey) (meta )) (IntT)) + (((name n_comment) (meta )) (StringT)))))))) + (meta )))) + (rhs + (Flat + ((node + (Relation + ((r_name customer) + (r_schema + (((((name c_custkey) (meta )) (IntT)) + (((name c_name) (meta )) (StringT)) + (((name c_address) (meta )) (StringT)) + (((name c_nationkey) (meta )) (IntT)) + (((name c_phone) (meta )) (StringT (padded))) + (((name c_acctbal) (meta )) (FixedT)) + (((name c_mktsegment) (meta )) (StringT (padded))) + (((name c_comment) (meta )) (StringT)))))))) + (meta )))) + (pred + (Binop Eq (Name ((name c_nationkey) (meta ))) + (Name ((name n_nationkey) (meta )))))))) |}] + +let%expect_test "join-opt" = + opt_test + @@ C.join (Binop (Eq, Name c_custkey, Name o_custkey)) (C.relation orders) + @@ C.join (Binop (Eq, Name c_nationkey, Name n_nationkey)) (C.relation nation) + @@ C.relation customer; + [%expect + {| + (((69208) + (Hash (lkey (Name ((name n_nationkey) (meta )))) + (lhs + (Flat + ((node + (Relation + ((r_name nation) + (r_schema + (((((name n_nationkey) (meta )) (IntT)) + (((name n_name) (meta )) (StringT (padded))) + (((name n_regionkey) (meta )) (IntT)) + (((name n_comment) (meta )) (StringT)))))))) + (meta )))) + (rkey (Name ((name c_nationkey) (meta )))) + (rhs + (Nest + (lhs + (Flat + ((node + (Relation + ((r_name customer) + (r_schema + (((((name c_custkey) (meta )) (IntT)) + (((name c_name) (meta )) (StringT)) + (((name c_address) (meta )) (StringT)) + (((name c_nationkey) (meta )) (IntT)) + (((name c_phone) (meta )) (StringT (padded))) + (((name c_acctbal) (meta )) (FixedT)) + (((name c_mktsegment) (meta )) (StringT (padded))) + (((name c_comment) (meta )) (StringT)))))))) + (meta )))) + (rhs + (Flat + ((node + (Relation + ((r_name orders) + (r_schema + (((((name o_orderkey) (meta )) (IntT)) + (((name o_custkey) (meta )) (IntT)) + (((name o_orderstatus) (meta )) (StringT (padded))) + (((name o_totalprice) (meta )) (FixedT)) + (((name o_orderdate) (meta )) (DateT)) + (((name o_orderpriority) (meta )) (StringT (padded))) + (((name o_clerk) (meta )) (StringT (padded))) + (((name o_shippriority) (meta )) (IntT)) + (((name o_comment) (meta )) (StringT)))))))) + (meta )))) + (pred + (Binop Eq (Name ((name c_custkey) (meta ))) + (Name ((name o_custkey) (meta )))))))))) |}] + +let%expect_test "" = + let params = + Set.of_list + (module Name) + [ + Name.create ~type_:Prim_type.string_t "k0_n1_name"; + Name.create ~type_:Prim_type.string_t "k0_n2_name"; + Name.create ~type_:Prim_type.date_t "k0_l_year"; + ] + in + let module Config = struct + let cost_conn = Db.create "postgresql:///tpch_1k" + let conn = cost_conn + let params = params + let random = Mcmc.Random_choice.create () + end in + let open Join_opt.Make (Config) in + let open Ops.Make (Config) in + let r = + Abslayout_load.load_string_exn ~params + (Lazy.force Test_util.tpch_conn) + {| +join(((n1_name = k0_n1_name) && + ((n2_name = k0_n2_name) && ((to_year(l_shipdate) = k0_l_year) && (true && (s_suppkey = l_suppkey))))), + join((o_orderkey = l_orderkey), + join((c_custkey = o_custkey), + join((c_nationkey = n2_nationkey), select([n_name as n2_name, n_nationkey as n2_nationkey], nation), customer), + orders), + filter(((l_shipdate >= date("1995-01-01")) && (l_shipdate <= date("1996-12-31"))), lineitem)), + join((s_nationkey = n1_nationkey), select([n_name as n1_name, n_nationkey as n1_nationkey], nation), supplier)) +|} + in + apply transform Path.root r |> Option.iter ~f:(Fmt.pr "%a" A.pp); + [%expect + {| + filter((true && + ((to_year(l_shipdate) = k0_l_year) && + ((n2_name = k0_n2_name) && (n1_name = k0_n1_name)))), + alist(alist(select([n_name as n1_name, n_nationkey as n1_nationkey], + nation) as s0, + atuple([ascalar(s0.n1_name), ascalar(s0.n1_nationkey)], cross)) as s2, + atuple([atuple([ascalar(s2.n1_name), ascalar(s2.n1_nationkey)], cross), + filter((s_nationkey = s2.n1_nationkey), + alist(join((s_suppkey = l_suppkey), + join((c_nationkey = n2_nationkey), + select([n_name as n2_name, + n_nationkey as n2_nationkey], + nation), + join((c_custkey = o_custkey), + join((o_orderkey = l_orderkey), + filter(((l_shipdate >= date("1995-01-01")) && + (l_shipdate <= date("1996-12-31"))), + lineitem), + orders), + customer)), + supplier) as s1, + atuple([ascalar(s1.n2_name), ascalar(s1.n2_nationkey), + ascalar(s1.l_orderkey), ascalar(s1.l_partkey), + ascalar(s1.l_suppkey), ascalar(s1.l_linenumber), + ascalar(s1.l_quantity), ascalar(s1.l_extendedprice), + ascalar(s1.l_discount), ascalar(s1.l_tax), + ascalar(s1.l_returnflag), ascalar(s1.l_linestatus), + ascalar(s1.l_shipdate), ascalar(s1.l_commitdate), + ascalar(s1.l_receiptdate), + ascalar(s1.l_shipinstruct), ascalar(s1.l_shipmode), + ascalar(s1.l_comment), ascalar(s1.o_orderkey), + ascalar(s1.o_custkey), ascalar(s1.o_orderstatus), + ascalar(s1.o_totalprice), ascalar(s1.o_orderdate), + ascalar(s1.o_orderpriority), ascalar(s1.o_clerk), + ascalar(s1.o_shippriority), ascalar(s1.o_comment), + ascalar(s1.c_custkey), ascalar(s1.c_name), + ascalar(s1.c_address), ascalar(s1.c_nationkey), + ascalar(s1.c_phone), ascalar(s1.c_acctbal), + ascalar(s1.c_mktsegment), ascalar(s1.c_comment), + ascalar(s1.s_suppkey), ascalar(s1.s_name), + ascalar(s1.s_address), ascalar(s1.s_nationkey), + ascalar(s1.s_phone), ascalar(s1.s_acctbal), + ascalar(s1.s_comment)], + cross)))], + cross))) |}] + +let%expect_test "" = + let r = + Abslayout_load.load_string_exn + (Lazy.force Test_util.tpch_conn) + {| + join((c_nationkey = n1_nationkey), + join((n1_regionkey = r_regionkey), + select([n_regionkey as n1_regionkey, n_nationkey as n1_nationkey], nation), + filter(true, region)), + customer) +|} + in + apply transform Path.root r |> Option.iter ~f:(Fmt.pr "%a" A.pp); + [%expect + {| + filter(true, + depjoin(alist(filter(true, region) as s3, + atuple([ascalar(s3.r_regionkey), ascalar(s3.r_name), + ascalar(s3.r_comment)], + cross)) as s7, + select([s7.r_regionkey, s7.r_name, s7.r_comment, n1_regionkey, + n1_nationkey, c_custkey, c_name, c_address, c_nationkey, + c_phone, c_acctbal, c_mktsegment, c_comment], + ahashidx(dedup( + select([n1_regionkey], + alist(alist(select([n_regionkey as n1_regionkey, + n_nationkey as n1_nationkey], + nation) as s9, + atuple([ascalar(s9.n1_regionkey), + ascalar(s9.n1_nationkey)], + cross)) as s11, + atuple([atuple([ascalar(s11.n1_regionkey), + ascalar(s11.n1_nationkey)], + cross), + filter((c_nationkey = s11.n1_nationkey), + alist(customer as s10, + atuple([ascalar(s10.c_custkey), + ascalar(s10.c_name), + ascalar(s10.c_address), + ascalar(s10.c_nationkey), + ascalar(s10.c_phone), + ascalar(s10.c_acctbal), + ascalar(s10.c_mktsegment), + ascalar(s10.c_comment)], + cross)))], + cross)))) as s8, + filter((s8.n1_regionkey = n1_regionkey), + alist(alist(select([n_regionkey as n1_regionkey, + n_nationkey as n1_nationkey], + nation) as s4, + atuple([ascalar(s4.n1_regionkey), ascalar(s4.n1_nationkey)], + cross)) as s6, + atuple([atuple([ascalar(s6.n1_regionkey), + ascalar(s6.n1_nationkey)], + cross), + filter((c_nationkey = s6.n1_nationkey), + alist(customer as s5, + atuple([ascalar(s5.c_custkey), ascalar(s5.c_name), + ascalar(s5.c_address), + ascalar(s5.c_nationkey), ascalar(s5.c_phone), + ascalar(s5.c_acctbal), + ascalar(s5.c_mktsegment), + ascalar(s5.c_comment)], + cross)))], + cross))), + s7.r_regionkey)))) |}] diff --git a/lib/list_tactics.ml b/lib/list_tactics.ml new file mode 100644 index 0000000..9df0dca --- /dev/null +++ b/lib/list_tactics.ml @@ -0,0 +1,71 @@ +open Ast +open Collections +module A = Abslayout +module P = Pred.Infix +module V = Visitors +open Match + +module Config = struct + module type S = sig + include Ops.Config.S + + val cost_conn : Db.t + end +end + +module Make (C : Config.S) = struct + open C + open Ops.Make (C) + + let split_list ?(min_factor = 3) r = + let open Option.Let_syntax in + let%bind { l_keys; l_scope; l_values } = to_list r in + let schema = Schema.schema l_keys in + if List.length schema <= 1 then None + else + let%bind counts = + List.map schema ~f:(fun n -> + let%bind result = + A.select [ P.count ] @@ A.dedup @@ A.select [ P.name n ] @@ l_keys + |> Sql.of_ralgebra |> Sql.to_string + |> Db.exec1 cost_conn Prim_type.int_t + |> Or_error.ok + in + match result with Int c :: _ -> return (n, c) | _ -> None) + |> Option.all + in + let counts = + List.sort counts ~compare:(fun (_, c) (_, c') -> [%compare: int] c' c) + in + let%bind split_field = + match counts with + | (n, c) :: (_, c') :: _ -> if c / c' > min_factor then Some n else None + | _ -> None + in + let other_fields = + List.filter schema ~f:(fun n -> + not ([%compare.equal: Name.t] n split_field)) + in + let other_fields_select = Schema.to_select_list other_fields in + let fresh_scope = Fresh.name Global.fresh "s%d" in + return + @@ A.list + (A.dedup @@ A.select [ P.name split_field ] @@ r) + fresh_scope + (A.list + (A.select other_fields_select + @@ A.filter + P.( + name split_field + = name (Name.scoped fresh_scope split_field)) + @@ r) + l_scope + (A.subst + (Map.singleton + (module Name) + (Name.scoped l_scope split_field) + (P.name (Name.scoped fresh_scope split_field))) + l_values)) + + let split_list = of_func split_list ~name:"split-list" +end diff --git a/lib/list_tactics_test.ml b/lib/list_tactics_test.ml new file mode 100644 index 0000000..01a3028 --- /dev/null +++ b/lib/list_tactics_test.ml @@ -0,0 +1,38 @@ +open Abslayout +open Abslayout_load +open Castor_test.Test_util + +module C = struct + let params = Set.empty (module Name) + let conn = Lazy.force tpch_conn + let cost_conn = Lazy.force tpch_conn +end + +open List_tactics.Make (C) +open Ops.Make (C) + +let load_string ?params s = load_string_exn ?params C.conn s + +let%expect_test "" = + let r = + load_string + {| +alist(select([substring(c_phone, 0, 2) as x395, c_acctbal], + filter((c_acctbal > 0.0), customer)) as s34, + atuple([ascalar(s34.x395), ascalar(s34.c_acctbal)], cross)) +|} + in + apply split_list Path.root r |> Option.iter ~f:(Fmt.pr "%a@." pp); + [%expect + {| + alist(dedup( + select([c_acctbal], + alist(select([substring(c_phone, 0, 2) as x395, c_acctbal], + filter((c_acctbal > 0.0), customer)) as s2, + atuple([ascalar(s2.x395), ascalar(s2.c_acctbal)], cross)))) as s0, + alist(select([x395], + filter((c_acctbal = s0.c_acctbal), + alist(select([substring(c_phone, 0, 2) as x395, c_acctbal], + filter((c_acctbal > 0.0), customer)) as s1, + atuple([ascalar(s1.x395), ascalar(s1.c_acctbal)], cross)))) as s34, + atuple([ascalar(s34.x395), ascalar(s0.c_acctbal)], cross))) |}] diff --git a/lib/mcmc.ml b/lib/mcmc.ml new file mode 100644 index 0000000..a95b403 --- /dev/null +++ b/lib/mcmc.ml @@ -0,0 +1,66 @@ +include (val Log.make ~level:(Some Info) "castor-opt.mcmc") + +module Random_choice = struct + module T = struct + type t = { + mutable pairs : ((string * Ast.t) * bool) list; + state : (Random.State.t[@sexp.opaque]); [@compare.ignore] + } + [@@deriving compare, sexp] + end + + include T + + module C = struct + include T + include Comparable.Make (T) + end + + let create ?(seed = 0) () = + { pairs = []; state = Random.State.make [| seed |] } + + let rand rand n r = + match + List.Assoc.find ~equal:[%compare.equal: string * Ast.t] rand.pairs (n, r) + with + | Some v -> v + | None -> + rand.pairs <- ((n, r), true) :: rand.pairs; + true + + let perturb rand = + let i' = Random.State.int rand.state @@ List.length rand.pairs in + { + rand with + pairs = + List.mapi rand.pairs ~f:(fun i (k, v) -> + if i = i' then (k, false) else (k, v)); + } + + let length r = List.length r.pairs +end + +let run ?(max_time = Time.Span.of_min 10.0) eval = + let start_time = Time.now () in + let state = Random_choice.create () in + let score = eval state in + info (fun m -> m "Initial score %f" score); + let rec loop state score = + info (fun m -> m "State space size: %d" (Random_choice.length state)); + if Time.(Span.(diff (now ()) start_time > max_time)) then ( + info (fun m -> m "Out of time. Final score %f" score); + (state, score)) + else + let state' = Random_choice.perturb state in + let score' = eval state' in + let h = Float.(min 1.0 (exp @@ (score - score'))) in + let u = Random.float 1.0 in + if Float.(u < h) then ( + info (fun m -> + m "Transitioning. Old score %f, new score %f" score score'); + loop state' score') + else ( + info (fun m -> m "Staying. Old score %f, new score %f" score score'); + loop state score) + in + loop state score diff --git a/lib/orderby_tactics.ml b/lib/orderby_tactics.ml new file mode 100644 index 0000000..45245ae --- /dev/null +++ b/lib/orderby_tactics.ml @@ -0,0 +1,84 @@ +open Ast +module A = Abslayout + +module Config = struct + module type S = sig + include Ops.Config.S + include Tactics_util.Config.S + + val params : Set.M(Name).t + end +end + +module Make (Config : Config.S) = struct + open Ops.Make (Config) + module Tactics_util = Tactics_util.Make (Config) + + let key_is_supported r key = + let s = Set.of_list (module Name) (Schema.schema r) in + List.for_all key ~f:(fun (p, _) -> + Tactics_util.is_supported r.meta#stage s p) + + module C = (val Constructors.Annot.with_strip_meta (fun () -> object end)) + + let push_orderby_depjoin key mk lhs scope rhs meta = + let open Option.Let_syntax in + let used_names = + let schema_rhs = Schema.schema rhs |> Set.of_list (module Name) in + List.map key ~f:(fun (p, _) -> Free.pred_free p) + |> Set.union_list (module Name) + |> Set.inter schema_rhs |> Set.to_list + in + let eqs = + let unscope n = + match Name.rel n with + | Some s when String.(s = scope) -> Name.unscoped n + | _ -> n + in + Set.map + (module Equiv.Eq) + meta#meta#eq + ~f:(fun (n, n') -> (unscope n, unscope n')) + in + let%map ctx = + Join_elim.translate eqs ~from:used_names ~to_:(Schema.schema lhs) + in + let ctx = + List.map ctx ~f:(fun (n, n') -> (n, Name n')) + |> Map.of_alist_exn (module Name) + in + let key = List.map key ~f:(fun (p, o) -> (Pred.subst ctx p, o)) in + mk (C.order_by key lhs) scope rhs + + let push_orderby r = + let open C in + let orderby_cross_tuple key rs = + match rs with + | r :: rs -> + if key_is_supported r key then + Some (tuple (order_by key r :: List.map ~f:strip_meta rs) Cross) + else None + | _ -> None + in + match r.node with + | OrderBy { key; rel = { node = Select (ps, r); _ } } -> + if key_is_supported r key then Some (select ps (order_by key r)) + else None + | OrderBy { key; rel = { node = Filter (ps, r); _ } } -> + Some (filter ps (order_by key r)) + | OrderBy { key; rel = { node = AHashIdx h; _ } } -> + Some + (hash_idx ?key_layout:h.hi_key_layout h.hi_keys h.hi_scope + (order_by key h.hi_values) h.hi_lookup) + | OrderBy { key; rel = { node = ATuple (rs, Cross); _ } } -> + orderby_cross_tuple key rs + | OrderBy { key; rel = { node = DepJoin d; meta } } -> + push_orderby_depjoin key dep_join d.d_lhs d.d_alias d.d_rhs meta + | OrderBy { key; rel = { node = AList l; meta } } -> + push_orderby_depjoin key list l.l_keys l.l_scope l.l_values meta + | _ -> None + + let push_orderby = + of_func_pre push_orderby ~name:"push-orderby" ~pre:(fun r -> + r |> Equiv.annotate |> Resolve.resolve_exn ~params:Config.params) +end diff --git a/lib/orderby_tactics_test.ml b/lib/orderby_tactics_test.ml new file mode 100644 index 0000000..d996fa8 --- /dev/null +++ b/lib/orderby_tactics_test.ml @@ -0,0 +1,122 @@ +open Castor_test.Test_util +module A = Abslayout + +module Config = struct + let cost_conn = Db.create "postgresql:///tpch_1k" + let conn = cost_conn + let validate = false + let param_ctx = Map.empty (module Name) + let params = Set.empty (module Name) + let verbose = false + let simplify = None +end + +open Orderby_tactics.Make (Config) +module O = Ops.Make (Config) + +let%expect_test "" = + let r = + {| + orderby([s_suppkey desc], + depjoin(alist(dedup(select([l_suppkey as l1_suppkey], lineitem)) as k0, + select([l_suppkey as supplier_no, sum(agg0) as total_revenue], + aorderedidx(dedup(select([l_shipdate], lineitem)) as s4, + filter((count0 > 0), + select([count() as count0, sum((l_extendedprice * (1 - l_discount))) as agg0, + l_suppkey, l_extendedprice, l_discount], + atuple([ascalar(s4.l_shipdate), + alist(select([l_suppkey, l_extendedprice, l_discount], + filter(((l_suppkey = k0.l1_suppkey) && (l_shipdate = s4.l_shipdate)), lineitem)) as s5, + atuple([ascalar(s5.l_suppkey), ascalar(s5.l_extendedprice), ascalar(s5.l_discount)], + cross))], + cross))), + >= date("0000-01-01"), < (date("0000-01-01") + month(3))))) as s1, + select([s_address, s_name, s_phone, s_suppkey, s1.total_revenue], + ahashidx(dedup(select([s_suppkey], supplier)) as s2, + alist(select([s_suppkey, s_name, s_address, s_phone], filter((s2.s_suppkey = s_suppkey), supplier)) as s0, + atuple([ascalar(s0.s_suppkey), ascalar(s0.s_name), ascalar(s0.s_address), ascalar(s0.s_phone)], cross)), + s1.supplier_no)))) +|} + |> Abslayout_load.load_string_exn (Lazy.force tpch_conn) + in + Format.printf "%a" (Fmt.Dump.option A.pp) @@ O.apply push_orderby Path.root r; + [%expect + {| + Some + depjoin(orderby([supplier_no desc], + alist(dedup(select([l_suppkey as l1_suppkey], lineitem)) as k0, + select([l_suppkey as supplier_no, sum(agg0) as total_revenue], + aorderedidx(dedup(select([l_shipdate], lineitem)) as s4, + filter((count0 > 0), + select([count() as count0, + sum((l_extendedprice * (1 - l_discount))) as agg0, + l_suppkey, l_extendedprice, l_discount], + atuple([ascalar(s4.l_shipdate), + alist(select([l_suppkey, l_extendedprice, + l_discount], + filter(((l_suppkey = k0.l1_suppkey) && + (l_shipdate = s4.l_shipdate)), + lineitem)) as s5, + atuple([ascalar(s5.l_suppkey), + ascalar(s5.l_extendedprice), + ascalar(s5.l_discount)], + cross))], + cross))), + >= date("0000-01-01"), < (date("0000-01-01") + month(3)))))) as s1, + select([s_address, s_name, s_phone, s_suppkey, s1.total_revenue], + ahashidx(dedup(select([s_suppkey], supplier)) as s2, + alist(select([s_suppkey, s_name, s_address, s_phone], + filter((s2.s_suppkey = s_suppkey), supplier)) as s0, + atuple([ascalar(s0.s_suppkey), ascalar(s0.s_name), + ascalar(s0.s_address), ascalar(s0.s_phone)], + cross)), + s1.supplier_no))) |}] + +let%expect_test "" = + let r = + {| +orderby([supplier_no desc], + alist(dedup(select([l_suppkey as l1_suppkey], lineitem)) as k0, + select([l_suppkey as supplier_no, sum(agg0) as total_revenue], + aorderedidx(dedup(select([l_shipdate], lineitem)) as s4, + filter((count0 > 0), + select([count() as count0, + sum((l_extendedprice * (1 - l_discount))) as agg0, + l_suppkey, l_extendedprice, l_discount], + atuple([ascalar(s4.l_shipdate), + alist(select([l_suppkey, l_extendedprice, + l_discount], + filter(((l_suppkey = k0.l1_suppkey) && + (l_shipdate = s4.l_shipdate)), + lineitem)) as s5, + atuple([ascalar(s5.l_suppkey), + ascalar(s5.l_extendedprice), + ascalar(s5.l_discount)], + cross))], + cross))), + >= date("0000-01-01"), < (date("0000-01-01") + month(3))))))|} + |> Abslayout_load.load_string_exn (Lazy.force tpch_conn) + in + Format.printf "%a" (Fmt.Dump.option A.pp) @@ O.apply push_orderby Path.root r; + [%expect + {| + Some + alist(orderby([l1_suppkey desc], + dedup(select([l_suppkey as l1_suppkey], lineitem))) as k0, + select([l_suppkey as supplier_no, sum(agg0) as total_revenue], + aorderedidx(dedup(select([l_shipdate], lineitem)) as s4, + filter((count0 > 0), + select([count() as count0, + sum((l_extendedprice * (1 - l_discount))) as agg0, + l_suppkey, l_extendedprice, l_discount], + atuple([ascalar(s4.l_shipdate), + alist(select([l_suppkey, l_extendedprice, l_discount], + filter(((l_suppkey = k0.l1_suppkey) && + (l_shipdate = s4.l_shipdate)), + lineitem)) as s5, + atuple([ascalar(s5.l_suppkey), + ascalar(s5.l_extendedprice), + ascalar(s5.l_discount)], + cross))], + cross))), + >= date("0000-01-01"), < (date("0000-01-01") + month(3))))) |}] diff --git a/lib/select_tactics.ml b/lib/select_tactics.ml new file mode 100644 index 0000000..52d3b0f --- /dev/null +++ b/lib/select_tactics.ml @@ -0,0 +1,237 @@ +open Ast +open Abslayout +open Collections +open Schema +module A = Abslayout +module P = Pred.Infix +module V = Visitors +open Match + +module Config = struct + module type S = sig + include Ops.Config.S + include Tactics_util.Config.S + end +end + +module Make (C : Config.S) = struct + open C + open Ops.Make (C) + open Tactics_util.Make (C) + open Simplify_tactic.Make (C) + + (** Push a select that doesn't contain aggregates. *) + let push_simple_select r = + let open Option.Let_syntax in + let%bind ps, r' = to_select r in + let%bind () = match select_kind ps with `Scalar -> Some () | _ -> None in + match r'.node with + | AList x -> return @@ list' { x with l_values = select ps x.l_values } + | DepJoin d -> return @@ dep_join' { d with d_rhs = select ps d.d_rhs } + | _ -> None + + let push_simple_select = of_func push_simple_select ~name:"push-simple-select" + + (** Extend a list of predicates to include those needed by aggregate `p`. + Returns a name to use in the aggregate. *) + let extend_aggs aggs p = + let aggs = ref aggs in + let add_agg a = + match + List.find !aggs ~f:(fun (_, a') -> [%compare.equal: _ pred] a a') + with + | Some (n, _) -> P.name n + | None -> + let n = + Fresh.name Global.fresh "agg%d" + |> Name.create ~type_:(Pred.to_type a) + in + aggs := (n, a) :: !aggs; + Name n + in + let visitor = + object + inherit [_] V.map + method! visit_Sum () p = Sum (add_agg (Sum p)) + method! visit_Count () = Sum (add_agg Count) + method! visit_Min () p = Min (add_agg (Min p)) + method! visit_Max () p = Max (add_agg (Max p)) + + method! visit_Avg () p = + Binop (Div, Sum (add_agg (Sum p)), Sum (add_agg Count)) + end + in + let p' = visitor#visit_pred () p in + (!aggs, p') + + (* Generate aggregates for collections that act by concatenating their + children. *) + let gen_concat_select_list outer_preds inner_schema = + let outer_aggs, inner_aggs = + List.fold_left outer_preds ~init:([], []) ~f:(fun (op, ip) p -> + let ip, p = extend_aggs ip p in + (op @ [ p ], ip)) + in + let inner_aggs = + List.map inner_aggs ~f:(fun (n, a) -> P.as_ a @@ Name.name n) + in + (* Don't want to project out anything that we might need later. *) + let inner_fields = inner_schema |> List.map ~f:P.name in + (outer_aggs, inner_aggs @ inner_fields) + + (* Look for evidence of a previous pushed select. *) + let already_pushed r' = + try + match Path.get_exn (Path.child Path.root 1) r' with + | { node = Filter (_, { node = Select _; _ }); _ } -> true + | _ -> false + with _ -> false + + let extend_with_tuple ns r = + tuple (List.map ns ~f:(fun n -> scalar @@ P.name n) @ [ r ]) Cross + + let push_select_collection r = + let open Option.Let_syntax in + let%bind ps, r' = to_select r in + let%bind () = match select_kind ps with `Agg -> Some () | _ -> None in + if already_pushed r' then None + else + let%map outer_preds, inner_preds = + match r'.node with + | AHashIdx h -> + let o = List.filter_map ps ~f:Pred.to_name |> List.map ~f:P.name + and i = + (* TODO: This hack works around problems with sql conversion + and lateral joins. *) + let kschema = schema h.hi_keys |> scoped h.hi_scope in + List.filter ps ~f:(function + | Name n -> not (List.mem ~equal:Name.O.( = ) kschema n) + | _ -> true) + in + return (o, i) + | AOrderedIdx { oi_values = rv; _ } + | AList { l_values = rv; _ } + | ATuple (rv :: _, Concat) -> + return @@ gen_concat_select_list ps (schema rv) + | _ -> None + and mk_collection = + match r'.node with + | AHashIdx h -> + let rk = h.hi_keys and rv = h.hi_values in + return @@ fun mk -> + hash_idx' { h with hi_values = mk (schema rk) h.hi_scope rv } + | AOrderedIdx o -> + let rk = o.oi_keys and rv = o.oi_values in + return @@ fun mk -> + ordered_idx' { o with oi_values = mk (schema rk) o.oi_scope rv } + | AList l -> + return @@ fun mk -> + list' { l with l_values = mk [] l.l_scope l.l_values } + | _ -> None + in + select outer_preds @@ mk_collection + @@ fun rk_schema scope rv -> + let inner_preds = + List.map inner_preds ~f:(Pred.scoped rk_schema scope) + |> List.filter ~f:(fun p -> + Pred.to_name p + |> Option.map ~f:(fun n -> + not (List.mem rk_schema n ~equal:[%compare.equal: Name.t])) + |> Option.value ~default:true) + in + select inner_preds rv + + let push_select_collection = + of_func_cond ~pre:Option.return + ~post:(fun r -> Resolve.resolve ~params r |> Result.ok) + push_select_collection ~name:"push-select-collection" + + let push_select_filter r = + let open Option.Let_syntax in + let%bind ps, r' = to_select r in + let%bind p', r'' = to_filter r' in + return @@ A.filter p' @@ A.select ps r'' + + let push_select_filter = + of_func_cond ~name:"push-select-filter" ~pre:Option.return + push_select_filter ~post:(fun r -> Resolve.resolve ~params r |> Result.ok) + + let push_select_depjoin r = + let open Option.Let_syntax in + let%bind ps, r' = to_select r in + let%bind { d_lhs; d_alias; d_rhs } = to_depjoin r' in + return @@ A.dep_join d_lhs d_alias @@ A.select ps d_rhs + + let push_select_depjoin = + of_func ~name:"push-select-depjoin" push_select_depjoin + + let push_select = + seq_many + [ + flatten_select; + first_success + [ push_select_collection; push_select_filter; push_select_depjoin ]; + ] + + let push_subqueries r = + let open Option.Let_syntax in + let%bind ps, r = to_select r in + let visitor = + object + inherit extract_subquery_visitor + method can_hoist _ = true + method fresh_name () = Name.create @@ Fresh.name Global.fresh "q%d" + end + in + let ps, subqueries = List.map ps ~f:(visitor#visit_pred ()) |> List.unzip in + let subqueries = + List.concat subqueries + |> List.map ~f:(fun (n, p) -> As_pred (p, Name.name n)) + in + return @@ A.select ps + @@ A.select ((Schema.schema r |> Schema.to_select_list) @ subqueries) r + + let push_subqueries = of_func push_subqueries ~name:"push-subqueries" + + let split_pred_left r = + let open Option.Let_syntax in + let%bind ps, r = to_select r in + let name = Name.fresh "x%d" in + match ps with + | [ Binop (op, p, p') ] -> + return + @@ A.select [ Binop (op, Name name, p') ] + @@ A.select [ As_pred (p, Name.name name) ] + @@ r + | [ As_pred (Binop (op, p, p'), n) ] -> + return + @@ A.select [ As_pred (Binop (op, Name name, p'), n) ] + @@ A.select [ As_pred (p, Name.name name) ] + @@ r + | _ -> None + + let split_pred_left = of_func split_pred_left ~name:"split-pred-left" + + let hoist_param = + let open A in + let f r = + match r.node with + | Select + ( [ + As_pred + ( First + { + node = Select ([ As_pred (Binop (op, p1, p2), _) ], r'); + _; + }, + n ); + ], + r ) -> + let fresh_id = Fresh.name Global.fresh "const%d" in + Option.return + @@ select [ As_pred (Binop (op, Name (Name.create fresh_id), p2), n) ] + @@ select [ As_pred (First (select [ p1 ] r'), fresh_id) ] r + | _ -> None + in + of_func ~name:"hoist-param" f +end diff --git a/lib/select_tactics_test.ml b/lib/select_tactics_test.ml new file mode 100644 index 0000000..d0868b9 --- /dev/null +++ b/lib/select_tactics_test.ml @@ -0,0 +1,103 @@ +open Abslayout +open Abslayout_load +open Select_tactics +open Castor_test.Test_util + +module C = struct + let params = Set.empty (module Name) + let fresh = Fresh.create () + let verbose = false + let validate = true + let param_ctx = Map.empty (module Name) + let conn = Lazy.force tpch_conn + let cost_conn = conn + let simplify = None +end + +open Make (C) +open Ops.Make (C) + +let () = + Log.setup_stderr (); + Logs.Src.set_level Check.src (Some Error) + +let%expect_test "push-select-index" = + let r = + load_string_exn (Lazy.force tpch_conn) + {| +select([sum(o_totalprice) as revenue], + aorderedidx(select([o_orderdate], dedup(select([o_orderdate], orders))) as s1, + filter(o_orderdate = s1.o_orderdate, orders), + >= date("0001-01-01"), < (date("0001-01-01") + year(1)))) +|} + in + let r' = Option.value_exn (apply push_select Path.root r) in + Format.printf "%a\n" pp r'; + [%expect + {| + select([sum(agg0) as revenue], + aorderedidx(select([o_orderdate], dedup(select([o_orderdate], orders))) as s1, + select([sum(o_totalprice) as agg0, o_orderkey, o_custkey, o_orderstatus, + o_totalprice, o_orderpriority, o_clerk, o_shippriority, o_comment], + filter((o_orderdate = s1.o_orderdate), orders)), + >= date("0001-01-01"), < (date("0001-01-01") + year(1)))) |}] + +let%expect_test "" = + let r = + load_string_exn (Lazy.force tpch_conn) + {| +select([substring(c1_phone, 0, 2) as x669, c1_phone, c1_acctbal, c1_custkey], + filter((c1_acctbal > + (select([avg(c_acctbal) as avgbal], + filter(((c_acctbal > 0.0) && + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || (substring(c_phone, 0, 2) = "")))))))), + customer)))), + filter(not(exists(filter((o_custkey = c1_custkey), orders))), + select([c_phone as c1_phone, c_acctbal as c1_acctbal, c_custkey as c1_custkey], customer)))) +|} + in + apply push_select_filter Path.root r |> Option.iter ~f:(Fmt.pr "%a@." pp); + [%expect + {| + filter((c1_acctbal > + (select([avg(c_acctbal) as avgbal], + filter(((c_acctbal > 0.0) && + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + (substring(c_phone, 0, 2) = "")))))))), + customer)))), + select([substring(c1_phone, 0, 2) as x669, c1_phone, c1_acctbal, c1_custkey], + filter(not(exists(filter((o_custkey = c1_custkey), orders))), + select([c_phone as c1_phone, c_acctbal as c1_acctbal, + c_custkey as c1_custkey], + customer)))) |}] + +let%expect_test "" = + let r = + load_string_exn (Lazy.force tpch_conn) + {| +select([substring(c1_phone, 0, 2) as x669, c1_phone, c1_custkey], + filter((c1_acctbal > + (select([avg(c_acctbal) as avgbal], + filter(((c_acctbal > 0.0) && + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || + ((substring(c_phone, 0, 2) = "") || (substring(c_phone, 0, 2) = "")))))))), + customer)))), + filter(not(exists(filter((o_custkey = c1_custkey), orders))), + select([c_phone as c1_phone, c_acctbal as c1_acctbal, c_custkey as c1_custkey], customer)))) +|} + in + apply push_select_filter Path.root r |> Option.iter ~f:(Fmt.pr "%a@." pp) diff --git a/lib/simple_tactics.ml b/lib/simple_tactics.ml new file mode 100644 index 0000000..69e5a30 --- /dev/null +++ b/lib/simple_tactics.ml @@ -0,0 +1,29 @@ +open Ast +open Collections +module A = Abslayout + +module Config = struct + module type S = sig + include Ops.Config.S + end +end + +module Make (Config : Config.S) = struct + open Config + open Ops.Make (Config) + + let row_store r = + (* Relation has no free variables that are bound at runtime. *) + if Is_serializable.is_static ~params r then + let scope = Fresh.name Global.fresh "s%d" in + let scalars = + Schema.schema r |> Schema.scoped scope + |> List.map ~f:(fun n -> A.scalar (Name n)) + in + Some (A.list (strip_meta r) scope (A.tuple scalars Cross)) + else None + + let row_store = + of_func_pre row_store ~pre:Is_serializable.annotate_stage + ~name:"to-row-store" +end diff --git a/lib/simple_tactics_test.ml b/lib/simple_tactics_test.ml new file mode 100644 index 0000000..4baabbd --- /dev/null +++ b/lib/simple_tactics_test.ml @@ -0,0 +1,39 @@ +open Abslayout +open Simple_tactics +open Castor_test.Test_util + +module C = struct + let params = + Set.singleton (module Name) (Name.create ~type_:Prim_type.int_t "param") + + let fresh = Fresh.create () + let verbose = false + let validate = false + let param_ctx = Map.empty (module Name) + let conn = Lazy.force test_db_conn + let cost_conn = Lazy.force test_db_conn + let simplify = None +end + +open Make (C) +open Ops.Make (C) + +let load_string ?params s = Abslayout_load.load_string_exn ?params C.conn s + +let%expect_test "row-store-comptime" = + let r = load_string "alist(r as r1, filter(r1.f = f, r))" in + Option.iter + (apply (at_ row_store Path.(all >>? is_filter >>| shallowest)) Path.root r) + ~f:(Format.printf "%a\n" pp); + [%expect + {| + alist(r as r1, + alist(filter((r1.f = f), r) as s0, + atuple([ascalar(s0.f), ascalar(s0.g)], cross))) |}] + +let%expect_test "row-store-runtime" = + let r = load_string "depjoin(r as r1, filter(r1.f = f, r))" in + Option.iter + (apply (at_ row_store Path.(all >>? is_filter >>| shallowest)) Path.root r) + ~f:(Format.printf "%a\n" pp); + [%expect {| |}] diff --git a/lib/string_tactics.ml b/lib/string_tactics.ml new file mode 100644 index 0000000..4a1e54c --- /dev/null +++ b/lib/string_tactics.ml @@ -0,0 +1,127 @@ +open Ast +open Abslayout +open Collections +module V = Visitors +module A = Abslayout +module P = Pred.Infix + +module Config = struct + module type My_s = sig + val conn : Db.t + val params : Set.M(Name).t + end + + module type S = sig + include My_s + include Ops.Config.S + end +end + +module Make (C : Config.S) = struct + module Ops = Ops.Make (C) + open Ops + module C : Config.My_s = C + open C + + let subst_rel ~key ~data r = + let visitor = + object + inherit [_] V.map + + method! visit_Relation () rel = + if [%compare.equal: Relation.t] key rel then data.node + else Relation rel + end + in + visitor#visit_t () r + + let dictionary_encode _ r = + let preds_v = + object + inherit [_] V.reduce + inherit [_] Util.list_monoid + + method! visit_Binop () op a1 a2 = + match (op, Pred.to_type a1, Pred.to_type a2) with + | Eq, StringT _, StringT _ -> [ (a1, a2) ] + | _ -> [] + end + in + preds_v#visit_t () r + |> List.filter_map ~f:(function + | Name n1, Name n2 -> ( + match + ( Db.relation_has_field conn (Name.name n1), + Db.relation_has_field conn (Name.name n2) ) + with + | Some rel, None -> + if Set.mem params n2 then Some (n1, rel, n2) else None + | None, Some rel -> + if Set.mem params n1 then Some (n2, rel, n1) else None + | _ -> None) + | _ -> None) + |> List.map ~f:(fun (key, rel, lookup) -> + let count_name = Fresh.name Global.fresh "c%d" in + let encoded_name = Fresh.name Global.fresh "x%d" in + let encoded_lookup_name = Fresh.name Global.fresh "x%d" in + let map_name = Fresh.name Global.fresh "m%d" in + let mapping = + select + [ + As_pred (Row_number, encoded_name); Name (Name.create map_name); + ] + (order_by + [ (Name (Name.create encoded_name), Desc) ] + (dedup + (select [ As_pred (Name key, map_name) ] (relation rel)))) + in + let encoded_rel = + join + (Binop (Eq, Name key, Name (Name.create map_name))) + (relation rel) mapping + in + let encoded_lookup = + let scope = Fresh.name Global.fresh "s%d" in + select + [ + As_pred + ( If + ( Binop (Gt, Name (Name.create count_name), Int 0), + Name (Name.create encoded_name), + Int (-1) ), + encoded_lookup_name ); + ] + (select + [ + As_pred (Count, count_name); Name (Name.create encoded_name); + ] + (hash_idx + (dedup + (select [ As_pred (Name key, map_name) ] (relation rel))) + scope + (select + [ Name (Name.create encoded_name) ] + (A.filter + P.( + Name (Name.create ~scope map_name) + = Name (Name.create map_name)) + mapping)) + [ Name lookup ])) + in + let scope = Fresh.name Global.fresh "s%d" in + dep_join encoded_lookup scope + (subst_rel ~key:rel ~data:encoded_rel + (subst + (Map.of_alist_exn + (module Name) + [ + (key, P.name @@ Name.create encoded_name); + ( lookup, + P.name @@ Name.create ~scope encoded_lookup_name ); + ]) + r))) + |> Seq.of_list + + let dictionary_encode = + Branching.(global dictionary_encode ~name:"dictionary-encode") +end diff --git a/lib/tactics_util.ml b/lib/tactics_util.ml new file mode 100644 index 0000000..c4dff30 --- /dev/null +++ b/lib/tactics_util.ml @@ -0,0 +1,285 @@ +open Ast +open Abslayout +open Collections +open Schema +module V = Visitors +module A = Abslayout +module P = Pred.Infix + +module Config = struct + module type S = sig + val conn : Db.t + val cost_conn : Db.t + val params : Set.M(Name).t + end +end + +module Make (Config : Config.S) = struct + open Config + + (** Remove all references to names in params while ensuring that the resulting + relation overapproximates the original. *) + let over_approx params r = + let visitor = + object (self) + inherit [_] V.map as super + + method! visit_Filter () (p, r) = + if Set.is_empty (Set.inter (Pred.names p) params) then + super#visit_Filter () (p, r) + else (self#visit_t () r).node + + method! visit_Select () (ps, r) = + match A.select_kind ps with + | `Agg -> Select (ps, r) + | `Scalar -> Select (ps, self#visit_t () r) + + method! visit_GroupBy () (ps, ks, r) = GroupBy (ps, ks, r) + end + in + let r = visitor#visit_t () r in + let remains = Set.inter (Free.free r) params in + if Set.is_empty remains then Ok r + else + Or_error.error "Failed to remove all parameters." remains + [%sexp_of: Set.M(Name).t] + + (** Precise selection of all valuations of a list of predicates from a relation. + *) + let all_values_precise ps r = + if Set.is_empty (Free.free r) then Ok (A.dedup (A.select ps r)) + else Or_error.errorf "Predicate contains free variables." + + let all_values_approx_1 ps r = + let open Or_error.Let_syntax in + let%map r' = over_approx params r in + A.dedup @@ A.select ps r' + + let rec closure m = + let m' = Map.map m ~f:(Pred.subst m) in + if [%compare.equal: Pred.t Map.M(Name).t] m m' then m else closure m' + + let group_by m ~f l = + List.fold_left ~init:(Map.empty m) + ~f:(fun m e -> Map.add_multi m ~key:(f e) ~data:e) + l + + let alias_map r = aliases r |> closure + + (** Collect a map from names to defining expressions from a relation. *) + let rec aliases r = + let plus = + Map.merge ~f:(fun ~key:_ -> function + | `Left r | `Right r -> Some r + | `Both (r1, r2) -> + if Pred.O.(r1 = r2) then Some r1 + else failwith "Multiple relations with same alias") + and zero = Map.empty (module Name) + and one k v = Map.singleton (module Name) k v in + match r.node with + | Select (ps, r) -> ( + match select_kind ps with + | `Scalar -> + List.fold_left ps ~init:(aliases r) ~f:(fun m p -> + match p with + | As_pred (p, n) -> plus (one (Name.create n) p) m + | _ -> m) + | `Agg -> zero) + | Filter (_, r) | Dedup r -> aliases r + | _ -> zero + + let all_values_attr n = + let open Option.Let_syntax in + let%bind rel = Db.relation_has_field cost_conn (Name.name n) in + return @@ A.select [ Name n ] @@ A.relation rel + + (** Approximate selection of all valuations of a list of predicates from a + relation. Works if the relation is parameterized, but only when the + predicates do not depend on those parameters. *) + let all_values_approx_2 ps r = + let open Or_error.Let_syntax in + (* Otherwise, if all grouping keys are from named relations, select all + possible grouping keys. *) + let alias_map = aliases r in + (* Find the definition of each key and collect all the names in that + definition. If they all come from base relations, then we can enumerate + the keys. *) + let orig_names = List.map ps ~f:Pred.to_name in + let ps = List.map ps ~f:(Pred.subst alias_map) in + + (* Try to substitute names that don't come from base relations with equivalent names that do. *) + let subst = + Equiv.eqs r |> Set.to_list + |> List.filter_map ~f:(fun (n, n') -> + match + ( Db.relation_has_field cost_conn (Name.name n), + Db.relation_has_field cost_conn (Name.name n') ) + with + | None, None | Some _, Some _ -> None + | Some _, None -> Some (n', n) + | None, Some _ -> Some (n, n')) + |> Map.of_alist_reduce (module Name) ~f:(fun n _ -> n) + |> Map.map ~f:P.name + in + + let preds = List.map ps ~f:(Pred.subst subst) in + + (* Collect the relations referred to by the predicate list. *) + let%bind rels = + List.map preds ~f:(fun p -> + List.map + (Pred.names p |> Set.to_list) + ~f:(fun n -> + match Db.relation_has_field cost_conn (Name.name n) with + | Some r -> Ok (r, n) + | None -> + Or_error.error "Name does not come from base relation." n + [%sexp_of: Name.t]) + |> Or_error.all) + |> Or_error.all + in + + let joined_rels = + List.concat rels + |> List.map ~f:(fun (r, n) -> (r.Relation.r_name, n)) + |> Map.of_alist_multi (module String) + |> Map.to_alist + |> List.map ~f:(fun (r, ns) -> + dedup + @@ select (Select_list.of_list @@ List.map ns ~f:P.name) + @@ relation (Db.relation cost_conn r)) + |> List.reduce ~f:(join (Bool true)) + in + + let sel_list = + List.map2_exn orig_names preds ~f:(fun n p -> + match n with Some n -> P.as_ p (Name.name n) | None -> p) + |> Select_list.of_list + in + match joined_rels with + | Some r -> Ok (select sel_list r) + | None -> Or_error.errorf "No relations found." + + let all_values_approx ps r = + if List.length ps = 1 then all_values_approx_2 ps r + else + match all_values_approx_1 ps r with + | Ok r' -> Ok r' + | Error _ -> all_values_approx_2 ps r + + let all_values ps r = + match all_values_precise ps r with + | Ok r' -> Ok r' + | Error _ -> all_values_approx ps r + + (** Check that a predicate is fully supported by a relation (it does not + depend on anything in the context.) *) + let is_supported stage bound pred = + Set.for_all (Free.pred_free pred) ~f:(fun n -> + Set.mem bound n + (* TODO: We assume that compile time names that are bound in the context + are ok, but this might not be true? *) + || (match Map.find stage n with + | Some `Compile -> true + | Some `Run -> false + | None -> + Logs.warn (fun m -> m "Missing stage on %a" Name.pp n); + false) + && Option.is_some (Name.rel n)) + + (** Remove names from a selection list. *) + let select_out ns r = + let ns = List.map ns ~f:Name.unscoped in + select + (schema r + |> List.filter ~f:(fun n' -> + not (List.mem ~equal:Name.O.( = ) ns (Name.unscoped n'))) + |> List.map ~f:P.name) + r + + let select_contains names ps r = + Set.( + is_empty + (diff + (inter names (of_list (module Name) (schema r))) + (of_list (module Name) (List.filter_map ~f:Pred.to_name ps)))) + + let rec all_pairs = function + | [] -> [] + | x :: xs -> List.map xs ~f:(fun x' -> (x, x')) @ all_pairs xs + + let rec disjoin = + let open Ast in + function + | [] -> Bool false | [ p ] -> p | p :: ps -> Binop (Or, p, disjoin ps) + + (** For a set of predicates, check whether more than one predicate is true at + any time. *) + let all_disjoint ps r = + let open Or_error.Let_syntax in + if List.length ps <= 1 then return true + else + let%map tups = + let sql = + let pred = + all_pairs ps |> List.map ~f:(fun (p, p') -> P.(p && p')) |> disjoin + in + filter pred @@ r + |> Unnest.unnest ~params:(Set.empty (module Name)) + |> Sql.of_ralgebra |> Sql.to_string + in + Log.debug (fun m -> m "All disjoint sql: %s" sql); + Db.run conn sql + in + List.length tups = 0 + + let replace_rel rel new_rel r = + let visitor = + object + inherit [_] Visitors.endo + + method! visit_Relation () r' { r_name = rel'; _ } = + if String.(rel = rel') then new_rel.Ast.node else r' + end + in + visitor#visit_t () r + + (** Visitor for extracting subqueries from a query, leaving behind names so + that the subqueries can later be bound to the names. + + can_hoist should be overwritten to determine whether the subquery can be + removed *) + class virtual extract_subquery_visitor = + object (self : 'self) + inherit [_] Visitors.mapreduce + inherit [_] Util.list_monoid + method virtual can_hoist : Ast.t -> bool + method virtual fresh_name : unit -> Name.t + + method! visit_AList () l = + let rv, ret = self#visit_t () l.l_values in + (AList { l with l_values = rv }, ret) + + method! visit_AHashIdx () h = + let hi_values, ret = self#visit_t () h.hi_values in + (AHashIdx { h with hi_values }, ret) + + method! visit_AOrderedIdx () o = + let rv, ret = self#visit_t () o.oi_values in + (AOrderedIdx { o with oi_values = rv }, ret) + + method! visit_AScalar () x = (AScalar x, []) + + method! visit_Exists () r = + if self#can_hoist r then + let name = self#fresh_name () in + (Name name, [ (name, Ast.Exists r) ]) + else (Exists r, []) + + method! visit_First () r = + if self#can_hoist r then + let name = self#fresh_name () in + (Name name, [ (name, First r) ]) + else (First r, []) + end +end diff --git a/lib/transform.ml b/lib/transform.ml new file mode 100644 index 0000000..d0a5b76 --- /dev/null +++ b/lib/transform.ml @@ -0,0 +1,381 @@ +open Collections +open Ast +module R = Resolve +module V = Visitors + +module Config = struct + module type S = sig + val conn : Db.t + val cost_conn : Db.t + val params : Set.M(Name).t + val cost_timeout : float option + val random : Mcmc.Random_choice.t + end +end + +module Make (Config : Config.S) = struct + open Config + module O = Ops.Make (Config) + open O + module Filter_tactics = Filter_tactics.Make (Config) + module Simple_tactics = Simple_tactics.Make (Config) + module Join_opt = Join_opt.Make (Config) + module Simplify_tactic = Simplify_tactic.Make (Config) + module Select_tactics = Select_tactics.Make (Config) + module Groupby_tactics = Groupby_tactics.Make (Config) + module Join_elim_tactics = Join_elim_tactics.Make (Config) + module Tactics_util = Tactics_util.Make (Config) + module Dedup_tactics = Dedup_tactics.Make (Config) + module Orderby_tactics = Orderby_tactics.Make (Config) + module Cost = Type_cost.Make (Config) + + let try_random tf = + global + (fun p r -> + if Mcmc.Random_choice.rand random tf.name (Path.get_exn p r) then + apply tf p r + else Some r) + "try-random" + + let try_random_branch tf = + Branching.global ~name:"try-random" (fun p r -> + if Mcmc.Random_choice.rand random (Branching.name tf) (Path.get_exn p r) + then Branching.apply tf p r + else Seq.singleton r) + + let is_serializable r p = + Is_serializable.is_serializeable ~params ~path:p r |> Result.is_ok + + let is_spine_serializable r p = + Is_serializable.is_spine_serializeable ~params ~path:p r |> Result.is_ok + + let has_params r p = Path.get_exn p r |> Free.free |> overlaps params + let has_free r p = not (Set.is_empty (Free.free (Path.get_exn p r))) + + let push_all_runtime_filters = + for_all Filter_tactics.push_filter Path.(all >>? is_run_time >>? is_filter) + + let push_static_filters = + for_all Filter_tactics.push_filter + Path.(all >>? is_run_time >>? is_filter >>? Infix.not is_param_filter) + + let hoist_all_filters = + for_all Filter_tactics.hoist_filter Path.(all >>? is_filter >> O.parent) + + let elim_param_filter tf test = + (* Eliminate comparison filters. *) + fix @@ traced + @@ seq_many + [ + (* Hoist parameterized filters as far up as possible. *) + for_all Filter_tactics.hoist_filter + (Path.all >>? is_param_filter >> parent); + Branching.( + seq_many + [ + unroll_fix @@ O.traced + @@ O.at_ Filter_tactics.push_filter + Path.(all >>? test >>? is_run_time >>| shallowest); + (* Eliminate a comparison filter. *) + choose + (for_all (try_random_branch tf) + Path.(all >>? test >>? is_run_time)) + id; + lift + (O.seq_many + [ + push_all_runtime_filters; + O.for_all Simple_tactics.row_store + Path.(all >>? is_run_time >>? is_relation); + push_all_runtime_filters; + fix Simplify_tactic.project; + Simplify_tactic.simplify; + ]); + ] + |> lower (min Cost.cost)); + ] + + let try_partition tf = + traced ~name:"try-partition" + @@ Branching.( + seq_many + [ + choose id (try_random_branch @@ traced Filter_tactics.partition); + lift tf; + ] + |> lower (min Cost.cost)) + + let try_ tf rest = + Branching.(seq (choose (lift tf) id) (lift rest) |> lower (min Cost.cost)) + + let try_many tfs rest = + Branching.( + seq (choose_many (List.map ~f:lift tfs)) (lift rest) + |> lower (min Cost.cost)) + + let is_serializable' r = + let bad_runtime_op = + Path.( + all >>? is_run_time + >>? Infix.( + is_join || is_groupby || is_orderby || is_dedup || is_relation)) + r + |> Seq.is_empty |> not + in + let mis_bound_params = + Path.(all >>? is_compile_time) r + |> Seq.for_all ~f:(fun p -> + not (overlaps (Free.free (Path.get_exn p r)) params)) + |> not + in + if bad_runtime_op then Error (Error.of_string "Bad runtime operation.") + else if mis_bound_params then + Error (Error.of_string "Parameters referenced at compile time.") + else Ok () + + let is_serializable'' r = Result.is_ok @@ is_serializable' r + + let _elim_subqueries = + seq_many + [ + Filter_tactics.elim_all_correlated_subqueries; + Simplify_tactic.unnest_and_simplify; + ] + + let cse = + traced ~name:"cse" + @@ seq_many' + [ + for_all Join_elim_tactics.push_join_filter Path.(all >>? is_join); + for_all' Filter_tactics.cse_filter Path.(all >>? is_filter); + ] + + let opt = + let open Infix in + seq_many + [ + seq_many + [ + (* Simplify predicates. *) + traced ~name:"simplify-preds" + @@ for_all Filter_tactics.simplify_filter Path.all; + (* CSE *) + seq_many' + [ + cse; + for_all Select_tactics.push_select Path.(all >>? is_select); + for_all Simple_tactics.row_store + Path.( + all >>? is_run_time >>? not has_params + >>? not is_serializable + >>? not (contains is_collection)); + ]; + (* Eliminate groupby operators. *) + traced ~name:"elim-groupby" + @@ fix + @@ seq_many + [ + first Groupby_tactics.elim_groupby Path.(all >>? is_groupby); + fix push_static_filters; + for_all Join_elim_tactics.push_join_filter + Path.(all >>? is_join); + ]; + (* Hoist parameterized filters as far up as possible. *) + traced ~name:"hoist-param-filters" + @@ try_random + @@ seq_many + [ + for_all Join_elim_tactics.hoist_join_param_filter + Path.(all >>? is_join); + for_all Filter_tactics.hoist_filter + Path.(all >>? is_param_filter >> O.parent); + ]; + try_random + @@ traced ~name:"elim-simple-filter" + @@ at_ Filter_tactics.elim_simple_filter + Path.(all >>? is_expensive_filter >>| shallowest); + (* Eliminate unparameterized join nests. Try using join optimization and + using a simple row store. *) + traced ~name:"elim-join-nests" + @@ try_many + [ + traced ~name:"elim-join-nests-opt" + @@ try_random + @@ for_all Join_opt.transform + Path.(all >>? is_join >>? is_run_time); + traced ~name:"elim-join-nests-flat" + @@ try_random + @@ at_ Simple_tactics.row_store + Path.( + all >>? is_join >>? is_run_time >>? not has_free + >>| shallowest); + id; + ] + (seq_many + [ + try_random @@ traced @@ Filter_tactics.elim_subquery; + try_random @@ push_all_runtime_filters; + Simplify_tactic.project; + traced ~name:"elim-join-filter" + @@ at_ Join_elim_tactics.elim_join_filter + Path.(all >>? is_join >>| shallowest); + try_ + (traced ~name:"elim-disjunct" + (seq_many + [ + hoist_all_filters; + first Filter_tactics.elim_disjunct + Path.(all >>? is_filter >>? is_run_time); + push_all_runtime_filters; + ])) + (seq_many + [ + (* Push constant filters *) + traced ~name:"push-constant-filters" + @@ for_all Filter_tactics.push_filter + Castor.Path.(all >>? is_const_filter); + (* Push orderby operators into compile time position if possible. *) + traced ~name:"push-orderby" + @@ for_all Orderby_tactics.push_orderby + Path.(all >>? is_orderby >>? is_run_time); + (* Eliminate comparison filters. *) + traced ~name:"elim-cmp-filters" + @@ elim_param_filter Filter_tactics.elim_cmp_filter + is_param_cmp_filter; + (* Eliminate equality filters. *) + traced ~name:"elim-eq-filters" + @@ elim_param_filter + (Branching.lift Filter_tactics.elim_eq_filter) + is_param_filter; + traced ~name:"push-all-unparam-filters" + @@ push_all_runtime_filters; + (* Push aggregate selects until they can be eliminated into + a row store. If the elimination doesn't happen, no change + is made. *) + traced ~name:"push-agg-select" + @@ for_all_disjoint + (until' + (at_ + (traced Select_tactics.push_select) + Path.( + all >>? is_agg_select >>? is_run_time + >>| deepest)) + (at_ + (traced Simple_tactics.row_store) + Path.( + all >>? is_agg_select >>? is_run_time + >>| deepest))) + Path.(all >>? is_agg_select >>? is_run_time); + (* Eliminate all unparameterized relations. *) + traced ~name:"elim-unparam-relations" + @@ fix + @@ seq_many + [ + at_ Simple_tactics.row_store + Path.( + all >>? is_run_time >>? not has_params + >>? not is_serializable + >>? not (contains is_collection) + >>| shallowest); + push_all_runtime_filters; + ]; + traced ~name:"push-all-unparam-filters" + @@ push_all_runtime_filters; + (* Push orderby operators into compile time position if possible. *) + traced ~name:"push-orderby-into-ctime" + @@ for_all Orderby_tactics.push_orderby + Path.(all >>? is_orderby >>? is_run_time) + (* Last-ditch tactic to eliminate orderby. *); + traced ~name:"final-orderby-elim" + @@ for_all Simple_tactics.row_store + Path.(all >>? is_orderby >>? is_run_time); + (* Try throwing away structure if it reduces overall cost. *) + (traced ~name:"drop-structure" + @@ Branching.( + seq_many + [ + choose id + (seq_many + [ + for_all + (lift Simple_tactics.row_store) + Path.( + all >>? is_run_time + >>? not has_params + >>? not is_scalar); + lift push_all_runtime_filters; + ]); + filter is_spine_serializable; + ] + |> lower (min Cost.cost))); + (* Cleanup*) + traced ~name:"cleanup" @@ fix + @@ seq_many + [ + for_all Select_tactics.push_simple_select + Path.(all >>? is_select); + for_all Dedup_tactics.push_dedup + Path.(all >>? is_dedup); + for_all Dedup_tactics.elim_dedup + Path.(all >>? is_dedup); + ]; + traced ~name:"project" + @@ fix Simplify_tactic.project; + traced ~name:"prf" @@ push_all_runtime_filters; + traced ~name:"simp" @@ Simplify_tactic.simplify; + traced @@ filter is_serializable''; + ]); + ]); + ]; + ] + + let is_serializable = is_serializable' +end + +exception Optimize_failure of Ast.t + +let rec optimize_exn (module C : Config.S) r = + (* Optimize outer query. *) + let module T = Make (C) in + let module O = Ops.Make (C) in + let r = + match O.apply T.(try_partition opt) Path.root r with + | Some r -> r + | None -> raise (Optimize_failure r) + in + optimize_subqueries (module C : Config.S) r + +(* Recursively optimize subqueries. *) +and optimize_subqueries (module C : Config.S) r = + let visitor = + object (self : 'a) + inherit [_] V.map + + method visit_subquery r = + let module C = struct + include C + + let params = Set.union params (Free.free r) + end in + optimize_exn (module C) r + + method! visit_Exists () r = Exists (self#visit_subquery r) + method! visit_First () r = First (self#visit_subquery r) + + method! visit_AList () l = + AList { l with l_values = self#visit_t () l.l_values } + + method! visit_AOrderedIdx () o = + AOrderedIdx { o with oi_values = self#visit_t () o.oi_values } + + method! visit_AHashIdx () h = + AHashIdx { h with hi_values = self#visit_t () h.hi_values } + + method! visit_AScalar () v = AScalar v + end + in + visitor#visit_t () r + +let optimize (module C : Config.S) r = + try Either.First (optimize_exn (module C) r) + with Optimize_failure r' -> Second r' diff --git a/lib/transform.mli b/lib/transform.mli new file mode 100644 index 0000000..4ac18b1 --- /dev/null +++ b/lib/transform.mli @@ -0,0 +1,17 @@ +open Ast + +module Config : sig + module type S = sig + val conn : Db.t + val cost_conn : Db.t + val params : Set.M(Name).t + val cost_timeout : float option + val random : Mcmc.Random_choice.t + end +end + +module Make (Config : Config.S) : sig + val is_serializable : 'a annot -> unit Or_error.t +end + +val optimize : (module Config.S) -> Ast.t -> (Ast.t, Ast.t) Either.t diff --git a/lib/type_cost.ml b/lib/type_cost.ml new file mode 100644 index 0000000..d72051a --- /dev/null +++ b/lib/type_cost.ml @@ -0,0 +1,62 @@ +open Ast +open Abslayout_load +open Type +module I = Abs_int +include (val Log.make ~level:(Some Warning) "castor-opt.type-cost") + +module Config = struct + module type S = sig + val params : Set.M(Name).t + val cost_timeout : float option + val cost_conn : Db.t + end +end + +module Make (Config : Config.S) = struct + open Config + + let rec read = function + | StringT { nchars = Top; _ } -> (* TODO: Fix this... *) I.Interval (5, 50) + | (NullT | EmptyT | IntT _ | DateT _ | FixedT _ | BoolT _ | StringT _) as t + -> + len t + | ListT (elem_t, m) -> I.(read elem_t * m.count) + | FuncT ([ t ], _) -> read t + | FuncT ([ t1; t2 ], _) -> I.(read t1 * read t2) + | FuncT _ -> failwith "Unexpected function." + | TupleT (elem_ts, _) -> List.sum (module I) elem_ts ~f:read + | HashIdxT (_, vt, _) -> I.(join zero (read vt)) + | OrderedIdxT (_, vt, _) -> I.(join zero (read vt)) + + let cost kind r = + info (fun m -> m "Computing cost of:@, %a." Abslayout.pp r); + let out = + let open Result.Let_syntax in + let%bind layout = load_layout ~params cost_conn r in + let%bind type_ = + Parallel.type_of ?timeout:cost_timeout cost_conn (strip_meta layout) + in + let c = read type_ in + match kind with + | `Min -> I.inf c + | `Max -> I.sup c + | `Avg -> + let%bind l = I.inf c in + let%map h = I.sup c in + l + ((h - l) / 2) + in + match out with + | Ok x -> + let x = Float.of_int x in + info (fun m -> m "Found cost %f." x); + x + | Error e -> + warn (fun m -> + m "Computing cost failed: %a" + (Resolve.pp_err @@ Parallel.pp_err @@ Fmt.nop) + e); + Float.max_value + + let cost ?(kind = `Avg) = + Memo.general ~hashable:(Hashtbl.Hashable.of_key (module Ast)) (cost kind) +end diff --git a/lib/type_cost_tests.ml b/lib/type_cost_tests.ml new file mode 100644 index 0000000..d67bb15 --- /dev/null +++ b/lib/type_cost_tests.ml @@ -0,0 +1,156 @@ +open Abslayout_load +open Castor_test.Test_util + +open Type_cost.Make (struct + let params = Set.empty (module Name) + let cost_timeout = Some 60.0 + let cost_conn = Lazy.force tpch_conn +end) + +let%expect_test "" = + let r = + load_string_exn (Lazy.force tpch_conn) + {| +select([s1_acctbal, s1_name, n1_name, p1_partkey, p1_mfgr, s1_address, + s1_phone, s1_comment], + ahashidx(depjoin(select([min(p_size) as lo, max(p_size) as hi], + dedup(select([p_size], part))) as k1, + select([range as k0], range(k1.lo, k1.hi))) as s0, + select([s1_acctbal, s1_name, n1_name, p1_partkey, p1_mfgr, s1_address, + s1_phone, s1_comment], + alist(filter((p1_size = s0.k0), + orderby([s1_acctbal desc, n1_name, s1_name, p1_partkey], + join((((r1_name = r_name) && + (((ps_partkey = ps1_partkey) && + (ps1_supplycost = min_cost)) && + (ps1_supplycost = min_cost))) && true), + join((n1_regionkey = r1_regionkey), + select([r_name as r1_name, r_regionkey as r1_regionkey], + region), + join((s1_nationkey = n1_nationkey), + select([n_name as n1_name, n_nationkey as n1_nationkey, + n_regionkey as n1_regionkey], + nation), + join((s1_suppkey = ps1_suppkey), + select([s_nationkey as s1_nationkey, + s_suppkey as s1_suppkey, + s_acctbal as s1_acctbal, s_name as s1_name, + s_address as s1_address, s_phone as s1_phone, + s_comment as s1_comment], + supplier), + join((p1_partkey = ps1_partkey), + select([p_size as p1_size, p_type as p1_type, + p_partkey as p1_partkey, p_mfgr as p1_mfgr], + part), + select([ps_supplycost as ps1_supplycost, + ps_partkey as ps1_partkey, + ps_suppkey as ps1_suppkey], + partsupp))))), + depjoin(dedup( + select([r_name, ps_partkey], + join((s_suppkey = ps_suppkey), + join((s_nationkey = n_nationkey), + join((n_regionkey = r_regionkey), + nation, + region), + supplier), + partsupp))) as k2, + select([r_name, ps_partkey, + min(ps_supplycost) as min_cost], + join((((r_name = k2.r_name) && + (ps_partkey = k2.ps_partkey)) && + (s_suppkey = ps_suppkey)), + join((s_nationkey = n_nationkey), + join((n_regionkey = r_regionkey), nation, region), + supplier), + partsupp)))))) as s1, + filter(((r1_name = "") && + (strpos(p1_type, "") = + ((strlen(p1_type) - strlen("")) + 1))), + atuple([ascalar(s1.r1_name), ascalar(s1.n1_name), + ascalar(s1.s1_acctbal), ascalar(s1.s1_name), + ascalar(s1.s1_address), ascalar(s1.s1_phone), + ascalar(s1.s1_comment), ascalar(s1.p1_type), + ascalar(s1.p1_partkey), ascalar(s1.p1_mfgr)], + cross)))), + 0)) +|} + in + Fmt.pr "%f" (cost r); + [%expect {| 7398.000000 |}] + +let%expect_test "" = + let r = + load_string_exn (Lazy.force tpch_conn) + {| +select([n1_name, n2_name, l_year, revenue], + alist(orderby([n1_name, n2_name, l_year], + dedup( + select([n1_name, n2_name, to_year(l_shipdate) as l_year], + join(((s_suppkey = l_suppkey) && true), + join((o_orderkey = l_orderkey), + join((c_custkey = o_custkey), + join((c_nationkey = n2_nationkey), + select([n_name as n2_name, n_nationkey as n2_nationkey], nation), + customer), + orders), + filter(((l_shipdate >= date("1995-01-01")) && (l_shipdate <= date("1996-12-31"))), lineitem)), + join((s_nationkey = n1_nationkey), + select([n_name as n1_name, n_nationkey as n1_nationkey], nation), + supplier))))) as k0, + select([n1_name, n2_name, l_year, revenue], + ahashidx(dedup( + atuple([dedup( + atuple([select([n_name as x27], dedup(select([n_name], nation))), + select([n_name as x27], dedup(select([n_name], nation)))], + concat)), + dedup( + atuple([select([n_name as x30], dedup(select([n_name], nation))), + select([n_name as x30], dedup(select([n_name], nation)))], + concat))], + cross)) as s7, + alist(filter((count0 > 0), + select([count() as count0, n1_name, n2_name, to_year(l_shipdate) as l_year, + sum((l_extendedprice * (1 - l_discount))) as revenue], + atuple([ascalar(s7.x27), ascalar(s7.x30), + alist(filter((true && (n2_name = k0.n2_name)), + select([n_name as n2_name, n_nationkey as n2_nationkey], nation)) as s4, + filter((((n1_name = s7.x27) && (n2_name = s7.x30)) || + ((n1_name = s7.x30) && (n2_name = s7.x27))), + atuple([ascalar(s4.n2_name), + alist(select([l_suppkey, l_shipdate, l_discount, l_extendedprice], + filter(((to_year(l_shipdate) = k0.l_year) && + (c_nationkey = s4.n2_nationkey)), + select([l_suppkey, l_shipdate, + l_discount, + l_extendedprice, + c_nationkey], + join((c_custkey = o_custkey), + join((o_orderkey = l_orderkey), + filter(((l_shipdate >= date("1995-01-01")) && + (l_shipdate <= date("1996-12-31"))), + lineitem), + orders), + customer)))) as s3, + atuple([atuple([ascalar(s3.l_shipdate), + ascalar(s3.l_discount), + ascalar(s3.l_extendedprice)], + cross), + alist(select([n1_name], + join((((n1_name = k0.n1_name) && (s_suppkey = s3.l_suppkey)) + && (s_nationkey = n1_nationkey)), + select([n_name as n1_name, n_nationkey as n1_nationkey], + nation), + supplier)) as s2, + ascalar(s2.n1_name))], + cross))], + cross)))], + cross))) as s8, + atuple([ascalar(s8.count0), ascalar(s8.n1_name), ascalar(s8.n2_name), + ascalar(s8.l_year), ascalar(s8.revenue)], + cross)), + ("", ""))))) +|} + in + Fmt.pr "%f" (cost r); + [%expect {| 15151.000000 |}]