Skip to content

Commit

Permalink
Added SqlGeneratorDatabricks with DatabricksDialect and databricks dr…
Browse files Browse the repository at this point in the history
…iver processing
  • Loading branch information
ValeriiKhalimendik committed Jan 16, 2025
1 parent c2ae2e4 commit 065561b
Show file tree
Hide file tree
Showing 4 changed files with 357 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.core.sql

import org.apache.spark.sql.jdbc.JdbcDialects
import org.slf4j.LoggerFactory
import za.co.absa.pramen.api.offset.OffsetValue
import za.co.absa.pramen.api.sql.{SqlColumnType, SqlConfig, SqlGeneratorBase}
import za.co.absa.pramen.core.sql.dialects.DatabricksDialect

import java.time.format.DateTimeFormatter
import java.time.{LocalDate, LocalDateTime}

object SqlGeneratorDatabricks {
private val log = LoggerFactory.getLogger(this.getClass)

/**
* This is required for Spark to be able to handle data that comes from Databricks JDBC drivers
*/
lazy val registerDialect: Boolean = {
log.info(s"Registering Databricks dialect...")
JdbcDialects.registerDialect(DatabricksDialect)
true
}
}

class SqlGeneratorDatabricks(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfig) {
private val dateFormatterApp = DateTimeFormatter.ofPattern(sqlConfig.dateFormatApp)

override val beginEndEscapeChars: (Char, Char) = ('`', '`')

SqlGeneratorDatabricks.registerDialect

override def getDtable(sql: String): String = {
s"($sql) tbl"
}

override def getCountQuery(tableName: String): String = {
s"SELECT COUNT(*) FROM ${escape(tableName)}"
}

override def getCountQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): String = {
val where = getWhere(infoDateBegin, infoDateEnd)
s"SELECT COUNT(*) FROM ${escape(tableName)} WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}

override def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = {
val where = getWhere(infoDateBegin, infoDateEnd)
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)} WHERE $where${getLimit(limit)}"
}

override def getWhere(dateBegin: LocalDate, dateEnd: LocalDate): String = {
val dateBeginLit = getDateLiteral(dateBegin)
val dateEndLit = getDateLiteral(dateEnd)

val dateTypes: Array[SqlColumnType] = Array(SqlColumnType.DATETIME)

val infoDateColumnAdjusted =
if (dateTypes.contains(sqlConfig.infoDateType)) {
s"CAST($infoDateColumn AS DATE)"
} else {
infoDateColumn
}

if (dateBeginLit == dateEndLit) {
s"$infoDateColumnAdjusted = $dateBeginLit"
} else {
s"$infoDateColumnAdjusted >= $dateBeginLit AND $infoDateColumnAdjusted <= $dateEndLit"
}
}

override def getDateLiteral(date: LocalDate): String = {
sqlConfig.infoDateType match {
case SqlColumnType.DATE =>
val dateStr = DateTimeFormatter.ISO_LOCAL_DATE.format(date)
s"to_date('$dateStr')"
case SqlColumnType.DATETIME =>
val dateStr = DateTimeFormatter.ISO_LOCAL_DATE.format(date)
s"to_date('$dateStr')"
case SqlColumnType.STRING =>
val dateStr = dateFormatterApp.format(date)
s"'$dateStr'"
case SqlColumnType.NUMBER =>
val dateStr = dateFormatterApp.format(date)
s"$dateStr"
}
}

override def getOffsetWhereCondition(column: String, condition: String, offset: OffsetValue): String = {
offset match {
case OffsetValue.DateTimeValue(ts) =>
val ldt = LocalDateTime.ofInstant(ts, sqlConfig.serverTimeZone)
val tsLiteral = timestampGenericDbFormatter.format(ldt)
s"$column $condition '$tsLiteral'"
case OffsetValue.IntegralValue(value) =>
s"$column $condition $value"
case OffsetValue.StringValue(value) =>
s"$column $condition '$value'"
}
}

private def getLimit(limit: Option[Int]): String = {
limit.map(n => s" LIMIT $n").getOrElse("")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ object SqlGeneratorLoader {
case "com.simba.spark.jdbc.Driver" => new SqlGeneratorHive(sqlConfig)
case "org.hsqldb.jdbc.JDBCDriver" => new SqlGeneratorHsqlDb(sqlConfig)
case "com.ibm.db2.jcc.DB2Driver" => new SqlGeneratorDb2(sqlConfig)
case "com.databricks.client.jdbc.Driver" => new SqlGeneratorDatabricks(sqlConfig)
case d =>
log.warn(s"Unsupported JDBC driver: '$d'. Trying to use a generic SQL generator.")
new SqlGeneratorGeneric(sqlConfig)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.core.sql.dialects

import org.apache.spark.sql.jdbc.JdbcDialect
import org.apache.spark.sql.types.{DataType, MetadataBuilder}
import org.slf4j.LoggerFactory

/**
* This is required for Spark to be able to handle data that comes from Databricks JDBC drivers
*/
object DatabricksDialect extends JdbcDialect {
private val logger = LoggerFactory.getLogger(this.getClass)

override def canHandle(url: String): Boolean = url.startsWith("jdbc:databricks")

override def quoteIdentifier(colName: String): String = {
colName.split('.').map(sub => s"`$sub`").mkString(".")
}

override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
super.getCatalystType(sqlType, typeName, size, md)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.core.tests.sql

import org.scalatest.wordspec.AnyWordSpec
import za.co.absa.pramen.api.offset.OffsetValue
import za.co.absa.pramen.api.sql.{QuotingPolicy, SqlColumnType, SqlGenerator, SqlGeneratorBase}
import za.co.absa.pramen.core.mocks.DummySqlConfigFactory

import java.time.{Instant, LocalDate}

class SqlGeneratorDatabricksSuite extends AnyWordSpec {

import za.co.absa.pramen.core.sql.SqlGeneratorLoader._

private val sqlConfigDate = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.DATE, infoDateColumn = "D")
private val sqlConfigEscape = DummySqlConfigFactory.getDummyConfig(infoDateColumn = "Info date", identifierQuotingPolicy = QuotingPolicy.Always)
private val sqlConfigDateTime = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.DATETIME, infoDateColumn = "D")
private val sqlConfigString = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.STRING, infoDateColumn = "D")
private val sqlConfigNumber = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.NUMBER, infoDateColumn = "D", dateFormatApp = "yyyyMMdd")
private val columns = Seq("A", "D", "Column with spaces")

private val date1 = LocalDate.of(2020, 8, 17)
private val date2 = LocalDate.of(2020, 8, 30)

val driver = "com.databricks.client.jdbc.Driver"

val gen: SqlGenerator = getSqlGenerator(driver, sqlConfigDate)
val genStr: SqlGenerator = getSqlGenerator(driver, sqlConfigString)
val genNum: SqlGenerator = getSqlGenerator(driver, sqlConfigNumber)
val genDateTime: SqlGenerator = getSqlGenerator(driver, sqlConfigDateTime)
val genEscaped: SqlGenerator = getSqlGenerator(driver, sqlConfigEscape)
val genEscaped2: SqlGenerator = getSqlGenerator(driver, DummySqlConfigFactory.getDummyConfig(infoDateColumn = "`Info date`", identifierQuotingPolicy = QuotingPolicy.Auto))

"generate count queries without date ranges" in {
assert(gen.getCountQuery("A") == "SELECT COUNT(*) FROM A")
}

"generate data queries without date ranges" in {
assert(gen.getDataQuery("A", Nil, None) == "SELECT * FROM A")
}

"generate data queries when list of columns is specified" in {
assert(genEscaped.getDataQuery("A", columns, None) == "SELECT `A`, `D`, `Column with spaces` FROM `A`")
}

"generate data queries with limit clause date ranges" in {
assert(gen.getDataQuery("A", Nil, Some(100)) == "SELECT * FROM A LIMIT 100")
}

"generate ranged count queries" when {
"date is in DATE format" in {
assert(gen.getCountQuery("A", date1, date1) ==
"SELECT COUNT(*) FROM A WHERE D = to_date('2020-08-17')")
assert(gen.getCountQuery("A", date1, date2) ==
"SELECT COUNT(*) FROM A WHERE D >= to_date('2020-08-17') AND D <= to_date('2020-08-30')")
}

"date is in DATETIME format" in {
assert(genDateTime.getCountQuery("A", date1, date1) ==
"SELECT COUNT(*) FROM A WHERE CAST(D AS DATE) = to_date('2020-08-17')")
assert(genDateTime.getCountQuery("A", date1, date2) ==
"SELECT COUNT(*) FROM A WHERE CAST(D AS DATE) >= to_date('2020-08-17') AND CAST(D AS DATE) <= to_date('2020-08-30')")
}

"date is in STRING format" in {
assert(genStr.getCountQuery("A", date1, date1) ==
"SELECT COUNT(*) FROM A WHERE D = '2020-08-17'")
assert(genStr.getCountQuery("A", date1, date2) ==
"SELECT COUNT(*) FROM A WHERE D >= '2020-08-17' AND D <= '2020-08-30'")
}

"date is in NUMBER format" in {
assert(genNum.getCountQuery("A", date1, date1) ==
"SELECT COUNT(*) FROM A WHERE D = 20200817")
assert(genNum.getCountQuery("A", date1, date2) ==
"SELECT COUNT(*) FROM A WHERE D >= 20200817 AND D <= 20200830")
}

"the table name and column name need to be escaped" in {
assert(genEscaped.getCountQuery("Input Table", date1, date1) ==
"SELECT COUNT(*) FROM `Input Table` WHERE `Info date` = to_date('2020-08-17')")
assert(genEscaped.getCountQuery("Input Table", date1, date2) ==
"SELECT COUNT(*) FROM `Input Table` WHERE `Info date` >= to_date('2020-08-17') AND `Info date` <= to_date('2020-08-30')")
}

"the table name and column name already escaped" in {
assert(genEscaped2.getCountQuery("Input Table", date1, date1) ==
"SELECT COUNT(*) FROM `Input Table` WHERE `Info date` = to_date('2020-08-17')")
assert(genEscaped2.getCountQuery("Input Table", date1, date2) ==
"SELECT COUNT(*) FROM `Input Table` WHERE `Info date` >= to_date('2020-08-17') AND `Info date` <= to_date('2020-08-30')")
}
}

"generate ranged data queries" when {
"date is in DATE format" in {
assert(gen.getDataQuery("A", date1, date1, Nil, None) ==
"SELECT * FROM A WHERE D = to_date('2020-08-17')")
assert(gen.getDataQuery("A", date1, date2, Nil, None) ==
"SELECT * FROM A WHERE D >= to_date('2020-08-17') AND D <= to_date('2020-08-30')")
}

"date is in DATETIME format" in {
assert(genDateTime.getDataQuery("A", date1, date1, Nil, None) ==
"SELECT * FROM A WHERE CAST(D AS DATE) = to_date('2020-08-17')")
assert(genDateTime.getDataQuery("A", date1, date2, Nil, None) ==
"SELECT * FROM A WHERE CAST(D AS DATE) >= to_date('2020-08-17') AND CAST(D AS DATE) <= to_date('2020-08-30')")
}

"date is in STRING format" in {
assert(genStr.getDataQuery("A", date1, date1, Nil, None) ==
"SELECT * FROM A WHERE D = '2020-08-17'")
assert(genStr.getDataQuery("A", date1, date2, Nil, None) ==
"SELECT * FROM A WHERE D >= '2020-08-17' AND D <= '2020-08-30'")
}

"date is in NUMBER format" in {
assert(genNum.getDataQuery("A", date1, date1, Nil, None) ==
"SELECT * FROM A WHERE D = 20200817")
assert(genNum.getDataQuery("A", date1, date2, Nil, None) ==
"SELECT * FROM A WHERE D >= 20200817 AND D <= 20200830")
}

"with limit records" in {
assert(gen.getDataQuery("A", date1, date1, Nil, Some(100)) ==
"SELECT * FROM A WHERE D = to_date('2020-08-17') LIMIT 100")
assert(gen.getDataQuery("A", date1, date2, Nil, Some(100)) ==
"SELECT * FROM A WHERE D >= to_date('2020-08-17') AND D <= to_date('2020-08-30') LIMIT 100")
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(gen.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(gen.getDtable("A") == "(A) tbl")
}

"wrapped query without alias for SQL queries " in {
assert(gen.getDtable("SELECT A FROM B") == "(SELECT A FROM B) tbl")
}
}

"quote" should {
"escape each subfields separately" in {
val actual = gen.quote("System User.`Table Name`")

assert(actual == "`System User`.`Table Name`")
}
}

"getOffsetWhereCondition" should {
"return the correct condition for integral offsets" in {
val actual = gen.asInstanceOf[SqlGeneratorBase]
.getOffsetWhereCondition("offset", "<", OffsetValue.IntegralValue(1))

assert(actual == "offset < 1")
}

"return the correct condition for datetime offsets" in {
val actual = gen.asInstanceOf[SqlGeneratorBase]
.getOffsetWhereCondition("offset", ">", OffsetValue.DateTimeValue(Instant.ofEpochMilli(1727761000)))

assert(actual == "offset > '1970-01-21 01:56:01.000'")
}

"return the correct condition for string offsets" in {
val actual = gen.asInstanceOf[SqlGeneratorBase]
.getOffsetWhereCondition("offset", ">=", OffsetValue.StringValue("AAA"))

assert(actual == "offset >= 'AAA'")
}
}
}

0 comments on commit 065561b

Please sign in to comment.