Skip to content

Commit

Permalink
Fix the broadcast joins issues caused by InputFileBlockRule[databrick…
Browse files Browse the repository at this point in the history
…s] (#9673)

InputFileBlockRule may change the meta of a broadcast join and its child plans, and this change may break the rule of the broadcast join running on GPU, leading to errors. Because GPU broadcast joins require the build side BroadcastExchangeExec running on GPU, and similarly if BroadcastExchangeExec runs on CPU, the broadcast joins should also run on CPU.

Change made:

Optimize the InputFileBlockRule by skipping the BroadcastExchangeLike because the file info cannot come from a broadcast. (This idea is from #9473)
Check the tagging for broadcast joins again after applying the InputFileBlockRule to fix the potential break.
Some API refactor, moving all input file related methods into the InputFileBlockRule object.
---------

Signed-off-by: Firestarman <[email protected]>
  • Loading branch information
firestarman authored Nov 20, 2023
1 parent 15e58aa commit 9ed98c8
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 70 deletions.
71 changes: 71 additions & 0 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,3 +1056,74 @@ def do_join(spark):
right = spark.read.parquet(data_path2)
return right.filter("cast(id2 as bigint) % 3 = 4").join(left, left.id == right.id, "inner")
assert_gpu_and_cpu_are_equal_collect(do_join, bloom_filter_confs)


@ignore_order(local=True)
@allow_non_gpu("ProjectExec", "FilterExec", "BroadcastHashJoinExec", "ColumnarToRowExec", "BroadcastExchangeExec")
@pytest.mark.parametrize("disable_build", [True, False])
def test_broadcast_hash_join_fix_fallback_by_inputfile(spark_tmp_path, disable_build):
data_path_parquet = spark_tmp_path + "/parquet"
data_path_orc = spark_tmp_path + "/orc"
# The smaller one (orc) will be the build side (a broadcast)
with_cpu_session(lambda spark: spark.range(100).write.orc(data_path_orc))
with_cpu_session(lambda spark: spark.range(10000).withColumn("id2", col("id") + 10)
.write.parquet(data_path_parquet))
def do_join(spark):
left = spark.read.parquet(data_path_parquet)
right = spark.read.orc(data_path_orc)
return left.join(broadcast(right), "id", "inner")\
.selectExpr("*", "input_file_block_length()")

if disable_build:
# To reproduce the error
# '''
# java.lang.IllegalStateException: the broadcast must be on the GPU too
# at com.nvidia.spark.rapids.shims.GpuBroadcastJoinMeta.verifyBuildSideWasReplaced...
# '''
scan_name = 'OrcScan'
else:
# An additional case that the exec contains the input file expression is not disabled
# by InputFileBlockRule mistakenly. When the stream side scan runs on CPU, but the
# build side scan runs on GPU, the InputFileBlockRule will not put the exec on
# CPU, leading to wrong output.
scan_name = 'ParquetScan'
assert_gpu_and_cpu_are_equal_collect(
do_join,
conf={"spark.sql.autoBroadcastJoinThreshold": "10M",
"spark.sql.sources.useV1SourceList": "",
"spark.rapids.sql.input." + scan_name: False})


@ignore_order(local=True)
@allow_non_gpu("ProjectExec", "BroadcastNestedLoopJoinExec", "ColumnarToRowExec", "BroadcastExchangeExec")
@pytest.mark.parametrize("disable_build", [True, False])
def test_broadcast_nested_join_fix_fallback_by_inputfile(spark_tmp_path, disable_build):
data_path_parquet = spark_tmp_path + "/parquet"
data_path_orc = spark_tmp_path + "/orc"
# The smaller one (orc) will be the build side (a broadcast)
with_cpu_session(lambda spark: spark.range(50).write.orc(data_path_orc))
with_cpu_session(lambda spark: spark.range(500).withColumn("id2", col("id") + 10)
.write.parquet(data_path_parquet))
def do_join(spark):
left = spark.read.parquet(data_path_parquet)
right = spark.read.orc(data_path_orc)
return left.crossJoin(broadcast(right)).selectExpr("*", "input_file_block_length()")

if disable_build:
# To reproduce the error
# '''
# java.lang.IllegalStateException: the broadcast must be on the GPU too
# at com.nvidia.spark.rapids.shims.GpuBroadcastJoinMeta.verifyBuildSideWasReplaced...
# '''
scan_name = 'OrcScan'
else:
# An additional case that the exec contains the input file expression is not disabled
# by InputFileBlockRule mistakenly. When the stream side scan runs on CPU, but the
# build side scan runs on GPU, the InputFileBlockRule will not put the exec on
# CPU, leading to wrong output.
scan_name = 'ParquetScan'
assert_gpu_and_cpu_are_equal_collect(
do_join,
conf={"spark.sql.autoBroadcastJoinThreshold": "-1",
"spark.sql.sources.useV1SourceList": "",
"spark.rapids.sql.input." + scan_name: False})
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.nvidia.spark.rapids.shims.{GpuBatchScanExec, SparkShimImpl}

import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
Expand All @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedC
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExecBase, DropTableExec, ShowTablesExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, Exchange, ReusedExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashedRelationBroadcastMode}
import org.apache.spark.sql.rapids.{GpuDataSourceScanExec, GpuFileSourceScanExec, GpuInputFileBlockLength, GpuInputFileBlockStart, GpuInputFileName, GpuShuffleEnv, GpuTaskMetrics}
import org.apache.spark.sql.rapids.{GpuDataSourceScanExec, GpuFileSourceScanExec, GpuShuffleEnv, GpuTaskMetrics}
import org.apache.spark.sql.rapids.execution.{ExchangeMappingCache, GpuBroadcastExchangeExec, GpuBroadcastExchangeExecBase, GpuBroadcastToRowExec, GpuCustomShuffleReaderExec, GpuHashJoin, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -329,30 +329,16 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
case _ => false
}



/**
* Because we cannot change the executors in spark itself we need to try and account for
* the ones that might have issues with coalesce here.
*/
private def disableCoalesceUntilInput(plan: SparkPlan): Boolean = {
plan.expressions.exists(GpuTransitionOverrides.checkHasInputFileExpressions)
}

private def disableScanUntilInput(exec: Expression): Boolean = {
exec match {
case _: InputFileName => true
case _: InputFileBlockStart => true
case _: InputFileBlockLength => true
case _: GpuInputFileName => true
case _: GpuInputFileBlockStart => true
case _: GpuInputFileBlockLength => true
case e => e.children.exists(disableScanUntilInput)
}
InputFileBlockRule.hasInputFileExpression(plan)
}

private def disableScanUntilInput(plan: SparkPlan): Boolean = {
plan.expressions.exists(disableScanUntilInput)
InputFileBlockRule.hasInputFileExpression(plan)
}

// This walks from the output to the input to look for any uses of InputFileName,
Expand Down Expand Up @@ -841,15 +827,4 @@ object GpuTransitionOverrides {
}
}

/**
* Check the Expression is or has Input File expressions.
* @param exec expression to check
* @return true or false
*/
def checkHasInputFileExpressions(exec: Expression): Boolean = exec match {
case _: InputFileName => true
case _: InputFileBlockStart => true
case _: InputFileBlockLength => true
case e => e.children.exists(checkHasInputFileExpressions)
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,36 +15,33 @@
*/
package com.nvidia.spark.rapids

import scala.collection.mutable.{ArrayBuffer, LinkedHashMap}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.{Expression, InputFileBlockLength, InputFileBlockStart, InputFileName}
import org.apache.spark.sql.execution.{FileSourceScanExec, LeafExecNode, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike}
import org.apache.spark.sql.rapids.{GpuInputFileBlockLength, GpuInputFileBlockStart, GpuInputFileName}

/**
* InputFileBlockRule is to prevent the SparkPlans
* [SparkPlan (with first input_file_xxx expression), FileScan) to run on GPU
*
* See https://github.com/NVIDIA/spark-rapids/issues/3333
* A rule prevents the plans [SparkPlan (with first input_file_xxx expression), FileScan)
* from running on GPU.
* For more details, please go to https://github.com/NVIDIA/spark-rapids/issues/3333.
*/
object InputFileBlockRule {
private type PlanMeta = SparkPlanMeta[SparkPlan]

private def checkHasInputFileExpressions(plan: SparkPlan): Boolean = {
plan.expressions.exists(GpuTransitionOverrides.checkHasInputFileExpressions)
}

// Apply the rule on SparkPlanMeta
def apply(plan: SparkPlanMeta[SparkPlan]) = {
/**
* key: the SparkPlanMeta where has the first input_file_xxx expression
* value: an array of the SparkPlanMeta chain [SparkPlan (with first input_file_xxx), FileScan)
*/
val resultOps = LinkedHashMap[SparkPlanMeta[SparkPlan], ArrayBuffer[SparkPlanMeta[SparkPlan]]]()
def apply(plan: PlanMeta): Unit = {
// key: the SparkPlanMeta where has the first input_file_xxx expression
// value: an array of the SparkPlanMeta chain [SparkPlan (with first input_file_xxx), FileScan)
val resultOps = mutable.LinkedHashMap[PlanMeta, ArrayBuffer[PlanMeta]]()
recursivelyResolve(plan, None, resultOps)

// If we've found some chains, we should prevent the transition.
resultOps.foreach { item =>
item._2.foreach(p => p.inputFilePreventsRunningOnGpu())
resultOps.foreach { case (_, metas) =>
metas.foreach(_.willNotWorkOnGpu("GPU plans may get incorrect file name" +
", or file start or file length from a CPU scan"))
}
}

Expand All @@ -54,39 +51,51 @@ object InputFileBlockRule {
* @param key the SparkPlanMeta with the first input_file_xxx
* @param resultOps the found SparkPlan chain
*/
private def recursivelyResolve(
plan: SparkPlanMeta[SparkPlan],
key: Option[SparkPlanMeta[SparkPlan]],
resultOps: LinkedHashMap[SparkPlanMeta[SparkPlan],
ArrayBuffer[SparkPlanMeta[SparkPlan]]]): Unit = {

private def recursivelyResolve(plan: PlanMeta, key: Option[PlanMeta],
resultOps: mutable.LinkedHashMap[PlanMeta, ArrayBuffer[PlanMeta]]): Unit = {
plan.wrapped match {
case _: ShuffleExchangeExec => // Exchange will invalid the input_file_xxx
case _: ShuffleExchangeLike => // Exchange will invalid the input_file_xxx
key.map(p => resultOps.remove(p)) // Remove the chain from Map
plan.childPlans.foreach(p => recursivelyResolve(p, None, resultOps))
case _: FileSourceScanExec | _: BatchScanExec =>
if (plan.canThisBeReplaced) { // FileScan can be replaced
key.map(p => resultOps.remove(p)) // Remove the chain from Map
}
case _: BroadcastExchangeLike =>
// noop: Don't go any further, the file info cannot come from a broadcast.
case _: LeafExecNode => // We've reached the LeafNode but without any FileScan
key.map(p => resultOps.remove(p)) // Remove the chain from Map
case _ =>
val newKey = if (key.isDefined) {
// The node is in the middle of chain [SparkPlan with input_file_xxx, FileScan)
resultOps.getOrElseUpdate(key.get, new ArrayBuffer[SparkPlanMeta[SparkPlan]]) += plan
resultOps.getOrElseUpdate(key.get, new ArrayBuffer[PlanMeta]) += plan
key
} else { // There is no parent Node who has input_file_xxx
if (checkHasInputFileExpressions(plan.wrapped)) {
// Current node has input_file_xxx. Mark it as the first Node with input_file_xxx
resultOps.getOrElseUpdate(plan, new ArrayBuffer[SparkPlanMeta[SparkPlan]]) += plan
} else { // There is no parent node who has input_file_xxx
if (hasInputFileExpression(plan.wrapped)) {
// Current node has input_file_xxx. Mark it as the first node with input_file_xxx
resultOps.getOrElseUpdate(plan, new ArrayBuffer[PlanMeta]) += plan
Some(plan)
} else {
None
}
}

plan.childPlans.foreach(p => recursivelyResolve(p, newKey, resultOps))
}
}

private def hasInputFileExpression(expr: Expression): Boolean = expr match {
case _: InputFileName => true
case _: InputFileBlockStart => true
case _: InputFileBlockLength => true
case _: GpuInputFileName => true
case _: GpuInputFileBlockStart => true
case _: GpuInputFileBlockLength => true
case e => e.children.exists(hasInputFileExpression)
}

/** Whether a plan has any InputFile{Name, BlockStart, BlockLength} expression. */
def hasInputFileExpression(plan: SparkPlan): Boolean = {
plan.expressions.exists(hasInputFileExpression)
}

}
27 changes: 19 additions & 8 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.command.{DataWritingCommand, RunnableCommand}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
import org.apache.spark.sql.execution.python.AggregateInPandasExec
import org.apache.spark.sql.rapids.aggregate.{CpuToGpuAggregateBufferConverter, GpuToCpuAggregateBufferConverter}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinMetaBase, GpuBroadcastNestedLoopJoinMetaBase}
import org.apache.spark.sql.types.DataType

trait DataFromReplacementRule {
Expand Down Expand Up @@ -170,13 +172,6 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
childRunnableCmds.foreach(_.recursiveSparkPlanRemoved())
}

final def inputFilePreventsRunningOnGpu(): Unit = {
if (canThisBeReplaced) {
willNotWorkOnGpu("Removed by InputFileBlockRule preventing plans " +
"[SparkPlan(with input_file_xxx), FileScan) running on GPU")
}
}

/**
* Call this to indicate that this should not be replaced with a GPU enabled version
* @param because why it should not be replaced.
Expand Down Expand Up @@ -672,6 +667,17 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
}
}

private def fixUpBroadcastJoins(): Unit = {
childPlans.foreach(_.fixUpBroadcastJoins())
wrapped match {
case _: BroadcastHashJoinExec =>
this.asInstanceOf[GpuBroadcastHashJoinMetaBase].checkTagForBuildSide()
case _: BroadcastNestedLoopJoinExec =>
this.asInstanceOf[GpuBroadcastNestedLoopJoinMetaBase].checkTagForBuildSide()
case _ => // noop
}
}

/**
* Run rules that happen for the entire tree after it has been tagged initially.
*/
Expand All @@ -693,7 +699,7 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
// So input_file_xxx in the following GPU operators will get empty value.
// InputFileBlockRule is to prevent the SparkPlans
// [SparkPlan (with first input_file_xxx expression), FileScan) to run on GPU
InputFileBlockRule.apply(this.asInstanceOf[SparkPlanMeta[SparkPlan]])
InputFileBlockRule(this.asInstanceOf[SparkPlanMeta[SparkPlan]])

// 2) For shuffles, avoid replacing the shuffle if the child is not going to be replaced.
fixUpExchangeOverhead()
Expand All @@ -702,6 +708,11 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
// WriteFilesExec is a new operator from Spark version 340,
// Did not extract a shim code for simplicity
tagChildAccordingToParent(this.asInstanceOf[SparkPlanMeta[SparkPlan]], "WriteFilesExec")

// 4) InputFileBlockRule may change the meta of broadcast join and its child plans,
// and this change may cause mismatch between the join and its build side
// BroadcastExchangeExec, leading to errors. Need to fix the mismatch.
fixUpBroadcastJoins()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ abstract class GpuBroadcastHashJoinMetaBase(
}
}

// Called in runAfterTagRules for a special post tagging for this broadcast join.
def checkTagForBuildSide(): Unit = {
val Seq(leftChild, rightChild) = childPlans
val buildSideMeta = buildSide match {
case GpuBuildLeft => leftChild
case GpuBuildRight => rightChild
}
// Check both of the conditions to avoid duplicate reason string.
if (!canThisBeReplaced && canBuildSideBeReplaced(buildSideMeta)) {
buildSideMeta.willNotWorkOnGpu("the BroadcastHashJoin this feeds is not on the GPU")
}
if (canThisBeReplaced && !canBuildSideBeReplaced(buildSideMeta)) {
willNotWorkOnGpu("the broadcast for this join must be on the GPU too")
}
}

def convertToGpu(): GpuExec
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,22 @@ abstract class GpuBroadcastNestedLoopJoinMetaBase(
"the BroadcastNestedLoopJoin this feeds is not on the GPU")
}
}

// Called in runAfterTagRules for a special post tagging for this broadcast join.
def checkTagForBuildSide(): Unit = {
val Seq(leftChild, rightChild) = childPlans
val buildSideMeta = gpuBuildSide match {
case GpuBuildLeft => leftChild
case GpuBuildRight => rightChild
}
// Check both of the conditions to avoid duplicate reason string.
if (!canThisBeReplaced && canBuildSideBeReplaced(buildSideMeta)) {
buildSideMeta.willNotWorkOnGpu("the BroadcastNestedLoopJoin this feeds is not on the GPU")
}
if (canThisBeReplaced && !canBuildSideBeReplaced(buildSideMeta)) {
willNotWorkOnGpu("the broadcast for this join must be on the GPU too")
}
}
}

/**
Expand Down

0 comments on commit 9ed98c8

Please sign in to comment.