Skip to content

Commit

Permalink
refine uts
Browse files Browse the repository at this point in the history
  • Loading branch information
FMX committed Jan 20, 2025
1 parent 497a33e commit e8f428e
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 93 deletions.
5 changes: 5 additions & 0 deletions client-spark/spark-3/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,10 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -58,34 +58,20 @@ class CelebornShuffleReader[K, C](
shuffleIdTracker: ExecutorShuffleIdTracker)
extends ShuffleReader[K, C] with Logging {

val mockReader = conf.testMockShuffleReader

private val dep =
if (mockReader) {
null
} else {
handle.dependency
}
private val dep = handle.dependency

@VisibleForTesting
val shuffleClient =
if (mockReader) {
new DummyShuffleClient(conf, Files.createTempFile("test", "mockfile").toFile)
} else {
ShuffleClient.get(
handle.appUniqueId,
handle.lifecycleManagerHost,
handle.lifecycleManagerPort,
conf,
handle.userIdentifier,
handle.extension)
}
val shuffleClient = ShuffleClient.get(
handle.appUniqueId,
handle.lifecycleManagerHost,
handle.lifecycleManagerPort,
conf,
handle.userIdentifier,
handle.extension)

private val exceptionRef = new AtomicReference[IOException]
private val throwsFetchFailure =
if (mockReader) conf.clientFetchThrowsFetchFailure else handle.throwsFetchFailure
private val encodedAttemptId =
if (mockReader) 0 else SparkCommonUtils.getEncodedAttemptNumber(context)
private val throwsFetchFailure = handle.throwsFetchFailure
private val encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(context)

override def read(): Iterator[Product2[K, C]] = {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@

package org.apache.spark.shuffle.celeborn

import java.nio.file.Files
import java.util.concurrent.TimeoutException

import org.apache.spark.TaskContext
import org.apache.spark.{Dependency, ShuffleDependency, TaskContext}
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito
import org.mockito.Mockito._
import org.scalatest.funsuite.AnyFunSuite

import org.apache.celeborn.client.DummyShuffleClient
import org.apache.celeborn.client.{DummyShuffleClient, ShuffleClient}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.CelebornIOException
import org.apache.celeborn.common.identity.UserIdentifier

class CelebornShuffleReaderSuite extends AnyFunSuite {

Expand All @@ -35,12 +39,24 @@ class CelebornShuffleReaderSuite extends AnyFunSuite {
* test the method `checkAndReportFetchFailureForUpdateFileGroupFailure`
*/
test("CELEBORN-1838 test check report fetch failure exceptions ") {
val handler = Mockito.mock(classOf[CelebornShuffleHandle[Int, Int, Int]])
val dependency = Mockito.mock(classOf[ShuffleDependency[Int, Int, Int]])
val handler = new CelebornShuffleHandle[Int, Int, Int](
"APP",
"HOST1",
1,
UserIdentifier.apply("a", "b"),
0,
true,
1,
dependency)
val context = Mockito.mock(classOf[TaskContext])
val metricReporter = Mockito.mock(classOf[ShuffleReadMetricsReporter])
val conf = new CelebornConf()
conf.set("celeborn.test.client.mockShuffleReader", "true")
conf.set("celeborn.client.spark.fetch.throwsFetchFailure", "true")

val tmpFile = Files.createTempFile("test", ".tmp").toFile
mockStatic(classOf[ShuffleClient]).when(() =>
ShuffleClient.get(any(), any(), any(), any(), any(), any())).thenReturn(
new DummyShuffleClient(conf, tmpFile))

val shuffleReader =
new CelebornShuffleReader[Int, Int](handler, 0, 0, 0, 0, context, conf, metricReporter, null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ class ReducePartitionCommitHandler(
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime

private val mockGetReducerFileGroupDelay = conf.testMockGetReducerFileGroupDelay
private val mockGetReducerFileGroupDelayedShuffle = conf.testMockGetReducerFileGroupDelayedShuffle
private val mockGetReducerFileGroupDelayedShuffleProbabilty =
conf.testMockGetReducerFileGroupDelayedProbability

// noinspection UnstableApiUsage
private val getReducerFileGroupRpcCache: Cache[Int, ByteBuffer] = CacheBuilder.newBuilder()
.concurrencyLevel(rpcCacheConcurrencyLevel)
Expand Down Expand Up @@ -309,16 +304,6 @@ class ReducePartitionCommitHandler(
}
})

if (mockGetReducerFileGroupDelay.isDefined && mockGetReducerFileGroupDelayedShuffle.get == shuffleId) {
val maxDelay = mockGetReducerFileGroupDelay.get
val probability = mockGetReducerFileGroupDelayedShuffleProbabilty.get
if (Random.nextDouble() < probability) {
// Int.Max means 35791 hours, will not try to set any configurations like this
val delayTime = Random.nextInt(maxDelay.toInt)
logInfo(s"Mock get reducer file group delayed, ${shuffleId} ${delayTime}")
Thread.sleep(delayTime)
}
}
context.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1247,13 +1247,6 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def testMockCommitFilesFailure: Boolean = get(TEST_MOCK_COMMIT_FILES_FAILURE)
def testMockShuffleLost: Boolean = get(TEST_CLIENT_MOCK_SHUFFLE_LOST)
def testMockShuffleLostShuffle: Int = get(TEST_CLIENT_MOCK_SHUFFLE_LOST_SHUFFLE)
def testMockShuffleReader: Boolean = get(TEST_CLIENT_MOCK_SHUFFLE_READER)
def testMockGetReducerFileGroupDelay: Option[Long] =
get(TEST_CLIENT_MOCK_GETREDUCERFILEGROUP_DELAY)
def testMockGetReducerFileGroupDelayedShuffle: Option[Int] =
get(TEST_CLIENT_MOCK_GETREDUCERFILEGROUP_DELAYED_SHUFFLE)
def testMockGetReducerFileGroupDelayedProbability: Option[Double] =
get(TEST_CLIENT_MOCK_GETREDUCERFILEGROUP_DELAY_PROBABILTY)
def testPushPrimaryDataTimeout: Boolean = get(TEST_CLIENT_PUSH_PRIMARY_DATA_TIMEOUT)
def testPushReplicaDataTimeout: Boolean = get(TEST_WORKER_PUSH_REPLICA_DATA_TIMEOUT)
def testRetryRevive: Boolean = get(TEST_CLIENT_RETRY_REVIVE)
Expand Down Expand Up @@ -3745,47 +3738,6 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(0)

val TEST_CLIENT_MOCK_SHUFFLE_READER: ConfigEntry[Boolean] =
buildConf("celeborn.test.client.mockShuffleReader")
.internal
.categories("test", "client")
.doc("Mock shuffle reader for shuffle")
.version("0.5.3")
.internal
.booleanConf
.createWithDefault(false)

val TEST_CLIENT_MOCK_GETREDUCERFILEGROUP_DELAY: OptionalConfigEntry[Long] =
buildConf("celeborn.test.client.mockGetReducerFileGroupDelay")
.internal
.categories("test", "client")
.doc("Mock get reducer file group response delay")
.version("0.5.4")
.internal
.timeConf(TimeUnit.MILLISECONDS)
.createOptional

val TEST_CLIENT_MOCK_GETREDUCERFILEGROUP_DELAYED_SHUFFLE: OptionalConfigEntry[Int] =
buildConf("celeborn.test.client.mockGetReducerFileGroupDelayedShuffle")
.internal
.categories("test", "client")
.doc("Mock get reducer file group response delay for certain shuffle")
.version("0.5.4")
.internal
.intConf
.createOptional

val TEST_CLIENT_MOCK_GETREDUCERFILEGROUP_DELAY_PROBABILTY: OptionalConfigEntry[Double] =
buildConf("celeborn.test.client.mockGetReducerFileGroupDelayProbability")
.internal
.categories("test", "client")
.doc("Mock get reducer file group response delay probability")
.version("0.5.4")
.internal
.doubleConf
.checkValue(v => v >= 0 && v <= 1, "must be in range [0, 1]")
.createOptional

val CLIENT_PUSH_REPLICATE_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.client.push.replicate.enabled")
.withAlternative("celeborn.push.replicate.enabled")
Expand Down
2 changes: 1 addition & 1 deletion project/CelebornBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ trait SparkClientProjects {
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % sparkVersion % "provided",
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
) ++ commonUnitTestDependencies
) ++ commonUnitTestDependencies ++ Seq(Dependencies.mockitoInline % "test")
)
}

Expand Down

0 comments on commit e8f428e

Please sign in to comment.