Skip to content

Commit

Permalink
#276 Improve unit test coverage of the JDBC source.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Nov 3, 2023
1 parent f616991 commit 928b052
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,31 @@
package za.co.absa.pramen.core.source

import com.typesafe.config.{Config, ConfigFactory}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.wordspec.AnyWordSpec
import za.co.absa.pramen.api.Source
import za.co.absa.pramen.api.{Query, Source}
import za.co.absa.pramen.core.ExternalChannelFactoryReflect
import za.co.absa.pramen.core.base.SparkTestBase
import za.co.absa.pramen.core.fixtures.RelationalDbFixture
import za.co.absa.pramen.core.reader.{TableReaderJdbc, TableReaderJdbcNative}
import za.co.absa.pramen.core.samples.RdbExampleTable

import java.time.LocalDate

class JdbcSourceSuite extends AnyWordSpec with BeforeAndAfterAll with SparkTestBase with RelationalDbFixture {

override def beforeAll(): Unit = {
super.beforeAll()

RdbExampleTable.Company.initTable(getConnection)
}

override protected def afterAll(): Unit = {
RdbExampleTable.Company.dropTable(getConnection)

super.afterAll()
}

class JdbcSourceSuite extends AnyWordSpec with SparkTestBase {
private val conf: Config = ConfigFactory.parseString(
s"""
| pramen {
Expand All @@ -31,13 +50,18 @@ class JdbcSourceSuite extends AnyWordSpec with SparkTestBase {
| name = "jdbc1"
| factory.class = "za.co.absa.pramen.core.source.JdbcSource"
| jdbc {
| driver = "driver1"
| connection.string = "url1"
| user = "user1"
| password = "password1"
| driver = "$driver"
| connection.string = "$url"
| user = "$user"
| password = "$password"
| }
|
| has.information.date.column = true
| save.timestamps.as.dates = true
| correct.decimals.in.schema = true
| correct.decimals.fix.precision = true
| enable.schema.metadata = true
|
| has.information.date.column = false
| information.date.column = "INFO_DATE"
| information.date.type = "date"
| information.date.app.format = "yyyy-MM-DD"
Expand Down Expand Up @@ -178,4 +202,45 @@ class JdbcSourceSuite extends AnyWordSpec with SparkTestBase {

}

"getReader" should {
"return JDBC table reader when a table is specified" in {
val srcConfig = conf.getConfigList("pramen.sources")
val src1Config = srcConfig.get(0)
val src = ExternalChannelFactoryReflect.fromConfig[Source](src1Config, "pramen.sources.0", "source").asInstanceOf[JdbcSource]
val query = Query.Table("company")

val reader = src.getReader(query)

val df = reader.getData(query, infoDateBegin = LocalDate.now(), infoDateEnd = LocalDate.now(), Nil)

assert(reader.isInstanceOf[TableReaderJdbc])
assert(df.schema.fields(1).metadata.getLong("maxLength") == 50L)
}

"return JDBC Native table reader when a SQL query is specified" in {
val srcConfig = conf.getConfigList("pramen.sources")
val src1Config = srcConfig.get(0)
val src = ExternalChannelFactoryReflect.fromConfig[Source](src1Config, "pramen.sources.0", "source").asInstanceOf[JdbcSource]
val query = Query.Sql("SELECT * FROM company")

val reader = src.getReader(query)
val df = reader.getData(query, infoDateBegin = LocalDate.now(), infoDateEnd = LocalDate.now(), Nil)

assert(reader.isInstanceOf[TableReaderJdbcNative])
assert(df.schema.fields(1).metadata.getLong("maxLength") == 50L)
}

"throw an exception on unknown query type" in {
val srcConfig = conf.getConfigList("pramen.sources")
val src1Config = srcConfig.get(0)
val src = ExternalChannelFactoryReflect.fromConfig[Source](src1Config, "pramen.sources.0", "source").asInstanceOf[JdbcSource]

val ex = intercept[IllegalArgumentException] {
src.getReader(Query.Path("/dummy"))
}

assert(ex.getMessage.contains("Unexpected 'path' spec for the JDBC reader. Only 'table' or 'sql' are supported. Config path: pramen.sources.0"))
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class JdbcSparkUtilsSuite extends AnyWordSpec with BeforeAndAfterAll with SparkT
RdbExampleTable.Company.initTable(getConnection)
}

override protected def afterAll(): Unit = {
RdbExampleTable.Company.dropTable(getConnection)
super.afterAll()
}

"addMetadataFromJdbc" should {
"add varchar metadata to Spark fields" in {
val connectionOptions = JdbcSparkUtils.getJdbcOptions(url, jdbcConfig, RdbExampleTable.Company.tableName, Map.empty)
Expand Down

0 comments on commit 928b052

Please sign in to comment.