Skip to content

Commit

Permalink
[GLUTEN-4587][VL] Add config to force fallback on scan of complex type (
Browse files Browse the repository at this point in the history
#4778)

* [VL] Add config to force fallback scan of complex type

* add UT

* set fallback default to be true

* update UTs based on new config
  • Loading branch information
yma11 authored Mar 1, 2024
1 parent 22d9fe3 commit 1ffc9cf
Show file tree
Hide file tree
Showing 8 changed files with 385 additions and 308 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,29 @@ object BackendSettings extends BackendSettingsApi {
}
}

val parquetTypeValidatorWithComplexTypeFallback: PartialFunction[StructField, String] = {
case StructField(_, arrayType: ArrayType, _, _) =>
arrayType.simpleString + " is forced to fallback."
case StructField(_, mapType: MapType, _, _) =>
mapType.simpleString + " is forced to fallback."
case StructField(_, structType: StructType, _, _) =>
structType.simpleString + " is forced to fallback."
}
val orcTypeValidatorWithComplexTypeFallback: PartialFunction[StructField, String] = {
case StructField(_, ByteType, _, _) => "ByteType not support"
case StructField(_, arrayType: ArrayType, _, _) =>
arrayType.simpleString + " is forced to fallback."
case StructField(_, mapType: MapType, _, _) =>
mapType.simpleString + " is forced to fallback."
case StructField(_, structType: StructType, _, _) =>
structType.simpleString + " is forced to fallback."
case StructField(_, stringType: StringType, _, metadata)
if CharVarcharUtils
.getRawTypeString(metadata)
.getOrElse(stringType.catalogString) != stringType.catalogString =>
CharVarcharUtils.getRawTypeString(metadata) + " not support"
case StructField(_, TimestampType, _, _) => "TimestampType not support"
}
format match {
case ParquetReadFormat =>
val typeValidator: PartialFunction[StructField, String] = {
Expand All @@ -103,7 +126,11 @@ object BackendSettings extends BackendSettingsApi {
if mapType.valueType.isInstanceOf[ArrayType] =>
"ArrayType as Value in MapType"
}
validateTypes(typeValidator)
if (!GlutenConfig.getConf.forceComplexTypeScanFallbackEnabled) {
validateTypes(typeValidator)
} else {
validateTypes(parquetTypeValidatorWithComplexTypeFallback)
}
case DwrfReadFormat => ValidationResult.ok
case OrcReadFormat =>
if (!GlutenConfig.getConf.veloxOrcScanEnabled) {
Expand All @@ -130,7 +157,11 @@ object BackendSettings extends BackendSettingsApi {
CharVarcharUtils.getRawTypeString(metadata) + " not support"
case StructField(_, TimestampType, _, _) => "TimestampType not support"
}
validateTypes(typeValidator)
if (!GlutenConfig.getConf.forceComplexTypeScanFallbackEnabled) {
validateTypes(typeValidator)
} else {
validateTypes(orcTypeValidatorWithComplexTypeFallback)
}
}
case _ => ValidationResult.notOk(s"Unsupported file format for $format.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,102 +325,108 @@ class VeloxOrcDataTypeValidationSuite extends VeloxWholeStageTransformerSuite {
}

test("Array type") {
// Validation: BatchScan.
runQueryAndCompare("select array from type1") {
checkOperatorMatch[BatchScanExecTransformer]
}
withSQLConf(("spark.gluten.sql.complexType.scan.fallback.enabled", "false")) {
// Validation: BatchScan.
runQueryAndCompare("select array from type1") {
checkOperatorMatch[BatchScanExecTransformer]
}

// Validation: BatchScan Project Aggregate Expand Sort Limit
runQueryAndCompare(
"select int, array from type1 " +
" group by grouping sets(int, array) sort by array, int limit 1") {
df =>
{
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
}
}
// Validation: BatchScan Project Aggregate Expand Sort Limit
runQueryAndCompare(
"select int, array from type1 " +
" group by grouping sets(int, array) sort by array, int limit 1") {
df =>
{
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
}
}

// Validation: BroadHashJoin, Filter, Project
super.sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "10M")
runQueryAndCompare(
"select type1.array from type1," +
" type2 where type1.array = type2.array") { _ => }
// Validation: BroadHashJoin, Filter, Project
super.sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "10M")
runQueryAndCompare(
"select type1.array from type1," +
" type2 where type1.array = type2.array") { _ => }

// Validation: ShuffledHashJoin, Filter, Project
super.sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
runQueryAndCompare(
"select type1.array from type1," +
" type2 where type1.array = type2.array") { _ => }
// Validation: ShuffledHashJoin, Filter, Project
super.sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
runQueryAndCompare(
"select type1.array from type1," +
" type2 where type1.array = type2.array") { _ => }
}
}

test("Map type") {
// Validation: BatchScan Project Limit
runQueryAndCompare("select map from type1 limit 1") {
df =>
{
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
}
}
// Validation: BatchScan Project Aggregate Sort Limit
// TODO validate Expand operator support map type ?
runQueryAndCompare(
"select map['key'] from type1 group by map['key']" +
" sort by map['key'] limit 1") {
df =>
{
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
}
}
withSQLConf(("spark.gluten.sql.complexType.scan.fallback.enabled", "false")) {
// Validation: BatchScan Project Limit
runQueryAndCompare("select map from type1 limit 1") {
df =>
{
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
}
}
// Validation: BatchScan Project Aggregate Sort Limit
// TODO validate Expand operator support map type ?
runQueryAndCompare(
"select map['key'] from type1 group by map['key']" +
" sort by map['key'] limit 1") {
df =>
{
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
}
}

// Validation: BroadHashJoin, Filter, Project
super.sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "10M")
runQueryAndCompare(
"select type1.map['key'] from type1," +
" type2 where type1.map['key'] = type2.map['key']") { _ => }
// Validation: BroadHashJoin, Filter, Project
super.sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "10M")
runQueryAndCompare(
"select type1.map['key'] from type1," +
" type2 where type1.map['key'] = type2.map['key']") { _ => }

// Validation: ShuffledHashJoin, Filter, Project
super.sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
runQueryAndCompare(
"select type1.map['key'] from type1," +
" type2 where type1.map['key'] = type2.map['key']") { _ => }
// Validation: ShuffledHashJoin, Filter, Project
super.sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
runQueryAndCompare(
"select type1.map['key'] from type1," +
" type2 where type1.map['key'] = type2.map['key']") { _ => }
}
}

test("Struct type") {
// Validation: BatchScan Project Limit
runQueryAndCompare("select struct from type1") {
df =>
{
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
}
}
// Validation: BatchScan Project Aggregate Sort Limit
// TODO validate Expand operator support Struct type ?
runQueryAndCompare(
"select int, struct.struct_1 from type1 " +
"sort by struct.struct_1 limit 1") {
df =>
{
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
assert(executedPlan.exists(plan => plan.isInstanceOf[ProjectExecTransformer]))
}
}
withSQLConf(("spark.gluten.sql.complexType.scan.fallback.enabled", "false")) {
// Validation: BatchScan Project Limit
runQueryAndCompare("select struct from type1") {
df =>
{
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
}
}
// Validation: BatchScan Project Aggregate Sort Limit
// TODO validate Expand operator support Struct type ?
runQueryAndCompare(
"select int, struct.struct_1 from type1 " +
"sort by struct.struct_1 limit 1") {
df =>
{
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
assert(executedPlan.exists(plan => plan.isInstanceOf[ProjectExecTransformer]))
}
}

// Validation: BroadHashJoin, Filter, Project
super.sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "10M")
runQueryAndCompare(
"select type1.struct.struct_1 from type1," +
" type2 where type1.struct.struct_1 = type2.struct.struct_1") { _ => }
// Validation: BroadHashJoin, Filter, Project
super.sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "10M")
runQueryAndCompare(
"select type1.struct.struct_1 from type1," +
" type2 where type1.struct.struct_1 = type2.struct.struct_1") { _ => }

// Validation: ShuffledHashJoin, Filter, Project
super.sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
runQueryAndCompare(
"select type1.struct.struct_1 from type1," +
" type2 where type1.struct.struct_1 = type2.struct.struct_1") { _ => }
// Validation: ShuffledHashJoin, Filter, Project
super.sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
runQueryAndCompare(
"select type1.struct.struct_1 from type1," +
" type2 where type1.struct.struct_1 = type2.struct.struct_1") { _ => }
}
}

test("Decimal type") {
Expand Down
Loading

0 comments on commit 1ffc9cf

Please sign in to comment.