Skip to content

Commit

Permalink
Fix: Incorrect Handling of Perspective Selectors (#204)
Browse files Browse the repository at this point in the history
The problem was that lists were not being raised correctly out of
expressions and, in particular, this was causing a problem with
perspective selectors.
  • Loading branch information
DavePearce authored Jun 18, 2024
1 parent 918e95f commit 7a8d41b
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 91 deletions.
2 changes: 1 addition & 1 deletion src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,7 @@ impl ColumnSet {
}

pub fn maybe_insert_column(&mut self, column: Column) -> Option<ColumnRef> {
if let Some(id) = self.cols.get(&column.handle) {
if let Some(_) = self.cols.get(&column.handle) {
None
} else {
let id = self._cols.len();
Expand Down
3 changes: 2 additions & 1 deletion src/compiler/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ impl FuncVerifier<Node> for Intrinsic {
| Intrinsic::VectorAdd
| Intrinsic::VectorSub
| Intrinsic::VectorMul => {
for (i, arg) in args.iter().enumerate() {
for (_, arg) in args.iter().enumerate() {
if arg.is_list() {
bail!("unexpected list operand for {}", self.to_string())
}
Expand Down Expand Up @@ -1734,6 +1734,7 @@ pub(crate) fn reduce_toplevel(
// Perspectives are just multiplicative coefficients, and are
// controlled exceptions to the usual loobean typing rules
let body_type = body.t();
println!("APPLYING PERSPECTIVE");
Intrinsic::Mul
.unchecked_call(&[persp_guard, body])
.with_context(|| anyhow!("constraint {}", name))?
Expand Down
1 change: 0 additions & 1 deletion src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use itertools::Itertools;
use log::*;
use logging_timer::time;
use owo_colors::OwoColorize;
use rayon::prelude::*;
use std::{cmp::Ordering, collections::HashSet};

use crate::{
Expand Down
36 changes: 0 additions & 36 deletions src/exporters/wizardiop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,42 +278,6 @@ fn reg_mangle_ith(cs: &ConstraintSet, c: &ColumnRef, i: usize) -> Result<String>
.unwrap_or_else(|| Handle::new("", format!("{}_#{}", reg_id, i)).mangle()))
}

fn reg(cs: &ConstraintSet, c: &Handle) -> Result<Handle> {
let reg_id = cs
.columns
.by_handle(c)?
.register
.ok_or_else(|| anyhow!("column {} has no backing register", c.pretty()))?;
let reg = &cs
.columns
.registers
.get(reg_id)
.ok_or_else(|| anyhow!("register {} for column {} does not exist", reg_id, c))?;
Ok(reg
.handle
.as_ref()
.cloned()
.unwrap_or_else(|| Handle::new(&c.module, reg_id.to_string())))
}

fn reg_splatter(cs: &ConstraintSet, c: &Handle, i: usize) -> Result<Handle> {
let reg_id = cs
.columns
.by_handle(c)?
.register
.ok_or_else(|| anyhow!("column {} has no backing register", c.pretty()))?;
let reg = &cs
.columns
.registers
.get(reg_id)
.ok_or_else(|| anyhow!("register {} for column {} does not exist", reg_id, c))?;
Ok(reg
.handle
.as_ref()
.map(|h| h.iota(i))
.unwrap_or_else(|| Handle::new(&c.module, reg_id.to_string())))
}

#[derive(Serialize, Debug)]
struct WiopColumn {
go_id: String,
Expand Down
8 changes: 0 additions & 8 deletions src/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,6 @@ fn expression_to_name(e: &Node, prefix: &str) -> String {
format!("C/{}[{}]", prefix, e)
}

/// Wraps `ex` into a `List` if it is not already one.
fn wrap(ex: Node) -> Node {
match ex.e() {
Expression::List(_) => ex,
_ => Node::from_expr(Expression::List(vec![ex])),
}
}

fn flatten_list(mut e: Node) -> Node {
match e.e_mut() {
Expression::List(ref mut xs) => {
Expand Down
186 changes: 147 additions & 39 deletions src/transformer/ifs.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use anyhow::Result;
use num_traits::Zero;

use crate::compiler::{Conditioning, Constraint, ConstraintSet, Expression, Intrinsic, Node};
use crate::compiler::{Constraint, ConstraintSet, Expression, Intrinsic, Node};

use super::{flatten_list, wrap};
use super::{flatten_list};

/// Expand if conditions, assuming they are roughly in "top-most"
/// positions. That is, we can have arbitrary nested if `List` and
Expand Down Expand Up @@ -36,11 +36,6 @@ fn do_expand_ifs(e: &mut Node) -> Result<()> {
if matches!(func, Intrinsic::IfZero | Intrinsic::IfNotZero) {
let cond = args[0].clone();
let if_not_zero = matches!(func, Intrinsic::IfNotZero);
assert!(if if_not_zero {
matches!(cond.t().c(), Conditioning::Boolean | Conditioning::None)
} else {
matches!(cond.t().c(), Conditioning::Loobean | Conditioning::None)
});

// If the condition reduces to a constant, we can determine the result
if let Ok(constant_cond) = cond.pure_eval() {
Expand All @@ -58,50 +53,37 @@ fn do_expand_ifs(e: &mut Node) -> Result<()> {
}
}
} else {
// Construct condition for then branch, and
// condition for else branch.
let conds = {
// Multiplier for if-non-zero branch.
let cond_not_zero = cond.clone();
// Multiplier for if-zero branch.
let cond_zero = Intrinsic::Sub.unchecked_call(&[
Node::one(),
Intrinsic::Normalize.unchecked_call(&[cond.clone()])?,
])?;
// Set ordering based on function itself.
if if_not_zero {
[cond_not_zero, cond_zero]
} else {
[cond_zero, cond_not_zero]
}
};

// Order the then/else blocks
let then_else = vec![args.get(1), args.get(2)]
.into_iter()
.enumerate()
// Only keep the non-empty branches
.filter_map(|(i, ex)| ex.map(|ex| (i, ex)))
// Ensure branches are wrapped in lists
.map(|(i, ex)| (i, wrap(ex.clone())))
// Map the corresponding then/else operations on the branches
.flat_map(|(i, exs)| {
if let Expression::List(exs) = exs.e() {
exs.iter()
.map(|ex: &Node| {
ex.flat_map(&|e| {
Intrinsic::Mul
.unchecked_call(&[conds[i].clone(), e.clone()])
.unwrap()
})
})
.collect::<Vec<_>>()
} else {
unreachable!()
}
})
.flatten()
.collect::<Vec<_>>();
*e = if then_else.len() == 1 {
then_else[0].clone()
} else {
Node::from_expr(Expression::List(then_else))
}
// Apply condition to body.
let then_else : Node = match (args.get(1),args.get(2)) {
(Some(e), None) => {
let then_cond = conds[0].clone();
Intrinsic::Mul.unchecked_call(&[then_cond, e.clone()]).unwrap()
}
(None, Some(e)) => {
let else_cond = conds[1].clone();
Intrinsic::Mul.unchecked_call(&[else_cond, e.clone()]).unwrap()
}
(_,_) => unreachable!()
};
// Finally, replace existing node.
*e = then_else.clone();
};
}
}
Expand Down Expand Up @@ -196,6 +178,114 @@ fn raise_ifs(mut e: Node) -> Node {
}
}

/// Pull `lists` out of nested positions and into top-most
/// positions. Specifically, something like this:
///
/// ```lisp
/// (defconstraint test () (if A (begin B C)))
/// ```
///
/// Has the nested `list` raised into the following position:
///
/// ```lisp
/// (defconstraint test () (begin (if A B) (if A C)))
/// ```
///
/// The purpose of this is to sanitize the structure of expressions
/// conditions to make their subsequent translation easier.
fn raise_lists(node: &Node) -> Vec<Node> {
match node.e() {
Expression::List(xs) => {
let mut exprs = Vec::new();
for x in xs {
exprs.extend(raise_lists(x));
}
exprs
}
Expression::Funcall { func, args } if args.len() > 0 => {
match func {
Intrinsic::IfZero if args.len() > 2 => {
let mut out = Vec::new();
// if-then
raise_binary(&args[0],&args[1],func,&mut out);
// if-else
raise_binary(&args[0],&args[2],&Intrinsic::IfNotZero,&mut out);
// done
out
}
Intrinsic::IfNotZero if args.len() > 2 => {
let mut out = Vec::new();
// if-then
raise_binary(&args[0],&args[1],func,&mut out);
// if-else
raise_binary(&args[0],&args[2],&Intrinsic::IfZero,&mut out);
// done
out
}
Intrinsic::Begin => unreachable!(),
_ => {
// More challenging because we have to compute the cross
// product.
let mut out = Vec::new();
raise_intrinsic(args,func,&mut out,&mut Vec::new());
out
}
}
}
_ => vec![node.clone()],
}
}

/// Enumerate all atomic invocations of this intrinsic by expanding
/// the cross-product of all arguments. To understand this, consider:
///
/// ```lisp
/// (* (begin A B) (begin X Y))
/// ```
///
/// This is considered "non-atomic" because it contains lists within.
/// This is expanded into the following distinct invocations:
///
/// ```lisp
/// (* A X)
/// (* B X)
/// (* A Y)
/// (* B Y)
/// ```
///
/// This method is responsible for enumerating the argument
/// combinations.
fn raise_intrinsic(args: &[Node], f: &Intrinsic, out: &mut Vec<Node>, acc: &mut Vec<Node>) {
let n = acc.len();
//
if n == args.len() {
out.push(Node::from_expr(f.raw_call(acc)));
} else {
// Raise nth expression
let raised_args = raise_lists(&args[n]);
// Continue
for e in raised_args {
acc.push(e);
raise_intrinsic(args,f,out,acc);
acc.pop();
}
}
// Done
}

/// Special case of `raise_intrinsic` for binary operands.
fn raise_binary(lhs: &Node, rhs: &Node, f: &Intrinsic, out: &mut Vec<Node>) {
let raised_lhs = raise_lists(lhs);
let raised_rhs = raise_lists(rhs);
// Simple cross product
for l in raised_lhs {
for r in &raised_rhs {
let l_r_expr = f.raw_call(&[l.clone(),r.clone()]);
out.push(Node::from_expr(l_r_expr));
}
}
}

/// Responsible for lowering `if` expressions into a multiplication
/// over the normalised condition. For example, this constraint:
///
Expand All @@ -216,6 +306,24 @@ fn raise_ifs(mut e: Node) -> Node {
/// it is evaluated at compile time and the entire `if` expression is
/// eliminated.
pub fn expand_ifs(cs: &mut ConstraintSet) {
// Raise lists
for c in cs.constraints.iter_mut() {
if let Constraint::Vanishes { expr, .. } = c {
let mut exprs = raise_lists(&*expr);
// Construct new expression
let nexpr = if exprs.len() == 1 {
// Optimise case where only a single expression, as we
// don't need a list in this case.
exprs.pop().unwrap()
} else {
// When there are multiple expressions, use a list.
Node::from_expr(Expression::List(exprs))
};
// Replace old expression with new
*expr = Box::new(nexpr);
}
}
// Raise ifs
for c in cs.constraints.iter_mut() {
if let Constraint::Vanishes { expr, .. } = c {
*expr = Box::new(raise_ifs(*expr.clone()));
Expand Down
6 changes: 1 addition & 5 deletions src/transformer/nhood.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
use anyhow::{bail, Result};
use owo_colors::OwoColorize;
use std::collections::HashMap;

use crate::{
column::{Column, Computation},
compiler::{
ColumnRef, Constraint, ConstraintSet, Domain, Intrinsic, Kind, Magma, Node, RawMagma,
ColumnRef, Constraint, ConstraintSet, Intrinsic, Kind, Node, RawMagma,
},
pretty::Base,
structs::Handle,
};

Expand Down

0 comments on commit 7a8d41b

Please sign in to comment.