From cc7378519ad1f1c92f89e81e1b94fd10c6eed4a5 Mon Sep 17 00:00:00 2001 From: SimonGuilloud Date: Fri, 4 Oct 2024 12:45:54 +0200 Subject: [PATCH] Small fixes to tests and congruence (#223) * fix tests, small cleaning in congruence algorithm. * further simplification * more minor improvements to congruence * update CHANGES.md * old --- CHANGES.md | 4 +- .../scala/lisa/automation/Congruence.scala | 134 +++++++++--------- .../lisa/maths/settheory/Comprehensions.scala | 2 +- .../lisa/maths/settheory/SetTheory2.scala | 2 +- .../lisa/automation/CongruenceTest.scala | 3 +- 5 files changed, 72 insertions(+), 73 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index fdcd76e10..537835d10 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,11 +1,13 @@ # Change List +## 2024-10-04 +Repair two broken tests, small improvements to congruence tactic. + ## 2024-10-03 Update to Scala 3.5.1. Silence warnings from scallion and other erroneous ones. ## 2024-07-22 Resealed the `Proof` trait following a fix of the relevant compiler bug [scala/scala3#19031](https://github.com/scala/scala3/issues/19031). - Updated to Scala 3.4.2, and relevant minor syntax changes from `-rewrite -source 3.4-migration`. ## 2024-04-12 diff --git a/lisa-sets/src/main/scala/lisa/automation/Congruence.scala b/lisa-sets/src/main/scala/lisa/automation/Congruence.scala index ccf25295f..8c60de410 100644 --- a/lisa-sets/src/main/scala/lisa/automation/Congruence.scala +++ b/lisa-sets/src/main/scala/lisa/automation/Congruence.scala @@ -240,25 +240,17 @@ import scala.collection.mutable class EGraphTerms() { - type ENode = Term | Formula - - - - val termMap = mutable.Map[Term, Set[Term]]() - val termParents = mutable.Map[Term, mutable.Set[AppliedFunctional | AppliedPredicate]]() - var termWorklist = List[Term]() + val termParentsT = mutable.Map[Term, mutable.Set[AppliedFunctional]]() + val termParentsF = mutable.Map[Term, mutable.Set[AppliedPredicate]]() val termUF = new UnionFind[Term]() - - - val formulaMap = mutable.Map[Formula, Set[Formula]]() val formulaParents = mutable.Map[Formula, mutable.Set[AppliedConnector]]() - var formulaWorklist = List[Formula]() val formulaUF = new UnionFind[Formula]() - + def find(id: Term): Term = termUF.find(id) + def find(id: Formula): Formula = formulaUF.find(id) trait TermStep case class TermExternal(between: (Term, Term)) extends TermStep @@ -316,62 +308,62 @@ class EGraphTerms() { def makeSingletonEClass(node:Term): Term = { termUF.add(node) - termMap(node) = Set(node) - termParents(node) = mutable.Set() + termParentsT(node) = mutable.Set() + termParentsF(node) = mutable.Set() node } def makeSingletonEClass(node:Formula): Formula = { formulaUF.add(node) - formulaMap(node) = Set(node) formulaParents(node) = mutable.Set() node } - def classOf(id: Term): Set[Term] = termMap(id) - def classOf(id: Formula): Set[Formula] = formulaMap(id) + def idEq(id1: Term, id2: Term): Boolean = find(id1) == find(id2) + def idEq(id1: Formula, id2: Formula): Boolean = find(id1) == find(id2) - def idEq(id1: Term, id2: Term): Boolean = termUF.find(id1) == termUF.find(id2) - def idEq(id1: Formula, id2: Formula): Boolean = formulaUF.find(id1) == formulaUF.find(id2) + def canonicalize(node: Term): Term = node match case AppliedFunctional(label, args) => - AppliedFunctional(label, args.map(termUF.find.asInstanceOf)) + AppliedFunctional(label, args.map(t => find(t))) case _ => node def canonicalize(node: Formula): Formula = { node match - case AppliedPredicate(label, args) => AppliedPredicate(label, args.map(termUF.find)) - case AppliedConnector(label, args) => AppliedConnector(label, args.map(formulaUF.find)) + case AppliedPredicate(label, args) => AppliedPredicate(label, args.map(find)) + case AppliedConnector(label, args) => AppliedConnector(label, args.map(find)) case node => node } def add(node: Term): Term = - if termMap.contains(node) then return node + if termUF.parent.contains(node) then return node makeSingletonEClass(node) + codes(node) = codes.size node match case node @ AppliedFunctional(_, args) => args.foreach(child => add(child) - termParents(child).add(node) + termParentsT(find(child)).add(node) ) - node - case _ => node + case _ => () + termSigs(canSig(node)) = node + node def add(node: Formula): Formula = - if formulaMap.contains(node) then return node + if formulaUF.parent.contains(node) then return node makeSingletonEClass(node) node match case node @ AppliedPredicate(_, args) => args.foreach(child => add(child) - termParents(child).add(node) + termParentsF(find(child)).add(node) ) node case node @ AppliedConnector(_, args) => args.foreach(child => add(child) - formulaParents(child).add(node) + formulaParents(find(child)).add(node) ) node case _ => node @@ -392,75 +384,81 @@ class EGraphTerms() { mergeWithStep(id1, id2, FormulaExternal((id1, id2))) } + type Sig = (TermLabel[?]|Term, List[Int]) + val termSigs = mutable.Map[Sig, Term]() + val codes = mutable.Map[Term, Int]() + + def canSig(node: Term): Sig = node match + case AppliedFunctional(label, args) => + (label, args.map(a => codes(find(a))).toList) + case _ => (node, List()) + protected def mergeWithStep(id1: Term, id2: Term, step: TermStep): Unit = { - if termUF.find(id1) == termUF.find(id2) then () + if find(id1) == find(id2) then () else termProofMap((id1, id2)) = step - val newSet = termMap(termUF.find(id1)) ++ termMap(termUF.find(id2)) - val newparents = termParents(termUF.find(id1)) ++ termParents(termUF.find(id2)) + val parentsT1 = termParentsT(find(id1)) + val parentsF1 = termParentsF(find(id1)) + + val parentsT2 = termParentsT(find(id2)) + val parentsF2 = termParentsF(find(id2)) + val preSigs : Map[Term, Sig] = parentsT1.map(t => (t, canSig(t))).toMap + codes(find(id2)) = codes(find(id1)) //assume parents(find(id1)) >= parents(find(id2)) termUF.union(id1, id2) - val newId1 = termUF.find(id1) - val newId2 = termUF.find(id2) - termMap(newId1) = newSet - termMap(newId2) = newSet - termParents(newId1) = newparents - termParents(newId2) = newparents + val newId = find(id1) - val id = termUF.find(id2) - termWorklist = id :: termWorklist - val cause = (id1, id2) - val termSeen = mutable.Map[Term, AppliedFunctional]() val formulaSeen = mutable.Map[Formula, AppliedPredicate]() - newparents.foreach { + var formWorklist = List[(Formula, Formula, FormulaStep)]() + var termWorklist = List[(Term, Term, TermStep)]() + + parentsT2.foreach { case pTerm: AppliedFunctional => - val canonicalPTerm = canonicalize(pTerm) - if termSeen.contains(canonicalPTerm) then - val qTerm = termSeen(canonicalPTerm) - Some((pTerm, qTerm, cause)) - mergeWithStep(pTerm, qTerm, TermCongruence((pTerm, qTerm))) + val canonicalPTerm = canSig(pTerm) + if termSigs.contains(canonicalPTerm) then + val qTerm = termSigs(canonicalPTerm) + termWorklist = (pTerm, qTerm, TermCongruence((pTerm, qTerm))) :: termWorklist else - termSeen(canonicalPTerm) = pTerm + termSigs(canonicalPTerm) = pTerm + } + (parentsF2 ++ parentsF1).foreach { case pFormula: AppliedPredicate => val canonicalPFormula = canonicalize(pFormula) if formulaSeen.contains(canonicalPFormula) then val qFormula = formulaSeen(canonicalPFormula) - - Some((pFormula, qFormula, cause)) - mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) + formWorklist = (pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) :: formWorklist else formulaSeen(canonicalPFormula) = pFormula } - termParents(id) = (termSeen.values.to(mutable.Set): mutable.Set[AppliedFunctional | AppliedPredicate]) ++ formulaSeen.values.to(mutable.Set) + termParentsT(newId) = termParentsT(id1) + termParentsT(newId).addAll(termParentsT(id2)) + termParentsF(newId) = formulaSeen.values.to(mutable.Set) + formWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) } + termWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) } } protected def mergeWithStep(id1: Formula, id2: Formula, step: FormulaStep): Unit = - if formulaUF.find(id1) == formulaUF.find(id2) then () + if find(id1) == find(id2) then () else formulaProofMap((id1, id2)) = step - val newSet = formulaMap(formulaUF.find(id1)) ++ formulaMap(formulaUF.find(id2)) - val newparents = formulaParents(formulaUF.find(id1)) ++ formulaParents(formulaUF.find(id2)) + val newparents = formulaParents(find(id1)) ++ formulaParents(find(id2)) formulaUF.union(id1, id2) - val newId1 = formulaUF.find(id1) - val newId2 = formulaUF.find(id2) - formulaMap(newId1) = newSet - formulaMap(newId2) = newSet - formulaParents(newId1) = newparents - formulaParents(newId2) = newparents - val id = formulaUF.find(id2) - formulaWorklist = id :: formulaWorklist - val cause = (id1, id2) + val newId = find(id1) + val formulaSeen = mutable.Map[Formula, AppliedConnector]() + var formWorklist = List[(Formula, Formula, FormulaStep)]() + newparents.foreach { case pFormula: AppliedConnector => val canonicalPFormula = canonicalize(pFormula) if formulaSeen.contains(canonicalPFormula) then val qFormula = formulaSeen(canonicalPFormula) - Some((pFormula, qFormula, cause)) - mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) + formWorklist = (pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) :: formWorklist + //mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) else formulaSeen(canonicalPFormula) = pFormula } - formulaParents(id) = formulaSeen.values.to(mutable.Set) + formulaParents(newId) = formulaSeen.values.to(mutable.Set) + formWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) } def proveTerm(using lib: Library, proof: lib.Proof)(id1: Term, id2:Term, base: Sequent): proof.ProofTacticJudgement = diff --git a/lisa-sets/src/main/scala/lisa/maths/settheory/Comprehensions.scala b/lisa-sets/src/main/scala/lisa/maths/settheory/Comprehensions.scala index 2ebbdc47c..55a2a2d98 100644 --- a/lisa-sets/src/main/scala/lisa/maths/settheory/Comprehensions.scala +++ b/lisa-sets/src/main/scala/lisa/maths/settheory/Comprehensions.scala @@ -177,7 +177,7 @@ object Comprehensions { } def filter(using _proof: Proof, name: sourcecode.Name)(filter: (Term ** 1) |-> Formula): Comprehension { val proof: _proof.type } = { - if (_proof.lockedSymbols ++ _proof.possibleGoal.toSet.flatMap(_.allSchematicLabels)).map(_.id.name).contains(name.value) then throw new Exception(s"Name $name is already used in the proof") + if (_proof.lockedSymbols ++ _proof.possibleGoal.toSet.flatMap(_.freeSchematicLabels)).map(_.id.name).contains(name.value) then throw new Exception(s"Name $name is already used in the proof") val id = name.value inline def _filter = filter inline def _t = t diff --git a/lisa-sets/src/main/scala/lisa/maths/settheory/SetTheory2.scala b/lisa-sets/src/main/scala/lisa/maths/settheory/SetTheory2.scala index f9599302e..2957a9ec4 100644 --- a/lisa-sets/src/main/scala/lisa/maths/settheory/SetTheory2.scala +++ b/lisa-sets/src/main/scala/lisa/maths/settheory/SetTheory2.scala @@ -45,7 +45,7 @@ object SetTheory2 extends lisa.Main { thenHave(in(x, A) |- ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))) by Weakening thenHave(in(x, A) |- ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z)))) by RightForall thenHave(in(x, A) |- ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by RightForall - //thenHave(in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by Restate + thenHave(in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by Restate thenHave(∀(x, in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z)))))) by RightForall thenHave(thesis) by Restate diff --git a/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala b/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala index 5e50502d5..34bc77eee 100644 --- a/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala +++ b/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala @@ -7,6 +7,7 @@ import org.scalatest.funsuite.AnyFunSuite class CongruenceTest extends AnyFunSuite with lisa.TestMain { + given lib: lisa.SetTheoryLibrary.type = lisa.SetTheoryLibrary val a = variable @@ -254,8 +255,6 @@ class CongruenceTest extends AnyFunSuite with lisa.TestMain { assert(egraph.idEq(fx, x)) assert(egraph.idEq(x, fx)) - assert(egraph.explain(fx, x) == Some(List(egraph.TermCongruence(fx, fffx), egraph.TermCongruence(fffx, ffffffffx), egraph.TermExternal(ffffffffx, x)))) - }