From daee6e9b4abcc4aad59137320ba9543accd3c479 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kevin=20L=C3=A4ufer?= Date: Fri, 20 Dec 2024 09:54:05 -0500 Subject: [PATCH] egraph: add custom to_dot --- patronus-egraphs/.gitignore | 2 + patronus-egraphs/Cargo.toml | 3 +- patronus-egraphs/src/arithmetic.rs | 114 +++++++++++++++++++++++++++-- 3 files changed, 110 insertions(+), 9 deletions(-) create mode 100644 patronus-egraphs/.gitignore diff --git a/patronus-egraphs/.gitignore b/patronus-egraphs/.gitignore new file mode 100644 index 0000000..705f6e6 --- /dev/null +++ b/patronus-egraphs/.gitignore @@ -0,0 +1,2 @@ +/*.pdf +/*.dot diff --git a/patronus-egraphs/Cargo.toml b/patronus-egraphs/Cargo.toml index fbb2aac..faa928a 100644 --- a/patronus-egraphs/Cargo.toml +++ b/patronus-egraphs/Cargo.toml @@ -13,5 +13,4 @@ rust-version.workspace = true patronus = { path = "../patronus" } egg.workspace = true baa.workspace = true -lazy_static = "1.5.0" -regex = "1.11.1" +rustc-hash.workspace = true diff --git a/patronus-egraphs/src/arithmetic.rs b/patronus-egraphs/src/arithmetic.rs index aa54ba5..4dea176 100644 --- a/patronus-egraphs/src/arithmetic.rs +++ b/patronus-egraphs/src/arithmetic.rs @@ -8,8 +8,10 @@ use egg::{ Rewrite, Subst, Var, }; use patronus::expr::*; +use rustc_hash::FxHashMap; use std::cmp::{max, Ordering}; use std::fmt::{Display, Formatter}; +use std::io::Write; use std::str::FromStr; define_language! { @@ -539,7 +541,10 @@ impl ArithRewrite { let condition = move |egraph: &mut EGraph, _, subst: &Subst| { let values: Vec = vars .iter() - .map(|v| get_const_width_or_sign(egraph, subst[*v])) + .map(|v| { + get_const_width_or_sign(egraph, subst[*v]) + .expect("failed to find constant width") + }) .collect(); cond(values.as_slice()) }; @@ -577,7 +582,7 @@ type EGraph = egg::EGraph; /// Finds a width or sign constant in the e-class referred to by the substitution /// and returns its value. Errors if no such constant can be found. -fn get_const_width_or_sign(egraph: &EGraph, id: Id) -> WidthInt { +fn get_const_width_or_sign(egraph: &EGraph, id: Id) -> Option { egraph[id] .nodes .iter() @@ -585,14 +590,13 @@ fn get_const_width_or_sign(egraph: &EGraph, id: Id) -> WidthInt { Arith::Width(w) => Some((*w).into()), Arith::Sign(s) => Some((*s).into()), Arith::WidthMaxPlus1([a, b]) => { - let a = get_const_width_or_sign(egraph, *a); - let b = get_const_width_or_sign(egraph, *b); + let a = get_const_width_or_sign(egraph, *a).expect("failed to find constant width"); + let b = get_const_width_or_sign(egraph, *b).expect("failed to find constant width"); Some(max(a, b) + 1) } _ => None, }) .next() - .expect("failed to find constant width") } /// Checks that input and output widths of operations are consistent. @@ -651,6 +655,102 @@ pub fn create_egg_rewrites() -> Vec> { .unwrap_or(vec![]) } +fn to_pdf(filename: &str, egraph: &EGraph) -> std::io::Result<()> { + use std::process::{Command, Stdio}; + let mut child = Command::new("dot") + .args(["-Tpdf", "-o", filename]) + .stdin(Stdio::piped()) + .stdout(Stdio::null()) + .spawn()?; + let stdin = child.stdin.as_mut().expect("Failed to open stdin"); + write_to_dot(stdin, egraph)?; + match child.wait()?.code() { + Some(0) => Ok(()), + Some(e) => panic!("dot program returned error code {}", e), + None => panic!("dot program was killed by a signal"), + } +} + +/// Reimplements egg's `to_dot` functionality. +/// This is necessary because we do not want to show the Width nodes in the graph, because +/// otherwise it becomes very confusing. +fn write_to_dot(out: &mut impl Write, egraph: &EGraph) -> std::io::Result<()> { + writeln!(out, "digraph egraph {{")?; + + // set compound=true to enable edges to clusters + writeln!(out, " compound=true")?; + writeln!(out, " clusterrank=local")?; + + // create a map from e-class id to width + let widths = FxHashMap::from_iter( + egraph + .classes() + .flat_map(|class| get_const_width_or_sign(egraph, class.id).map(|w| (class.id, w))), + ); + + // define all the nodes, clustered by eclass + for class in egraph.classes() { + if !widths.contains_key(&class.id) { + writeln!(out, " subgraph cluster_{} {{", class.id)?; + writeln!(out, " style=dotted")?; + for (i, node) in class.iter().enumerate() { + writeln!(out, " {}.{}[label = \"{}\"]", class.id, i, node)?; + } + writeln!(out, " }}")?; + } + } + + for class in egraph.classes() { + if !widths.contains_key(&class.id) { + for (i_in_class, node) in class.iter().enumerate() { + let mut arg_i = 0; + node.for_each(|child| { + if !widths.contains_key(&child) { + // write the edge to the child, but clip it to the eclass with lhead + let (anchor, label) = dot_edge(arg_i, node.len()); + let child_leader = egraph.find(child); + + if child_leader == class.id { + writeln!( + out, + // {}.0 to pick an arbitrary node in the cluster + " {}.{}{} -> {}.{}:n [lhead = cluster_{}, {}]", + class.id, i_in_class, anchor, class.id, i_in_class, class.id, label + ) + .unwrap(); + } else { + writeln!( + out, + // {}.0 to pick an arbitrary node in the cluster + " {}.{}{} -> {}.0 [lhead = cluster_{}, {}]", + class.id, i_in_class, anchor, child, child_leader, label + ) + .unwrap(); + } + } + arg_i += 1; + }); + } + } + } + + write!(out, "}}") +} + +fn dot_edge(i: usize, len: usize) -> (String, String) { + assert!(i < len); + let s = |s: &str| s.to_string(); + match (len, i) { + (1, 0) => (s(""), s("")), + (2, 0) => (s(":sw"), s("")), + (2, 1) => (s(":se"), s("")), + (3, 0) => (s(":sw"), s("")), + (3, 1) => (s(":s"), s("")), + (3, 2) => (s(":se"), s("")), + (_, _) => (s(""), format!("label={}", i)), + } +} + #[cfg(test)] mod tests { use super::*; @@ -713,9 +813,9 @@ mod tests { let spec_class = runner.roots[0]; let impl_class = runner.roots[1]; - // println!("{spec_class} {impl_class}"); + println!("{spec_class} {impl_class}"); - // runner.egraph.dot().to_pdf("graph.pdf").unwrap(); + // to_pdf("graph.pdf", &runner.egraph).unwrap(); } #[test]