Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate between EGraph types #306

Merged
merged 5 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 209 additions & 1 deletion src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::*;
use std::{
borrow::BorrowMut,
fmt::{self, Debug, Display},
marker::PhantomData,
};

#[cfg(feature = "serde-1")]
Expand Down Expand Up @@ -532,7 +533,6 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
}

/// Creates a [`Dot`] to visualize this egraph. See [`Dot`].
///
pub fn dot(&self) -> Dot<L, N> {
Dot {
egraph: self,
Expand All @@ -542,6 +542,214 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
}
}

/// Translates `EGraph<L, A>` into `EGraph<L2, A2>`. For common cases, you don't
/// need to implement this manually. See the provided [`SimpleLanguageMapper`].
pub trait LanguageMapper<L, A>
where
L: Language,
A: Analysis<L>,
{
/// The target language to translate into.
type L2: Language;

/// The target analysis to transate into.
type A2: Analysis<Self::L2>;

/// Translate a node of `L` into a node of `L2`.
fn map_node(&self, node: L) -> Self::L2;

/// Translate `L::Discriminant` into `L2::Discriminant`
fn map_discriminant(
&self,
discriminant: L::Discriminant,
) -> <Self::L2 as Language>::Discriminant;

/// Translate an analysis of type `A` into an analysis of `A2`.
fn map_analysis(&self, analysis: A) -> Self::A2;

/// Translate `A::Data` into `A2::Data`.
fn map_data(&self, data: A::Data) -> <Self::A2 as Analysis<Self::L2>>::Data;

/// Translate an [`EClass`] over `L` into an [`EClass`] over `L2`.
fn map_eclass(
&self,
src_eclass: EClass<L, A::Data>,
) -> EClass<Self::L2, <Self::A2 as Analysis<Self::L2>>::Data> {
EClass {
id: src_eclass.id,
nodes: src_eclass
.nodes
.into_iter()
.map(|l| self.map_node(l))
.collect(),
data: self.map_data(src_eclass.data),
parents: src_eclass
.parents
.into_iter()
.map(|(l, id)| (self.map_node(l), id))
.collect(),
}
}

/// Map an `EGraph` over `L` into an `EGraph` over `L2`.
fn map_egraph(&self, src_egraph: EGraph<L, A>) -> EGraph<Self::L2, Self::A2> {
let kv_map = |(k, v): (L, Id)| (self.map_node(k), v);
EGraph {
analysis: self.map_analysis(src_egraph.analysis),
explain: None,
unionfind: src_egraph.unionfind,
memo: src_egraph.memo.into_iter().map(kv_map).collect(),
pending: src_egraph.pending.into_iter().map(kv_map).collect(),
analysis_pending: src_egraph
.analysis_pending
.into_iter()
.map(kv_map)
.collect(),
classes: src_egraph
.classes
.into_iter()
.map(|(id, eclass)| (id, self.map_eclass(eclass)))
.collect(),
classes_by_op: src_egraph
.classes_by_op
.into_iter()
.map(|(k, v)| (self.map_discriminant(k), v))
.collect(),
clean: src_egraph.clean,
}
}
}

/// An implementation of [`LanguageMapper`] that can convert an [`EGraph`] over one
/// language into an [`EGraph`] over a different language in common cases.
///
/// Specifically, you can use this if have
/// [`conversion`](https://doc.rust-lang.org/1.76.0/core/convert/index.html)
/// implemented between your source and target language, as well as your source and
/// target analysis.
///
/// Here is an example of how to use this. Consider a case where you have a newtype
/// wrapper over an existing language type:
///
/// ```rust
/// use egg::*;
///
/// #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
/// struct MyLang(SymbolLang);
/// # impl Language for MyLang {
/// # type Discriminant = <SymbolLang as Language>::Discriminant;
/// #
/// # fn matches(&self, other: &Self) -> bool {
/// # self.0.matches(&other.0)
/// # }
/// #
/// # fn children(&self) -> &[Id] {
/// # self.0.children()
/// # }
/// #
/// # fn children_mut(&mut self) -> &mut [Id] {
/// # self.0.children_mut()
/// # }
/// #
/// # fn discriminant(&self) -> Self::Discriminant {
/// # self.0.discriminant()
/// # }
/// # }
///
/// // some external library function
/// pub fn external(egraph: EGraph<SymbolLang, ()>) { }
///
/// fn do_thing(egraph: EGraph<MyLang, ()>) {
/// // how do I call external?
/// external(todo!())
/// }
/// ```
///
/// By providing an implementation of `From<MyLang> for SymbolLang`, we can
/// construct `SimpleLanguageMapper` and use it to translate our [`EGraph`] into the
/// right type.
///
/// ```rust
/// # use egg::*;
/// # #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
/// # struct MyLang(SymbolLang);
/// # impl Language for MyLang {
/// # type Discriminant = <SymbolLang as Language>::Discriminant;
/// #
/// # fn matches(&self, other: &Self) -> bool {
/// # self.0.matches(&other.0)
/// # }
/// #
/// # fn children(&self) -> &[Id] {
/// # self.0.children()
/// # }
/// #
/// # fn children_mut(&mut self) -> &mut [Id] {
/// # self.0.children_mut()
/// # }
/// #
/// # fn discriminant(&self) -> Self::Discriminant {
/// # self.0.discriminant()
/// # }
/// # }
/// # pub fn external(egraph: EGraph<SymbolLang, ()>) { }
/// impl From<MyLang> for SymbolLang {
/// fn from(value: MyLang) -> Self {
/// value.0
/// }
/// }
///
/// fn do_thing(egraph: EGraph<MyLang, ()>) {
/// external(SimpleLanguageMapper::default().map_egraph(egraph))
/// }
/// ```
///
/// Note that we do not need to provide any conversion for the analysis, because it
/// is the same in both source and target e-graphs.
pub struct SimpleLanguageMapper<L2, A2> {
_phantom: PhantomData<(L2, A2)>,
}

impl<L, A> Default for SimpleLanguageMapper<L, A> {
fn default() -> Self {
SimpleLanguageMapper {
_phantom: PhantomData::default(),
}
}
}

impl<L, A, L2, A2> LanguageMapper<L, A> for SimpleLanguageMapper<L2, A2>
where
L: Language,
A: Analysis<L>,
L2: Language + From<L>,
A2: Analysis<L2> + From<A>,
<L2 as Language>::Discriminant: From<<L as Language>::Discriminant>,
<A2 as Analysis<L2>>::Data: From<<A as Analysis<L>>::Data>,
{
type L2 = L2;
type A2 = A2;

fn map_node(&self, node: L) -> Self::L2 {
node.into()
}

fn map_discriminant(
&self,
discriminant: <L as Language>::Discriminant,
) -> <Self::L2 as Language>::Discriminant {
discriminant.into()
}

fn map_analysis(&self, analysis: A) -> Self::A2 {
analysis.into()
}

fn map_data(&self, data: <A as Analysis<L>>::Data) -> <Self::A2 as Analysis<Self::L2>>::Data {
data.into()
}
}

/// Given an `Id` using the `egraph[id]` syntax, retrieve the e-class.
impl<L: Language, N: Analysis<L>> std::ops::Index<Id> for EGraph<L, N> {
type Output = EClass<L, N::Data>;
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub(crate) use {explain::Explain, unionfind::UnionFind};
pub use {
dot::Dot,
eclass::EClass,
egraph::EGraph,
egraph::{EGraph, LanguageMapper, SimpleLanguageMapper},
explain::{
Explanation, FlatExplanation, FlatTerm, Justification, TreeExplanation, TreeTerm,
UnionEqualities,
Expand Down
28 changes: 27 additions & 1 deletion src/util.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fmt;
use std::{fmt, iter::FromIterator};
use symbolic_expressions::Sexp;

use fmt::{Debug, Display, Formatter};
Expand Down Expand Up @@ -172,3 +172,29 @@ where
r
}
}

impl<T> IntoIterator for UniqueQueue<T>
where
T: Eq + std::hash::Hash + Clone,
{
type Item = T;

type IntoIter = <std::collections::VecDeque<T> as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
self.queue.into_iter()
}
}

impl<A> FromIterator<A> for UniqueQueue<A>
where
A: Eq + std::hash::Hash + Clone,
{
fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self {
let mut queue = UniqueQueue::default();
for t in iter {
queue.insert(t);
}
queue
}
}
Loading