From 9f06a409e770098f8c71f9779b08899cad95df4b Mon Sep 17 00:00:00 2001 From: DavePearce Date: Mon, 22 Jul 2024 14:58:07 +1200 Subject: [PATCH] Fix for `raise_ifs` This was not correctly raising `if` expressions for certain expression forms. This also adds relevant tests as well. --- src/transformer/ifs.rs | 20 ++++++++++---------- tests/issue219_a.lisp | 5 +++++ tests/issue219_b.lisp | 4 ++++ tests/issue219_c.lisp | 4 ++++ tests/issue219_d.lisp | 4 ++++ tests/models.rs | 36 ++++++++++++++++++++++++++++++++++++ 6 files changed, 63 insertions(+), 10 deletions(-) create mode 100644 tests/issue219_a.lisp create mode 100644 tests/issue219_b.lisp create mode 100644 tests/issue219_c.lisp create mode 100644 tests/issue219_d.lisp diff --git a/src/transformer/ifs.rs b/src/transformer/ifs.rs index 7f449a1..fb3c3f5 100644 --- a/src/transformer/ifs.rs +++ b/src/transformer/ifs.rs @@ -126,7 +126,11 @@ fn raise_ifs(mut e: Node) -> Node { .fold(true, |b, e| b && !matches!(e.e(), Expression::Void))); // match func { - Intrinsic::Add + Intrinsic::Neg + | Intrinsic::Inv + | Intrinsic::Normalize + | Intrinsic::Exp + | Intrinsic::Add | Intrinsic::Sub | Intrinsic::Mul | Intrinsic::VectorAdd @@ -147,7 +151,7 @@ fn raise_ifs(mut e: Node) -> Node { let new_then = func.unchecked_call(&then_args).unwrap(); let mut new_args = vec![cond, new_then]; // Pull out false branch (if applicable): - // (func a b (if cond then else) c) + // (func a b (if cond c d) e) // ==> (if !cond (func a b d e)) if let Some(arg_else) = args_if.get(2).cloned() { let mut else_args = args.clone(); @@ -163,13 +167,7 @@ fn raise_ifs(mut e: Node) -> Node { } e } - Intrinsic::IfZero - | Intrinsic::IfNotZero - | Intrinsic::Neg - | Intrinsic::Inv - | Intrinsic::Normalize - | Intrinsic::Exp - | Intrinsic::Begin => e, + Intrinsic::IfZero | Intrinsic::IfNotZero | Intrinsic::Begin => e, } } Expression::List(xs) => { @@ -330,7 +328,9 @@ pub fn expand_ifs(cs: &mut ConstraintSet) { // Raise ifs for c in cs.constraints.iter_mut() { if let Constraint::Vanishes { expr, .. } = c { - *expr = Box::new(raise_ifs(*expr.clone())); + let nexpr = raise_ifs(*expr.clone()); + // Replace old expression with new + *expr = Box::new(nexpr); } } for c in cs.constraints.iter_mut() { diff --git a/tests/issue219_a.lisp b/tests/issue219_a.lisp new file mode 100644 index 0000000..a84979e --- /dev/null +++ b/tests/issue219_a.lisp @@ -0,0 +1,5 @@ +(defcolumns X) + +(defconstraint Constraint () + (neq! (if (is-zero 1) X X) X) +) diff --git a/tests/issue219_b.lisp b/tests/issue219_b.lisp new file mode 100644 index 0000000..92348b1 --- /dev/null +++ b/tests/issue219_b.lisp @@ -0,0 +1,4 @@ +(defcolumns ST X) + +(defconstraint constraint-test () + (if-not-zero ST (is-not-zero! (if (vanishes! X) 1 1)))) diff --git a/tests/issue219_c.lisp b/tests/issue219_c.lisp new file mode 100644 index 0000000..4b90f6c --- /dev/null +++ b/tests/issue219_c.lisp @@ -0,0 +1,4 @@ +(defcolumns ST X) + +(defconstraint constraint-test () + (if-not-zero ST (vanishes! (- 1 (if (vanishes! X) 1 1))))) diff --git a/tests/issue219_d.lisp b/tests/issue219_d.lisp new file mode 100644 index 0000000..73265db --- /dev/null +++ b/tests/issue219_d.lisp @@ -0,0 +1,4 @@ +(defcolumns ST X Y) + +(defconstraint constraint-test () + (if-not-zero ST (vanishes! (- 1 (if (vanishes! X) Y 1))))) diff --git a/tests/models.rs b/tests/models.rs index f3e63da..bf2128e 100644 --- a/tests/models.rs +++ b/tests/models.rs @@ -240,6 +240,26 @@ static MODELS: &[Model] = &[ cols: &["ST", "X"], oracle: Some(issue241_b_oracle), }, + Model { + name: "issue219_a", + cols: &["X"], + oracle: Some(|_| false), + }, + Model { + name: "issue219_b", + cols: &["ST", "X"], + oracle: Some(|_| true), + }, + Model { + name: "issue219_c", + cols: &["ST", "X"], + oracle: Some(|_| true), + }, + Model { + name: "issue219_d", + cols: &["ST", "X", "Y"], + oracle: Some(issue219_d_oracle), + }, ]; // =================================================================== @@ -350,3 +370,19 @@ fn issue241_b_oracle(tr: &Trace) -> bool { } true } + +// =================================================================== +// Issue 219 +// =================================================================== + +#[allow(non_snake_case)] +fn issue219_d_oracle(tr: &Trace) -> bool { + let (ST, X, Y) = (tr.col("ST"), tr.col("X"), tr.col("Y")); + + for k in 0..tr.height() { + if ST[k] != 0 && X[k] == 0 && Y[k] != 1 { + return false; + } + } + true +}