Skip to content

Commit

Permalink
egraph: example equality working
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 20, 2024
1 parent 51a4246 commit ac59fff
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 42 deletions.
26 changes: 24 additions & 2 deletions patronus-egraphs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ define_language! {
">>>" = ArithmeticRightShift([Id; 7]),
// operations on widths
"max+1" = WidthMaxPlus1([Id; 2]),
"wlsh" = WidthLeftShift([Id; 2]),
Width(WidthValue),
Sign(Sign),
// not a width, but a value constant
Expand All @@ -35,8 +36,18 @@ define_language! {
#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct WidthValue(WidthInt);

fn eval_width_max_plus_1(a: WidthInt, b: WidthInt) -> WidthInt {
max(a, b) + 1
pub(crate) fn eval_width_max_plus_1(wa: WidthInt, wb: WidthInt) -> WidthInt {
max(wa, wb) + 1
}

pub(crate) fn eval_width_left_shift(wa: WidthInt, wb: WidthInt) -> WidthInt {
if wb >= WidthInt::BITS {
// very very very large width, not what you want
WidthInt::MAX
} else {
let max_shift: WidthInt = (1 << wb) - 1;
wa + max_shift
}
}

// this allows us to use ArithWidthConst as an argument to ctx.bit_vec_val
Expand Down Expand Up @@ -146,6 +157,7 @@ impl Analysis<Arith> for WidthConstantFold {
match expr {
&Arith::Width(w) => Some(w.0),
Arith::WidthMaxPlus1([a, b]) => Some(eval_width_max_plus_1(x(a)?, x(b)?)),
Arith::WidthLeftShift([a, b]) => Some(eval_width_left_shift(x(a)?, x(b)?)),
_ => None,
}
}
Expand Down Expand Up @@ -330,6 +342,11 @@ pub fn from_arith(ctx: &mut Context, expr: &RecExpr<Arith>) -> ExprRef {
let b = get_u64(ctx, stack.pop().unwrap()) as WidthInt;
ctx.bit_vec_val(eval_width_max_plus_1(a, b), 32)
}
Arith::WidthLeftShift(_) => {
let a = get_u64(ctx, stack.pop().unwrap()) as WidthInt;
let b = get_u64(ctx, stack.pop().unwrap()) as WidthInt;
ctx.bit_vec_val(eval_width_left_shift(a, b), 32)
}
Arith::Width(width) => ctx.bit_vec_val(*width, 32),
Arith::Sign(sign) => ctx.bit_vec_val(*sign, 1),
Arith::Const(value) => {
Expand Down Expand Up @@ -384,6 +401,11 @@ fn get_width(root: usize, expressions: &[Arith]) -> WidthInt {
let b = get_width(usize::from(*b), expressions);
eval_width_max_plus_1(a, b)
}
Arith::WidthLeftShift([a, b]) => {
let a = get_width(usize::from(*a), expressions);
let b = get_width(usize::from(*b), expressions);
eval_width_left_shift(a, b)
}
other => todo!("calculate width for {other:?}"),
}
}
Expand Down
2 changes: 2 additions & 0 deletions patronus-egraphs/src/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ fn write_to_dot(out: &mut impl Write, egraph: &EGraph) -> std::io::Result<()> {
// set compound=true to enable edges to clusters
writeln!(out, " compound=true")?;
writeln!(out, " clusterrank=local")?;
// more spacing
writeln!(out, " ranksep=2")?;

// create a map from e-class id to width
let widths = FxHashMap::from_iter(
Expand Down
50 changes: 10 additions & 40 deletions patronus-egraphs/src/rewrites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ introspect them in order to check re-write conditions or debug matches.
!*/

use crate::arithmetic::{eval_width_left_shift, eval_width_max_plus_1};
use crate::{get_const_width_or_sign, is_bin_op, Arith, EGraph, WidthConstantFold};
use egg::{
ConditionalApplier, ENodeOrVar, Id, Language, Pattern, PatternAst, Searcher, Subst, Var,
Expand Down Expand Up @@ -55,7 +56,8 @@ pub fn create_rewrites() -> Vec<ArithRewrite> {
// we do not want (b + c) to wrap, because in that case the result would always be zero
// the value being shifted has to be consistently signed or unsigned
"(<< ?wo ?wa ?sa ?a ?wbc unsign (+ ?wbc ?wb unsign ?b ?wc unsign ?c))" =>
"(<< ?wo ?wo ?sa (<< ?wo ?wa ?sa ?a ?wb unsign ?b) ?wc unsign ?c)";
// RHS: we set wab to the minimum not to overflow
"(<< ?wo (wlsh ?wa ?wb) ?sa (<< (wlsh ?wa ?wb) ?wa ?sa ?a ?wb unsign ?b) ?wc unsign ?c)";
// ?wbc >= max(wb, wc) + 1
if["?wbc", "?wb", "?wc"], |w| w[0] >= (max(w[1], w[2]) + 1)),
// a * 2 <=> a + a
Expand All @@ -69,8 +71,8 @@ pub fn create_rewrites() -> Vec<ArithRewrite> {
arith_rewrite!("left-shift-mult";
// TODO: currently all signs are forced to unsigned
"(<< ?wo ?wab unsign (* ?wab ?wa unsign ?a ?wb unsign ?b) ?wc unsign ?c)" =>
// we set the width of (a << c) to the result width to satisfy wac >= wo
"(* ?wo ?wo unsign (<< ?wo ?wa unsign ?a ?wc unsign ?c) ?wb unsign ?b)";
// RHS: we set wac to the minimum not to overflow
"(* ?wo (wlsh ?wa ?wc) unsign (<< (wlsh ?wa ?wc) ?wa unsign ?a ?wc unsign ?c) ?wb unsign ?b)";
// we want to determine that there is no overflow
// lhs: wab >= wa + wb && wo >= wab + max_shift(wc)
// rhs: wac >= wa + max_shift(c) && wo >= wac + wb
Expand All @@ -80,7 +82,7 @@ pub fn create_rewrites() -> Vec<ArithRewrite> {

/// Determines if there is no overflow possible for this addition.
fn add_no_ov(wo: WidthInt, wa: WidthInt, wb: WidthInt) -> bool {
wo >= max(wa, wb) + 1
wo >= eval_width_max_plus_1(wa, wb)
}

/// Determines if there is no overflow possible for this multiplication.
Expand All @@ -90,13 +92,7 @@ fn mul_no_ov(wo: WidthInt, wa: WidthInt, wb: WidthInt) -> bool {

/// Determines if there is no overflow possible for this left shift.
fn lsh_no_ov(wo: WidthInt, wa: WidthInt, wb: WidthInt) -> bool {
if wb >= WidthInt::BITS {
// avoid overflow
false
} else {
let max_shift: WidthInt = (1 << wb) - 1;
wo >= wa + max_shift
}
wo >= eval_width_left_shift(wa, wb)
}

pub struct ArithRewrite {
Expand Down Expand Up @@ -298,7 +294,7 @@ pub fn create_egg_rewrites() -> Vec<Rewrite> {
mod tests {
use super::*;
use crate::arithmetic::verification_fig_1;
use crate::{to_arith, to_dot, to_pdf};
use crate::to_arith;
use patronus::expr::{Context, SerializableIrNode};
#[test]
fn test_data_path_verification_fig_1_rewrites() {
Expand All @@ -307,45 +303,19 @@ mod tests {
let spec_e = to_arith(&ctx, spec);
let impl_e = to_arith(&ctx, implementation);

println!("{spec_e}");
println!("{impl_e}");

// run egraph operations
let egg_rewrites = create_egg_rewrites();
let runner = egg::Runner::default()
.with_expr(&spec_e)
.with_expr(&impl_e)
.with_iter_limit(10)
.run(&egg_rewrites);

runner.print_report();

let spec_class = runner.egraph.find(runner.roots[0]);
let impl_class = runner.egraph.find(runner.roots[1]);
println!("{spec_class} {impl_class}");

let left_shift_mult = create_rewrites()
.into_iter()
.find(|r| r.name == "left-shift-mult")
.unwrap();
println!("{}", left_shift_mult.patterns().0);
let r = left_shift_mult.find_lhs_matches(&runner.egraph);
for m in r {
println!("{m:?}");
}

to_pdf("graph.pdf", &runner.egraph).unwrap();
to_dot("graph.dot", &runner.egraph).unwrap();
runner.egraph.dot().to_pdf("full_graph.pdf").unwrap();
runner.egraph.dot().to_dot("full_graph.dot").unwrap();

// investigating eclass 26 and 13 which should ideally be the same
println!("{}", inspect_e_class(&runner.egraph, 26));
println!("{}", inspect_e_class(&runner.egraph, 13));
println!("{}", inspect_e_class(&runner.egraph, 25));
println!("{}", inspect_e_class(&runner.egraph, 12));
assert_eq!(spec_class, impl_class, "should prove equality!");
}

#[allow(dead_code)]
fn inspect_e_class(egraph: &EGraph, id: usize) -> String {
let nodes = egraph[id.into()]
.nodes
Expand Down

0 comments on commit ac59fff

Please sign in to comment.