diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index dee8d897a65..4ddb8e98dc1 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -70,6 +70,7 @@ public class HashBasedShuffleWriter extends ShuffleWriter { private final ShuffleDependency dep; private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; + private final int stageId; private final int shuffleId; private final int mapId; private final TaskContext taskContext; @@ -132,6 +133,7 @@ public HashBasedShuffleWriter( throws IOException { this.mapId = taskContext.partitionId(); this.dep = handle.dependency(); + this.stageId = taskContext.stageId(); this.shuffleId = dep.shuffleId(); SerializerInstance serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); @@ -185,7 +187,7 @@ public HashBasedShuffleWriter( columnarShuffleDictionaryMaxFactor = conf.columnarShuffleDictionaryMaxFactor(); this.schema = SparkUtils.getSchema(dep); this.celebornBatchBuilders = new CelebornBatchBuilder[numPartitions]; - this.isColumnarShuffle = CelebornBatchBuilder.supportsColumnarType(schema); + this.isColumnarShuffle = schema != null && CelebornBatchBuilder.supportsColumnarType(schema); } } @@ -194,6 +196,8 @@ public void write(scala.collection.Iterator> records) throws IOEx try { if (canUseFastWrite()) { if (isColumnarShuffle) { + logger.info( + "Fast columnar write of columnar shuffle {} for stage {}.", shuffleId, stageId); fastColumnarWrite0(records); } else { fastWrite0(records); diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 8f8600b9a92..7f4a541cbd5 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -17,7 +17,6 @@ package org.apache.spark.shuffle.celeborn; -import java.io.IOException; import java.util.concurrent.atomic.LongAdder; import scala.Tuple2; @@ -162,10 +161,17 @@ public static ShuffleReader getReader( private static final DynFields.UnboundField SCHEMA_FIELD = DynFields.builder().hiddenImpl(ShuffleDependency.class, "schema").defaultAlwaysNull().build(); - public static StructType getSchema(ShuffleDependency dep) throws IOException { - StructType schema = SCHEMA_FIELD.bind(dep).get(); + public static StructType getSchema(ShuffleDependency dep) { + StructType schema = null; + try { + schema = SCHEMA_FIELD.bind(dep).get(); + } catch (Exception e) { + LOG.error("Failed to bind shuffle dependency of shuffle {}.", dep.shuffleId(), e); + } if (schema == null) { - throw new IOException("Failed to get Schema, columnar shuffle won't work properly."); + LOG.warn( + "Failed to get Schema of shuffle {}, columnar shuffle won't work properly.", + dep.shuffleId()); } return schema; } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index ecbf51f80f4..e27cfbfc19e 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -62,14 +62,11 @@ class CelebornShuffleReader[K, C]( var serializerInstance = dep.serializer.newInstance() if (conf.columnarShuffleEnabled) { val schema = SparkUtils.getSchema(dep) - if (CelebornBatchBuilder.supportsColumnarType( - schema)) { - val dataSize = SparkUtils.getDataSize( - dep.serializer.asInstanceOf[UnsafeRowSerializer]) + if (schema != null && CelebornBatchBuilder.supportsColumnarType(schema)) { + logInfo(s"Creating column batch serializer of columnar shuffle ${dep.shuffleId}.") + val dataSize = SparkUtils.getDataSize(dep.serializer.asInstanceOf[UnsafeRowSerializer]) serializerInstance = new CelebornColumnarBatchSerializer( schema, - conf.columnarShuffleBatchSize, - conf.columnarShuffleDictionaryEnabled, conf.columnarShuffleOffHeapEnabled, dataSize).newInstance() } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornBatchBuilder.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornBatchBuilder.scala index 7ae77fec030..bc93c10b5f0 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornBatchBuilder.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornBatchBuilder.scala @@ -28,7 +28,7 @@ abstract class CelebornBatchBuilder { def writeRow(row: InternalRow): Unit - def getRowCnt(): Int + def getRowCnt: Int def int2ByteArray(i: Int): Array[Byte] = { val result = new Array[Byte](4) @@ -46,7 +46,7 @@ object CelebornBatchBuilder { f.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | StringType => true - case dt: DecimalType => true + case _: DecimalType => true case _ => false }) } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala index 064bbefc6af..a75c8d32a89 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala @@ -61,13 +61,9 @@ abstract class CelebornBasicColumnAccessor[JvmType]( columnType.extract(buffer, row, ordinal) } - protected def underlyingBuffer = buffer + protected def underlyingBuffer: ByteBuffer = buffer } -class CelebornNullColumnAccessor(buffer: ByteBuffer) - extends CelebornBasicColumnAccessor[Any](buffer, CELEBORN_NULL) - with CelebornNullableColumnAccessor - abstract class CelebornNativeColumnAccessor[T <: AtomicType]( override protected val buffer: ByteBuffer, override protected val columnType: NativeCelebornColumnType[T]) @@ -112,7 +108,6 @@ private[sql] object CelebornColumnAccessor { val buf = buffer.order(ByteOrder.nativeOrder) dataType match { - case NullType => new CelebornNullColumnAccessor(buf) case BooleanType => new CelebornBooleanColumnAccessor(buf) case ByteType => new CelebornByteColumnAccessor(buf) case ShortType => new CelebornShortColumnAccessor(buf) @@ -135,7 +130,7 @@ private[sql] object CelebornColumnAccessor { columnAccessor match { case nativeAccessor: CelebornNativeColumnAccessor[_] => nativeAccessor.decompress(columnVector, numRows) - case d: CelebornDecimalColumnAccessor => + case _: CelebornDecimalColumnAccessor => (0 until numRows).foreach(columnAccessor.extractToColumnVector(columnVector, _)) case _ => throw new RuntimeException("Not support non-primitive type now") diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala index 0abfdd0cd4e..f65a5fd8653 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala @@ -88,10 +88,6 @@ class CelebornBasicColumnBuilder[JvmType]( } } -class CelebornNullColumnBuilder - extends CelebornBasicColumnBuilder[Any](new CelebornObjectColumnStats(NullType), CELEBORN_NULL) - with CelebornNullableColumnBuilder - abstract class CelebornComplexColumnBuilder[JvmType]( columnStats: CelebornColumnStats, columnType: CelebornColumnType[JvmType]) @@ -318,7 +314,6 @@ class CelebornDecimalCodeGenColumnBuilder(dataType: DecimalType) } object CelebornColumnBuilder { - val MAX_BATCH_SIZE_IN_BYTE: Long = 4 * 1024 * 1024L def ensureFreeSpace(orig: ByteBuffer, size: Int): ByteBuffer = { if (orig.remaining >= size) { @@ -343,7 +338,6 @@ object CelebornColumnBuilder { encodingEnabled: Boolean, encoder: Encoder[_ <: AtomicType]): CelebornColumnBuilder = { val builder: CelebornColumnBuilder = dataType match { - case NullType => new CelebornNullColumnBuilder case ByteType => new CelebornByteColumnBuilder case BooleanType => new CelebornBooleanColumnBuilder case ShortType => new CelebornShortColumnBuilder @@ -367,7 +361,7 @@ object CelebornColumnBuilder { new CelebornCompactDecimalColumnBuilder(dt) case dt: DecimalType => new CelebornDecimalColumnBuilder(dt) case other => - throw new Exception(s"not support type: $other") + throw new Exception(s"Unsupported type: $other") } builder.initialize(rowCnt, columnName, encodingEnabled) diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala index 6c2aa0f7b83..80e883bf0f7 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala @@ -63,7 +63,7 @@ final private[columnar] class CelebornBooleanColumnStats extends CelebornColumnS val value = row.getBoolean(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -79,15 +79,15 @@ final private[columnar] class CelebornBooleanColumnStats extends CelebornColumnS } final private[columnar] class CelebornByteColumnStats extends CelebornColumnStats { - protected var upper = Byte.MinValue - protected var lower = Byte.MaxValue + protected var upper: Byte = Byte.MinValue + protected var lower: Byte = Byte.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getByte(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -103,15 +103,15 @@ final private[columnar] class CelebornByteColumnStats extends CelebornColumnStat } final private[columnar] class CelebornShortColumnStats extends CelebornColumnStats { - protected var upper = Short.MinValue - protected var lower = Short.MaxValue + protected var upper: Short = Short.MinValue + protected var lower: Short = Short.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getShort(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -127,15 +127,15 @@ final private[columnar] class CelebornShortColumnStats extends CelebornColumnSta } final private[columnar] class CelebornIntColumnStats extends CelebornColumnStats { - protected var upper = Int.MinValue - protected var lower = Int.MaxValue + protected var upper: Int = Int.MinValue + protected var lower: Int = Int.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getInt(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -151,15 +151,15 @@ final private[columnar] class CelebornIntColumnStats extends CelebornColumnStats } final private[columnar] class CelebornLongColumnStats extends CelebornColumnStats { - protected var upper = Long.MinValue - protected var lower = Long.MaxValue + protected var upper: Long = Long.MinValue + protected var lower: Long = Long.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getLong(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -175,15 +175,15 @@ final private[columnar] class CelebornLongColumnStats extends CelebornColumnStat } final private[columnar] class CelebornFloatColumnStats extends CelebornColumnStats { - protected var upper = Float.MinValue - protected var lower = Float.MaxValue + protected var upper: Float = Float.MinValue + protected var lower: Float = Float.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getFloat(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -199,15 +199,15 @@ final private[columnar] class CelebornFloatColumnStats extends CelebornColumnSta } final private[columnar] class CelebornDoubleColumnStats extends CelebornColumnStats { - protected var upper = Double.MinValue - protected var lower = Double.MaxValue + protected var upper: Double = Double.MinValue + protected var lower: Double = Double.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getDouble(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -223,8 +223,8 @@ final private[columnar] class CelebornDoubleColumnStats extends CelebornColumnSt } final private[columnar] class CelebornStringColumnStats extends CelebornColumnStats { - protected var upper: UTF8String = null - protected var lower: UTF8String = null + protected var upper: UTF8String = _ + protected var lower: UTF8String = _ override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { @@ -232,7 +232,7 @@ final private[columnar] class CelebornStringColumnStats extends CelebornColumnSt val size = CELEBORN_STRING.actualSize(row, ordinal) gatherValueStats(value, size) } else { - gatherNullStats + gatherNullStats() } } @@ -247,34 +247,19 @@ final private[columnar] class CelebornStringColumnStats extends CelebornColumnSt Array[Any](lower, upper, nullCount, count, sizeInBytes) } -final private[columnar] class CelebornBinaryColumnStats extends CelebornColumnStats { - override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - if (!row.isNullAt(ordinal)) { - val size = CELEBORN_BINARY.actualSize(row, ordinal) - sizeInBytes += size - count += 1 - } else { - gatherNullStats - } - } - - override def collectedStatistics: Array[Any] = - Array[Any](null, null, nullCount, count, sizeInBytes) -} - final private[columnar] class CelebornDecimalColumnStats(precision: Int, scale: Int) extends CelebornColumnStats { def this(dt: DecimalType) = this(dt.precision, dt.scale) - protected var upper: Decimal = null - protected var lower: Decimal = null + protected var upper: Decimal = _ + protected var lower: Decimal = _ override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getDecimal(ordinal, precision, scale) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -294,21 +279,3 @@ final private[columnar] class CelebornDecimalColumnStats(precision: Int, scale: override def collectedStatistics: Array[Any] = Array[Any](lower, upper, nullCount, count, sizeInBytes) } - -final private[columnar] class CelebornObjectColumnStats(dataType: DataType) - extends CelebornColumnStats { - val columnType = CelebornColumnType(dataType) - - override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - if (!row.isNullAt(ordinal)) { - val size = columnType.actualSize(row, ordinal) - sizeInBytes += size - count += 1 - } else { - gatherNullStats - } - } - - override def collectedStatistics: Array[Any] = - Array[Any](null, null, nullCount, count, sizeInBytes) -} diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala index d1d5461a431..69cf10a2ea4 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.columnar import java.math.{BigDecimal, BigInteger} import java.nio.ByteBuffer -import scala.reflect.runtime.universe.TypeTag - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -177,26 +175,10 @@ sealed abstract private[columnar] class CelebornColumnType[JvmType] { override def toString: String = getClass.getSimpleName.stripSuffix("$") } -private[columnar] object CELEBORN_NULL extends CelebornColumnType[Any] { - - override def dataType: DataType = NullType - override def defaultSize: Int = 0 - override def append(v: Any, buffer: ByteBuffer): Unit = {} - override def extract(buffer: ByteBuffer): Any = null - override def setField(row: InternalRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal) - override def getField(row: InternalRow, ordinal: Int): Any = null -} - abstract private[columnar] class NativeCelebornColumnType[T <: AtomicType]( val dataType: T, val defaultSize: Int) - extends CelebornColumnType[T#InternalType] { - - /** - * Scala TypeTag. Can be used to create primitive arrays and hash tables. - */ - def scalaTag: TypeTag[dataType.InternalType] = dataType.tag -} + extends CelebornColumnType[T#InternalType] {} private[columnar] object CELEBORN_INT extends NativeCelebornColumnType(IntegerType, 4) { override def append(v: Int, buffer: ByteBuffer): Unit = { @@ -428,26 +410,28 @@ private[columnar] trait DirectCopyCelebornColumnType[JvmType] extends CelebornCo // copy the bytes from ByteBuffer to UnsafeRow override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { - if (row.isInstanceOf[MutableUnsafeRow]) { - val numBytes = buffer.getInt - val cursor = buffer.position() - buffer.position(cursor + numBytes) - row.asInstanceOf[MutableUnsafeRow].writer.write( - ordinal, - buffer.array(), - buffer.arrayOffset() + cursor, - numBytes) - } else { - setField(row, ordinal, extract(buffer)) + row match { + case r: MutableUnsafeRow => + val numBytes = buffer.getInt + val cursor = buffer.position() + buffer.position(cursor + numBytes) + r.writer.write( + ordinal, + buffer.array(), + buffer.arrayOffset() + cursor, + numBytes) + case _ => + setField(row, ordinal, extract(buffer)) } } // copy the bytes from UnsafeRow to ByteBuffer override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { - if (row.isInstanceOf[UnsafeRow]) { - row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer) - } else { - super.append(row, ordinal, buffer) + row match { + case r: UnsafeRow => + r.writeFieldTo(ordinal, buffer) + case _ => + super.append(row, ordinal, buffer) } } } @@ -472,10 +456,11 @@ private[columnar] object CELEBORN_STRING } override def setField(row: InternalRow, ordinal: Int, value: UTF8String): Unit = { - if (row.isInstanceOf[MutableUnsafeRow]) { - row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value) - } else { - row.update(ordinal, value.clone()) + row match { + case r: MutableUnsafeRow => + r.writer.write(ordinal, value) + case _ => + row.update(ordinal, value.clone()) } } @@ -617,26 +602,6 @@ sealed abstract private[columnar] class ByteArrayCelebornColumnType[JvmType](val } } -private[columnar] object CELEBORN_BINARY extends ByteArrayCelebornColumnType[Array[Byte]](16) { - - def dataType: DataType = BinaryType - - override def setField(row: InternalRow, ordinal: Int, value: Array[Byte]): Unit = { - row.update(ordinal, value) - } - - override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - row.getBinary(ordinal) - } - - override def actualSize(row: InternalRow, ordinal: Int): Int = { - row.getBinary(ordinal).length + 4 - } - - def serialize(value: Array[Byte]): Array[Byte] = value - def deserialize(bytes: Array[Byte]): Array[Byte] = bytes -} - private[columnar] case class CELEBORN_LARGE_DECIMAL(precision: Int, scale: Int) extends ByteArrayCelebornColumnType[Decimal](12) { @@ -673,7 +638,6 @@ private[columnar] object CELEBORN_LARGE_DECIMAL { private[columnar] object CelebornColumnType { def apply(dataType: DataType): CelebornColumnType[_] = { dataType match { - case NullType => CELEBORN_NULL case BooleanType => CELEBORN_BOOLEAN case ByteType => CELEBORN_BYTE case ShortType => CELEBORN_SHORT @@ -682,7 +646,6 @@ private[columnar] object CelebornColumnType { case FloatType => CELEBORN_FLOAT case DoubleType => CELEBORN_DOUBLE case StringType => CELEBORN_STRING - case BinaryType => CELEBORN_BINARY case dt: DecimalType if dt.precision <= Decimal.MAX_INT_DIGITS => CELEBORN_COMPACT_MINI_DECIMAL(dt) case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala index 159b15e327a..ab6f600721b 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala @@ -30,7 +30,8 @@ class CelebornColumnarBatchBuilder( encodingEnabled: Boolean = false) extends CelebornBatchBuilder { var rowCnt = 0 - val typeConversion: PartialFunction[DataType, NativeCelebornColumnType[_ <: AtomicType]] = { + private val typeConversion + : PartialFunction[DataType, NativeCelebornColumnType[_ <: AtomicType]] = { case IntegerType => CELEBORN_INT case LongType => CELEBORN_LONG case StringType => CELEBORN_STRING @@ -45,7 +46,7 @@ class CelebornColumnarBatchBuilder( case _ => null } - val encodersArr: Array[Encoder[_ <: AtomicType]] = schema.map { attribute => + private val encodersArr: Array[Encoder[_ <: AtomicType]] = schema.map { attribute => val nativeColumnType = typeConversion(attribute.dataType) if (nativeColumnType == null) { null @@ -63,14 +64,13 @@ class CelebornColumnarBatchBuilder( var columnBuilders: Array[CelebornColumnBuilder] = _ def newBuilders(): Unit = { - totalSize = 0 rowCnt = 0 var i = -1 columnBuilders = schema.map { attribute => i += 1 encodersArr(i) match { case encoder: CelebornDictionaryEncoding.CelebornEncoder[_] if !encoder.overflow => - encoder.cleanBatch + encoder.cleanBatch() case _ => } CelebornColumnBuilder( @@ -100,8 +100,6 @@ class CelebornColumnarBatchBuilder( giantBuffer.toByteArray } - var totalSize = 0 - def writeRow(row: InternalRow): Unit = { var i = 0 while (i < row.numFields) { @@ -111,21 +109,5 @@ class CelebornColumnarBatchBuilder( rowCnt += 1 } - def getTotalSize(): Int = { - var i = 0 - var tempTotalSize = 0 - while (i < schema.length) { - columnBuilders(i) match { - case builder: CelebornCompressibleColumnBuilder[_] => - tempTotalSize += builder.getTotalSize.toInt - case builder: CelebornNullableColumnBuilder => tempTotalSize += builder.getTotalSize.toInt - case _ => - } - i += 1 - } - totalSize = tempTotalSize + 4 + 4 * schema.length - totalSize - } - - def getRowCnt(): Int = rowCnt + def getRowCnt: Int = rowCnt } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala index 1c15d163a2b..e510e645259 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala @@ -102,21 +102,6 @@ class CelebornColumnarBatchCodeGenBuild { val writeRowCode = new mutable.StringBuilder() for (index <- schema.indices) { schema.fields(index).dataType match { - case NullType => - initCode.append( - s""" - | ${classOf[CelebornNullColumnBuilder].getName} b$index; - """.stripMargin) - buildCode.append( - s""" - | b$index = new ${classOf[CelebornNullColumnBuilder].getName}(); - | builder.initialize($batchSize, "${schema.fields(index).name}", false); - """.stripMargin) - writeCode.append(genWriteCode(index)) - writeRowCode.append( - s""" - | b$index.appendFrom(row, $index); - """.stripMargin) case ByteType => initCode.append( s""" diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala index c4be15c0e13..3018c0edf5b 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala @@ -34,15 +34,11 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} class CelebornColumnarBatchSerializer( schema: StructType, - columnBatchSize: Int, - encodingEnabled: Boolean, offHeapColumnVectorEnabled: Boolean, dataSize: SQLMetric = null) extends Serializer with Serializable { override def newInstance(): SerializerInstance = new CelebornColumnarBatchSerializerInstance( schema, - columnBatchSize, - encodingEnabled, offHeapColumnVectorEnabled, dataSize) override def supportsRelocationOfSerializedObjects: Boolean = true @@ -50,8 +46,6 @@ class CelebornColumnarBatchSerializer( private class CelebornColumnarBatchSerializerInstance( schema: StructType, - columnBatchSize: Int, - encodingEnabled: Boolean, offHeapColumnVectorEnabled: Boolean, dataSize: SQLMetric) extends SerializerInstance { @@ -93,7 +87,8 @@ private class CelebornColumnarBatchSerializerInstance( } } - val toUnsafe: UnsafeProjection = UnsafeProjection.create(schema.fields.map(f => f.dataType)) + private val toUnsafe: UnsafeProjection = + UnsafeProjection.create(schema.fields.map(f => f.dataType)) override def deserializeStream(in: InputStream): DeserializationStream = { val numFields = schema.fields.length @@ -160,7 +155,7 @@ private class CelebornColumnarBatchSerializerInstance( try { dIn.readInt() } catch { - case e: EOFException => + case _: EOFException => dIn.close() EOF } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala index 2d87856c199..a0cc2be2a1f 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala @@ -29,7 +29,7 @@ trait CelebornCompressibleColumnBuilder[T <: AtomicType] this: CelebornNativeColumnBuilder[T] with WithCelebornCompressionSchemes => - var compressionEncoder: Encoder[T] = CelebornPassThrough.encoder(columnType) + private var compressionEncoder: Encoder[T] = CelebornPassThrough.encoder(columnType) def init(encoder: Encoder[T]): Unit = { compressionEncoder = encoder @@ -46,7 +46,7 @@ trait CelebornCompressibleColumnBuilder[T <: AtomicType] // the row to become unaligned, thus causing crashes. Until a way of fixing the compression // is found to also allow aligned accesses this must be disabled for SPARK. - protected def isWorthCompressing(encoder: Encoder[T]) = { + protected def isWorthCompressing(encoder: Encoder[T]): Boolean = { CelebornCompressibleColumnBuilder.unaligned && encoder.compressionRatio < 0.8 } @@ -103,5 +103,5 @@ trait CelebornCompressibleColumnBuilder[T <: AtomicType] } object CelebornCompressibleColumnBuilder { - val unaligned = Platform.unaligned() + val unaligned: Boolean = Platform.unaligned() } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala index a6ba31176c4..1e7ebae0e28 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.columnar -import java.nio.{ByteBuffer, ByteOrder} +import java.nio.ByteBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.WritableColumnVector @@ -76,11 +76,4 @@ object CelebornCompressionScheme { typeId, throw new UnsupportedOperationException(s"Unrecognized compression scheme type ID: $typeId")) } - - def columnHeaderSize(columnBuffer: ByteBuffer): Int = { - val header = columnBuffer.duplicate().order(ByteOrder.nativeOrder) - val nullCount = header.getInt() - // null count + null positions - 4 + 4 * nullCount - } } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala index 316e213c8ca..c2dfb53c211 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala @@ -33,7 +33,7 @@ case object CelebornPassThrough extends CelebornCompressionScheme { override def supports(columnType: CelebornColumnType[_]): Boolean = true override def encoder[T <: AtomicType](columnType: NativeCelebornColumnType[T]): Encoder[T] = { - new this.CelebornEncoder[T](columnType) + new this.CelebornEncoder[T]() } override def decoder[T <: AtomicType]( @@ -42,7 +42,7 @@ case object CelebornPassThrough extends CelebornCompressionScheme { new this.CelebornDecoder(buffer, columnType) } - class CelebornEncoder[T <: AtomicType](columnType: NativeCelebornColumnType[T]) + class CelebornEncoder[T <: AtomicType]() extends Encoder[T] { override def uncompressedSize: Int = 0 @@ -247,7 +247,7 @@ case object CelebornDictionaryEncoding extends CelebornCompressionScheme { override val typeId = 1 // 32K unique values allowed - var MAX_DICT_SIZE = Short.MaxValue + var MAX_DICT_SIZE: Short = Short.MaxValue override def decoder[T <: AtomicType]( buffer: ByteBuffer, @@ -277,7 +277,7 @@ case object CelebornDictionaryEncoding extends CelebornCompressionScheme { // Total number of elements. private var count = 0 - def cleanBatch: Unit = { + def cleanBatch(): Unit = { count = 0 _uncompressedSize = 0 } @@ -341,11 +341,11 @@ case object CelebornDictionaryEncoding extends CelebornCompressionScheme { buffer: ByteBuffer, columnType: NativeCelebornColumnType[T]) extends Decoder[T] { - val elementNum = ByteBufferHelper.getInt(buffer) + private val elementNum: Int = ByteBufferHelper.getInt(buffer) private val dictionary: Array[Any] = new Array[Any](elementNum) - private var intDictionary: Array[Int] = null - private var longDictionary: Array[Long] = null - private var stringDictionary: Array[String] = null + private var intDictionary: Array[Int] = _ + private var longDictionary: Array[Long] = _ + private var stringDictionary: Array[String] = _ columnType.dataType match { case _: IntegerType =>