From 5a8926ce4384a9f276fd5008314cf846a451d09d Mon Sep 17 00:00:00 2001 From: gaochao0509 <1623735386@qq.com> Date: Wed, 15 Nov 2023 14:48:16 +0800 Subject: [PATCH] [CELEBORN-1123] Support fallback to non-columnar shuffle for schema that cannot be obtained from shuffle dependency --- .../ColumnarHashBasedShuffleWriter.java | 38 ++++++--- .../CustomShuffleDependencyUtils.java | 19 +++-- .../CelebornColumnarShuffleReader.scala | 9 +- .../columnar/CelebornBatchBuilder.scala | 4 +- .../columnar/CelebornColumnAccessor.scala | 9 +- .../columnar/CelebornColumnBuilder.scala | 8 +- .../columnar/CelebornColumnStats.scala | 59 +++---------- .../columnar/CelebornColumnType.scala | 83 +++++-------------- .../CelebornColumnarBatchBuilder.scala | 26 +----- .../CelebornColumnarBatchCodeGenBuild.scala | 15 ---- .../CelebornColumnarBatchSerializer.scala | 13 +-- .../CelebornCompressibleColumnBuilder.scala | 2 +- .../columnar/CelebornCompressionScheme.scala | 9 +- .../columnar/CelebornCompressionSchemes.scala | 16 ++-- .../ColumnarHashBasedShuffleWriterSuiteJ.java | 63 ++++++++++++-- .../CelebornColumnarShuffleReaderSuite.scala | 59 ++++++++++++- .../celeborn/CelebornShuffleReader.scala | 2 +- .../CelebornShuffleWriterSuiteBase.java | 16 ++-- 18 files changed, 224 insertions(+), 226 deletions(-) diff --git a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java index 2943d310df5..be5d78c5023 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java +++ b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java @@ -21,6 +21,8 @@ import scala.Product2; +import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.ShuffleDependency; import org.apache.spark.TaskContext; import org.apache.spark.annotation.Private; import org.apache.spark.serializer.Serializer; @@ -32,6 +34,8 @@ import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchCodeGenBuild; import org.apache.spark.sql.execution.metric.SQLMetric; import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.common.CelebornConf; @@ -39,14 +43,19 @@ @Private public class ColumnarHashBasedShuffleWriter extends HashBasedShuffleWriter { - private CelebornBatchBuilder[] celebornBatchBuilders; - private StructType schema; - private Serializer depSerializer; - private boolean isColumnarShuffle = false; - private int columnarShuffleBatchSize; - private boolean columnarShuffleCodeGenEnabled; - private boolean columnarShuffleDictionaryEnabled; - private double columnarShuffleDictionaryMaxFactor; + private static final Logger logger = + LoggerFactory.getLogger(ColumnarHashBasedShuffleWriter.class); + + private final int stageId; + private final int shuffleId; + private final CelebornBatchBuilder[] celebornBatchBuilders; + private final StructType schema; + private final Serializer depSerializer; + private final boolean isColumnarShuffle; + private final int columnarShuffleBatchSize; + private final boolean columnarShuffleCodeGenEnabled; + private final boolean columnarShuffleDictionaryEnabled; + private final double columnarShuffleDictionaryMaxFactor; public ColumnarHashBasedShuffleWriter( CelebornShuffleHandle handle, @@ -61,17 +70,21 @@ public ColumnarHashBasedShuffleWriter( columnarShuffleCodeGenEnabled = conf.columnarShuffleCodeGenEnabled(); columnarShuffleDictionaryEnabled = conf.columnarShuffleDictionaryEnabled(); columnarShuffleDictionaryMaxFactor = conf.columnarShuffleDictionaryMaxFactor(); - this.schema = CustomShuffleDependencyUtils.getSchema(handle.dependency()); + ShuffleDependency shuffleDependency = handle.dependency(); + this.stageId = taskContext.stageId(); + this.shuffleId = shuffleDependency.shuffleId(); + this.schema = CustomShuffleDependencyUtils.getSchema(shuffleDependency); this.depSerializer = handle.dependency().serializer(); this.celebornBatchBuilders = new CelebornBatchBuilder[handle.dependency().partitioner().numPartitions()]; - this.isColumnarShuffle = CelebornBatchBuilder.supportsColumnarType(schema); + this.isColumnarShuffle = schema != null && CelebornBatchBuilder.supportsColumnarType(schema); } @Override protected void fastWrite0(scala.collection.Iterator iterator) throws IOException, InterruptedException { if (isColumnarShuffle) { + logger.info("Fast columnar write of columnar shuffle {} for stage {}.", shuffleId, stageId); fastColumnarWrite0(iterator); } else { super.fastWrite0(iterator); @@ -141,4 +154,9 @@ private void closeColumnarWrite() throws IOException { } } } + + @VisibleForTesting + public boolean isColumnarShuffle() { + return isColumnarShuffle; + } } diff --git a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/CustomShuffleDependencyUtils.java b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/CustomShuffleDependencyUtils.java index ace98601b45..b5a646b062d 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/CustomShuffleDependencyUtils.java +++ b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/CustomShuffleDependencyUtils.java @@ -17,15 +17,17 @@ package org.apache.spark.shuffle.celeborn; -import java.io.IOException; - import org.apache.spark.ShuffleDependency; import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.apache.celeborn.reflect.DynFields; public class CustomShuffleDependencyUtils { + private static final Logger logger = LoggerFactory.getLogger(CustomShuffleDependencyUtils.class); + /** * Columnar Shuffle requires a field, `ShuffleDependency#schema`, which does not exist in vanilla * Spark. @@ -33,10 +35,17 @@ public class CustomShuffleDependencyUtils { private static final DynFields.UnboundField SCHEMA_FIELD = DynFields.builder().hiddenImpl(ShuffleDependency.class, "schema").defaultAlwaysNull().build(); - public static StructType getSchema(ShuffleDependency dep) throws IOException { - StructType schema = SCHEMA_FIELD.bind(dep).get(); + public static StructType getSchema(ShuffleDependency dep) { + StructType schema = null; + try { + schema = SCHEMA_FIELD.bind(dep).get(); + } catch (Exception e) { + logger.error("Failed to bind shuffle dependency of shuffle {}.", dep.shuffleId(), e); + } if (schema == null) { - throw new IOException("Failed to get Schema, columnar shuffle won't work properly."); + logger.warn( + "Failed to get Schema of shuffle {}, columnar shuffle won't work properly.", + dep.shuffleId()); } return schema; } diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala index e6f3cdccb2c..f47f9880ca8 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala @@ -46,14 +46,11 @@ class CelebornColumnarShuffleReader[K, C]( override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { val schema = CustomShuffleDependencyUtils.getSchema(dep) - if (CelebornBatchBuilder.supportsColumnarType( - schema)) { - val dataSize = SparkUtils.getDataSize( - dep.serializer.asInstanceOf[UnsafeRowSerializer]) + if (schema != null && CelebornBatchBuilder.supportsColumnarType(schema)) { + logInfo(s"Creating column batch serializer of columnar shuffle ${dep.shuffleId}.") + val dataSize = SparkUtils.getDataSize(dep.serializer.asInstanceOf[UnsafeRowSerializer]) new CelebornColumnarBatchSerializer( schema, - conf.columnarShuffleBatchSize, - conf.columnarShuffleDictionaryEnabled, conf.columnarShuffleOffHeapEnabled, dataSize).newInstance() } else { diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornBatchBuilder.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornBatchBuilder.scala index 7ae77fec030..bc93c10b5f0 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornBatchBuilder.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornBatchBuilder.scala @@ -28,7 +28,7 @@ abstract class CelebornBatchBuilder { def writeRow(row: InternalRow): Unit - def getRowCnt(): Int + def getRowCnt: Int def int2ByteArray(i: Int): Array[Byte] = { val result = new Array[Byte](4) @@ -46,7 +46,7 @@ object CelebornBatchBuilder { f.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | StringType => true - case dt: DecimalType => true + case _: DecimalType => true case _ => false }) } diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala index 064bbefc6af..a75c8d32a89 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala @@ -61,13 +61,9 @@ abstract class CelebornBasicColumnAccessor[JvmType]( columnType.extract(buffer, row, ordinal) } - protected def underlyingBuffer = buffer + protected def underlyingBuffer: ByteBuffer = buffer } -class CelebornNullColumnAccessor(buffer: ByteBuffer) - extends CelebornBasicColumnAccessor[Any](buffer, CELEBORN_NULL) - with CelebornNullableColumnAccessor - abstract class CelebornNativeColumnAccessor[T <: AtomicType]( override protected val buffer: ByteBuffer, override protected val columnType: NativeCelebornColumnType[T]) @@ -112,7 +108,6 @@ private[sql] object CelebornColumnAccessor { val buf = buffer.order(ByteOrder.nativeOrder) dataType match { - case NullType => new CelebornNullColumnAccessor(buf) case BooleanType => new CelebornBooleanColumnAccessor(buf) case ByteType => new CelebornByteColumnAccessor(buf) case ShortType => new CelebornShortColumnAccessor(buf) @@ -135,7 +130,7 @@ private[sql] object CelebornColumnAccessor { columnAccessor match { case nativeAccessor: CelebornNativeColumnAccessor[_] => nativeAccessor.decompress(columnVector, numRows) - case d: CelebornDecimalColumnAccessor => + case _: CelebornDecimalColumnAccessor => (0 until numRows).foreach(columnAccessor.extractToColumnVector(columnVector, _)) case _ => throw new RuntimeException("Not support non-primitive type now") diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala index 0abfdd0cd4e..f65a5fd8653 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala @@ -88,10 +88,6 @@ class CelebornBasicColumnBuilder[JvmType]( } } -class CelebornNullColumnBuilder - extends CelebornBasicColumnBuilder[Any](new CelebornObjectColumnStats(NullType), CELEBORN_NULL) - with CelebornNullableColumnBuilder - abstract class CelebornComplexColumnBuilder[JvmType]( columnStats: CelebornColumnStats, columnType: CelebornColumnType[JvmType]) @@ -318,7 +314,6 @@ class CelebornDecimalCodeGenColumnBuilder(dataType: DecimalType) } object CelebornColumnBuilder { - val MAX_BATCH_SIZE_IN_BYTE: Long = 4 * 1024 * 1024L def ensureFreeSpace(orig: ByteBuffer, size: Int): ByteBuffer = { if (orig.remaining >= size) { @@ -343,7 +338,6 @@ object CelebornColumnBuilder { encodingEnabled: Boolean, encoder: Encoder[_ <: AtomicType]): CelebornColumnBuilder = { val builder: CelebornColumnBuilder = dataType match { - case NullType => new CelebornNullColumnBuilder case ByteType => new CelebornByteColumnBuilder case BooleanType => new CelebornBooleanColumnBuilder case ShortType => new CelebornShortColumnBuilder @@ -367,7 +361,7 @@ object CelebornColumnBuilder { new CelebornCompactDecimalColumnBuilder(dt) case dt: DecimalType => new CelebornDecimalColumnBuilder(dt) case other => - throw new Exception(s"not support type: $other") + throw new Exception(s"Unsupported type: $other") } builder.initialize(rowCnt, columnName, encodingEnabled) diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala index 6c2aa0f7b83..b0b9f61db9a 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala @@ -63,7 +63,7 @@ final private[columnar] class CelebornBooleanColumnStats extends CelebornColumnS val value = row.getBoolean(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -87,7 +87,7 @@ final private[columnar] class CelebornByteColumnStats extends CelebornColumnStat val value = row.getByte(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -111,7 +111,7 @@ final private[columnar] class CelebornShortColumnStats extends CelebornColumnSta val value = row.getShort(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -135,7 +135,7 @@ final private[columnar] class CelebornIntColumnStats extends CelebornColumnStats val value = row.getInt(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -159,7 +159,7 @@ final private[columnar] class CelebornLongColumnStats extends CelebornColumnStat val value = row.getLong(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -183,7 +183,7 @@ final private[columnar] class CelebornFloatColumnStats extends CelebornColumnSta val value = row.getFloat(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -207,7 +207,7 @@ final private[columnar] class CelebornDoubleColumnStats extends CelebornColumnSt val value = row.getDouble(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -223,8 +223,8 @@ final private[columnar] class CelebornDoubleColumnStats extends CelebornColumnSt } final private[columnar] class CelebornStringColumnStats extends CelebornColumnStats { - protected var upper: UTF8String = null - protected var lower: UTF8String = null + protected var upper: UTF8String = _ + protected var lower: UTF8String = _ override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { @@ -232,7 +232,7 @@ final private[columnar] class CelebornStringColumnStats extends CelebornColumnSt val size = CELEBORN_STRING.actualSize(row, ordinal) gatherValueStats(value, size) } else { - gatherNullStats + gatherNullStats() } } @@ -247,34 +247,19 @@ final private[columnar] class CelebornStringColumnStats extends CelebornColumnSt Array[Any](lower, upper, nullCount, count, sizeInBytes) } -final private[columnar] class CelebornBinaryColumnStats extends CelebornColumnStats { - override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - if (!row.isNullAt(ordinal)) { - val size = CELEBORN_BINARY.actualSize(row, ordinal) - sizeInBytes += size - count += 1 - } else { - gatherNullStats - } - } - - override def collectedStatistics: Array[Any] = - Array[Any](null, null, nullCount, count, sizeInBytes) -} - final private[columnar] class CelebornDecimalColumnStats(precision: Int, scale: Int) extends CelebornColumnStats { def this(dt: DecimalType) = this(dt.precision, dt.scale) - protected var upper: Decimal = null - protected var lower: Decimal = null + protected var upper: Decimal = _ + protected var lower: Decimal = _ override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getDecimal(ordinal, precision, scale) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -294,21 +279,3 @@ final private[columnar] class CelebornDecimalColumnStats(precision: Int, scale: override def collectedStatistics: Array[Any] = Array[Any](lower, upper, nullCount, count, sizeInBytes) } - -final private[columnar] class CelebornObjectColumnStats(dataType: DataType) - extends CelebornColumnStats { - val columnType = CelebornColumnType(dataType) - - override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - if (!row.isNullAt(ordinal)) { - val size = columnType.actualSize(row, ordinal) - sizeInBytes += size - count += 1 - } else { - gatherNullStats - } - } - - override def collectedStatistics: Array[Any] = - Array[Any](null, null, nullCount, count, sizeInBytes) -} diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala index d1d5461a431..69cf10a2ea4 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.columnar import java.math.{BigDecimal, BigInteger} import java.nio.ByteBuffer -import scala.reflect.runtime.universe.TypeTag - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -177,26 +175,10 @@ sealed abstract private[columnar] class CelebornColumnType[JvmType] { override def toString: String = getClass.getSimpleName.stripSuffix("$") } -private[columnar] object CELEBORN_NULL extends CelebornColumnType[Any] { - - override def dataType: DataType = NullType - override def defaultSize: Int = 0 - override def append(v: Any, buffer: ByteBuffer): Unit = {} - override def extract(buffer: ByteBuffer): Any = null - override def setField(row: InternalRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal) - override def getField(row: InternalRow, ordinal: Int): Any = null -} - abstract private[columnar] class NativeCelebornColumnType[T <: AtomicType]( val dataType: T, val defaultSize: Int) - extends CelebornColumnType[T#InternalType] { - - /** - * Scala TypeTag. Can be used to create primitive arrays and hash tables. - */ - def scalaTag: TypeTag[dataType.InternalType] = dataType.tag -} + extends CelebornColumnType[T#InternalType] {} private[columnar] object CELEBORN_INT extends NativeCelebornColumnType(IntegerType, 4) { override def append(v: Int, buffer: ByteBuffer): Unit = { @@ -428,26 +410,28 @@ private[columnar] trait DirectCopyCelebornColumnType[JvmType] extends CelebornCo // copy the bytes from ByteBuffer to UnsafeRow override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { - if (row.isInstanceOf[MutableUnsafeRow]) { - val numBytes = buffer.getInt - val cursor = buffer.position() - buffer.position(cursor + numBytes) - row.asInstanceOf[MutableUnsafeRow].writer.write( - ordinal, - buffer.array(), - buffer.arrayOffset() + cursor, - numBytes) - } else { - setField(row, ordinal, extract(buffer)) + row match { + case r: MutableUnsafeRow => + val numBytes = buffer.getInt + val cursor = buffer.position() + buffer.position(cursor + numBytes) + r.writer.write( + ordinal, + buffer.array(), + buffer.arrayOffset() + cursor, + numBytes) + case _ => + setField(row, ordinal, extract(buffer)) } } // copy the bytes from UnsafeRow to ByteBuffer override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { - if (row.isInstanceOf[UnsafeRow]) { - row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer) - } else { - super.append(row, ordinal, buffer) + row match { + case r: UnsafeRow => + r.writeFieldTo(ordinal, buffer) + case _ => + super.append(row, ordinal, buffer) } } } @@ -472,10 +456,11 @@ private[columnar] object CELEBORN_STRING } override def setField(row: InternalRow, ordinal: Int, value: UTF8String): Unit = { - if (row.isInstanceOf[MutableUnsafeRow]) { - row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value) - } else { - row.update(ordinal, value.clone()) + row match { + case r: MutableUnsafeRow => + r.writer.write(ordinal, value) + case _ => + row.update(ordinal, value.clone()) } } @@ -617,26 +602,6 @@ sealed abstract private[columnar] class ByteArrayCelebornColumnType[JvmType](val } } -private[columnar] object CELEBORN_BINARY extends ByteArrayCelebornColumnType[Array[Byte]](16) { - - def dataType: DataType = BinaryType - - override def setField(row: InternalRow, ordinal: Int, value: Array[Byte]): Unit = { - row.update(ordinal, value) - } - - override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - row.getBinary(ordinal) - } - - override def actualSize(row: InternalRow, ordinal: Int): Int = { - row.getBinary(ordinal).length + 4 - } - - def serialize(value: Array[Byte]): Array[Byte] = value - def deserialize(bytes: Array[Byte]): Array[Byte] = bytes -} - private[columnar] case class CELEBORN_LARGE_DECIMAL(precision: Int, scale: Int) extends ByteArrayCelebornColumnType[Decimal](12) { @@ -673,7 +638,6 @@ private[columnar] object CELEBORN_LARGE_DECIMAL { private[columnar] object CelebornColumnType { def apply(dataType: DataType): CelebornColumnType[_] = { dataType match { - case NullType => CELEBORN_NULL case BooleanType => CELEBORN_BOOLEAN case ByteType => CELEBORN_BYTE case ShortType => CELEBORN_SHORT @@ -682,7 +646,6 @@ private[columnar] object CelebornColumnType { case FloatType => CELEBORN_FLOAT case DoubleType => CELEBORN_DOUBLE case StringType => CELEBORN_STRING - case BinaryType => CELEBORN_BINARY case dt: DecimalType if dt.precision <= Decimal.MAX_INT_DIGITS => CELEBORN_COMPACT_MINI_DECIMAL(dt) case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala index 159b15e327a..23a81370d45 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala @@ -30,7 +30,8 @@ class CelebornColumnarBatchBuilder( encodingEnabled: Boolean = false) extends CelebornBatchBuilder { var rowCnt = 0 - val typeConversion: PartialFunction[DataType, NativeCelebornColumnType[_ <: AtomicType]] = { + private val typeConversion + : PartialFunction[DataType, NativeCelebornColumnType[_ <: AtomicType]] = { case IntegerType => CELEBORN_INT case LongType => CELEBORN_LONG case StringType => CELEBORN_STRING @@ -45,7 +46,7 @@ class CelebornColumnarBatchBuilder( case _ => null } - val encodersArr: Array[Encoder[_ <: AtomicType]] = schema.map { attribute => + private val encodersArr: Array[Encoder[_ <: AtomicType]] = schema.map { attribute => val nativeColumnType = typeConversion(attribute.dataType) if (nativeColumnType == null) { null @@ -63,7 +64,6 @@ class CelebornColumnarBatchBuilder( var columnBuilders: Array[CelebornColumnBuilder] = _ def newBuilders(): Unit = { - totalSize = 0 rowCnt = 0 var i = -1 columnBuilders = schema.map { attribute => @@ -100,8 +100,6 @@ class CelebornColumnarBatchBuilder( giantBuffer.toByteArray } - var totalSize = 0 - def writeRow(row: InternalRow): Unit = { var i = 0 while (i < row.numFields) { @@ -111,21 +109,5 @@ class CelebornColumnarBatchBuilder( rowCnt += 1 } - def getTotalSize(): Int = { - var i = 0 - var tempTotalSize = 0 - while (i < schema.length) { - columnBuilders(i) match { - case builder: CelebornCompressibleColumnBuilder[_] => - tempTotalSize += builder.getTotalSize.toInt - case builder: CelebornNullableColumnBuilder => tempTotalSize += builder.getTotalSize.toInt - case _ => - } - i += 1 - } - totalSize = tempTotalSize + 4 + 4 * schema.length - totalSize - } - - def getRowCnt(): Int = rowCnt + def getRowCnt: Int = rowCnt } diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala index 1c15d163a2b..e510e645259 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala @@ -102,21 +102,6 @@ class CelebornColumnarBatchCodeGenBuild { val writeRowCode = new mutable.StringBuilder() for (index <- schema.indices) { schema.fields(index).dataType match { - case NullType => - initCode.append( - s""" - | ${classOf[CelebornNullColumnBuilder].getName} b$index; - """.stripMargin) - buildCode.append( - s""" - | b$index = new ${classOf[CelebornNullColumnBuilder].getName}(); - | builder.initialize($batchSize, "${schema.fields(index).name}", false); - """.stripMargin) - writeCode.append(genWriteCode(index)) - writeRowCode.append( - s""" - | b$index.appendFrom(row, $index); - """.stripMargin) case ByteType => initCode.append( s""" diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala index c4be15c0e13..f9c08a0f6a1 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala @@ -34,24 +34,18 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} class CelebornColumnarBatchSerializer( schema: StructType, - columnBatchSize: Int, - encodingEnabled: Boolean, offHeapColumnVectorEnabled: Boolean, dataSize: SQLMetric = null) extends Serializer with Serializable { override def newInstance(): SerializerInstance = new CelebornColumnarBatchSerializerInstance( schema, - columnBatchSize, - encodingEnabled, offHeapColumnVectorEnabled, dataSize) override def supportsRelocationOfSerializedObjects: Boolean = true } -private class CelebornColumnarBatchSerializerInstance( +class CelebornColumnarBatchSerializerInstance( schema: StructType, - columnBatchSize: Int, - encodingEnabled: Boolean, offHeapColumnVectorEnabled: Boolean, dataSize: SQLMetric) extends SerializerInstance { @@ -93,7 +87,8 @@ private class CelebornColumnarBatchSerializerInstance( } } - val toUnsafe: UnsafeProjection = UnsafeProjection.create(schema.fields.map(f => f.dataType)) + private val toUnsafe: UnsafeProjection = + UnsafeProjection.create(schema.fields.map(f => f.dataType)) override def deserializeStream(in: InputStream): DeserializationStream = { val numFields = schema.fields.length @@ -160,7 +155,7 @@ private class CelebornColumnarBatchSerializerInstance( try { dIn.readInt() } catch { - case e: EOFException => + case _: EOFException => dIn.close() EOF } diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala index 2d87856c199..6b7d5b50564 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala @@ -29,7 +29,7 @@ trait CelebornCompressibleColumnBuilder[T <: AtomicType] this: CelebornNativeColumnBuilder[T] with WithCelebornCompressionSchemes => - var compressionEncoder: Encoder[T] = CelebornPassThrough.encoder(columnType) + private var compressionEncoder: Encoder[T] = CelebornPassThrough.encoder(columnType) def init(encoder: Encoder[T]): Unit = { compressionEncoder = encoder diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala index a6ba31176c4..1e7ebae0e28 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.columnar -import java.nio.{ByteBuffer, ByteOrder} +import java.nio.ByteBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.WritableColumnVector @@ -76,11 +76,4 @@ object CelebornCompressionScheme { typeId, throw new UnsupportedOperationException(s"Unrecognized compression scheme type ID: $typeId")) } - - def columnHeaderSize(columnBuffer: ByteBuffer): Int = { - val header = columnBuffer.duplicate().order(ByteOrder.nativeOrder) - val nullCount = header.getInt() - // null count + null positions - 4 + 4 * nullCount - } } diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala index 316e213c8ca..c2dfb53c211 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala @@ -33,7 +33,7 @@ case object CelebornPassThrough extends CelebornCompressionScheme { override def supports(columnType: CelebornColumnType[_]): Boolean = true override def encoder[T <: AtomicType](columnType: NativeCelebornColumnType[T]): Encoder[T] = { - new this.CelebornEncoder[T](columnType) + new this.CelebornEncoder[T]() } override def decoder[T <: AtomicType]( @@ -42,7 +42,7 @@ case object CelebornPassThrough extends CelebornCompressionScheme { new this.CelebornDecoder(buffer, columnType) } - class CelebornEncoder[T <: AtomicType](columnType: NativeCelebornColumnType[T]) + class CelebornEncoder[T <: AtomicType]() extends Encoder[T] { override def uncompressedSize: Int = 0 @@ -247,7 +247,7 @@ case object CelebornDictionaryEncoding extends CelebornCompressionScheme { override val typeId = 1 // 32K unique values allowed - var MAX_DICT_SIZE = Short.MaxValue + var MAX_DICT_SIZE: Short = Short.MaxValue override def decoder[T <: AtomicType]( buffer: ByteBuffer, @@ -277,7 +277,7 @@ case object CelebornDictionaryEncoding extends CelebornCompressionScheme { // Total number of elements. private var count = 0 - def cleanBatch: Unit = { + def cleanBatch(): Unit = { count = 0 _uncompressedSize = 0 } @@ -341,11 +341,11 @@ case object CelebornDictionaryEncoding extends CelebornCompressionScheme { buffer: ByteBuffer, columnType: NativeCelebornColumnType[T]) extends Decoder[T] { - val elementNum = ByteBufferHelper.getInt(buffer) + private val elementNum: Int = ByteBufferHelper.getInt(buffer) private val dictionary: Array[Any] = new Array[Any](elementNum) - private var intDictionary: Array[Int] = null - private var longDictionary: Array[Long] = null - private var stringDictionary: Array[String] = null + private var intDictionary: Array[Int] = _ + private var longDictionary: Array[Long] = _ + private var stringDictionary: Array[String] = _ columnType.dataType match { case _: IntegerType => diff --git a/client-spark/spark-3-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java b/client-spark/spark-3-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java index 2cb6d3548ef..e481b6181cc 100644 --- a/client-spark/spark-3-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java +++ b/client-spark/spark-3-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java @@ -17,7 +17,15 @@ package org.apache.spark.shuffle.celeborn; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.util.UUID; + +import org.apache.spark.HashPartitioner; import org.apache.spark.TaskContext; +import org.apache.spark.serializer.KryoSerializer; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; @@ -28,28 +36,59 @@ import org.apache.spark.sql.types.IntegerType$; import org.apache.spark.sql.types.StringType$; import org.apache.spark.sql.types.StructType; +import org.junit.Test; import org.mockito.MockedStatic; import org.mockito.Mockito; +import org.apache.celeborn.client.DummyShuffleClient; import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.common.CelebornConf; public class ColumnarHashBasedShuffleWriterSuiteJ extends CelebornShuffleWriterSuiteBase { - private StructType schema = + private final StructType schema = new StructType().add("key", IntegerType$.MODULE$).add("value", StringType$.MODULE$); + @Test + public void createColumnarShuffleWriter() throws Exception { + Mockito.doReturn(new HashPartitioner(numPartitions)).when(dependency).partitioner(); + final CelebornConf conf = new CelebornConf(); + final File tempFile = new File(tempDir, UUID.randomUUID().toString()); + final DummyShuffleClient client = new DummyShuffleClient(conf, tempFile); + client.initReducePartitionMap(shuffleId, numPartitions, 1); + + // Create ColumnarHashBasedShuffleWriter with handle of which dependency has null schema. + Mockito.doReturn(new KryoSerializer(sparkConf)).when(dependency).serializer(); + ShuffleWriter writer = + createShuffleWriterWithoutSchema( + new CelebornShuffleHandle<>( + "appId", "host", 0, this.userIdentifier, 0, 10, this.dependency), + taskContext, + conf, + client, + metrics.shuffleWriteMetrics()); + assertTrue(writer instanceof ColumnarHashBasedShuffleWriter); + assertFalse(((ColumnarHashBasedShuffleWriter) writer).isColumnarShuffle()); + + // Create ColumnarHashBasedShuffleWriter with handle of which dependency has non-null schema. + Mockito.doReturn(new UnsafeRowSerializer(2, null)).when(dependency).serializer(); + writer = + createShuffleWriter( + new CelebornShuffleHandle<>( + "appId", "host", 0, this.userIdentifier, 0, 10, this.dependency), + taskContext, + conf, + client, + metrics.shuffleWriteMetrics()); + assertTrue(((ColumnarHashBasedShuffleWriter) writer).isColumnarShuffle()); + } + @Override protected SerializerInstance newSerializerInstance(Serializer serializer) { if (serializer instanceof UnsafeRowSerializer && CelebornBatchBuilder.supportsColumnarType(schema)) { CelebornConf conf = new CelebornConf(); - return new CelebornColumnarBatchSerializer( - schema, - conf.columnarShuffleBatchSize(), - conf.columnarShuffleDictionaryEnabled(), - conf.columnarShuffleOffHeapEnabled(), - null) + return new CelebornColumnarBatchSerializer(schema, conf.columnarShuffleOffHeapEnabled(), null) .newInstance(); } else { return serializer.newInstance(); @@ -72,4 +111,14 @@ protected ShuffleWriter createShuffleWriter( handle, context, conf, client, metrics, SendBufferPool.get(1, 30, 60)); } } + + private ShuffleWriter createShuffleWriterWithoutSchema( + CelebornShuffleHandle handle, + TaskContext context, + CelebornConf conf, + ShuffleClient client, + ShuffleWriteMetricsReporter metrics) { + return SparkUtils.createColumnarHashBasedShuffleWriter( + handle, context, conf, client, metrics, SendBufferPool.get(1, 30, 60)); + } } diff --git a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala index ec57c192b16..5a14d021906 100644 --- a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala +++ b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala @@ -17,6 +17,11 @@ package org.apache.spark.shuffle.celeborn +import org.apache.spark.{ShuffleDependency, SparkConf} +import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance} +import org.apache.spark.sql.execution.UnsafeRowSerializer +import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchSerializerInstance +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.junit.Test import org.mockito.{MockedStatic, Mockito} @@ -37,9 +42,9 @@ class CelebornColumnarShuffleReaderSuite { 10, null) - var shuffleClientClass: MockedStatic[ShuffleClient] = null + var shuffleClient: MockedStatic[ShuffleClient] = null try { - shuffleClientClass = Mockito.mockStatic(classOf[ShuffleClient]) + shuffleClient = Mockito.mockStatic(classOf[ShuffleClient]) val shuffleReader = SparkUtils.createColumnarShuffleReader( handle, 0, @@ -51,8 +56,54 @@ class CelebornColumnarShuffleReaderSuite { null) assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]]) } finally { - if (shuffleClientClass != null) { - shuffleClientClass.close() + if (shuffleClient != null) { + shuffleClient.close() + } + } + } + + @Test + def columnarShuffleReaderNewSerializerInstance(): Unit = { + var shuffleClient: MockedStatic[ShuffleClient] = null + try { + shuffleClient = Mockito.mockStatic(classOf[ShuffleClient]) + val shuffleReader = SparkUtils.createColumnarShuffleReader( + new CelebornShuffleHandle[Int, String, String]( + "appId", + "host", + 0, + new UserIdentifier("mock", "mock"), + 0, + 10, + null), + 0, + 10, + 0, + 10, + null, + new CelebornConf(), + null) + val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]]) + Mockito.when(shuffleDependency.shuffleId).thenReturn(0) + Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer( + new SparkConf(false))) + + // CelebornColumnarShuffleReader creates new serializer instance with dependency which has null schema. + var serializerInstance = shuffleReader.newSerializerInstance(shuffleDependency) + assert(serializerInstance.getClass == classOf[KryoSerializerInstance]) + + // CelebornColumnarShuffleReader creates new serializer instance with dependency which has non-null schema. + val dependencyUtils = Mockito.mockStatic(classOf[CustomShuffleDependencyUtils]) + dependencyUtils.when(() => + CustomShuffleDependencyUtils.getSchema(shuffleDependency)).thenReturn(new StructType().add( + "key", + IntegerType).add("value", StringType)) + Mockito.when(shuffleDependency.serializer).thenReturn(new UnsafeRowSerializer(2, null)) + serializerInstance = shuffleReader.newSerializerInstance(shuffleDependency) + assert(serializerInstance.getClass == classOf[CelebornColumnarBatchSerializerInstance]) + } finally { + if (shuffleClient != null) { + shuffleClient.close() } } } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 5ec0fed9b79..d83df1a5b76 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -192,7 +192,7 @@ class CelebornShuffleReader[K, C]( } } - protected def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { + def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { dep.serializer.newInstance() } diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java index d8a3d4986e3..f4109753ea4 100644 --- a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java +++ b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java @@ -94,23 +94,23 @@ public abstract class CelebornShuffleWriterSuiteBase { private final String appId = "appId"; private final String host = "host"; private final int port = 0; - private final int shuffleId = 0; + protected final int shuffleId = 0; - private final UserIdentifier userIdentifier = new UserIdentifier("mock", "mock"); + protected final UserIdentifier userIdentifier = new UserIdentifier("mock", "mock"); private final int numMaps = 10; - private final int numPartitions = 10; - private final SparkConf sparkConf = new SparkConf(false); + protected final int numPartitions = 10; + protected final SparkConf sparkConf = new SparkConf(false); private final BlockManagerId bmId = BlockManagerId.apply("execId", "host", 1, None$.empty()); private final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(UnifiedMemoryManager.apply(sparkConf, 1), 0); @Mock(answer = Answers.RETURNS_SMART_NULLS) - private TaskContext taskContext = null; + protected TaskContext taskContext = null; @Mock(answer = Answers.RETURNS_SMART_NULLS) - private ShuffleDependency dependency = null; + protected ShuffleDependency dependency = null; @Mock(answer = Answers.RETURNS_SMART_NULLS) private SparkEnv env = null; @@ -118,9 +118,9 @@ public abstract class CelebornShuffleWriterSuiteBase { @Mock(answer = Answers.RETURNS_SMART_NULLS) private BlockManager blockManager = null; - private TaskMetrics metrics = null; + protected TaskMetrics metrics = null; - private static File tempDir = null; + protected static File tempDir = null; @BeforeClass public static void beforeAll() {