Skip to content

Commit

Permalink
Merge branch 'master' into stan/28-transaction-processing
Browse files Browse the repository at this point in the history
  • Loading branch information
sczembor authored Nov 27, 2024
2 parents 59f655f + 1364ad0 commit f34572a
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 0 deletions.
119 changes: 119 additions & 0 deletions dal/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package dal

import (
"database/sql"

_ "github.com/mattn/go-sqlite3" // Import the SQLite driver
)

// TxStatus represents the different states of a transaction.
type TxStatus byte

// Transaction status constants
const (
StatusPending TxStatus = iota
StatusBroadcasted
StatusConfirmed
)

// SQL queries
const (
insertTransactionSQL = "INSERT INTO transactions(txid, rawtx, status) values(?,?,?)"
getPendingTransactionsSQL = "SELECT txid, rawtx, status FROM transactions WHERE status = ?"
getTransactionByTxidSQL = "SELECT txid, rawtx, status FROM transactions WHERE txid = ?"
updateTransactionStatusSQL = "UPDATE transactions SET status = ? WHERE txid = ?"
createTransactionsTableSQL = `
CREATE TABLE IF NOT EXISTS transactions (
txid TEXT PRIMARY KEY,
rawtx TEXT NOT NULL,
status INTEGER NOT NULL NOT NULL
)
`
)

// Tx represents a transaction record in the database.
type Tx struct {
BtcTxID uint64 `json:"txid"`
RawTx string `json:"rawtx"`
Status TxStatus `json:"status"`
// TODO: other fields
}

// DB holds the database connection and provides methods for interacting with it.
type DB struct {
conn *sql.DB
}

// NewDB creates a new DB instance
func NewDB(dbPath string) (*DB, error) {
db := &DB{}

var err error
db.conn, err = sql.Open("sqlite3", dbPath)
if err != nil {
return nil, err
}
return db, err
}

// InitDB initializes the database
func (db DB) InitDB() error {
_, err := db.conn.Exec(createTransactionsTableSQL)
return err
}

// InsertTx inserts a new transaction into the database
func (db DB) InsertTx(tx Tx) error {
_, err := db.conn.Exec(insertTransactionSQL, tx.BtcTxID, tx.RawTx, tx.Status)
if err != nil {
return err
}

return nil
}

// GetTx retrives a transaction by its txid
func (db DB) GetTx(txID uint64) (*Tx, error) {
row := db.conn.QueryRow(getTransactionByTxidSQL, txID)
var tx Tx
err := row.Scan(&tx.BtcTxID, &tx.RawTx, &tx.Status)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}

return &tx, nil
}

// GetPendingTxs retrieves all transactions with a "pending" status
func (db DB) GetPendingTxs() ([]Tx, error) {
rows, err := db.conn.Query(getPendingTransactionsSQL, StatusPending)
if err != nil {
return nil, err
}
defer rows.Close()

var transactions []Tx
for rows.Next() {
var tx Tx
err := rows.Scan(&tx.BtcTxID, &tx.RawTx, &tx.Status)
if err != nil {
return nil, err
}
transactions = append(transactions, tx)
}

return transactions, nil
}

// UpdateTxStatus updates the status of a transaction by txid
func (db DB) UpdateTxStatus(txID uint64, status TxStatus) error {
_, err := db.conn.Exec(updateTransactionStatusSQL, status, txID)
if err != nil {
return err
}

return nil
}
73 changes: 73 additions & 0 deletions dal/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package dal

import (
"testing"

"gotest.tools/assert"
)

func TestInsertTx(t *testing.T) {
db := initTestDB(t)

tx := Tx{
BtcTxID: 1,
RawTx: "raw-transaction-hex",
Status: StatusPending,
}

err := db.InsertTx(tx)
assert.NilError(t, err)

retrievedTx, err := db.GetTx(1)
assert.NilError(t, err)
assert.DeepEqual(t, retrievedTx, &tx)
}

func TestGetPendingTxs(t *testing.T) {
db := initTestDB(t)

transactions := []Tx{
{BtcTxID: 1, RawTx: "tx1-hex", Status: StatusPending},
{BtcTxID: 2, RawTx: "tx2-hex", Status: StatusBroadcasted},
{BtcTxID: 3, RawTx: "tx3-hex", Status: StatusPending},
}
for _, tx := range transactions {
err := db.InsertTx(tx)
assert.NilError(t, err)
}

pendingTxs, err := db.GetPendingTxs()
assert.NilError(t, err)
assert.Equal(t, len(pendingTxs), 2)
}

func TestUpdateTxStatus(t *testing.T) {
db := initTestDB(t)

txID := uint64(1)

tx := Tx{
BtcTxID: txID,
RawTx: "raw-transaction-hex",
Status: StatusPending,
}
err := db.InsertTx(tx)
assert.NilError(t, err)

err = db.UpdateTxStatus(txID, StatusBroadcasted)
assert.NilError(t, err)

updatedTx, err := db.GetTx(txID)
assert.NilError(t, err)
assert.Equal(t, updatedTx.Status, StatusBroadcasted)
}

func initTestDB(t *testing.T) *DB {
t.Helper()

db, err := NewDB(":memory:")
assert.NilError(t, err)
err = db.InitDB()
assert.NilError(t, err)
return db
}

0 comments on commit f34572a

Please sign in to comment.