From 9eddcadbbc4e01bacf59d55f5079cd8b6c5f268b Mon Sep 17 00:00:00 2001 From: zhengtao Date: Thu, 16 Jan 2025 20:32:54 +0800 Subject: [PATCH] add ut for error condition --- .../service/deploy/worker/Controller.scala | 34 ++++++++-------- .../service/deploy/worker/WorkerSuite.scala | 40 +++++++++++++++---- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala index 0678e8357e0..6dc8c8539d7 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala @@ -760,25 +760,23 @@ private[deploy] class Controller( val (commitStartWaitTime, context) = epochWaitTimeEntry.getValue try { val commitInfo = shuffleCommitInfos.get(shuffleKey).get(epoch) - if (commitInfo != null) { - commitInfo.synchronized { - if (commitInfo.status == CommitInfo.COMMIT_FINISHED) { - context.reply(commitInfo.response) + commitInfo.synchronized { + if (commitInfo.status == CommitInfo.COMMIT_FINISHED) { + context.reply(commitInfo.response) + epochIterator.remove() + } else { + if (currentTime - commitStartWaitTime >= shuffleCommitTimeout) { + val replyResponse = CommitFilesResponse( + StatusCode.COMMIT_FILE_EXCEPTION, + List.empty.asJava, + List.empty.asJava, + commitInfo.response.failedPrimaryIds, + commitInfo.response.failedReplicaIds) + shuffleCommitInfos.get(shuffleKey).put( + epoch, + new CommitInfo(replyResponse, CommitInfo.COMMIT_FINISHED)) + context.reply(replyResponse) epochIterator.remove() - } else { - if (currentTime - commitStartWaitTime >= shuffleCommitTimeout) { - val replyResponse = CommitFilesResponse( - StatusCode.COMMIT_FILE_EXCEPTION, - List.empty.asJava, - List.empty.asJava, - commitInfo.response.failedPrimaryIds, - commitInfo.response.failedReplicaIds) - shuffleCommitInfos.get(shuffleKey).put( - epoch, - new CommitInfo(replyResponse, CommitInfo.COMMIT_FINISHED)) - context.reply(replyResponse) - epochIterator.remove() - } } } } diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala index 9b43e3efbc8..c28aa9f58da 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala @@ -190,6 +190,7 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach { val epoch0: Long = 0 val epoch1: Long = 1 val epoch2: Long = 2 + val epoch3: Long = 3 val startWaitTime = System.currentTimeMillis() // update an INPROCESS commitInfo @@ -207,9 +208,12 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach { shuffleKey, JavaUtils.newConcurrentHashMap[Long, (Long, RpcCallContext)]()) val epochWaitTimeMap = shuffleCommitTime.get(shuffleKey) - epochWaitTimeMap.putIfAbsent(epoch0, (startWaitTime, context)) + epochWaitTimeMap.put(epoch0, (startWaitTime, context)) } + assert(shuffleCommitTime.get(shuffleKey).get(epoch0)._1 == startWaitTime) + assert(epochCommitMap.get(epoch0).status == CommitInfo.COMMIT_INPROCESS) + // update an INPROCESS commitInfo val response1 = CommitFilesResponse( null, @@ -225,9 +229,12 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach { shuffleKey, JavaUtils.newConcurrentHashMap[Long, (Long, RpcCallContext)]()) val epochWaitTimeMap = shuffleCommitTime.get(shuffleKey) - epochWaitTimeMap.putIfAbsent(epoch1, (startWaitTime, context)) + epochWaitTimeMap.put(epoch1, (startWaitTime, context)) } + assert(shuffleCommitTime.get(shuffleKey).get(epoch1)._1 == startWaitTime) + assert(epochCommitMap.get(epoch1).status == CommitInfo.COMMIT_INPROCESS) + // update an FINISHED commitInfo val response2 = CommitFilesResponse( StatusCode.SUCCESS, @@ -235,7 +242,7 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach { replicaIds.asJava, List.empty.asJava, List.empty.asJava) - epochCommitMap.putIfAbsent(epoch2, new CommitInfo(response2, CommitInfo.COMMIT_FINISHED)) + epochCommitMap.put(epoch2, new CommitInfo(response2, CommitInfo.COMMIT_FINISHED)) val commitInfo2 = epochCommitMap.get(epoch2) commitInfo2.synchronized { @@ -244,16 +251,26 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach { JavaUtils.newConcurrentHashMap[Long, (Long, RpcCallContext)]()) val epochWaitTimeMap = shuffleCommitTime.get(shuffleKey) // epoch2 is already timeout - epochWaitTimeMap.putIfAbsent(epoch2, (startWaitTime, context)) + epochWaitTimeMap.put(epoch2, (startWaitTime, context)) } - assert(shuffleCommitTime.get(shuffleKey).get(epoch0)._1 == startWaitTime) - assert(epochCommitMap.get(epoch0).status == CommitInfo.COMMIT_INPROCESS) - assert(shuffleCommitTime.get(shuffleKey).get(epoch1)._1 == startWaitTime) - assert(epochCommitMap.get(epoch1).status == CommitInfo.COMMIT_INPROCESS) assert(shuffleCommitTime.get(shuffleKey).get(epoch2)._1 == startWaitTime) assert(epochCommitMap.get(epoch2).status == CommitInfo.COMMIT_FINISHED) + // add a new shuffleKey2 to shuffleCommitTime but not to shuffleCommitInfos + val shuffleKey2 = "2" + shuffleCommitTime.putIfAbsent( + shuffleKey2, + JavaUtils.newConcurrentHashMap[Long, (Long, RpcCallContext)]()) + shuffleCommitTime.get(shuffleKey2).put(epoch0, (startWaitTime, context)) + assert(shuffleCommitTime.containsKey(shuffleKey2)) + assert(!shuffleCommitInfos.containsKey(shuffleKey2)) + + // add an epoch to shuffleCommitTime but not to shuffleCommitInfos + shuffleCommitTime.get(shuffleKey).put(epoch3, (startWaitTime, context)) + assert(shuffleCommitTime.get(shuffleKey).get(epoch3)._1 == startWaitTime) + assert(!shuffleCommitInfos.get(shuffleKey).containsKey(epoch3)) + // update status of epoch1 to FINISHED epochCommitMap.get(epoch1).status = CommitInfo.COMMIT_FINISHED assert(epochCommitMap.get(epoch1).status == CommitInfo.COMMIT_FINISHED) @@ -262,6 +279,13 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach { controller.checkCommitTimeout(shuffleCommitTime) assert(epochCommitMap.get(epoch0).status == CommitInfo.COMMIT_INPROCESS) + // shuffleCommitTime will be removed when shuffleCommitInfos contains no shuffleKey + assert(!shuffleCommitTime.containsKey(shuffleKey2)) + assert(!shuffleCommitInfos.containsKey(shuffleKey2)) + + // epoch will be removed when shuffleCommitInfos contains no epoch + assert(!shuffleCommitTime.get(shuffleKey).containsKey(epoch3)) + // FINISHED status of epoch1 will be removed from shuffleCommitTime assert(shuffleCommitTime.get(shuffleKey).get(epoch1) == null)