Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
winningsix committed Nov 28, 2023
1 parent 78e5804 commit 2ad72c8
Showing 1 changed file with 8 additions and 19 deletions.
27 changes: 8 additions & 19 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.python.AggregateInPandasExec
import org.apache.spark.sql.rapids.TimeZoneDB
import org.apache.spark.sql.rapids.aggregate.{CpuToGpuAggregateBufferConverter, GpuToCpuAggregateBufferConverter}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinMetaBase, GpuBroadcastNestedLoopJoinMetaBase}
import org.apache.spark.sql.types.{ArrayType, DataType, DateType, MapType, StringType, StructType, TimestampNTZType, TimestampType}
import org.apache.spark.sql.types.{ArrayType, DataType, DateType, MapType, StringType, StructType}

trait DataFromReplacementRule {
val operationName: String
Expand Down Expand Up @@ -1150,16 +1150,12 @@ abstract class BaseExprMeta[INPUT <: Expression](
}

// Mostly base on Spark existing [[Cast.needsTimeZone]] method. Two changes are made:
// 1. Backport commit https://github.com/apache/spark/pull/40524 merged since Spark 3.5
// 2. Existing `needsTimezone``` doesn't consider complex types to string which is timezone
// 1. Override date related based on https://github.com/apache/spark/pull/40524 merged
// 2. Existing `needsTimezone` doesn't consider complex types to string which is timezone
// related. (incl. struct/map/list to string).
private[this] def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match {
case (StringType, TimestampType) => true
case (TimestampType, StringType) => true
case (DateType, TimestampType) => true
case (TimestampType, DateType) => true
case (TimestampType, TimestampNTZType) => true
case (TimestampNTZType, TimestampType) => true
case (StringType, DateType) => false
case (DateType, StringType) => false
case (ArrayType(fromType, _), StringType) => needsTimeZone(fromType, to)
case (MapType(fromKey, fromValue, _), StringType) =>
needsTimeZone(fromKey, to) || needsTimeZone(fromValue, to)
Expand All @@ -1168,16 +1164,9 @@ abstract class BaseExprMeta[INPUT <: Expression](
case fromField =>
needsTimeZone(fromField.dataType, to)
}
case (ArrayType(fromType, _), ArrayType(toType, _)) => needsTimeZone(fromType, toType)
case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
needsTimeZone(fromKey, toKey) || needsTimeZone(fromValue, toValue)
case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).exists {
case (fromField, toField) =>
needsTimeZone(fromField.dataType, toField.dataType)
}
case _ => false
// Avoid copying full implementation here. Otherwise needs to create shim for TimestampNTZ
// since Spark 3.4.0
case _ => Cast.needsTimeZone(from, to)
}

// Level 3 timezone checking flag, need to override to true when supports timezone in functions
Expand Down

0 comments on commit 2ad72c8

Please sign in to comment.