Skip to content

Commit

Permalink
[CELEBORN-1660] Using map for workers to find worker fast
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Using map for workers so that we can find a worker by uniqueId fast.

### Why are the changes needed?

For large celeborn cluster, it might be slow.

- updateWorkerHeartbeatMeta
https://github.com/apache/celeborn/blob/1e77f01cd317b1dc885965d6053b391db1d42bc7/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java#L222

- handleWorkerLost
https://github.com/apache/celeborn/blob/1e77f01cd317b1dc885965d6053b391db1d42bc7/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala#L762-L765
### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing UT.

Closes #2870 from turboFei/worksMap.

Lead-authored-by: Wang, Fei <[email protected]>
Co-authored-by: Fei Wang <[email protected]>
Signed-off-by: mingji <[email protected]>
  • Loading branch information
2 people authored and FMX committed Nov 1, 2024
1 parent 2b026a3 commit e2f640c
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public abstract class AbstractMetaManager implements IMetadataHandler {
public final Map<String, Set<Integer>> registeredAppAndShuffles =
JavaUtils.newConcurrentHashMap();
public final Set<String> hostnameSet = ConcurrentHashMap.newKeySet();
public final Set<WorkerInfo> workers = ConcurrentHashMap.newKeySet();
public final Map<String, WorkerInfo> workersMap = JavaUtils.newConcurrentHashMap();

public final ConcurrentHashMap<WorkerInfo, Long> lostWorkers = JavaUtils.newConcurrentHashMap();
public final ConcurrentHashMap<WorkerInfo, WorkerEventInfo> workerEventInfos =
Expand Down Expand Up @@ -170,8 +170,8 @@ public void updateWorkerLostMeta(
WorkerInfo worker = new WorkerInfo(host, rpcPort, pushPort, fetchPort, replicatePort);
workerLostEvents.add(worker);
// remove worker from workers
synchronized (workers) {
workers.remove(worker);
synchronized (workersMap) {
workersMap.remove(worker.toUniqueId());
lostWorkers.put(worker, System.currentTimeMillis());
}
excludedWorkers.remove(worker);
Expand All @@ -182,15 +182,15 @@ public void updateWorkerRemoveMeta(
String host, int rpcPort, int pushPort, int fetchPort, int replicatePort) {
WorkerInfo worker = new WorkerInfo(host, rpcPort, pushPort, fetchPort, replicatePort);
// remove worker from workers
synchronized (workers) {
workers.remove(worker);
synchronized (workersMap) {
workersMap.remove(worker.toUniqueId());
lostWorkers.put(worker, System.currentTimeMillis());
}
excludedWorkers.remove(worker);
}

public void removeWorkersUnavailableInfoMeta(List<WorkerInfo> unavailableWorkers) {
synchronized (workers) {
synchronized (workersMap) {
for (WorkerInfo workerInfo : unavailableWorkers) {
if (lostWorkers.containsKey(workerInfo)) {
lostWorkers.remove(workerInfo);
Expand Down Expand Up @@ -219,8 +219,8 @@ public void updateWorkerHeartbeatMeta(
host, rpcPort, pushPort, fetchPort, replicatePort, -1, disks, userResourceConsumption);
AtomicLong availableSlots = new AtomicLong();
LOG.debug("update worker {}:{} heartbeat {}", host, rpcPort, disks);
synchronized (workers) {
Optional<WorkerInfo> workerInfo = workers.stream().filter(w -> w.equals(worker)).findFirst();
synchronized (workersMap) {
Optional<WorkerInfo> workerInfo = Optional.ofNullable(workersMap.get(worker.toUniqueId()));
workerInfo.ifPresent(
info -> {
info.updateThenGetDiskInfos(disks, Option.apply(estimatedPartitionSize));
Expand Down Expand Up @@ -287,10 +287,8 @@ public void updateRegisterWorkerMeta(
workerInfo.networkLocation_$eq(rackResolver.resolve(host).getNetworkLocation());
}
workerInfo.updateDiskMaxSlots(estimatedPartitionSize);
synchronized (workers) {
if (!workers.contains(workerInfo)) {
workers.add(workerInfo);
}
synchronized (workersMap) {
workersMap.putIfAbsent(workerInfo.toUniqueId(), workerInfo);
shutdownWorkers.remove(workerInfo);
lostWorkers.remove(workerInfo);
excludedWorkers.remove(workerInfo);
Expand All @@ -315,7 +313,7 @@ public void writeMetaInfoToFile(File file) throws IOException, RuntimeException
manuallyExcludedWorkers,
workerLostEvents,
appHeartbeatTime,
workers,
new HashSet(workersMap.values()),
partitionTotalWritten.sum(),
partitionTotalFileCount.sum(),
appDiskUsageMetric.snapShots(),
Expand Down Expand Up @@ -381,7 +379,7 @@ public void restoreMetaFromFile(File file) throws IOException {
.collect(Collectors.toList());
scala.collection.immutable.Map<String, Node> resolveMap =
rackResolver.resolveToMap(workerHostList);
workers.addAll(
workersMap.putAll(
workerInfoSet.stream()
.peek(
workerInfo -> {
Expand All @@ -391,7 +389,7 @@ public void restoreMetaFromFile(File file) throws IOException {
resolveMap.get(workerInfo.host()).get().getNetworkLocation());
}
})
.collect(Collectors.toSet()));
.collect(Collectors.toMap(WorkerInfo::toUniqueId, w -> w)));

snapshotMetaInfo
.getLostWorkersMap()
Expand Down Expand Up @@ -437,19 +435,19 @@ public void restoreMetaFromFile(File file) throws IOException {
LOG.info("Successfully restore meta info from snapshot {}", file.getAbsolutePath());
LOG.info(
"Worker size: {}, Registered shuffle size: {}. Worker excluded list size: {}. Manually Excluded list size: {}",
workers.size(),
workersMap.size(),
registeredAppAndShuffles.size(),
excludedWorkers.size(),
manuallyExcludedWorkers.size());
workers.forEach(workerInfo -> LOG.info(workerInfo.toString()));
workersMap.values().forEach(workerInfo -> LOG.info(workerInfo.toString()));
registeredAppAndShuffles.forEach(
(appId, shuffleId) -> LOG.info("RegisteredShuffle {}-{}", appId, shuffleId));
}

private void cleanUpState() {
registeredAppAndShuffles.clear();
hostnameSet.clear();
workers.clear();
workersMap.clear();
lostWorkers.clear();
appHeartbeatTime.clear();
excludedWorkers.clear();
Expand All @@ -464,7 +462,7 @@ private void cleanUpState() {
}

public void updateMetaByReportWorkerUnavailable(List<WorkerInfo> failedWorkers) {
synchronized (this.workers) {
synchronized (this.workersMap) {
shutdownWorkers.addAll(failedWorkers);
}
}
Expand All @@ -473,7 +471,7 @@ public void updateWorkerEventMeta(int workerEventTypeValue, List<WorkerInfo> wor
long eventTime = System.currentTimeMillis();
ResourceProtos.WorkerEventType eventType =
ResourceProtos.WorkerEventType.forNumber(workerEventTypeValue);
synchronized (this.workers) {
synchronized (this.workersMap) {
for (WorkerInfo workerInfo : workerInfoList) {
WorkerEventInfo workerEventInfo = workerEventInfos.get(workerInfo);
LOG.info("Received worker event: {} for worker: {}", eventType, workerInfo.toUniqueId());
Expand All @@ -489,7 +487,7 @@ public void updateWorkerEventMeta(int workerEventTypeValue, List<WorkerInfo> wor
}

public void updateMetaByReportWorkerDecommission(List<WorkerInfo> workers) {
synchronized (this.workers) {
synchronized (this.workersMap) {
decommissionWorkers.addAll(workers);
}
}
Expand Down Expand Up @@ -520,19 +518,19 @@ public void updatePartitionSize() {
"Celeborn cluster estimated partition size changed from {} to {}",
Utils.bytesToString(oldEstimatedPartitionSize),
Utils.bytesToString(estimatedPartitionSize));
workers.stream()
.filter(
worker ->
!excludedWorkers.contains(worker) && !manuallyExcludedWorkers.contains(worker))
.forEach(workerInfo -> workerInfo.updateDiskMaxSlots(estimatedPartitionSize));

HashSet<WorkerInfo> workers = new HashSet(workersMap.values());
excludedWorkers.forEach(workers::remove);
manuallyExcludedWorkers.forEach(workers::remove);
workers.forEach(workerInfo -> workerInfo.updateDiskMaxSlots(estimatedPartitionSize));
}

public boolean isWorkerAvailable(WorkerInfo workerInfo) {
return !excludedWorkers.contains(workerInfo)
return (workerInfo.getWorkerStatus().getState() == PbWorkerStatus.State.Normal
&& !workerEventInfos.containsKey(workerInfo))
&& !excludedWorkers.contains(workerInfo)
&& !shutdownWorkers.contains(workerInfo)
&& !manuallyExcludedWorkers.contains(workerInfo)
&& (!workerEventInfos.containsKey(workerInfo)
&& workerInfo.getWorkerStatus().getState() == PbWorkerStatus.State.Normal);
&& !manuallyExcludedWorkers.contains(workerInfo);
}

public void updateApplicationMeta(ApplicationMeta applicationMeta) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ private[celeborn] class Master(
masterSource.addGauge(MasterSource.REGISTERED_SHUFFLE_COUNT) { () =>
statusSystem.registeredShuffleCount
}
masterSource.addGauge(MasterSource.WORKER_COUNT) { () => statusSystem.workers.size }
masterSource.addGauge(MasterSource.WORKER_COUNT) { () => statusSystem.workersMap.size }
masterSource.addGauge(MasterSource.LOST_WORKER_COUNT) { () => statusSystem.lostWorkers.size }
masterSource.addGauge(MasterSource.EXCLUDED_WORKER_COUNT) { () =>
statusSystem.excludedWorkers.size + statusSystem.manuallyExcludedWorkers.size
}
masterSource.addGauge(MasterSource.AVAILABLE_WORKER_COUNT) { () =>
statusSystem.workers.asScala.count { w =>
statusSystem.workersMap.values().asScala.count { w =>
statusSystem.isWorkerAvailable(w)
}
}
Expand All @@ -242,7 +242,7 @@ private[celeborn] class Master(
}
masterSource.addGauge(MasterSource.PARTITION_SIZE) { () => statusSystem.estimatedPartitionSize }
masterSource.addGauge(MasterSource.ACTIVE_SHUFFLE_SIZE) { () =>
statusSystem.workers.parallelStream()
statusSystem.workersMap.values().parallelStream()
.mapToLong(new ToLongFunction[WorkerInfo]() {
override def applyAsLong(value: WorkerInfo): Long =
value.userResourceConsumption.values().parallelStream()
Expand All @@ -252,7 +252,7 @@ private[celeborn] class Master(
}).sum()
}
masterSource.addGauge(MasterSource.ACTIVE_SHUFFLE_FILE_COUNT) { () =>
statusSystem.workers.parallelStream()
statusSystem.workersMap.values().parallelStream()
.mapToLong(new ToLongFunction[WorkerInfo]() {
override def applyAsLong(value: WorkerInfo): Long =
value.userResourceConsumption.values().parallelStream()
Expand All @@ -263,11 +263,11 @@ private[celeborn] class Master(
}

masterSource.addGauge(MasterSource.DEVICE_CELEBORN_TOTAL_CAPACITY) { () =>
statusSystem.workers.asScala.toList.map(_.totalSpace()).sum
statusSystem.workersMap.values().asScala.toList.map(_.totalSpace()).sum
}

masterSource.addGauge(MasterSource.DEVICE_CELEBORN_FREE_CAPACITY) { () =>
statusSystem.workers.asScala.toList.map(_.totalActualUsableSpace()).sum
statusSystem.workersMap.values().asScala.toList.map(_.totalActualUsableSpace()).sum
}

masterSource.addGauge(MasterSource.IS_ACTIVE_MASTER) { () => isMasterActive }
Expand Down Expand Up @@ -596,7 +596,7 @@ private[celeborn] class Master(
return
}

statusSystem.workers.asScala.foreach { worker =>
statusSystem.workersMap.values().asScala.foreach { worker =>
if (worker.lastHeartbeat < currentTime - workerHeartbeatTimeoutMs
&& !statusSystem.workerLostEvents.contains(worker)) {
logWarning(s"Worker ${worker.readableAddress()} timeout! Trigger WorkerLost event.")
Expand Down Expand Up @@ -635,18 +635,18 @@ private[celeborn] class Master(
if (HAHelper.getAppTimeoutDeadline(statusSystem) > currentTime) {
return
}
statusSystem.appHeartbeatTime.keySet().asScala.foreach { key =>
if (statusSystem.appHeartbeatTime.get(key) < currentTime - appHeartbeatTimeoutMs) {
logWarning(s"Application $key timeout, trigger applicationLost event.")
statusSystem.appHeartbeatTime.asScala.foreach { case (appId, heartbeatTime) =>
if (heartbeatTime < currentTime - appHeartbeatTimeoutMs) {
logWarning(s"Application $appId timeout, trigger applicationLost event.")
val requestId = MasterClient.genRequestId()
var res = self.askSync[ApplicationLostResponse](ApplicationLost(key, requestId))
var res = self.askSync[ApplicationLostResponse](ApplicationLost(appId, requestId))
var retry = 1
while (res.status != StatusCode.SUCCESS && retry <= 3) {
res = self.askSync[ApplicationLostResponse](ApplicationLost(key, requestId))
res = self.askSync[ApplicationLostResponse](ApplicationLost(appId, requestId))
retry += 1
}
if (retry > 3) {
logWarning(s"Handle ApplicationLost event for $key failed more than 3 times!")
logWarning(s"Handle ApplicationLost event for $appId failed more than 3 times!")
}
}
}
Expand All @@ -667,7 +667,7 @@ private[celeborn] class Master(
workerStatus: WorkerStatus,
requestId: String): Unit = {
val targetWorker = new WorkerInfo(host, rpcPort, pushPort, fetchPort, replicatePort)
val registered = statusSystem.workers.asScala.contains(targetWorker)
val registered = statusSystem.workersMap.containsKey(targetWorker.toUniqueId())
if (!registered) {
logWarning(s"Received heartbeat from unknown worker " +
s"$host:$rpcPort:$pushPort:$fetchPort:$replicatePort.")
Expand Down Expand Up @@ -758,10 +758,7 @@ private[celeborn] class Master(
-1,
new util.HashMap[String, DiskInfo](),
JavaUtils.newConcurrentHashMap[UserIdentifier, ResourceConsumption]())
val worker: WorkerInfo = statusSystem.workers
.asScala
.find(_ == targetWorker)
.orNull
val worker: WorkerInfo = statusSystem.workersMap.get(targetWorker.toUniqueId())
if (worker == null) {
logWarning(s"Unknown worker $host:$rpcPort:$pushPort:$fetchPort:$replicatePort" +
s" for WorkerLost handler!")
Expand Down Expand Up @@ -806,7 +803,7 @@ private[celeborn] class Master(
return
}

if (statusSystem.workers.contains(workerToRegister)) {
if (statusSystem.workersMap.containsKey(workerToRegister.toUniqueId())) {
logWarning(s"Receive RegisterWorker while worker" +
s" ${workerToRegister.toString()} already exists, re-register.")
statusSystem.handleRegisterWorker(
Expand Down Expand Up @@ -908,7 +905,7 @@ private[celeborn] class Master(
// offer slots
val slots =
masterSource.sample(MasterSource.OFFER_SLOTS_TIME, s"offerSlots-${Random.nextInt()}") {
statusSystem.workers.synchronized {
statusSystem.workersMap.synchronized {
if (slotsAssignPolicy == SlotsAssignPolicy.LOADAWARE) {
SlotsAllocator.offerSlotsLoadAware(
selectedWorkers,
Expand Down Expand Up @@ -1121,24 +1118,24 @@ private[celeborn] class Master(
fileCount,
System.currentTimeMillis(),
requestId)
// unknown workers will retain in needCheckedWorkerList
needCheckedWorkerList.removeAll(statusSystem.workers)
val unknownWorkers = needCheckedWorkerList.asScala.filterNot(w =>
statusSystem.workersMap.containsKey(w.toUniqueId())).asJava
if (shouldResponse) {
// UserResourceConsumption and DiskInfo are eliminated from WorkerInfo
// during serialization of HeartbeatFromApplicationResponse
var availableWorksSentToClient = new util.ArrayList[WorkerInfo]()
if (needAvailableWorkers) {
availableWorksSentToClient = new util.ArrayList[WorkerInfo](
statusSystem.workers.asScala.filter(worker =>
statusSystem.isWorkerAvailable(worker)).asJava)
statusSystem.workersMap.values().asScala.filter(worker =>
statusSystem.isWorkerAvailable(worker)).toList.asJava)
}
var appRelatedShuffles =
val appRelatedShuffles =
statusSystem.registeredAppAndShuffles.getOrDefault(appId, Collections.emptySet())
context.reply(HeartbeatFromApplicationResponse(
StatusCode.SUCCESS,
new util.ArrayList(
(statusSystem.excludedWorkers.asScala ++ statusSystem.manuallyExcludedWorkers.asScala).asJava),
needCheckedWorkerList,
unknownWorkers,
new util.ArrayList[WorkerInfo](
(statusSystem.shutdownWorkers.asScala ++ statusSystem.decommissionWorkers.asScala).asJava),
availableWorksSentToClient,
Expand Down Expand Up @@ -1215,7 +1212,7 @@ private[celeborn] class Master(
// TODO: Support calculate topN app resource consumption.
private def computeUserResourceConsumption(
userIdentifier: UserIdentifier): ResourceConsumption = {
val resourceConsumption = statusSystem.workers.asScala.flatMap {
val resourceConsumption = statusSystem.workersMap.values().asScala.flatMap {
workerInfo => workerInfo.userResourceConsumption.asScala.get(userIdentifier)
}.foldRight(ResourceConsumption(0, 0, 0, 0))(_ add _)
resourceConsumption
Expand Down Expand Up @@ -1249,7 +1246,7 @@ private[celeborn] class Master(

private def workersAvailable(
tmpExcludedWorkerList: Set[WorkerInfo] = Set.empty): util.List[WorkerInfo] = {
statusSystem.workers.asScala.filter { w =>
statusSystem.workersMap.values().asScala.filter { w =>
statusSystem.isWorkerAvailable(w) && !tmpExcludedWorkerList.contains(w)
}.toList.asJava
}
Expand Down Expand Up @@ -1282,7 +1279,7 @@ private[celeborn] class Master(
}

private def getWorkers: String = {
statusSystem.workers.asScala.mkString("\n")
statusSystem.workersMap.values().asScala.mkString("\n")
}

override def handleWorkerEvent(
Expand Down Expand Up @@ -1411,7 +1408,8 @@ private[celeborn] class Master(
",")} and remove ${removeWorkers.map(_.readableAddress).mkString(",")}.\n")
}
val unknownExcludedWorkers =
(addWorkers ++ removeWorkers).filter(!statusSystem.workers.contains(_))
(addWorkers ++ removeWorkers).filterNot(w =>
statusSystem.workersMap.containsKey(w.toUniqueId()))
if (unknownExcludedWorkers.nonEmpty) {
sb.append(
s"Unknown workers ${unknownExcludedWorkers.map(_.readableAddress).mkString(",")}." +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class WorkerResource extends ApiRequestContext {
@GET
def workers: WorkersResponse = {
new WorkersResponse()
.workers(statusSystem.workers.asScala.map(ApiUtils.workerData).toSeq.asJava)
.workers(statusSystem.workersMap.values().asScala.map(ApiUtils.workerData).toSeq.asJava)
.lostWorkers(statusSystem.lostWorkers.asScala.toSeq.sortBy(_._2)
.map(kv =>
new WorkerTimestampData().worker(ApiUtils.workerData(kv._1)).timestamp(kv._2)).asJava)
Expand Down Expand Up @@ -134,7 +134,8 @@ class WorkerResource extends ApiRequestContext {
s"eventType(${request.getEventType}) and workers(${request.getWorkers}) are required")
}
val workers = request.getWorkers.asScala.map(ApiUtils.toWorkerInfo).toSeq
val (filteredWorkers, unknownWorkers) = workers.partition(statusSystem.workers.contains)
val (filteredWorkers, unknownWorkers) =
workers.partition(w => statusSystem.workersMap.containsKey(w.toUniqueId()))
if (filteredWorkers.isEmpty) {
throw new BadRequestException(
s"None of the workers are known: ${unknownWorkers.map(_.readableAddress).mkString(", ")}")
Expand Down
Loading

0 comments on commit e2f640c

Please sign in to comment.