From 344275dc55d74d4253534a94be04986a0bcbf02f Mon Sep 17 00:00:00 2001 From: Stefano Scafiti Date: Mon, 9 Dec 2024 17:54:34 +0100 Subject: [PATCH] chore(embedded/sql): add support for LEFT JOIN Signed-off-by: Stefano Scafiti --- embedded/sql/engine_test.go | 138 ++++++++++++++++++++++++++ embedded/sql/joint_row_reader.go | 43 +++++--- embedded/sql/joint_row_reader_test.go | 5 +- embedded/sql/stmt.go | 1 - 4 files changed, 173 insertions(+), 14 deletions(-) diff --git a/embedded/sql/engine_test.go b/embedded/sql/engine_test.go index 78f812b6ce..3c44841de3 100644 --- a/embedded/sql/engine_test.go +++ b/embedded/sql/engine_test.go @@ -5986,6 +5986,123 @@ func TestNestedJoins(t *testing.T) { require.NoError(t, err) } +func TestLeftJoins(t *testing.T) { + e := setupCommonTest(t) + + _, _, err := e.Exec( + context.Background(), + nil, + ` + CREATE TABLE customers ( + customer_id INTEGER, + customer_name VARCHAR(50), + email VARCHAR(100), + + PRIMARY KEY customer_id + ); + + CREATE TABLE products ( + product_id INTEGER, + product_name VARCHAR(50), + price FLOAT, + + PRIMARY KEY product_id + ); + + CREATE TABLE orders ( + order_id INTEGER, + customer_id INTEGER, + order_date TIMESTAMP, + + PRIMARY KEY order_id + ); + + CREATE TABLE order_items ( + order_item_id INTEGER, + order_id INTEGER, + product_id INTEGER, + quantity INTEGER, + + PRIMARY KEY order_item_id + ); + + INSERT INTO customers (customer_id, customer_name, email) + VALUES + (1, 'Alice Johnson', 'alice@example.com'), + (2, 'Bob Smith', 'bob@example.com'), + (3, 'Charlie Brown', 'charlie@example.com'); + + INSERT INTO products (product_id, product_name, price) + VALUES + (1, 'Laptop', 1200.00), + (2, 'Smartphone', 800.00), + (3, 'Tablet', 400.00); + + INSERT INTO orders (order_id, customer_id, order_date) + VALUES + (101, 1, '2024-11-01'::TIMESTAMP), + (102, 2, '2024-11-02'::TIMESTAMP), + (103, 1, '2024-11-03'::TIMESTAMP); + + INSERT INTO order_items (order_item_id, order_id, product_id, quantity) + VALUES + (1, 101, 1, 2), + (2, 101, 2, 1), + (3, 102, 3, 3), + (4, 103, 2, 2); + `, + nil, + ) + require.NoError(t, err) + + assertQueryShouldProduceResults( + t, + e, + `SELECT c.customer_id, c.customer_name, c.email, o.order_id, o.order_date + FROM customers c LEFT JOIN orders o ON c.customer_id = o.customer_id + ORDER BY c.customer_id, o.order_date;`, + ` + SELECT * + FROM ( + VALUES + (1, 'Alice Johnson', 'alice@example.com', 101, '2024-11-01'::TIMESTAMP), + (1, 'Alice Johnson', 'alice@example.com', 103, '2024-11-03'::TIMESTAMP), + (2, 'Bob Smith', 'bob@example.com', 102, '2024-11-02'::TIMESTAMP), + (3, 'Charlie Brown', 'charlie@example.com', NULL, NULL) + )`, + ) + + assertQueryShouldProduceResults( + t, + e, + ` + SELECT + c.customer_name, + c.email, + o.order_id, + o.order_date, + p.product_name, + oi.quantity, + p.price, + (oi.quantity * p.price) AS total_price + FROM + products p + LEFT JOIN order_Items oi ON p.product_id = oi.product_id + LEFT JOIN orders o ON oi.order_id = o.order_id + LEFT JOIN customers c ON o.customer_id = c.customer_id + ORDER BY o.order_date, c.customer_name;`, + ` + SELECT * + FROM ( + VALUES + ('Alice Johnson', 'alice@example.com', 101, '2024-11-01'::TIMESTAMP, 'Laptop', 2, 1200.00, 2400.00), + ('Alice Johnson', 'alice@example.com', 101, '2024-11-01'::TIMESTAMP, 'Smartphone', 1, 800.00, 800.00), + ('Bob Smith', 'bob@example.com', 102, '2024-11-02'::TIMESTAMP, 'Tablet', 3, 400.00, 1200.00), + ('Alice Johnson', 'alice@example.com', 103, '2024-11-03'::TIMESTAMP, 'Smartphone', 2, 800.00, 1600.00) + )`, + ) +} + func TestReOpening(t *testing.T) { st, err := store.Open(t.TempDir(), store.DefaultOptions().WithMultiIndexing(true)) require.NoError(t, err) @@ -9434,3 +9551,24 @@ func TestFunctions(t *testing.T) { require.Equal(t, "OBJECT", rows[0].ValuesByPosition[0].RawValue().(string)) }) } + +func assertQueryShouldProduceResults(t *testing.T, e *Engine, query, resultQuery string) { + queryReader, err := e.Query(context.Background(), nil, query, nil) + require.NoError(t, err) + defer queryReader.Close() + + resultReader, err := e.Query(context.Background(), nil, resultQuery, nil) + require.NoError(t, err) + defer resultReader.Close() + + for { + actualRow, actualErr := queryReader.Read(context.Background()) + expectedRow, expectedErr := resultReader.Read(context.Background()) + require.Equal(t, expectedErr, actualErr) + + if errors.Is(actualErr, ErrNoMoreRows) { + break + } + require.Equal(t, expectedRow.ValuesByPosition, actualRow.ValuesByPosition) + } +} diff --git a/embedded/sql/joint_row_reader.go b/embedded/sql/joint_row_reader.go index ddd04b717a..e01ba14a02 100644 --- a/embedded/sql/joint_row_reader.go +++ b/embedded/sql/joint_row_reader.go @@ -39,7 +39,9 @@ func newJointRowReader(rowReader RowReader, joins []*JoinSpec) (*jointRowReader, } for _, jspec := range joins { - if jspec.joinType != InnerJoin { + switch jspec.joinType { + case InnerJoin, LeftJoin: + default: return nil, ErrUnsupportedJoinType } } @@ -113,7 +115,6 @@ func (jointr *jointRowReader) colsBySelector(ctx context.Context) (map[string]Co colDescriptors[sel] = des } } - return colDescriptors, nil } @@ -240,17 +241,35 @@ func (jointr *jointRowReader) Read(ctx context.Context) (row *Row, err error) { r, err := reader.Read(ctx) if err == ErrNoMoreRows { - // previous reader will need to read next row - unsolvedFK = true - - err = reader.Close() - if err != nil { - return nil, err + if jspec.joinType == InnerJoin { + // previous reader will need to read next row + unsolvedFK = true + + err = reader.Close() + if err != nil { + return nil, err + } + + break + } else { // LEFT JOIN: fill column values with NULLs + cols, err := reader.Columns(ctx) + if err != nil { + return nil, err + } + + r = &Row{ + ValuesByPosition: make([]TypedValue, len(cols)), + ValuesBySelector: make(map[string]TypedValue, len(cols)), + } + + for i, col := range cols { + nullValue := NewNull(col.Type) + + r.ValuesByPosition[i] = nullValue + r.ValuesBySelector[col.Selector()] = nullValue + } } - - break - } - if err != nil { + } else if err != nil { reader.Close() return nil, err } diff --git a/embedded/sql/joint_row_reader_test.go b/embedded/sql/joint_row_reader_test.go index ba8fa267cd..f28e5e6bf4 100644 --- a/embedded/sql/joint_row_reader_test.go +++ b/embedded/sql/joint_row_reader_test.go @@ -51,9 +51,12 @@ func TestJointRowReader(t *testing.T) { r, err := newRawRowReader(tx, nil, table, period{}, "", &ScanSpecs{Index: table.primaryIndex}) require.NoError(t, err) - _, err = newJointRowReader(r, []*JoinSpec{{joinType: LeftJoin}}) + _, err = newJointRowReader(r, []*JoinSpec{{joinType: RightJoin}}) require.ErrorIs(t, err, ErrUnsupportedJoinType) + _, err = newJointRowReader(r, []*JoinSpec{{joinType: LeftJoin}}) + require.NoError(t, err) + _, err = newJointRowReader(r, []*JoinSpec{{joinType: InnerJoin, ds: &SelectStmt{}}}) require.NoError(t, err) diff --git a/embedded/sql/stmt.go b/embedded/sql/stmt.go index c6cbe46d6c..ffab9e4ab6 100644 --- a/embedded/sql/stmt.go +++ b/embedded/sql/stmt.go @@ -3409,7 +3409,6 @@ func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[strin rowReader = newLimitRowReader(rowReader, limit) } } - return rowReader, nil }