diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index 21e4a7de81fc..2a6c77f9631e 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -110,6 +110,12 @@ class JoinFuzzer { numGroups(_numGroups) {} }; + static core::PlanNodePtr tryFlipJoinSides(const core::HashJoinNode& joinNode); + static core::PlanNodePtr tryFlipJoinSides( + const core::MergeJoinNode& joinNode); + static core::PlanNodePtr tryFlipJoinSides( + const core::NestedLoopJoinNode& joinNode); + private: static VectorFuzzer::Options getFuzzerOptions() { VectorFuzzer::Options opts; @@ -135,6 +141,10 @@ class JoinFuzzer { // Randomly pick a join type to test. core::JoinType pickJoinType(); + template + static std::pair tryFlipJoinSidesHelper( + const TNode& joinNode); + // Makes the query plan with default settings in JoinFuzzer and value inputs // for both probe and build sides. // @@ -605,67 +615,100 @@ std::optional tryFlipJoinType(core::JoinType joinType) { } } +template +std::pair +JoinFuzzer::tryFlipJoinSidesHelper(const TNode& joinNode) { + core::PlanNodePtr left = joinNode.sources()[0]; + core::PlanNodePtr right = joinNode.sources()[1]; + if (auto leftJoinInput = + std::dynamic_pointer_cast(joinNode.sources()[0])) { + left = JoinFuzzer::tryFlipJoinSides(*leftJoinInput); + } + if (auto rightJoinInput = + std::dynamic_pointer_cast(joinNode.sources()[1])) { + right = JoinFuzzer::tryFlipJoinSides(*rightJoinInput); + } + return make_pair(left, right); +} + // Returns a plan with flipped join sides of the input hash join node. If the -// join type doesn't allow flipping, returns a nullptr. -core::PlanNodePtr tryFlipJoinSides(const core::HashJoinNode& joinNode) { +// inputs of the join node are other hash join nodes, recursively flip the join +// sides of those join nodes as well. If the join type doesn't allow flipping, +// returns a nullptr. +core::PlanNodePtr JoinFuzzer::tryFlipJoinSides( + const core::HashJoinNode& joinNode) { // Null-aware right semi project join doesn't support filter. if (joinNode.filter() && joinNode.joinType() == core::JoinType::kLeftSemiProject && joinNode.isNullAware()) { return nullptr; } + auto flippedJoinType = tryFlipJoinType(joinNode.joinType()); - if (!flippedJoinType.has_value()) { + if (!flippedJoinType) { return nullptr; } + auto [left, right] = + JoinFuzzer::tryFlipJoinSidesHelper(joinNode); return std::make_shared( joinNode.id(), - flippedJoinType.value(), + *flippedJoinType, joinNode.isNullAware(), joinNode.rightKeys(), joinNode.leftKeys(), joinNode.filter(), - joinNode.sources()[1], - joinNode.sources()[0], + right, + left, joinNode.outputType()); } // Returns a plan with flipped join sides of the input merge join node. If the +// inputs of the join node are other merge join nodes, recursively flip the join +// sides of those join nodes as well. If the // join type doesn't allow flipping, returns a nullptr. -core::PlanNodePtr tryFlipJoinSides(const core::MergeJoinNode& joinNode) { +core::PlanNodePtr JoinFuzzer::tryFlipJoinSides( + const core::MergeJoinNode& joinNode) { // Merge join only supports inner and left join, so only inner join can be // flipped. if (joinNode.joinType() != core::JoinType::kInner) { return nullptr; } - auto flippedJoinType = core::JoinType::kInner; + + auto [left, right] = + JoinFuzzer::tryFlipJoinSidesHelper(joinNode); return std::make_shared( joinNode.id(), - flippedJoinType, + core::JoinType::kInner, joinNode.rightKeys(), joinNode.leftKeys(), joinNode.filter(), - joinNode.sources()[1], - joinNode.sources()[0], + right, + left, joinNode.outputType()); } // Returns a plan with flipped join sides of the input nested loop join node. If -// the join type doesn't allow flipping, returns a nullptr. -core::PlanNodePtr tryFlipJoinSides(const core::NestedLoopJoinNode& joinNode) { +// the inputs of the join node are other nested loop join nodes, recursively +// flip the join sides of those join nodes as well. If the join type doesn't +// allow flipping, returns a nullptr. +core::PlanNodePtr JoinFuzzer::tryFlipJoinSides( + const core::NestedLoopJoinNode& joinNode) { auto flippedJoinType = tryFlipJoinType(joinNode.joinType()); - if (!flippedJoinType.has_value()) { + if (!flippedJoinType) { return nullptr; } + auto [left, right] = + JoinFuzzer::tryFlipJoinSidesHelper(joinNode); + return std::make_shared( joinNode.id(), flippedJoinType.value(), joinNode.joinCondition(), - joinNode.sources()[1], - joinNode.sources()[0], + right, + left, joinNode.outputType()); } @@ -819,7 +862,7 @@ void addFlippedJoinPlan( int32_t numGroups = 0) { auto joinNode = std::dynamic_pointer_cast(plan); VELOX_CHECK_NOT_NULL(joinNode); - if (auto flippedPlan = tryFlipJoinSides(*joinNode)) { + if (auto flippedPlan = JoinFuzzer::tryFlipJoinSides(*joinNode)) { plans.push_back(JoinFuzzer::PlanWithSplits{ flippedPlan, probeScanId,