Skip to content

Commit

Permalink
smt: implement get_value
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 13, 2024
1 parent 511e0e9 commit 3681e3c
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 50 deletions.
152 changes: 104 additions & 48 deletions patronus/src/smt/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
use crate::expr::{ArrayType, Context, ExprRef, Type, TypeCheck, WidthInt};
use regex::bytes::RegexSet;
use rustc_hash::FxHashMap;
use std::cmp::Ordering;
use std::fmt::{Debug, Formatter};
use thiserror::Error;

#[derive(Debug, Error)]
pub enum SmtParserError {
#[error("[smt] get-value response: {0}")]
GetValueResponse(String),
#[error("[smt] expected an expression but got: {0}")]
ExpectedExpr(String),
#[error("[smt] expected a type but got: {0}")]
Expand Down Expand Up @@ -47,40 +50,121 @@ type Result<T> = std::result::Result<T, SmtParserError>;
type SymbolTable = FxHashMap<String, ExprRef>;

pub fn parse_expr(ctx: &mut Context, st: &mut SymbolTable, input: &[u8]) -> Result<ExprRef> {
let lexer = Lexer::new(input);
let (expr, extra_closing_parens) = parse_expr_internal(ctx, st, lexer)?;
if extra_closing_parens == 0 {
Ok(expr)
} else {
Err(SmtParserError::MissingOpen)
}
}

fn parse_expr_internal(
ctx: &mut Context,
st: &mut SymbolTable,
lexer: Lexer,
) -> Result<(ExprRef, u64)> {
use ParserItem::*;
let mut stack: Vec<ParserItem> = Vec::with_capacity(64);
let lexer = Lexer::new(input);
// keep track of how many closing parenthesis without an opening one are encountered
let mut orphan_closing_count = 0u64;
for token in lexer {
match token {
Token::Open => {
if orphan_closing_count > 0 {
return Err(SmtParserError::MissingOpen);
}
stack.push(Open);
}
Token::Close => {
// find the closest Open
let open_pos = match stack.iter().rev().position(|i| matches!(i, Open)) {
Some(p) => stack.len() - 1 - p,
None => return Err(SmtParserError::MissingOpen),
};
let pattern = &stack[open_pos + 1..];
let result = parse_pattern(ctx, st, pattern)?;
stack.truncate(open_pos);
stack.push(result);
if let Some(p) = stack.iter().rev().position(|i| matches!(i, Open)) {
let open_pos = stack.len() - 1 - p;
let pattern = &stack[open_pos + 1..];
let result = parse_pattern(ctx, st, pattern)?;
stack.truncate(open_pos);
stack.push(result);
} else {
// no matching open parenthesis
orphan_closing_count += 1;
}
}
Token::Value(value) => {
if orphan_closing_count > 0 {
return Err(SmtParserError::MissingOpen);
}
// we eagerly parse number literals, but we do not make decisions on symbols yet
stack.push(early_parse_number_literals(ctx, value)?);
}
Token::EscapedValue(value) => stack.push(PExpr(lookup_sym(st, value)?)),
Token::EscapedValue(value) => {
if orphan_closing_count > 0 {
return Err(SmtParserError::MissingOpen);
}
stack.push(PExpr(lookup_sym(st, value)?))
}
}
}

if let [PExpr(e)] = stack.as_slice() {
Ok(*e)
Ok((*e, orphan_closing_count))
} else {
todo!("error message!")
}
}

/// Extracts the value expression from SMT solver responses of the form ((... value))
/// We expect value to be self contained and thus no symbol table should be necessary.
pub fn parse_get_value_response(ctx: &mut Context, input: &[u8]) -> Result<ExprRef> {
let mut lexer = Lexer::new(input);

// skip `(`
let open_one = lexer.next() == Some(Token::Open);
let open_two = lexer.next() == Some(Token::Open);
if !open_one || !open_two {
return Err(SmtParserError::GetValueResponse(
"expected two opening parentheses".to_string(),
));
}

// skip next expr
if !skip_expr(&mut lexer) {
return Err(SmtParserError::GetValueResponse(
"failed to find first expression".to_string(),
));
}

// parse next expr
let mut st = FxHashMap::default();
let (expr, extra_closing_parens) = parse_expr_internal(ctx, &mut st, lexer)?;
match extra_closing_parens.cmp(&2) {
Ordering::Less => Err(SmtParserError::GetValueResponse(
"expected two closing parentheses".to_string(),
)),
Ordering::Equal => Ok(expr),
Ordering::Greater => Err(SmtParserError::MissingOpen),
}
}

fn skip_expr(lexer: &mut Lexer) -> bool {
let mut open_count = 0u64;
while let Some(token) = lexer.next() {
match token {
Token::Open => {
open_count += 1;
}
Token::Close => {
open_count -= 1;
if open_count == 0 {
return true;
}
}
_ => return true,
}
}
// reached end of tokens
false
}

fn lookup_sym(st: &SymbolTable, name: &[u8]) -> Result<ExprRef> {
let name = std::str::from_utf8(name)?;
match st.get(name) {
Expand Down Expand Up @@ -137,14 +221,6 @@ fn expr(st: &SymbolTable, item: &ParserItem<'_>) -> Result<ExprRef> {
}
}

/// errors if the item cannot be directly converted to a type
fn tpe(item: &ParserItem<'_>) -> Result<Type> {
match item {
ParserItem::PType(t) => Ok(*t),
other => Err(SmtParserError::ExpectedType(format!("{other:?}"))),
}
}

fn early_parse_number_literals<'a>(ctx: &mut Context, value: &'a [u8]) -> Result<ParserItem<'a>> {
if let Some(match_id) = NUM_LIT_REGEX.matches(value).into_iter().next() {
match match_id {
Expand Down Expand Up @@ -195,34 +271,6 @@ impl<'a> Debug for ParserItem<'a> {
}
}

const RESERVED_WORDS: &[&str] = &[
"_",
"!",
"as",
"let",
"exists",
"forall",
"match",
"par",
"BINARY",
"DECIMAL",
"HEXADECIMAL",
"NUMERAL",
"STRING",
];

const BV_LIB_SYMBOL: &[&str] = &[
"BitVec", "concat", "extract", // op1 is from:
"bvnot", "bvneg", // op2 is from:
"bvand", "bvor", "bvadd", "bvmul", "bvudiv", "bvurem", "bvshl", "bvlshr", //
"bvult",
];

const COMMANDS: &[&str] = &[
"assert",
// TODO
];

lazy_static! {
static ref NUM_LIT_REGEX: RegexSet = RegexSet::new([
r"^#b[01]+$", // binary
Expand All @@ -241,6 +289,7 @@ struct Lexer<'a> {
pos: usize,
}

#[derive(Eq, PartialEq)]
enum Token<'a> {
Open,
Close,
Expand Down Expand Up @@ -358,10 +407,17 @@ mod tests {
let mut ctx = Context::default();
let a = ctx.bv_symbol("a", 2);
let mut symbols = FxHashMap::from_iter([("a".to_string(), a)]);
let expr = parse_expr(&mut ctx, &mut symbols, &mut "(bvand a #b00)".as_bytes()).unwrap();
let expr = parse_expr(&mut ctx, &mut symbols, "(bvand a #b00)".as_bytes()).unwrap();
assert_eq!(expr, ctx.build(|c| c.and(a, c.bit_vec_val(0, 2))));
}

#[test]
fn test_get_value_parser() {
let mut ctx = Context::default();
let expr = parse_get_value_response(&mut ctx, "((a #b011))".as_bytes()).unwrap();
assert_eq!(expr, ctx.bit_vec_val(3, 3));
}

#[test]
fn test_parse_smt_array_const_and_store() {
let mut ctx = Context::default();
Expand Down
2 changes: 1 addition & 1 deletion patronus/src/smt/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// released under BSD 3-Clause License
// author: Kevin Laeufer <[email protected]>

use crate::expr::{Context, Expr, ExprRef, ForEachChild, SerializableIrNode, Type, TypeCheck};
use crate::expr::{Context, Expr, ExprRef, ForEachChild, Type, TypeCheck};
use crate::smt::solver::SmtCommand;
use baa::BitVecOps;
use std::io::Write;
Expand Down
6 changes: 5 additions & 1 deletion patronus/src/smt/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// author: Kevin Laeufer <[email protected]>

use crate::expr::{Context, ExprRef};
use crate::smt::parser::{parse_get_value_response, SmtParserError};
use crate::smt::serialize::serialize_cmd;
use std::io::{BufRead, BufReader, BufWriter};
use std::io::{Read, Write};
Expand All @@ -22,6 +23,8 @@ pub enum Error {
SolverDead(String),
#[error("[smt] {0} returned an unexpected response:\n{1}")]
UnexpectedResponse(String, String),
#[error("[smt] failed to parse a response")]
Parser(#[from] SmtParserError),
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -322,7 +325,8 @@ impl<R: Write + Send> SolverContext for SmtLibSolverCtx<R> {
self.write_cmd(Some(ctx), &SmtCommand::GetValue(e))?;
self.read_response()?;
let response = self.response.trim();
todo!("parse get-value response: {response}")
let expr = parse_get_value_response(ctx, response.as_bytes())?;
Ok(expr)
}
}

Expand Down

0 comments on commit 3681e3c

Please sign in to comment.