Skip to content

Commit

Permalink
[CELEBORN-1123] Support fallback to non-columnar shuffle for schema t…
Browse files Browse the repository at this point in the history
…hat cannot be obtained from shuffle dependency
  • Loading branch information
gaochao0509 committed Nov 17, 2023
1 parent 7263f64 commit 5a8926c
Show file tree
Hide file tree
Showing 18 changed files with 224 additions and 226 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import scala.Product2;

import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.TaskContext;
import org.apache.spark.annotation.Private;
import org.apache.spark.serializer.Serializer;
Expand All @@ -32,21 +34,28 @@
import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchCodeGenBuild;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;

@Private
public class ColumnarHashBasedShuffleWriter<K, V, C> extends HashBasedShuffleWriter<K, V, C> {

private CelebornBatchBuilder[] celebornBatchBuilders;
private StructType schema;
private Serializer depSerializer;
private boolean isColumnarShuffle = false;
private int columnarShuffleBatchSize;
private boolean columnarShuffleCodeGenEnabled;
private boolean columnarShuffleDictionaryEnabled;
private double columnarShuffleDictionaryMaxFactor;
private static final Logger logger =
LoggerFactory.getLogger(ColumnarHashBasedShuffleWriter.class);

private final int stageId;
private final int shuffleId;
private final CelebornBatchBuilder[] celebornBatchBuilders;
private final StructType schema;
private final Serializer depSerializer;
private final boolean isColumnarShuffle;
private final int columnarShuffleBatchSize;
private final boolean columnarShuffleCodeGenEnabled;
private final boolean columnarShuffleDictionaryEnabled;
private final double columnarShuffleDictionaryMaxFactor;

public ColumnarHashBasedShuffleWriter(
CelebornShuffleHandle<K, V, C> handle,
Expand All @@ -61,17 +70,21 @@ public ColumnarHashBasedShuffleWriter(
columnarShuffleCodeGenEnabled = conf.columnarShuffleCodeGenEnabled();
columnarShuffleDictionaryEnabled = conf.columnarShuffleDictionaryEnabled();
columnarShuffleDictionaryMaxFactor = conf.columnarShuffleDictionaryMaxFactor();
this.schema = CustomShuffleDependencyUtils.getSchema(handle.dependency());
ShuffleDependency<?, ?, ?> shuffleDependency = handle.dependency();
this.stageId = taskContext.stageId();
this.shuffleId = shuffleDependency.shuffleId();
this.schema = CustomShuffleDependencyUtils.getSchema(shuffleDependency);
this.depSerializer = handle.dependency().serializer();
this.celebornBatchBuilders =
new CelebornBatchBuilder[handle.dependency().partitioner().numPartitions()];
this.isColumnarShuffle = CelebornBatchBuilder.supportsColumnarType(schema);
this.isColumnarShuffle = schema != null && CelebornBatchBuilder.supportsColumnarType(schema);
}

@Override
protected void fastWrite0(scala.collection.Iterator iterator)
throws IOException, InterruptedException {
if (isColumnarShuffle) {
logger.info("Fast columnar write of columnar shuffle {} for stage {}.", shuffleId, stageId);
fastColumnarWrite0(iterator);
} else {
super.fastWrite0(iterator);
Expand Down Expand Up @@ -141,4 +154,9 @@ private void closeColumnarWrite() throws IOException {
}
}
}

@VisibleForTesting
public boolean isColumnarShuffle() {
return isColumnarShuffle;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,35 @@

package org.apache.spark.shuffle.celeborn;

import java.io.IOException;

import org.apache.spark.ShuffleDependency;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.celeborn.reflect.DynFields;

public class CustomShuffleDependencyUtils {

private static final Logger logger = LoggerFactory.getLogger(CustomShuffleDependencyUtils.class);

/**
* Columnar Shuffle requires a field, `ShuffleDependency#schema`, which does not exist in vanilla
* Spark.
*/
private static final DynFields.UnboundField<StructType> SCHEMA_FIELD =
DynFields.builder().hiddenImpl(ShuffleDependency.class, "schema").defaultAlwaysNull().build();

public static StructType getSchema(ShuffleDependency<?, ?, ?> dep) throws IOException {
StructType schema = SCHEMA_FIELD.bind(dep).get();
public static StructType getSchema(ShuffleDependency<?, ?, ?> dep) {
StructType schema = null;
try {
schema = SCHEMA_FIELD.bind(dep).get();
} catch (Exception e) {
logger.error("Failed to bind shuffle dependency of shuffle {}.", dep.shuffleId(), e);
}
if (schema == null) {
throw new IOException("Failed to get Schema, columnar shuffle won't work properly.");
logger.warn(
"Failed to get Schema of shuffle {}, columnar shuffle won't work properly.",
dep.shuffleId());
}
return schema;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,11 @@ class CelebornColumnarShuffleReader[K, C](

override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
val schema = CustomShuffleDependencyUtils.getSchema(dep)
if (CelebornBatchBuilder.supportsColumnarType(
schema)) {
val dataSize = SparkUtils.getDataSize(
dep.serializer.asInstanceOf[UnsafeRowSerializer])
if (schema != null && CelebornBatchBuilder.supportsColumnarType(schema)) {
logInfo(s"Creating column batch serializer of columnar shuffle ${dep.shuffleId}.")
val dataSize = SparkUtils.getDataSize(dep.serializer.asInstanceOf[UnsafeRowSerializer])
new CelebornColumnarBatchSerializer(
schema,
conf.columnarShuffleBatchSize,
conf.columnarShuffleDictionaryEnabled,
conf.columnarShuffleOffHeapEnabled,
dataSize).newInstance()
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ abstract class CelebornBatchBuilder {

def writeRow(row: InternalRow): Unit

def getRowCnt(): Int
def getRowCnt: Int

def int2ByteArray(i: Int): Array[Byte] = {
val result = new Array[Byte](4)
Expand All @@ -46,7 +46,7 @@ object CelebornBatchBuilder {
f.dataType match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | StringType => true
case dt: DecimalType => true
case _: DecimalType => true
case _ => false
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,9 @@ abstract class CelebornBasicColumnAccessor[JvmType](
columnType.extract(buffer, row, ordinal)
}

protected def underlyingBuffer = buffer
protected def underlyingBuffer: ByteBuffer = buffer
}

class CelebornNullColumnAccessor(buffer: ByteBuffer)
extends CelebornBasicColumnAccessor[Any](buffer, CELEBORN_NULL)
with CelebornNullableColumnAccessor

abstract class CelebornNativeColumnAccessor[T <: AtomicType](
override protected val buffer: ByteBuffer,
override protected val columnType: NativeCelebornColumnType[T])
Expand Down Expand Up @@ -112,7 +108,6 @@ private[sql] object CelebornColumnAccessor {
val buf = buffer.order(ByteOrder.nativeOrder)

dataType match {
case NullType => new CelebornNullColumnAccessor(buf)
case BooleanType => new CelebornBooleanColumnAccessor(buf)
case ByteType => new CelebornByteColumnAccessor(buf)
case ShortType => new CelebornShortColumnAccessor(buf)
Expand All @@ -135,7 +130,7 @@ private[sql] object CelebornColumnAccessor {
columnAccessor match {
case nativeAccessor: CelebornNativeColumnAccessor[_] =>
nativeAccessor.decompress(columnVector, numRows)
case d: CelebornDecimalColumnAccessor =>
case _: CelebornDecimalColumnAccessor =>
(0 until numRows).foreach(columnAccessor.extractToColumnVector(columnVector, _))
case _ =>
throw new RuntimeException("Not support non-primitive type now")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ class CelebornBasicColumnBuilder[JvmType](
}
}

class CelebornNullColumnBuilder
extends CelebornBasicColumnBuilder[Any](new CelebornObjectColumnStats(NullType), CELEBORN_NULL)
with CelebornNullableColumnBuilder

abstract class CelebornComplexColumnBuilder[JvmType](
columnStats: CelebornColumnStats,
columnType: CelebornColumnType[JvmType])
Expand Down Expand Up @@ -318,7 +314,6 @@ class CelebornDecimalCodeGenColumnBuilder(dataType: DecimalType)
}

object CelebornColumnBuilder {
val MAX_BATCH_SIZE_IN_BYTE: Long = 4 * 1024 * 1024L

def ensureFreeSpace(orig: ByteBuffer, size: Int): ByteBuffer = {
if (orig.remaining >= size) {
Expand All @@ -343,7 +338,6 @@ object CelebornColumnBuilder {
encodingEnabled: Boolean,
encoder: Encoder[_ <: AtomicType]): CelebornColumnBuilder = {
val builder: CelebornColumnBuilder = dataType match {
case NullType => new CelebornNullColumnBuilder
case ByteType => new CelebornByteColumnBuilder
case BooleanType => new CelebornBooleanColumnBuilder
case ShortType => new CelebornShortColumnBuilder
Expand All @@ -367,7 +361,7 @@ object CelebornColumnBuilder {
new CelebornCompactDecimalColumnBuilder(dt)
case dt: DecimalType => new CelebornDecimalColumnBuilder(dt)
case other =>
throw new Exception(s"not support type: $other")
throw new Exception(s"Unsupported type: $other")
}

builder.initialize(rowCnt, columnName, encodingEnabled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ final private[columnar] class CelebornBooleanColumnStats extends CelebornColumnS
val value = row.getBoolean(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -87,7 +87,7 @@ final private[columnar] class CelebornByteColumnStats extends CelebornColumnStat
val value = row.getByte(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -111,7 +111,7 @@ final private[columnar] class CelebornShortColumnStats extends CelebornColumnSta
val value = row.getShort(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -135,7 +135,7 @@ final private[columnar] class CelebornIntColumnStats extends CelebornColumnStats
val value = row.getInt(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -159,7 +159,7 @@ final private[columnar] class CelebornLongColumnStats extends CelebornColumnStat
val value = row.getLong(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -183,7 +183,7 @@ final private[columnar] class CelebornFloatColumnStats extends CelebornColumnSta
val value = row.getFloat(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -207,7 +207,7 @@ final private[columnar] class CelebornDoubleColumnStats extends CelebornColumnSt
val value = row.getDouble(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -223,16 +223,16 @@ final private[columnar] class CelebornDoubleColumnStats extends CelebornColumnSt
}

final private[columnar] class CelebornStringColumnStats extends CelebornColumnStats {
protected var upper: UTF8String = null
protected var lower: UTF8String = null
protected var upper: UTF8String = _
protected var lower: UTF8String = _

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getUTF8String(ordinal)
val size = CELEBORN_STRING.actualSize(row, ordinal)
gatherValueStats(value, size)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -247,34 +247,19 @@ final private[columnar] class CelebornStringColumnStats extends CelebornColumnSt
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

final private[columnar] class CelebornBinaryColumnStats extends CelebornColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val size = CELEBORN_BINARY.actualSize(row, ordinal)
sizeInBytes += size
count += 1
} else {
gatherNullStats
}
}

override def collectedStatistics: Array[Any] =
Array[Any](null, null, nullCount, count, sizeInBytes)
}

final private[columnar] class CelebornDecimalColumnStats(precision: Int, scale: Int)
extends CelebornColumnStats {
def this(dt: DecimalType) = this(dt.precision, dt.scale)

protected var upper: Decimal = null
protected var lower: Decimal = null
protected var upper: Decimal = _
protected var lower: Decimal = _

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getDecimal(ordinal, precision, scale)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -294,21 +279,3 @@ final private[columnar] class CelebornDecimalColumnStats(precision: Int, scale:
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

final private[columnar] class CelebornObjectColumnStats(dataType: DataType)
extends CelebornColumnStats {
val columnType = CelebornColumnType(dataType)

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val size = columnType.actualSize(row, ordinal)
sizeInBytes += size
count += 1
} else {
gatherNullStats
}
}

override def collectedStatistics: Array[Any] =
Array[Any](null, null, nullCount, count, sizeInBytes)
}
Loading

0 comments on commit 5a8926c

Please sign in to comment.