Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sql-1]accounts: preparatory commits for SQL-izing accounts #934

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
20 changes: 11 additions & 9 deletions accounts/checkers.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func NewAccountChecker(service Service,
}

return nil, service.AssociateInvoice(
acct.ID, hash,
ctx, acct.ID, hash,
)
}, mid.PassThroughErrorHandler,
),
Expand Down Expand Up @@ -615,12 +615,12 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params,
fee := lnrpc.CalculateFeeLimit(limit, sendAmt)
sendAmt += fee

err = service.CheckBalance(acct.ID, sendAmt)
err = service.CheckBalance(ctx, acct.ID, sendAmt)
if err != nil {
return fmt.Errorf("error validating account balance: %w", err)
}

err = service.AssociatePayment(acct.ID, pHash, sendAmt)
err = service.AssociatePayment(ctx, acct.ID, pHash, sendAmt)
if err != nil {
return fmt.Errorf("error associating payment: %w", err)
}
Expand Down Expand Up @@ -661,11 +661,13 @@ func checkSendResponse(ctx context.Context, service Service,
if status == lnrpc.Payment_FAILED {
service.DeleteValues(reqID)

return nil, service.RemovePayment(hash)
return nil, service.RemovePayment(ctx, hash)
}

// If there is no immediate failure, make sure we track the payment.
err = service.TrackPayment(acct.ID, hash, lnwire.MilliSatoshi(fullAmt))
err = service.TrackPayment(
ctx, acct.ID, hash, lnwire.MilliSatoshi(fullAmt),
)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -713,12 +715,12 @@ func checkSendToRoute(ctx context.Context, service Service, paymentHash []byte,
}
sendAmt += fee

err = service.CheckBalance(acct.ID, sendAmt)
err = service.CheckBalance(ctx, acct.ID, sendAmt)
if err != nil {
return fmt.Errorf("error validating account balance: %w", err)
}

err = service.AssociatePayment(acct.ID, hash, sendAmt)
err = service.AssociatePayment(ctx, acct.ID, hash, sendAmt)
if err != nil {
return fmt.Errorf("error associating payment with hash %s: %w",
hash, err)
Expand Down Expand Up @@ -749,7 +751,7 @@ func erroredPaymentHandler(service Service) mid.ErrorHandler {
"hash: %s and amount: %d", reqVals.PaymentHash,
reqVals.PaymentAmount)

err = service.PaymentErrored(acct.ID, reqVals.PaymentHash)
err = service.PaymentErrored(ctx, acct.ID, reqVals.PaymentHash)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -812,7 +814,7 @@ func sendToRouteHTLCResponseHandler(service Service) func(ctx context.Context,
}

err = service.TrackPayment(
acct.ID, reqValues.PaymentHash,
ctx, acct.ID, reqValues.PaymentHash,
lnwire.MilliSatoshi(totalAmount),
)
if err != nil {
Expand Down
43 changes: 26 additions & 17 deletions accounts/checkers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func newMockService() *mockService {
}
}

func (m *mockService) CheckBalance(_ AccountID,
func (m *mockService) CheckBalance(_ context.Context, _ AccountID,
wantBalance lnwire.MilliSatoshi) error {

if wantBalance > m.acctBalanceMsat {
Expand All @@ -81,24 +81,28 @@ func (m *mockService) CheckBalance(_ AccountID,
return nil
}

func (m *mockService) AssociateInvoice(id AccountID, hash lntypes.Hash) error {
func (m *mockService) AssociateInvoice(_ context.Context, id AccountID,
hash lntypes.Hash) error {

m.trackedInvoices[hash] = id

return nil
}

func (m *mockService) AssociatePayment(id AccountID, paymentHash lntypes.Hash,
amt lnwire.MilliSatoshi) error {
func (m *mockService) AssociatePayment(_ context.Context, id AccountID,
paymentHash lntypes.Hash, amt lnwire.MilliSatoshi) error {

return nil
}

func (m *mockService) PaymentErrored(id AccountID, hash lntypes.Hash) error {
func (m *mockService) PaymentErrored(_ context.Context, id AccountID,
hash lntypes.Hash) error {

return nil
}

func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash,
amt lnwire.MilliSatoshi) error {
func (m *mockService) TrackPayment(_ context.Context, _ AccountID,
hash lntypes.Hash, amt lnwire.MilliSatoshi) error {

m.trackedPayments[hash] = &PaymentEntry{
Status: lnrpc.Payment_UNKNOWN,
Expand All @@ -108,7 +112,9 @@ func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash,
return nil
}

func (m *mockService) RemovePayment(hash lntypes.Hash) error {
func (m *mockService) RemovePayment(_ context.Context,
hash lntypes.Hash) error {

delete(m.trackedPayments, hash)

return nil
Expand Down Expand Up @@ -517,14 +523,15 @@ func testSendPayment(t *testing.T, uri string) {
errFunc := func(err error) {
lndMock.mainErrChan <- err
}
service, err := NewService(t.TempDir(), errFunc)
store := NewTestDB(t)
service, err := NewService(store, errFunc)
require.NoError(t, err)

err = service.Start(ctx, lndMock, routerMock, chainParams)
require.NoError(t, err)

assertBalance := func(id AccountID, expectedBalance int64) {
acct, err := service.Account(id)
acct, err := service.Account(ctx, id)
require.NoError(t, err)

require.Equal(t, expectedBalance,
Expand All @@ -539,7 +546,7 @@ func testSendPayment(t *testing.T, uri string) {

// Create an account and add it to the context.
acct, err := service.NewAccount(
5000, time.Now().Add(time.Hour), "test",
ctx, 5000, time.Now().Add(time.Hour), "test",
)
require.NoError(t, err)

Expand Down Expand Up @@ -713,14 +720,15 @@ func TestSendPaymentV2(t *testing.T) {
errFunc := func(err error) {
lndMock.mainErrChan <- err
}
service, err := NewService(t.TempDir(), errFunc)
store := NewTestDB(t)
service, err := NewService(store, errFunc)
require.NoError(t, err)

err = service.Start(ctx, lndMock, routerMock, chainParams)
require.NoError(t, err)

assertBalance := func(id AccountID, expectedBalance int64) {
acct, err := service.Account(id)
acct, err := service.Account(ctx, id)
require.NoError(t, err)

require.Equal(t, expectedBalance,
Expand All @@ -735,7 +743,7 @@ func TestSendPaymentV2(t *testing.T) {

// Create an account and add it to the context.
acct, err := service.NewAccount(
5000, time.Now().Add(time.Hour), "test",
ctx, 5000, time.Now().Add(time.Hour), "test",
)
require.NoError(t, err)

Expand Down Expand Up @@ -900,14 +908,15 @@ func TestSendToRouteV2(t *testing.T) {
errFunc := func(err error) {
lndMock.mainErrChan <- err
}
service, err := NewService(t.TempDir(), errFunc)
store := NewTestDB(t)
service, err := NewService(store, errFunc)
require.NoError(t, err)

err = service.Start(ctx, lndMock, routerMock, chainParams)
require.NoError(t, err)

assertBalance := func(id AccountID, expectedBalance int64) {
acct, err := service.Account(id)
acct, err := service.Account(ctx, id)
require.NoError(t, err)

require.Equal(t, expectedBalance,
Expand All @@ -922,7 +931,7 @@ func TestSendToRouteV2(t *testing.T) {

// Create an account and add it to the context.
acct, err := service.NewAccount(
5000, time.Now().Add(time.Hour), "test",
ctx, 5000, time.Now().Add(time.Hour), "test",
)
require.NoError(t, err)

Expand Down
11 changes: 11 additions & 0 deletions accounts/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package accounts

import "errors"

var (
// ErrLabelAlreadyExists is returned by the CreateAccount method if the
// account label is already used by an existing account.
ErrLabelAlreadyExists = errors.New(
"account label uniqueness constraint violation",
)
)
2 changes: 1 addition & 1 deletion accounts/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (s *InterceptorService) Intercept(ctx context.Context,
"macaroon caveat")
}

acct, err := s.Account(*acctID)
acct, err := s.Account(ctx, *acctID)
if err != nil {
return mid.RPCErrString(
req, "error getting account %x: %v", acctID[:], err,
Expand Down
38 changes: 23 additions & 15 deletions accounts/interface.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package accounts

import (
"context"
"encoding/hex"
"errors"
"fmt"
Expand Down Expand Up @@ -201,30 +202,34 @@ var (
type Store interface {
// NewAccount creates a new OffChainBalanceAccount with the given
// balance and a randomly chosen ID.
NewAccount(balance lnwire.MilliSatoshi, expirationDate time.Time,
label string) (*OffChainBalanceAccount, error)
NewAccount(ctx context.Context, balance lnwire.MilliSatoshi,
expirationDate time.Time, label string) (
*OffChainBalanceAccount, error)

// UpdateAccount writes an account to the database, overwriting the
// existing one if it exists.
UpdateAccount(account *OffChainBalanceAccount) error
UpdateAccount(ctx context.Context,
account *OffChainBalanceAccount) error

// Account retrieves an account from the Store and un-marshals it. If
// the account cannot be found, then ErrAccNotFound is returned.
Account(id AccountID) (*OffChainBalanceAccount, error)
Account(ctx context.Context, id AccountID) (*OffChainBalanceAccount,
error)

// Accounts retrieves all accounts from the store and un-marshals them.
Accounts() ([]*OffChainBalanceAccount, error)
Accounts(ctx context.Context) ([]*OffChainBalanceAccount, error)

// RemoveAccount finds an account by its ID and removes it from the¨
// store.
RemoveAccount(id AccountID) error
RemoveAccount(ctx context.Context, id AccountID) error

// LastIndexes returns the last invoice add and settle index or
// ErrNoInvoiceIndexKnown if no indexes are known yet.
LastIndexes() (uint64, uint64, error)
LastIndexes(ctx context.Context) (uint64, uint64, error)

// StoreLastIndexes stores the last invoice add and settle index.
StoreLastIndexes(addIndex, settleIndex uint64) error
StoreLastIndexes(ctx context.Context, addIndex,
settleIndex uint64) error

// Close closes the underlying store.
Close() error
Expand All @@ -234,34 +239,37 @@ type Store interface {
type Service interface {
// CheckBalance ensures an account is valid and has a balance equal to
// or larger than the amount that is required.
CheckBalance(id AccountID, requiredBalance lnwire.MilliSatoshi) error
CheckBalance(ctx context.Context, id AccountID,
requiredBalance lnwire.MilliSatoshi) error

// AssociateInvoice associates a generated invoice with the given
// account, making it possible for the account to be credited in case
// the invoice is paid.
AssociateInvoice(id AccountID, hash lntypes.Hash) error
AssociateInvoice(ctx context.Context, id AccountID,
hash lntypes.Hash) error

// TrackPayment adds a new payment to be tracked to the service. If the
// payment is eventually settled, its amount needs to be debited from
// the given account.
TrackPayment(id AccountID, hash lntypes.Hash,
TrackPayment(ctx context.Context, id AccountID, hash lntypes.Hash,
fullAmt lnwire.MilliSatoshi) error

// RemovePayment removes a failed payment from the service because it no
// longer needs to be tracked. The payment is certain to never succeed,
// so we never need to debit the amount from the account.
RemovePayment(hash lntypes.Hash) error
RemovePayment(ctx context.Context, hash lntypes.Hash) error

// AssociatePayment associates a payment (hash) with the given account,
// ensuring that the payment will be tracked for a user when LiT is
// restarted.
AssociatePayment(id AccountID, paymentHash lntypes.Hash,
fullAmt lnwire.MilliSatoshi) error
AssociatePayment(ctx context.Context, id AccountID,
paymentHash lntypes.Hash, fullAmt lnwire.MilliSatoshi) error

// PaymentErrored removes a pending payment from the accounts
// registered payment list. This should only ever be called if we are
// sure that the payment request errored out.
PaymentErrored(id AccountID, hash lntypes.Hash) error
PaymentErrored(ctx context.Context, id AccountID,
hash lntypes.Hash) error

RequestValuesStore
}
Expand Down
Loading
Loading