diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala index 42776a6cab0..1acce133fdb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -284,10 +284,29 @@ object GpuDeviceManager extends Logging { private var memoryEventHandler: DeviceMemoryEventHandler = _ - private def initializeRmm(gpuId: Int, rapidsConf: Option[RapidsConf]): Unit = { - if (!Rmm.isInitialized) { - val conf = rapidsConf.getOrElse(new RapidsConf(SparkEnv.get.conf)) + private def initializeSpillAndMemoryEvents(conf: RapidsConf): Unit = { + SpillFramework.initialize(conf) + + memoryEventHandler = new DeviceMemoryEventHandler( + SpillFramework.stores.deviceStore, + conf.gpuOomDumpDir, + conf.gpuOomMaxRetries) + + if (conf.sparkRmmStateEnable) { + val debugLoc = if (conf.sparkRmmDebugLocation.isEmpty) { + null + } else { + conf.sparkRmmDebugLocation + } + RmmSpark.setEventHandler(memoryEventHandler, debugLoc) + } else { + logWarning("SparkRMM retry has been disabled") + Rmm.setEventHandler(memoryEventHandler) + } + } + private def initializeRmmGpuPool(gpuId: Int, conf: RapidsConf): Unit = { + if (!Rmm.isInitialized) { val poolSize = conf.chunkedPackPoolSize chunkedPackMemoryResource = if (poolSize > 0) { @@ -391,30 +410,10 @@ object GpuDeviceManager extends Logging { } } - SpillFramework.initialize(conf) - - memoryEventHandler = new DeviceMemoryEventHandler( - SpillFramework.stores.deviceStore, - conf.gpuOomDumpDir, - conf.gpuOomMaxRetries) - - if (conf.sparkRmmStateEnable) { - val debugLoc = if (conf.sparkRmmDebugLocation.isEmpty) { - null - } else { - conf.sparkRmmDebugLocation - } - RmmSpark.setEventHandler(memoryEventHandler, debugLoc) - } else { - logWarning("SparkRMM retry has been disabled") - Rmm.setEventHandler(memoryEventHandler) - } - GpuShuffleEnv.init(conf) } } - private def initializeOffHeapLimits(gpuId: Int, rapidsConf: Option[RapidsConf]): Unit = { - val conf = rapidsConf.getOrElse(new RapidsConf(SparkEnv.get.conf)) + private def initializePinnedPoolAndOffHeapLimits(gpuId: Int, conf: RapidsConf): Unit = { val setCuioDefaultResource = conf.pinnedPoolCuioDefault val (pinnedSize, nonPinnedLimit) = if (conf.offHeapLimitEnabled) { logWarning("OFF HEAP MEMORY LIMITS IS ENABLED. " + @@ -508,8 +507,13 @@ object GpuDeviceManager extends Logging { "Cannot initialize memory due to previous shutdown failing") } else if (singletonMemoryInitialized == Uninitialized) { val gpu = gpuId.getOrElse(findGpuAndAcquire()) - initializeRmm(gpu, rapidsConf) - initializeOffHeapLimits(gpu, rapidsConf) + val conf = rapidsConf.getOrElse(new RapidsConf(SparkEnv.get.conf)) + initializePinnedPoolAndOffHeapLimits(gpu, conf) + initializeRmmGpuPool(gpu, conf) + // we want to initialize this last because we want to take advantage + // of pinned memory if it is configured + initializeSpillAndMemoryEvents(conf) + GpuShuffleEnv.init(conf) singletonMemoryInitialized = Initialized } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala index 3e0e2afa9c7..ea51bed5af5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -131,18 +131,23 @@ object GpuShuffleEnv extends Logging { def isRapidsShuffleAvailable(conf: RapidsConf): Boolean = { // the driver has `mgr` defined when this is checked val sparkEnv = SparkEnv.get - val isRapidsManager = sparkEnv.shuffleManager.isInstanceOf[RapidsShuffleManagerLike] - if (isRapidsManager) { - validateRapidsShuffleManager(sparkEnv.shuffleManager.getClass.getName) + if (sparkEnv == null) { + // we may hit this in some tests that don't need to use the RAPIDS shuffle manager. + false + } else { + val isRapidsManager = sparkEnv.shuffleManager.isInstanceOf[RapidsShuffleManagerLike] + if (isRapidsManager) { + validateRapidsShuffleManager(sparkEnv.shuffleManager.getClass.getName) + } + // executors have `env` defined when this is checked + // in tests + val isConfiguredInEnv = Option(env).exists(_.isRapidsShuffleConfigured) + (isConfiguredInEnv || isRapidsManager) && + (conf.isMultiThreadedShuffleManagerMode || + (conf.isGPUShuffle && !isExternalShuffleEnabled && + !isSparkAuthenticateEnabled)) && + conf.isSqlExecuteOnGPU } - // executors have `env` defined when this is checked - // in tests - val isConfiguredInEnv = Option(env).exists(_.isRapidsShuffleConfigured) - (isConfiguredInEnv || isRapidsManager) && - (conf.isMultiThreadedShuffleManagerMode || - (conf.isGPUShuffle && !isExternalShuffleEnabled && - !isSparkAuthenticateEnabled)) && - conf.isSqlExecuteOnGPU } def useGPUShuffle(conf: RapidsConf): Boolean = { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala index deb0860f337..ea9a26b443f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BatchWithPartitionDataSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2023-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ import org.apache.spark.unsafe.types.UTF8String /** * Unit tests for utility methods in [[ BatchWithPartitionDataUtils ]] */ -class BatchWithPartitionDataSuite extends RmmSparkRetrySuiteBase with SparkQueryCompareTestSuite { +class BatchWithPartitionDataSuite extends RmmSparkRetrySuiteBase { test("test splitting partition data into groups") { val maxGpuColumnSizeBytes = 1000L @@ -55,48 +55,46 @@ class BatchWithPartitionDataSuite extends RmmSparkRetrySuiteBase with SparkQuery // This test uses single-row partition values that should throw a GpuSplitAndRetryOOM exception // when a retry is forced. val maxGpuColumnSizeBytes = 1000L - withGpuSparkSession(_ => { - val (_, partValues, _, partSchema) = getSamplePartitionData - closeOnExcept(buildBatch(getSampleValueData)) { valueBatch => - val resultBatchIter = BatchWithPartitionDataUtils.addPartitionValuesToBatch(valueBatch, - Array(1), partValues.take(1), partSchema, maxGpuColumnSizeBytes) - RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - withResource(resultBatchIter) { _ => - assertThrows[GpuSplitAndRetryOOM] { - resultBatchIter.next() - } + val (_, partValues, _, partSchema) = getSamplePartitionData + closeOnExcept(buildBatch(getSampleValueData)) { valueBatch => + val resultBatchIter = BatchWithPartitionDataUtils.addPartitionValuesToBatch(valueBatch, + Array(1), partValues.take(1), partSchema, maxGpuColumnSizeBytes) + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + withResource(resultBatchIter) { _ => + assertThrows[GpuSplitAndRetryOOM] { + resultBatchIter.next() } } - }) + } } test("test adding partition values to batch with OOM split and retry") { // This test should split the input batch and process them when a retry is forced. val maxGpuColumnSizeBytes = 1000L - withGpuSparkSession(_ => { - val (partCols, partValues, partRows, partSchema) = getSamplePartitionData - withResource(buildBatch(getSampleValueData)) { valueBatch => - withResource(buildBatch(partCols)) { partBatch => - withResource(GpuColumnVector.combineColumns(valueBatch, partBatch)) { expectedBatch => - // we incRefCounts here because `addPartitionValuesToBatch` takes ownership of - // `valueBatch`, but we are keeping it alive since its columns are part of - // `expectedBatch` - GpuColumnVector.incRefCounts(valueBatch) - val resultBatchIter = BatchWithPartitionDataUtils.addPartitionValuesToBatch(valueBatch, - partRows, partValues, partSchema, maxGpuColumnSizeBytes) - withResource(resultBatchIter) { _ => - RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - // Assert that the final count of rows matches expected batch - // We also need to close each batch coming from `resultBatchIter`. - val rowCounts = resultBatchIter.map(withResource(_){_.numRows()}).sum - assert(rowCounts == expectedBatch.numRows()) - } + val (partCols, partValues, partRows, partSchema) = getSamplePartitionData + withResource(buildBatch(getSampleValueData)) { valueBatch => + withResource(buildBatch(partCols)) { partBatch => + withResource(GpuColumnVector.combineColumns(valueBatch, partBatch)) { expectedBatch => + // we incRefCounts here because `addPartitionValuesToBatch` takes ownership of + // `valueBatch`, but we are keeping it alive since its columns are part of + // `expectedBatch` + GpuColumnVector.incRefCounts(valueBatch) + val resultBatchIter = BatchWithPartitionDataUtils.addPartitionValuesToBatch(valueBatch, + partRows, partValues, partSchema, maxGpuColumnSizeBytes) + withResource(resultBatchIter) { _ => + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + // Assert that the final count of rows matches expected batch + // We also need to close each batch coming from `resultBatchIter`. + val rowCounts = resultBatchIter.map(withResource(_) { + _.numRows() + }).sum + assert(rowCounts == expectedBatch.numRows()) } } } - }) + } } private def getSamplePartitionData: (Array[Array[String]], Array[InternalRow], Array[Long], @@ -140,4 +138,4 @@ class BatchWithPartitionDataSuite extends RmmSparkRetrySuiteBase with SparkQuery GpuColumnVector.from(ColumnVector.fromStrings(v: _*), StringType)) new ColumnarBatch(colVectors.toArray, numRows) } -} +} \ No newline at end of file diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ShufflePartitionerRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ShufflePartitionerRetrySuite.scala index fc9c85112e5..874b1dbf655 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ShufflePartitionerRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ShufflePartitionerRetrySuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2023-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.RmmSpark -import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, ExprId, SortOrder, SpecificInternalRow} import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -36,49 +35,45 @@ class ShufflePartitionerRetrySuite extends RmmSparkRetrySuiteBase { } private def testRoundRobinPartitioner(partNum: Int) = { - TestUtils.withGpuSparkSession(new SparkConf()) { _ => - val rrp = GpuRoundRobinPartitioning(partNum) - // batch will be closed within columnarEvalAny - val batch = buildBatch - RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - var ret: Array[(ColumnarBatch, Int)] = null - try { - ret = rrp.columnarEvalAny(batch).asInstanceOf[Array[(ColumnarBatch, Int)]] - assert(partNum === ret.size) - } finally { - if (ret != null) { - ret.map(_._1).safeClose() - } + val rrp = GpuRoundRobinPartitioning(partNum) + // batch will be closed within columnarEvalAny + val batch = buildBatch + RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + var ret: Array[(ColumnarBatch, Int)] = null + try { + ret = rrp.columnarEvalAny(batch).asInstanceOf[Array[(ColumnarBatch, Int)]] + assert(partNum === ret.size) + } finally { + if (ret != null) { + ret.map(_._1).safeClose() } } } test("GPU range partition with retry") { - TestUtils.withGpuSparkSession(new SparkConf()) { _ => - // Initialize range bounds - val fieldTypes: Array[DataType] = Array(IntegerType) - val bounds = new SpecificInternalRow(fieldTypes) - bounds.setInt(0, 3) - // Initialize GPU sorter - val ref = GpuBoundReference(0, IntegerType, nullable = true)(ExprId(0), "a") - val sortOrder = SortOrder(ref, Ascending) - val attrs = AttributeReference(ref.name, ref.dataType, ref.nullable)() - val gpuSorter = new GpuSorter(Seq(sortOrder), Array(attrs)) + // Initialize range bounds + val fieldTypes: Array[DataType] = Array(IntegerType) + val bounds = new SpecificInternalRow(fieldTypes) + bounds.setInt(0, 3) + // Initialize GPU sorter + val ref = GpuBoundReference(0, IntegerType, nullable = true)(ExprId(0), "a") + val sortOrder = SortOrder(ref, Ascending) + val attrs = AttributeReference(ref.name, ref.dataType, ref.nullable)() + val gpuSorter = new GpuSorter(Seq(sortOrder), Array(attrs)) - val rp = GpuRangePartitioner(Array.apply(bounds), gpuSorter) - // batch will be closed within columnarEvalAny - val batch = buildBatch - RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - var ret: Array[(ColumnarBatch, Int)] = null - try { - ret = rp.columnarEvalAny(batch).asInstanceOf[Array[(ColumnarBatch, Int)]] - assert(ret.length === 2) - } finally { - if (ret != null) { - ret.map(_._1).safeClose() - } + val rp = GpuRangePartitioner(Array.apply(bounds), gpuSorter) + // batch will be closed within columnarEvalAny + val batch = buildBatch + RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + var ret: Array[(ColumnarBatch, Int)] = null + try { + ret = rp.columnarEvalAny(batch).asInstanceOf[Array[(ColumnarBatch, Int)]] + assert(ret.length === 2) + } finally { + if (ret != null) { + ret.map(_._1).safeClose() } } } diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/execution/InternalColumnarRDDConverterSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/execution/InternalColumnarRDDConverterSuite.scala index 2852bbd41de..adc5e8b33b0 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/execution/InternalColumnarRDDConverterSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/execution/InternalColumnarRDDConverterSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,10 +27,9 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.util.MapData import org.apache.spark.sql.types._ -class InternalColumnarRDDConverterSuite extends SparkQueryCompareTestSuite - with RmmSparkRetrySuiteBase { +class InternalColumnarRDDConverterSuite extends RmmSparkRetrySuiteBase { - def compareMapAndMapDate[K,V](map: collection.Map[K, V], mapData: MapData): Assertion = { + def compareMapAndMapDate[K, V](map: collection.Map[K, V], mapData: MapData): Assertion = { assert(map.size == mapData.numElements()) val outputMap = mutable.Map[Any, Any]() // Only String now, TODO: support other data types in Map @@ -66,12 +65,12 @@ class InternalColumnarRDDConverterSuite extends SparkQueryCompareTestSuite } test("transform boolean, byte, short, int, float, long, double, date, timestamp data" + - " back and forth between Row and Columnar") { + " back and forth between Row and Columnar") { val schema = StructType(Seq( StructField("Boolean", BooleanType), StructField("BinaryNotNull", BooleanType, nullable = false), StructField("Byte", ByteType), - StructField("ByteNotNull",ByteType, nullable = false), + StructField("ByteNotNull", ByteType, nullable = false), StructField("Short", ShortType), StructField("ShortNotNull", ShortType, nullable = false), StructField("Int", IntegerType), @@ -86,8 +85,8 @@ class InternalColumnarRDDConverterSuite extends SparkQueryCompareTestSuite StructField("DateNotNull", DateType, nullable = false), StructField("Timestamp", TimestampType), StructField("TimestampNotNull", TimestampType, nullable = false), - StructField("Decimal", DecimalType(20,10)), - StructField("DecimalNotNull", DecimalType(20,10), nullable = false) + StructField("Decimal", DecimalType(20, 10)), + StructField("DecimalNotNull", DecimalType(20, 10), nullable = false) )) val numRows = 100 val rows = GpuBatchUtilsSuite.createExternalRows(schema, numRows) @@ -113,7 +112,7 @@ class InternalColumnarRDDConverterSuite extends SparkQueryCompareTestSuite } else { if (f.dataType.isInstanceOf[DecimalType]) { assert(input.get(i) == output.get(i, f.dataType) - .asInstanceOf[Decimal].toJavaBigDecimal) + .asInstanceOf[Decimal].toJavaBigDecimal) } else { assert(input.get(i) == output.get(i, f.dataType)) } @@ -272,7 +271,7 @@ class InternalColumnarRDDConverterSuite extends SparkQueryCompareTestSuite assert(outputStructRow.isNullAt(2)) } else { assert(inputStructRow.getSeq(2) sameElements outputStructRow.getArray(2) - .toDoubleArray()) + .toDoubleArray()) } } } @@ -280,7 +279,11 @@ class InternalColumnarRDDConverterSuite extends SparkQueryCompareTestSuite } } } +} +// these tests are separated out because they are really just spark tests, and they should skip +// the initialization done in `RmmSparkRetrySuiteBase` +class InternalColumnarRDDConverterSparkSessionSuite extends SparkQueryCompareTestSuite { test("InternalColumnarRddConverter should extractRDDTable RDD[ColumnarBatch]") { withGpuSparkSession(spark => { val path = TestResourceFinder.getResourcePath("disorder-read-schema.parquet") @@ -308,5 +311,4 @@ class InternalColumnarRDDConverterSuite extends SparkQueryCompareTestSuite assert(result.forall(_ == true)) }, new SparkConf().set("spark.rapids.sql.test.allowedNonGpu", "DeserializeToObjectExec")) } - }