-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added SqlGeneratorDatabricks with DatabricksDialect and databricks dr…
…iver processing
- Loading branch information
1 parent
c2ae2e4
commit 1e7b868
Showing
4 changed files
with
357 additions
and
0 deletions.
There are no files selected for viewing
127 changes: 127 additions & 0 deletions
127
pramen/core/src/main/scala/za/co/absa/pramen/core/sql/SqlGeneratorDatabricks.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/* | ||
* Copyright 2022 ABSA Group Limited | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package za.co.absa.pramen.core.sql | ||
|
||
import org.apache.spark.sql.jdbc.JdbcDialects | ||
import org.slf4j.LoggerFactory | ||
import za.co.absa.pramen.api.offset.OffsetValue | ||
import za.co.absa.pramen.api.sql.{SqlColumnType, SqlConfig, SqlGeneratorBase} | ||
import za.co.absa.pramen.core.sql.dialects.DatabricksDialect | ||
|
||
import java.time.format.DateTimeFormatter | ||
import java.time.{LocalDate, LocalDateTime} | ||
|
||
object SqlGeneratorDatabricks { | ||
private val log = LoggerFactory.getLogger(this.getClass) | ||
|
||
/** | ||
* This is required for Spark to be able to handle data that comes from Databricks JDBC drivers | ||
*/ | ||
lazy val registerDialect: Boolean = { | ||
log.info(s"Registering Databricks dialect...") | ||
JdbcDialects.registerDialect(DatabricksDialect) | ||
true | ||
} | ||
} | ||
|
||
class SqlGeneratorDatabricks(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfig) { | ||
private val dateFormatterApp = DateTimeFormatter.ofPattern(sqlConfig.dateFormatApp) | ||
|
||
override val beginEndEscapeChars: (Char, Char) = ('`', '`') | ||
|
||
SqlGeneratorDatabricks.registerDialect | ||
|
||
override def getDtable(sql: String): String = { | ||
s"($sql) tbl" | ||
} | ||
|
||
override def getCountQuery(tableName: String): String = { | ||
s"SELECT COUNT(*) FROM ${escape(tableName)}" | ||
} | ||
|
||
override def getCountQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): String = { | ||
val where = getWhere(infoDateBegin, infoDateEnd) | ||
s"SELECT COUNT(*) FROM ${escape(tableName)} WHERE $where" | ||
} | ||
|
||
override def getCountQueryForSql(filteredSql: String): String = { | ||
s"SELECT COUNT(*) FROM ($filteredSql) AS query" | ||
} | ||
|
||
override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = { | ||
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}" | ||
} | ||
|
||
override def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = { | ||
val where = getWhere(infoDateBegin, infoDateEnd) | ||
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)} WHERE $where${getLimit(limit)}" | ||
} | ||
|
||
override def getWhere(dateBegin: LocalDate, dateEnd: LocalDate): String = { | ||
val dateBeginLit = getDateLiteral(dateBegin) | ||
val dateEndLit = getDateLiteral(dateEnd) | ||
|
||
val dateTypes: Array[SqlColumnType] = Array(SqlColumnType.DATETIME) | ||
|
||
val infoDateColumnAdjusted = | ||
if (dateTypes.contains(sqlConfig.infoDateType)) { | ||
s"CAST($infoDateColumn AS DATE)" | ||
} else { | ||
infoDateColumn | ||
} | ||
|
||
if (dateBeginLit == dateEndLit) { | ||
s"$infoDateColumnAdjusted = $dateBeginLit" | ||
} else { | ||
s"$infoDateColumnAdjusted >= $dateBeginLit AND $infoDateColumnAdjusted <= $dateEndLit" | ||
} | ||
} | ||
|
||
override def getDateLiteral(date: LocalDate): String = { | ||
sqlConfig.infoDateType match { | ||
case SqlColumnType.DATE => | ||
val dateStr = DateTimeFormatter.ISO_LOCAL_DATE.format(date) | ||
s"to_date('$dateStr')" | ||
case SqlColumnType.DATETIME => | ||
val dateStr = DateTimeFormatter.ISO_LOCAL_DATE.format(date) | ||
s"to_date('$dateStr')" | ||
case SqlColumnType.STRING => | ||
val dateStr = dateFormatterApp.format(date) | ||
s"'$dateStr'" | ||
case SqlColumnType.NUMBER => | ||
val dateStr = dateFormatterApp.format(date) | ||
s"$dateStr" | ||
} | ||
} | ||
|
||
override def getOffsetWhereCondition(column: String, condition: String, offset: OffsetValue): String = { | ||
offset match { | ||
case OffsetValue.DateTimeValue(ts) => | ||
val ldt = LocalDateTime.ofInstant(ts, sqlConfig.serverTimeZone) | ||
val tsLiteral = timestampGenericDbFormatter.format(ldt) | ||
s"$column $condition '$tsLiteral'" | ||
case OffsetValue.IntegralValue(value) => | ||
s"$column $condition $value" | ||
case OffsetValue.StringValue(value) => | ||
s"$column $condition '$value'" | ||
} | ||
} | ||
|
||
private def getLimit(limit: Option[Int]): String = { | ||
limit.map(n => s" LIMIT $n").getOrElse("") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
37 changes: 37 additions & 0 deletions
37
pramen/core/src/main/scala/za/co/absa/pramen/core/sql/dialects/DatabricksDialect.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Copyright 2022 ABSA Group Limited | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package za.co.absa.pramen.core.sql.dialects | ||
|
||
import org.apache.spark.sql.jdbc.JdbcDialect | ||
import org.apache.spark.sql.types.{DataType, MetadataBuilder} | ||
import org.slf4j.LoggerFactory | ||
|
||
/** | ||
* This is required for Spark to be able to handle data that comes from Databricks JDBC drivers | ||
*/ | ||
object DatabricksDialect extends JdbcDialect { | ||
private val logger = LoggerFactory.getLogger(this.getClass) | ||
|
||
override def canHandle(url: String): Boolean = url.startsWith("jdbc:databricks") | ||
|
||
override def quoteIdentifier(colName: String): String = { | ||
colName.split('.').map(sub => s"`$sub`").mkString(".") | ||
} | ||
|
||
override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = | ||
super.getCatalystType(sqlType, typeName, size, md) | ||
} |
192 changes: 192 additions & 0 deletions
192
...en/core/src/test/scala/za/co/absa/pramen/core/tests/sql/SqlGeneratorDatabricksSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
/* | ||
* Copyright 2022 ABSA Group Limited | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package za.co.absa.pramen.core.tests.sql | ||
|
||
import org.scalatest.wordspec.AnyWordSpec | ||
import za.co.absa.pramen.api.offset.OffsetValue | ||
import za.co.absa.pramen.api.sql.{QuotingPolicy, SqlColumnType, SqlGenerator, SqlGeneratorBase} | ||
import za.co.absa.pramen.core.mocks.DummySqlConfigFactory | ||
|
||
import java.time.{Instant, LocalDate} | ||
|
||
class SqlGeneratorDatabricksSuite extends AnyWordSpec { | ||
|
||
import za.co.absa.pramen.core.sql.SqlGeneratorLoader._ | ||
|
||
private val sqlConfigDate = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.DATE, infoDateColumn = "D") | ||
private val sqlConfigEscape = DummySqlConfigFactory.getDummyConfig(infoDateColumn = "Info date", identifierQuotingPolicy = QuotingPolicy.Always) | ||
private val sqlConfigDateTime = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.DATETIME, infoDateColumn = "D") | ||
private val sqlConfigString = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.STRING, infoDateColumn = "D") | ||
private val sqlConfigNumber = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.NUMBER, infoDateColumn = "D", dateFormatApp = "yyyyMMdd") | ||
private val columns = Seq("A", "D", "Column with spaces") | ||
|
||
private val date1 = LocalDate.of(2020, 8, 17) | ||
private val date2 = LocalDate.of(2020, 8, 30) | ||
|
||
val driver = "com.databricks.client.jdbc.Driver" | ||
|
||
val gen: SqlGenerator = getSqlGenerator(driver, sqlConfigDate) | ||
val genStr: SqlGenerator = getSqlGenerator(driver, sqlConfigString) | ||
val genNum: SqlGenerator = getSqlGenerator(driver, sqlConfigNumber) | ||
val genDateTime: SqlGenerator = getSqlGenerator(driver, sqlConfigDateTime) | ||
val genEscaped: SqlGenerator = getSqlGenerator(driver, sqlConfigEscape) | ||
val genEscaped2: SqlGenerator = getSqlGenerator(driver, DummySqlConfigFactory.getDummyConfig(infoDateColumn = "`Info date`", identifierQuotingPolicy = QuotingPolicy.Auto)) | ||
|
||
"generate count queries without date ranges" in { | ||
assert(gen.getCountQuery("A") == "SELECT COUNT(*) FROM A") | ||
} | ||
|
||
"generate data queries without date ranges" in { | ||
assert(gen.getDataQuery("A", Nil, None) == "SELECT * FROM A") | ||
} | ||
|
||
"generate data queries when list of columns is specified" in { | ||
assert(genEscaped.getDataQuery("A", columns, None) == "SELECT `A`, `D`, `Column with spaces` FROM `A`") | ||
} | ||
|
||
"generate data queries with limit clause date ranges" in { | ||
assert(gen.getDataQuery("A", Nil, Some(100)) == "SELECT * FROM A LIMIT 100") | ||
} | ||
|
||
"generate ranged count queries" when { | ||
"date is in DATE format" in { | ||
assert(gen.getCountQuery("A", date1, date1) == | ||
"SELECT COUNT(*) FROM A WHERE D = to_date('2020-08-17')") | ||
assert(gen.getCountQuery("A", date1, date2) == | ||
"SELECT COUNT(*) FROM A WHERE D >= to_date('2020-08-17') AND D <= to_date('2020-08-30')") | ||
} | ||
|
||
"date is in DATETIME format" in { | ||
assert(genDateTime.getCountQuery("A", date1, date1) == | ||
"SELECT COUNT(*) FROM A WHERE CAST(D AS DATE) = to_date('2020-08-17')") | ||
assert(genDateTime.getCountQuery("A", date1, date2) == | ||
"SELECT COUNT(*) FROM A WHERE CAST(D AS DATE) >= to_date('2020-08-17') AND CAST(D AS DATE) <= to_date('2020-08-30')") | ||
} | ||
|
||
"date is in STRING format" in { | ||
assert(genStr.getCountQuery("A", date1, date1) == | ||
"SELECT COUNT(*) FROM A WHERE D = '2020-08-17'") | ||
assert(genStr.getCountQuery("A", date1, date2) == | ||
"SELECT COUNT(*) FROM A WHERE D >= '2020-08-17' AND D <= '2020-08-30'") | ||
} | ||
|
||
"date is in NUMBER format" in { | ||
assert(genNum.getCountQuery("A", date1, date1) == | ||
"SELECT COUNT(*) FROM A WHERE D = 20200817") | ||
assert(genNum.getCountQuery("A", date1, date2) == | ||
"SELECT COUNT(*) FROM A WHERE D >= 20200817 AND D <= 20200830") | ||
} | ||
|
||
"the table name and column name need to be escaped" in { | ||
assert(genEscaped.getCountQuery("Input Table", date1, date1) == | ||
"SELECT COUNT(*) FROM `Input Table` WHERE `Info date` = to_date('2020-08-17')") | ||
assert(genEscaped.getCountQuery("Input Table", date1, date2) == | ||
"SELECT COUNT(*) FROM `Input Table` WHERE `Info date` >= to_date('2020-08-17') AND `Info date` <= to_date('2020-08-30')") | ||
} | ||
|
||
"the table name and column name already escaped" in { | ||
assert(genEscaped2.getCountQuery("Input Table", date1, date1) == | ||
"SELECT COUNT(*) FROM `Input Table` WHERE `Info date` = to_date('2020-08-17')") | ||
assert(genEscaped2.getCountQuery("Input Table", date1, date2) == | ||
"SELECT COUNT(*) FROM `Input Table` WHERE `Info date` >= to_date('2020-08-17') AND `Info date` <= to_date('2020-08-30')") | ||
} | ||
} | ||
|
||
"generate ranged data queries" when { | ||
"date is in DATE format" in { | ||
assert(gen.getDataQuery("A", date1, date1, Nil, None) == | ||
"SELECT * FROM A WHERE D = to_date('2020-08-17')") | ||
assert(gen.getDataQuery("A", date1, date2, Nil, None) == | ||
"SELECT * FROM A WHERE D >= to_date('2020-08-17') AND D <= to_date('2020-08-30')") | ||
} | ||
|
||
"date is in DATETIME format" in { | ||
assert(genDateTime.getDataQuery("A", date1, date1, Nil, None) == | ||
"SELECT * FROM A WHERE CAST(D AS DATE) = to_date('2020-08-17')") | ||
assert(genDateTime.getDataQuery("A", date1, date2, Nil, None) == | ||
"SELECT * FROM A WHERE CAST(D AS DATE) >= to_date('2020-08-17') AND CAST(D AS DATE) <= to_date('2020-08-30')") | ||
} | ||
|
||
"date is in STRING format" in { | ||
assert(genStr.getDataQuery("A", date1, date1, Nil, None) == | ||
"SELECT * FROM A WHERE D = '2020-08-17'") | ||
assert(genStr.getDataQuery("A", date1, date2, Nil, None) == | ||
"SELECT * FROM A WHERE D >= '2020-08-17' AND D <= '2020-08-30'") | ||
} | ||
|
||
"date is in NUMBER format" in { | ||
assert(genNum.getDataQuery("A", date1, date1, Nil, None) == | ||
"SELECT * FROM A WHERE D = 20200817") | ||
assert(genNum.getDataQuery("A", date1, date2, Nil, None) == | ||
"SELECT * FROM A WHERE D >= 20200817 AND D <= 20200830") | ||
} | ||
|
||
"with limit records" in { | ||
assert(gen.getDataQuery("A", date1, date1, Nil, Some(100)) == | ||
"SELECT * FROM A WHERE D = to_date('2020-08-17') LIMIT 100") | ||
assert(gen.getDataQuery("A", date1, date2, Nil, Some(100)) == | ||
"SELECT * FROM A WHERE D >= to_date('2020-08-17') AND D <= to_date('2020-08-30') LIMIT 100") | ||
} | ||
} | ||
|
||
"getCountQueryForSql" should { | ||
"generate count queries for an SQL subquery" in { | ||
assert(gen.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query") | ||
} | ||
} | ||
|
||
"getDtable" should { | ||
"return the original table when a table is provided" in { | ||
assert(gen.getDtable("A") == "(A) tbl") | ||
} | ||
|
||
"wrapped query without alias for SQL queries " in { | ||
assert(gen.getDtable("SELECT A FROM B") == "(SELECT A FROM B) tbl") | ||
} | ||
} | ||
|
||
"quote" should { | ||
"escape each subfields separately" in { | ||
val actual = gen.quote("System User.`Table Name`") | ||
|
||
assert(actual == "`System User`.`Table Name`") | ||
} | ||
} | ||
|
||
"getOffsetWhereCondition" should { | ||
"return the correct condition for integral offsets" in { | ||
val actual = gen.asInstanceOf[SqlGeneratorBase] | ||
.getOffsetWhereCondition("offset", "<", OffsetValue.IntegralValue(1)) | ||
|
||
assert(actual == "offset < 1") | ||
} | ||
|
||
"return the correct condition for datetime offsets" in { | ||
val actual = gen.asInstanceOf[SqlGeneratorBase] | ||
.getOffsetWhereCondition("offset", ">", OffsetValue.DateTimeValue(Instant.ofEpochMilli(1727761000))) | ||
|
||
assert(actual == "offset > '1970-01-21 01:56:01.000'") | ||
} | ||
|
||
"return the correct condition for string offsets" in { | ||
val actual = gen.asInstanceOf[SqlGeneratorBase] | ||
.getOffsetWhereCondition("offset", ">=", OffsetValue.StringValue("AAA")) | ||
|
||
assert(actual == "offset >= 'AAA'") | ||
} | ||
} | ||
} |