From 20b025c5bd4a7e768e7d05b75d837a4d198c2183 Mon Sep 17 00:00:00 2001 From: sczembor Date: Tue, 26 Nov 2024 17:07:51 +0100 Subject: [PATCH] apply code review suggestions --- go.mod | 1 + go.sum | 2 + native/database/db.go | 77 ++++++++++++++++++---------------- native/database/db_test.go | 85 ++++++++++++++++---------------------- 4 files changed, 81 insertions(+), 84 deletions(-) diff --git a/go.mod b/go.mod index 042ee8d..61aede0 100644 --- a/go.mod +++ b/go.mod @@ -174,6 +174,7 @@ require ( google.golang.org/protobuf v1.35.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gotest.tools v2.2.0+incompatible gotest.tools/v3 v3.5.1 // indirect nhooyr.io/websocket v1.8.6 // indirect pgregory.net/rapid v1.1.0 // indirect diff --git a/go.sum b/go.sum index c38c84a..4e2ecfe 100644 --- a/go.sum +++ b/go.sum @@ -1027,6 +1027,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= +gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/native/database/db.go b/native/database/db.go index 2ac9498..6273815 100644 --- a/native/database/db.go +++ b/native/database/db.go @@ -6,14 +6,14 @@ import ( _ "github.com/mattn/go-sqlite3" // Import the SQLite driver ) -// TransactionStatus represents the different states of a transaction. -type TransactionStatus string +// TxStatus represents the different states of a transaction. +type TxStatus int -// Different tx states +// Transaction status constants const ( - StatusPending TransactionStatus = "pending" - StatusBroadcasted TransactionStatus = "broadcasted" - StatusConfirmed TransactionStatus = "confirmed" + StatusPending TxStatus = iota + StatusBroadcasted + StatusConfirmed ) // SQL queries @@ -26,39 +26,46 @@ const ( CREATE TABLE IF NOT EXISTS transactions ( txid TEXT PRIMARY KEY, rawtx TEXT NOT NULL, - status TEXT NOT NULL + status INTEGER NOT NULL NOT NULL ) ` ) -// Transaction represents a transaction record in the database. -type Transaction struct { - BtcTxID uint64 `json:"txid"` - RawTx string `json:"rawtx"` - Status TransactionStatus `json:"status"` +// Tx represents a transaction record in the database. +type Tx struct { + BtcTxID uint64 `json:"txid"` + RawTx string `json:"rawtx"` + Status TxStatus `json:"status"` // TODO: other fields } -var db *sql.DB +// DB holds the database connection and provides methods for interacting with it. +type DB struct { + conn *sql.DB +} + +// NewDB creates a new DB instance and initializes the database connection. +func NewDB(dbPath string) (*DB, error) { + db := &DB{} // Create a new DB instance -// InitDB initializes the database connection and creates the table if it doesn't exist -func InitDB(dbPath string) error { + // Initialize the database connection var err error - db, err = sql.Open("sqlite3", dbPath) + db.conn, err = sql.Open("sqlite3", dbPath) if err != nil { - return err + return nil, err } - _, err = db.Exec(createTransactionsTableSQL) + _, err = db.conn.Exec(createTransactionsTableSQL) if err != nil { - return err + return nil, err } - return nil + + return db, nil // Return the initialized DB instance } -// InsertTransaction inserts a new transaction into the database -func InsertTransaction(tx Transaction) error { - stmt, err := db.Prepare(insertTransactionSQL) +// InsertTx inserts a new transaction into the database +func (db DB) InsertTx(tx Tx) error { + stmt, err := db.conn.Prepare(insertTransactionSQL) if err != nil { return err } @@ -72,9 +79,9 @@ func InsertTransaction(tx Transaction) error { return nil } -// GetTransaction retrives a transaction by its txid -func GetTransaction(txID uint64) (*Transaction, error) { - stmt, err := db.Prepare(getTransactionByTxidSQL) +// GetTx retrives a transaction by its txid +func (db DB) GetTx(txID uint64) (*Tx, error) { + stmt, err := db.conn.Prepare(getTransactionByTxidSQL) if err != nil { return nil, err } @@ -82,7 +89,7 @@ func GetTransaction(txID uint64) (*Transaction, error) { row := stmt.QueryRow(txID) - var tx Transaction + var tx Tx err = row.Scan(&tx.BtcTxID, &tx.RawTx, &tx.Status) if err != nil { if err == sql.ErrNoRows { @@ -94,17 +101,17 @@ func GetTransaction(txID uint64) (*Transaction, error) { return &tx, nil } -// GetPendingTransactions retrieves all transactions with a "pending" status -func GetPendingTransactions() ([]Transaction, error) { - rows, err := db.Query(getPendingTransactionsSQL, StatusPending) +// GetPendingTxs retrieves all transactions with a "pending" status +func (db DB) GetPendingTxs() ([]Tx, error) { + rows, err := db.conn.Query(getPendingTransactionsSQL, StatusPending) if err != nil { return nil, err } defer rows.Close() - var transactions []Transaction + var transactions []Tx for rows.Next() { - var tx Transaction + var tx Tx err := rows.Scan(&tx.BtcTxID, &tx.RawTx, &tx.Status) if err != nil { return nil, err @@ -115,9 +122,9 @@ func GetPendingTransactions() ([]Transaction, error) { return transactions, nil } -// UpdateTransactionStatus updates the status of a transaction by txid -func UpdateTransactionStatus(txID uint64, status TransactionStatus) error { - stmt, err := db.Prepare(updateTransactionStatusSQL) +// UpdateTxStatus updates the status of a transaction by txid +func (db DB) UpdateTxStatus(txID uint64, status TxStatus) error { + stmt, err := db.conn.Prepare(updateTransactionStatusSQL) if err != nil { return err } diff --git a/native/database/db_test.go b/native/database/db_test.go index c8598bb..c61d4fb 100644 --- a/native/database/db_test.go +++ b/native/database/db_test.go @@ -2,83 +2,70 @@ package database import ( "testing" + + "gotest.tools/assert" ) -func TestInsertTransaction(t *testing.T) { - err := InitDB(":memory:") // in-memory database - if err != nil { - t.Fatal(err) - } +func TestInsertTx(t *testing.T) { + db := initTestDB(t) - tx := Transaction{ + tx := Tx{ BtcTxID: 1, RawTx: "raw-transaction-hex", Status: StatusPending, } - err = InsertTransaction(tx) - if err != nil { - t.Errorf("InsertTransaction() error = %v", err) - } + err := db.InsertTx(tx) + assert.NilError(t, err) + + retrievedTx, err := db.GetTx(1) + assert.NilError(t, err) + assert.DeepEqual(t, retrievedTx, &tx) } -func TestGetPendingTransactions(t *testing.T) { - err := InitDB(":memory:") // in-memory database - if err != nil { - t.Fatal(err) - } +func TestGetPendingTxs(t *testing.T) { + db := initTestDB(t) - transactions := []Transaction{ + transactions := []Tx{ {BtcTxID: 1, RawTx: "tx1-hex", Status: StatusPending}, {BtcTxID: 2, RawTx: "tx2-hex", Status: StatusBroadcasted}, {BtcTxID: 3, RawTx: "tx3-hex", Status: StatusPending}, } for _, tx := range transactions { - err = InsertTransaction(tx) - if err != nil { - t.Fatal(err) - } - } - - pendingTxs, err := GetPendingTransactions() - if err != nil { - t.Errorf("GetPendingTransactions() error = %v", err) + err := db.InsertTx(tx) + assert.NilError(t, err) } - if len(pendingTxs) != 2 { - t.Errorf("Expected 2 pending transactions, got %d", len(pendingTxs)) - } + pendingTxs, err := db.GetPendingTxs() + assert.NilError(t, err) + assert.Equal(t, len(pendingTxs), 2) } -func TestUpdateTransactionStatus(t *testing.T) { - err := InitDB(":memory:") // in-memory database - if err != nil { - t.Fatal(err) - } +func TestUpdateTxStatus(t *testing.T) { + db := initTestDB(t) txID := uint64(1) - tx := Transaction{ + tx := Tx{ BtcTxID: txID, RawTx: "raw-transaction-hex", Status: StatusPending, } - err = InsertTransaction(tx) - if err != nil { - t.Fatal(err) - } + err := db.InsertTx(tx) + assert.NilError(t, err) - err = UpdateTransactionStatus(txID, StatusBroadcasted) - if err != nil { - t.Errorf("UpdateTransactionStatus() error = %v", err) - } + err = db.UpdateTxStatus(txID, StatusBroadcasted) + assert.NilError(t, err) - updatedTx, err := GetTransaction(txID) - if err != nil { - t.Errorf("GetTransaction() error = %v", err) - } + updatedTx, err := db.GetTx(txID) + assert.NilError(t, err) + assert.Equal(t, updatedTx.Status, StatusBroadcasted) +} - if updatedTx.Status != StatusBroadcasted { - t.Errorf("Expected StatusBrodcasted, got %s", updatedTx.Status) - } +func initTestDB(t *testing.T) *DB { + t.Helper() + + db, err := NewDB(":memory:") + assert.NilError(t, err) + return db }