Skip to content

Commit

Permalink
[CELEBORN-1123] Support fallback to non-columnar shuffle for schema t…
Browse files Browse the repository at this point in the history
…hat cannot be obtained from shuffle dependency (#2110)
  • Loading branch information
SteNicholas authored Nov 27, 2023
1 parent 870fdc8 commit c67c6e4
Show file tree
Hide file tree
Showing 14 changed files with 91 additions and 210 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public class HashBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
private final ShuffleDependency<K, V, C> dep;
private final Partitioner partitioner;
private final ShuffleWriteMetricsReporter writeMetrics;
private final int stageId;
private final int shuffleId;
private final int mapId;
private final TaskContext taskContext;
Expand Down Expand Up @@ -132,6 +133,7 @@ public HashBasedShuffleWriter(
throws IOException {
this.mapId = taskContext.partitionId();
this.dep = handle.dependency();
this.stageId = taskContext.stageId();
this.shuffleId = dep.shuffleId();
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
Expand Down Expand Up @@ -185,7 +187,7 @@ public HashBasedShuffleWriter(
columnarShuffleDictionaryMaxFactor = conf.columnarShuffleDictionaryMaxFactor();
this.schema = SparkUtils.getSchema(dep);
this.celebornBatchBuilders = new CelebornBatchBuilder[numPartitions];
this.isColumnarShuffle = CelebornBatchBuilder.supportsColumnarType(schema);
this.isColumnarShuffle = schema != null && CelebornBatchBuilder.supportsColumnarType(schema);
}
}

Expand All @@ -194,6 +196,8 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOEx
try {
if (canUseFastWrite()) {
if (isColumnarShuffle) {
logger.info(
"Fast columnar write of columnar shuffle {} for stage {}.", shuffleId, stageId);
fastColumnarWrite0(records);
} else {
fastWrite0(records);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.shuffle.celeborn;

import java.io.IOException;
import java.util.concurrent.atomic.LongAdder;

import scala.Tuple2;
Expand Down Expand Up @@ -162,10 +161,17 @@ public static <K, C> ShuffleReader<K, C> getReader(
private static final DynFields.UnboundField<StructType> SCHEMA_FIELD =
DynFields.builder().hiddenImpl(ShuffleDependency.class, "schema").defaultAlwaysNull().build();

public static StructType getSchema(ShuffleDependency<?, ?, ?> dep) throws IOException {
StructType schema = SCHEMA_FIELD.bind(dep).get();
public static StructType getSchema(ShuffleDependency<?, ?, ?> dep) {
StructType schema = null;
try {
schema = SCHEMA_FIELD.bind(dep).get();
} catch (Exception e) {
LOG.error("Failed to bind shuffle dependency of shuffle {}.", dep.shuffleId(), e);
}
if (schema == null) {
throw new IOException("Failed to get Schema, columnar shuffle won't work properly.");
LOG.warn(
"Failed to get Schema of shuffle {}, columnar shuffle won't work properly.",
dep.shuffleId());
}
return schema;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,11 @@ class CelebornShuffleReader[K, C](
var serializerInstance = dep.serializer.newInstance()
if (conf.columnarShuffleEnabled) {
val schema = SparkUtils.getSchema(dep)
if (CelebornBatchBuilder.supportsColumnarType(
schema)) {
val dataSize = SparkUtils.getDataSize(
dep.serializer.asInstanceOf[UnsafeRowSerializer])
if (schema != null && CelebornBatchBuilder.supportsColumnarType(schema)) {
logInfo(s"Creating column batch serializer of columnar shuffle ${dep.shuffleId}.")
val dataSize = SparkUtils.getDataSize(dep.serializer.asInstanceOf[UnsafeRowSerializer])
serializerInstance = new CelebornColumnarBatchSerializer(
schema,
conf.columnarShuffleBatchSize,
conf.columnarShuffleDictionaryEnabled,
conf.columnarShuffleOffHeapEnabled,
dataSize).newInstance()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ abstract class CelebornBatchBuilder {

def writeRow(row: InternalRow): Unit

def getRowCnt(): Int
def getRowCnt: Int

def int2ByteArray(i: Int): Array[Byte] = {
val result = new Array[Byte](4)
Expand All @@ -46,7 +46,7 @@ object CelebornBatchBuilder {
f.dataType match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | StringType => true
case dt: DecimalType => true
case _: DecimalType => true
case _ => false
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,9 @@ abstract class CelebornBasicColumnAccessor[JvmType](
columnType.extract(buffer, row, ordinal)
}

protected def underlyingBuffer = buffer
protected def underlyingBuffer: ByteBuffer = buffer
}

class CelebornNullColumnAccessor(buffer: ByteBuffer)
extends CelebornBasicColumnAccessor[Any](buffer, CELEBORN_NULL)
with CelebornNullableColumnAccessor

abstract class CelebornNativeColumnAccessor[T <: AtomicType](
override protected val buffer: ByteBuffer,
override protected val columnType: NativeCelebornColumnType[T])
Expand Down Expand Up @@ -112,7 +108,6 @@ private[sql] object CelebornColumnAccessor {
val buf = buffer.order(ByteOrder.nativeOrder)

dataType match {
case NullType => new CelebornNullColumnAccessor(buf)
case BooleanType => new CelebornBooleanColumnAccessor(buf)
case ByteType => new CelebornByteColumnAccessor(buf)
case ShortType => new CelebornShortColumnAccessor(buf)
Expand All @@ -135,7 +130,7 @@ private[sql] object CelebornColumnAccessor {
columnAccessor match {
case nativeAccessor: CelebornNativeColumnAccessor[_] =>
nativeAccessor.decompress(columnVector, numRows)
case d: CelebornDecimalColumnAccessor =>
case _: CelebornDecimalColumnAccessor =>
(0 until numRows).foreach(columnAccessor.extractToColumnVector(columnVector, _))
case _ =>
throw new RuntimeException("Not support non-primitive type now")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ class CelebornBasicColumnBuilder[JvmType](
}
}

class CelebornNullColumnBuilder
extends CelebornBasicColumnBuilder[Any](new CelebornObjectColumnStats(NullType), CELEBORN_NULL)
with CelebornNullableColumnBuilder

abstract class CelebornComplexColumnBuilder[JvmType](
columnStats: CelebornColumnStats,
columnType: CelebornColumnType[JvmType])
Expand Down Expand Up @@ -318,7 +314,6 @@ class CelebornDecimalCodeGenColumnBuilder(dataType: DecimalType)
}

object CelebornColumnBuilder {
val MAX_BATCH_SIZE_IN_BYTE: Long = 4 * 1024 * 1024L

def ensureFreeSpace(orig: ByteBuffer, size: Int): ByteBuffer = {
if (orig.remaining >= size) {
Expand All @@ -343,7 +338,6 @@ object CelebornColumnBuilder {
encodingEnabled: Boolean,
encoder: Encoder[_ <: AtomicType]): CelebornColumnBuilder = {
val builder: CelebornColumnBuilder = dataType match {
case NullType => new CelebornNullColumnBuilder
case ByteType => new CelebornByteColumnBuilder
case BooleanType => new CelebornBooleanColumnBuilder
case ShortType => new CelebornShortColumnBuilder
Expand All @@ -367,7 +361,7 @@ object CelebornColumnBuilder {
new CelebornCompactDecimalColumnBuilder(dt)
case dt: DecimalType => new CelebornDecimalColumnBuilder(dt)
case other =>
throw new Exception(s"not support type: $other")
throw new Exception(s"Unsupported type: $other")
}

builder.initialize(rowCnt, columnName, encodingEnabled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ final private[columnar] class CelebornBooleanColumnStats extends CelebornColumnS
val value = row.getBoolean(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -79,15 +79,15 @@ final private[columnar] class CelebornBooleanColumnStats extends CelebornColumnS
}

final private[columnar] class CelebornByteColumnStats extends CelebornColumnStats {
protected var upper = Byte.MinValue
protected var lower = Byte.MaxValue
protected var upper: Byte = Byte.MinValue
protected var lower: Byte = Byte.MaxValue

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getByte(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -103,15 +103,15 @@ final private[columnar] class CelebornByteColumnStats extends CelebornColumnStat
}

final private[columnar] class CelebornShortColumnStats extends CelebornColumnStats {
protected var upper = Short.MinValue
protected var lower = Short.MaxValue
protected var upper: Short = Short.MinValue
protected var lower: Short = Short.MaxValue

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getShort(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -127,15 +127,15 @@ final private[columnar] class CelebornShortColumnStats extends CelebornColumnSta
}

final private[columnar] class CelebornIntColumnStats extends CelebornColumnStats {
protected var upper = Int.MinValue
protected var lower = Int.MaxValue
protected var upper: Int = Int.MinValue
protected var lower: Int = Int.MaxValue

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getInt(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -151,15 +151,15 @@ final private[columnar] class CelebornIntColumnStats extends CelebornColumnStats
}

final private[columnar] class CelebornLongColumnStats extends CelebornColumnStats {
protected var upper = Long.MinValue
protected var lower = Long.MaxValue
protected var upper: Long = Long.MinValue
protected var lower: Long = Long.MaxValue

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getLong(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -175,15 +175,15 @@ final private[columnar] class CelebornLongColumnStats extends CelebornColumnStat
}

final private[columnar] class CelebornFloatColumnStats extends CelebornColumnStats {
protected var upper = Float.MinValue
protected var lower = Float.MaxValue
protected var upper: Float = Float.MinValue
protected var lower: Float = Float.MaxValue

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getFloat(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -199,15 +199,15 @@ final private[columnar] class CelebornFloatColumnStats extends CelebornColumnSta
}

final private[columnar] class CelebornDoubleColumnStats extends CelebornColumnStats {
protected var upper = Double.MinValue
protected var lower = Double.MaxValue
protected var upper: Double = Double.MinValue
protected var lower: Double = Double.MaxValue

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getDouble(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -223,16 +223,16 @@ final private[columnar] class CelebornDoubleColumnStats extends CelebornColumnSt
}

final private[columnar] class CelebornStringColumnStats extends CelebornColumnStats {
protected var upper: UTF8String = null
protected var lower: UTF8String = null
protected var upper: UTF8String = _
protected var lower: UTF8String = _

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getUTF8String(ordinal)
val size = CELEBORN_STRING.actualSize(row, ordinal)
gatherValueStats(value, size)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -247,34 +247,19 @@ final private[columnar] class CelebornStringColumnStats extends CelebornColumnSt
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

final private[columnar] class CelebornBinaryColumnStats extends CelebornColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val size = CELEBORN_BINARY.actualSize(row, ordinal)
sizeInBytes += size
count += 1
} else {
gatherNullStats
}
}

override def collectedStatistics: Array[Any] =
Array[Any](null, null, nullCount, count, sizeInBytes)
}

final private[columnar] class CelebornDecimalColumnStats(precision: Int, scale: Int)
extends CelebornColumnStats {
def this(dt: DecimalType) = this(dt.precision, dt.scale)

protected var upper: Decimal = null
protected var lower: Decimal = null
protected var upper: Decimal = _
protected var lower: Decimal = _

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getDecimal(ordinal, precision, scale)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -294,21 +279,3 @@ final private[columnar] class CelebornDecimalColumnStats(precision: Int, scale:
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

final private[columnar] class CelebornObjectColumnStats(dataType: DataType)
extends CelebornColumnStats {
val columnType = CelebornColumnType(dataType)

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val size = columnType.actualSize(row, ordinal)
sizeInBytes += size
count += 1
} else {
gatherNullStats
}
}

override def collectedStatistics: Array[Any] =
Array[Any](null, null, nullCount, count, sizeInBytes)
}
Loading

0 comments on commit c67c6e4

Please sign in to comment.