From 09f68d23ef1cdf78cc0af8e51c2285d4fe1be615 Mon Sep 17 00:00:00 2001 From: shreyakhajanchi <92910380+shreyakhajanchi@users.noreply.github.com> Date: Wed, 4 Dec 2024 10:39:40 +0530 Subject: [PATCH] custom transformation implementation (#2040) * custom transformation implementation * adding a test * addressing comments --- v2/spanner-to-sourcedb/pom.xml | 6 + .../v2/templates/SpannerToSourceDb.java | 56 +++++- .../v2/templates/constants/Constants.java | 7 + .../dbutils/dml/MySQLDMLGenerator.java | 31 +++- .../processor/InputRecordProcessor.java | 45 ++++- .../templates/models/DMLGeneratorRequest.java | 15 ++ .../templates/transforms/SourceWriterFn.java | 43 ++++- .../transforms/SourceWriterTransform.java | 30 +++- .../dbutils/dml/MySQLDMLGeneratorTest.java | 27 +++ .../transforms/SourceWriterFnTest.java | 155 ++++++++++++++-- .../test/resources/customTransformation.json | 169 ++++++++++++++++++ 11 files changed, 548 insertions(+), 36 deletions(-) create mode 100644 v2/spanner-to-sourcedb/src/test/resources/customTransformation.json diff --git a/v2/spanner-to-sourcedb/pom.xml b/v2/spanner-to-sourcedb/pom.xml index 841405fb9a..fd7896b923 100644 --- a/v2/spanner-to-sourcedb/pom.xml +++ b/v2/spanner-to-sourcedb/pom.xml @@ -82,6 +82,12 @@ beam-it-jdbc test + + com.google.cloud.teleport.v2 + spanner-custom-shard + ${project.version} + test + com.google.cloud.teleport diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java index e2cfa9a009..d33ec69b4e 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java @@ -29,6 +29,7 @@ import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema; import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; import com.google.cloud.teleport.v2.spanner.migrations.spanner.SpannerSchema; +import com.google.cloud.teleport.v2.spanner.migrations.transformation.CustomTransformation; import com.google.cloud.teleport.v2.spanner.migrations.utils.SecretManagerAccessorImpl; import com.google.cloud.teleport.v2.spanner.migrations.utils.SessionFileReader; import com.google.cloud.teleport.v2.spanner.migrations.utils.ShardFileReader; @@ -367,6 +368,53 @@ public interface Options extends PipelineOptions, StreamingOptions { String getSourceType(); void setSourceType(String value); + + @TemplateParameter.GcsReadFile( + order = 25, + optional = true, + description = "Custom transformation jar location in Cloud Storage", + helpText = + "Custom jar location in Cloud Storage that contains the custom transformation logic for processing records" + + " in reverse replication.") + @Default.String("") + String getTransformationJarPath(); + + void setTransformationJarPath(String value); + + @TemplateParameter.Text( + order = 26, + optional = true, + description = "Custom class name for transformation", + helpText = + "Fully qualified class name having the custom transformation logic. It is a" + + " mandatory field in case transformationJarPath is specified") + @Default.String("") + String getTransformationClassName(); + + void setTransformationClassName(String value); + + @TemplateParameter.Text( + order = 27, + optional = true, + description = "Custom parameters for transformation", + helpText = + "String containing any custom parameters to be passed to the custom transformation class.") + @Default.String("") + String getTransformationCustomParameters(); + + void setTransformationCustomParameters(String value); + + @TemplateParameter.Text( + order = 28, + optional = true, + description = "Directory name for holding filtered records", + helpText = + "Records skipped from reverse replication are written to this directory. Default" + + " directory name is skip.") + @Default.String("filteredEvents") + String getFilterEventsDirectoryName(); + + void setFilterEventsDirectoryName(String value); } /** @@ -541,6 +589,11 @@ public static PipelineResult run(Options options) { } else { mergedRecords = dlqRecords; } + CustomTransformation customTransformation = + CustomTransformation.builder( + options.getTransformationJarPath(), options.getTransformationClassName()) + .setCustomParameters(options.getTransformationCustomParameters()) + .build(); SourceWriterTransform.Result sourceWriterOutput = mergedRecords .apply( @@ -578,7 +631,8 @@ public static PipelineResult run(Options options) { options.getShadowTablePrefix(), options.getSkipDirectoryName(), connectionPoolSizePerWorker, - options.getSourceType())); + options.getSourceType(), + customTransformation)); PCollection> dlqPermErrorRecords = reconsumedElements diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/constants/Constants.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/constants/Constants.java index 31ed2b8428..1368a46fe3 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/constants/Constants.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/constants/Constants.java @@ -59,6 +59,9 @@ public class Constants { // The Tag for skipped records public static final TupleTag SKIPPED_TAG = new TupleTag() {}; + // The Tag for records filtered via custom transformation. + public static final TupleTag FILTERED_TAG = new TupleTag() {}; + // Message written to the file for skipped records public static final String SKIPPED_TAG_MESSAGE = "Skipped record from reverse replication"; @@ -72,4 +75,8 @@ public class Constants { public static final String DEFAULT_SHARD_ID = "single_shard"; public static final String SOURCE_MYSQL = "mysql"; + + // Message written to the file for filtered records + public static final String FILTERED_TAG_MESSAGE = + "Filtered record from custom transformation in reverse replication"; } diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java index 8774fedfa9..c06917bf87 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java @@ -82,7 +82,8 @@ public DMLGeneratorResponse getDMLStatement(DMLGeneratorRequest dmlGeneratorRequ sourceTable, dmlGeneratorRequest.getNewValuesJson(), dmlGeneratorRequest.getKeyValuesJson(), - dmlGeneratorRequest.getSourceDbTimezoneOffset()); + dmlGeneratorRequest.getSourceDbTimezoneOffset(), + dmlGeneratorRequest.getCustomTransformationResponse()); if (pkcolumnNameValues == null) { LOG.warn( "Cannot reverse replicate for table {} without primary key, skipping the record", @@ -194,7 +195,8 @@ private static DMLGeneratorResponse generateUpsertStatement( sourceTable, dmlGeneratorRequest.getNewValuesJson(), dmlGeneratorRequest.getKeyValuesJson(), - dmlGeneratorRequest.getSourceDbTimezoneOffset()); + dmlGeneratorRequest.getSourceDbTimezoneOffset(), + dmlGeneratorRequest.getCustomTransformationResponse()); return getUpsertStatement( sourceTable.getName(), sourceTable.getPrimaryKeySet(), @@ -207,7 +209,8 @@ private static Map getColumnValues( SourceTable sourceTable, JSONObject newValuesJson, JSONObject keyValuesJson, - String sourceDbTimezoneOffset) { + String sourceDbTimezoneOffset, + Map customTransformationResponse) { Map response = new HashMap<>(); /* @@ -224,6 +227,10 @@ private static Map getColumnValues( as the column will be stored with default/null values */ Set sourcePKs = sourceTable.getPrimaryKeySet(); + Set customTransformColumns = null; + if (customTransformationResponse != null) { + customTransformColumns = customTransformationResponse.keySet(); + } for (Map.Entry entry : sourceTable.getColDefs().entrySet()) { SourceColumnDefinition sourceColDef = entry.getValue(); @@ -231,6 +238,10 @@ private static Map getColumnValues( if (sourcePKs.contains(colName)) { continue; // we only need non-primary keys } + if (customTransformColumns != null && customTransformColumns.contains(colName)) { + response.put(colName, customTransformationResponse.get(colName).toString()); + continue; + } String colId = entry.getKey(); SpannerColumnDefinition spannerColDef = spannerTable.getColDefs().get(colId); @@ -272,7 +283,8 @@ private static Map getPkColumnValues( SourceTable sourceTable, JSONObject newValuesJson, JSONObject keyValuesJson, - String sourceDbTimezoneOffset) { + String sourceDbTimezoneOffset, + Map customTransformationResponse) { Map response = new HashMap<>(); /* Get all primary key col ids from source table @@ -286,6 +298,10 @@ private static Map getPkColumnValues( if the column does not exist in any of the JSON - return null */ ColumnPK[] sourcePKs = sourceTable.getPrimaryKeys(); + Set customTransformColumns = null; + if (customTransformationResponse != null) { + customTransformColumns = customTransformationResponse.keySet(); + } for (int i = 0; i < sourcePKs.length; i++) { ColumnPK currentSourcePK = sourcePKs[i]; @@ -298,6 +314,13 @@ private static Map getPkColumnValues( sourceColDef.getName()); return null; } + if (customTransformColumns != null + && customTransformColumns.contains(sourceColDef.getName())) { + response.put( + sourceColDef.getName(), + customTransformationResponse.get(sourceColDef.getName()).toString()); + continue; + } String spannerColumnName = spannerColDef.getName(); String columnValue = ""; if (keyValuesJson.has(spannerColumnName)) { diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/InputRecordProcessor.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/InputRecordProcessor.java index 6fba8c3fe2..9bdfe2bcda 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/InputRecordProcessor.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/InputRecordProcessor.java @@ -15,7 +15,12 @@ */ package com.google.cloud.teleport.v2.templates.dbutils.processor; +import com.google.cloud.teleport.v2.spanner.exceptions.InvalidTransformationException; +import com.google.cloud.teleport.v2.spanner.migrations.convertors.ChangeEventToMapConvertor; import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema; +import com.google.cloud.teleport.v2.spanner.utils.ISpannerMigrationTransformer; +import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationRequest; +import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationResponse; import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord; import com.google.cloud.teleport.v2.templates.dbutils.dao.source.IDao; import com.google.cloud.teleport.v2.templates.dbutils.dml.IDMLGenerator; @@ -23,10 +28,12 @@ import com.google.cloud.teleport.v2.templates.models.DMLGeneratorResponse; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Map; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Distribution; import org.apache.beam.sdk.metrics.Metrics; import org.apache.commons.lang3.exception.ExceptionUtils; +import org.joda.time.Duration; import org.json.JSONObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,14 +42,18 @@ public class InputRecordProcessor { private static final Logger LOG = LoggerFactory.getLogger(InputRecordProcessor.class); + private static final Distribution applyCustomTransformationResponseTimeMetric = + Metrics.distribution( + InputRecordProcessor.class, "apply_custom_transformation_impl_latency_ms"); - public static void processRecord( + public static boolean processRecord( TrimmedShardedDataChangeRecord spannerRecord, Schema schema, IDao dao, String shardId, String sourceDbTimezoneOffset, - IDMLGenerator dmlGenerator) + IDMLGenerator dmlGenerator, + ISpannerMigrationTransformer spannerToSourceTransformer) throws Exception { try { @@ -53,17 +64,43 @@ public static void processRecord( String newValueJsonStr = spannerRecord.getMod().getNewValuesJson(); JSONObject newValuesJson = new JSONObject(newValueJsonStr); JSONObject keysJson = new JSONObject(keysJsonStr); + Map customTransformationResponse = null; + if (spannerToSourceTransformer != null) { + org.joda.time.Instant startTimestamp = org.joda.time.Instant.now(); + Map mapRequest = + ChangeEventToMapConvertor.combineJsonObjects(keysJson, newValuesJson); + MigrationTransformationRequest migrationTransformationRequest = + new MigrationTransformationRequest(tableName, mapRequest, shardId, modType); + MigrationTransformationResponse migrationTransformationResponse = null; + try { + migrationTransformationResponse = + spannerToSourceTransformer.toSourceRow(migrationTransformationRequest); + } catch (Exception e) { + throw new InvalidTransformationException(e); + } + org.joda.time.Instant endTimestamp = org.joda.time.Instant.now(); + applyCustomTransformationResponseTimeMetric.update( + new Duration(startTimestamp, endTimestamp).getMillis()); + if (migrationTransformationResponse.isEventFiltered()) { + Metrics.counter(InputRecordProcessor.class, "filtered_events_" + shardId).inc(); + return true; + } + if (migrationTransformationResponse != null) { + customTransformationResponse = migrationTransformationResponse.getResponseRow(); + } + } DMLGeneratorRequest dmlGeneratorRequest = new DMLGeneratorRequest.Builder( modType, tableName, newValuesJson, keysJson, sourceDbTimezoneOffset) .setSchema(schema) + .setCustomTransformationResponse(customTransformationResponse) .build(); DMLGeneratorResponse dmlGeneratorResponse = dmlGenerator.getDMLStatement(dmlGeneratorRequest); if (dmlGeneratorResponse.getDmlStatement().isEmpty()) { LOG.warn("DML statement is empty for table: " + tableName); - return; + return false; } dao.write(dmlGeneratorResponse.getDmlStatement()); @@ -79,7 +116,7 @@ public static void processRecord( long replicationLag = ChronoUnit.SECONDS.between(commitTsInst, instTime); lagMetric.update(replicationLag); // update the lag metric - + return false; } catch (Exception e) { LOG.error( "The exception while processing shardId: {} is {} ", diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/models/DMLGeneratorRequest.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/models/DMLGeneratorRequest.java index 661f05038d..3db153c51e 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/models/DMLGeneratorRequest.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/models/DMLGeneratorRequest.java @@ -16,6 +16,7 @@ package com.google.cloud.teleport.v2.templates.models; import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema; +import java.util.Map; import org.json.JSONObject; /** @@ -51,6 +52,8 @@ public class DMLGeneratorRequest { // The timezone offset of the source database, used for handling timezone-specific data. private final String sourceDbTimezoneOffset; + private Map customTransformationResponse; + public DMLGeneratorRequest(Builder builder) { this.modType = builder.modType; this.spannerTableName = builder.spannerTableName; @@ -58,6 +61,7 @@ public DMLGeneratorRequest(Builder builder) { this.newValuesJson = builder.newValuesJson; this.keyValuesJson = builder.keyValuesJson; this.sourceDbTimezoneOffset = builder.sourceDbTimezoneOffset; + this.customTransformationResponse = builder.customTransformationResponse; } public String getModType() { @@ -84,6 +88,10 @@ public String getSourceDbTimezoneOffset() { return sourceDbTimezoneOffset; } + public Map getCustomTransformationResponse() { + return customTransformationResponse; + } + public static class Builder { private final String modType; private final String spannerTableName; @@ -91,6 +99,7 @@ public static class Builder { private final JSONObject keyValuesJson; private final String sourceDbTimezoneOffset; private Schema schema; + private Map customTransformationResponse; public Builder( String modType, @@ -110,6 +119,12 @@ public Builder setSchema(Schema schema) { return this; } + public Builder setCustomTransformationResponse( + Map customTransformationResponse) { + this.customTransformationResponse = customTransformationResponse; + return this; + } + public DMLGeneratorRequest build() { return new DMLGeneratorRequest(this); } diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFn.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFn.java index f4af6d9780..6b00511bf8 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFn.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFn.java @@ -23,10 +23,14 @@ import com.google.cloud.teleport.v2.spanner.ddl.Ddl; import com.google.cloud.teleport.v2.spanner.ddl.IndexColumn; import com.google.cloud.teleport.v2.spanner.ddl.Table; +import com.google.cloud.teleport.v2.spanner.exceptions.InvalidTransformationException; import com.google.cloud.teleport.v2.spanner.migrations.convertors.ChangeEventSpannerConvertor; import com.google.cloud.teleport.v2.spanner.migrations.exceptions.ChangeEventConvertorException; import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema; import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; +import com.google.cloud.teleport.v2.spanner.migrations.transformation.CustomTransformation; +import com.google.cloud.teleport.v2.spanner.migrations.utils.CustomTransformationImplFetcher; +import com.google.cloud.teleport.v2.spanner.utils.ISpannerMigrationTransformer; import com.google.cloud.teleport.v2.templates.changestream.ChangeStreamErrorRecord; import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord; import com.google.cloud.teleport.v2.templates.constants.Constants; @@ -75,6 +79,9 @@ public class SourceWriterFn extends DoFn shards; @@ -86,6 +93,8 @@ public class SourceWriterFn extends DoFn shards, @@ -96,7 +105,8 @@ public SourceWriterFn( String shadowTablePrefix, String skipDirName, int maxThreadPerDataflowWorker, - String source) { + String source, + CustomTransformation customTransformation) { this.schema = schema; this.sourceDbTimezoneOffset = sourceDbTimezoneOffset; @@ -107,6 +117,7 @@ public SourceWriterFn( this.skipDirName = skipDirName; this.maxThreadPerDataflowWorker = maxThreadPerDataflowWorker; this.source = source; + this.customTransformation = customTransformation; } // for unit testing purposes @@ -124,6 +135,12 @@ public void setSourceProcessor(SourceProcessor sourceProcessor) { this.sourceProcessor = sourceProcessor; } + // for unit testing purposes + public void setSpannerToSourceTransformer( + ISpannerMigrationTransformer spannerToSourceTransformer) { + this.spannerToSourceTransformer = spannerToSourceTransformer; + } + /** Setup function connects to Cloud Spanner. */ @Setup public void setup() throws UnsupportedSourceException { @@ -132,6 +149,8 @@ public void setup() throws UnsupportedSourceException { sourceProcessor = SourceProcessorFactory.createSourceProcessor(source, shards, maxThreadPerDataflowWorker); spannerDao = new SpannerDao(spannerConfig); + spannerToSourceTransformer = + CustomTransformationImplFetcher.getCustomTransformationLogicImpl(customTransformation); } /** Teardown function disconnects from the Cloud Spanner. */ @@ -184,13 +203,18 @@ public void processElement(ProcessContext c) { if (!isSourceAhead) { IDao sourceDao = sourceProcessor.getSourceDao(shardId); - InputRecordProcessor.processRecord( - spannerRec, - schema, - sourceDao, - shardId, - sourceDbTimezoneOffset, - sourceProcessor.getDmlGenerator()); + boolean isEventFiltered = + InputRecordProcessor.processRecord( + spannerRec, + schema, + sourceDao, + shardId, + sourceDbTimezoneOffset, + sourceProcessor.getDmlGenerator(), + spannerToSourceTransformer); + if (isEventFiltered) { + outputWithTag(c, Constants.FILTERED_TAG, Constants.FILTERED_TAG_MESSAGE, spannerRec); + } spannerDao.updateShadowTable( getShadowTableMutation( @@ -206,6 +230,9 @@ public void processElement(ProcessContext c) { } com.google.cloud.Timestamp timestamp = com.google.cloud.Timestamp.now(); c.output(Constants.SUCCESS_TAG, timestamp.toString()); + } catch (InvalidTransformationException ex) { + invalidTransformationException.inc(); + outputWithTag(c, Constants.PERMANENT_ERROR_TAG, ex.getMessage(), spannerRec); } catch (ChangeEventConvertorException ex) { outputWithTag(c, Constants.PERMANENT_ERROR_TAG, ex.getMessage(), spannerRec); } catch (SpannerException diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterTransform.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterTransform.java index e80f617eeb..ef9ddbfaac 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterTransform.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterTransform.java @@ -19,6 +19,7 @@ import com.google.cloud.teleport.v2.spanner.ddl.Ddl; import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema; import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; +import com.google.cloud.teleport.v2.spanner.migrations.transformation.CustomTransformation; import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord; import com.google.cloud.teleport.v2.templates.constants.Constants; import com.google.common.base.Preconditions; @@ -52,6 +53,7 @@ public class SourceWriterTransform private final String skipDirName; private final int maxThreadPerDataflowWorker; private final String source; + private final CustomTransformation customTransformation; public SourceWriterTransform( List shards, @@ -62,7 +64,8 @@ public SourceWriterTransform( String shadowTablePrefix, String skipDirName, int maxThreadPerDataflowWorker, - String source) { + String source, + CustomTransformation customTransformation) { this.schema = schema; this.sourceDbTimezoneOffset = sourceDbTimezoneOffset; @@ -73,6 +76,7 @@ public SourceWriterTransform( this.skipDirName = skipDirName; this.maxThreadPerDataflowWorker = maxThreadPerDataflowWorker; this.source = source; + this.customTransformation = customTransformation; } @Override @@ -91,18 +95,21 @@ public SourceWriterTransform.Result expand( this.shadowTablePrefix, this.skipDirName, this.maxThreadPerDataflowWorker, - this.source)) + this.source, + this.customTransformation)) .withOutputTags( Constants.SUCCESS_TAG, TupleTagList.of(Constants.PERMANENT_ERROR_TAG) .and(Constants.RETRYABLE_ERROR_TAG) - .and(Constants.SKIPPED_TAG))); + .and(Constants.SKIPPED_TAG) + .and(Constants.FILTERED_TAG))); return Result.create( sourceWriteResults.get(Constants.SUCCESS_TAG), sourceWriteResults.get(Constants.PERMANENT_ERROR_TAG), sourceWriteResults.get(Constants.RETRYABLE_ERROR_TAG), - sourceWriteResults.get(Constants.SKIPPED_TAG)); + sourceWriteResults.get(Constants.SKIPPED_TAG), + sourceWriteResults.get(Constants.FILTERED_TAG)); } /** Container class for the results of this transform. */ @@ -113,13 +120,18 @@ private static Result create( PCollection successfulSourceWrites, PCollection permanentErrors, PCollection retryableErrors, - PCollection skippedSourceWrites) { + PCollection skippedSourceWrites, + PCollection filteredWrites) { Preconditions.checkNotNull(successfulSourceWrites); Preconditions.checkNotNull(permanentErrors); Preconditions.checkNotNull(retryableErrors); Preconditions.checkNotNull(skippedSourceWrites); return new AutoValue_SourceWriterTransform_Result( - successfulSourceWrites, permanentErrors, retryableErrors, skippedSourceWrites); + successfulSourceWrites, + permanentErrors, + retryableErrors, + skippedSourceWrites, + filteredWrites); } public abstract PCollection successfulSourceWrites(); @@ -130,6 +142,8 @@ private static Result create( public abstract PCollection skippedSourceWrites(); + public abstract PCollection filteredWrites(); + @Override public void finishSpecifyingOutput( String transformName, PInput input, PTransform transform) { @@ -151,7 +165,9 @@ public Map, PValue> expand() { Constants.RETRYABLE_ERROR_TAG, retryableErrors(), Constants.SKIPPED_TAG, - skippedSourceWrites()); + skippedSourceWrites(), + Constants.FILTERED_TAG, + filteredWrites()); } } } diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java index b01d139a6d..2d110e80ae 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java @@ -1088,6 +1088,33 @@ public void testSpannerColDefsNull() { assertTrue(sql.isEmpty()); } + @Test + public void customTransformationMatch() { + Schema schema = SessionFileReader.read("src/test/resources/customTransformation.json"); + String tableName = "Singers"; + String newValuesString = "{\"FirstName\":\"kk\",\"LastName\":\"ll\"}"; + JSONObject newValuesJson = new JSONObject(newValuesString); + String keyValueString = "{\"SingerId\":\"999\"}"; + JSONObject keyValuesJson = new JSONObject(keyValueString); + String modType = "INSERT"; + Map customTransformation = new HashMap<>(); + customTransformation.put("FullName", "\'kk ll\'"); + customTransformation.put("SingerId", "1"); + + MySQLDMLGenerator mySQLDMLGenerator = new MySQLDMLGenerator(); + DMLGeneratorResponse dmlGeneratorResponse = + mySQLDMLGenerator.getDMLStatement( + new DMLGeneratorRequest.Builder( + modType, tableName, newValuesJson, keyValuesJson, "+00:00") + .setSchema(schema) + .setCustomTransformationResponse(customTransformation) + .build()); + String sql = dmlGeneratorResponse.getDmlStatement(); + + assertTrue(sql.contains("`FullName` = 'kk ll'")); + assertTrue(sql.contains("VALUES (1,'kk ll')")); + } + public static Schema getSchemaObject() { Map syntheticPKeys = new HashMap(); Map srcSchema = new HashMap(); diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFnTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFnTest.java index f2bac7247d..039429b0ce 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFnTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFnTest.java @@ -15,6 +15,7 @@ */ package com.google.cloud.teleport.v2.templates.transforms; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.contains; import static org.mockito.ArgumentMatchers.eq; @@ -29,6 +30,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.cloud.Timestamp; import com.google.cloud.teleport.v2.spanner.ddl.Ddl; +import com.google.cloud.teleport.v2.spanner.exceptions.InvalidTransformationException; import com.google.cloud.teleport.v2.spanner.migrations.schema.ColumnPK; import com.google.cloud.teleport.v2.spanner.migrations.schema.NameAndCols; import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema; @@ -38,7 +40,10 @@ import com.google.cloud.teleport.v2.spanner.migrations.schema.SpannerTable; import com.google.cloud.teleport.v2.spanner.migrations.schema.SyntheticPKey; import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; +import com.google.cloud.teleport.v2.spanner.migrations.transformation.CustomTransformation; import com.google.cloud.teleport.v2.spanner.migrations.utils.SessionFileReader; +import com.google.cloud.teleport.v2.spanner.utils.ISpannerMigrationTransformer; +import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationResponse; import com.google.cloud.teleport.v2.templates.changestream.ChangeStreamErrorRecord; import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord; import com.google.cloud.teleport.v2.templates.constants.Constants; @@ -63,6 +68,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.runners.MethodSorters; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -76,6 +82,7 @@ public class SourceWriterFnTest { @Mock HashMap mockDaoMap; @Mock private SpannerConfig mockSpannerConfig; @Mock private DoFn.ProcessContext processContext; + @Mock private ISpannerMigrationTransformer mockSpannerMigrationTransformer; private static Gson gson = new Gson(); private Shard testShard; @@ -148,7 +155,8 @@ public void testSourceIsAhead() throws Exception { "shadow_", "skip", 500, - "mysql"); + "mysql", + null); ObjectMapper mapper = new ObjectMapper(); mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); sourceWriterFn.setObjectMapper(mapper); @@ -175,7 +183,8 @@ public void testSourceIsAheadWithSameCommitTimestamp() throws Exception { "shadow_", "skip", 500, - "mysql"); + "mysql", + null); ObjectMapper mapper = new ObjectMapper(); mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); sourceWriterFn.setObjectMapper(mapper); @@ -201,7 +210,8 @@ public void testSourceIsBehind() throws Exception { "shadow_", "skip", 500, - "mysql"); + "mysql", + null); ObjectMapper mapper = new ObjectMapper(); mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); sourceWriterFn.setObjectMapper(mapper); @@ -213,6 +223,118 @@ public void testSourceIsBehind() throws Exception { verify(mockSpannerDao, atLeast(1)).updateShadowTable(any()); } + @Test + public void testCustomTransformationException() throws Exception { + TrimmedShardedDataChangeRecord record = getParent1TrimmedDataChangeRecord("shardA"); + record.setShard("shardA"); + when(processContext.element()).thenReturn(KV.of(1L, record)); + when(mockSpannerMigrationTransformer.toSourceRow(any())) + .thenThrow(new InvalidTransformationException("some exception")); + CustomTransformation customTransformation = + CustomTransformation.builder("jarPath", "classPath").build(); + SourceWriterFn sourceWriterFn = + new SourceWriterFn( + ImmutableList.of(testShard), + testSchema, + mockSpannerConfig, + testSourceDbTimezoneOffset, + testDdl, + "shadow_", + "skip", + 500, + "mysql", + customTransformation); + ObjectMapper mapper = new ObjectMapper(); + mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); + sourceWriterFn.setObjectMapper(mapper); + sourceWriterFn.setSourceProcessor(sourceProcessor); + sourceWriterFn.setSpannerDao(mockSpannerDao); + sourceWriterFn.setSpannerToSourceTransformer(mockSpannerMigrationTransformer); + sourceWriterFn.processElement(processContext); + verify(mockSpannerDao, atLeast(1)).getShadowTableRecord(any(), any()); + String jsonRec = gson.toJson(record, TrimmedShardedDataChangeRecord.class); + ChangeStreamErrorRecord errorRecord = + new ChangeStreamErrorRecord( + jsonRec, + "com.google.cloud.teleport.v2.spanner.exceptions.InvalidTransformationException: some exception"); + verify(processContext, atLeast(1)) + .output( + Constants.PERMANENT_ERROR_TAG, gson.toJson(errorRecord, ChangeStreamErrorRecord.class)); + } + + @Test + public void testCustomTransformationApplied() throws Exception { + TrimmedShardedDataChangeRecord record = getParent1TrimmedDataChangeRecord("shardA"); + record.setShard("shardA"); + when(processContext.element()).thenReturn(KV.of(1L, record)); + when(mockSpannerMigrationTransformer.toSourceRow(any())) + .thenReturn(new MigrationTransformationResponse(Map.of("id", "45"), false)); + CustomTransformation customTransformation = + CustomTransformation.builder("jarPath", "classPath").build(); + SourceWriterFn sourceWriterFn = + new SourceWriterFn( + ImmutableList.of(testShard), + testSchema, + mockSpannerConfig, + testSourceDbTimezoneOffset, + testDdl, + "shadow_", + "skip", + 500, + "mysql", + customTransformation); + ObjectMapper mapper = new ObjectMapper(); + mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); + sourceWriterFn.setObjectMapper(mapper); + sourceWriterFn.setSourceProcessor(sourceProcessor); + sourceWriterFn.setSpannerDao(mockSpannerDao); + sourceWriterFn.setSpannerToSourceTransformer(mockSpannerMigrationTransformer); + sourceWriterFn.processElement(processContext); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(String.class); + verify(mockSpannerDao, atLeast(1)).getShadowTableRecord(any(), any()); + verify(mockSqlDao, atLeast(1)).write(argumentCaptor.capture()); + assertTrue(argumentCaptor.getValue().contains("INSERT INTO `parent1`(`id`) VALUES (45)")); + verify(mockSpannerDao, atLeast(1)).updateShadowTable(any()); + } + + @Test + public void testCustomTransformationFiltered() throws Exception { + TrimmedShardedDataChangeRecord record = getParent1TrimmedDataChangeRecord("shardA"); + record.setShard("shardA"); + when(processContext.element()).thenReturn(KV.of(1L, record)); + when(mockSpannerMigrationTransformer.toSourceRow(any())) + .thenReturn(new MigrationTransformationResponse(null, true)); + CustomTransformation customTransformation = + CustomTransformation.builder("jarPath", "classPath").build(); + SourceWriterFn sourceWriterFn = + new SourceWriterFn( + ImmutableList.of(testShard), + testSchema, + mockSpannerConfig, + testSourceDbTimezoneOffset, + testDdl, + "shadow_", + "skip", + 500, + "mysql", + customTransformation); + ObjectMapper mapper = new ObjectMapper(); + mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); + sourceWriterFn.setObjectMapper(mapper); + sourceWriterFn.setSourceProcessor(sourceProcessor); + sourceWriterFn.setSpannerDao(mockSpannerDao); + sourceWriterFn.setSpannerToSourceTransformer(mockSpannerMigrationTransformer); + sourceWriterFn.processElement(processContext); + verify(mockSpannerDao, atLeast(1)).getShadowTableRecord(any(), any()); + verify(mockSqlDao, atLeast(0)).write(any()); + verify(mockSpannerDao, atLeast(0)).updateShadowTable(any()); + String jsonRec = gson.toJson(record, TrimmedShardedDataChangeRecord.class); + ChangeStreamErrorRecord errorRecord = + new ChangeStreamErrorRecord(jsonRec, Constants.FILTERED_TAG_MESSAGE); + verify(processContext, atLeast(1)) + .output(Constants.FILTERED_TAG, gson.toJson(errorRecord, ChangeStreamErrorRecord.class)); + } + @Test public void testNoShard() throws Exception { TrimmedShardedDataChangeRecord record = getParent1TrimmedDataChangeRecord("shardA"); @@ -227,7 +349,8 @@ public void testNoShard() throws Exception { "shadow_", "skip", 500, - "mysql"); + "mysql", + null); ObjectMapper mapper = new ObjectMapper(); mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); sourceWriterFn.setObjectMapper(mapper); @@ -257,7 +380,8 @@ public void testSkipShard() throws Exception { "shadow_", "skip", 500, - "mysql"); + "mysql", + null); ObjectMapper mapper = new ObjectMapper(); mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); sourceWriterFn.setObjectMapper(mapper); @@ -285,7 +409,8 @@ public void testPermanentError() throws Exception { "shadow_", "skip", 500, - "mysql"); + "mysql", + null); ObjectMapper mapper = new ObjectMapper(); mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); sourceWriterFn.setObjectMapper(mapper); @@ -317,7 +442,8 @@ public void testRetryableError() throws Exception { "shadow_", "skip", 500, - "mysql"); + "mysql", + null); ObjectMapper mapper = new ObjectMapper(); mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); sourceWriterFn.setObjectMapper(mapper); @@ -345,7 +471,8 @@ public void testRetryableErrorForForeignKey() throws Exception { "shadow_", "skip", 500, - "mysql"); + "mysql", + null); ObjectMapper mapper = new ObjectMapper(); mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); sourceWriterFn.setObjectMapper(mapper); @@ -375,7 +502,8 @@ public void testRetryableErrorConnectionFailure() throws Exception { "shadow_", "skip", 500, - "mysql"); + "mysql", + null); ObjectMapper mapper = new ObjectMapper(); mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); sourceWriterFn.setObjectMapper(mapper); @@ -405,7 +533,8 @@ public void testPermanentConnectionFailure() throws Exception { "shadow_", "skip", 500, - "mysql"); + "mysql", + null); ObjectMapper mapper = new ObjectMapper(); mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); sourceWriterFn.setObjectMapper(mapper); @@ -435,7 +564,8 @@ public void testPermanentGenericException() throws Exception { "shadow_", "skip", 500, - "mysql"); + "mysql", + null); ObjectMapper mapper = new ObjectMapper(); mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); sourceWriterFn.setObjectMapper(mapper); @@ -464,7 +594,8 @@ public void testDMLEmpty() throws Exception { "shadow_", "skip", 500, - "mysql"); + "mysql", + null); ObjectMapper mapper = new ObjectMapper(); mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); sourceWriterFn.setObjectMapper(mapper); diff --git a/v2/spanner-to-sourcedb/src/test/resources/customTransformation.json b/v2/spanner-to-sourcedb/src/test/resources/customTransformation.json new file mode 100644 index 0000000000..320d184e7c --- /dev/null +++ b/v2/spanner-to-sourcedb/src/test/resources/customTransformation.json @@ -0,0 +1,169 @@ +{ + "SessionName": "NewSession", + "EditorName": "", + "DatabaseType": "mysql", + "DatabaseName": "ui_demo", + "Dialect": "google_standard_sql", + "Notes": null, + "Tags": null, + "SpSchema": { + "t1": { + "Name": "Singers", + "ColIds": [ + "c5", + "c6", + "c7" + ], + "ColDefs": { + "c5": { + "Name": "SingerId", + "T": { + "Name": "INT64", + "Len": 0, + "IsArray": false + }, + "NotNull": true, + "Comment": "From: SingerId int", + "Id": "c5" + }, + "c6": { + "Name": "FirstName", + "T": { + "Name": "STRING", + "Len": 1024, + "IsArray": false + }, + "NotNull": false, + "Comment": "From: FirstName varchar(1024)", + "Id": "c6" + }, + "c7": { + "Name": "LastName", + "T": { + "Name": "STRING", + "Len": 1024, + "IsArray": false + }, + "NotNull": false, + "Comment": "From: LastName varchar(1024)", + "Id": "c7" + } + }, + "PrimaryKeys": [ + { + "ColId": "c5", + "Desc": false, + "Order": 1 + } + ], + "ForeignKeys": null, + "Indexes": [ + { + "Name": "ind1", + "TableId": "t1", + "Unique": false, + "Keys": [ + { + "ColId": "c5", + "Desc": false, + "Order": 1 + } + ], + "Id": "i9", + "StoredColumnIds": null + } + ], + "ParentId": "", + "Comment": "Spanner schema for source table Singers", + "Id": "t1" + } + }, + "SyntheticPKeys": {}, + "SrcSchema": { + "t1": { + "Name": "Singers", + "Schema": "ui_demo", + "ColIds": [ + "c5", + "c6" + ], + "ColDefs": { + "c5": { + "Name": "SingerId", + "Type": { + "Name": "int", + "Mods": null, + "ArrayBounds": null + }, + "NotNull": true, + "Ignored": { + "Check": false, + "Identity": false, + "Default": false, + "Exclusion": false, + "ForeignKey": false, + "AutoIncrement": false + }, + "Id": "c5" + }, + "c6": { + "Name": "FullName", + "Type": { + "Name": "varchar", + "Mods": [ + 1024 + ], + "ArrayBounds": null + }, + "NotNull": false, + "Ignored": { + "Check": false, + "Identity": false, + "Default": false, + "Exclusion": false, + "ForeignKey": false, + "AutoIncrement": false + }, + "Id": "c6" + } + }, + "PrimaryKeys": [ + { + "ColId": "c5", + "Desc": false, + "Order": 1 + } + ], + "ForeignKeys": null, + "Indexes": [ + { + "Name": "ind1", + "Unique": false, + "Keys": [ + { + "ColId": "c5", + "Desc": false, + "Order": 1 + } + ], + "Id": "i9", + "StoredColumnIds": null + } + ], + "Id": "t1" + } + }, + "SchemaIssues": { + "t1": { + "c5": [ + 13, + 18 + ] + } + }, + "Location": {}, + "TimezoneOffset": "+00:00", + "SpDialect": "google_standard_sql", + "UniquePKey": {}, + "Rules": [] +} \ No newline at end of file