Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 19, 2024
1 parent ceeaad9 commit c29f17e
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions patronus-egraphs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
// author: Kevin Laeufer <[email protected]>

use baa::BitVecOps;
use egg::{define_language, ConditionalApplier, ENodeOrVar, Id, Language, Pattern, PatternAst, RecExpr, Rewrite, Subst, Var};
use egg::{
define_language, ConditionalApplier, ENodeOrVar, Id, Language, Pattern, PatternAst, RecExpr,
Rewrite, Subst, Var,
};
use patronus::expr::*;
use std::cmp::{max, Ordering};
use std::fmt::{Display, Formatter};
use std::str::FromStr;


define_language! {
/// Intermediate expression language for bit vector arithmetic rewrites.
/// Inspired by: "ROVER: RTL Optimization via Verified E-Graph Rewriting" (TCAD'24)
Expand Down Expand Up @@ -381,7 +383,6 @@ fn get_child_widths(root: usize, expressions: &[Arith], out: &mut Vec<WidthInt>)
debug_assert!(expr.children().is_empty(), "{expr:?}")
}
}

}
}

Expand Down Expand Up @@ -454,7 +455,12 @@ fn extend(
w_in: WidthInt,
signed: bool,
) -> ExprRef {
debug_assert_eq!(expr.get_bv_type(ctx).unwrap(), w_in, "{}", expr.serialize_to_str(ctx));
debug_assert_eq!(
expr.get_bv_type(ctx).unwrap(),
w_in,
"{}",
expr.serialize_to_str(ctx)
);
match w_out.cmp(&w_in) {
Ordering::Less => unreachable!("cannot extend from {w_in} to {w_out}"),
Ordering::Equal => expr,
Expand All @@ -463,8 +469,6 @@ fn extend(
}
}



pub struct ArithRewrite {
name: String,
/// most general lhs pattern
Expand Down Expand Up @@ -551,8 +555,6 @@ impl ArithRewrite {
}
}



type EGraph = egg::EGraph<Arith, ()>;

/// Finds a width or sign constant in the e-class referred to by the substitution
Expand All @@ -570,7 +572,6 @@ fn get_const_width_or_sign(egraph: &EGraph, subst: &Subst, v: Var) -> WidthInt {
.expect("failed to find constant width")
}


/// Checks that input and output widths of operations are consistent.
fn check_width_consistency(pattern: &Pattern<Arith>) {
let exprs = pattern.ast.as_ref();
Expand All @@ -581,12 +582,20 @@ fn check_width_consistency(pattern: &Pattern<Arith>) {
let a_width_id = usize::from(expr.children()[1]);
let a_id = usize::from(expr.children()[3]);
if let Some(a_op_out_width_id) = get_output_width_id(&exprs[a_id]) {
assert_eq!(a_width_id, a_op_out_width_id, "In `{expr}`, subexpression `{}` has inconsistent width: {} != {}", &exprs[a_id], &exprs[a_width_id], &exprs[a_op_out_width_id]);
assert_eq!(
a_width_id, a_op_out_width_id,
"In `{expr}`, subexpression `{}` has inconsistent width: {} != {}",
&exprs[a_id], &exprs[a_width_id], &exprs[a_op_out_width_id]
);
}
let b_width_id = usize::from(expr.children()[4]);
let b_id = usize::from(expr.children()[6]);
if let Some(b_op_out_width_id) = get_output_width_id(&exprs[b_id]) {
assert_eq!(b_width_id, b_op_out_width_id, "In `{expr}`, subexpression `{}` has inconsistent width: {} != {}", &exprs[b_id], &exprs[b_width_id], &exprs[b_op_out_width_id]);
assert_eq!(
b_width_id, b_op_out_width_id,
"In `{expr}`, subexpression `{}` has inconsistent width: {} != {}",
&exprs[b_id], &exprs[b_width_id], &exprs[b_op_out_width_id]
);
}
}
}
Expand Down

0 comments on commit c29f17e

Please sign in to comment.