From b832a30bf4ac35c73e10beeb8267a46eab7c8d7b Mon Sep 17 00:00:00 2001 From: zhengtao Date: Wed, 15 Jan 2025 19:40:35 +0800 Subject: [PATCH] change shuffleKey judgement --- .../service/deploy/worker/Controller.scala | 72 +++++++++++-------- .../service/deploy/worker/WorkerSuite.scala | 63 +++++++++++++++- 2 files changed, 104 insertions(+), 31 deletions(-) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala index c92d34dba01..f8317f94627 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala @@ -456,7 +456,7 @@ private[deploy] class Controller( JavaUtils.newConcurrentHashMap[Long, (Long, RpcCallContext)]()) val epochWaitTimeMap = shuffleCommitTime.get(shuffleKey) val WaitTimestamp = System.currentTimeMillis() - epochWaitTimeMap.putIfAbsent(epoch, (WaitTimestamp, context)) + epochWaitTimeMap.put(epoch, (WaitTimestamp, context)) return } else { logInfo(s"Start commitFiles for $shuffleKey") @@ -746,40 +746,52 @@ private[deploy] class Controller( String, ConcurrentHashMap[Long, (Long, RpcCallContext)]]): Unit = { - shuffleCommitTime.asScala.foreach { - case (shuffleKey, epochWaitTimeMap) => - if (!shuffleCommitInfos.containsKey(shuffleKey)) { - shuffleCommitTime.remove(shuffleKey) - } else { - epochWaitTimeMap.asScala.foreach { case (epoch, (waitTime, context)) => - val commitInfo = shuffleCommitInfos.get(shuffleKey).get(epoch) - if (commitInfo == null) { - epochWaitTimeMap.remove(epoch) - } else { - commitInfo.synchronized { - if (commitInfo.status == CommitInfo.COMMIT_FINISHED) { - context.reply(commitInfo.response) - epochWaitTimeMap.remove(epoch) - } else { - val currentTime = System.currentTimeMillis() - if (currentTime - waitTime >= shuffleCommitTimeout) { - val replyResponse = CommitFilesResponse( - StatusCode.COMMIT_FILE_EXCEPTION, - List.empty.asJava, - List.empty.asJava, - commitInfo.response.failedPrimaryIds, - commitInfo.response.failedReplicaIds) - shuffleCommitInfos.get(shuffleKey).put( - epoch, - new CommitInfo(replyResponse, CommitInfo.COMMIT_FINISHED)) - context.reply(replyResponse) - epochWaitTimeMap.remove(epoch) - } + val commitTimeIterator = shuffleCommitTime.entrySet().iterator() + while (commitTimeIterator.hasNext) { + val timeMapEntry = commitTimeIterator.next() + val shuffleKey = timeMapEntry.getKey + val epochWaitTimeMap = timeMapEntry.getValue + val epochIterator = epochWaitTimeMap.entrySet().iterator() + + while (epochIterator.hasNext && shuffleCommitInfos.containsKey(shuffleKey)) { + val epochWaitTimeEntry = epochIterator.next() + val epoch = epochWaitTimeEntry.getKey + val (waitTime, context) = epochWaitTimeEntry.getValue + val commitInfo = shuffleCommitInfos.get(shuffleKey).get(epoch) + if (commitInfo != null) { + try { + commitInfo.synchronized { + if (commitInfo.status == CommitInfo.COMMIT_FINISHED) { + context.reply(commitInfo.response) + epochIterator.remove() + } else { + val currentTime = System.currentTimeMillis() + if (currentTime - waitTime >= shuffleCommitTimeout) { + val replyResponse = CommitFilesResponse( + StatusCode.COMMIT_FILE_EXCEPTION, + List.empty.asJava, + List.empty.asJava, + commitInfo.response.failedPrimaryIds, + commitInfo.response.failedReplicaIds) + shuffleCommitInfos.get(shuffleKey).put( + epoch, + new CommitInfo(replyResponse, CommitInfo.COMMIT_FINISHED)) + context.reply(replyResponse) + epochIterator.remove() } } } + } catch { + case error: Exception => + logError( + s"Exception occurs when checkCommitTimeout for shuffleKey-epoch:$shuffleKey-$epoch, error: $error") } } + } + if (!shuffleCommitInfos.containsKey(shuffleKey)) { + logWarning(s"Shuffle $shuffleKey commit expired when checkCommitTimeout.") + commitTimeIterator.remove() + } } } diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala index 662ea74c357..10cd015efe6 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala @@ -21,12 +21,14 @@ import java.io.File import java.nio.file.{Files, Paths} import java.util import java.util.{HashSet => JHashSet} +import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.junit.Assert import org.mockito.MockitoSugar._ -import org.scalatest.BeforeAndAfterEach +import org.scalatest.{shortstacks, BeforeAndAfterEach} import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.common.CelebornConf @@ -276,4 +278,63 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach { assert(shuffleCommitTime.get(shuffleKey).get(epoch2) == null) assert(epochCommitMap.get(epoch2).response.status == StatusCode.SUCCESS) } + + test("test check") { + val shuffleCommitInfos: ConcurrentHashMap[String, ConcurrentHashMap[Long, Long]] = + JavaUtils.newConcurrentHashMap[String, ConcurrentHashMap[Long, Long]]() + + val shuffleCommitTime: ConcurrentHashMap[String, ConcurrentHashMap[Long, Long]] = + JavaUtils.newConcurrentHashMap[String, ConcurrentHashMap[Long, Long]]() + + shuffleCommitInfos.put("1", JavaUtils.newConcurrentHashMap[Long, Long]()) + val epochCommitMap1 = shuffleCommitInfos.get("1") + epochCommitMap1.put(1L, 1L) + + shuffleCommitInfos.put("2", JavaUtils.newConcurrentHashMap[Long, Long]()) + val epochCommitMap2 = shuffleCommitInfos.get("2") + epochCommitMap2.put(2L, 2L) + + shuffleCommitInfos.put("3", JavaUtils.newConcurrentHashMap[Long, Long]()) + val epochCommitMap3 = shuffleCommitInfos.get("3") + epochCommitMap3.put(3L, 3L) + + shuffleCommitTime.put("1", JavaUtils.newConcurrentHashMap[Long, Long]()) + val commitTimeMap1 = shuffleCommitTime.get("1") + commitTimeMap1.put(1L, 1L) + + shuffleCommitTime.put("2", JavaUtils.newConcurrentHashMap[Long, Long]()) + val commitTimeMap2 = shuffleCommitTime.get("2") + commitTimeMap2.put(2L, 2L) + commitTimeMap2.put(4L, 4L) + + shuffleCommitTime.put("3", JavaUtils.newConcurrentHashMap[Long, Long]()) + val commitTimeMap3 = shuffleCommitTime.get("3") + commitTimeMap3.put(3L, 3L) + + assert(shuffleCommitInfos.size() == 3) + + val commitTimeIterator = shuffleCommitTime.entrySet().iterator() + val testList = new ArrayBuffer[Long]() + val res = ArrayBuffer[Long](1L, 3L) + while (commitTimeIterator.hasNext) { + val timeMapEntry = commitTimeIterator.next() + val shuffleKey = timeMapEntry.getKey + val epochWaitTimeMap = timeMapEntry.getValue + val epochIterator = epochWaitTimeMap.entrySet().iterator() + + shuffleCommitInfos.remove("2") + while (epochIterator.hasNext && shuffleCommitInfos.containsKey(shuffleKey)) { + val epochWaitTimeEntry = epochIterator.next() + val epoch = epochWaitTimeEntry.getKey + val value = epochWaitTimeEntry.getValue + testList.append(epoch) + } + if (!shuffleCommitInfos.containsKey(shuffleKey)) { + commitTimeIterator.remove() + } + } + assert(testList.equals(res)) + assert(shuffleCommitTime.size() == 2) + + } }