diff --git a/pkg/ccl/changefeedccl/BUILD.bazel b/pkg/ccl/changefeedccl/BUILD.bazel index 9bc54f08efaa..95b61c19b47b 100644 --- a/pkg/ccl/changefeedccl/BUILD.bazel +++ b/pkg/ccl/changefeedccl/BUILD.bazel @@ -288,6 +288,7 @@ go_test( "//pkg/server/telemetry", "//pkg/settings/cluster", "//pkg/spanconfig", + "//pkg/spanconfig/spanconfigjob", "//pkg/spanconfig/spanconfigptsreader", "//pkg/sql", "//pkg/sql/catalog", diff --git a/pkg/ccl/changefeedccl/cdctest/nemeses.go b/pkg/ccl/changefeedccl/cdctest/nemeses.go index 2d278f4c96bf..ba5f566c3a76 100644 --- a/pkg/ccl/changefeedccl/cdctest/nemeses.go +++ b/pkg/ccl/changefeedccl/cdctest/nemeses.go @@ -19,6 +19,53 @@ import ( "github.com/cockroachdb/errors" ) +type ChangefeedOption struct { + FullTableName bool + Format string + KeyInValue bool +} + +func newChangefeedOption(testName string) ChangefeedOption { + isCloudstorage := strings.Contains(testName, "cloudstorage") + isWebhook := strings.Contains(testName, "webhook") + cfo := ChangefeedOption{ + FullTableName: rand.Intn(2) < 1, + + // Because key_in_value is on by default for cloudstorage and webhook sinks, + // the key in the value is extracted and removed from the test feed + // messages (see extractKeyFromJSONValue function). + // TODO: (#138749) enable testing key_in_value for cloudstorage + // and webhook sinks + KeyInValue: !isCloudstorage && !isWebhook && rand.Intn(2) < 1, + Format: "json", + } + + if isCloudstorage && rand.Intn(2) < 1 { + cfo.Format = "parquet" + } + + return cfo +} + +func (co ChangefeedOption) String() string { + return fmt.Sprintf("full_table_name=%t,key_in_value=%t,format=%s", + co.FullTableName, co.KeyInValue, co.Format) +} + +func (cfo ChangefeedOption) OptionString() string { + options := "" + if cfo.Format == "parquet" { + options = ", format=parquet" + } + if cfo.FullTableName { + options = options + ", full_table_name" + } + if cfo.KeyInValue { + options = options + ", key_in_value" + } + return options +} + type NemesesOption struct { EnableFpValidator bool EnableSQLSmith bool @@ -36,7 +83,8 @@ var NemesesOptions = []NemesesOption{ } func (no NemesesOption) String() string { - return fmt.Sprintf("fp_validator=%t,sql_smith=%t", no.EnableFpValidator, no.EnableSQLSmith) + return fmt.Sprintf("fp_validator=%t,sql_smith=%t", + no.EnableFpValidator, no.EnableSQLSmith) } // RunNemesis runs a jepsen-style validation of whether a changefeed meets our @@ -50,8 +98,7 @@ func (no NemesesOption) String() string { func RunNemesis( f TestFeedFactory, db *gosql.DB, - isSinkless bool, - isCloudstorage bool, + testName string, withLegacySchemaChanger bool, rng *rand.Rand, nOp NemesesOption, @@ -69,6 +116,8 @@ func RunNemesis( ctx := context.Background() eventPauseCount := 10 + + isSinkless := strings.Contains(testName, "sinkless") if isSinkless { // Disable eventPause for sinkless changefeeds because we currently do not // have "correct" pause and unpause mechanisms for changefeeds that aren't @@ -199,11 +248,13 @@ func RunNemesis( } } - withFormatParquet := "" - if isCloudstorage && rand.Intn(2) < 1 { - withFormatParquet = ", format=parquet" - } - foo, err := f.Feed(fmt.Sprintf(`CREATE CHANGEFEED FOR foo WITH updated, resolved, diff %s`, withFormatParquet)) + cfo := newChangefeedOption(testName) + changefeedStatement := fmt.Sprintf( + `CREATE CHANGEFEED FOR foo WITH updated, resolved, diff%s`, + cfo.OptionString(), + ) + log.Infof(ctx, "Using changefeed options: %s", changefeedStatement) + foo, err := f.Feed(changefeedStatement) if err != nil { return nil, err } @@ -218,7 +269,8 @@ func RunNemesis( if _, err := db.Exec(createFprintStmtBuf.String()); err != nil { return nil, err } - baV, err := NewBeforeAfterValidator(db, `foo`) + + baV, err := NewBeforeAfterValidator(db, `foo`, cfo) if err != nil { return nil, err } @@ -817,7 +869,7 @@ func noteFeedMessage(a fsm.Args) error { } ns.availableRows-- log.Infof(a.Ctx, "%s->%s", m.Key, m.Value) - return ns.v.NoteRow(m.Partition, string(m.Key), string(m.Value), ts) + return ns.v.NoteRow(m.Partition, string(m.Key), string(m.Value), ts, m.Topic) } } } diff --git a/pkg/ccl/changefeedccl/cdctest/validator.go b/pkg/ccl/changefeedccl/cdctest/validator.go index 1fe47f0afb5f..f7a2322a62c0 100644 --- a/pkg/ccl/changefeedccl/cdctest/validator.go +++ b/pkg/ccl/changefeedccl/cdctest/validator.go @@ -23,7 +23,7 @@ import ( // guarantees in a single table. type Validator interface { // NoteRow accepts a changed row entry. - NoteRow(partition string, key, value string, updated hlc.Timestamp) error + NoteRow(partition, key, value string, updated hlc.Timestamp, topic string) error // NoteResolved accepts a resolved timestamp entry. NoteResolved(partition string, resolved hlc.Timestamp) error // Failures returns any violations seen so far. @@ -64,7 +64,7 @@ var _ StreamValidator = &orderValidator{} type noOpValidator struct{} // NoteRow accepts a changed row entry. -func (v *noOpValidator) NoteRow(string, string, string, hlc.Timestamp) error { return nil } +func (v *noOpValidator) NoteRow(string, string, string, hlc.Timestamp, string) error { return nil } // NoteResolved accepts a resolved timestamp entry. func (v *noOpValidator) NoteResolved(string, hlc.Timestamp) error { return nil } @@ -125,7 +125,9 @@ func (v *orderValidator) GetValuesForKeyBelowTimestamp( } // NoteRow implements the Validator interface. -func (v *orderValidator) NoteRow(partition string, key, value string, updated hlc.Timestamp) error { +func (v *orderValidator) NoteRow( + partition, key, value string, updated hlc.Timestamp, topic string, +) error { if prev, ok := v.partitionForKey[key]; ok && prev != partition { v.failures = append(v.failures, fmt.Sprintf( `key [%s] received on two partitions: %s and %s`, key, prev, partition, @@ -189,6 +191,8 @@ type beforeAfterValidator struct { table string primaryKeyCols []string resolved map[string]hlc.Timestamp + fullTableName bool + keyInValue bool failures []string } @@ -196,7 +200,9 @@ type beforeAfterValidator struct { // NewBeforeAfterValidator returns a Validator verifies that the "before" and // "after" fields in each row agree with the source table when performing AS OF // SYSTEM TIME lookups before and at the row's timestamp. -func NewBeforeAfterValidator(sqlDB *gosql.DB, table string) (Validator, error) { +func NewBeforeAfterValidator( + sqlDB *gosql.DB, table string, option ChangefeedOption, +) (Validator, error) { primaryKeyCols, err := fetchPrimaryKeyCols(sqlDB, table) if err != nil { return nil, errors.Wrap(err, "fetchPrimaryKeyCols failed") @@ -205,6 +211,8 @@ func NewBeforeAfterValidator(sqlDB *gosql.DB, table string) (Validator, error) { return &beforeAfterValidator{ sqlDB: sqlDB, table: table, + fullTableName: option.FullTableName, + keyInValue: option.KeyInValue, primaryKeyCols: primaryKeyCols, resolved: make(map[string]hlc.Timestamp), }, nil @@ -212,8 +220,21 @@ func NewBeforeAfterValidator(sqlDB *gosql.DB, table string) (Validator, error) { // NoteRow implements the Validator interface. func (v *beforeAfterValidator) NoteRow( - partition string, key, value string, updated hlc.Timestamp, + partition, key, value string, updated hlc.Timestamp, topic string, ) error { + if v.fullTableName { + if topic != fmt.Sprintf(`d.public.%s`, v.table) { + v.failures = append(v.failures, fmt.Sprintf( + "topic %s does not match expected table d.public.%s", topic, v.table, + )) + } + } else { + if topic != v.table { + v.failures = append(v.failures, fmt.Sprintf( + "topic %s does not match expected table %s", topic, v.table, + )) + } + } keyJSON, err := json.ParseJSON(key) if err != nil { return err @@ -230,6 +251,26 @@ func (v *beforeAfterValidator) NoteRow( return err } + if v.keyInValue { + keyString := keyJSON.String() + keyInValueJSON, err := valueJSON.FetchValKey("key") + if err != nil { + return err + } + + if keyInValueJSON == nil { + v.failures = append(v.failures, fmt.Sprintf( + "no key in value, expected key value %s", keyString)) + } else { + keyInValueString := keyInValueJSON.String() + if keyInValueString != keyString { + v.failures = append(v.failures, fmt.Sprintf( + "key in value %s does not match expected key value %s", + keyInValueString, keyString)) + } + } + } + afterJSON, err := valueJSON.FetchValKey("after") if err != nil { return err @@ -451,7 +492,7 @@ func (v *FingerprintValidator) DBFunc( // NoteRow implements the Validator interface. func (v *FingerprintValidator) NoteRow( - ignoredPartition string, key, value string, updated hlc.Timestamp, + partition, key, value string, updated hlc.Timestamp, topic string, ) error { if v.firstRowTimestamp.IsEmpty() || updated.Less(v.firstRowTimestamp) { v.firstRowTimestamp = updated @@ -663,9 +704,11 @@ func (v *FingerprintValidator) Failures() []string { type Validators []Validator // NoteRow implements the Validator interface. -func (vs Validators) NoteRow(partition string, key, value string, updated hlc.Timestamp) error { +func (vs Validators) NoteRow( + partition, key, value string, updated hlc.Timestamp, topic string, +) error { for _, v := range vs { - if err := v.NoteRow(partition, key, value, updated); err != nil { + if err := v.NoteRow(partition, key, value, updated, topic); err != nil { return err } } @@ -707,10 +750,12 @@ func NewCountValidator(v Validator) *CountValidator { } // NoteRow implements the Validator interface. -func (v *CountValidator) NoteRow(partition string, key, value string, updated hlc.Timestamp) error { +func (v *CountValidator) NoteRow( + partition, key, value string, updated hlc.Timestamp, topic string, +) error { v.NumRows++ v.rowsSinceResolved++ - return v.v.NoteRow(partition, key, value, updated) + return v.v.NoteRow(partition, key, value, updated, topic) } // NoteResolved implements the Validator interface. diff --git a/pkg/ccl/changefeedccl/cdctest/validator_test.go b/pkg/ccl/changefeedccl/cdctest/validator_test.go index 922a50be3f86..ae8fa8ba85ad 100644 --- a/pkg/ccl/changefeedccl/cdctest/validator_test.go +++ b/pkg/ccl/changefeedccl/cdctest/validator_test.go @@ -24,9 +24,13 @@ func ts(i int64) hlc.Timestamp { return hlc.Timestamp{WallTime: i} } -func noteRow(t *testing.T, v Validator, partition, key, value string, updated hlc.Timestamp) { +func noteRow( + t *testing.T, v Validator, partition, key, value string, updated hlc.Timestamp, topic string, +) { t.Helper() - if err := v.NoteRow(partition, key, value, updated); err != nil { + // None of the validators in this file include assertions about the topic + // name, so it's ok to pass in an empty string for topic. + if err := v.NoteRow(partition, key, value, updated, topic); err != nil { t.Fatal(err) } } @@ -57,23 +61,23 @@ func TestOrderValidator(t *testing.T) { }) t.Run(`dupe okay`, func(t *testing.T) { v := NewOrderValidator(`t1`) - noteRow(t, v, `p1`, `k1`, ignored, ts(1)) - noteRow(t, v, `p1`, `k1`, ignored, ts(2)) - noteRow(t, v, `p1`, `k1`, ignored, ts(1)) + noteRow(t, v, `p1`, `k1`, ignored, ts(1), `foo`) + noteRow(t, v, `p1`, `k1`, ignored, ts(2), `foo`) + noteRow(t, v, `p1`, `k1`, ignored, ts(1), `foo`) assertValidatorFailures(t, v) }) t.Run(`key on two partitions`, func(t *testing.T) { v := NewOrderValidator(`t1`) - noteRow(t, v, `p1`, `k1`, ignored, ts(2)) - noteRow(t, v, `p2`, `k1`, ignored, ts(1)) + noteRow(t, v, `p1`, `k1`, ignored, ts(2), `foo`) + noteRow(t, v, `p2`, `k1`, ignored, ts(1), `foo`) assertValidatorFailures(t, v, `key [k1] received on two partitions: p1 and p2`, ) }) t.Run(`new key with lower timestamp`, func(t *testing.T) { v := NewOrderValidator(`t1`) - noteRow(t, v, `p1`, `k1`, ignored, ts(2)) - noteRow(t, v, `p1`, `k1`, ignored, ts(1)) + noteRow(t, v, `p1`, `k1`, ignored, ts(2), `foo`) + noteRow(t, v, `p1`, `k1`, ignored, ts(1), `foo`) assertValidatorFailures(t, v, `topic t1 partition p1: saw new row timestamp 1.0000000000 after 2.0000000000 was seen`, ) @@ -82,12 +86,12 @@ func TestOrderValidator(t *testing.T) { v := NewOrderValidator(`t1`) noteResolved(t, v, `p2`, ts(3)) // Okay because p2 saw the resolved timestamp but p1 didn't. - noteRow(t, v, `p1`, `k1`, ignored, ts(1)) + noteRow(t, v, `p1`, `k1`, ignored, ts(1), `foo`) noteResolved(t, v, `p1`, ts(3)) // This one is not okay. - noteRow(t, v, `p1`, `k1`, ignored, ts(2)) + noteRow(t, v, `p1`, `k1`, ignored, ts(2), `foo`) // Still okay because we've seen it before. - noteRow(t, v, `p1`, `k1`, ignored, ts(1)) + noteRow(t, v, `p1`, `k1`, ignored, ts(1), `foo`) assertValidatorFailures(t, v, `topic t1 partition p1`+ `: saw new row timestamp 2.0000000000 after 3.0000000000 was resolved`, @@ -95,6 +99,12 @@ func TestOrderValidator(t *testing.T) { }) } +var standardChangefeedOptions = ChangefeedOption{ + FullTableName: false, + KeyInValue: false, + Format: "json", +} + func TestBeforeAfterValidator(t *testing.T) { defer leaktest.AfterTest(t)() @@ -130,97 +140,115 @@ func TestBeforeAfterValidator(t *testing.T) { } t.Run(`empty`, func(t *testing.T) { - v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`) + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, standardChangefeedOptions) + require.NoError(t, err) + assertValidatorFailures(t, v) + }) + t.Run(`fullTableName`, func(t *testing.T) { + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, ChangefeedOption{ + FullTableName: true, + KeyInValue: false, + Format: "json", + }) + require.NoError(t, err) + assertValidatorFailures(t, v) + }) + t.Run(`key_in_value`, func(t *testing.T) { + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, ChangefeedOption{ + FullTableName: false, + KeyInValue: true, + Format: "json", + }) require.NoError(t, err) assertValidatorFailures(t, v) }) t.Run(`during initial`, func(t *testing.T) { - v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`) + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, standardChangefeedOptions) require.NoError(t, err) // "before" is ignored if missing. - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1]) - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2]) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1], `foo`) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2], `foo`) // However, if provided, it is validated. - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":2}, "before": {"k":1,"v":1}}`, ts[2]) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":2}, "before": {"k":1,"v":1}}`, ts[2], `foo`) assertValidatorFailures(t, v) - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":3}, "before": {"k":1,"v":3}}`, ts[3]) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":3}, "before": {"k":1,"v":3}}`, ts[3], `foo`) assertValidatorFailures(t, v, `"before" field did not agree with row at `+ts[3].Prev().AsOfSystemTime()+ `: SELECT count(*) = 1 FROM foo AS OF SYSTEM TIME '`+ts[3].Prev().AsOfSystemTime()+ `' WHERE to_json(k)::TEXT = $1 AND to_json(v)::TEXT = $2 [1 3]`) }) t.Run(`missing before`, func(t *testing.T) { - v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`) + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, standardChangefeedOptions) require.NoError(t, err) noteResolved(t, v, `p`, ts[0]) // "before" should have been provided. - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2]) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2], `foo`) assertValidatorFailures(t, v, `"before" field did not agree with row at `+ts[2].Prev().AsOfSystemTime()+ `: SELECT count(*) = 0 FROM foo AS OF SYSTEM TIME '`+ts[2].Prev().AsOfSystemTime()+ `' WHERE to_json(k)::TEXT = $1 [1]`) }) t.Run(`incorrect before`, func(t *testing.T) { - v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`) + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, standardChangefeedOptions) require.NoError(t, err) noteResolved(t, v, `p`, ts[0]) // "before" provided with wrong value. - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":3}, "before": {"k":5,"v":10}}`, ts[3]) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":3}, "before": {"k":5,"v":10}}`, ts[3], `foo`) assertValidatorFailures(t, v, `"before" field did not agree with row at `+ts[3].Prev().AsOfSystemTime()+ `: SELECT count(*) = 1 FROM foo AS OF SYSTEM TIME '`+ts[3].Prev().AsOfSystemTime()+ `' WHERE to_json(k)::TEXT = $1 AND to_json(v)::TEXT = $2 [5 10]`) }) t.Run(`unnecessary before`, func(t *testing.T) { - v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`) + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, standardChangefeedOptions) require.NoError(t, err) noteResolved(t, v, `p`, ts[0]) // "before" provided but should not have been. - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":1}, "before": {"k":1,"v":1}}`, ts[1]) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":1}, "before": {"k":1,"v":1}}`, ts[1], `foo`) assertValidatorFailures(t, v, `"before" field did not agree with row at `+ts[1].Prev().AsOfSystemTime()+ `: SELECT count(*) = 1 FROM foo AS OF SYSTEM TIME '`+ts[1].Prev().AsOfSystemTime()+ `' WHERE to_json(k)::TEXT = $1 AND to_json(v)::TEXT = $2 [1 1]`) }) t.Run(`missing after`, func(t *testing.T) { - v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`) + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, standardChangefeedOptions) require.NoError(t, err) noteResolved(t, v, `p`, ts[0]) // "after" should have been provided. - noteRow(t, v, `p`, `[1]`, `{"before": {"k":1,"v":1}}`, ts[2]) + noteRow(t, v, `p`, `[1]`, `{"before": {"k":1,"v":1}}`, ts[2], `foo`) assertValidatorFailures(t, v, `"after" field did not agree with row at `+ts[2].AsOfSystemTime()+ `: SELECT count(*) = 0 FROM foo AS OF SYSTEM TIME '`+ts[2].AsOfSystemTime()+ `' WHERE to_json(k)::TEXT = $1 [1]`) }) t.Run(`incorrect after`, func(t *testing.T) { - v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`) + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, standardChangefeedOptions) require.NoError(t, err) noteResolved(t, v, `p`, ts[0]) // "after" provided with wrong value. - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":5}, "before": {"k":1,"v":2}}`, ts[3]) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":5}, "before": {"k":1,"v":2}}`, ts[3], `foo`) assertValidatorFailures(t, v, `"after" field did not agree with row at `+ts[3].AsOfSystemTime()+ `: SELECT count(*) = 1 FROM foo AS OF SYSTEM TIME '`+ts[3].AsOfSystemTime()+ `' WHERE to_json(k)::TEXT = $1 AND to_json(v)::TEXT = $2 [1 5]`) }) t.Run(`unnecessary after`, func(t *testing.T) { - v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`) + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, standardChangefeedOptions) require.NoError(t, err) noteResolved(t, v, `p`, ts[0]) // "after" provided but should not have been. - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":3}, "before": {"k":1,"v":3}}`, ts[4]) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":3}, "before": {"k":1,"v":3}}`, ts[4], `foo`) assertValidatorFailures(t, v, `"after" field did not agree with row at `+ts[4].AsOfSystemTime()+ `: SELECT count(*) = 1 FROM foo AS OF SYSTEM TIME '`+ts[4].AsOfSystemTime()+ `' WHERE to_json(k)::TEXT = $1 AND to_json(v)::TEXT = $2 [1 3]`) }) t.Run(`incorrect before and after`, func(t *testing.T) { - v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`) + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, standardChangefeedOptions) require.NoError(t, err) noteResolved(t, v, `p`, ts[0]) // "before" and "after" both provided with wrong value. - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":5}, "before": {"k":1,"v":4}}`, ts[3]) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":5}, "before": {"k":1,"v":4}}`, ts[3], `foo`) assertValidatorFailures(t, v, `"after" field did not agree with row at `+ts[3].AsOfSystemTime()+ `: SELECT count(*) = 1 FROM foo AS OF SYSTEM TIME '`+ts[3].AsOfSystemTime()+ @@ -230,19 +258,19 @@ func TestBeforeAfterValidator(t *testing.T) { `' WHERE to_json(k)::TEXT = $1 AND to_json(v)::TEXT = $2 [1 4]`) }) t.Run(`correct`, func(t *testing.T) { - v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`) + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, standardChangefeedOptions) require.NoError(t, err) noteResolved(t, v, `p`, ts[0]) - noteRow(t, v, `p`, `[1]`, `{}`, ts[0]) - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1]) - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":1}, "before": null}`, ts[1]) - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":2}, "before": {"k":1,"v":1}}`, ts[2]) - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":3}, "before": {"k":1,"v":2}}`, ts[3]) - noteRow(t, v, `p`, `[1]`, `{ "before": {"k":1,"v":3}}`, ts[4]) - noteRow(t, v, `p`, `[1]`, `{"after": null, "before": {"k":1,"v":3}}`, ts[4]) - noteRow(t, v, `p`, `[2]`, `{}`, ts[1]) - noteRow(t, v, `p`, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2]) - noteRow(t, v, `p`, `[2]`, `{"after": {"k":2,"v":2}, "before": null}`, ts[2]) + noteRow(t, v, `p`, `[1]`, `{}`, ts[0], `foo`) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1], `foo`) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":1}, "before": null}`, ts[1], `foo`) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":2}, "before": {"k":1,"v":1}}`, ts[2], `foo`) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1,"v":3}, "before": {"k":1,"v":2}}`, ts[3], `foo`) + noteRow(t, v, `p`, `[1]`, `{ "before": {"k":1,"v":3}}`, ts[4], `foo`) + noteRow(t, v, `p`, `[1]`, `{"after": null, "before": {"k":1,"v":3}}`, ts[4], `foo`) + noteRow(t, v, `p`, `[2]`, `{}`, ts[1], `foo`) + noteRow(t, v, `p`, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2], `foo`) + noteRow(t, v, `p`, `[2]`, `{"after": {"k":2,"v":2}, "before": null}`, ts[2], `foo`) assertValidatorFailures(t, v) }) } @@ -269,10 +297,10 @@ func TestBeforeAfterValidatorForGeometry(t *testing.T) { t.Fatal(err) } } - v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`) + v, err := NewBeforeAfterValidator(sqlDBRaw, `foo`, standardChangefeedOptions) require.NoError(t, err) assertValidatorFailures(t, v) - noteRow(t, v, `p`, `[1]`, `{"after": {"k":1, "geom":{"coordinates": [1,2], "type": "Point"}}}`, ts[0]) + noteRow(t, v, `p`, `[1]`, `{"after": {"k":1, "geom":{"coordinates": [1,2], "type": "Point"}}}`, ts[0], `foo`) } func TestFingerprintValidator(t *testing.T) { @@ -326,7 +354,7 @@ func TestFingerprintValidator(t *testing.T) { sqlDB.Exec(t, createTableStmt(`wrong_data`)) v, err := NewFingerprintValidator(sqlDBRaw, `foo`, `wrong_data`, []string{`p`}, testColumns) require.NoError(t, err) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":10}}`, ts[1]) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":10}}`, ts[1], `foo`) noteResolved(t, v, `p`, ts[1]) assertValidatorFailures(t, v, `fingerprints did not match at `+ts[1].AsOfSystemTime()+ @@ -340,14 +368,14 @@ func TestFingerprintValidator(t *testing.T) { if err := v.NoteResolved(`p`, ts[0]); err != nil { t.Fatal(err) } - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1]) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1], `foo`) noteResolved(t, v, `p`, ts[1]) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2]) - noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2]) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2], `foo`) + noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2], `foo`) noteResolved(t, v, `p`, ts[2]) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":3}}`, ts[3]) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":3}}`, ts[3], `foo`) noteResolved(t, v, `p`, ts[3]) - noteRow(t, v, ignored, `[1]`, `{"after": null}`, ts[4]) + noteRow(t, v, ignored, `[1]`, `{"after": null}`, ts[4], `foo`) noteResolved(t, v, `p`, ts[4]) noteResolved(t, v, `p`, ts[5]) assertValidatorFailures(t, v) @@ -356,11 +384,11 @@ func TestFingerprintValidator(t *testing.T) { sqlDB.Exec(t, createTableStmt(`rows_unsorted`)) v, err := NewFingerprintValidator(sqlDBRaw, `foo`, `rows_unsorted`, []string{`p`}, testColumns) require.NoError(t, err) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":3}}`, ts[3]) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2]) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1]) - noteRow(t, v, ignored, `[1]`, `{"after": null}`, ts[4]) - noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2]) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":3}}`, ts[3], `foo`) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2], `foo`) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1], `foo`) + noteRow(t, v, ignored, `[1]`, `{"after": null}`, ts[4], `foo`) + noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2], `foo`) noteResolved(t, v, `p`, ts[5]) assertValidatorFailures(t, v) }) @@ -371,9 +399,9 @@ func TestFingerprintValidator(t *testing.T) { noteResolved(t, v, `p`, ts[0]) // Intentionally missing {"k":1,"v":1} at ts[1]. // Insert a fake row since we don't fingerprint earlier than the first seen row. - noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2].Prev()) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2]) - noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2]) + noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2].Prev(), `foo`) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2], `foo`) + noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2], `foo`) noteResolved(t, v, `p`, ts[2].Prev()) assertValidatorFailures(t, v, `fingerprints did not match at `+ts[2].Prev().AsOfSystemTime()+ @@ -385,11 +413,11 @@ func TestFingerprintValidator(t *testing.T) { v, err := NewFingerprintValidator(sqlDBRaw, `foo`, `missed_middle`, []string{`p`}, testColumns) require.NoError(t, err) noteResolved(t, v, `p`, ts[0]) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1]) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1], `foo`) // Intentionally missing {"k":1,"v":2} at ts[2]. - noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2]) + noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2], `foo`) noteResolved(t, v, `p`, ts[2]) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":3}}`, ts[3]) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":3}}`, ts[3], `foo`) noteResolved(t, v, `p`, ts[3]) assertValidatorFailures(t, v, `fingerprints did not match at `+ts[2].AsOfSystemTime()+ @@ -403,9 +431,9 @@ func TestFingerprintValidator(t *testing.T) { v, err := NewFingerprintValidator(sqlDBRaw, `foo`, `missed_end`, []string{`p`}, testColumns) require.NoError(t, err) noteResolved(t, v, `p`, ts[0]) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1]) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2]) - noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2]) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1], `foo`) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2], `foo`) + noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[2], `foo`) // Intentionally missing {"k":1,"v":3} at ts[3]. noteResolved(t, v, `p`, ts[3]) assertValidatorFailures(t, v, @@ -417,8 +445,8 @@ func TestFingerprintValidator(t *testing.T) { sqlDB.Exec(t, createTableStmt(`initial_scan`)) v, err := NewFingerprintValidator(sqlDBRaw, `foo`, `initial_scan`, []string{`p`}, testColumns) require.NoError(t, err) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":3}}`, ts[3]) - noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[3]) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":3}}`, ts[3], `foo`) + noteRow(t, v, ignored, `[2]`, `{"after": {"k":2,"v":2}}`, ts[3], `foo`) noteResolved(t, v, `p`, ts[3]) assertValidatorFailures(t, v) }) @@ -434,7 +462,7 @@ func TestFingerprintValidator(t *testing.T) { sqlDB.Exec(t, createTableStmt(`resolved_unsorted`)) v, err := NewFingerprintValidator(sqlDBRaw, `foo`, `resolved_unsorted`, []string{`p`}, testColumns) require.NoError(t, err) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1]) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1], `foo`) noteResolved(t, v, `p`, ts[1]) noteResolved(t, v, `p`, ts[1]) noteResolved(t, v, `p`, ts[0]) @@ -444,8 +472,8 @@ func TestFingerprintValidator(t *testing.T) { sqlDB.Exec(t, createTableStmt(`two_partitions`)) v, err := NewFingerprintValidator(sqlDBRaw, `foo`, `two_partitions`, []string{`p0`, `p1`}, testColumns) require.NoError(t, err) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1]) - noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2]) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":1}}`, ts[1], `foo`) + noteRow(t, v, ignored, `[1]`, `{"after": {"k":1,"v":2}}`, ts[2], `foo`) // Intentionally missing {"k":2,"v":2}. noteResolved(t, v, `p0`, ts[2]) noteResolved(t, v, `p0`, ts[4]) @@ -478,7 +506,7 @@ func TestValidators(t *testing.T) { NewOrderValidator(`t2`), } noteResolved(t, v, `p1`, ts(2)) - noteRow(t, v, `p1`, `k1`, ignored, ts(1)) + noteRow(t, v, `p1`, `k1`, ignored, ts(1), `foo`) assertValidatorFailures(t, v, `topic t1 partition p1`+ `: saw new row timestamp 1.0000000000 after 2.0000000000 was resolved`, diff --git a/pkg/ccl/changefeedccl/changefeed_dist.go b/pkg/ccl/changefeedccl/changefeed_dist.go index d2db50babc54..ee3c2f092483 100644 --- a/pkg/ccl/changefeedccl/changefeed_dist.go +++ b/pkg/ccl/changefeedccl/changefeed_dist.go @@ -149,7 +149,7 @@ func fetchTableDescriptors( ) error { targetDescs = make([]catalog.TableDescriptor, 0, targets.NumUniqueTables()) if err := txn.KV().SetFixedTimestamp(ctx, ts); err != nil { - return err + return errors.Wrapf(err, "setting timestamp for table descriptor fetch") } // Note that all targets are currently guaranteed to have a Table ID // and lie within the primary index span. Deduplication is important @@ -157,7 +157,7 @@ func fetchTableDescriptors( return targets.EachTableID(func(id catid.DescID) error { tableDesc, err := descriptors.ByIDWithoutLeased(txn.KV()).WithoutNonPublic().Get().Table(ctx, id) if err != nil { - return err + return errors.Wrapf(err, "fetching table descriptor %d", id) } targetDescs = append(targetDescs, tableDesc) return nil diff --git a/pkg/ccl/changefeedccl/nemeses_test.go b/pkg/ccl/changefeedccl/nemeses_test.go index 4c2bd7975b6b..3386784cf7c7 100644 --- a/pkg/ccl/changefeedccl/nemeses_test.go +++ b/pkg/ccl/changefeedccl/nemeses_test.go @@ -8,7 +8,6 @@ package changefeedccl import ( "math" "regexp" - "strings" "testing" "github.com/cockroachdb/cockroach/pkg/ccl/changefeedccl/cdctest" @@ -35,12 +34,8 @@ func TestChangefeedNemeses(t *testing.T) { sqlDB := sqlutils.MakeSQLRunner(s.DB) withLegacySchemaChanger := maybeDisableDeclarativeSchemaChangesForTest(t, sqlDB) - // TODO(dan): Ugly hack to disable `eventPause` in sinkless feeds. See comment in - // `RunNemesis` for details. - isSinkless := strings.Contains(t.Name(), "sinkless") - isCloudstorage := strings.Contains(t.Name(), "cloudstorage") - v, err := cdctest.RunNemesis(f, s.DB, isSinkless, isCloudstorage, withLegacySchemaChanger, rng, nop) + v, err := cdctest.RunNemesis(f, s.DB, t.Name(), withLegacySchemaChanger, rng, nop) if err != nil { t.Fatalf("%+v", err) } diff --git a/pkg/ccl/changefeedccl/protected_timestamps_test.go b/pkg/ccl/changefeedccl/protected_timestamps_test.go index 8d854322e6fc..1611348ba81f 100644 --- a/pkg/ccl/changefeedccl/protected_timestamps_test.go +++ b/pkg/ccl/changefeedccl/protected_timestamps_test.go @@ -26,7 +26,9 @@ import ( "github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts/ptpb" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/spanconfig" + "github.com/cockroachdb/cockroach/pkg/spanconfig/spanconfigjob" "github.com/cockroachdb/cockroach/pkg/spanconfig/spanconfigptsreader" "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/catalog/bootstrap" @@ -454,16 +456,41 @@ func TestPTSRecordProtectsTargetsAndSystemTables(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - s, db, stopServer := startTestFullServer(t, feedTestOptions{}) + ctx := context.Background() + + // Useful for debugging. + require.NoError(t, log.SetVModule("spanconfigstore=2,store=2,reconciler=3,mvcc_gc_queue=2,kvaccessor=2")) + + settings := cluster.MakeTestingClusterSettings() + spanconfigjob.ReconciliationJobCheckpointInterval.Override(ctx, &settings.SV, 1*time.Second) + + // Keep track of where the spanconfig reconciler is up to. + lastReconcilerCheckpoint := atomic.Value{} + lastReconcilerCheckpoint.Store(hlc.Timestamp{}) + s, db, stopServer := startTestFullServer(t, feedTestOptions{ + knobsFn: func(knobs *base.TestingKnobs) { + if knobs.SpanConfig == nil { + knobs.SpanConfig = &spanconfig.TestingKnobs{} + } + scKnobs := knobs.SpanConfig.(*spanconfig.TestingKnobs) + scKnobs.JobOnCheckpointInterceptor = func(lastCheckpoint hlc.Timestamp) error { + now := hlc.Timestamp{WallTime: time.Now().UnixNano()} + t.Logf("reconciler checkpoint %s (%s)", lastCheckpoint, now.GoTime().Sub(lastCheckpoint.GoTime())) + lastReconcilerCheckpoint.Store(lastCheckpoint) + return nil + } + scKnobs.SQLWatcherCheckpointNoopsEveryDurationOverride = 1 * time.Second + }, + settings: settings, + }) + defer stopServer() execCfg := s.ExecutorConfig().(sql.ExecutorConfig) sqlDB := sqlutils.MakeSQLRunner(db) - sqlDB.Exec(t, `ALTER DATABASE system CONFIGURE ZONE USING gc.ttlseconds = 1`) sqlDB.Exec(t, "CREATE TABLE foo (a INT, b STRING)") sqlDB.Exec(t, `CREATE USER test`) sqlDB.Exec(t, `GRANT admin TO test`) ts := s.Clock().Now() - ctx := context.Background() fooDescr := cdctest.GetHydratedTableDescriptor(t, s.ExecutorConfig(), "d", "foo") var targets changefeedbase.Targets @@ -471,12 +498,30 @@ func TestPTSRecordProtectsTargetsAndSystemTables(t *testing.T) { TableID: fooDescr.GetID(), }) + // We need to give our PTS record a legit job ID so the protected ts + // reconciler doesn't delete it, so start up a dummy changefeed job and use its id. + registry := s.JobRegistry().(*jobs.Registry) + dummyJobDone := make(chan struct{}) + defer close(dummyJobDone) + registry.TestingWrapResumerConstructor(jobspb.TypeChangefeed, + func(raw jobs.Resumer) jobs.Resumer { + return &fakeResumer{done: dummyJobDone} + }) + var jobID jobspb.JobID + sqlDB.QueryRow(t, `CREATE CHANGEFEED FOR TABLE foo INTO 'null://'`).Scan(&jobID) + waitForJobStatus(sqlDB, t, jobID, `running`) + // Lay protected timestamp record. - ptr := createProtectedTimestampRecord(ctx, s.Codec(), 42, targets, ts) + ptr := createProtectedTimestampRecord(ctx, s.Codec(), jobID, targets, ts) require.NoError(t, execCfg.InternalDB.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { return execCfg.ProtectedTimestampProvider.WithTxn(txn).Protect(ctx, ptr) })) + // Set GC TTL to a small value to make the tables GC'd. We need to set this + // *after* we set the PTS record so that we dont GC the tables before + // the PTS is applied/picked up. + sqlDB.Exec(t, `ALTER DATABASE system CONFIGURE ZONE USING gc.ttlseconds = 1`) + // The following code was shameless stolen from // TestShowTenantFingerprintsProtectsTimestamp which almost // surely copied it from the 2-3 other tests that have @@ -512,7 +557,7 @@ func TestPTSRecordProtectsTargetsAndSystemTables(t *testing.T) { var rangeID int64 row.Scan(&rangeID) refreshPTSReaderCache(s.Clock().Now(), tableName, databaseName) - t.Logf("enqueuing range %d for mvccGC", rangeID) + t.Logf("enqueuing range %d (table %s.%s) for mvccGC", rangeID, tableName, databaseName) sqlDB.Exec(t, `SELECT crdb_internal.kv_enqueue_replica($1, 'mvccGC', true)`, rangeID) } @@ -526,7 +571,21 @@ func TestPTSRecordProtectsTargetsAndSystemTables(t *testing.T) { // Change the user's password to update the users table. sqlDB.Exec(t, `ALTER USER test WITH PASSWORD 'testpass'`) + // Sleep for enough time to pass the configured GC threshold (1 second). time.Sleep(2 * time.Second) + + // Wait for the spanconfigs to be reconciled. + now := hlc.Timestamp{WallTime: time.Now().UnixNano()} + t.Logf("waiting for spanconfigs to be reconciled") + testutils.SucceedsWithin(t, func() error { + lastCheckpoint := lastReconcilerCheckpoint.Load().(hlc.Timestamp) + if lastCheckpoint.Less(now) { + return errors.Errorf("last checkpoint %s is not less than now %s", lastCheckpoint, now) + } + t.Logf("last reconciler checkpoint ok at %s", lastCheckpoint) + return nil + }, 1*time.Minute) + // If you want to GC all system tables: // // tabs := systemschema.MakeSystemTables() @@ -535,6 +594,7 @@ func TestPTSRecordProtectsTargetsAndSystemTables(t *testing.T) { // gcTestTableRange("system", t.GetName()) // } // } + t.Logf("GC'ing system tables") gcTestTableRange("system", "descriptor") gcTestTableRange("system", "zones") gcTestTableRange("system", "comments") diff --git a/pkg/ccl/changefeedccl/validations_test.go b/pkg/ccl/changefeedccl/validations_test.go index ad19df9f81a1..c0fec8a03009 100644 --- a/pkg/ccl/changefeedccl/validations_test.go +++ b/pkg/ccl/changefeedccl/validations_test.go @@ -87,7 +87,7 @@ func TestCatchupScanOrdering(t *testing.T) { if err != nil { t.Fatal(err) } - err = v.NoteRow(m.Partition, string(m.Key), string(m.Value), updated) + err = v.NoteRow(m.Partition, string(m.Key), string(m.Value), updated, m.Topic) if err != nil { t.Fatal(err) } diff --git a/pkg/cmd/roachtest/tests/cdc.go b/pkg/cmd/roachtest/tests/cdc.go index 3cdde1b96141..7359b210bc87 100644 --- a/pkg/cmd/roachtest/tests/cdc.go +++ b/pkg/cmd/roachtest/tests/cdc.go @@ -898,7 +898,7 @@ func runCDCBank(ctx context.Context, t test.Test, c cluster.Cluster) { partitionStr := strconv.Itoa(int(m.Partition)) if len(m.Key) > 0 { - if err := v.NoteRow(partitionStr, string(m.Key), string(m.Value), updated); err != nil { + if err := v.NoteRow(partitionStr, string(m.Key), string(m.Value), updated, m.Topic); err != nil { return err } } else { @@ -926,7 +926,11 @@ func runCDCBank(ctx context.Context, t test.Test, c cluster.Cluster) { if err != nil { return errors.Wrap(err, "error creating validator") } - baV, err := cdctest.NewBeforeAfterValidator(db, `bank.bank`) + baV, err := cdctest.NewBeforeAfterValidator(db, `bank.bank`, cdctest.ChangefeedOption{ + FullTableName: false, + KeyInValue: false, + Format: "json", + }) if err != nil { return err } @@ -953,7 +957,7 @@ func runCDCBank(ctx context.Context, t test.Test, c cluster.Cluster) { partitionStr := strconv.Itoa(int(m.Partition)) if len(m.Key) > 0 { startTime := timeutil.Now() - if err := v.NoteRow(partitionStr, string(m.Key), string(m.Value), updated); err != nil { + if err := v.NoteRow(partitionStr, string(m.Key), string(m.Value), updated, m.Topic); err != nil { return err } timeSpentValidatingRows += timeutil.Since(startTime) @@ -3890,7 +3894,7 @@ func (c *topicConsumer) validateMessage(partition int32, m *sarama.ConsumerMessa return err } default: - err := c.validator.NoteRow(partitionStr, string(m.Key), string(m.Value), updated) + err := c.validator.NoteRow(partitionStr, string(m.Key), string(m.Value), updated, m.Topic) if err != nil { return err } diff --git a/pkg/cmd/roachtest/tests/mixed_version_cdc.go b/pkg/cmd/roachtest/tests/mixed_version_cdc.go index 0ff218bc0555..ee110a6619ed 100644 --- a/pkg/cmd/roachtest/tests/mixed_version_cdc.go +++ b/pkg/cmd/roachtest/tests/mixed_version_cdc.go @@ -301,7 +301,7 @@ func (cmvt *cdcMixedVersionTester) validate( partitionStr := strconv.Itoa(int(m.Partition)) if len(m.Key) > 0 { - if err := cmvt.validator.NoteRow(partitionStr, string(m.Key), string(m.Value), updated); err != nil { + if err := cmvt.validator.NoteRow(partitionStr, string(m.Key), string(m.Value), updated, m.Topic); err != nil { return err } } else { diff --git a/pkg/crosscluster/logical/BUILD.bazel b/pkg/crosscluster/logical/BUILD.bazel index e84738aa1213..60304f80292b 100644 --- a/pkg/crosscluster/logical/BUILD.bazel +++ b/pkg/crosscluster/logical/BUILD.bazel @@ -162,6 +162,7 @@ go_test( "//pkg/testutils/skip", "//pkg/testutils/sqlutils", "//pkg/testutils/testcluster", + "//pkg/util", "//pkg/util/allstacks", "//pkg/util/hlc", "//pkg/util/leaktest", diff --git a/pkg/crosscluster/logical/dead_letter_queue_test.go b/pkg/crosscluster/logical/dead_letter_queue_test.go index b341bb30ab45..ac7192731408 100644 --- a/pkg/crosscluster/logical/dead_letter_queue_test.go +++ b/pkg/crosscluster/logical/dead_letter_queue_test.go @@ -15,6 +15,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/ccl/changefeedccl/cdcevent" "github.com/cockroachdb/cockroach/pkg/ccl/changefeedccl/cdctest" "github.com/cockroachdb/cockroach/pkg/ccl/changefeedccl/changefeedbase" + "github.com/cockroachdb/cockroach/pkg/crosscluster/replicationtestutils" "github.com/cockroachdb/cockroach/pkg/jobs/jobspb" "github.com/cockroachdb/cockroach/pkg/kv" "github.com/cockroachdb/cockroach/pkg/kv/kvserver" @@ -505,6 +506,7 @@ func testEndToEndDLQ(t *testing.T, mode string) { dbA.Exec(t, "SET CLUSTER SETTING logical_replication.consumer.retry_queue_duration = '100ms'") dbA.Exec(t, "SET CLUSTER SETTING logical_replication.consumer.retry_queue_backoff = '1ms'") + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) type testCase struct { tableName string diff --git a/pkg/crosscluster/logical/logical_replication_job.go b/pkg/crosscluster/logical/logical_replication_job.go index 467c8d4b985d..b72eff050b4a 100644 --- a/pkg/crosscluster/logical/logical_replication_job.go +++ b/pkg/crosscluster/logical/logical_replication_job.go @@ -190,13 +190,18 @@ func (r *logicalReplicationResumer) ingest( if err != nil { return err } - if err := r.job.NoTxn().Update(ctx, func(txn isql.Txn, md jobs.JobMetadata, ju *jobs.JobUpdater) error { - ldrProg := md.Progress.Details.(*jobspb.Progress_LogicalReplication).LogicalReplication - ldrProg.PartitionConnUris = planInfo.partitionPgUrls - ju.UpdateProgress(md.Progress) - return nil - }); err != nil { - return err + + // If the routing mode is gateway, we don't want to checkpoint addresses + // since they may not be in the same network. + if uris[0].RoutingMode() != streamclient.RoutingModeGateway { + if err := r.job.NoTxn().Update(ctx, func(txn isql.Txn, md jobs.JobMetadata, ju *jobs.JobUpdater) error { + ldrProg := md.Progress.Details.(*jobspb.Progress_LogicalReplication).LogicalReplication + ldrProg.PartitionConnUris = planInfo.partitionPgUrls + ju.UpdateProgress(md.Progress) + return nil + }); err != nil { + return err + } } // Update the local progress copy as it was just updated. progress = r.job.Progress().Details.(*jobspb.Progress_LogicalReplication).LogicalReplication diff --git a/pkg/crosscluster/logical/logical_replication_job_test.go b/pkg/crosscluster/logical/logical_replication_job_test.go index 4556a6ae9ff7..4f79b32e8425 100644 --- a/pkg/crosscluster/logical/logical_replication_job_test.go +++ b/pkg/crosscluster/logical/logical_replication_job_test.go @@ -10,6 +10,7 @@ import ( "context" gosql "database/sql" "fmt" + "net" "net/url" "slices" "strings" @@ -49,6 +50,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils/skip" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" + "github.com/cockroachdb/cockroach/pkg/util" "github.com/cockroachdb/cockroach/pkg/util/allstacks" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/leaktest" @@ -159,8 +161,8 @@ func TestLogicalStreamIngestionJobNameResolution(t *testing.T) { server, s, dbA, dbB := setupLogicalTestServer(t, ctx, testClusterBaseClusterArgs, 1) defer server.Stopper().Stop(ctx) - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() + + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -205,7 +207,6 @@ func TestLogicalStreamIngestionJob(t *testing.T) { t.Run(mode, func(t *testing.T) { testLogicalStreamIngestionJobBasic(t, mode) }) - } } @@ -270,18 +271,8 @@ func testLogicalStreamIngestionJobBasic(t *testing.T, mode string) { dbA.Exec(t, "INSERT INTO tab VALUES (1, 'hello')") dbB.Exec(t, "INSERT INTO tab VALUES (1, 'goodbye')") - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() - - // Swap one of the URLs to external:// to verify this indirection works. - // TODO(dt): this create should support placeholder for URI. - dbB.Exec(t, "CREATE EXTERNAL CONNECTION a AS '"+dbAURL.String()+"'") - dbAURL = url.URL{ - Scheme: "external", - Host: "a", - } + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) var ( jobAID jobspb.JobID @@ -324,18 +315,8 @@ func TestLogicalStreamIngestionJobWithCursor(t *testing.T) { dbA.Exec(t, "INSERT INTO tab VALUES (1, 'hello')") dbB.Exec(t, "INSERT INTO tab VALUES (1, 'goodbye')") - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() - - // Swap one of the URLs to external:// to verify this indirection works. - // TODO(dt): this create should support placeholder for URI. - dbB.Exec(t, "CREATE EXTERNAL CONNECTION a AS '"+dbAURL.String()+"'") - dbAURL = url.URL{ - Scheme: "external", - Host: "a", - } + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) var ( jobAID jobspb.JobID @@ -390,8 +371,7 @@ func TestCreateTables(t *testing.T) { defer tc.Stopper().Stop(ctx) sqlA := sqlDBs[0] - aURL, cleanup := srv.PGUrl(t, serverutils.DBName("a")) - defer cleanup() + aURL := replicationtestutils.GetReplicationUri(t, srv, srv, serverutils.DBName("a")) t.Run("basic", func(t *testing.T) { // Ensure the offline scan replicates index spans. @@ -519,18 +499,8 @@ func TestLogicalStreamIngestionAdvancePTS(t *testing.T) { dbA.Exec(t, "INSERT INTO tab VALUES (1, 'hello')") dbB.Exec(t, "INSERT INTO tab VALUES (1, 'goodbye')") - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() - - // Swap one of the URLs to external:// to verify this indirection works. - // TODO(dt): this create should support placeholder for URI. - dbB.Exec(t, "CREATE EXTERNAL CONNECTION a AS '"+dbAURL.String()+"'") - dbAURL = url.URL{ - Scheme: "external", - Host: "a", - } + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) var ( jobAID jobspb.JobID @@ -565,8 +535,7 @@ func TestLogicalStreamIngestionCancelUpdatesProducerJob(t *testing.T) { dbA.Exec(t, "SET CLUSTER SETTING physical_replication.producer.stream_liveness_track_frequency='50ms'") - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) var jobBID jobspb.JobID dbB.QueryRow(t, "CREATE LOGICAL REPLICATION STREAM FROM TABLE tab ON $1 INTO TABLE tab", dbAURL.String()).Scan(&jobBID) @@ -595,8 +564,7 @@ func TestRestoreFromLDR(t *testing.T) { server, s, dbA, dbB := setupLogicalTestServer(t, ctx, args, 1) defer server.Stopper().Stop(ctx) - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) var jobBID jobspb.JobID dbA.Exec(t, "INSERT INTO tab VALUES (1, 'hello')") @@ -626,8 +594,7 @@ func TestImportIntoLDR(t *testing.T) { server, s, dbA, dbB := setupLogicalTestServer(t, ctx, args, 1) defer server.Stopper().Stop(ctx) - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) var jobBID jobspb.JobID dbA.Exec(t, "INSERT INTO tab VALUES (1, 'hello')") @@ -651,8 +618,7 @@ func TestLogicalStreamIngestionErrors(t *testing.T) { server := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{}) defer server.Stopper().Stop(ctx) s := server.Server(0).ApplicationLayer() - url, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() + url := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) urlA := url.String() _, err := server.Conns[0].Exec("CREATE DATABASE a") @@ -715,9 +681,8 @@ family f2(other_payload, v2)) serverASQL.Exec(t, "INSERT INTO tab_with_cf(pk, payload, other_payload) VALUES (1, 'hello', 'ruroh1')") - serverAURL, cleanup := s.PGUrl(t) + serverAURL := replicationtestutils.GetReplicationUri(t, s, s) serverAURL.Path = "a" - defer cleanup() var jobBID jobspb.JobID serverBSQL.QueryRow(t, "CREATE LOGICAL REPLICATION STREAM FROM TABLE tab_with_cf ON $1 INTO TABLE tab_with_cf WITH MODE = validated", serverAURL.String()).Scan(&jobBID) @@ -748,9 +713,7 @@ func TestLogicalReplicationWithPhantomDelete(t *testing.T) { tc, s, serverASQL, serverBSQL := setupLogicalTestServer(t, ctx, testClusterBaseClusterArgs, 1) defer tc.Stopper().Stop(ctx) - serverAURL, cleanup := s.PGUrl(t) - serverAURL.Path = "a" - defer cleanup() + serverAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) for _, mode := range []string{"validated", "immediate"} { t.Run(mode, func(t *testing.T) { @@ -789,10 +752,8 @@ func TestFilterRangefeedInReplicationStream(t *testing.T) { dbA, dbB, dbC := dbs[0], dbs[1], dbs[2] - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) var jobAID, jobBID, jobCID jobspb.JobID @@ -857,8 +818,7 @@ func TestRandomTables(t *testing.T) { var tableName, streamStartStmt string rng, _ := randutil.NewPseudoRand() - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) // Keep retrying until the random table satisfies all the static checks // we make when creating the replication stream. @@ -1013,8 +973,7 @@ func TestPreviouslyInterestingTables(t *testing.T) { baseTableName := "rand_table" rng, _ := randutil.NewPseudoRand() numInserts := 20 - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) for i, tc := range testCases { t.Run(tc.name, func(t *testing.T) { tableName := fmt.Sprintf("%s%d", baseTableName, i) @@ -1104,10 +1063,8 @@ func TestLogicalAutoReplan(t *testing.T) { serverutils.SetClusterSetting(t, server, "logical_replication.replan_flow_threshold", 0) serverutils.SetClusterSetting(t, server, "logical_replication.replan_flow_frequency", time.Millisecond*500) - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) var ( jobAID jobspb.JobID @@ -1170,8 +1127,7 @@ func TestLogicalJobResiliency(t *testing.T) { server, s, dbA, dbB := setupLogicalTestServer(t, ctx, clusterArgs, 3) defer server.Stopper().Stop(ctx) - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) CreateScatteredTable(t, dbB, 2, "B") @@ -1222,10 +1178,8 @@ func TestHeartbeatCancel(t *testing.T) { serverutils.SetClusterSetting(t, server, "logical_replication.consumer.heartbeat_frequency", time.Second*1) - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) var ( jobAID jobspb.JobID @@ -1270,8 +1224,7 @@ func TestMultipleSourcesIntoSingleDest(t *testing.T) { server, s, runners, dbNames := setupServerWithNumDBs(t, ctx, clusterArgs, 1, 3) defer server.Stopper().Stop(ctx) - PGURLs, cleanup := GetPGURLs(t, s, dbNames) - defer cleanup() + PGURLs := GetPGURLs(t, s, dbNames) dbA, dbB, dbC := runners[0], runners[1], runners[2] @@ -1358,8 +1311,7 @@ func TestFourWayReplication(t *testing.T) { server, s, runners, dbNames := setupServerWithNumDBs(t, ctx, clusterArgs, 1, numDBs) defer server.Stopper().Stop(ctx) - PGURLs, cleanup := GetPGURLs(t, s, dbNames) - defer cleanup() + PGURLs := GetPGURLs(t, s, dbNames) // Each row is a DB, each column is a jobID from another DB to that target DB jobIDs := make([][]jobspb.JobID, numDBs) @@ -1415,8 +1367,7 @@ func TestForeignKeyConstraints(t *testing.T) { server, s, dbA, _ := setupLogicalTestServer(t, ctx, clusterArgs, 1) defer server.Stopper().Stop(ctx) - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) dbA.Exec(t, "CREATE TABLE test(a int primary key, b int)") @@ -1563,22 +1514,13 @@ func CreateScatteredTable(t *testing.T, db *sqlutils.SQLRunner, numNodes int, db }, timeout) } -func GetPGURLs( - t *testing.T, s serverutils.ApplicationLayerInterface, dbNames []string, -) ([]url.URL, func()) { +func GetPGURLs(t *testing.T, s serverutils.ApplicationLayerInterface, dbNames []string) []url.URL { result := []url.URL{} - cleanups := []func(){} for _, name := range dbNames { - resultURL, cleanup := s.PGUrl(t, serverutils.DBName(name)) + resultURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName(name)) result = append(result, resultURL) - cleanups = append(cleanups, cleanup) - } - - return result, func() { - for _, f := range cleanups { - f() - } } + return result } func WaitUntilReplicatedTime( @@ -1728,18 +1670,8 @@ func TestLogicalStreamIngestionJobWithFallbackUDF(t *testing.T) { dbB.Exec(t, lwwFunc) dbA.Exec(t, lwwFunc) - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() - - // Swap one of the URLs to external:// to verify this indirection works. - // TODO(dt): this create should support placeholder for URI. - dbB.Exec(t, "CREATE EXTERNAL CONNECTION a AS '"+dbAURL.String()+"'") - dbAURL = url.URL{ - Scheme: "external", - Host: "a", - } + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) var ( jobAID jobspb.JobID @@ -1857,15 +1789,8 @@ func TestShowLogicalReplicationJobs(t *testing.T) { server, s, dbA, dbB := setupLogicalTestServer(t, ctx, testClusterBaseClusterArgs, 1) defer server.Stopper().Stop(ctx) - dbAURL, cleanup := s.PGUrl(t, - serverutils.DBName("a"), - serverutils.UserPassword(username.RootUser, "password")) - defer cleanup() - - dbBURL, cleanupB := s.PGUrl(t, - serverutils.DBName("b"), - serverutils.UserPassword(username.RootUser, "password")) - defer cleanupB() + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a"), serverutils.UserPassword(username.RootUser, "password")) + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b"), serverutils.UserPassword(username.RootUser, "password")) var ( jobAID jobspb.JobID @@ -2005,8 +1930,7 @@ func TestUserPrivileges(t *testing.T) { server, s, dbA, _ := setupLogicalTestServer(t, ctx, clusterArgs, 1) defer server.Stopper().Stop(ctx) - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) // Create user with no privileges dbA.Exec(t, fmt.Sprintf("CREATE USER %s", username.TestUser)) @@ -2084,8 +2008,7 @@ func TestLogicalReplicationSchemaChanges(t *testing.T) { server, s, dbA, dbB := setupLogicalTestServer(t, ctx, clusterArgs, 1) defer server.Stopper().Stop(ctx) - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) var jobAID jobspb.JobID dbA.QueryRow(t, "CREATE LOGICAL REPLICATION STREAM FROM TABLE tab ON $1 INTO TABLE tab", dbBURL.String()).Scan(&jobAID) @@ -2126,11 +2049,7 @@ func TestUserDefinedTypes(t *testing.T) { server, s, dbA, dbB := setupLogicalTestServer(t, ctx, clusterArgs, 1) defer server.Stopper().Stop(ctx) - _, cleanupA := s.PGUrl(t, serverutils.DBName("a")) - defer cleanupA() - - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) // Create the same user-defined type both tables. dbA.Exec(t, "CREATE TYPE my_enum AS ENUM ('one', 'two', 'three')") @@ -2175,6 +2094,52 @@ func TestUserDefinedTypes(t *testing.T) { } } +func TestLogicalReplicationGatewayRoute(t *testing.T) { + defer leaktest.AfterTest(t)() + + // Create a blackhole so we can claim a port and black hole any connections + // routed there. + blackhole, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + defer func() { + require.NoError(t, blackhole.Close()) + }() + + t.Log("blackhole listening on", blackhole.Addr()) + // Set the SQL advertise addr to something unroutable so that we know the + // config connection url was used for all streams. + args := testClusterBaseClusterArgs + args.ServerArgs.Knobs.Streaming = &sql.StreamingTestingKnobs{ + OnGetSQLInstanceInfo: func(node *roachpb.NodeDescriptor) *roachpb.NodeDescriptor { + copy := *node + copy.SQLAddress = util.UnresolvedAddr{ + NetworkField: "tcp", + AddressField: blackhole.Addr().String(), + } + return © + }, + } + ts, s, runners, dbs := setupServerWithNumDBs(t, context.Background(), args, 1, 2) + defer ts.Stopper().Stop(context.Background()) + + url, cleanup := s.PGUrl(t, serverutils.DBName(dbs[1])) + defer cleanup() + + q := url.Query() + q.Set(streamclient.RoutingModeKey, string(streamclient.RoutingModeGateway)) + url.RawQuery = q.Encode() + + var jobID jobspb.JobID + runners[0].QueryRow(t, "CREATE LOGICAL REPLICATION STREAM FROM TABLE tab ON $1 INTO TABLE tab", url.String()).Scan(&jobID) + runners[1].Exec(t, "INSERT INTO tab VALUES (1, 'hello')") + + now := s.Clock().Now() + WaitUntilReplicatedTime(t, now, runners[0], jobID) + + progress := jobutils.GetJobProgress(t, runners[0], jobID) + require.Empty(t, progress.Details.(*jobspb.Progress_LogicalReplication).LogicalReplication.PartitionConnUris) +} + // TestLogicalReplicationCreationChecks verifies that we check that the table // schemas are compatible when creating the replication stream. func TestLogicalReplicationCreationChecks(t *testing.T) { @@ -2195,8 +2160,7 @@ func TestLogicalReplicationCreationChecks(t *testing.T) { server, s, dbA, dbB := setupLogicalTestServer(t, ctx, clusterArgs, 1) defer server.Stopper().Stop(ctx) - dbBURL, cleanupB := s.PGUrl(t, serverutils.DBName("b")) - defer cleanupB() + dbBURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("b")) // Column families are not allowed. dbA.Exec(t, "ALTER TABLE tab ADD COLUMN new_col INT NOT NULL CREATE FAMILY f1") diff --git a/pkg/crosscluster/logical/udf_row_processor_test.go b/pkg/crosscluster/logical/udf_row_processor_test.go index 391552e69552..214b1497e657 100644 --- a/pkg/crosscluster/logical/udf_row_processor_test.go +++ b/pkg/crosscluster/logical/udf_row_processor_test.go @@ -85,8 +85,7 @@ func TestUDFWithRandomTables(t *testing.T) { sqlA, tableName, numInserts, nil) require.NoError(t, err) - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) streamStartStmt := fmt.Sprintf("CREATE LOGICAL REPLICATION STREAM FROM TABLE %[1]s ON $1 INTO TABLE %[1]s WITH FUNCTION repl_apply FOR TABLE %[1]s", tableName) var jobBID jobspb.JobID @@ -127,8 +126,7 @@ func TestUDFInsertOnly(t *testing.T) { $$ LANGUAGE plpgsql `) - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) streamStartStmt := fmt.Sprintf("CREATE LOGICAL REPLICATION STREAM FROM TABLE %[1]s ON $1 INTO TABLE %[1]s WITH DEFAULT FUNCTION = 'funcs.repl_apply'", tableName) var jobBID jobspb.JobID @@ -177,8 +175,7 @@ func TestUDFPreviousValue(t *testing.T) { $$ LANGUAGE plpgsql `) - dbAURL, cleanup := s.PGUrl(t, serverutils.DBName("a")) - defer cleanup() + dbAURL := replicationtestutils.GetReplicationUri(t, s, s, serverutils.DBName("a")) streamStartStmt := fmt.Sprintf("CREATE LOGICAL REPLICATION STREAM FROM TABLE %[1]s ON $1 INTO TABLE %[1]s WITH FUNCTION repl_apply FOR TABLE %[1]s", tableName) var jobBID jobspb.JobID diff --git a/pkg/crosscluster/physical/BUILD.bazel b/pkg/crosscluster/physical/BUILD.bazel index 35ec16478b67..c54f9119c611 100644 --- a/pkg/crosscluster/physical/BUILD.bazel +++ b/pkg/crosscluster/physical/BUILD.bazel @@ -178,6 +178,7 @@ go_test( "//pkg/testutils/sqlutils", "//pkg/testutils/storageutils", "//pkg/testutils/testcluster", + "//pkg/util", "//pkg/util/ctxgroup", "//pkg/util/duration", "//pkg/util/hlc", diff --git a/pkg/crosscluster/physical/alter_replication_job_test.go b/pkg/crosscluster/physical/alter_replication_job_test.go index c1eb3b1d7045..c93df515c740 100644 --- a/pkg/crosscluster/physical/alter_replication_job_test.go +++ b/pkg/crosscluster/physical/alter_replication_job_test.go @@ -8,7 +8,6 @@ package physical import ( "context" "fmt" - "net/url" "testing" "time" @@ -586,8 +585,7 @@ func TestAlterTenantStartReplicationAfterRestore(t *testing.T) { enforcedGC.ts = afterBackup enforcedGC.Unlock() - u, cleanupURLA := sqlutils.PGUrl(t, srv.SQLAddr(), t.Name(), url.User(username.RootUser)) - defer cleanupURLA() + u := replicationtestutils.GetReplicationUri(t, srv, srv, serverutils.User(username.RootUser)) db.Exec(t, "RESTORE TENANT 3 FROM LATEST IN 'nodelocal://1/t' WITH TENANT = '5', TENANT_NAME = 't2'") db.Exec(t, "ALTER TENANT t2 START REPLICATION OF t1 ON $1", u.String()) diff --git a/pkg/crosscluster/physical/replication_random_client_test.go b/pkg/crosscluster/physical/replication_random_client_test.go index e8e2d9e1c114..4606c071dfb5 100644 --- a/pkg/crosscluster/physical/replication_random_client_test.go +++ b/pkg/crosscluster/physical/replication_random_client_test.go @@ -82,7 +82,7 @@ func (sv *streamClientValidator) noteRow( ) error { sv.mu.Lock() defer sv.mu.Unlock() - return sv.NoteRow(partition, key, value, updated) + return sv.NoteRow(partition, key, value, updated, "" /* topic */) } func (sv *streamClientValidator) noteResolved(partition string, resolved hlc.Timestamp) error { diff --git a/pkg/crosscluster/physical/replication_stream_e2e_test.go b/pkg/crosscluster/physical/replication_stream_e2e_test.go index 4594ae64d3c1..2f379c099439 100644 --- a/pkg/crosscluster/physical/replication_stream_e2e_test.go +++ b/pkg/crosscluster/physical/replication_stream_e2e_test.go @@ -679,6 +679,7 @@ func TestTenantStreamingMultipleNodes(t *testing.T) { testutils.RunTrueAndFalse(t, "fromSystem", func(t *testing.T, sys bool) { args := replicationtestutils.DefaultTenantStreamingClustersArgs args.MultitenantSingleClusterNumNodes = 3 + args.RoutingMode = streamclient.RoutingModeNode // Track the number of unique addresses that were connected to clientAddresses := make(map[string]struct{}) @@ -787,6 +788,7 @@ func TestStreamingAutoReplan(t *testing.T) { ctx := context.Background() args := replicationtestutils.DefaultTenantStreamingClustersArgs args.MultitenantSingleClusterNumNodes = 1 + args.RoutingMode = streamclient.RoutingModeNode retryErrorChan := make(chan error) turnOffReplanning := make(chan struct{}) @@ -802,7 +804,6 @@ func TestStreamingAutoReplan(t *testing.T) { clientAddresses[addr] = struct{}{} }, AfterRetryIteration: func(err error) { - if err != nil && !alreadyReplanned.Load() { retryErrorChan <- err <-turnOffReplanning diff --git a/pkg/crosscluster/physical/stream_ingestion_dist.go b/pkg/crosscluster/physical/stream_ingestion_dist.go index 37269066e602..853d5707b758 100644 --- a/pkg/crosscluster/physical/stream_ingestion_dist.go +++ b/pkg/crosscluster/physical/stream_ingestion_dist.go @@ -124,22 +124,25 @@ func startDistIngestion( return err } - err = ingestionJob.NoTxn().Update(ctx, func(txn isql.Txn, md jobs.JobMetadata, ju *jobs.JobUpdater) error { - // Persist the initial Stream Addresses to the jobs table before execution begins. - if len(planner.initialPartitionPgUrls) == 0 { - return jobs.MarkAsPermanentJobError(errors.AssertionFailedf( - "attempted to persist an empty list of partition connection uris")) - } - md.Progress.GetStreamIngest().PartitionConnUris = make([]string, len(planner.initialPartitionPgUrls)) - for i := range planner.initialPartitionPgUrls { - md.Progress.GetStreamIngest().PartitionConnUris[i] = planner.initialPartitionPgUrls[i].Serialize() + if planner.initialPartitionPgUrls[0].RoutingMode() != streamclient.RoutingModeGateway { + err = ingestionJob.NoTxn().Update(ctx, func(txn isql.Txn, md jobs.JobMetadata, ju *jobs.JobUpdater) error { + // Persist the initial Stream Addresses to the jobs table before execution begins. + if len(planner.initialPartitionPgUrls) == 0 { + return jobs.MarkAsPermanentJobError(errors.AssertionFailedf( + "attempted to persist an empty list of partition connection uris")) + } + md.Progress.GetStreamIngest().PartitionConnUris = make([]string, len(planner.initialPartitionPgUrls)) + for i := range planner.initialPartitionPgUrls { + md.Progress.GetStreamIngest().PartitionConnUris[i] = planner.initialPartitionPgUrls[i].Serialize() + } + ju.UpdateProgress(md.Progress) + return nil + }) + if err != nil { + return errors.Wrap(err, "failed to update job progress") } - ju.UpdateProgress(md.Progress) - return nil - }) - if err != nil { - return errors.Wrap(err, "failed to update job progress") } + jobsprofiler.StorePlanDiagram(ctx, execCtx.ExecCfg().DistSQLSrv.Stopper, planner.initialPlan, execCtx.ExecCfg().InternalDB, ingestionJob.ID()) diff --git a/pkg/crosscluster/physical/stream_ingestion_job_test.go b/pkg/crosscluster/physical/stream_ingestion_job_test.go index 7d54e52349f3..31ef1ded3d16 100644 --- a/pkg/crosscluster/physical/stream_ingestion_job_test.go +++ b/pkg/crosscluster/physical/stream_ingestion_job_test.go @@ -9,6 +9,7 @@ import ( "context" gosql "database/sql" "fmt" + "net" "net/url" "testing" @@ -17,6 +18,7 @@ import ( _ "github.com/cockroachdb/cockroach/pkg/crosscluster/producer" "github.com/cockroachdb/cockroach/pkg/crosscluster/replicationtestutils" "github.com/cockroachdb/cockroach/pkg/crosscluster/replicationutils" + "github.com/cockroachdb/cockroach/pkg/crosscluster/streamclient" "github.com/cockroachdb/cockroach/pkg/jobs" "github.com/cockroachdb/cockroach/pkg/jobs/jobspb" "github.com/cockroachdb/cockroach/pkg/keys" @@ -31,6 +33,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/skip" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util" "github.com/cockroachdb/cockroach/pkg/util/ctxgroup" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/leaktest" @@ -129,26 +132,12 @@ func TestTenantStreamingFailback(t *testing.T) { sqlA := sqlutils.MakeSQLRunner(aDB) sqlB := sqlutils.MakeSQLRunner(bDB) - serverAURL, cleanupURLA := sqlutils.PGUrl(t, serverA.SQLAddr(), t.Name(), url.User(username.RootUser)) - defer cleanupURLA() - serverBURL, cleanupURLB := sqlutils.PGUrl(t, serverB.SQLAddr(), t.Name(), url.User(username.RootUser)) - defer cleanupURLB() - - for _, s := range []string{ - "SET CLUSTER SETTING kv.rangefeed.enabled = true", - "SET CLUSTER SETTING kv.rangefeed.closed_timestamp_refresh_interval = '200ms'", - "SET CLUSTER SETTING kv.closed_timestamp.target_duration = '100ms'", - "SET CLUSTER SETTING kv.closed_timestamp.side_transport_interval = '50ms'", - - "SET CLUSTER SETTING physical_replication.consumer.heartbeat_frequency = '1s'", - "SET CLUSTER SETTING physical_replication.consumer.job_checkpoint_frequency = '100ms'", - "SET CLUSTER SETTING physical_replication.consumer.minimum_flush_interval = '10ms'", - "SET CLUSTER SETTING physical_replication.consumer.failover_signal_poll_interval = '100ms'", - "SET CLUSTER SETTING spanconfig.reconciliation_job.checkpoint_interval = '100ms'", - } { - sqlA.Exec(t, s) - sqlB.Exec(t, s) - } + serverAURL := replicationtestutils.GetReplicationUri(t, serverA, serverB, serverutils.User(username.RootUser)) + serverBURL := replicationtestutils.GetReplicationUri(t, serverB, serverA, serverutils.User(username.RootUser)) + + replicationtestutils.ConfigureDefaultSettings(t, sqlA) + replicationtestutils.ConfigureDefaultSettings(t, sqlB) + compareAtTimetamp := func(ts string) { fingerprintQueryFmt := "SELECT fingerprint FROM [SHOW EXPERIMENTAL_FINGERPRINTS FROM TENANT %s] AS OF SYSTEM TIME %s" var fingerprintF int64 @@ -156,7 +145,6 @@ func TestTenantStreamingFailback(t *testing.T) { var fingerprintG int64 sqlB.QueryRow(t, fmt.Sprintf(fingerprintQueryFmt, "g", ts)).Scan(&fingerprintG) require.Equal(t, fingerprintF, fingerprintG, "fingerprint mismatch at %s", ts) - } // The overall test plan looks like: @@ -673,3 +661,59 @@ func waitUntilTenantServerStopped( return nil }) } + +func TestPhysicalReplicationGatewayRoute(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Create a blackhole so we can claim a port and black hole any connections + // routed there. + blackhole, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + defer func() { + require.NoError(t, blackhole.Close()) + }() + + srv, db, _ := serverutils.StartServer(t, base.TestServerArgs{ + DefaultTestTenant: base.TestControlsTenantsExplicitly, + Knobs: base.TestingKnobs{ + Streaming: &sql.StreamingTestingKnobs{ + OnGetSQLInstanceInfo: func(node *roachpb.NodeDescriptor) *roachpb.NodeDescriptor { + copy := *node + copy.SQLAddress = util.UnresolvedAddr{ + NetworkField: "tcp", + AddressField: blackhole.Addr().String(), + } + return © + }, + }, + }, + }) + defer srv.Stopper().Stop(context.Background()) + + systemDB := sqlutils.MakeSQLRunner(db) + + replicationtestutils.ConfigureDefaultSettings(t, systemDB) + + // Create the source tenant and start service + systemDB.Exec(t, "CREATE VIRTUAL CLUSTER source") + systemDB.Exec(t, "ALTER VIRTUAL CLUSTER source START SERVICE SHARED") + + serverURL, cleanup := srv.PGUrl(t) + defer cleanup() + + q := serverURL.Query() + q.Set(streamclient.RoutingModeKey, string(streamclient.RoutingModeGateway)) + serverURL.RawQuery = q.Encode() + + // Create the destination tenant by replicating the source cluster + systemDB.Exec(t, "CREATE VIRTUAL CLUSTER target FROM REPLICATION OF source ON $1", serverURL.String()) + + _, jobID := replicationtestutils.GetStreamJobIds(t, context.Background(), systemDB, "target") + + now := srv.Clock().Now() + replicationtestutils.WaitUntilReplicatedTime(t, now, systemDB, jobspb.JobID(jobID)) + + progress := jobutils.GetJobProgress(t, systemDB, jobspb.JobID(jobID)) + require.Empty(t, progress.Details.(*jobspb.Progress_StreamIngest).StreamIngest.PartitionConnUris) +} diff --git a/pkg/crosscluster/physical/stream_ingestion_planning_test.go b/pkg/crosscluster/physical/stream_ingestion_planning_test.go index 98ca5a9a1609..6dfc18c9782c 100644 --- a/pkg/crosscluster/physical/stream_ingestion_planning_test.go +++ b/pkg/crosscluster/physical/stream_ingestion_planning_test.go @@ -7,10 +7,10 @@ package physical import ( "context" - "net/url" "testing" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/crosscluster/replicationtestutils" "github.com/cockroachdb/cockroach/pkg/jobs" "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" @@ -43,8 +43,7 @@ func TestCreateTenantFromReplicationUsingID(t *testing.T) { sqlA := sqlutils.MakeSQLRunner(aDB) sqlB := sqlutils.MakeSQLRunner(bDB) - serverAURL, cleanupURLA := sqlutils.PGUrl(t, serverA.SQLAddr(), t.Name(), url.User(username.RootUser)) - defer cleanupURLA() + serverAURL := replicationtestutils.GetReplicationUri(t, serverA, serverB, serverutils.User(username.RootUser)) verifyCreatedTenant := func(t *testing.T, db *sqlutils.SQLRunner, id int64, fn func()) { const query = "SELECT count(*), count(CASE WHEN id = $1 THEN 1 END) FROM system.tenants" diff --git a/pkg/crosscluster/producer/replication_manager.go b/pkg/crosscluster/producer/replication_manager.go index 781e5e0d8789..02ca6ce7001a 100644 --- a/pkg/crosscluster/producer/replication_manager.go +++ b/pkg/crosscluster/producer/replication_manager.go @@ -43,6 +43,7 @@ type replicationStreamManagerImpl struct { resolver resolver.SchemaResolver txn descs.Txn sessionID clusterunique.ID + knobs *sql.StreamingTestingKnobs } // StartReplicationStream implements streaming.ReplicationStreamManager interface. @@ -215,7 +216,7 @@ func (r *replicationStreamManagerImpl) PlanLogicalReplication( } } - spec, err := buildReplicationStreamSpec(ctx, r.evalCtx, tenID, false, spans, useStreaksInLDR.Get(&r.evalCtx.Settings.SV)) + spec, err := r.buildReplicationStreamSpec(ctx, r.evalCtx, tenID, false, spans, useStreaksInLDR.Get(&r.evalCtx.Settings.SV)) if err != nil { return nil, err } @@ -271,7 +272,7 @@ func (r *replicationStreamManagerImpl) GetPhysicalReplicationStreamSpec( if err := r.checkLicense(); err != nil { return nil, err } - return getPhysicalReplicationStreamSpec(ctx, r.evalCtx, r.txn, streamID) + return r.getPhysicalReplicationStreamSpec(ctx, r.evalCtx, r.txn, streamID) } // CompleteReplicationStream implements ReplicationStreamManager interface. @@ -290,7 +291,7 @@ func (r *replicationStreamManagerImpl) SetupSpanConfigsStream( if err := r.checkLicense(); err != nil { return nil, err } - return setupSpanConfigsStream(ctx, r.evalCtx, r.txn, tenantName) + return r.setupSpanConfigsStream(ctx, r.evalCtx, r.txn, tenantName) } func (r *replicationStreamManagerImpl) DebugGetProducerStatuses( @@ -351,7 +352,11 @@ func newReplicationStreamManagerWithPrivilegesCheck( privilege.REPLICATION); err != nil { return nil, err } - return &replicationStreamManagerImpl{evalCtx: evalCtx, txn: txn, sessionID: sessionID, resolver: sc}, nil + + execCfg := evalCtx.Planner.ExecutorConfig().(*sql.ExecutorConfig) + knobs := execCfg.StreamingTestingKnobs + + return &replicationStreamManagerImpl{evalCtx: evalCtx, txn: txn, sessionID: sessionID, resolver: sc, knobs: knobs}, nil } func (r *replicationStreamManagerImpl) checkLicense() error { diff --git a/pkg/crosscluster/producer/stream_lifetime.go b/pkg/crosscluster/producer/stream_lifetime.go index 7378995adcbf..d42438b1a4f5 100644 --- a/pkg/crosscluster/producer/stream_lifetime.go +++ b/pkg/crosscluster/producer/stream_lifetime.go @@ -257,7 +257,7 @@ func heartbeatReplicationStream( } // getPhysicalReplicationStreamSpec gets a replication stream specification for the specified stream. -func getPhysicalReplicationStreamSpec( +func (r *replicationStreamManagerImpl) getPhysicalReplicationStreamSpec( ctx context.Context, evalCtx *eval.Context, txn isql.Txn, streamID streampb.StreamID, ) (*streampb.ReplicationStreamSpec, error) { jobExecCtx := evalCtx.JobExecContext.(sql.JobExecContext) @@ -274,11 +274,10 @@ func getPhysicalReplicationStreamSpec( if j.Status() != jobs.StatusRunning { return nil, jobIsNotRunningError(jobID, j.Status(), "create stream spec") } - return buildReplicationStreamSpec(ctx, evalCtx, details.TenantID, false, details.Spans, true) - + return r.buildReplicationStreamSpec(ctx, evalCtx, details.TenantID, false, details.Spans, true) } -func buildReplicationStreamSpec( +func (r *replicationStreamManagerImpl) buildReplicationStreamSpec( ctx context.Context, evalCtx *eval.Context, tenantID roachpb.TenantID, @@ -326,6 +325,9 @@ func buildReplicationStreamSpec( if err != nil { return nil, err } + if r.knobs != nil && r.knobs.OnGetSQLInstanceInfo != nil { + nodeInfo = r.knobs.OnGetSQLInstanceInfo(nodeInfo) + } res.Partitions = append(res.Partitions, streampb.ReplicationStreamSpec_Partition{ NodeID: roachpb.NodeID(sp.SQLInstanceID), SQLAddress: nodeInfo.SQLAddress, @@ -379,7 +381,7 @@ func completeReplicationStream( }) } -func setupSpanConfigsStream( +func (r *replicationStreamManagerImpl) setupSpanConfigsStream( ctx context.Context, evalCtx *eval.Context, txn isql.Txn, tenantName roachpb.TenantName, ) (eval.ValueGenerator, error) { @@ -392,8 +394,8 @@ func setupSpanConfigsStream( execConfig := evalCtx.Planner.ExecutorConfig().(*sql.ExecutorConfig) spanConfigName := systemschema.SpanConfigurationsTableName - if knobs := execConfig.StreamingTestingKnobs; knobs != nil && knobs.MockSpanConfigTableName != nil { - spanConfigName = knobs.MockSpanConfigTableName + if r.knobs != nil && r.knobs.MockSpanConfigTableName != nil { + spanConfigName = r.knobs.MockSpanConfigTableName } if err := sql.DescsTxn(ctx, execConfig, func(ctx context.Context, txn isql.Txn, col *descs.Collection) error { diff --git a/pkg/crosscluster/replicationtestutils/BUILD.bazel b/pkg/crosscluster/replicationtestutils/BUILD.bazel index d4a2fc9d0e0b..ce61d37605c4 100644 --- a/pkg/crosscluster/replicationtestutils/BUILD.bazel +++ b/pkg/crosscluster/replicationtestutils/BUILD.bazel @@ -9,6 +9,7 @@ go_library( "replication_helpers.go", "span_config_helpers.go", "testutils.go", + "uri_util.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/crosscluster/replicationtestutils", visibility = ["//visibility:public"], @@ -49,6 +50,7 @@ go_library( "//pkg/util", "//pkg/util/ctxgroup", "//pkg/util/hlc", + "//pkg/util/metamorphic", "//pkg/util/protoutil", "//pkg/util/randutil", "//pkg/util/retry", diff --git a/pkg/crosscluster/replicationtestutils/testutils.go b/pkg/crosscluster/replicationtestutils/testutils.go index e55f4ae4ff73..6b8d99e76df8 100644 --- a/pkg/crosscluster/replicationtestutils/testutils.go +++ b/pkg/crosscluster/replicationtestutils/testutils.go @@ -18,6 +18,7 @@ import ( apd "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/crosscluster/replicationutils" + "github.com/cockroachdb/cockroach/pkg/crosscluster/streamclient" "github.com/cockroachdb/cockroach/pkg/jobs" "github.com/cockroachdb/cockroach/pkg/jobs/jobspb" "github.com/cockroachdb/cockroach/pkg/kv/kvpb" @@ -77,6 +78,8 @@ type TenantStreamingClustersArgs struct { NoMetamorphicExternalConnection bool ExternalIODir string + + RoutingMode streamclient.RoutingMode } var DefaultTenantStreamingClustersArgs = TenantStreamingClustersArgs{ @@ -93,11 +96,11 @@ var DefaultTenantStreamingClustersArgs = TenantStreamingClustersArgs{ `) }, SrcNumNodes: 1, - SrcClusterSettings: defaultSrcClusterSetting, + SrcClusterSettings: DefaultClusterSettings, DestTenantName: roachpb.TenantName("destination"), DestTenantID: roachpb.MustMakeTenantID(2), DestNumNodes: 1, - DestClusterSettings: defaultDestClusterSetting, + DestClusterSettings: DefaultClusterSettings, } type TenantStreamingClusters struct { @@ -451,6 +454,12 @@ func CreateMultiTenantStreamingCluster( cluster, url, cleanup := startC2CTestCluster(ctx, t, serverArgs, args.MultitenantSingleClusterNumNodes, args.MultiTenantSingleClusterTestRegions) + if args.RoutingMode != "" { + query := url.Query() + query.Set(streamclient.RoutingModeKey, string(args.RoutingMode)) + url.RawQuery = query.Encode() + } + rng, _ := randutil.NewPseudoRand() destNodeIdx := args.MultitenantSingleClusterNumNodes - 1 @@ -495,6 +504,11 @@ func CreateTenantStreamingClusters( g.GoCtx(func(ctx context.Context) error { // Start the source cluster. srcCluster, srcURL, srcCleanup = startC2CTestCluster(ctx, t, serverArgs, args.SrcNumNodes, args.SrcClusterTestRegions) + if args.RoutingMode != "" { + query := srcURL.Query() + query.Set(streamclient.RoutingModeKey, string(args.RoutingMode)) + srcURL.RawQuery = query.Encode() + } return nil }) @@ -626,33 +640,32 @@ func CreateScatteredTable(t *testing.T, c *TenantStreamingClusters, numNodes int }, timeout) } -var defaultSrcClusterSetting = map[string]string{ - `kv.rangefeed.enabled`: `true`, +var DefaultClusterSettings = map[string]string{ + `bulkio.stream_ingestion.failover_signal_poll_interval`: `'100ms'`, + `bulkio.stream_ingestion.minimum_flush_interval`: `'10ms'`, + `jobs.registry.interval.adopt`: `'1s'`, + `kv.bulk_io_write.small_write_size`: `'1'`, + `kv.closed_timestamp.side_transport_interval`: `'50ms'`, // Speed up the rangefeed. These were set by squinting at the settings set in // the changefeed integration tests. `kv.closed_timestamp.target_duration`: `'100ms'`, `kv.rangefeed.closed_timestamp_refresh_interval`: `'200ms'`, - `kv.closed_timestamp.side_transport_interval`: `'50ms'`, - // Large timeout makes test to not fail with unexpected timeout failures. - `stream_replication.stream_liveness_track_frequency`: `'2s'`, - `stream_replication.min_checkpoint_frequency`: `'1s'`, + `kv.rangefeed.enabled`: `true`, // Finer grain checkpoints to keep replicated time close to present. `physical_replication.producer.timestamp_granularity`: `'100ms'`, - // Make all AddSSTable operation to trigger AddSSTable events. - `kv.bulk_io_write.small_write_size`: `'1'`, - `jobs.registry.interval.adopt`: `'1s'`, // Speed up span reconciliation `spanconfig.reconciliation_job.checkpoint_interval`: `'100ms'`, + `stream_replication.consumer_heartbeat_frequency`: `'1s'`, + `stream_replication.job_checkpoint_frequency`: `'100ms'`, + `stream_replication.min_checkpoint_frequency`: `'1s'`, + // Large timeout makes test to not fail with unexpected timeout failures. + `stream_replication.stream_liveness_track_frequency`: `'2s'`, } -var defaultDestClusterSetting = map[string]string{ - `stream_replication.consumer_heartbeat_frequency`: `'1s'`, - `stream_replication.job_checkpoint_frequency`: `'100ms'`, - `bulkio.stream_ingestion.minimum_flush_interval`: `'10ms'`, - `bulkio.stream_ingestion.failover_signal_poll_interval`: `'100ms'`, - `jobs.registry.interval.adopt`: `'1s'`, - `spanconfig.reconciliation_job.checkpoint_interval`: `'100ms'`, - `kv.rangefeed.enabled`: `true`, +func ConfigureDefaultSettings(t *testing.T, sqlRunner *sqlutils.SQLRunner) { + for key, val := range DefaultClusterSettings { + sqlRunner.Exec(t, fmt.Sprintf("SET CLUSTER SETTING %s = %s;", key, val)) + } } func ConfigureClusterSettings(setting map[string]string) []string { diff --git a/pkg/crosscluster/replicationtestutils/uri_util.go b/pkg/crosscluster/replicationtestutils/uri_util.go new file mode 100644 index 000000000000..bc5e3ad1f7ca --- /dev/null +++ b/pkg/crosscluster/replicationtestutils/uri_util.go @@ -0,0 +1,50 @@ +package replicationtestutils + +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +import ( + "fmt" + "math/rand" + "net/url" + "testing" + + "github.com/cockroachdb/cockroach/pkg/crosscluster/streamclient" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/metamorphic" + "github.com/stretchr/testify/require" +) + +// TODO make the sql address unconnectable if useGatewayRoutingMode is true +var useGatewayRoutingMode = metamorphic.ConstantWithTestBool("stream-use-gateway-routing-mode", false) +var useExternalConnection = metamorphic.ConstantWithTestBool("stream-use-external-connection", true) + +func GetReplicationUri( + t *testing.T, + sourceCluster serverutils.ApplicationLayerInterface, + destCluster serverutils.ApplicationLayerInterface, + sourceConnOptions ...serverutils.SQLConnOption, +) url.URL { + sourceURI, cleanup := sourceCluster.PGUrl(t, sourceConnOptions...) + t.Cleanup(cleanup) + + if useGatewayRoutingMode { + query := sourceURI.Query() + query.Set(streamclient.RoutingModeKey, string(streamclient.RoutingModeGateway)) + sourceURI.RawQuery = query.Encode() + } + + if useExternalConnection { + conn := destCluster.SQLConn(t) + defer conn.Close() + + externalUri := url.URL{Scheme: "external", Host: fmt.Sprintf("replication-uri-%d", rand.Int63())} + _, err := conn.Exec(fmt.Sprintf("CREATE EXTERNAL CONNECTION '%s' AS '%s'", externalUri.Host, sourceURI.String())) + require.NoError(t, err) + return externalUri + } + + return sourceURI +} diff --git a/pkg/crosscluster/streamclient/partitioned_stream_client.go b/pkg/crosscluster/streamclient/partitioned_stream_client.go index a52c8b0bb707..bd9db8ed07ff 100644 --- a/pkg/crosscluster/streamclient/partitioned_stream_client.go +++ b/pkg/crosscluster/streamclient/partitioned_stream_client.go @@ -182,10 +182,17 @@ func (p *partitionedStreamClient) createTopology( SourceTenantID: spec.SourceTenantID, } for _, sp := range spec.Partitions { - nodeUri, err := p.clusterUri.ResolveNode(sp.SQLAddress) - if err != nil { - return Topology{}, err + var connUri ClusterUri + if p.clusterUri.RoutingMode() == RoutingModeGateway { + connUri = p.clusterUri + } else { + var err error + connUri, err = MakeClusterUriForNode(p.clusterUri, sp.SQLAddress) + if err != nil { + return Topology{}, err + } } + rawSpec, err := protoutil.Marshal(sp.SourcePartition) if err != nil { return Topology{}, err @@ -194,7 +201,7 @@ func (p *partitionedStreamClient) createTopology( ID: sp.NodeID.String(), SubscriptionToken: SubscriptionToken(rawSpec), SrcInstanceID: int(sp.NodeID), - ConnUri: nodeUri, + ConnUri: connUri, SrcLocality: sp.Locality, Spans: sp.SourcePartition.Spans, }) diff --git a/pkg/crosscluster/streamclient/uri.go b/pkg/crosscluster/streamclient/uri.go index a93f3c1d2833..7f4cc5430beb 100644 --- a/pkg/crosscluster/streamclient/uri.go +++ b/pkg/crosscluster/streamclient/uri.go @@ -89,6 +89,12 @@ func ParseClusterUri(uri string) (ClusterUri, error) { if !allowedConfigUriSchemes[url.Scheme] { return ClusterUri{}, errors.Newf("stream replication from scheme %q is unsupported", url.Scheme) } + if url.Query().Has(RoutingModeKey) { + mode := url.Query().Get(RoutingModeKey) + if mode != string(RoutingModeNode) && mode != string(RoutingModeGateway) { + return ClusterUri{}, errors.Newf("unknown crdb_route value %q", mode) + } + } return ClusterUri{uri: *url}, nil } @@ -105,14 +111,21 @@ func MakeTestClusterUri(url url.URL) ClusterUri { return ClusterUri{uri: url} } -func (sa *ClusterUri) ResolveNode(hostname util.UnresolvedAddr) (ClusterUri, error) { - host, port, err := net.SplitHostPort(hostname.AddressField) +// MakeClusterUriForNode creates a new ClusterUri with the node address set to the given +// address. MakeClusterUriForNode will return an error if the uri has routing mode +// gateway. +func MakeClusterUriForNode(uri ClusterUri, nodeAddress util.UnresolvedAddr) (ClusterUri, error) { + if uri.RoutingMode() == RoutingModeGateway { + return ClusterUri{}, errors.Newf("cannot set node address on gateway uri %s", uri.Redacted()) + } + + host, port, err := net.SplitHostPort(nodeAddress.AddressField) if err != nil { return ClusterUri{}, err } - copy := sa.uri - copy.Host = net.JoinHostPort(host, port) - return ClusterUri{uri: copy}, nil + copy := uri + copy.uri.Host = net.JoinHostPort(host, port) + return copy, nil } func (sa *ClusterUri) Serialize() string { @@ -137,3 +150,32 @@ func redactUrl(u url.URL) string { u.RawQuery = "redacted" return u.String() } + +const RoutingModeKey = "crdb_route" + +type RoutingMode string + +const ( + // routinModeNode is the default routing mode for LDR and PCR. The + // configuration uri is used to connect to the cluster and build a dist sql + // plan for the stream producers. The processors in the destination client + // then connect directly to the nodes described by the source cluster's plan. + RoutingModeNode RoutingMode = "node" + // routingModeGateway is a routing mode that replaces the default node + // routing mode. Processors in the source cluster will connect to the + // configured uri instead of the per-node uris returned by the source + // clusters plan. This allows for LDR and PCR to be used in situations where + // the source cluster nodes are not directly routable from the destination + // nodes. + RoutingModeGateway RoutingMode = "gateway" +) + +// RoutingMode returns the routing mode specified in the uri. If no routing +// mode is specified, the default routing mode is returned. The routing mode is +// validated by the ClusterUri constructor. +func (c *ClusterUri) RoutingMode() RoutingMode { + if key := c.uri.Query().Get(RoutingModeKey); key != "" { + return RoutingMode(key) + } + return RoutingModeNode +} diff --git a/pkg/crosscluster/streamclient/uri_test.go b/pkg/crosscluster/streamclient/uri_test.go index 36e330fc2ca6..d76139c6a825 100644 --- a/pkg/crosscluster/streamclient/uri_test.go +++ b/pkg/crosscluster/streamclient/uri_test.go @@ -50,6 +50,9 @@ func TestParseClusterUri(t *testing.T) { tests := []testCase{ {uri: "postgres://foo", err: ""}, {uri: "postgresql://foo", err: ""}, + {uri: "postgresql://foo?crdb_route=node", err: ""}, + {uri: "postgresql://foo?crdb_route=gateway", err: ""}, + {uri: "postgresql://foo?crdb_route=ohhno", err: "unknown crdb_route value \"ohhno\""}, {uri: "randomgen://foo", err: ""}, {uri: "external://foo", err: "external uri \"external://foo\" must be resolved before constructing a cluster uri"}, {uri: "ohhno://foo", err: "stream replication from scheme \"ohhno\" is unsupported"}, diff --git a/pkg/kv/kvserver/replicate_queue_test.go b/pkg/kv/kvserver/replicate_queue_test.go index 5b221d7cd463..893466b802ee 100644 --- a/pkg/kv/kvserver/replicate_queue_test.go +++ b/pkg/kv/kvserver/replicate_queue_test.go @@ -2140,7 +2140,8 @@ func iterateOverAllStores( // the range log where the added replica type is a LEARNER. func TestPromoteNonVoterInAddVoter(t *testing.T) { defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) + scope := log.Scope(t) + defer scope.Close(t) // This test is slow under stress/race and can time out when upreplicating / // rebalancing to ensure all stores have the same range count initially, due @@ -2149,6 +2150,8 @@ func TestPromoteNonVoterInAddVoter(t *testing.T) { skip.UnderDeadlock(t) skip.UnderRace(t) + defer testutils.StartExecTrace(t, scope.GetDirectory()).Finish(t) + ctx := context.Background() // Create 7 stores: 3 in Region 1, 2 in Region 2, and 2 in Region 3. diff --git a/pkg/spanconfig/spanconfigjob/job.go b/pkg/spanconfig/spanconfigjob/job.go index 7d48cbff1f20..90d642f97b31 100644 --- a/pkg/spanconfig/spanconfigjob/job.go +++ b/pkg/spanconfig/spanconfigjob/job.go @@ -30,7 +30,7 @@ type resumer struct { var _ jobs.Resumer = (*resumer)(nil) -var reconciliationJobCheckpointInterval = settings.RegisterDurationSetting( +var ReconciliationJobCheckpointInterval = settings.RegisterDurationSetting( settings.ApplicationLevel, "spanconfig.reconciliation_job.checkpoint_interval", "the frequency at which the span config reconciliation job checkpoints itself", @@ -104,17 +104,17 @@ func (r *resumer) Resume(ctx context.Context, execCtxI interface{}) (jobErr erro syncutil.Mutex util.EveryN }{} - persistCheckpointsMu.EveryN = util.Every(reconciliationJobCheckpointInterval.Get(settingValues)) + persistCheckpointsMu.EveryN = util.Every(ReconciliationJobCheckpointInterval.Get(settingValues)) - reconciliationJobCheckpointInterval.SetOnChange(settingValues, func(ctx context.Context) { + ReconciliationJobCheckpointInterval.SetOnChange(settingValues, func(ctx context.Context) { persistCheckpointsMu.Lock() defer persistCheckpointsMu.Unlock() - persistCheckpointsMu.EveryN = util.Every(reconciliationJobCheckpointInterval.Get(settingValues)) + persistCheckpointsMu.EveryN = util.Every(ReconciliationJobCheckpointInterval.Get(settingValues)) }) checkpointingDisabled := false shouldSkipRetry := false - var onCheckpointInterceptor func() error + var onCheckpointInterceptor func(lastCheckpoint hlc.Timestamp) error retryOpts := retry.Options{ InitialBackoff: 5 * time.Second, @@ -140,7 +140,7 @@ func (r *resumer) Resume(ctx context.Context, execCtxI interface{}) (jobErr erro started := timeutil.Now() if err := rc.Reconcile(ctx, lastCheckpoint, r.job.Session(), func() error { if onCheckpointInterceptor != nil { - if err := onCheckpointInterceptor(); err != nil { + if err := onCheckpointInterceptor(lastCheckpoint); err != nil { return err } } diff --git a/pkg/spanconfig/spanconfigkvaccessor/kvaccessor.go b/pkg/spanconfig/spanconfigkvaccessor/kvaccessor.go index ca0b5a22a7c8..fc9e87f1f7ed 100644 --- a/pkg/spanconfig/spanconfigkvaccessor/kvaccessor.go +++ b/pkg/spanconfig/spanconfigkvaccessor/kvaccessor.go @@ -163,6 +163,8 @@ func (k *KVAccessor) UpdateSpanConfigRecords( toUpsert []spanconfig.Record, minCommitTS, maxCommitTS hlc.Timestamp, ) error { + log.VInfof(ctx, 2, "kv accessor updating span configs: toDelete=%+v, toUpsert=%+v, minCommitTS=%s, maxCommitTS=%s", toDelete, toUpsert, minCommitTS, maxCommitTS) + if k.optionalTxn != nil { return k.updateSpanConfigRecordsWithTxn(ctx, toDelete, toUpsert, k.optionalTxn, minCommitTS, maxCommitTS) } diff --git a/pkg/spanconfig/spanconfigmanager/manager_test.go b/pkg/spanconfig/spanconfigmanager/manager_test.go index d9ee96394557..69986b235511 100644 --- a/pkg/spanconfig/spanconfigmanager/manager_test.go +++ b/pkg/spanconfig/spanconfigmanager/manager_test.go @@ -303,7 +303,7 @@ func TestReconciliationJobErrorAndRecovery(t *testing.T) { ManagerDisableJobCreation: true, // disable the automatic job creation JobDisableInternalRetry: true, SQLWatcherCheckpointNoopsEveryDurationOverride: 100 * time.Millisecond, - JobOnCheckpointInterceptor: func() error { + JobOnCheckpointInterceptor: func(_ hlc.Timestamp) error { mu.Lock() defer mu.Unlock() @@ -388,7 +388,7 @@ func TestReconciliationUsesRightCheckpoint(t *testing.T) { }, ManagerDisableJobCreation: true, // disable the automatic job creation SQLWatcherCheckpointNoopsEveryDurationOverride: 10 * time.Millisecond, - JobOnCheckpointInterceptor: func() error { + JobOnCheckpointInterceptor: func(_ hlc.Timestamp) error { select { case err := <-errCh: return err diff --git a/pkg/spanconfig/spanconfigreconciler/reconciler.go b/pkg/spanconfig/spanconfigreconciler/reconciler.go index 9e2c9d43ee98..122cd92fcaba 100644 --- a/pkg/spanconfig/spanconfigreconciler/reconciler.go +++ b/pkg/spanconfig/spanconfigreconciler/reconciler.go @@ -461,6 +461,10 @@ func updateSpanConfigRecords( } return err // not a retryable error, bubble up } + + if log.V(3) { + log.Infof(ctx, "successfully updated span config records: deleted = %+#v; upserted = %+#v", toDelete, toUpsert) + } return nil // we performed the update; we're done here } return nil diff --git a/pkg/spanconfig/spanconfigstore/store.go b/pkg/spanconfig/spanconfigstore/store.go index 20b78331473d..922b702d0c99 100644 --- a/pkg/spanconfig/spanconfigstore/store.go +++ b/pkg/spanconfig/spanconfigstore/store.go @@ -360,7 +360,7 @@ func (s *Store) maybeLogUpdate(ctx context.Context, update *spanconfig.Update) e // Log if there is a SpanConfig change in any field other than // ProtectedTimestamps to avoid logging PTS updates. - if found && curSpanConfig.HasConfigurationChange(nextSC) { + if log.V(2) || (found && curSpanConfig.HasConfigurationChange(nextSC)) { log.KvDistribution.Infof(ctx, "changing the spanconfig for span:%+v from:%+v to:%+v", target, curSpanConfig, nextSC) diff --git a/pkg/spanconfig/testing_knobs.go b/pkg/spanconfig/testing_knobs.go index 2226c9ecb20d..ad968d3a6bed 100644 --- a/pkg/spanconfig/testing_knobs.go +++ b/pkg/spanconfig/testing_knobs.go @@ -49,7 +49,7 @@ type TestingKnobs struct { // JobPersistCheckpointInterceptor, if set, is invoked before the // reconciliation job persists checkpoints. - JobOnCheckpointInterceptor func() error + JobOnCheckpointInterceptor func(lastCheckpoint hlc.Timestamp) error // KVSubscriberRangeFeedKnobs control lifecycle events for the rangefeed // underlying the KVSubscriber. diff --git a/pkg/sql/exec_util.go b/pkg/sql/exec_util.go index 6119f0899d21..a8028c2282b1 100644 --- a/pkg/sql/exec_util.go +++ b/pkg/sql/exec_util.go @@ -1940,6 +1940,8 @@ type StreamingTestingKnobs struct { SpanConfigRangefeedCacheKnobs *rangefeedcache.TestingKnobs + OnGetSQLInstanceInfo func(cluster *roachpb.NodeDescriptor) *roachpb.NodeDescriptor + FailureRate uint32 } diff --git a/pkg/sql/logictest/logic.go b/pkg/sql/logictest/logic.go index 787f984b4b1d..89ca74238dee 100644 --- a/pkg/sql/logictest/logic.go +++ b/pkg/sql/logictest/logic.go @@ -3495,17 +3495,28 @@ func (t *logicTest) unexpectedError(sql string, pos string, err error) (bool, er return false, fmt.Errorf("%s: %s\nexpected success, but found\n%s", pos, sql, formatErr(err)) } +var uniqueHashPattern = regexp.MustCompile(`UNIQUE.*USING\s+HASH`) + func (t *logicTest) execStatement(stmt logicStatement) (bool, error) { db := t.db t.noticeBuffer = nil if *showSQL { t.outf("%s;", stmt.sql) } - execSQL, changed := randgen.ApplyString(t.rng, stmt.sql, randgen.ColumnFamilyMutator) - if changed { - log.Infof(context.Background(), "Rewrote test statement:\n%s", execSQL) - if *showSQL { - t.outf("rewrote:\n%s\n", execSQL) + execSQL := stmt.sql + // TODO(#65929, #107398): Don't mutate column families for CREATE TABLE + // statements with unique, hash-sharded indexes. The altered AST will be + // reserialized with a UNIQUE constraint, not a UNIQUE INDEX, which may not + // be parsable because constraints do not support all the options that + // indexes do. + if !uniqueHashPattern.MatchString(stmt.sql) { + var changed bool + execSQL, changed = randgen.ApplyString(t.rng, execSQL, randgen.ColumnFamilyMutator) + if changed { + log.Infof(context.Background(), "Rewrote test statement:\n%s", execSQL) + if *showSQL { + t.outf("rewrote:\n%s\n", execSQL) + } } } @@ -3535,8 +3546,6 @@ func (t *logicTest) execStatement(stmt logicStatement) (bool, error) { return t.finishExecStatement(stmt, execSQL, res, err) } -var uniqueHashPattern = regexp.MustCompile(`UNIQUE.*USING\s+HASH`) - func (t *logicTest) finishExecStatement( stmt logicStatement, execSQL string, res gosql.Result, err error, ) (bool, error) { diff --git a/pkg/sql/opt/props/histogram.go b/pkg/sql/opt/props/histogram.go index ddda074f426f..762b6e29d2f9 100644 --- a/pkg/sql/opt/props/histogram.go +++ b/pkg/sql/opt/props/histogram.go @@ -336,10 +336,11 @@ func (h *Histogram) filter( // used for comparison and are not stored, and two spans are never // built and referenced simultaneously. var sb spanBuilder + sb.init(prefix) { // Limit the scope of firstBucket to avoid referencing it below after // sb.makeSpanFromBucket has been called again. - firstBucket := sb.makeSpanFromBucket(ctx, &iter, prefix) + firstBucket := sb.makeSpanFromBucket(ctx, &iter) for spanIndex < spanCount { span := getSpan(spanIndex) if firstBucket.StartsAfter(&keyCtx, span) { @@ -357,7 +358,7 @@ func (h *Histogram) filter( span := getSpan(spanIndex) bucIndex := sort.Search(bucketCount, func(i int) bool { iter.setIdx(i) - bucket := sb.makeSpanFromBucket(ctx, &iter, prefix) + bucket := sb.makeSpanFromBucket(ctx, &iter) if desc { return span.StartsAfter(&keyCtx, &bucket) } @@ -382,7 +383,7 @@ func (h *Histogram) filter( } if spanCount == 1 && bucIndex < bucketCount-1 { iter.setIdx(bucIndex + 1) - bucket := sb.makeSpanFromBucket(ctx, &iter, prefix) + bucket := sb.makeSpanFromBucket(ctx, &iter) if !desc && bucket.StartsAfter(&keyCtx, span) || desc && !bucket.StartsAfter(&keyCtx, span) { newBucketCount = 2 @@ -406,7 +407,7 @@ func (h *Histogram) filter( // Convert the bucket to a span in order to take advantage of the // constraint library. - left := sb.makeSpanFromBucket(ctx, &iter, prefix) + left := sb.makeSpanFromBucket(ctx, &iter) right := getSpan(spanIndex) if left.StartsAfter(&keyCtx, right) { @@ -425,7 +426,7 @@ func (h *Histogram) filter( continue } - filteredBucket := iter.b + filteredBucket := *iter.b if filteredSpan.Compare(&keyCtx, &left) != 0 { // The bucket was cut off in the middle. Get the resulting filtered // bucket. @@ -476,7 +477,7 @@ func (h *Histogram) filter( filtered.addEmptyBucket(ctx, iter.inclusiveLowerBound(ctx), desc) } else if lastBucket := filtered.buckets[len(filtered.buckets)-1]; lastBucket.NumRange != 0 { iter.setIdx(0) - span := sb.makeSpanFromBucket(ctx, &iter, prefix) + span := sb.makeSpanFromBucket(ctx, &iter) ub := h.getPrevUpperBound(ctx, span.EndKey(), span.EndBoundary(), colOffset) filtered.addEmptyBucket(ctx, ub, desc) } @@ -567,16 +568,16 @@ func (h *Histogram) getPrevUpperBound( } func (h *Histogram) addEmptyBucket(ctx context.Context, upperBound tree.Datum, desc bool) { - h.addBucket(ctx, &cat.HistogramBucket{UpperBound: upperBound}, desc) + h.addBucket(ctx, cat.HistogramBucket{UpperBound: upperBound}, desc) } -func (h *Histogram) addBucket(ctx context.Context, bucket *cat.HistogramBucket, desc bool) { +func (h *Histogram) addBucket(ctx context.Context, bucket cat.HistogramBucket, desc bool) { // Check whether we can combine this bucket with the previous bucket. if len(h.buckets) != 0 { lastBucket := &h.buckets[len(h.buckets)-1] - lower, higher := lastBucket, bucket + lower, higher := lastBucket, &bucket if desc { - lower, higher = bucket, lastBucket + lower, higher = &bucket, lastBucket } if lower.NumRange == 0 && lower.NumEq == 0 && higher.NumRange == 0 { lastBucket.NumEq = higher.NumEq @@ -592,7 +593,7 @@ func (h *Histogram) addBucket(ctx context.Context, bucket *cat.HistogramBucket, return } } - h.buckets = append(h.buckets, *bucket) + h.buckets = append(h.buckets, bucket) } // ApplySelectivity returns a histogram with the given selectivity applied. If @@ -740,6 +741,15 @@ type spanBuilder struct { endScratch []tree.Datum } +func (sb *spanBuilder) init(prefix []tree.Datum) { + n := len(prefix) + 1 + d := make([]tree.Datum, 2*n) + copy(d, prefix) + copy(d[n:], prefix) + sb.startScratch = d[:n:n] + sb.endScratch = d[n:] +} + // makeSpanFromBucket constructs a constraint.Span from iter's current histogram // bucket. // @@ -747,7 +757,7 @@ type spanBuilder struct { // on the same spanBuilder. This is because it reuses scratch slices in the // spanBuilder to reduce allocations when building span keys. func (sb *spanBuilder) makeSpanFromBucket( - ctx context.Context, iter *histogramIter, prefix []tree.Datum, + ctx context.Context, iter *histogramIter, ) (span constraint.Span) { start, startBoundary := iter.lowerBound() end, endBoundary := iter.upperBound() @@ -762,10 +772,8 @@ func (sb *spanBuilder) makeSpanFromBucket( startBoundary = constraint.IncludeBoundary endBoundary = constraint.IncludeBoundary } - sb.startScratch = append(sb.startScratch[:0], prefix...) - sb.startScratch = append(sb.startScratch, start) - sb.endScratch = append(sb.endScratch[:0], prefix...) - sb.endScratch = append(sb.endScratch, end) + sb.startScratch[len(sb.startScratch)-1] = start + sb.endScratch[len(sb.endScratch)-1] = end span.Init( constraint.MakeCompositeKey(sb.startScratch...), startBoundary, @@ -808,7 +816,7 @@ func (sb *spanBuilder) makeSpanFromBucket( // we use the heuristic that NumRange is reduced by half. func getFilteredBucket( iter *histogramIter, keyCtx *constraint.KeyContext, filteredSpan *constraint.Span, colOffset int, -) *cat.HistogramBucket { +) cat.HistogramBucket { spanLowerBound := filteredSpan.StartKey().Value(colOffset) spanUpperBound := filteredSpan.EndKey().Value(colOffset) bucketLowerBound := iter.inclusiveLowerBound(keyCtx.Ctx) @@ -915,7 +923,7 @@ func getFilteredBucket( if iter.desc { upperBound = spanLowerBound } - return &cat.HistogramBucket{ + return cat.HistogramBucket{ NumEq: numEq, NumRange: numRange, DistinctRange: distinctCountRange, diff --git a/pkg/sql/opt/props/histogram_test.go b/pkg/sql/opt/props/histogram_test.go index ecb4da360fa6..f5153ec6d36f 100644 --- a/pkg/sql/opt/props/histogram_test.go +++ b/pkg/sql/opt/props/histogram_test.go @@ -438,8 +438,8 @@ func TestFilterBucket(t *testing.T) { // the second bucket. iter.setIdx(1) b := getFilteredBucket(&iter, &keyCtx, span, colOffset) - roundBucket(b) - return b, nil + roundBucket(&b) + return &b, nil } runTest := func(h *Histogram, testData []testCase, colOffset int, typs ...types.Family) { diff --git a/pkg/testutils/BUILD.bazel b/pkg/testutils/BUILD.bazel index 5fb1beba8f2e..f0fcdb14104c 100644 --- a/pkg/testutils/BUILD.bazel +++ b/pkg/testutils/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "backup.go", "dir.go", "error.go", + "exectrace.go", "files.go", "hook.go", "keys.go", @@ -33,6 +34,7 @@ go_library( "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", "@com_github_dataexmachina_dev_side_eye_go//sideeye", + "@com_github_petermattis_goid//:goid", "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/testutils/exectrace.go b/pkg/testutils/exectrace.go new file mode 100644 index 000000000000..c9c4e60dfad5 --- /dev/null +++ b/pkg/testutils/exectrace.go @@ -0,0 +1,71 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package testutils + +import ( + "context" + "fmt" + "os" + "path/filepath" + "runtime/trace" + + "github.com/petermattis/goid" +) + +type ActiveExecTrace struct { + name string + file *os.File + reg *trace.Region +} + +// Finish stops the ongoing execution trace, if there is one, and closes the +// file. It must be called only once. +func (a *ActiveExecTrace) Finish(t interface { + Failed() bool + Logf(string, ...interface{}) +}) { + if a == nil { + return + } + a.reg.End() + trace.Stop() + _ = a.file.Close() + if !t.Failed() { + _ = os.Remove(a.file.Name()) + } else { + t.Logf("execution trace written to %s", a.file.Name()) + } +} + +// StartExecTrace starts a Go execution trace and returns a handle that allows +// stopping it. If a trace cannot be started, this is logged and nil is returned. +// It is valid to stop a nil ActiveExecTrace. +// +// This helper is intended to instrument tests for which an execution trace is +// desired on the next failure. +func StartExecTrace( + t interface { + Name() string + Logf(string, ...interface{}) + }, dir string, +) *ActiveExecTrace { + path := filepath.Join(dir, fmt.Sprintf("exectrace_goid_%d.bin", goid.Get())) + f, err := os.Create(path) + if err != nil { + t.Logf("could not create file for execution trace: %s", err) + return nil + } + if err := trace.Start(f); err != nil { + _ = f.Close() + t.Logf("could not start execution trace: %s", err) + return nil + } + return &ActiveExecTrace{ + name: t.Name(), + file: f, + reg: trace.StartRegion(context.Background(), t.Name()), + } +}