Skip to content

Commit

Permalink
refactor: review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
shekhar-rudder committed Jan 8, 2025
1 parent bf0f51b commit f74c0de
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package snowpipestreaming

import (
"context"
"errors"
"fmt"

"github.com/rudderlabs/rudder-go-kit/logger"
Expand All @@ -12,29 +13,11 @@ import (
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
)

type errCode int

const (
errBackoff errCode = iota
errAuthz
var (
errAuthz = errors.New("snowpipe authorization error")
errBackoff = errors.New("snowpipe backoff error")
)

type snowflakeConnectionErr struct {
code errCode
err error
}

func newSnowflakeConnectionErr(code errCode, err error) *snowflakeConnectionErr {
return &snowflakeConnectionErr{
code: code,
err: err,
}
}

func (sae *snowflakeConnectionErr) Error() string {
return sae.err.Error()
}

// initializeChannelWithSchema creates a new channel for the given table if it doesn't exist.
// If the channel already exists, it checks for new columns and adds them to the table.
// It returns the channel response after creating or recreating the channel.
Expand Down Expand Up @@ -89,7 +72,7 @@ func (m *Manager) addColumns(ctx context.Context, namespace, tableName string, c
snowflakeManager.Cleanup(ctx)
}()
if err = snowflakeManager.AddColumns(ctx, tableName, columns); err != nil {
return newSnowflakeConnectionErr(errAuthz, fmt.Errorf("adding column: %w", err))
return fmt.Errorf("adding column: %w, %w", errAuthz, err)
}
return nil
}
Expand Down Expand Up @@ -180,10 +163,10 @@ func (m *Manager) handleSchemaError(
snowflakeManager.Cleanup(ctx)
}()
if err := snowflakeManager.CreateSchema(ctx); err != nil {
return nil, newSnowflakeConnectionErr(errAuthz, fmt.Errorf("creating schema: %w", err))
return nil, fmt.Errorf("creating schema: %w, %w", errAuthz, err)
}
if err := snowflakeManager.CreateTable(ctx, channelReq.TableConfig.Table, eventSchema); err != nil {
return nil, newSnowflakeConnectionErr(errAuthz, fmt.Errorf("creating table: %w", err))
return nil, fmt.Errorf("creating table: %w, %w", errAuthz, err)
}
return m.api.CreateChannel(ctx, channelReq)
}
Expand All @@ -208,7 +191,7 @@ func (m *Manager) handleTableError(
snowflakeManager.Cleanup(ctx)
}()
if err := snowflakeManager.CreateTable(ctx, channelReq.TableConfig.Table, eventSchema); err != nil {
return nil, newSnowflakeConnectionErr(errAuthz, fmt.Errorf("creating table: %w", err))
return nil, fmt.Errorf("creating table: %w, %w", errAuthz, err)
}
return m.api.CreateChannel(ctx, channelReq)
}
Expand Down Expand Up @@ -248,8 +231,8 @@ func (m *Manager) deleteChannel(ctx context.Context, tableName, channelID string
}

func (m *Manager) createSnowflakeManager(ctx context.Context, namespace string) (manager.Manager, error) {
if m.authzBackoff.isInBackoff() {
return nil, newSnowflakeConnectionErr(errBackoff, fmt.Errorf("skipping snowflake manager creation due to backoff"))
if m.isInBackoff() {
return nil, fmt.Errorf("skipping snowflake manager creation due to backoff: %w", errBackoff)
}
modelWarehouse := whutils.ModelWarehouse{
WorkspaceID: m.destination.WorkspaceID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"sync"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-retryablehttp"
jsoniter "github.com/json-iterator/go"
"github.com/samber/lo"
Expand All @@ -38,19 +39,12 @@ import (

var json = jsoniter.ConfigCompatibleWithStandardLibrary

type systemClock struct{}

func (t systemClock) Now() time.Time {
return timeutil.Now()
}

func New(
conf *config.Config,
log logger.Logger,
statsFactory stats.Stats,
destination *backendconfig.DestinationT,
) *Manager {
clock := systemClock{}
m := &Manager{
appConfig: conf,
logger: log.Child("snowpipestreaming").Withn(
Expand All @@ -60,7 +54,7 @@ func New(
),
statsFactory: statsFactory,
destination: destination,
now: clock.Now,
now: timeutil.Now,
channelCache: sync.Map{},
polledImportInfoMap: make(map[string]*importInfo),
}
Expand All @@ -76,7 +70,9 @@ func New(
m.config.client.retryMax = conf.GetInt("SnowpipeStreaming.Client.retryMax", 5)
m.config.instanceID = conf.GetString("INSTANCE_ID", "1")
m.config.maxBufferCapacity = conf.GetReloadableInt64Var(512*bytesize.KB, bytesize.B, "SnowpipeStreaming.maxBufferCapacity")
m.authzBackoff = newAuthzBackoff(conf.GetDuration("SnowpipeStreaming.backoffDuration", 1, time.Second), clock)
m.config.backoff.initialInterval = conf.GetReloadableDurationVar(1, time.Second, "SnowpipeStreaming.backoffInitialIntervalInSeconds")
m.config.backoff.multiplier = conf.GetReloadableFloat64Var(2.0, "SnowpipeStreaming.backoffMultiplier")
m.config.backoff.maxInterval = conf.GetReloadableDurationVar(1, time.Hour, "SnowpipeStreaming.backoffMaxIntervalInHours")

tags := stats.Tags{
"module": "batch_router",
Expand Down Expand Up @@ -110,12 +106,12 @@ func New(
snowpipeapi.New(m.appConfig, m.statsFactory, m.config.client.url, m.requestDoer),
destination,
)
m.managerCreator = func(mCtx context.Context, modelWarehouse whutils.ModelWarehouse, conf *config.Config, logger logger.Logger, stats stats.Stats) (manager.Manager, error) {
sf, err := manager.New(whutils.SnowpipeStreaming, conf, logger, stats)
m.managerCreator = func(ctx context.Context, modelWarehouse whutils.ModelWarehouse, conf *config.Config, logger logger.Logger, statsFactory stats.Stats) (manager.Manager, error) {
sf, err := manager.New(whutils.SnowpipeStreaming, conf, logger, statsFactory)
if err != nil {
return nil, fmt.Errorf("creating snowflake manager: %w", err)
}
err = sf.Setup(mCtx, modelWarehouse, whutils.NewNoOpUploader())
err = sf.Setup(ctx, modelWarehouse, whutils.NewNoOpUploader())
if err != nil {
return nil, fmt.Errorf("setting up snowflake manager: %w", err)
}
Expand All @@ -142,6 +138,10 @@ func (m *Manager) retryableClient() *retryablehttp.Client {
return client
}

func (m *Manager) Now() time.Time {
return m.now()
}

func (m *Manager) Transform(job *jobsdb.JobT) (string, error) {
return common.GetMarshalledData(string(job.EventPayload), job.JobID)
}
Expand Down Expand Up @@ -173,15 +173,13 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU

discardsChannel, err := m.initializeChannelWithSchema(ctx, asyncDest.Destination.ID, &destConf, discardsTable(), discardsSchema())
if err != nil {
var sfConnectionErr *snowflakeConnectionErr
if errors.As(err, &sfConnectionErr) {
if sfConnectionErr.code == errAuthz {
m.authzBackoff.set()
if errors.Is(err, errAuthz) || errors.Is(err, errBackoff) {
if errors.Is(err, errAuthz) {
m.setBackOff()
}
return m.failedJobs(asyncDest, err.Error())
} else {
return m.abortJobs(asyncDest, fmt.Errorf("failed to prepare discards channel: %w", err).Error())
}
return m.abortJobs(asyncDest, fmt.Errorf("failed to prepare discards channel: %w", err).Error())
}
m.logger.Infon("Prepared discards channel")

Expand Down Expand Up @@ -218,12 +216,11 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU
for _, info := range uploadInfos {
imInfo, discardImInfo, err := m.sendEventsToSnowpipe(ctx, asyncDest.Destination.ID, &destConf, info)
if err != nil {
var sfConnectionErr *snowflakeConnectionErr
if errors.As(err, &sfConnectionErr) {
if sfConnectionErr.code == errAuthz && !isBackoffSet {
m.authzBackoff.set()
}
if errors.Is(err, errAuthz) || errors.Is(err, errBackoff) {
shouldResetBackoff = false
if errors.Is(err, errAuthz) && !isBackoffSet {
m.setBackOff()
}
}
m.logger.Warnn("Failed to send events to Snowpipe",
logger.NewStringField("table", info.tableName),
Expand All @@ -245,7 +242,7 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU
}
}
if shouldResetBackoff {
m.authzBackoff.reset()
m.resetBackoff()
}
if discardImportInfo != nil {
importInfos = append(importInfos, discardImportInfo)
Expand Down Expand Up @@ -286,7 +283,7 @@ func (m *Manager) eventsFromFile(fileName string, eventsCount int) ([]*event, er

events := make([]*event, 0, eventsCount)

formattedTS := m.now().Format(misc.RFC3339Milli)
formattedTS := m.Now().Format(misc.RFC3339Milli)
scanner := bufio.NewScanner(file)
scanner.Buffer(nil, int(m.config.maxBufferCapacity.Load()))

Expand Down Expand Up @@ -330,7 +327,7 @@ func (m *Manager) sendEventsToSnowpipe(
}
log.Infon("Prepared channel", logger.NewStringField("channelID", channelResponse.ChannelID))

formattedTS := m.now().Format(misc.RFC3339Milli)
formattedTS := m.Now().Format(misc.RFC3339Milli)
var discardInfos []discardInfo
for _, tableEvent := range info.events {
discardInfos = append(discardInfos, getDiscardedRecordsFromEvent(tableEvent, channelResponse.SnowpipeSchema, info.tableName, formattedTS)...)
Expand Down Expand Up @@ -600,3 +597,34 @@ func (m *Manager) GetUploadStats(input common.GetUploadStatsInput) common.GetUpl
},
}
}

func (m *Manager) isInBackoff() bool {
if m.backoff.next.IsZero() {
return false
}
return m.Now().Before(m.backoff.next)
}

func (m *Manager) resetBackoff() {
m.backoff.next = time.Time{}
m.backoff.attempts = 0
}

func (m *Manager) setBackOff() {
b := backoff.NewExponentialBackOff(
backoff.WithInitialInterval(m.config.backoff.initialInterval.Load()),
backoff.WithMultiplier(m.config.backoff.multiplier.Load()),
backoff.WithClockProvider(m),
backoff.WithRandomizationFactor(0),
backoff.WithMaxElapsedTime(0),
backoff.WithMaxInterval(m.config.backoff.maxInterval.Load()),
)
b.Reset()
m.backoff.attempts++

var d time.Duration
for index := int64(0); index < int64(m.backoff.attempts); index++ {
d = b.NextBackOff()
}
m.backoff.next = m.Now().Add(d)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/common"
internalapi "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/snowpipestreaming/internal/api"
"github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/snowpipestreaming/internal/model"
"github.com/rudderlabs/rudder-server/utils/timeutil"
"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
"github.com/rudderlabs/rudder-server/warehouse/integrations/snowflake"
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
Expand Down Expand Up @@ -59,22 +60,14 @@ func newMockManager(m manager.Manager) *mockManager {
}
}

func (m *mockManager) CreateSchema(ctx context.Context) (err error) {
func (m *mockManager) CreateSchema(context.Context) error {
return m.createSchemaErr
}

func (m *mockManager) CreateTable(ctx context.Context, tableName string, columnMap whutils.ModelTableSchema) (err error) {
func (m *mockManager) CreateTable(context.Context, string, whutils.ModelTableSchema) error {
return nil
}

type mockClock struct {
nowTime time.Time
}

func (m *mockClock) Now() time.Time {
return m.nowTime
}

var (
usersChannelResponse = &model.ChannelResponse{
ChannelID: "test-users-channel",
Expand Down Expand Up @@ -366,43 +359,47 @@ func TestSnowpipeStreaming(t *testing.T) {
mockManager.createSchemaErr = fmt.Errorf("failed to create schema")
return mockManager, nil
}
mockClock := &mockClock{}
mockClock.nowTime = sm.now()
sm.authzBackoff = newAuthzBackoff(time.Second*10, mockClock)
sm.config.backoff.initialInterval = config.SingleValueLoader(time.Second * 10)
asyncDestStruct := &common.AsyncDestinationStruct{
Destination: destination,
FileName: "testdata/successful_user_records.txt",
}
require.Equal(t, false, sm.authzBackoff.isInBackoff())
require.False(t, sm.isInBackoff())
output1 := sm.Upload(asyncDestStruct)
require.Equal(t, 2, output1.FailedCount)
require.Equal(t, 0, output1.AbortCount)
require.Equal(t, 1, managerCreatorCallCount)
require.Equal(t, true, sm.authzBackoff.isInBackoff())
require.True(t, sm.isInBackoff())

sm.Upload(asyncDestStruct)
// client is not created again due to backoff error
require.Equal(t, 1, managerCreatorCallCount)
require.Equal(t, true, sm.authzBackoff.isInBackoff())

mockClock.nowTime = sm.now().Add(time.Second * 5)
require.Equal(t, true, sm.authzBackoff.isInBackoff())
mockClock.nowTime = sm.now().Add(time.Second * 20)
require.Equal(t, false, sm.authzBackoff.isInBackoff())
require.True(t, sm.isInBackoff())

sm.now = func() time.Time {
return timeutil.Now().Add(time.Second * 5)
}
require.True(t, sm.isInBackoff())
sm.now = func() time.Time {
return timeutil.Now().Add(time.Second * 20)
}
require.False(t, sm.isInBackoff())
sm.Upload(asyncDestStruct)
// client created again since backoff duration has been exceeded
require.Equal(t, 2, managerCreatorCallCount)
require.Equal(t, false, sm.authzBackoff.isInBackoff())
require.True(t, sm.isInBackoff())

sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) {
sm := snowflake.New(config.New(), logger.NOP, stats.NOP)
managerCreatorCallCount++
return newMockManager(sm), nil
}
sm.now = func() time.Time {
return timeutil.Now().Add(time.Second * 50)
}
sm.Upload(asyncDestStruct)
require.Equal(t, 3, managerCreatorCallCount)
require.Equal(t, false, sm.authzBackoff.isInBackoff())
require.False(t, sm.isInBackoff())
})

t.Run("Upload with discards table authorization error should mark the job as failed", func(t *testing.T) {
Expand Down Expand Up @@ -432,7 +429,7 @@ func TestSnowpipeStreaming(t *testing.T) {
require.Equal(t, 0, output.AbortCount)
require.NotEmpty(t, output.FailedReason)
require.Empty(t, output.AbortReason)
require.Equal(t, true, sm.authzBackoff.isInBackoff())
require.Equal(t, true, sm.isInBackoff())
})

t.Run("Upload insert error for all events", func(t *testing.T) {
Expand Down
Loading

0 comments on commit f74c0de

Please sign in to comment.