diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/sql/SqlGeneratorDatabricks.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/sql/SqlGeneratorDatabricks.scala new file mode 100644 index 000000000..0d2d6447d --- /dev/null +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/sql/SqlGeneratorDatabricks.scala @@ -0,0 +1,127 @@ +/* + * Copyright 2022 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.pramen.core.sql + +import org.apache.spark.sql.jdbc.JdbcDialects +import org.slf4j.LoggerFactory +import za.co.absa.pramen.api.offset.OffsetValue +import za.co.absa.pramen.api.sql.{SqlColumnType, SqlConfig, SqlGeneratorBase} +import za.co.absa.pramen.core.sql.dialects.DatabricksDialect + +import java.time.format.DateTimeFormatter +import java.time.{LocalDate, LocalDateTime} + +object SqlGeneratorDatabricks { + private val log = LoggerFactory.getLogger(this.getClass) + + /** + * This is required for Spark to be able to handle data that comes from Databricks JDBC drivers + */ + lazy val registerDialect: Boolean = { + log.info(s"Registering Databricks dialect...") + JdbcDialects.registerDialect(DatabricksDialect) + true + } +} + +class SqlGeneratorDatabricks(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfig) { + private val dateFormatterApp = DateTimeFormatter.ofPattern(sqlConfig.dateFormatApp) + + override val beginEndEscapeChars: (Char, Char) = ('`', '`') + + SqlGeneratorDatabricks.registerDialect + + override def getDtable(sql: String): String = { + s"($sql) tbl" + } + + override def getCountQuery(tableName: String): String = { + s"SELECT COUNT(*) FROM ${escape(tableName)}" + } + + override def getCountQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): String = { + val where = getWhere(infoDateBegin, infoDateEnd) + s"SELECT COUNT(*) FROM ${escape(tableName)} WHERE $where" + } + + override def getCountQueryForSql(filteredSql: String): String = { + s"SELECT COUNT(*) FROM ($filteredSql) AS query" + } + + override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = { + s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}" + } + + override def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = { + val where = getWhere(infoDateBegin, infoDateEnd) + s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)} WHERE $where${getLimit(limit)}" + } + + override def getWhere(dateBegin: LocalDate, dateEnd: LocalDate): String = { + val dateBeginLit = getDateLiteral(dateBegin) + val dateEndLit = getDateLiteral(dateEnd) + + val dateTypes: Array[SqlColumnType] = Array(SqlColumnType.DATETIME) + + val infoDateColumnAdjusted = + if (dateTypes.contains(sqlConfig.infoDateType)) { + s"CAST($infoDateColumn AS DATE)" + } else { + infoDateColumn + } + + if (dateBeginLit == dateEndLit) { + s"$infoDateColumnAdjusted = $dateBeginLit" + } else { + s"$infoDateColumnAdjusted >= $dateBeginLit AND $infoDateColumnAdjusted <= $dateEndLit" + } + } + + override def getDateLiteral(date: LocalDate): String = { + sqlConfig.infoDateType match { + case SqlColumnType.DATE => + val dateStr = DateTimeFormatter.ISO_LOCAL_DATE.format(date) + s"to_date('$dateStr')" + case SqlColumnType.DATETIME => + val dateStr = DateTimeFormatter.ISO_LOCAL_DATE.format(date) + s"to_date('$dateStr')" + case SqlColumnType.STRING => + val dateStr = dateFormatterApp.format(date) + s"'$dateStr'" + case SqlColumnType.NUMBER => + val dateStr = dateFormatterApp.format(date) + s"$dateStr" + } + } + + override def getOffsetWhereCondition(column: String, condition: String, offset: OffsetValue): String = { + offset match { + case OffsetValue.DateTimeValue(ts) => + val ldt = LocalDateTime.ofInstant(ts, sqlConfig.serverTimeZone) + val tsLiteral = timestampGenericDbFormatter.format(ldt) + s"$column $condition '$tsLiteral'" + case OffsetValue.IntegralValue(value) => + s"$column $condition $value" + case OffsetValue.StringValue(value) => + s"$column $condition '$value'" + } + } + + private def getLimit(limit: Option[Int]): String = { + limit.map(n => s" LIMIT $n").getOrElse("") + } +} diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/sql/SqlGeneratorLoader.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/sql/SqlGeneratorLoader.scala index 3fdee94d8..7c3847d25 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/sql/SqlGeneratorLoader.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/sql/SqlGeneratorLoader.scala @@ -53,6 +53,7 @@ object SqlGeneratorLoader { case "com.simba.spark.jdbc.Driver" => new SqlGeneratorHive(sqlConfig) case "org.hsqldb.jdbc.JDBCDriver" => new SqlGeneratorHsqlDb(sqlConfig) case "com.ibm.db2.jcc.DB2Driver" => new SqlGeneratorDb2(sqlConfig) + case "com.databricks.client.jdbc.Driver" => new SqlGeneratorDatabricks(sqlConfig) case d => log.warn(s"Unsupported JDBC driver: '$d'. Trying to use a generic SQL generator.") new SqlGeneratorGeneric(sqlConfig) diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/sql/dialects/DatabricksDialect.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/sql/dialects/DatabricksDialect.scala new file mode 100644 index 000000000..fa4c96c1f --- /dev/null +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/sql/dialects/DatabricksDialect.scala @@ -0,0 +1,37 @@ +/* + * Copyright 2022 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.pramen.core.sql.dialects + +import org.apache.spark.sql.jdbc.JdbcDialect +import org.apache.spark.sql.types.{DataType, MetadataBuilder} +import org.slf4j.LoggerFactory + +/** + * This is required for Spark to be able to handle data that comes from Databricks JDBC drivers + */ +object DatabricksDialect extends JdbcDialect { + private val logger = LoggerFactory.getLogger(this.getClass) + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:databricks") + + override def quoteIdentifier(colName: String): String = { + colName.split('.').map(sub => s"`$sub`").mkString(".") + } + + override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + super.getCatalystType(sqlType, typeName, size, md) +} diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/sql/SqlGeneratorDatabricksSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/sql/SqlGeneratorDatabricksSuite.scala new file mode 100644 index 000000000..b3e85b873 --- /dev/null +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/sql/SqlGeneratorDatabricksSuite.scala @@ -0,0 +1,192 @@ +/* + * Copyright 2022 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.pramen.core.tests.sql + +import org.scalatest.wordspec.AnyWordSpec +import za.co.absa.pramen.api.offset.OffsetValue +import za.co.absa.pramen.api.sql.{QuotingPolicy, SqlColumnType, SqlGenerator, SqlGeneratorBase} +import za.co.absa.pramen.core.mocks.DummySqlConfigFactory + +import java.time.{Instant, LocalDate} + +class SqlGeneratorDatabricksSuite extends AnyWordSpec { + + import za.co.absa.pramen.core.sql.SqlGeneratorLoader._ + + private val sqlConfigDate = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.DATE, infoDateColumn = "D") + private val sqlConfigEscape = DummySqlConfigFactory.getDummyConfig(infoDateColumn = "Info date", identifierQuotingPolicy = QuotingPolicy.Always) + private val sqlConfigDateTime = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.DATETIME, infoDateColumn = "D") + private val sqlConfigString = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.STRING, infoDateColumn = "D") + private val sqlConfigNumber = DummySqlConfigFactory.getDummyConfig(infoDateType = SqlColumnType.NUMBER, infoDateColumn = "D", dateFormatApp = "yyyyMMdd") + private val columns = Seq("A", "D", "Column with spaces") + + private val date1 = LocalDate.of(2020, 8, 17) + private val date2 = LocalDate.of(2020, 8, 30) + + val driver = "com.databricks.client.jdbc.Driver" + + val gen: SqlGenerator = getSqlGenerator(driver, sqlConfigDate) + val genStr: SqlGenerator = getSqlGenerator(driver, sqlConfigString) + val genNum: SqlGenerator = getSqlGenerator(driver, sqlConfigNumber) + val genDateTime: SqlGenerator = getSqlGenerator(driver, sqlConfigDateTime) + val genEscaped: SqlGenerator = getSqlGenerator(driver, sqlConfigEscape) + val genEscaped2: SqlGenerator = getSqlGenerator(driver, DummySqlConfigFactory.getDummyConfig(infoDateColumn = "`Info date`", identifierQuotingPolicy = QuotingPolicy.Auto)) + + "generate count queries without date ranges" in { + assert(gen.getCountQuery("A") == "SELECT COUNT(*) FROM A") + } + + "generate data queries without date ranges" in { + assert(gen.getDataQuery("A", Nil, None) == "SELECT * FROM A") + } + + "generate data queries when list of columns is specified" in { + assert(genEscaped.getDataQuery("A", columns, None) == "SELECT `A`, `D`, `Column with spaces` FROM `A`") + } + + "generate data queries with limit clause date ranges" in { + assert(gen.getDataQuery("A", Nil, Some(100)) == "SELECT * FROM A LIMIT 100") + } + + "generate ranged count queries" when { + "date is in DATE format" in { + assert(gen.getCountQuery("A", date1, date1) == + "SELECT COUNT(*) FROM A WHERE D = to_date('2020-08-17')") + assert(gen.getCountQuery("A", date1, date2) == + "SELECT COUNT(*) FROM A WHERE D >= to_date('2020-08-17') AND D <= to_date('2020-08-30')") + } + + "date is in DATETIME format" in { + assert(genDateTime.getCountQuery("A", date1, date1) == + "SELECT COUNT(*) FROM A WHERE CAST(D AS DATE) = to_date('2020-08-17')") + assert(genDateTime.getCountQuery("A", date1, date2) == + "SELECT COUNT(*) FROM A WHERE CAST(D AS DATE) >= to_date('2020-08-17') AND CAST(D AS DATE) <= to_date('2020-08-30')") + } + + "date is in STRING format" in { + assert(genStr.getCountQuery("A", date1, date1) == + "SELECT COUNT(*) FROM A WHERE D = '2020-08-17'") + assert(genStr.getCountQuery("A", date1, date2) == + "SELECT COUNT(*) FROM A WHERE D >= '2020-08-17' AND D <= '2020-08-30'") + } + + "date is in NUMBER format" in { + assert(genNum.getCountQuery("A", date1, date1) == + "SELECT COUNT(*) FROM A WHERE D = 20200817") + assert(genNum.getCountQuery("A", date1, date2) == + "SELECT COUNT(*) FROM A WHERE D >= 20200817 AND D <= 20200830") + } + + "the table name and column name need to be escaped" in { + assert(genEscaped.getCountQuery("Input Table", date1, date1) == + "SELECT COUNT(*) FROM `Input Table` WHERE `Info date` = to_date('2020-08-17')") + assert(genEscaped.getCountQuery("Input Table", date1, date2) == + "SELECT COUNT(*) FROM `Input Table` WHERE `Info date` >= to_date('2020-08-17') AND `Info date` <= to_date('2020-08-30')") + } + + "the table name and column name already escaped" in { + assert(genEscaped2.getCountQuery("Input Table", date1, date1) == + "SELECT COUNT(*) FROM `Input Table` WHERE `Info date` = to_date('2020-08-17')") + assert(genEscaped2.getCountQuery("Input Table", date1, date2) == + "SELECT COUNT(*) FROM `Input Table` WHERE `Info date` >= to_date('2020-08-17') AND `Info date` <= to_date('2020-08-30')") + } + } + + "generate ranged data queries" when { + "date is in DATE format" in { + assert(gen.getDataQuery("A", date1, date1, Nil, None) == + "SELECT * FROM A WHERE D = to_date('2020-08-17')") + assert(gen.getDataQuery("A", date1, date2, Nil, None) == + "SELECT * FROM A WHERE D >= to_date('2020-08-17') AND D <= to_date('2020-08-30')") + } + + "date is in DATETIME format" in { + assert(genDateTime.getDataQuery("A", date1, date1, Nil, None) == + "SELECT * FROM A WHERE CAST(D AS DATE) = to_date('2020-08-17')") + assert(genDateTime.getDataQuery("A", date1, date2, Nil, None) == + "SELECT * FROM A WHERE CAST(D AS DATE) >= to_date('2020-08-17') AND CAST(D AS DATE) <= to_date('2020-08-30')") + } + + "date is in STRING format" in { + assert(genStr.getDataQuery("A", date1, date1, Nil, None) == + "SELECT * FROM A WHERE D = '2020-08-17'") + assert(genStr.getDataQuery("A", date1, date2, Nil, None) == + "SELECT * FROM A WHERE D >= '2020-08-17' AND D <= '2020-08-30'") + } + + "date is in NUMBER format" in { + assert(genNum.getDataQuery("A", date1, date1, Nil, None) == + "SELECT * FROM A WHERE D = 20200817") + assert(genNum.getDataQuery("A", date1, date2, Nil, None) == + "SELECT * FROM A WHERE D >= 20200817 AND D <= 20200830") + } + + "with limit records" in { + assert(gen.getDataQuery("A", date1, date1, Nil, Some(100)) == + "SELECT * FROM A WHERE D = to_date('2020-08-17') LIMIT 100") + assert(gen.getDataQuery("A", date1, date2, Nil, Some(100)) == + "SELECT * FROM A WHERE D >= to_date('2020-08-17') AND D <= to_date('2020-08-30') LIMIT 100") + } + } + + "getCountQueryForSql" should { + "generate count queries for an SQL subquery" in { + assert(gen.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query") + } + } + + "getDtable" should { + "return the original table when a table is provided" in { + assert(gen.getDtable("A") == "(A) tbl") + } + + "wrapped query without alias for SQL queries " in { + assert(gen.getDtable("SELECT A FROM B") == "(SELECT A FROM B) tbl") + } + } + + "quote" should { + "escape each subfields separately" in { + val actual = gen.quote("System User.`Table Name`") + + assert(actual == "`System User`.`Table Name`") + } + } + + "getOffsetWhereCondition" should { + "return the correct condition for integral offsets" in { + val actual = gen.asInstanceOf[SqlGeneratorBase] + .getOffsetWhereCondition("offset", "<", OffsetValue.IntegralValue(1)) + + assert(actual == "offset < 1") + } + + "return the correct condition for datetime offsets" in { + val actual = gen.asInstanceOf[SqlGeneratorBase] + .getOffsetWhereCondition("offset", ">", OffsetValue.DateTimeValue(Instant.ofEpochMilli(1727761000))) + + assert(actual == "offset > '1970-01-21 01:56:01.000'") + } + + "return the correct condition for string offsets" in { + val actual = gen.asInstanceOf[SqlGeneratorBase] + .getOffsetWhereCondition("offset", ">=", OffsetValue.StringValue("AAA")) + + assert(actual == "offset >= 'AAA'") + } + } +}