Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1123] Support fallback to non-columnar shuffle for schema that cannot be obtained from shuffle dependency #2101

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import scala.Product2;

import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.TaskContext;
import org.apache.spark.annotation.Private;
import org.apache.spark.serializer.Serializer;
Expand All @@ -32,21 +34,28 @@
import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchCodeGenBuild;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;

@Private
public class ColumnarHashBasedShuffleWriter<K, V, C> extends HashBasedShuffleWriter<K, V, C> {

private CelebornBatchBuilder[] celebornBatchBuilders;
private StructType schema;
private Serializer depSerializer;
private boolean isColumnarShuffle = false;
private int columnarShuffleBatchSize;
private boolean columnarShuffleCodeGenEnabled;
private boolean columnarShuffleDictionaryEnabled;
private double columnarShuffleDictionaryMaxFactor;
private static final Logger logger =
LoggerFactory.getLogger(ColumnarHashBasedShuffleWriter.class);

private final int stageId;
private final int shuffleId;
private final CelebornBatchBuilder[] celebornBatchBuilders;
private final StructType schema;
private final Serializer depSerializer;
private final boolean isColumnarShuffle;
private final int columnarShuffleBatchSize;
private final boolean columnarShuffleCodeGenEnabled;
private final boolean columnarShuffleDictionaryEnabled;
private final double columnarShuffleDictionaryMaxFactor;

public ColumnarHashBasedShuffleWriter(
CelebornShuffleHandle<K, V, C> handle,
Expand All @@ -61,17 +70,21 @@ public ColumnarHashBasedShuffleWriter(
columnarShuffleCodeGenEnabled = conf.columnarShuffleCodeGenEnabled();
columnarShuffleDictionaryEnabled = conf.columnarShuffleDictionaryEnabled();
columnarShuffleDictionaryMaxFactor = conf.columnarShuffleDictionaryMaxFactor();
this.schema = CustomShuffleDependencyUtils.getSchema(handle.dependency());
ShuffleDependency<?, ?, ?> shuffleDependency = handle.dependency();
this.stageId = taskContext.stageId();
this.shuffleId = shuffleDependency.shuffleId();
this.schema = CustomShuffleDependencyUtils.getSchema(shuffleDependency);
this.depSerializer = handle.dependency().serializer();
this.celebornBatchBuilders =
new CelebornBatchBuilder[handle.dependency().partitioner().numPartitions()];
this.isColumnarShuffle = CelebornBatchBuilder.supportsColumnarType(schema);
this.isColumnarShuffle = schema != null && CelebornBatchBuilder.supportsColumnarType(schema);
}

@Override
protected void fastWrite0(scala.collection.Iterator iterator)
throws IOException, InterruptedException {
if (isColumnarShuffle) {
logger.info("Fast columnar write of columnar shuffle {} for stage {}.", shuffleId, stageId);
fastColumnarWrite0(iterator);
} else {
super.fastWrite0(iterator);
Expand Down Expand Up @@ -141,4 +154,9 @@ private void closeColumnarWrite() throws IOException {
}
}
}

@VisibleForTesting
public boolean isColumnarShuffle() {
return isColumnarShuffle;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,35 @@

package org.apache.spark.shuffle.celeborn;

import java.io.IOException;

import org.apache.spark.ShuffleDependency;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.celeborn.reflect.DynFields;

public class CustomShuffleDependencyUtils {

private static final Logger logger = LoggerFactory.getLogger(CustomShuffleDependencyUtils.class);

/**
* Columnar Shuffle requires a field, `ShuffleDependency#schema`, which does not exist in vanilla
* Spark.
*/
private static final DynFields.UnboundField<StructType> SCHEMA_FIELD =
DynFields.builder().hiddenImpl(ShuffleDependency.class, "schema").defaultAlwaysNull().build();

public static StructType getSchema(ShuffleDependency<?, ?, ?> dep) throws IOException {
StructType schema = SCHEMA_FIELD.bind(dep).get();
public static StructType getSchema(ShuffleDependency<?, ?, ?> dep) {
StructType schema = null;
try {
schema = SCHEMA_FIELD.bind(dep).get();
} catch (Exception e) {
logger.error("Failed to bind shuffle dependency of shuffle {}.", dep.shuffleId(), e);
}
if (schema == null) {
throw new IOException("Failed to get Schema, columnar shuffle won't work properly.");
logger.warn(
"Failed to get Schema of shuffle {}, columnar shuffle won't work properly.",
dep.shuffleId());
}
return schema;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,11 @@ class CelebornColumnarShuffleReader[K, C](

override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
val schema = CustomShuffleDependencyUtils.getSchema(dep)
if (CelebornBatchBuilder.supportsColumnarType(
schema)) {
val dataSize = SparkUtils.getDataSize(
dep.serializer.asInstanceOf[UnsafeRowSerializer])
if (schema != null && CelebornBatchBuilder.supportsColumnarType(schema)) {
logInfo(s"Creating column batch serializer of columnar shuffle ${dep.shuffleId}.")
val dataSize = SparkUtils.getDataSize(dep.serializer.asInstanceOf[UnsafeRowSerializer])
new CelebornColumnarBatchSerializer(
schema,
conf.columnarShuffleBatchSize,
conf.columnarShuffleDictionaryEnabled,
conf.columnarShuffleOffHeapEnabled,
dataSize).newInstance()
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ abstract class CelebornBatchBuilder {

def writeRow(row: InternalRow): Unit

def getRowCnt(): Int
def getRowCnt: Int

def int2ByteArray(i: Int): Array[Byte] = {
val result = new Array[Byte](4)
Expand All @@ -46,7 +46,7 @@ object CelebornBatchBuilder {
f.dataType match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | StringType => true
case dt: DecimalType => true
case _: DecimalType => true
case _ => false
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,9 @@ abstract class CelebornBasicColumnAccessor[JvmType](
columnType.extract(buffer, row, ordinal)
}

protected def underlyingBuffer = buffer
protected def underlyingBuffer: ByteBuffer = buffer
}

class CelebornNullColumnAccessor(buffer: ByteBuffer)
extends CelebornBasicColumnAccessor[Any](buffer, CELEBORN_NULL)
with CelebornNullableColumnAccessor

abstract class CelebornNativeColumnAccessor[T <: AtomicType](
override protected val buffer: ByteBuffer,
override protected val columnType: NativeCelebornColumnType[T])
Expand Down Expand Up @@ -112,7 +108,6 @@ private[sql] object CelebornColumnAccessor {
val buf = buffer.order(ByteOrder.nativeOrder)

dataType match {
case NullType => new CelebornNullColumnAccessor(buf)
case BooleanType => new CelebornBooleanColumnAccessor(buf)
case ByteType => new CelebornByteColumnAccessor(buf)
case ShortType => new CelebornShortColumnAccessor(buf)
Expand All @@ -135,7 +130,7 @@ private[sql] object CelebornColumnAccessor {
columnAccessor match {
case nativeAccessor: CelebornNativeColumnAccessor[_] =>
nativeAccessor.decompress(columnVector, numRows)
case d: CelebornDecimalColumnAccessor =>
case _: CelebornDecimalColumnAccessor =>
(0 until numRows).foreach(columnAccessor.extractToColumnVector(columnVector, _))
case _ =>
throw new RuntimeException("Not support non-primitive type now")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ class CelebornBasicColumnBuilder[JvmType](
}
}

class CelebornNullColumnBuilder
extends CelebornBasicColumnBuilder[Any](new CelebornObjectColumnStats(NullType), CELEBORN_NULL)
with CelebornNullableColumnBuilder

abstract class CelebornComplexColumnBuilder[JvmType](
columnStats: CelebornColumnStats,
columnType: CelebornColumnType[JvmType])
Expand Down Expand Up @@ -318,7 +314,6 @@ class CelebornDecimalCodeGenColumnBuilder(dataType: DecimalType)
}

object CelebornColumnBuilder {
val MAX_BATCH_SIZE_IN_BYTE: Long = 4 * 1024 * 1024L

def ensureFreeSpace(orig: ByteBuffer, size: Int): ByteBuffer = {
if (orig.remaining >= size) {
Expand All @@ -343,7 +338,6 @@ object CelebornColumnBuilder {
encodingEnabled: Boolean,
encoder: Encoder[_ <: AtomicType]): CelebornColumnBuilder = {
val builder: CelebornColumnBuilder = dataType match {
case NullType => new CelebornNullColumnBuilder
case ByteType => new CelebornByteColumnBuilder
case BooleanType => new CelebornBooleanColumnBuilder
case ShortType => new CelebornShortColumnBuilder
Expand All @@ -367,7 +361,7 @@ object CelebornColumnBuilder {
new CelebornCompactDecimalColumnBuilder(dt)
case dt: DecimalType => new CelebornDecimalColumnBuilder(dt)
case other =>
throw new Exception(s"not support type: $other")
throw new Exception(s"Unsupported type: $other")
}

builder.initialize(rowCnt, columnName, encodingEnabled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ final private[columnar] class CelebornBooleanColumnStats extends CelebornColumnS
val value = row.getBoolean(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -87,7 +87,7 @@ final private[columnar] class CelebornByteColumnStats extends CelebornColumnStat
val value = row.getByte(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -111,7 +111,7 @@ final private[columnar] class CelebornShortColumnStats extends CelebornColumnSta
val value = row.getShort(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -135,7 +135,7 @@ final private[columnar] class CelebornIntColumnStats extends CelebornColumnStats
val value = row.getInt(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -159,7 +159,7 @@ final private[columnar] class CelebornLongColumnStats extends CelebornColumnStat
val value = row.getLong(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -183,7 +183,7 @@ final private[columnar] class CelebornFloatColumnStats extends CelebornColumnSta
val value = row.getFloat(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -207,7 +207,7 @@ final private[columnar] class CelebornDoubleColumnStats extends CelebornColumnSt
val value = row.getDouble(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -223,16 +223,16 @@ final private[columnar] class CelebornDoubleColumnStats extends CelebornColumnSt
}

final private[columnar] class CelebornStringColumnStats extends CelebornColumnStats {
protected var upper: UTF8String = null
protected var lower: UTF8String = null
protected var upper: UTF8String = _
protected var lower: UTF8String = _

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getUTF8String(ordinal)
val size = CELEBORN_STRING.actualSize(row, ordinal)
gatherValueStats(value, size)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -247,34 +247,19 @@ final private[columnar] class CelebornStringColumnStats extends CelebornColumnSt
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

final private[columnar] class CelebornBinaryColumnStats extends CelebornColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val size = CELEBORN_BINARY.actualSize(row, ordinal)
sizeInBytes += size
count += 1
} else {
gatherNullStats
}
}

override def collectedStatistics: Array[Any] =
Array[Any](null, null, nullCount, count, sizeInBytes)
}

final private[columnar] class CelebornDecimalColumnStats(precision: Int, scale: Int)
extends CelebornColumnStats {
def this(dt: DecimalType) = this(dt.precision, dt.scale)

protected var upper: Decimal = null
protected var lower: Decimal = null
protected var upper: Decimal = _
protected var lower: Decimal = _

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getDecimal(ordinal, precision, scale)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -294,21 +279,3 @@ final private[columnar] class CelebornDecimalColumnStats(precision: Int, scale:
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

final private[columnar] class CelebornObjectColumnStats(dataType: DataType)
extends CelebornColumnStats {
val columnType = CelebornColumnType(dataType)

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val size = columnType.actualSize(row, ordinal)
sizeInBytes += size
count += 1
} else {
gatherNullStats
}
}

override def collectedStatistics: Array[Any] =
Array[Any](null, null, nullCount, count, sizeInBytes)
}
Loading
Loading