From 657e15a44b3569aa57b0c7ae7e69af3565dc6c4b Mon Sep 17 00:00:00 2001 From: zhengtao Date: Tue, 10 Dec 2024 15:14:48 +0800 Subject: [PATCH] fix PushDatahandler reference count error --- .../deploy/worker/PushDataHandler.scala | 129 +++++++++--------- 1 file changed, 68 insertions(+), 61 deletions(-) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala index b35f730bdf9..ec173a2436b 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala @@ -262,28 +262,28 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler return } val writePromise = Promise[Unit]() - // for primary, send data to replica if (doReplicate) { + val peer = location.getPeer + val peerWorker = new WorkerInfo( + peer.getHost, + peer.getRpcPort, + peer.getPushPort, + peer.getFetchPort, + peer.getReplicatePort) + if (unavailablePeers.containsKey(peerWorker)) { + // pushData.body().release() + fileWriter.decrementPendingWrites() + workerSource.incCounter(WorkerSource.REPLICATE_DATA_CREATE_CONNECTION_FAIL_COUNT) + logError( + s"PushData replication failed caused by unavailable peer for partitionLocation: $location") + callbackWithTimer.onFailure( + new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA)) + return + } + pushData.body().retain() replicateThreadPool.submit(new Runnable { override def run(): Unit = { - val peer = location.getPeer - val peerWorker = new WorkerInfo( - peer.getHost, - peer.getRpcPort, - peer.getPushPort, - peer.getFetchPort, - peer.getReplicatePort) - if (unavailablePeers.containsKey(peerWorker)) { - pushData.body().release() - workerSource.incCounter(WorkerSource.REPLICATE_DATA_CREATE_CONNECTION_FAIL_COUNT) - logError( - s"PushData replication failed caused by unavailable peer for partitionLocation: $location") - callbackWithTimer.onFailure( - new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA)) - return - } - // Handle the response from replica val wrappedCallback = new RpcResponseCallback() { override def onSuccess(response: ByteBuffer): Unit = { @@ -323,24 +323,28 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler } override def onFailure(e: Throwable): Unit = { - logError(s"PushData replication failed for partitionLocation: $location", e) - // 1. Throw PUSH_DATA_WRITE_FAIL_REPLICA by replica peer worker - // 2. Throw PUSH_DATA_TIMEOUT_REPLICA by TransportResponseHandler - // 3. Throw IOException by channel, convert to PUSH_DATA_CONNECTION_EXCEPTION_REPLICA - if (e.getMessage.startsWith(StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA.name())) { - workerSource.incCounter(WorkerSource.REPLICATE_DATA_WRITE_FAIL_COUNT) - callbackWithTimer.onFailure(e) - } else if (e.getMessage.startsWith(StatusCode.PUSH_DATA_TIMEOUT_REPLICA.name())) { - workerSource.incCounter(WorkerSource.REPLICATE_DATA_TIMEOUT_COUNT) - callbackWithTimer.onFailure(e) - } else if (ExceptionUtils.connectFail(e.getMessage)) { - workerSource.incCounter(WorkerSource.REPLICATE_DATA_CONNECTION_EXCEPTION_COUNT) - callbackWithTimer.onFailure( - new CelebornIOException(StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_REPLICA)) - } else { - workerSource.incCounter(WorkerSource.REPLICATE_DATA_FAIL_NON_CRITICAL_CAUSE_COUNT) - callbackWithTimer.onFailure( - new CelebornIOException(StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA)) + Try(Await.result(writePromise.future, Duration.Inf)) match { + case _ => + logError(s"PushData replication failed for partitionLocation: $location", e) + // 1. Throw PUSH_DATA_WRITE_FAIL_REPLICA by replica peer worker + // 2. Throw PUSH_DATA_TIMEOUT_REPLICA by TransportResponseHandler + // 3. Throw IOException by channel, convert to PUSH_DATA_CONNECTION_EXCEPTION_REPLICA + if (e.getMessage.startsWith(StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA.name())) { + workerSource.incCounter(WorkerSource.REPLICATE_DATA_WRITE_FAIL_COUNT) + callbackWithTimer.onFailure(e) + } else if (e.getMessage.startsWith(StatusCode.PUSH_DATA_TIMEOUT_REPLICA.name())) { + workerSource.incCounter(WorkerSource.REPLICATE_DATA_TIMEOUT_COUNT) + callbackWithTimer.onFailure(e) + } else if (ExceptionUtils.connectFail(e.getMessage)) { + workerSource.incCounter(WorkerSource.REPLICATE_DATA_CONNECTION_EXCEPTION_COUNT) + callbackWithTimer.onFailure( + new CelebornIOException(StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_REPLICA)) + } else { + workerSource.incCounter( + WorkerSource.REPLICATE_DATA_FAIL_NON_CRITICAL_CAUSE_COUNT) + callbackWithTimer.onFailure( + new CelebornIOException(StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA)) + } } } } @@ -360,7 +364,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler logError( s"PushData replication failed during connecting peer for partitionLocation: $location", e) - callbackWithTimer.onFailure( + wrappedCallback.onFailure( new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA)) } } @@ -546,9 +550,6 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler val writePromise = Promise[Unit]() // for primary, send data to replica if (doReplicate) { - pushMergedData.body().retain() - replicateThreadPool.submit(new Runnable { - override def run(): Unit = { val location = partitionIdToLocations.head._2 val peer = location.getPeer val peerWorker = new WorkerInfo( @@ -558,7 +559,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler peer.getFetchPort, peer.getReplicatePort) if (unavailablePeers.containsKey(peerWorker)) { - pushMergedData.body().release() + fileWriters.foreach(_.decrementPendingWrites()) workerSource.incCounter(WorkerSource.REPLICATE_DATA_CREATE_CONNECTION_FAIL_COUNT) logError( s"PushMergedData replication failed caused by unavailable peer for partitionLocation: $location") @@ -566,7 +567,9 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA)) return } - + pushMergedData.body().retain() + replicateThreadPool.submit(new Runnable { + override def run(): Unit = { // Handle the response from replica val wrappedCallback = new RpcResponseCallback() { override def onSuccess(response: ByteBuffer): Unit = { @@ -600,24 +603,28 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler } override def onFailure(e: Throwable): Unit = { - logError(s"PushMergedData replicate failed for partitionLocation: $location", e) - // 1. Throw PUSH_DATA_WRITE_FAIL_REPLICA by replica peer worker - // 2. Throw PUSH_DATA_TIMEOUT_REPLICA by TransportResponseHandler - // 3. Throw IOException by channel, convert to PUSH_DATA_CONNECTION_EXCEPTION_REPLICA - if (e.getMessage.startsWith(StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA.name())) { - workerSource.incCounter(WorkerSource.REPLICATE_DATA_WRITE_FAIL_COUNT) - callbackWithTimer.onFailure(e) - } else if (e.getMessage.startsWith(StatusCode.PUSH_DATA_TIMEOUT_REPLICA.name())) { - workerSource.incCounter(WorkerSource.REPLICATE_DATA_TIMEOUT_COUNT) - callbackWithTimer.onFailure(e) - } else if (ExceptionUtils.connectFail(e.getMessage)) { - workerSource.incCounter(WorkerSource.REPLICATE_DATA_CONNECTION_EXCEPTION_COUNT) - callbackWithTimer.onFailure( - new CelebornIOException(StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_REPLICA)) - } else { - workerSource.incCounter(WorkerSource.REPLICATE_DATA_FAIL_NON_CRITICAL_CAUSE_COUNT) - callbackWithTimer.onFailure( - new CelebornIOException(StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA)) + Try(Await.result(writePromise.future, Duration.Inf)) match { + case _ => + logError(s"PushMergedData replicate failed for partitionLocation: $location", e) + // 1. Throw PUSH_DATA_WRITE_FAIL_REPLICA by replica peer worker + // 2. Throw PUSH_DATA_TIMEOUT_REPLICA by TransportResponseHandler + // 3. Throw IOException by channel, convert to PUSH_DATA_CONNECTION_EXCEPTION_REPLICA + if (e.getMessage.startsWith(StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA.name())) { + workerSource.incCounter(WorkerSource.REPLICATE_DATA_WRITE_FAIL_COUNT) + callbackWithTimer.onFailure(e) + } else if (e.getMessage.startsWith(StatusCode.PUSH_DATA_TIMEOUT_REPLICA.name())) { + workerSource.incCounter(WorkerSource.REPLICATE_DATA_TIMEOUT_COUNT) + callbackWithTimer.onFailure(e) + } else if (ExceptionUtils.connectFail(e.getMessage)) { + workerSource.incCounter(WorkerSource.REPLICATE_DATA_CONNECTION_EXCEPTION_COUNT) + callbackWithTimer.onFailure( + new CelebornIOException(StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_REPLICA)) + } else { + workerSource.incCounter( + WorkerSource.REPLICATE_DATA_FAIL_NON_CRITICAL_CAUSE_COUNT) + callbackWithTimer.onFailure( + new CelebornIOException(StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA)) + } } } } @@ -642,7 +649,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler logError( s"PushMergedData replication failed during connecting peer for partitionLocation: $location", e) - callbackWithTimer.onFailure( + wrappedCallback.onFailure( new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA)) } }