Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the broadcast joins issues caused by InputFileBlockRule[databricks] #9673

Merged
merged 8 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,3 +1026,74 @@ def do_join(spark):
right = spark.read.parquet(data_path2)
return right.filter("cast(id2 as bigint) % 3 = 4").join(left, left.id == right.id, "inner")
assert_gpu_and_cpu_are_equal_collect(do_join, bloom_filter_confs)


@ignore_order(local=True)
@allow_non_gpu("ProjectExec", "FilterExec", "BroadcastHashJoinExec", "ColumnarToRowExec", "BroadcastExchangeExec")
@pytest.mark.parametrize("disable_build", [True, False])
def test_broadcast_hash_join_fix_fallback_by_inputfile(spark_tmp_path, disable_build):
data_path_parquet = spark_tmp_path + "/parquet"
data_path_orc = spark_tmp_path + "/orc"
# The smaller one (orc) will be the build side (a broadcast)
with_cpu_session(lambda spark: spark.range(100).write.orc(data_path_orc))
with_cpu_session(lambda spark: spark.range(10000).withColumn("id2", col("id") + 10)
.write.parquet(data_path_parquet))
def do_join(spark):
left = spark.read.parquet(data_path_parquet)
right = spark.read.orc(data_path_orc)
return left.join(broadcast(right), "id", "inner")\
.selectExpr("*", "input_file_block_length()")

if disable_build:
# To reproduce the error
# '''
# java.lang.IllegalStateException: the broadcast must be on the GPU too
# at com.nvidia.spark.rapids.shims.GpuBroadcastJoinMeta.verifyBuildSideWasReplaced...
# '''
scan_name = 'OrcScan'
else:
# An additional case that the exec contains the input file expression is not disabled
# by InputFileBlockRule mistakenly. When the stream side scan runs on CPU, but the
# build side scan runs on GPU, the InputFileBlockRule will not put the exec on
# CPU, leading to wrong output.
scan_name = 'ParquetScan'
assert_gpu_and_cpu_are_equal_collect(
do_join,
conf={"spark.sql.autoBroadcastJoinThreshold": "10M",
"spark.sql.sources.useV1SourceList": "",
"spark.rapids.sql.input." + scan_name: False})


@ignore_order(local=True)
@allow_non_gpu("ProjectExec", "BroadcastNestedLoopJoinExec", "ColumnarToRowExec", "BroadcastExchangeExec")
@pytest.mark.parametrize("disable_build", [True, False])
def test_broadcast_nested_join_fix_fallback_by_inputfile(spark_tmp_path, disable_build):
data_path_parquet = spark_tmp_path + "/parquet"
data_path_orc = spark_tmp_path + "/orc"
# The smaller one (orc) will be the build side (a broadcast)
with_cpu_session(lambda spark: spark.range(50).write.orc(data_path_orc))
with_cpu_session(lambda spark: spark.range(500).withColumn("id2", col("id") + 10)
.write.parquet(data_path_parquet))
def do_join(spark):
left = spark.read.parquet(data_path_parquet)
right = spark.read.orc(data_path_orc)
return left.crossJoin(broadcast(right)).selectExpr("*", "input_file_block_length()")

if disable_build:
# To reproduce the error
# '''
# java.lang.IllegalStateException: the broadcast must be on the GPU too
# at com.nvidia.spark.rapids.shims.GpuBroadcastJoinMeta.verifyBuildSideWasReplaced...
# '''
scan_name = 'OrcScan'
else:
# An additional case that the exec contains the input file expression is not disabled
# by InputFileBlockRule mistakenly. When the stream side scan runs on CPU, but the
# build side scan runs on GPU, the InputFileBlockRule will not put the exec on
# CPU, leading to wrong output.
scan_name = 'ParquetScan'
assert_gpu_and_cpu_are_equal_collect(
do_join,
conf={"spark.sql.autoBroadcastJoinThreshold": "-1",
"spark.sql.sources.useV1SourceList": "",
"spark.rapids.sql.input." + scan_name: False})
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.nvidia.spark.rapids.shims.{GpuBatchScanExec, SparkShimImpl}

import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
Expand All @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedC
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExecBase, DropTableExec, ShowTablesExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, Exchange, ReusedExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashedRelationBroadcastMode}
import org.apache.spark.sql.rapids.{GpuDataSourceScanExec, GpuFileSourceScanExec, GpuInputFileBlockLength, GpuInputFileBlockStart, GpuInputFileName, GpuShuffleEnv, GpuTaskMetrics}
import org.apache.spark.sql.rapids.{GpuDataSourceScanExec, GpuFileSourceScanExec, GpuShuffleEnv, GpuTaskMetrics}
import org.apache.spark.sql.rapids.execution.{ExchangeMappingCache, GpuBroadcastExchangeExec, GpuBroadcastExchangeExecBase, GpuBroadcastToRowExec, GpuCustomShuffleReaderExec, GpuHashJoin, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -329,30 +329,16 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
case _ => false
}



/**
* Because we cannot change the executors in spark itself we need to try and account for
* the ones that might have issues with coalesce here.
*/
private def disableCoalesceUntilInput(plan: SparkPlan): Boolean = {
plan.expressions.exists(GpuTransitionOverrides.checkHasInputFileExpressions)
}

private def disableScanUntilInput(exec: Expression): Boolean = {
exec match {
case _: InputFileName => true
case _: InputFileBlockStart => true
case _: InputFileBlockLength => true
case _: GpuInputFileName => true
case _: GpuInputFileBlockStart => true
case _: GpuInputFileBlockLength => true
case e => e.children.exists(disableScanUntilInput)
}
InputFileBlockRule.hasInputFileExpression(plan)
}

private def disableScanUntilInput(plan: SparkPlan): Boolean = {
plan.expressions.exists(disableScanUntilInput)
InputFileBlockRule.hasInputFileExpression(plan)
}

// This walks from the output to the input to look for any uses of InputFileName,
Expand Down Expand Up @@ -841,15 +827,4 @@ object GpuTransitionOverrides {
}
}

/**
* Check the Expression is or has Input File expressions.
* @param exec expression to check
* @return true or false
*/
def checkHasInputFileExpressions(exec: Expression): Boolean = exec match {
case _: InputFileName => true
case _: InputFileBlockStart => true
case _: InputFileBlockLength => true
case e => e.children.exists(checkHasInputFileExpressions)
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,36 +15,33 @@
*/
package com.nvidia.spark.rapids

import scala.collection.mutable.{ArrayBuffer, LinkedHashMap}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.{Expression, InputFileBlockLength, InputFileBlockStart, InputFileName}
import org.apache.spark.sql.execution.{FileSourceScanExec, LeafExecNode, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike}
import org.apache.spark.sql.rapids.{GpuInputFileBlockLength, GpuInputFileBlockStart, GpuInputFileName}

/**
* InputFileBlockRule is to prevent the SparkPlans
* [SparkPlan (with first input_file_xxx expression), FileScan) to run on GPU
*
* See https://github.com/NVIDIA/spark-rapids/issues/3333
* A rule prevents the plans [SparkPlan (with first input_file_xxx expression), FileScan)
* from running on GPU.
* For more details, please go to https://github.com/NVIDIA/spark-rapids/issues/3333.
*/
object InputFileBlockRule {
private type PlanMeta = SparkPlanMeta[SparkPlan]

private def checkHasInputFileExpressions(plan: SparkPlan): Boolean = {
plan.expressions.exists(GpuTransitionOverrides.checkHasInputFileExpressions)
}

// Apply the rule on SparkPlanMeta
def apply(plan: SparkPlanMeta[SparkPlan]) = {
/**
* key: the SparkPlanMeta where has the first input_file_xxx expression
* value: an array of the SparkPlanMeta chain [SparkPlan (with first input_file_xxx), FileScan)
*/
val resultOps = LinkedHashMap[SparkPlanMeta[SparkPlan], ArrayBuffer[SparkPlanMeta[SparkPlan]]]()
def apply(plan: PlanMeta): Unit = {
// key: the SparkPlanMeta where has the first input_file_xxx expression
// value: an array of the SparkPlanMeta chain [SparkPlan (with first input_file_xxx), FileScan)
val resultOps = mutable.LinkedHashMap[PlanMeta, ArrayBuffer[PlanMeta]]()
recursivelyResolve(plan, None, resultOps)

// If we've found some chains, we should prevent the transition.
resultOps.foreach { item =>
item._2.foreach(p => p.inputFilePreventsRunningOnGpu())
resultOps.foreach { case (_, metas) =>
metas.foreach(_.willNotWorkOnGpu("GPU plans may get incorrect file name" +
", or file start or file length from a CPU scan"))
}
}

Expand All @@ -54,39 +51,51 @@ object InputFileBlockRule {
* @param key the SparkPlanMeta with the first input_file_xxx
* @param resultOps the found SparkPlan chain
*/
private def recursivelyResolve(
plan: SparkPlanMeta[SparkPlan],
key: Option[SparkPlanMeta[SparkPlan]],
resultOps: LinkedHashMap[SparkPlanMeta[SparkPlan],
ArrayBuffer[SparkPlanMeta[SparkPlan]]]): Unit = {

private def recursivelyResolve(plan: PlanMeta, key: Option[PlanMeta],
resultOps: mutable.LinkedHashMap[PlanMeta, ArrayBuffer[PlanMeta]]): Unit = {
plan.wrapped match {
case _: ShuffleExchangeExec => // Exchange will invalid the input_file_xxx
case _: ShuffleExchangeLike => // Exchange will invalid the input_file_xxx
key.map(p => resultOps.remove(p)) // Remove the chain from Map
plan.childPlans.foreach(p => recursivelyResolve(p, None, resultOps))
case _: FileSourceScanExec | _: BatchScanExec =>
if (plan.canThisBeReplaced) { // FileScan can be replaced
key.map(p => resultOps.remove(p)) // Remove the chain from Map
}
case _: BroadcastExchangeLike =>
// noop: Don't go any further, the file info cannot come from a broadcast.
case _: LeafExecNode => // We've reached the LeafNode but without any FileScan
key.map(p => resultOps.remove(p)) // Remove the chain from Map
case _ =>
val newKey = if (key.isDefined) {
// The node is in the middle of chain [SparkPlan with input_file_xxx, FileScan)
resultOps.getOrElseUpdate(key.get, new ArrayBuffer[SparkPlanMeta[SparkPlan]]) += plan
resultOps.getOrElseUpdate(key.get, new ArrayBuffer[PlanMeta]) += plan
key
} else { // There is no parent Node who has input_file_xxx
if (checkHasInputFileExpressions(plan.wrapped)) {
// Current node has input_file_xxx. Mark it as the first Node with input_file_xxx
resultOps.getOrElseUpdate(plan, new ArrayBuffer[SparkPlanMeta[SparkPlan]]) += plan
} else { // There is no parent node who has input_file_xxx
if (hasInputFileExpression(plan.wrapped)) {
// Current node has input_file_xxx. Mark it as the first node with input_file_xxx
resultOps.getOrElseUpdate(plan, new ArrayBuffer[PlanMeta]) += plan
Some(plan)
} else {
None
}
}

plan.childPlans.foreach(p => recursivelyResolve(p, newKey, resultOps))
}
}

private def hasInputFileExpression(expr: Expression): Boolean = expr match {
case _: InputFileName => true
case _: InputFileBlockStart => true
case _: InputFileBlockLength => true
case _: GpuInputFileName => true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm why do we still need to return true given it's already converted to Gpu case? Given the reason mentioned above is GPU plans may get incorrect file name or file start or file length from a CPU scan.

Copy link
Collaborator Author

@firestarman firestarman Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be used for two stages during the overiding process. The stage after inserting transitions for row and column may get a InputFileName or a GpuInputFileName.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concerning this issue, we will never get a GpuInputFileName since plan conversion does not happen.

case _: GpuInputFileBlockStart => true
case _: GpuInputFileBlockLength => true
case e => e.children.exists(hasInputFileExpression)
}

/** Whether a plan has any InputFile{Name, BlockStart, BlockLength} expression. */
def hasInputFileExpression(plan: SparkPlan): Boolean = {
plan.expressions.exists(hasInputFileExpression)
}

}
27 changes: 19 additions & 8 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.command.{DataWritingCommand, RunnableCommand}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
import org.apache.spark.sql.execution.python.AggregateInPandasExec
import org.apache.spark.sql.rapids.aggregate.{CpuToGpuAggregateBufferConverter, GpuToCpuAggregateBufferConverter}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinMetaBase, GpuBroadcastNestedLoopJoinMetaBase}
import org.apache.spark.sql.types.DataType

trait DataFromReplacementRule {
Expand Down Expand Up @@ -170,13 +172,6 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
childRunnableCmds.foreach(_.recursiveSparkPlanRemoved())
}

final def inputFilePreventsRunningOnGpu(): Unit = {
if (canThisBeReplaced) {
willNotWorkOnGpu("Removed by InputFileBlockRule preventing plans " +
"[SparkPlan(with input_file_xxx), FileScan) running on GPU")
}
}

/**
* Call this to indicate that this should not be replaced with a GPU enabled version
* @param because why it should not be replaced.
Expand Down Expand Up @@ -672,6 +667,17 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
}
}

private def fixUpBroadcastJoins(): Unit = {
childPlans.foreach(_.fixUpBroadcastJoins())
wrapped match {
case _: BroadcastHashJoinExec =>
this.asInstanceOf[GpuBroadcastHashJoinMetaBase].checkTagForBuildSide()
case _: BroadcastNestedLoopJoinExec =>
this.asInstanceOf[GpuBroadcastNestedLoopJoinMetaBase].checkTagForBuildSide()
case _ => // noop
}
}

/**
* Run rules that happen for the entire tree after it has been tagged initially.
*/
Expand All @@ -693,7 +699,7 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
// So input_file_xxx in the following GPU operators will get empty value.
// InputFileBlockRule is to prevent the SparkPlans
// [SparkPlan (with first input_file_xxx expression), FileScan) to run on GPU
InputFileBlockRule.apply(this.asInstanceOf[SparkPlanMeta[SparkPlan]])
InputFileBlockRule(this.asInstanceOf[SparkPlanMeta[SparkPlan]])

// 2) For shuffles, avoid replacing the shuffle if the child is not going to be replaced.
fixUpExchangeOverhead()
Expand All @@ -702,6 +708,11 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
// WriteFilesExec is a new operator from Spark version 340,
// Did not extract a shim code for simplicity
tagChildAccordingToParent(this.asInstanceOf[SparkPlanMeta[SparkPlan]], "WriteFilesExec")

// 4) InputFileBlockRule may change the meta of broadcast join and its child plans,
// and this change may cause mismatch between the join and its build side
// BroadcastExchangeExec, leading to errors. Need to fix the mismatch.
fixUpBroadcastJoins()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ abstract class GpuBroadcastHashJoinMetaBase(
}
}

// Called in runAfterTagRules for a special post tagging for this broadcast join.
def checkTagForBuildSide(): Unit = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make more sense to move this into GpuBroadcastJoinMeta?

Copy link
Collaborator Author

@firestarman firestarman Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not do that because there are 4 shims for GpuBroadcastJoinMeta, which means I need to duplicate this code 4 times. The current option looks much simpler, only two times.

val Seq(leftChild, rightChild) = childPlans
val buildSideMeta = buildSide match {
case GpuBuildLeft => leftChild
case GpuBuildRight => rightChild
}
// Check both of the conditions to avoid duplicate reason string.
if (!canThisBeReplaced && canBuildSideBeReplaced(buildSideMeta)) {
buildSideMeta.willNotWorkOnGpu("the BroadcastHashJoin this feeds is not on the GPU")
}
if (canThisBeReplaced && !canBuildSideBeReplaced(buildSideMeta)) {
willNotWorkOnGpu("the broadcast for this join must be on the GPU too")
}
}

def convertToGpu(): GpuExec
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ abstract class GpuBroadcastNestedLoopJoinMetaBase(
"the BroadcastNestedLoopJoin this feeds is not on the GPU")
}
}

// Called in runAfterTagRules for a special post tagging for this broadcast join.
def checkTagForBuildSide(): Unit = {
val Seq(leftChild, rightChild) = childPlans
val buildSideMeta = gpuBuildSide match {
case GpuBuildLeft => leftChild
case GpuBuildRight => rightChild
}
// Check both of the conditions to avoid duplicate reason string.
if (!canThisBeReplaced && canBuildSideBeReplaced(buildSideMeta)) {
buildSideMeta.willNotWorkOnGpu("the BroadcastNestedLoopJoin this feeds is not on the GPU")
}
if (canThisBeReplaced && !canBuildSideBeReplaced(buildSideMeta)) {
willNotWorkOnGpu("the broadcast for this join must be on the GPU too")
}
}
}

/**
Expand Down