From 09f68d23ef1cdf78cc0af8e51c2285d4fe1be615 Mon Sep 17 00:00:00 2001
From: shreyakhajanchi <92910380+shreyakhajanchi@users.noreply.github.com>
Date: Wed, 4 Dec 2024 10:39:40 +0530
Subject: [PATCH] custom transformation implementation (#2040)
* custom transformation implementation
* adding a test
* addressing comments
---
v2/spanner-to-sourcedb/pom.xml | 6 +
.../v2/templates/SpannerToSourceDb.java | 56 +++++-
.../v2/templates/constants/Constants.java | 7 +
.../dbutils/dml/MySQLDMLGenerator.java | 31 +++-
.../processor/InputRecordProcessor.java | 45 ++++-
.../templates/models/DMLGeneratorRequest.java | 15 ++
.../templates/transforms/SourceWriterFn.java | 43 ++++-
.../transforms/SourceWriterTransform.java | 30 +++-
.../dbutils/dml/MySQLDMLGeneratorTest.java | 27 +++
.../transforms/SourceWriterFnTest.java | 155 ++++++++++++++--
.../test/resources/customTransformation.json | 169 ++++++++++++++++++
11 files changed, 548 insertions(+), 36 deletions(-)
create mode 100644 v2/spanner-to-sourcedb/src/test/resources/customTransformation.json
diff --git a/v2/spanner-to-sourcedb/pom.xml b/v2/spanner-to-sourcedb/pom.xml
index 841405fb9a..fd7896b923 100644
--- a/v2/spanner-to-sourcedb/pom.xml
+++ b/v2/spanner-to-sourcedb/pom.xml
@@ -82,6 +82,12 @@
beam-it-jdbc
test
+
+ com.google.cloud.teleport.v2
+ spanner-custom-shard
+ ${project.version}
+ test
+
com.google.cloud.teleport
diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java
index e2cfa9a009..d33ec69b4e 100644
--- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java
+++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java
@@ -29,6 +29,7 @@
import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema;
import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard;
import com.google.cloud.teleport.v2.spanner.migrations.spanner.SpannerSchema;
+import com.google.cloud.teleport.v2.spanner.migrations.transformation.CustomTransformation;
import com.google.cloud.teleport.v2.spanner.migrations.utils.SecretManagerAccessorImpl;
import com.google.cloud.teleport.v2.spanner.migrations.utils.SessionFileReader;
import com.google.cloud.teleport.v2.spanner.migrations.utils.ShardFileReader;
@@ -367,6 +368,53 @@ public interface Options extends PipelineOptions, StreamingOptions {
String getSourceType();
void setSourceType(String value);
+
+ @TemplateParameter.GcsReadFile(
+ order = 25,
+ optional = true,
+ description = "Custom transformation jar location in Cloud Storage",
+ helpText =
+ "Custom jar location in Cloud Storage that contains the custom transformation logic for processing records"
+ + " in reverse replication.")
+ @Default.String("")
+ String getTransformationJarPath();
+
+ void setTransformationJarPath(String value);
+
+ @TemplateParameter.Text(
+ order = 26,
+ optional = true,
+ description = "Custom class name for transformation",
+ helpText =
+ "Fully qualified class name having the custom transformation logic. It is a"
+ + " mandatory field in case transformationJarPath is specified")
+ @Default.String("")
+ String getTransformationClassName();
+
+ void setTransformationClassName(String value);
+
+ @TemplateParameter.Text(
+ order = 27,
+ optional = true,
+ description = "Custom parameters for transformation",
+ helpText =
+ "String containing any custom parameters to be passed to the custom transformation class.")
+ @Default.String("")
+ String getTransformationCustomParameters();
+
+ void setTransformationCustomParameters(String value);
+
+ @TemplateParameter.Text(
+ order = 28,
+ optional = true,
+ description = "Directory name for holding filtered records",
+ helpText =
+ "Records skipped from reverse replication are written to this directory. Default"
+ + " directory name is skip.")
+ @Default.String("filteredEvents")
+ String getFilterEventsDirectoryName();
+
+ void setFilterEventsDirectoryName(String value);
}
/**
@@ -541,6 +589,11 @@ public static PipelineResult run(Options options) {
} else {
mergedRecords = dlqRecords;
}
+ CustomTransformation customTransformation =
+ CustomTransformation.builder(
+ options.getTransformationJarPath(), options.getTransformationClassName())
+ .setCustomParameters(options.getTransformationCustomParameters())
+ .build();
SourceWriterTransform.Result sourceWriterOutput =
mergedRecords
.apply(
@@ -578,7 +631,8 @@ public static PipelineResult run(Options options) {
options.getShadowTablePrefix(),
options.getSkipDirectoryName(),
connectionPoolSizePerWorker,
- options.getSourceType()));
+ options.getSourceType(),
+ customTransformation));
PCollection> dlqPermErrorRecords =
reconsumedElements
diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/constants/Constants.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/constants/Constants.java
index 31ed2b8428..1368a46fe3 100644
--- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/constants/Constants.java
+++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/constants/Constants.java
@@ -59,6 +59,9 @@ public class Constants {
// The Tag for skipped records
public static final TupleTag SKIPPED_TAG = new TupleTag() {};
+ // The Tag for records filtered via custom transformation.
+ public static final TupleTag FILTERED_TAG = new TupleTag() {};
+
// Message written to the file for skipped records
public static final String SKIPPED_TAG_MESSAGE = "Skipped record from reverse replication";
@@ -72,4 +75,8 @@ public class Constants {
public static final String DEFAULT_SHARD_ID = "single_shard";
public static final String SOURCE_MYSQL = "mysql";
+
+ // Message written to the file for filtered records
+ public static final String FILTERED_TAG_MESSAGE =
+ "Filtered record from custom transformation in reverse replication";
}
diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java
index 8774fedfa9..c06917bf87 100644
--- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java
+++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java
@@ -82,7 +82,8 @@ public DMLGeneratorResponse getDMLStatement(DMLGeneratorRequest dmlGeneratorRequ
sourceTable,
dmlGeneratorRequest.getNewValuesJson(),
dmlGeneratorRequest.getKeyValuesJson(),
- dmlGeneratorRequest.getSourceDbTimezoneOffset());
+ dmlGeneratorRequest.getSourceDbTimezoneOffset(),
+ dmlGeneratorRequest.getCustomTransformationResponse());
if (pkcolumnNameValues == null) {
LOG.warn(
"Cannot reverse replicate for table {} without primary key, skipping the record",
@@ -194,7 +195,8 @@ private static DMLGeneratorResponse generateUpsertStatement(
sourceTable,
dmlGeneratorRequest.getNewValuesJson(),
dmlGeneratorRequest.getKeyValuesJson(),
- dmlGeneratorRequest.getSourceDbTimezoneOffset());
+ dmlGeneratorRequest.getSourceDbTimezoneOffset(),
+ dmlGeneratorRequest.getCustomTransformationResponse());
return getUpsertStatement(
sourceTable.getName(),
sourceTable.getPrimaryKeySet(),
@@ -207,7 +209,8 @@ private static Map getColumnValues(
SourceTable sourceTable,
JSONObject newValuesJson,
JSONObject keyValuesJson,
- String sourceDbTimezoneOffset) {
+ String sourceDbTimezoneOffset,
+ Map customTransformationResponse) {
Map response = new HashMap<>();
/*
@@ -224,6 +227,10 @@ private static Map getColumnValues(
as the column will be stored with default/null values
*/
Set sourcePKs = sourceTable.getPrimaryKeySet();
+ Set customTransformColumns = null;
+ if (customTransformationResponse != null) {
+ customTransformColumns = customTransformationResponse.keySet();
+ }
for (Map.Entry entry : sourceTable.getColDefs().entrySet()) {
SourceColumnDefinition sourceColDef = entry.getValue();
@@ -231,6 +238,10 @@ private static Map getColumnValues(
if (sourcePKs.contains(colName)) {
continue; // we only need non-primary keys
}
+ if (customTransformColumns != null && customTransformColumns.contains(colName)) {
+ response.put(colName, customTransformationResponse.get(colName).toString());
+ continue;
+ }
String colId = entry.getKey();
SpannerColumnDefinition spannerColDef = spannerTable.getColDefs().get(colId);
@@ -272,7 +283,8 @@ private static Map getPkColumnValues(
SourceTable sourceTable,
JSONObject newValuesJson,
JSONObject keyValuesJson,
- String sourceDbTimezoneOffset) {
+ String sourceDbTimezoneOffset,
+ Map customTransformationResponse) {
Map response = new HashMap<>();
/*
Get all primary key col ids from source table
@@ -286,6 +298,10 @@ private static Map getPkColumnValues(
if the column does not exist in any of the JSON - return null
*/
ColumnPK[] sourcePKs = sourceTable.getPrimaryKeys();
+ Set customTransformColumns = null;
+ if (customTransformationResponse != null) {
+ customTransformColumns = customTransformationResponse.keySet();
+ }
for (int i = 0; i < sourcePKs.length; i++) {
ColumnPK currentSourcePK = sourcePKs[i];
@@ -298,6 +314,13 @@ private static Map getPkColumnValues(
sourceColDef.getName());
return null;
}
+ if (customTransformColumns != null
+ && customTransformColumns.contains(sourceColDef.getName())) {
+ response.put(
+ sourceColDef.getName(),
+ customTransformationResponse.get(sourceColDef.getName()).toString());
+ continue;
+ }
String spannerColumnName = spannerColDef.getName();
String columnValue = "";
if (keyValuesJson.has(spannerColumnName)) {
diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/InputRecordProcessor.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/InputRecordProcessor.java
index 6fba8c3fe2..9bdfe2bcda 100644
--- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/InputRecordProcessor.java
+++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/InputRecordProcessor.java
@@ -15,7 +15,12 @@
*/
package com.google.cloud.teleport.v2.templates.dbutils.processor;
+import com.google.cloud.teleport.v2.spanner.exceptions.InvalidTransformationException;
+import com.google.cloud.teleport.v2.spanner.migrations.convertors.ChangeEventToMapConvertor;
import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema;
+import com.google.cloud.teleport.v2.spanner.utils.ISpannerMigrationTransformer;
+import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationRequest;
+import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationResponse;
import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord;
import com.google.cloud.teleport.v2.templates.dbutils.dao.source.IDao;
import com.google.cloud.teleport.v2.templates.dbutils.dml.IDMLGenerator;
@@ -23,10 +28,12 @@
import com.google.cloud.teleport.v2.templates.models.DMLGeneratorResponse;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
+import java.util.Map;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Distribution;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.joda.time.Duration;
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -35,14 +42,18 @@
public class InputRecordProcessor {
private static final Logger LOG = LoggerFactory.getLogger(InputRecordProcessor.class);
+ private static final Distribution applyCustomTransformationResponseTimeMetric =
+ Metrics.distribution(
+ InputRecordProcessor.class, "apply_custom_transformation_impl_latency_ms");
- public static void processRecord(
+ public static boolean processRecord(
TrimmedShardedDataChangeRecord spannerRecord,
Schema schema,
IDao dao,
String shardId,
String sourceDbTimezoneOffset,
- IDMLGenerator dmlGenerator)
+ IDMLGenerator dmlGenerator,
+ ISpannerMigrationTransformer spannerToSourceTransformer)
throws Exception {
try {
@@ -53,17 +64,43 @@ public static void processRecord(
String newValueJsonStr = spannerRecord.getMod().getNewValuesJson();
JSONObject newValuesJson = new JSONObject(newValueJsonStr);
JSONObject keysJson = new JSONObject(keysJsonStr);
+ Map customTransformationResponse = null;
+ if (spannerToSourceTransformer != null) {
+ org.joda.time.Instant startTimestamp = org.joda.time.Instant.now();
+ Map mapRequest =
+ ChangeEventToMapConvertor.combineJsonObjects(keysJson, newValuesJson);
+ MigrationTransformationRequest migrationTransformationRequest =
+ new MigrationTransformationRequest(tableName, mapRequest, shardId, modType);
+ MigrationTransformationResponse migrationTransformationResponse = null;
+ try {
+ migrationTransformationResponse =
+ spannerToSourceTransformer.toSourceRow(migrationTransformationRequest);
+ } catch (Exception e) {
+ throw new InvalidTransformationException(e);
+ }
+ org.joda.time.Instant endTimestamp = org.joda.time.Instant.now();
+ applyCustomTransformationResponseTimeMetric.update(
+ new Duration(startTimestamp, endTimestamp).getMillis());
+ if (migrationTransformationResponse.isEventFiltered()) {
+ Metrics.counter(InputRecordProcessor.class, "filtered_events_" + shardId).inc();
+ return true;
+ }
+ if (migrationTransformationResponse != null) {
+ customTransformationResponse = migrationTransformationResponse.getResponseRow();
+ }
+ }
DMLGeneratorRequest dmlGeneratorRequest =
new DMLGeneratorRequest.Builder(
modType, tableName, newValuesJson, keysJson, sourceDbTimezoneOffset)
.setSchema(schema)
+ .setCustomTransformationResponse(customTransformationResponse)
.build();
DMLGeneratorResponse dmlGeneratorResponse = dmlGenerator.getDMLStatement(dmlGeneratorRequest);
if (dmlGeneratorResponse.getDmlStatement().isEmpty()) {
LOG.warn("DML statement is empty for table: " + tableName);
- return;
+ return false;
}
dao.write(dmlGeneratorResponse.getDmlStatement());
@@ -79,7 +116,7 @@ public static void processRecord(
long replicationLag = ChronoUnit.SECONDS.between(commitTsInst, instTime);
lagMetric.update(replicationLag); // update the lag metric
-
+ return false;
} catch (Exception e) {
LOG.error(
"The exception while processing shardId: {} is {} ",
diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/models/DMLGeneratorRequest.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/models/DMLGeneratorRequest.java
index 661f05038d..3db153c51e 100644
--- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/models/DMLGeneratorRequest.java
+++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/models/DMLGeneratorRequest.java
@@ -16,6 +16,7 @@
package com.google.cloud.teleport.v2.templates.models;
import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema;
+import java.util.Map;
import org.json.JSONObject;
/**
@@ -51,6 +52,8 @@ public class DMLGeneratorRequest {
// The timezone offset of the source database, used for handling timezone-specific data.
private final String sourceDbTimezoneOffset;
+ private Map customTransformationResponse;
+
public DMLGeneratorRequest(Builder builder) {
this.modType = builder.modType;
this.spannerTableName = builder.spannerTableName;
@@ -58,6 +61,7 @@ public DMLGeneratorRequest(Builder builder) {
this.newValuesJson = builder.newValuesJson;
this.keyValuesJson = builder.keyValuesJson;
this.sourceDbTimezoneOffset = builder.sourceDbTimezoneOffset;
+ this.customTransformationResponse = builder.customTransformationResponse;
}
public String getModType() {
@@ -84,6 +88,10 @@ public String getSourceDbTimezoneOffset() {
return sourceDbTimezoneOffset;
}
+ public Map getCustomTransformationResponse() {
+ return customTransformationResponse;
+ }
+
public static class Builder {
private final String modType;
private final String spannerTableName;
@@ -91,6 +99,7 @@ public static class Builder {
private final JSONObject keyValuesJson;
private final String sourceDbTimezoneOffset;
private Schema schema;
+ private Map customTransformationResponse;
public Builder(
String modType,
@@ -110,6 +119,12 @@ public Builder setSchema(Schema schema) {
return this;
}
+ public Builder setCustomTransformationResponse(
+ Map customTransformationResponse) {
+ this.customTransformationResponse = customTransformationResponse;
+ return this;
+ }
+
public DMLGeneratorRequest build() {
return new DMLGeneratorRequest(this);
}
diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFn.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFn.java
index f4af6d9780..6b00511bf8 100644
--- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFn.java
+++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFn.java
@@ -23,10 +23,14 @@
import com.google.cloud.teleport.v2.spanner.ddl.Ddl;
import com.google.cloud.teleport.v2.spanner.ddl.IndexColumn;
import com.google.cloud.teleport.v2.spanner.ddl.Table;
+import com.google.cloud.teleport.v2.spanner.exceptions.InvalidTransformationException;
import com.google.cloud.teleport.v2.spanner.migrations.convertors.ChangeEventSpannerConvertor;
import com.google.cloud.teleport.v2.spanner.migrations.exceptions.ChangeEventConvertorException;
import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema;
import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard;
+import com.google.cloud.teleport.v2.spanner.migrations.transformation.CustomTransformation;
+import com.google.cloud.teleport.v2.spanner.migrations.utils.CustomTransformationImplFetcher;
+import com.google.cloud.teleport.v2.spanner.utils.ISpannerMigrationTransformer;
import com.google.cloud.teleport.v2.templates.changestream.ChangeStreamErrorRecord;
import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord;
import com.google.cloud.teleport.v2.templates.constants.Constants;
@@ -75,6 +79,9 @@ public class SourceWriterFn extends DoFn shards;
@@ -86,6 +93,8 @@ public class SourceWriterFn extends DoFn shards,
@@ -96,7 +105,8 @@ public SourceWriterFn(
String shadowTablePrefix,
String skipDirName,
int maxThreadPerDataflowWorker,
- String source) {
+ String source,
+ CustomTransformation customTransformation) {
this.schema = schema;
this.sourceDbTimezoneOffset = sourceDbTimezoneOffset;
@@ -107,6 +117,7 @@ public SourceWriterFn(
this.skipDirName = skipDirName;
this.maxThreadPerDataflowWorker = maxThreadPerDataflowWorker;
this.source = source;
+ this.customTransformation = customTransformation;
}
// for unit testing purposes
@@ -124,6 +135,12 @@ public void setSourceProcessor(SourceProcessor sourceProcessor) {
this.sourceProcessor = sourceProcessor;
}
+ // for unit testing purposes
+ public void setSpannerToSourceTransformer(
+ ISpannerMigrationTransformer spannerToSourceTransformer) {
+ this.spannerToSourceTransformer = spannerToSourceTransformer;
+ }
+
/** Setup function connects to Cloud Spanner. */
@Setup
public void setup() throws UnsupportedSourceException {
@@ -132,6 +149,8 @@ public void setup() throws UnsupportedSourceException {
sourceProcessor =
SourceProcessorFactory.createSourceProcessor(source, shards, maxThreadPerDataflowWorker);
spannerDao = new SpannerDao(spannerConfig);
+ spannerToSourceTransformer =
+ CustomTransformationImplFetcher.getCustomTransformationLogicImpl(customTransformation);
}
/** Teardown function disconnects from the Cloud Spanner. */
@@ -184,13 +203,18 @@ public void processElement(ProcessContext c) {
if (!isSourceAhead) {
IDao sourceDao = sourceProcessor.getSourceDao(shardId);
- InputRecordProcessor.processRecord(
- spannerRec,
- schema,
- sourceDao,
- shardId,
- sourceDbTimezoneOffset,
- sourceProcessor.getDmlGenerator());
+ boolean isEventFiltered =
+ InputRecordProcessor.processRecord(
+ spannerRec,
+ schema,
+ sourceDao,
+ shardId,
+ sourceDbTimezoneOffset,
+ sourceProcessor.getDmlGenerator(),
+ spannerToSourceTransformer);
+ if (isEventFiltered) {
+ outputWithTag(c, Constants.FILTERED_TAG, Constants.FILTERED_TAG_MESSAGE, spannerRec);
+ }
spannerDao.updateShadowTable(
getShadowTableMutation(
@@ -206,6 +230,9 @@ public void processElement(ProcessContext c) {
}
com.google.cloud.Timestamp timestamp = com.google.cloud.Timestamp.now();
c.output(Constants.SUCCESS_TAG, timestamp.toString());
+ } catch (InvalidTransformationException ex) {
+ invalidTransformationException.inc();
+ outputWithTag(c, Constants.PERMANENT_ERROR_TAG, ex.getMessage(), spannerRec);
} catch (ChangeEventConvertorException ex) {
outputWithTag(c, Constants.PERMANENT_ERROR_TAG, ex.getMessage(), spannerRec);
} catch (SpannerException
diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterTransform.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterTransform.java
index e80f617eeb..ef9ddbfaac 100644
--- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterTransform.java
+++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterTransform.java
@@ -19,6 +19,7 @@
import com.google.cloud.teleport.v2.spanner.ddl.Ddl;
import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema;
import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard;
+import com.google.cloud.teleport.v2.spanner.migrations.transformation.CustomTransformation;
import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord;
import com.google.cloud.teleport.v2.templates.constants.Constants;
import com.google.common.base.Preconditions;
@@ -52,6 +53,7 @@ public class SourceWriterTransform
private final String skipDirName;
private final int maxThreadPerDataflowWorker;
private final String source;
+ private final CustomTransformation customTransformation;
public SourceWriterTransform(
List shards,
@@ -62,7 +64,8 @@ public SourceWriterTransform(
String shadowTablePrefix,
String skipDirName,
int maxThreadPerDataflowWorker,
- String source) {
+ String source,
+ CustomTransformation customTransformation) {
this.schema = schema;
this.sourceDbTimezoneOffset = sourceDbTimezoneOffset;
@@ -73,6 +76,7 @@ public SourceWriterTransform(
this.skipDirName = skipDirName;
this.maxThreadPerDataflowWorker = maxThreadPerDataflowWorker;
this.source = source;
+ this.customTransformation = customTransformation;
}
@Override
@@ -91,18 +95,21 @@ public SourceWriterTransform.Result expand(
this.shadowTablePrefix,
this.skipDirName,
this.maxThreadPerDataflowWorker,
- this.source))
+ this.source,
+ this.customTransformation))
.withOutputTags(
Constants.SUCCESS_TAG,
TupleTagList.of(Constants.PERMANENT_ERROR_TAG)
.and(Constants.RETRYABLE_ERROR_TAG)
- .and(Constants.SKIPPED_TAG)));
+ .and(Constants.SKIPPED_TAG)
+ .and(Constants.FILTERED_TAG)));
return Result.create(
sourceWriteResults.get(Constants.SUCCESS_TAG),
sourceWriteResults.get(Constants.PERMANENT_ERROR_TAG),
sourceWriteResults.get(Constants.RETRYABLE_ERROR_TAG),
- sourceWriteResults.get(Constants.SKIPPED_TAG));
+ sourceWriteResults.get(Constants.SKIPPED_TAG),
+ sourceWriteResults.get(Constants.FILTERED_TAG));
}
/** Container class for the results of this transform. */
@@ -113,13 +120,18 @@ private static Result create(
PCollection successfulSourceWrites,
PCollection permanentErrors,
PCollection retryableErrors,
- PCollection skippedSourceWrites) {
+ PCollection skippedSourceWrites,
+ PCollection filteredWrites) {
Preconditions.checkNotNull(successfulSourceWrites);
Preconditions.checkNotNull(permanentErrors);
Preconditions.checkNotNull(retryableErrors);
Preconditions.checkNotNull(skippedSourceWrites);
return new AutoValue_SourceWriterTransform_Result(
- successfulSourceWrites, permanentErrors, retryableErrors, skippedSourceWrites);
+ successfulSourceWrites,
+ permanentErrors,
+ retryableErrors,
+ skippedSourceWrites,
+ filteredWrites);
}
public abstract PCollection successfulSourceWrites();
@@ -130,6 +142,8 @@ private static Result create(
public abstract PCollection skippedSourceWrites();
+ public abstract PCollection filteredWrites();
+
@Override
public void finishSpecifyingOutput(
String transformName, PInput input, PTransform, ?> transform) {
@@ -151,7 +165,9 @@ public Map, PValue> expand() {
Constants.RETRYABLE_ERROR_TAG,
retryableErrors(),
Constants.SKIPPED_TAG,
- skippedSourceWrites());
+ skippedSourceWrites(),
+ Constants.FILTERED_TAG,
+ filteredWrites());
}
}
}
diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java
index b01d139a6d..2d110e80ae 100644
--- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java
+++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java
@@ -1088,6 +1088,33 @@ public void testSpannerColDefsNull() {
assertTrue(sql.isEmpty());
}
+ @Test
+ public void customTransformationMatch() {
+ Schema schema = SessionFileReader.read("src/test/resources/customTransformation.json");
+ String tableName = "Singers";
+ String newValuesString = "{\"FirstName\":\"kk\",\"LastName\":\"ll\"}";
+ JSONObject newValuesJson = new JSONObject(newValuesString);
+ String keyValueString = "{\"SingerId\":\"999\"}";
+ JSONObject keyValuesJson = new JSONObject(keyValueString);
+ String modType = "INSERT";
+ Map customTransformation = new HashMap<>();
+ customTransformation.put("FullName", "\'kk ll\'");
+ customTransformation.put("SingerId", "1");
+
+ MySQLDMLGenerator mySQLDMLGenerator = new MySQLDMLGenerator();
+ DMLGeneratorResponse dmlGeneratorResponse =
+ mySQLDMLGenerator.getDMLStatement(
+ new DMLGeneratorRequest.Builder(
+ modType, tableName, newValuesJson, keyValuesJson, "+00:00")
+ .setSchema(schema)
+ .setCustomTransformationResponse(customTransformation)
+ .build());
+ String sql = dmlGeneratorResponse.getDmlStatement();
+
+ assertTrue(sql.contains("`FullName` = 'kk ll'"));
+ assertTrue(sql.contains("VALUES (1,'kk ll')"));
+ }
+
public static Schema getSchemaObject() {
Map syntheticPKeys = new HashMap();
Map srcSchema = new HashMap();
diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFnTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFnTest.java
index f2bac7247d..039429b0ce 100644
--- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFnTest.java
+++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFnTest.java
@@ -15,6 +15,7 @@
*/
package com.google.cloud.teleport.v2.templates.transforms;
+import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.contains;
import static org.mockito.ArgumentMatchers.eq;
@@ -29,6 +30,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.cloud.Timestamp;
import com.google.cloud.teleport.v2.spanner.ddl.Ddl;
+import com.google.cloud.teleport.v2.spanner.exceptions.InvalidTransformationException;
import com.google.cloud.teleport.v2.spanner.migrations.schema.ColumnPK;
import com.google.cloud.teleport.v2.spanner.migrations.schema.NameAndCols;
import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema;
@@ -38,7 +40,10 @@
import com.google.cloud.teleport.v2.spanner.migrations.schema.SpannerTable;
import com.google.cloud.teleport.v2.spanner.migrations.schema.SyntheticPKey;
import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard;
+import com.google.cloud.teleport.v2.spanner.migrations.transformation.CustomTransformation;
import com.google.cloud.teleport.v2.spanner.migrations.utils.SessionFileReader;
+import com.google.cloud.teleport.v2.spanner.utils.ISpannerMigrationTransformer;
+import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationResponse;
import com.google.cloud.teleport.v2.templates.changestream.ChangeStreamErrorRecord;
import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord;
import com.google.cloud.teleport.v2.templates.constants.Constants;
@@ -63,6 +68,7 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.runners.MethodSorters;
+import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
@@ -76,6 +82,7 @@ public class SourceWriterFnTest {
@Mock HashMap mockDaoMap;
@Mock private SpannerConfig mockSpannerConfig;
@Mock private DoFn.ProcessContext processContext;
+ @Mock private ISpannerMigrationTransformer mockSpannerMigrationTransformer;
private static Gson gson = new Gson();
private Shard testShard;
@@ -148,7 +155,8 @@ public void testSourceIsAhead() throws Exception {
"shadow_",
"skip",
500,
- "mysql");
+ "mysql",
+ null);
ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
sourceWriterFn.setObjectMapper(mapper);
@@ -175,7 +183,8 @@ public void testSourceIsAheadWithSameCommitTimestamp() throws Exception {
"shadow_",
"skip",
500,
- "mysql");
+ "mysql",
+ null);
ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
sourceWriterFn.setObjectMapper(mapper);
@@ -201,7 +210,8 @@ public void testSourceIsBehind() throws Exception {
"shadow_",
"skip",
500,
- "mysql");
+ "mysql",
+ null);
ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
sourceWriterFn.setObjectMapper(mapper);
@@ -213,6 +223,118 @@ public void testSourceIsBehind() throws Exception {
verify(mockSpannerDao, atLeast(1)).updateShadowTable(any());
}
+ @Test
+ public void testCustomTransformationException() throws Exception {
+ TrimmedShardedDataChangeRecord record = getParent1TrimmedDataChangeRecord("shardA");
+ record.setShard("shardA");
+ when(processContext.element()).thenReturn(KV.of(1L, record));
+ when(mockSpannerMigrationTransformer.toSourceRow(any()))
+ .thenThrow(new InvalidTransformationException("some exception"));
+ CustomTransformation customTransformation =
+ CustomTransformation.builder("jarPath", "classPath").build();
+ SourceWriterFn sourceWriterFn =
+ new SourceWriterFn(
+ ImmutableList.of(testShard),
+ testSchema,
+ mockSpannerConfig,
+ testSourceDbTimezoneOffset,
+ testDdl,
+ "shadow_",
+ "skip",
+ 500,
+ "mysql",
+ customTransformation);
+ ObjectMapper mapper = new ObjectMapper();
+ mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
+ sourceWriterFn.setObjectMapper(mapper);
+ sourceWriterFn.setSourceProcessor(sourceProcessor);
+ sourceWriterFn.setSpannerDao(mockSpannerDao);
+ sourceWriterFn.setSpannerToSourceTransformer(mockSpannerMigrationTransformer);
+ sourceWriterFn.processElement(processContext);
+ verify(mockSpannerDao, atLeast(1)).getShadowTableRecord(any(), any());
+ String jsonRec = gson.toJson(record, TrimmedShardedDataChangeRecord.class);
+ ChangeStreamErrorRecord errorRecord =
+ new ChangeStreamErrorRecord(
+ jsonRec,
+ "com.google.cloud.teleport.v2.spanner.exceptions.InvalidTransformationException: some exception");
+ verify(processContext, atLeast(1))
+ .output(
+ Constants.PERMANENT_ERROR_TAG, gson.toJson(errorRecord, ChangeStreamErrorRecord.class));
+ }
+
+ @Test
+ public void testCustomTransformationApplied() throws Exception {
+ TrimmedShardedDataChangeRecord record = getParent1TrimmedDataChangeRecord("shardA");
+ record.setShard("shardA");
+ when(processContext.element()).thenReturn(KV.of(1L, record));
+ when(mockSpannerMigrationTransformer.toSourceRow(any()))
+ .thenReturn(new MigrationTransformationResponse(Map.of("id", "45"), false));
+ CustomTransformation customTransformation =
+ CustomTransformation.builder("jarPath", "classPath").build();
+ SourceWriterFn sourceWriterFn =
+ new SourceWriterFn(
+ ImmutableList.of(testShard),
+ testSchema,
+ mockSpannerConfig,
+ testSourceDbTimezoneOffset,
+ testDdl,
+ "shadow_",
+ "skip",
+ 500,
+ "mysql",
+ customTransformation);
+ ObjectMapper mapper = new ObjectMapper();
+ mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
+ sourceWriterFn.setObjectMapper(mapper);
+ sourceWriterFn.setSourceProcessor(sourceProcessor);
+ sourceWriterFn.setSpannerDao(mockSpannerDao);
+ sourceWriterFn.setSpannerToSourceTransformer(mockSpannerMigrationTransformer);
+ sourceWriterFn.processElement(processContext);
+ ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(String.class);
+ verify(mockSpannerDao, atLeast(1)).getShadowTableRecord(any(), any());
+ verify(mockSqlDao, atLeast(1)).write(argumentCaptor.capture());
+ assertTrue(argumentCaptor.getValue().contains("INSERT INTO `parent1`(`id`) VALUES (45)"));
+ verify(mockSpannerDao, atLeast(1)).updateShadowTable(any());
+ }
+
+ @Test
+ public void testCustomTransformationFiltered() throws Exception {
+ TrimmedShardedDataChangeRecord record = getParent1TrimmedDataChangeRecord("shardA");
+ record.setShard("shardA");
+ when(processContext.element()).thenReturn(KV.of(1L, record));
+ when(mockSpannerMigrationTransformer.toSourceRow(any()))
+ .thenReturn(new MigrationTransformationResponse(null, true));
+ CustomTransformation customTransformation =
+ CustomTransformation.builder("jarPath", "classPath").build();
+ SourceWriterFn sourceWriterFn =
+ new SourceWriterFn(
+ ImmutableList.of(testShard),
+ testSchema,
+ mockSpannerConfig,
+ testSourceDbTimezoneOffset,
+ testDdl,
+ "shadow_",
+ "skip",
+ 500,
+ "mysql",
+ customTransformation);
+ ObjectMapper mapper = new ObjectMapper();
+ mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
+ sourceWriterFn.setObjectMapper(mapper);
+ sourceWriterFn.setSourceProcessor(sourceProcessor);
+ sourceWriterFn.setSpannerDao(mockSpannerDao);
+ sourceWriterFn.setSpannerToSourceTransformer(mockSpannerMigrationTransformer);
+ sourceWriterFn.processElement(processContext);
+ verify(mockSpannerDao, atLeast(1)).getShadowTableRecord(any(), any());
+ verify(mockSqlDao, atLeast(0)).write(any());
+ verify(mockSpannerDao, atLeast(0)).updateShadowTable(any());
+ String jsonRec = gson.toJson(record, TrimmedShardedDataChangeRecord.class);
+ ChangeStreamErrorRecord errorRecord =
+ new ChangeStreamErrorRecord(jsonRec, Constants.FILTERED_TAG_MESSAGE);
+ verify(processContext, atLeast(1))
+ .output(Constants.FILTERED_TAG, gson.toJson(errorRecord, ChangeStreamErrorRecord.class));
+ }
+
@Test
public void testNoShard() throws Exception {
TrimmedShardedDataChangeRecord record = getParent1TrimmedDataChangeRecord("shardA");
@@ -227,7 +349,8 @@ public void testNoShard() throws Exception {
"shadow_",
"skip",
500,
- "mysql");
+ "mysql",
+ null);
ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
sourceWriterFn.setObjectMapper(mapper);
@@ -257,7 +380,8 @@ public void testSkipShard() throws Exception {
"shadow_",
"skip",
500,
- "mysql");
+ "mysql",
+ null);
ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
sourceWriterFn.setObjectMapper(mapper);
@@ -285,7 +409,8 @@ public void testPermanentError() throws Exception {
"shadow_",
"skip",
500,
- "mysql");
+ "mysql",
+ null);
ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
sourceWriterFn.setObjectMapper(mapper);
@@ -317,7 +442,8 @@ public void testRetryableError() throws Exception {
"shadow_",
"skip",
500,
- "mysql");
+ "mysql",
+ null);
ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
sourceWriterFn.setObjectMapper(mapper);
@@ -345,7 +471,8 @@ public void testRetryableErrorForForeignKey() throws Exception {
"shadow_",
"skip",
500,
- "mysql");
+ "mysql",
+ null);
ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
sourceWriterFn.setObjectMapper(mapper);
@@ -375,7 +502,8 @@ public void testRetryableErrorConnectionFailure() throws Exception {
"shadow_",
"skip",
500,
- "mysql");
+ "mysql",
+ null);
ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
sourceWriterFn.setObjectMapper(mapper);
@@ -405,7 +533,8 @@ public void testPermanentConnectionFailure() throws Exception {
"shadow_",
"skip",
500,
- "mysql");
+ "mysql",
+ null);
ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
sourceWriterFn.setObjectMapper(mapper);
@@ -435,7 +564,8 @@ public void testPermanentGenericException() throws Exception {
"shadow_",
"skip",
500,
- "mysql");
+ "mysql",
+ null);
ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
sourceWriterFn.setObjectMapper(mapper);
@@ -464,7 +594,8 @@ public void testDMLEmpty() throws Exception {
"shadow_",
"skip",
500,
- "mysql");
+ "mysql",
+ null);
ObjectMapper mapper = new ObjectMapper();
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
sourceWriterFn.setObjectMapper(mapper);
diff --git a/v2/spanner-to-sourcedb/src/test/resources/customTransformation.json b/v2/spanner-to-sourcedb/src/test/resources/customTransformation.json
new file mode 100644
index 0000000000..320d184e7c
--- /dev/null
+++ b/v2/spanner-to-sourcedb/src/test/resources/customTransformation.json
@@ -0,0 +1,169 @@
+{
+ "SessionName": "NewSession",
+ "EditorName": "",
+ "DatabaseType": "mysql",
+ "DatabaseName": "ui_demo",
+ "Dialect": "google_standard_sql",
+ "Notes": null,
+ "Tags": null,
+ "SpSchema": {
+ "t1": {
+ "Name": "Singers",
+ "ColIds": [
+ "c5",
+ "c6",
+ "c7"
+ ],
+ "ColDefs": {
+ "c5": {
+ "Name": "SingerId",
+ "T": {
+ "Name": "INT64",
+ "Len": 0,
+ "IsArray": false
+ },
+ "NotNull": true,
+ "Comment": "From: SingerId int",
+ "Id": "c5"
+ },
+ "c6": {
+ "Name": "FirstName",
+ "T": {
+ "Name": "STRING",
+ "Len": 1024,
+ "IsArray": false
+ },
+ "NotNull": false,
+ "Comment": "From: FirstName varchar(1024)",
+ "Id": "c6"
+ },
+ "c7": {
+ "Name": "LastName",
+ "T": {
+ "Name": "STRING",
+ "Len": 1024,
+ "IsArray": false
+ },
+ "NotNull": false,
+ "Comment": "From: LastName varchar(1024)",
+ "Id": "c7"
+ }
+ },
+ "PrimaryKeys": [
+ {
+ "ColId": "c5",
+ "Desc": false,
+ "Order": 1
+ }
+ ],
+ "ForeignKeys": null,
+ "Indexes": [
+ {
+ "Name": "ind1",
+ "TableId": "t1",
+ "Unique": false,
+ "Keys": [
+ {
+ "ColId": "c5",
+ "Desc": false,
+ "Order": 1
+ }
+ ],
+ "Id": "i9",
+ "StoredColumnIds": null
+ }
+ ],
+ "ParentId": "",
+ "Comment": "Spanner schema for source table Singers",
+ "Id": "t1"
+ }
+ },
+ "SyntheticPKeys": {},
+ "SrcSchema": {
+ "t1": {
+ "Name": "Singers",
+ "Schema": "ui_demo",
+ "ColIds": [
+ "c5",
+ "c6"
+ ],
+ "ColDefs": {
+ "c5": {
+ "Name": "SingerId",
+ "Type": {
+ "Name": "int",
+ "Mods": null,
+ "ArrayBounds": null
+ },
+ "NotNull": true,
+ "Ignored": {
+ "Check": false,
+ "Identity": false,
+ "Default": false,
+ "Exclusion": false,
+ "ForeignKey": false,
+ "AutoIncrement": false
+ },
+ "Id": "c5"
+ },
+ "c6": {
+ "Name": "FullName",
+ "Type": {
+ "Name": "varchar",
+ "Mods": [
+ 1024
+ ],
+ "ArrayBounds": null
+ },
+ "NotNull": false,
+ "Ignored": {
+ "Check": false,
+ "Identity": false,
+ "Default": false,
+ "Exclusion": false,
+ "ForeignKey": false,
+ "AutoIncrement": false
+ },
+ "Id": "c6"
+ }
+ },
+ "PrimaryKeys": [
+ {
+ "ColId": "c5",
+ "Desc": false,
+ "Order": 1
+ }
+ ],
+ "ForeignKeys": null,
+ "Indexes": [
+ {
+ "Name": "ind1",
+ "Unique": false,
+ "Keys": [
+ {
+ "ColId": "c5",
+ "Desc": false,
+ "Order": 1
+ }
+ ],
+ "Id": "i9",
+ "StoredColumnIds": null
+ }
+ ],
+ "Id": "t1"
+ }
+ },
+ "SchemaIssues": {
+ "t1": {
+ "c5": [
+ 13,
+ 18
+ ]
+ }
+ },
+ "Location": {},
+ "TimezoneOffset": "+00:00",
+ "SpDialect": "google_standard_sql",
+ "UniquePKey": {},
+ "Rules": []
+}
\ No newline at end of file