Skip to content

Commit

Permalink
apply code review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
sczembor committed Nov 26, 2024
1 parent 3cc5554 commit 20b025c
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 84 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ require (
google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools v2.2.0+incompatible
gotest.tools/v3 v3.5.1 // indirect
nhooyr.io/websocket v1.8.6 // indirect
pgregory.net/rapid v1.1.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
Expand Down
77 changes: 42 additions & 35 deletions native/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import (
_ "github.com/mattn/go-sqlite3" // Import the SQLite driver
)

// TransactionStatus represents the different states of a transaction.
type TransactionStatus string
// TxStatus represents the different states of a transaction.
type TxStatus int

// Different tx states
// Transaction status constants
const (
StatusPending TransactionStatus = "pending"
StatusBroadcasted TransactionStatus = "broadcasted"
StatusConfirmed TransactionStatus = "confirmed"
StatusPending TxStatus = iota
StatusBroadcasted
StatusConfirmed
)

// SQL queries
Expand All @@ -26,39 +26,46 @@ const (
CREATE TABLE IF NOT EXISTS transactions (
txid TEXT PRIMARY KEY,
rawtx TEXT NOT NULL,
status TEXT NOT NULL
status INTEGER NOT NULL NOT NULL
)
`
)

// Transaction represents a transaction record in the database.
type Transaction struct {
BtcTxID uint64 `json:"txid"`
RawTx string `json:"rawtx"`
Status TransactionStatus `json:"status"`
// Tx represents a transaction record in the database.
type Tx struct {
BtcTxID uint64 `json:"txid"`
RawTx string `json:"rawtx"`
Status TxStatus `json:"status"`
// TODO: other fields
}

var db *sql.DB
// DB holds the database connection and provides methods for interacting with it.
type DB struct {
conn *sql.DB
}

// NewDB creates a new DB instance and initializes the database connection.
func NewDB(dbPath string) (*DB, error) {
db := &DB{} // Create a new DB instance

// InitDB initializes the database connection and creates the table if it doesn't exist
func InitDB(dbPath string) error {
// Initialize the database connection
var err error
db, err = sql.Open("sqlite3", dbPath)
db.conn, err = sql.Open("sqlite3", dbPath)
if err != nil {
return err
return nil, err
}

_, err = db.Exec(createTransactionsTableSQL)
_, err = db.conn.Exec(createTransactionsTableSQL)
if err != nil {
return err
return nil, err
}
return nil

return db, nil // Return the initialized DB instance
}

// InsertTransaction inserts a new transaction into the database
func InsertTransaction(tx Transaction) error {
stmt, err := db.Prepare(insertTransactionSQL)
// InsertTx inserts a new transaction into the database
func (db DB) InsertTx(tx Tx) error {
stmt, err := db.conn.Prepare(insertTransactionSQL)
if err != nil {
return err
}
Expand All @@ -72,17 +79,17 @@ func InsertTransaction(tx Transaction) error {
return nil
}

// GetTransaction retrives a transaction by its txid
func GetTransaction(txID uint64) (*Transaction, error) {
stmt, err := db.Prepare(getTransactionByTxidSQL)
// GetTx retrives a transaction by its txid
func (db DB) GetTx(txID uint64) (*Tx, error) {
stmt, err := db.conn.Prepare(getTransactionByTxidSQL)
if err != nil {
return nil, err
}
defer stmt.Close()

row := stmt.QueryRow(txID)

var tx Transaction
var tx Tx
err = row.Scan(&tx.BtcTxID, &tx.RawTx, &tx.Status)
if err != nil {
if err == sql.ErrNoRows {
Expand All @@ -94,17 +101,17 @@ func GetTransaction(txID uint64) (*Transaction, error) {
return &tx, nil
}

// GetPendingTransactions retrieves all transactions with a "pending" status
func GetPendingTransactions() ([]Transaction, error) {
rows, err := db.Query(getPendingTransactionsSQL, StatusPending)
// GetPendingTxs retrieves all transactions with a "pending" status
func (db DB) GetPendingTxs() ([]Tx, error) {
rows, err := db.conn.Query(getPendingTransactionsSQL, StatusPending)
if err != nil {
return nil, err
}
defer rows.Close()

var transactions []Transaction
var transactions []Tx
for rows.Next() {
var tx Transaction
var tx Tx
err := rows.Scan(&tx.BtcTxID, &tx.RawTx, &tx.Status)
if err != nil {
return nil, err
Expand All @@ -115,9 +122,9 @@ func GetPendingTransactions() ([]Transaction, error) {
return transactions, nil
}

// UpdateTransactionStatus updates the status of a transaction by txid
func UpdateTransactionStatus(txID uint64, status TransactionStatus) error {
stmt, err := db.Prepare(updateTransactionStatusSQL)
// UpdateTxStatus updates the status of a transaction by txid
func (db DB) UpdateTxStatus(txID uint64, status TxStatus) error {
stmt, err := db.conn.Prepare(updateTransactionStatusSQL)
if err != nil {
return err
}
Expand Down
85 changes: 36 additions & 49 deletions native/database/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,83 +2,70 @@ package database

import (
"testing"

"gotest.tools/assert"
)

func TestInsertTransaction(t *testing.T) {
err := InitDB(":memory:") // in-memory database
if err != nil {
t.Fatal(err)
}
func TestInsertTx(t *testing.T) {
db := initTestDB(t)

tx := Transaction{
tx := Tx{
BtcTxID: 1,
RawTx: "raw-transaction-hex",
Status: StatusPending,
}

err = InsertTransaction(tx)
if err != nil {
t.Errorf("InsertTransaction() error = %v", err)
}
err := db.InsertTx(tx)
assert.NilError(t, err)

retrievedTx, err := db.GetTx(1)
assert.NilError(t, err)
assert.DeepEqual(t, retrievedTx, &tx)
}

func TestGetPendingTransactions(t *testing.T) {
err := InitDB(":memory:") // in-memory database
if err != nil {
t.Fatal(err)
}
func TestGetPendingTxs(t *testing.T) {
db := initTestDB(t)

transactions := []Transaction{
transactions := []Tx{
{BtcTxID: 1, RawTx: "tx1-hex", Status: StatusPending},
{BtcTxID: 2, RawTx: "tx2-hex", Status: StatusBroadcasted},
{BtcTxID: 3, RawTx: "tx3-hex", Status: StatusPending},
}
for _, tx := range transactions {
err = InsertTransaction(tx)
if err != nil {
t.Fatal(err)
}
}

pendingTxs, err := GetPendingTransactions()
if err != nil {
t.Errorf("GetPendingTransactions() error = %v", err)
err := db.InsertTx(tx)
assert.NilError(t, err)
}

if len(pendingTxs) != 2 {
t.Errorf("Expected 2 pending transactions, got %d", len(pendingTxs))
}
pendingTxs, err := db.GetPendingTxs()
assert.NilError(t, err)
assert.Equal(t, len(pendingTxs), 2)
}

func TestUpdateTransactionStatus(t *testing.T) {
err := InitDB(":memory:") // in-memory database
if err != nil {
t.Fatal(err)
}
func TestUpdateTxStatus(t *testing.T) {
db := initTestDB(t)

txID := uint64(1)

tx := Transaction{
tx := Tx{
BtcTxID: txID,
RawTx: "raw-transaction-hex",
Status: StatusPending,
}
err = InsertTransaction(tx)
if err != nil {
t.Fatal(err)
}
err := db.InsertTx(tx)
assert.NilError(t, err)

err = UpdateTransactionStatus(txID, StatusBroadcasted)
if err != nil {
t.Errorf("UpdateTransactionStatus() error = %v", err)
}
err = db.UpdateTxStatus(txID, StatusBroadcasted)
assert.NilError(t, err)

updatedTx, err := GetTransaction(txID)
if err != nil {
t.Errorf("GetTransaction() error = %v", err)
}
updatedTx, err := db.GetTx(txID)
assert.NilError(t, err)
assert.Equal(t, updatedTx.Status, StatusBroadcasted)
}

if updatedTx.Status != StatusBroadcasted {
t.Errorf("Expected StatusBrodcasted, got %s", updatedTx.Status)
}
func initTestDB(t *testing.T) *DB {
t.Helper()

db, err := NewDB(":memory:")
assert.NilError(t, err)
return db
}

0 comments on commit 20b025c

Please sign in to comment.