Skip to content

Commit

Permalink
Small fixes to tests and congruence (#223)
Browse files Browse the repository at this point in the history
* fix tests, small cleaning in congruence algorithm.

* further simplification

* more minor improvements to congruence

* update CHANGES.md

* old
  • Loading branch information
SimonGuilloud authored Oct 4, 2024
1 parent 8af257a commit cc73785
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 73 deletions.
4 changes: 3 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Change List

## 2024-10-04
Repair two broken tests, small improvements to congruence tactic.

## 2024-10-03
Update to Scala 3.5.1. Silence warnings from scallion and other erroneous ones.

## 2024-07-22
Resealed the `Proof` trait following a fix of the relevant compiler bug [scala/scala3#19031](https://github.com/scala/scala3/issues/19031).

Updated to Scala 3.4.2, and relevant minor syntax changes from `-rewrite -source 3.4-migration`.

## 2024-04-12
Expand Down
134 changes: 66 additions & 68 deletions lisa-sets/src/main/scala/lisa/automation/Congruence.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,25 +240,17 @@ import scala.collection.mutable

class EGraphTerms() {

type ENode = Term | Formula



val termMap = mutable.Map[Term, Set[Term]]()
val termParents = mutable.Map[Term, mutable.Set[AppliedFunctional | AppliedPredicate]]()
var termWorklist = List[Term]()
val termParentsT = mutable.Map[Term, mutable.Set[AppliedFunctional]]()
val termParentsF = mutable.Map[Term, mutable.Set[AppliedPredicate]]()
val termUF = new UnionFind[Term]()




val formulaMap = mutable.Map[Formula, Set[Formula]]()
val formulaParents = mutable.Map[Formula, mutable.Set[AppliedConnector]]()
var formulaWorklist = List[Formula]()
val formulaUF = new UnionFind[Formula]()



def find(id: Term): Term = termUF.find(id)
def find(id: Formula): Formula = formulaUF.find(id)

trait TermStep
case class TermExternal(between: (Term, Term)) extends TermStep
Expand Down Expand Up @@ -316,62 +308,62 @@ class EGraphTerms() {

def makeSingletonEClass(node:Term): Term = {
termUF.add(node)
termMap(node) = Set(node)
termParents(node) = mutable.Set()
termParentsT(node) = mutable.Set()
termParentsF(node) = mutable.Set()
node
}
def makeSingletonEClass(node:Formula): Formula = {
formulaUF.add(node)
formulaMap(node) = Set(node)
formulaParents(node) = mutable.Set()
node
}

def classOf(id: Term): Set[Term] = termMap(id)
def classOf(id: Formula): Set[Formula] = formulaMap(id)
def idEq(id1: Term, id2: Term): Boolean = find(id1) == find(id2)
def idEq(id1: Formula, id2: Formula): Boolean = find(id1) == find(id2)

def idEq(id1: Term, id2: Term): Boolean = termUF.find(id1) == termUF.find(id2)
def idEq(id1: Formula, id2: Formula): Boolean = formulaUF.find(id1) == formulaUF.find(id2)


def canonicalize(node: Term): Term = node match
case AppliedFunctional(label, args) =>
AppliedFunctional(label, args.map(termUF.find.asInstanceOf))
AppliedFunctional(label, args.map(t => find(t)))
case _ => node


def canonicalize(node: Formula): Formula = {
node match
case AppliedPredicate(label, args) => AppliedPredicate(label, args.map(termUF.find))
case AppliedConnector(label, args) => AppliedConnector(label, args.map(formulaUF.find))
case AppliedPredicate(label, args) => AppliedPredicate(label, args.map(find))
case AppliedConnector(label, args) => AppliedConnector(label, args.map(find))
case node => node
}

def add(node: Term): Term =
if termMap.contains(node) then return node
if termUF.parent.contains(node) then return node
makeSingletonEClass(node)
codes(node) = codes.size
node match
case node @ AppliedFunctional(_, args) =>
args.foreach(child =>
add(child)
termParents(child).add(node)
termParentsT(find(child)).add(node)
)
node
case _ => node
case _ => ()
termSigs(canSig(node)) = node
node

def add(node: Formula): Formula =
if formulaMap.contains(node) then return node
if formulaUF.parent.contains(node) then return node
makeSingletonEClass(node)
node match
case node @ AppliedPredicate(_, args) =>
args.foreach(child =>
add(child)
termParents(child).add(node)
termParentsF(find(child)).add(node)
)
node
case node @ AppliedConnector(_, args) =>
args.foreach(child =>
add(child)
formulaParents(child).add(node)
formulaParents(find(child)).add(node)
)
node
case _ => node
Expand All @@ -392,75 +384,81 @@ class EGraphTerms() {
mergeWithStep(id1, id2, FormulaExternal((id1, id2)))
}

type Sig = (TermLabel[?]|Term, List[Int])
val termSigs = mutable.Map[Sig, Term]()
val codes = mutable.Map[Term, Int]()

def canSig(node: Term): Sig = node match
case AppliedFunctional(label, args) =>
(label, args.map(a => codes(find(a))).toList)
case _ => (node, List())

protected def mergeWithStep(id1: Term, id2: Term, step: TermStep): Unit = {
if termUF.find(id1) == termUF.find(id2) then ()
if find(id1) == find(id2) then ()
else
termProofMap((id1, id2)) = step
val newSet = termMap(termUF.find(id1)) ++ termMap(termUF.find(id2))
val newparents = termParents(termUF.find(id1)) ++ termParents(termUF.find(id2))
val parentsT1 = termParentsT(find(id1))
val parentsF1 = termParentsF(find(id1))

val parentsT2 = termParentsT(find(id2))
val parentsF2 = termParentsF(find(id2))
val preSigs : Map[Term, Sig] = parentsT1.map(t => (t, canSig(t))).toMap
codes(find(id2)) = codes(find(id1)) //assume parents(find(id1)) >= parents(find(id2))
termUF.union(id1, id2)
val newId1 = termUF.find(id1)
val newId2 = termUF.find(id2)
termMap(newId1) = newSet
termMap(newId2) = newSet
termParents(newId1) = newparents
termParents(newId2) = newparents
val newId = find(id1)

val id = termUF.find(id2)
termWorklist = id :: termWorklist
val cause = (id1, id2)
val termSeen = mutable.Map[Term, AppliedFunctional]()
val formulaSeen = mutable.Map[Formula, AppliedPredicate]()
newparents.foreach {
var formWorklist = List[(Formula, Formula, FormulaStep)]()
var termWorklist = List[(Term, Term, TermStep)]()

parentsT2.foreach {
case pTerm: AppliedFunctional =>
val canonicalPTerm = canonicalize(pTerm)
if termSeen.contains(canonicalPTerm) then
val qTerm = termSeen(canonicalPTerm)
Some((pTerm, qTerm, cause))
mergeWithStep(pTerm, qTerm, TermCongruence((pTerm, qTerm)))
val canonicalPTerm = canSig(pTerm)
if termSigs.contains(canonicalPTerm) then
val qTerm = termSigs(canonicalPTerm)
termWorklist = (pTerm, qTerm, TermCongruence((pTerm, qTerm))) :: termWorklist
else
termSeen(canonicalPTerm) = pTerm
termSigs(canonicalPTerm) = pTerm
}
(parentsF2 ++ parentsF1).foreach {
case pFormula: AppliedPredicate =>
val canonicalPFormula = canonicalize(pFormula)
if formulaSeen.contains(canonicalPFormula) then
val qFormula = formulaSeen(canonicalPFormula)

Some((pFormula, qFormula, cause))
mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula)))
formWorklist = (pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) :: formWorklist
else
formulaSeen(canonicalPFormula) = pFormula
}
termParents(id) = (termSeen.values.to(mutable.Set): mutable.Set[AppliedFunctional | AppliedPredicate]) ++ formulaSeen.values.to(mutable.Set)
termParentsT(newId) = termParentsT(id1)
termParentsT(newId).addAll(termParentsT(id2))
termParentsF(newId) = formulaSeen.values.to(mutable.Set)
formWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) }
termWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) }
}

protected def mergeWithStep(id1: Formula, id2: Formula, step: FormulaStep): Unit =
if formulaUF.find(id1) == formulaUF.find(id2) then ()
if find(id1) == find(id2) then ()
else
formulaProofMap((id1, id2)) = step
val newSet = formulaMap(formulaUF.find(id1)) ++ formulaMap(formulaUF.find(id2))
val newparents = formulaParents(formulaUF.find(id1)) ++ formulaParents(formulaUF.find(id2))
val newparents = formulaParents(find(id1)) ++ formulaParents(find(id2))
formulaUF.union(id1, id2)
val newId1 = formulaUF.find(id1)
val newId2 = formulaUF.find(id2)
formulaMap(newId1) = newSet
formulaMap(newId2) = newSet
formulaParents(newId1) = newparents
formulaParents(newId2) = newparents
val id = formulaUF.find(id2)
formulaWorklist = id :: formulaWorklist
val cause = (id1, id2)
val newId = find(id1)

val formulaSeen = mutable.Map[Formula, AppliedConnector]()
var formWorklist = List[(Formula, Formula, FormulaStep)]()

newparents.foreach {
case pFormula: AppliedConnector =>
val canonicalPFormula = canonicalize(pFormula)
if formulaSeen.contains(canonicalPFormula) then
val qFormula = formulaSeen(canonicalPFormula)
Some((pFormula, qFormula, cause))
mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula)))
formWorklist = (pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) :: formWorklist
//mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula)))
else
formulaSeen(canonicalPFormula) = pFormula
}
formulaParents(id) = formulaSeen.values.to(mutable.Set)
formulaParents(newId) = formulaSeen.values.to(mutable.Set)
formWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) }


def proveTerm(using lib: Library, proof: lib.Proof)(id1: Term, id2:Term, base: Sequent): proof.ProofTacticJudgement =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ object Comprehensions {
}

def filter(using _proof: Proof, name: sourcecode.Name)(filter: (Term ** 1) |-> Formula): Comprehension { val proof: _proof.type } = {
if (_proof.lockedSymbols ++ _proof.possibleGoal.toSet.flatMap(_.allSchematicLabels)).map(_.id.name).contains(name.value) then throw new Exception(s"Name $name is already used in the proof")
if (_proof.lockedSymbols ++ _proof.possibleGoal.toSet.flatMap(_.freeSchematicLabels)).map(_.id.name).contains(name.value) then throw new Exception(s"Name $name is already used in the proof")
val id = name.value
inline def _filter = filter
inline def _t = t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object SetTheory2 extends lisa.Main {
thenHave(in(x, A) |- ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))) by Weakening
thenHave(in(x, A) |- (z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z)))) by RightForall
thenHave(in(x, A) |- (y, (z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by RightForall
//thenHave(in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by Restate
thenHave(in(x, A) ==> (y, (z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by Restate
thenHave((x, in(x, A) ==> (y, (z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z)))))) by RightForall
thenHave(thesis) by Restate

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.scalatest.funsuite.AnyFunSuite

class CongruenceTest extends AnyFunSuite with lisa.TestMain {


given lib: lisa.SetTheoryLibrary.type = lisa.SetTheoryLibrary

val a = variable
Expand Down Expand Up @@ -254,8 +255,6 @@ class CongruenceTest extends AnyFunSuite with lisa.TestMain {
assert(egraph.idEq(fx, x))
assert(egraph.idEq(x, fx))

assert(egraph.explain(fx, x) == Some(List(egraph.TermCongruence(fx, fffx), egraph.TermCongruence(fffx, ffffffffx), egraph.TermExternal(ffffffffx, x))))

}


Expand Down

0 comments on commit cc73785

Please sign in to comment.