Skip to content

Commit

Permalink
[CELEBORN-1109] Cache RegisterShuffleResponse to improve the processi…
Browse files Browse the repository at this point in the history
…ng speed of LifecycleManager

### What changes were proposed in this pull request?
Cache RegisterShuffleResponse to improve the processing speed of LifecycleManager

### Why are the changes needed?
During the processing of the registerShuffle request, constructing the RegisterShuffleResponse instance and serialization can indeed consume a significant amount of time.  When there are a large number of registerShuffle requests that need to be processed by the LifecycleManager simultaneously, the response time of the LifecycleManager will be delayed. Therefore, caching is needed to improve the processing performance of the LifecycleManager.

![image](https://github.com/apache/incubator-celeborn/assets/107825064/06d3cb3c-156a-46c7-a08d-fefa18b26e40)

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

Closes #2070 from kerwin-zk/issue-1109.

Authored-by: xiyu.zk <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
  • Loading branch information
kerwin-zk authored and waitinfuture committed Nov 7, 2023
1 parent 52eddc5 commit ffbbe25
Showing 1 changed file with 79 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@

package org.apache.celeborn.client

import java.nio.ByteBuffer
import java.util
import java.util.{function, List => JList}
import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit}
import java.util.concurrent.{Callable, ConcurrentHashMap, ScheduledFuture, TimeUnit}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random

import com.google.common.annotations.VisibleForTesting
import com.google.common.cache.{Cache, CacheBuilder}

import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
import org.apache.celeborn.client.listener.WorkerStatusListener
Expand All @@ -39,6 +41,7 @@ import org.apache.celeborn.common.protocol.RpcNameConstants.WORKER_EP
import org.apache.celeborn.common.protocol.message.ControlMessages._
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.rpc._
import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext}
import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, ThreadUtils, Utils}
// Can Remove this if celeborn don't support scala211 in future
import org.apache.celeborn.common.util.FunctionConverter._
Expand Down Expand Up @@ -77,6 +80,16 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
private val userIdentifier: UserIdentifier = IdentityProvider.instantiate(conf).provide()
private val availableStorageTypes = conf.availableStorageTypes

private val rpcCacheSize = conf.clientRpcCacheSize
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime

private val registerShuffleResponseRpcCache: Cache[Int, ByteBuffer] = CacheBuilder.newBuilder()
.concurrencyLevel(rpcCacheConcurrencyLevel)
.expireAfterAccess(rpcCacheExpireTime, TimeUnit.MILLISECONDS)
.maximumSize(rpcCacheSize)
.build().asInstanceOf[Cache[Int, ByteBuffer]]

@VisibleForTesting
def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, ShufflePartitionLocationInfo] =
shuffleAllocatedWorkers.get(shuffleId)
Expand Down Expand Up @@ -316,21 +329,32 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
// If shuffle is registered, reply this shuffle's partition location and return.
// Else add this request to registeringShuffleRequest.
if (registeredShuffle.contains(shuffleId)) {
val initialLocs = workerSnapshots(shuffleId)
.values()
.asScala
.flatMap(_.getAllPrimaryLocationsWithMinEpoch())
.filter(p =>
(partitionType == PartitionType.REDUCE && p.getEpoch == 0) || (partitionType == PartitionType.MAP && p.getId == partitionId))
.toArray
val rpcContext: RpcCallContext = context.context
partitionType match {
case PartitionType.MAP => processMapTaskReply(
case PartitionType.MAP =>
processMapTaskReply(
shuffleId,
context.context,
rpcContext,
partitionId,
initialLocs)
getInitialLocs(shuffleId, p => p.getId == partitionId))
case PartitionType.REDUCE =>
context.reply(RegisterShuffleResponse(StatusCode.SUCCESS, initialLocs))
if (rpcContext.isInstanceOf[LocalNettyRpcCallContext]) {
context.reply(RegisterShuffleResponse(
StatusCode.SUCCESS,
getInitialLocs(shuffleId, p => p.getEpoch == 0)))
} else {
val cachedMsg = registerShuffleResponseRpcCache.get(
shuffleId,
new Callable[ByteBuffer]() {
override def call(): ByteBuffer = {
rpcContext.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(
RegisterShuffleResponse(
StatusCode.SUCCESS,
getInitialLocs(shuffleId, p => p.getEpoch == 0)))
}
})
rpcContext.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg)
}
case _ =>
throw new UnsupportedOperationException(s"Not support $partitionType yet")
}
Expand All @@ -345,6 +369,17 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
}
}

def getInitialLocs(
shuffleId: Int,
partitionLocationFilter: PartitionLocation => Boolean): Array[PartitionLocation] = {
workerSnapshots(shuffleId)
.values()
.asScala
.flatMap(_.getAllPrimaryLocationsWithMinEpoch())
.filter(partitionLocationFilter)
.toArray
}

def processMapTaskReply(
shuffleId: Int,
context: RpcCallContext,
Expand All @@ -365,8 +400,24 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
}

// Reply to all RegisterShuffle request for current shuffle id.
def reply(response: PbRegisterShuffleResponse): Unit = {
def replyRegisterShuffle(response: PbRegisterShuffleResponse): Unit = {
registeringShuffleRequest.synchronized {
val serializedMsg: Option[ByteBuffer] = partitionType match {
case PartitionType.REDUCE =>
context.context match {
case remoteContext: RemoteNettyRpcCallContext =>
if (response.getStatus == StatusCode.SUCCESS.getValue) {
Option(remoteContext.nettyEnv.serialize(
response))
} else {
Option.empty
}

case _ => Option.empty
}
case _ => Option.empty
}

registeringShuffleRequest.asScala
.get(shuffleId)
.foreach(_.asScala.foreach(context => {
Expand All @@ -387,7 +438,15 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
// otherwise will lost original exception message
context.reply(response)
}
case PartitionType.REDUCE => context.reply(response)
case PartitionType.REDUCE =>
if (context.context.isInstanceOf[
LocalNettyRpcCallContext] || response.getStatus != StatusCode.SUCCESS.getValue) {
context.reply(response)
} else {
registerShuffleResponseRpcCache.put(shuffleId, serializedMsg.get)
context.context.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(
serializedMsg.get)
}
case _ =>
throw new UnsupportedOperationException(s"Not support $partitionType yet")
}
Expand All @@ -404,11 +463,11 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
res.status match {
case StatusCode.REQUEST_FAILED =>
logInfo(s"OfferSlots RPC request failed for $shuffleId!")
reply(RegisterShuffleResponse(StatusCode.REQUEST_FAILED, Array.empty))
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.REQUEST_FAILED, Array.empty))
return
case StatusCode.SLOT_NOT_AVAILABLE =>
logInfo(s"OfferSlots for $shuffleId failed!")
reply(RegisterShuffleResponse(StatusCode.SLOT_NOT_AVAILABLE, Array.empty))
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.SLOT_NOT_AVAILABLE, Array.empty))
return
case StatusCode.SUCCESS =>
logInfo(s"OfferSlots for $shuffleId Success!Slots Info: ${res.workerResource}")
Expand Down Expand Up @@ -455,7 +514,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
// If reserve slots failed, clear allocated resources, reply ReserveSlotFailed and return.
if (!reserveSlotsSuccess) {
logError(s"reserve buffer for $shuffleId failed, reply to all.")
reply(RegisterShuffleResponse(StatusCode.RESERVE_SLOTS_FAILED, Array.empty))
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.RESERVE_SLOTS_FAILED, Array.empty))
} else {
logInfo(s"ReserveSlots for $shuffleId success with details:$slots!")
// Forth, register shuffle success, update status
Expand All @@ -475,7 +534,9 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
// Fifth, reply the allocated partition location to ShuffleClient.
logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
val allPrimaryPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).toArray
reply(RegisterShuffleResponse(StatusCode.SUCCESS, allPrimaryPartitionLocations))
replyRegisterShuffle(RegisterShuffleResponse(
StatusCode.SUCCESS,
allPrimaryPartitionLocations))
}
}

Expand Down

0 comments on commit ffbbe25

Please sign in to comment.