Skip to content

Commit

Permalink
egraph: add custom to_dot
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 20, 2024
1 parent 5b11c4a commit daee6e9
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 9 deletions.
2 changes: 2 additions & 0 deletions patronus-egraphs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/*.pdf
/*.dot
3 changes: 1 addition & 2 deletions patronus-egraphs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,4 @@ rust-version.workspace = true
patronus = { path = "../patronus" }
egg.workspace = true
baa.workspace = true
lazy_static = "1.5.0"
regex = "1.11.1"
rustc-hash.workspace = true
114 changes: 107 additions & 7 deletions patronus-egraphs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ use egg::{
Rewrite, Subst, Var,
};
use patronus::expr::*;
use rustc_hash::FxHashMap;
use std::cmp::{max, Ordering};
use std::fmt::{Display, Formatter};
use std::io::Write;
use std::str::FromStr;

define_language! {
Expand Down Expand Up @@ -539,7 +541,10 @@ impl ArithRewrite {
let condition = move |egraph: &mut EGraph, _, subst: &Subst| {
let values: Vec<WidthInt> = vars
.iter()
.map(|v| get_const_width_or_sign(egraph, subst[*v]))
.map(|v| {
get_const_width_or_sign(egraph, subst[*v])
.expect("failed to find constant width")
})
.collect();
cond(values.as_slice())
};
Expand Down Expand Up @@ -577,22 +582,21 @@ type EGraph = egg::EGraph<Arith, ()>;

/// Finds a width or sign constant in the e-class referred to by the substitution
/// and returns its value. Errors if no such constant can be found.
fn get_const_width_or_sign(egraph: &EGraph, id: Id) -> WidthInt {
fn get_const_width_or_sign(egraph: &EGraph, id: Id) -> Option<WidthInt> {
egraph[id]
.nodes
.iter()
.flat_map(|n| match n {
Arith::Width(w) => Some((*w).into()),
Arith::Sign(s) => Some((*s).into()),
Arith::WidthMaxPlus1([a, b]) => {
let a = get_const_width_or_sign(egraph, *a);
let b = get_const_width_or_sign(egraph, *b);
let a = get_const_width_or_sign(egraph, *a).expect("failed to find constant width");
let b = get_const_width_or_sign(egraph, *b).expect("failed to find constant width");
Some(max(a, b) + 1)
}
_ => None,
})
.next()
.expect("failed to find constant width")
}

/// Checks that input and output widths of operations are consistent.
Expand Down Expand Up @@ -651,6 +655,102 @@ pub fn create_egg_rewrites() -> Vec<Rewrite<Arith, ()>> {
.unwrap_or(vec![])
}

fn to_pdf(filename: &str, egraph: &EGraph) -> std::io::Result<()> {
use std::process::{Command, Stdio};
let mut child = Command::new("dot")
.args(["-Tpdf", "-o", filename])
.stdin(Stdio::piped())
.stdout(Stdio::null())
.spawn()?;
let stdin = child.stdin.as_mut().expect("Failed to open stdin");
write_to_dot(stdin, egraph)?;
match child.wait()?.code() {
Some(0) => Ok(()),
Some(e) => panic!("dot program returned error code {}", e),
None => panic!("dot program was killed by a signal"),
}
}

/// Reimplements egg's `to_dot` functionality.
/// This is necessary because we do not want to show the Width nodes in the graph, because
/// otherwise it becomes very confusing.
fn write_to_dot(out: &mut impl Write, egraph: &EGraph) -> std::io::Result<()> {
writeln!(out, "digraph egraph {{")?;

// set compound=true to enable edges to clusters
writeln!(out, " compound=true")?;
writeln!(out, " clusterrank=local")?;

// create a map from e-class id to width
let widths = FxHashMap::from_iter(
egraph
.classes()
.flat_map(|class| get_const_width_or_sign(egraph, class.id).map(|w| (class.id, w))),
);

// define all the nodes, clustered by eclass
for class in egraph.classes() {
if !widths.contains_key(&class.id) {
writeln!(out, " subgraph cluster_{} {{", class.id)?;
writeln!(out, " style=dotted")?;
for (i, node) in class.iter().enumerate() {
writeln!(out, " {}.{}[label = \"{}\"]", class.id, i, node)?;
}
writeln!(out, " }}")?;
}
}

for class in egraph.classes() {
if !widths.contains_key(&class.id) {
for (i_in_class, node) in class.iter().enumerate() {
let mut arg_i = 0;
node.for_each(|child| {
if !widths.contains_key(&child) {
// write the edge to the child, but clip it to the eclass with lhead
let (anchor, label) = dot_edge(arg_i, node.len());
let child_leader = egraph.find(child);

if child_leader == class.id {
writeln!(
out,
// {}.0 to pick an arbitrary node in the cluster
" {}.{}{} -> {}.{}:n [lhead = cluster_{}, {}]",
class.id, i_in_class, anchor, class.id, i_in_class, class.id, label
)
.unwrap();
} else {
writeln!(
out,
// {}.0 to pick an arbitrary node in the cluster
" {}.{}{} -> {}.0 [lhead = cluster_{}, {}]",
class.id, i_in_class, anchor, child, child_leader, label
)
.unwrap();
}
}
arg_i += 1;
});
}
}
}

write!(out, "}}")
}

fn dot_edge(i: usize, len: usize) -> (String, String) {
assert!(i < len);
let s = |s: &str| s.to_string();
match (len, i) {
(1, 0) => (s(""), s("")),
(2, 0) => (s(":sw"), s("")),
(2, 1) => (s(":se"), s("")),
(3, 0) => (s(":sw"), s("")),
(3, 1) => (s(":s"), s("")),
(3, 2) => (s(":se"), s("")),
(_, _) => (s(""), format!("label={}", i)),
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -713,9 +813,9 @@ mod tests {

let spec_class = runner.roots[0];
let impl_class = runner.roots[1];
// println!("{spec_class} {impl_class}");
println!("{spec_class} {impl_class}");

// runner.egraph.dot().to_pdf("graph.pdf").unwrap();
// to_pdf("graph.pdf", &runner.egraph).unwrap();
}

#[test]
Expand Down

0 comments on commit daee6e9

Please sign in to comment.