Skip to content

Commit

Permalink
change shuffleKey judgement
Browse files Browse the repository at this point in the history
  • Loading branch information
zaynt4606 committed Jan 15, 2025
1 parent 24aa9d6 commit b832a30
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ private[deploy] class Controller(
JavaUtils.newConcurrentHashMap[Long, (Long, RpcCallContext)]())
val epochWaitTimeMap = shuffleCommitTime.get(shuffleKey)
val WaitTimestamp = System.currentTimeMillis()
epochWaitTimeMap.putIfAbsent(epoch, (WaitTimestamp, context))
epochWaitTimeMap.put(epoch, (WaitTimestamp, context))
return
} else {
logInfo(s"Start commitFiles for $shuffleKey")
Expand Down Expand Up @@ -746,40 +746,52 @@ private[deploy] class Controller(
String,
ConcurrentHashMap[Long, (Long, RpcCallContext)]]): Unit = {

shuffleCommitTime.asScala.foreach {
case (shuffleKey, epochWaitTimeMap) =>
if (!shuffleCommitInfos.containsKey(shuffleKey)) {
shuffleCommitTime.remove(shuffleKey)
} else {
epochWaitTimeMap.asScala.foreach { case (epoch, (waitTime, context)) =>
val commitInfo = shuffleCommitInfos.get(shuffleKey).get(epoch)
if (commitInfo == null) {
epochWaitTimeMap.remove(epoch)
} else {
commitInfo.synchronized {
if (commitInfo.status == CommitInfo.COMMIT_FINISHED) {
context.reply(commitInfo.response)
epochWaitTimeMap.remove(epoch)
} else {
val currentTime = System.currentTimeMillis()
if (currentTime - waitTime >= shuffleCommitTimeout) {
val replyResponse = CommitFilesResponse(
StatusCode.COMMIT_FILE_EXCEPTION,
List.empty.asJava,
List.empty.asJava,
commitInfo.response.failedPrimaryIds,
commitInfo.response.failedReplicaIds)
shuffleCommitInfos.get(shuffleKey).put(
epoch,
new CommitInfo(replyResponse, CommitInfo.COMMIT_FINISHED))
context.reply(replyResponse)
epochWaitTimeMap.remove(epoch)
}
val commitTimeIterator = shuffleCommitTime.entrySet().iterator()
while (commitTimeIterator.hasNext) {
val timeMapEntry = commitTimeIterator.next()
val shuffleKey = timeMapEntry.getKey
val epochWaitTimeMap = timeMapEntry.getValue
val epochIterator = epochWaitTimeMap.entrySet().iterator()

while (epochIterator.hasNext && shuffleCommitInfos.containsKey(shuffleKey)) {
val epochWaitTimeEntry = epochIterator.next()
val epoch = epochWaitTimeEntry.getKey
val (waitTime, context) = epochWaitTimeEntry.getValue
val commitInfo = shuffleCommitInfos.get(shuffleKey).get(epoch)
if (commitInfo != null) {
try {
commitInfo.synchronized {
if (commitInfo.status == CommitInfo.COMMIT_FINISHED) {
context.reply(commitInfo.response)
epochIterator.remove()
} else {
val currentTime = System.currentTimeMillis()
if (currentTime - waitTime >= shuffleCommitTimeout) {
val replyResponse = CommitFilesResponse(
StatusCode.COMMIT_FILE_EXCEPTION,
List.empty.asJava,
List.empty.asJava,
commitInfo.response.failedPrimaryIds,
commitInfo.response.failedReplicaIds)
shuffleCommitInfos.get(shuffleKey).put(
epoch,
new CommitInfo(replyResponse, CommitInfo.COMMIT_FINISHED))
context.reply(replyResponse)
epochIterator.remove()
}
}
}
} catch {
case error: Exception =>
logError(
s"Exception occurs when checkCommitTimeout for shuffleKey-epoch:$shuffleKey-$epoch, error: $error")
}
}
}
if (!shuffleCommitInfos.containsKey(shuffleKey)) {
logWarning(s"Shuffle $shuffleKey commit expired when checkCommitTimeout.")
commitTimeIterator.remove()
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import java.io.File
import java.nio.file.{Files, Paths}
import java.util
import java.util.{HashSet => JHashSet}
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.junit.Assert
import org.mockito.MockitoSugar._
import org.scalatest.BeforeAndAfterEach
import org.scalatest.{shortstacks, BeforeAndAfterEach}
import org.scalatest.funsuite.AnyFunSuite

import org.apache.celeborn.common.CelebornConf
Expand Down Expand Up @@ -276,4 +278,63 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach {
assert(shuffleCommitTime.get(shuffleKey).get(epoch2) == null)
assert(epochCommitMap.get(epoch2).response.status == StatusCode.SUCCESS)
}

test("test check") {
val shuffleCommitInfos: ConcurrentHashMap[String, ConcurrentHashMap[Long, Long]] =
JavaUtils.newConcurrentHashMap[String, ConcurrentHashMap[Long, Long]]()

val shuffleCommitTime: ConcurrentHashMap[String, ConcurrentHashMap[Long, Long]] =
JavaUtils.newConcurrentHashMap[String, ConcurrentHashMap[Long, Long]]()

shuffleCommitInfos.put("1", JavaUtils.newConcurrentHashMap[Long, Long]())
val epochCommitMap1 = shuffleCommitInfos.get("1")
epochCommitMap1.put(1L, 1L)

shuffleCommitInfos.put("2", JavaUtils.newConcurrentHashMap[Long, Long]())
val epochCommitMap2 = shuffleCommitInfos.get("2")
epochCommitMap2.put(2L, 2L)

shuffleCommitInfos.put("3", JavaUtils.newConcurrentHashMap[Long, Long]())
val epochCommitMap3 = shuffleCommitInfos.get("3")
epochCommitMap3.put(3L, 3L)

shuffleCommitTime.put("1", JavaUtils.newConcurrentHashMap[Long, Long]())
val commitTimeMap1 = shuffleCommitTime.get("1")
commitTimeMap1.put(1L, 1L)

shuffleCommitTime.put("2", JavaUtils.newConcurrentHashMap[Long, Long]())
val commitTimeMap2 = shuffleCommitTime.get("2")
commitTimeMap2.put(2L, 2L)
commitTimeMap2.put(4L, 4L)

shuffleCommitTime.put("3", JavaUtils.newConcurrentHashMap[Long, Long]())
val commitTimeMap3 = shuffleCommitTime.get("3")
commitTimeMap3.put(3L, 3L)

assert(shuffleCommitInfos.size() == 3)

val commitTimeIterator = shuffleCommitTime.entrySet().iterator()
val testList = new ArrayBuffer[Long]()
val res = ArrayBuffer[Long](1L, 3L)
while (commitTimeIterator.hasNext) {
val timeMapEntry = commitTimeIterator.next()
val shuffleKey = timeMapEntry.getKey
val epochWaitTimeMap = timeMapEntry.getValue
val epochIterator = epochWaitTimeMap.entrySet().iterator()

shuffleCommitInfos.remove("2")
while (epochIterator.hasNext && shuffleCommitInfos.containsKey(shuffleKey)) {
val epochWaitTimeEntry = epochIterator.next()
val epoch = epochWaitTimeEntry.getKey
val value = epochWaitTimeEntry.getValue
testList.append(epoch)
}
if (!shuffleCommitInfos.containsKey(shuffleKey)) {
commitTimeIterator.remove()
}
}
assert(testList.equals(res))
assert(shuffleCommitTime.size() == 2)

}
}

0 comments on commit b832a30

Please sign in to comment.