From 3fcbbec427774fa8bd3e812a4234dcece791456f Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 6 Oct 2024 14:21:42 +0200 Subject: [PATCH 1/2] Add support for Row expression. --- internal/jet/func_expression.go | 7 +- internal/jet/literal_expression.go | 26 ------- internal/jet/order_set_aggregate_functions.go | 7 +- internal/jet/row_expression.go | 78 +++++++++++++++++++ mysql/expressions.go | 8 ++ postgres/expressions.go | 8 ++ postgres/literal.go | 10 +++ sqlite/expressions.go | 8 ++ sqlite/functions.go | 6 +- tests/mysql/alltypes_test.go | 39 +++++++++- tests/mysql/with_test.go | 4 +- tests/postgres/alltypes_test.go | 50 ++++++++++-- tests/postgres/update_test.go | 5 +- tests/postgres/with_test.go | 12 +-- tests/sqlite/alltypes_test.go | 41 +++++++++- tests/sqlite/with_test.go | 8 +- 16 files changed, 254 insertions(+), 63 deletions(-) create mode 100644 internal/jet/row_expression.go diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 7e498806..ddc579e4 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -12,11 +12,6 @@ func OR(expressions ...BoolExpression) BoolExpression { return newBoolExpressionListOperator("OR", expressions...) } -// ROW function is used to create a tuple value that consists of a set of expressions or column values. -func ROW(expressions ...Expression) Expression { - return NewFunc("ROW", expressions, nil) -} - // ------------------ Mathematical functions ---------------// // ABSf calculates absolute value from float expression @@ -711,7 +706,7 @@ func (p parametersSerializer) serialize(statement StatementType, out *SQLBuilder if _, isStatement := expression.(Statement); isStatement { expression.serialize(statement, out, options...) } else { - skipWrap(expression).serialize(statement, out, options...) + expression.serialize(statement, out, append(options, NoWrap, Ident)...) } } } diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index d6f0b415..251d3ab9 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -374,32 +374,6 @@ func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, option //---------------------------------------------------// -type wrap struct { - ExpressionInterfaceImpl - expressions []Expression -} - -func (n *wrap) serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString("(") - - if len(n.expressions) == 1 { - options = append(options, NoWrap, Ident) - } - serializeExpressionList(statementType, n.expressions, ", ", out, options...) - - out.WriteString(")") -} - -// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... ) -func WRAP(expression ...Expression) Expression { - wrap := &wrap{expressions: expression} - wrap.ExpressionInterfaceImpl.Parent = wrap - - return wrap -} - -//---------------------------------------------------// - type rawExpression struct { ExpressionInterfaceImpl diff --git a/internal/jet/order_set_aggregate_functions.go b/internal/jet/order_set_aggregate_functions.go index 8ce5d1e1..eff954a5 100644 --- a/internal/jet/order_set_aggregate_functions.go +++ b/internal/jet/order_set_aggregate_functions.go @@ -54,7 +54,12 @@ type orderSetAggregateFuncExpression struct { func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString(p.name) - WRAP(p.fraction).serialize(statement, out, FallTrough(options)...) + + if p.fraction != nil { + WRAP(p.fraction).serialize(statement, out, FallTrough(options)...) + } else { + WRAP().serialize(statement, out, FallTrough(options)...) + } out.WriteString("WITHIN GROUP") p.orderBy.serialize(statement, out) } diff --git a/internal/jet/row_expression.go b/internal/jet/row_expression.go new file mode 100644 index 00000000..819cf15a --- /dev/null +++ b/internal/jet/row_expression.go @@ -0,0 +1,78 @@ +package jet + +// RowExpression interface +type RowExpression interface { + Expression + + EQ(rhs RowExpression) BoolExpression + NOT_EQ(rhs RowExpression) BoolExpression + IS_DISTINCT_FROM(rhs RowExpression) BoolExpression + IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression + + LT(rhs RowExpression) BoolExpression + LT_EQ(rhs RowExpression) BoolExpression + GT(rhs RowExpression) BoolExpression + GT_EQ(rhs RowExpression) BoolExpression +} + +type rowInterfaceImpl struct { + parent RowExpression +} + +func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression { + return Eq(n.parent, rhs) +} + +func (n *rowInterfaceImpl) NOT_EQ(rhs RowExpression) BoolExpression { + return NotEq(n.parent, rhs) +} + +func (n *rowInterfaceImpl) IS_DISTINCT_FROM(rhs RowExpression) BoolExpression { + return IsDistinctFrom(n.parent, rhs) +} + +func (n *rowInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression { + return IsNotDistinctFrom(n.parent, rhs) +} + +func (n *rowInterfaceImpl) GT(rhs RowExpression) BoolExpression { + return Gt(n.parent, rhs) +} + +func (n *rowInterfaceImpl) GT_EQ(rhs RowExpression) BoolExpression { + return GtEq(n.parent, rhs) +} + +func (n *rowInterfaceImpl) LT(rhs RowExpression) BoolExpression { + return Lt(n.parent, rhs) +} + +func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression { + return LtEq(n.parent, rhs) +} + +//---------------------------------------------------// + +type rowExpressionWrapper struct { + rowInterfaceImpl + Expression +} + +// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. +// This enables the Go compiler to interpret any expression as a row expression +// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. +func RowExp(expression Expression) RowExpression { + rowExpressionWrap := rowExpressionWrapper{Expression: expression} + rowExpressionWrap.rowInterfaceImpl.parent = &rowExpressionWrap + return &rowExpressionWrap +} + +// ROW function is used to create a tuple value that consists of a set of expressions or column values. +func ROW(expressions ...Expression) RowExpression { + return RowExp(NewFunc("ROW", expressions, nil)) +} + +// WRAP creates row expressions without ROW keyword `( expression1, expression2, ... )`. +func WRAP(expressions ...Expression) RowExpression { + return RowExp(NewFunc("", expressions, nil)) +} diff --git a/mysql/expressions.go b/mysql/expressions.go index 53b1fa7f..4073ef56 100644 --- a/mysql/expressions.go +++ b/mysql/expressions.go @@ -30,6 +30,9 @@ type DateTimeExpression = jet.TimestampExpression // TimestampExpression interface type TimestampExpression = jet.TimestampExpression +// RowExpression interface +type RowExpression = jet.RowExpression + // BoolExp is bool expression wrapper around arbitrary expression. // Allows go compiler to see any expression as bool expression. // Does not add sql cast to generated sql builder output. @@ -70,6 +73,11 @@ var DateTimeExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampExp = jet.TimestampExp +// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. +// This enables the Go compiler to interpret any expression as a row expression +// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. +var RowExp = jet.RowExp + // CustomExpression is used to define custom expressions. var CustomExpression = jet.CustomExpression diff --git a/postgres/expressions.go b/postgres/expressions.go index 98729100..d8ad34b4 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -36,6 +36,9 @@ type TimestampExpression = jet.TimestampExpression // TimestampzExpression interface type TimestampzExpression = jet.TimestampzExpression +// RowExpression interface +type RowExpression = jet.RowExpression + // DateRange Expression interface type DateRange = jet.Range[DateExpression] @@ -99,6 +102,11 @@ var TimestampExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampzExp = jet.TimestampzExp +// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. +// This enables the Go compiler to interpret any expression as a row expression +// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. +var RowExp = jet.RowExp + // RangeExp is range expression wrapper around arbitrary expression. // Allows go compiler to see any expression as range expression. // Does not add sql cast to generated sql builder output. diff --git a/postgres/literal.go b/postgres/literal.go index e3a95b3b..26b75d88 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -57,6 +57,16 @@ func Uint64(value uint64) IntegerExpression { // Float creates new float literal expression var Float = jet.Float +// Float32 is constructor for 32 bit float literals +func Float32(value float32) FloatExpression { + return CAST(jet.Literal(value)).AS_REAL() +} + +// Float64 is constructor for 64 bit float literals +func Float64(value float64) FloatExpression { + return CAST(jet.Literal(value)).AS_DOUBLE() +} + // Decimal creates new float literal expression var Decimal = jet.Decimal diff --git a/sqlite/expressions.go b/sqlite/expressions.go index 42ccc96c..0b2d320a 100644 --- a/sqlite/expressions.go +++ b/sqlite/expressions.go @@ -33,6 +33,9 @@ type DateTimeExpression = jet.TimestampExpression // TimestampExpression interface type TimestampExpression = jet.TimestampExpression +// RowExpression interface +type RowExpression = jet.RowExpression + // BoolExp is bool expression wrapper around arbitrary expression. // Allows go compiler to see any expression as bool expression. // Does not add sql cast to generated sql builder output. @@ -73,6 +76,11 @@ var DateTimeExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampExp = jet.TimestampExp +// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. +// This enables the Go compiler to interpret any expression as a row expression +// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. +var RowExp = jet.RowExp + // CustomExpression is used to define custom expressions. var CustomExpression = jet.CustomExpression diff --git a/sqlite/functions.go b/sqlite/functions.go index 47a0a5b7..a76f236d 100644 --- a/sqlite/functions.go +++ b/sqlite/functions.go @@ -15,10 +15,8 @@ var ( OR = jet.OR ) -// ROW is construct one table row from list of expressions. -func ROW(expressions ...Expression) Expression { - return jet.NewFunc("", expressions, nil) -} +// ROW is construct one row from a list of expressions. +var ROW = jet.WRAP // ------------------ Mathematical functions ---------------// diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index d5702bda..61bc7f25 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -97,18 +97,18 @@ func TestExpressionOperators(t *testing.T) { SELECT all_types.'integer' IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN (?, ?)) AS "result.in", - (all_types.small_int_ptr IN ( + (all_types.small_int_ptr IN (( SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types - )) AS "result.in_select", + ))) AS "result.in_select", (CURRENT_USER()) AS "result.raw", (? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg", (? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN ( + (all_types.small_int_ptr NOT IN (( SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types - )) AS "result.not_in_select" + ))) AS "result.not_in_select" FROM test_sample.all_types LIMIT ?; `, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2)) @@ -1404,3 +1404,34 @@ VALUES ('91.23', '45.67', '12.35', '56.79', 0.2, 0.22, 0.3, 0.33, 0.4, 0.44); require.Equal(t, 45.67, *result.Floats.DecimalPtr) }) } + +func TestRowExpression(t *testing.T) { + now := time.Now() + nowAddHour := time.Now().Add(time.Hour) + + stmt := SELECT( + ROW(Bool(false), DateT(now)).EQ(ROW(Bool(true), DateT(now))), + ROW(Bool(false), DateT(now)).NOT_EQ(ROW(Bool(true), DateT(now))), + ROW(TimestampT(nowAddHour), String("txt")).IS_DISTINCT_FROM(RowExp(Raw("row(NOW(), 'png')"))), + ROW(TimestampT(now), DateTimeT(nowAddHour)).GT(ROW(TimestampT(now), DateTimeT(now))), + ROW(DateTimeT(nowAddHour), Int(1)).GT_EQ(ROW(DateTimeT(now), Int(2))), + ROW(TimestampT(now), DateTimeT(nowAddHour)).LT(ROW(TimestampT(now), DateTimeT(now))), + ROW(DateTimeT(nowAddHour), Float(1.22)).LT_EQ(ROW(DateTimeT(now), Float(2.33))), + ) + + //fmt.Println(stmt.Sql()) + //fmt.Println(stmt.DebugSql()) + + testutils.AssertStatementSql(t, stmt, ` +SELECT ROW(?, CAST(? AS DATE)) = ROW(?, CAST(? AS DATE)), + ROW(?, CAST(? AS DATE)) != ROW(?, CAST(? AS DATE)), + NOT(ROW(TIMESTAMP(?), ?) <=> (row(NOW(), 'png'))), + ROW(TIMESTAMP(?), CAST(? AS DATETIME)) > ROW(TIMESTAMP(?), CAST(? AS DATETIME)), + ROW(CAST(? AS DATETIME), ?) >= ROW(CAST(? AS DATETIME), ?), + ROW(TIMESTAMP(?), CAST(? AS DATETIME)) < ROW(TIMESTAMP(?), CAST(? AS DATETIME)), + ROW(CAST(? AS DATETIME), ?) <= ROW(CAST(? AS DATETIME), ?); +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} diff --git a/tests/mysql/with_test.go b/tests/mysql/with_test.go index d7d8d3d7..59da3a5c 100644 --- a/tests/mysql/with_test.go +++ b/tests/mysql/with_test.go @@ -164,10 +164,10 @@ WITH payments_to_delete AS ( WHERE payment.amount < 0.5 ) DELETE FROM dvds.payment -WHERE payment.payment_id IN ( +WHERE payment.payment_id IN (( SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id" FROM payments_to_delete - ); + )); `, "''", "`")) tx, err := db.Begin() diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 70c43326..07350b8b 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -347,24 +347,24 @@ func TestExpressionOperators(t *testing.T) { AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), ).LIMIT(2) - //fmt.Println(query.Sql()) + // fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, ` SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN ($1::smallint, $2::smallint)) AS "result.in", - (all_types.small_int_ptr IN ( + (all_types.small_int_ptr IN (( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types - )) AS "result.in_select", + ))) AS "result.in_select", (CURRENT_USER) AS "result.raw", ($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg", ($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN ($9, $10::smallint, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN ( + (all_types.small_int_ptr NOT IN (( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types - )) AS "result.not_in_select" + ))) AS "result.not_in_select" FROM test_sample.all_types LIMIT $11; `, int8(11), int8(22), 78, 56, 11, 22, 33, 44, int64(11), int16(22), int64(2)) @@ -1111,6 +1111,46 @@ FROM test_sample.all_types; require.NoError(t, err) } +func TestRowExpression(t *testing.T) { + now := time.Now() + nowAddHour := time.Now().Add(time.Hour) + + stmt := SELECT( + ROW(Int32(1), Float32(11.22), String("john")).AS("row"), + WRAP(Int64(1), Float64(11.22), String("john")).AS("wrap"), + + ROW(Bool(false), DateT(now)).EQ(ROW(Bool(true), DateT(now))), + WRAP(Bool(false), DateT(now)).NOT_EQ(WRAP(Bool(true), DateT(now))), + + ROW(TimeT(nowAddHour)).IS_DISTINCT_FROM(RowExp(Raw("row(NOW()::time)"))), + ROW().IS_NOT_DISTINCT_FROM(ROW()), + + ROW(TimestampT(now), TimestampzT(nowAddHour)).GT(WRAP(TimestampT(now), TimestampzT(now))), + ROW(TimestampzT(nowAddHour)).GT_EQ(ROW(TimestampzT(now))), + WRAP(TimestampT(now), TimestampzT(nowAddHour)).LT(ROW(TimestampT(now), TimestampzT(now))), + ROW(TimestampzT(nowAddHour)).LT_EQ(ROW(TimestampzT(now))), + ) + + //fmt.Println(stmt.Sql()) + //fmt.Println(stmt.DebugSql()) + + testutils.AssertStatementSql(t, stmt, ` +SELECT ROW($1::integer, $2::real, $3::text) AS "row", + ($4::bigint, $5::double precision, $6::text) AS "wrap", + ROW($7::boolean, $8::date) = ROW($9::boolean, $10::date), + ($11::boolean, $12::date) != ($13::boolean, $14::date), + ROW($15::time without time zone) IS DISTINCT FROM (row(NOW()::time)), + ROW() IS NOT DISTINCT FROM ROW(), + ROW($16::timestamp without time zone, $17::timestamp with time zone) > ($18::timestamp without time zone, $19::timestamp with time zone), + ROW($20::timestamp with time zone) >= ROW($21::timestamp with time zone), + ($22::timestamp without time zone, $23::timestamp with time zone) < ROW($24::timestamp without time zone, $25::timestamp with time zone), + ROW($26::timestamp with time zone) <= ROW($27::timestamp with time zone); +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} + func TestSubQueryColumnReference(t *testing.T) { type expected struct { sql string diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index e0c7da2a..103b4208 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -344,7 +344,10 @@ func TestUpdateExecContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - testutils.AssertExecContextErr(ctx, t, updateStmt, db, "context deadline exceeded") + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := updateStmt.ExecContext(ctx, tx) + require.Error(t, err, "context deadline exceeded") + }) } func TestUpdateFrom(t *testing.T) { diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go index 21fca326..92b5649f 100644 --- a/tests/postgres/with_test.go +++ b/tests/postgres/with_test.go @@ -83,10 +83,10 @@ SELECT orders.ship_region AS "orders.ship_region", SUM(order_details.quantity) AS "product_sales" FROM northwind.orders INNER JOIN northwind.order_details ON (orders.order_id = order_details.order_id) -WHERE orders.ship_region IN ( +WHERE orders.ship_region IN (( SELECT top_region."orders.ship_region" AS "orders.ship_region" FROM top_region - ) + )) GROUP BY orders.ship_region, order_details.product_id ORDER BY SUM(order_details.quantity) DESC; `) @@ -157,19 +157,19 @@ func TestWithStatementDeleteAndInsert(t *testing.T) { testutils.AssertStatementSql(t, stmt, ` WITH remove_discontinued_orders AS ( DELETE FROM northwind.order_details - WHERE order_details.product_id IN ( + WHERE order_details.product_id IN (( SELECT products.product_id AS "products.product_id" FROM northwind.products WHERE products.discontinued = $1 - ) + )) RETURNING order_details.product_id AS "order_details.product_id" ),update_discontinued_price AS ( UPDATE northwind.products SET unit_price = $2 - WHERE products.product_id IN ( + WHERE products.product_id IN (( SELECT remove_discontinued_orders."order_details.product_id" AS "order_details.product_id" FROM remove_discontinued_orders - ) + )) RETURNING products.product_id AS "products.product_id", products.product_name AS "products.product_name", products.supplier_id AS "products.supplier_id", diff --git a/tests/sqlite/alltypes_test.go b/tests/sqlite/alltypes_test.go index 41e5cf75..080e870d 100644 --- a/tests/sqlite/alltypes_test.go +++ b/tests/sqlite/alltypes_test.go @@ -234,18 +234,18 @@ func TestExpressionOperators(t *testing.T) { SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN (?, ?)) AS "result.in", - (all_types.small_int_ptr IN ( + (all_types.small_int_ptr IN (( SELECT all_types.integer AS "all_types.integer" FROM all_types - )) AS "result.in_select", + ))) AS "result.in_select", (length(121232459)) AS "result.raw", (? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg", (? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN ( + (all_types.small_int_ptr NOT IN (( SELECT all_types.integer AS "all_types.integer" FROM all_types - )) AS "result.not_in_select" + ))) AS "result.not_in_select" FROM all_types LIMIT ?; `, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2)) @@ -900,3 +900,36 @@ func TestDateTimeExpressions(t *testing.T) { require.Equal(t, dest.JulianDay, 2.4551543576232754e+06) require.Equal(t, dest.StrfTime, "20:34") } + +func TestRowExpression(t *testing.T) { + date := Date(2000, 9, 9) + time := Time(11, 22, 11) + dateTime := DateTime(2008, 11, 22, 10, 12, 40) + dateTime2 := DateTime(2011, 1, 2, 5, 12, 40) + + stmt := SELECT( + ROW(Bool(false), date).EQ(ROW(Bool(true), date)), + ROW(Bool(false), time).NOT_EQ(ROW(Bool(true), time)), + ROW(time).IS_DISTINCT_FROM(RowExp(Raw("(time('now'))"))), + ROW(dateTime, dateTime2).GT(ROW(dateTime, dateTime2)), + ROW(dateTime2).GT_EQ(ROW(dateTime)), + ROW(dateTime, dateTime2).LT(ROW(dateTime, dateTime2)), + ROW(dateTime2).LT_EQ(ROW(dateTime2)), + ) + + //fmt.Println(stmt.Sql()) + //fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT (FALSE, DATE('2000-09-09')) = (TRUE, DATE('2000-09-09')), + (FALSE, TIME('11:22:11')) != (TRUE, TIME('11:22:11')), + (TIME('11:22:11')) IS NOT ((time('now'))), + (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')) > (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')), + (DATETIME('2011-01-02 05:12:40')) >= (DATETIME('2008-11-22 10:12:40')), + (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')) < (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')), + (DATETIME('2011-01-02 05:12:40')) <= (DATETIME('2011-01-02 05:12:40')); +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} diff --git a/tests/sqlite/with_test.go b/tests/sqlite/with_test.go index 402df2f0..485655c1 100644 --- a/tests/sqlite/with_test.go +++ b/tests/sqlite/with_test.go @@ -153,10 +153,10 @@ WITH payments_to_update AS ( ) UPDATE payment SET amount = 0 -WHERE payment.payment_id IN ( +WHERE payment.payment_id IN (( SELECT payments_to_update.''payment.payment_id'' AS "payment.payment_id" FROM payments_to_update - ); + )); `, "''", "`", -1)) tx := beginDBTx(t) @@ -205,10 +205,10 @@ WITH payments_to_delete AS ( WHERE payment.amount < 0.5 ) DELETE FROM payment -WHERE payment.payment_id IN ( +WHERE payment.payment_id IN (( SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id" FROM payments_to_delete - ); + )); `, "''", "`", -1)) tx := beginDBTx(t) From 8d112f7db8ec674496bf5e8a0919c892b8827fd9 Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 17 Oct 2024 14:12:21 +0200 Subject: [PATCH 2/2] Add support for VALUES statement. --- internal/jet/dialect.go | 8 + internal/jet/expression.go | 17 +- internal/jet/order_set_aggregate_functions.go | 4 +- internal/jet/projection_test.go | 2 +- internal/jet/raw_statement.go | 2 +- internal/jet/row_expression.go | 48 ++- internal/jet/select_table.go | 28 +- internal/jet/values.go | 35 ++ internal/jet/with_statement.go | 11 +- mysql/cast.go | 30 +- mysql/dialect.go | 4 + mysql/functions.go | 4 +- mysql/select_statement.go | 2 +- mysql/select_table.go | 4 +- mysql/set_statement.go | 2 +- mysql/statement.go | 2 +- mysql/values.go | 32 ++ mysql/with_statement.go | 6 +- postgres/dialect.go | 4 + postgres/dialect_test.go | 16 +- postgres/functions.go | 14 +- postgres/insert_statement_test.go | 26 +- postgres/select_statement.go | 2 +- postgres/select_table.go | 6 +- postgres/set_statement.go | 2 +- postgres/values.go | 32 ++ postgres/with_statement.go | 6 +- sqlite/dialect.go | 4 + sqlite/functions.go | 6 +- sqlite/select_statement.go | 2 +- sqlite/select_table.go | 4 +- sqlite/set_statement.go | 2 +- sqlite/values.go | 26 ++ sqlite/with_statement.go | 10 +- tests/mysql/main_test.go | 6 + tests/mysql/values_test.go | 347 ++++++++++++++++++ tests/postgres/alltypes_test.go | 9 +- tests/postgres/northwind_test.go | 32 +- tests/postgres/values_test.go | 284 ++++++++++++++ tests/sqlite/update_test.go | 2 +- tests/sqlite/values_test.go | 344 +++++++++++++++++ 41 files changed, 1296 insertions(+), 131 deletions(-) create mode 100644 internal/jet/values.go create mode 100644 mysql/values.go create mode 100644 postgres/values.go create mode 100644 sqlite/values.go create mode 100644 tests/mysql/values_test.go create mode 100644 tests/postgres/values_test.go create mode 100644 tests/sqlite/values_test.go diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index f3ad2b40..68c4c022 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -13,6 +13,7 @@ type Dialect interface { ArgumentPlaceholder() QueryPlaceholderFunc IsReservedWord(name string) bool SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc + ValuesDefaultColumnName(index int) string } // SerializerFunc func @@ -35,6 +36,7 @@ type DialectParams struct { ArgumentPlaceholder QueryPlaceholderFunc ReservedWords []string SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc + ValuesDefaultColumnName func(index int) string } // NewDialect creates new dialect with params @@ -49,6 +51,7 @@ func NewDialect(params DialectParams) Dialect { argumentPlaceholder: params.ArgumentPlaceholder, reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords), serializeOrderBy: params.SerializeOrderBy, + valuesDefaultColumnName: params.ValuesDefaultColumnName, } } @@ -62,6 +65,7 @@ type dialectImpl struct { argumentPlaceholder QueryPlaceholderFunc reservedWords map[string]bool serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc + valuesDefaultColumnName func(index int) string } func (d *dialectImpl) Name() string { @@ -107,6 +111,10 @@ func (d *dialectImpl) SerializeOrderBy() func(expression Expression, ascending, return d.serializeOrderBy } +func (d *dialectImpl) ValuesDefaultColumnName(index int) string { + return d.valuesDefaultColumnName(index) +} + func arrayOfStringsToMapOfStrings(arr []string) map[string]bool { ret := map[string]bool{} for _, elem := range arr { diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 05b1797f..9999803f 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -51,12 +51,12 @@ func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression { // IN checks if this expressions matches any in expressions list func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { - return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN") + return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "IN") } // NOT_IN checks if this expressions is different of all expressions in expressions list func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { - return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN") + return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "NOT IN") } // AS the temporary alias name to assign to the expression @@ -316,15 +316,6 @@ func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder, } } -type skipParenthesisWrap struct { - Expression -} - -func skipWrap(expression Expression) Expression { - return &skipParenthesisWrap{expression} -} - -// since the expression is a function parameter, there is no need to wrap it in parentheses -func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - s.Expression.serialize(statement, out, append(options, NoWrap)...) +func wrap(expressions ...Expression) Expression { + return NewFunc("", expressions, nil) } diff --git a/internal/jet/order_set_aggregate_functions.go b/internal/jet/order_set_aggregate_functions.go index eff954a5..c8538450 100644 --- a/internal/jet/order_set_aggregate_functions.go +++ b/internal/jet/order_set_aggregate_functions.go @@ -56,9 +56,9 @@ func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out out.WriteString(p.name) if p.fraction != nil { - WRAP(p.fraction).serialize(statement, out, FallTrough(options)...) + wrap(p.fraction).serialize(statement, out, FallTrough(options)...) } else { - WRAP().serialize(statement, out, FallTrough(options)...) + wrap().serialize(statement, out, FallTrough(options)...) } out.WriteString("WITHIN GROUP") p.orderBy.serialize(statement, out) diff --git a/internal/jet/projection_test.go b/internal/jet/projection_test.go index 0370b437..61dd9209 100644 --- a/internal/jet/projection_test.go +++ b/internal/jet/projection_test.go @@ -39,7 +39,7 @@ AVG(table1.col_int) AS "avg", table2.col3 AS "col3", table2.col4 AS "col4"`) - subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery")) + subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery", nil)) assertProjectionSerialize(t, subQueryProjections, `"subQuery"."table1.col3" AS "table1.col3", diff --git a/internal/jet/raw_statement.go b/internal/jet/raw_statement.go index 191c7b47..99fb8eb5 100644 --- a/internal/jet/raw_statement.go +++ b/internal/jet/raw_statement.go @@ -8,7 +8,7 @@ type rawStatementImpl struct { } // RawStatement creates new sql statements from raw query and optional map of named arguments -func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) Statement { +func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) SerializerStatement { newRawStatement := rawStatementImpl{ serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ dialect: dialect, diff --git a/internal/jet/row_expression.go b/internal/jet/row_expression.go index 819cf15a..e7d5ed57 100644 --- a/internal/jet/row_expression.go +++ b/internal/jet/row_expression.go @@ -3,6 +3,7 @@ package jet // RowExpression interface type RowExpression interface { Expression + HasProjections EQ(rhs RowExpression) BoolExpression NOT_EQ(rhs RowExpression) BoolExpression @@ -16,7 +17,9 @@ type RowExpression interface { } type rowInterfaceImpl struct { - parent RowExpression + parent Expression + dialect Dialect + elemCount int } func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression { @@ -51,13 +54,44 @@ func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression { return LtEq(n.parent, rhs) } -//---------------------------------------------------// +func (n *rowInterfaceImpl) projections() ProjectionList { + var ret ProjectionList + for i := 0; i < n.elemCount; i++ { + rowColumn := NewColumnImpl(n.dialect.ValuesDefaultColumnName(i), "", nil) + ret = append(ret, &rowColumn) + } + + return ret +} + +// ---------------------------------------------------// type rowExpressionWrapper struct { rowInterfaceImpl Expression } +func newRowExpression(name string, dialect Dialect, expressions ...Expression) RowExpression { + ret := &rowExpressionWrapper{} + ret.rowInterfaceImpl.parent = ret + + ret.Expression = NewFunc(name, expressions, ret) + ret.dialect = dialect + ret.elemCount = len(expressions) + + return ret +} + +// ROW function is used to create a tuple value that consists of a set of expressions or column values. +func ROW(dialect Dialect, expressions ...Expression) RowExpression { + return newRowExpression("ROW", dialect, expressions...) +} + +// WRAP creates row expressions without ROW keyword `( expression1, expression2, ... )`. +func WRAP(dialect Dialect, expressions ...Expression) RowExpression { + return newRowExpression("", dialect, expressions...) +} + // RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. // This enables the Go compiler to interpret any expression as a row expression // Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. @@ -66,13 +100,3 @@ func RowExp(expression Expression) RowExpression { rowExpressionWrap.rowInterfaceImpl.parent = &rowExpressionWrap return &rowExpressionWrap } - -// ROW function is used to create a tuple value that consists of a set of expressions or column values. -func ROW(expressions ...Expression) RowExpression { - return RowExp(NewFunc("ROW", expressions, nil)) -} - -// WRAP creates row expressions without ROW keyword `( expression1, expression2, ... )`. -func WRAP(expressions ...Expression) RowExpression { - return RowExp(NewFunc("", expressions, nil)) -} diff --git a/internal/jet/select_table.go b/internal/jet/select_table.go index c25fba36..f1f58ac9 100644 --- a/internal/jet/select_table.go +++ b/internal/jet/select_table.go @@ -8,15 +8,21 @@ type SelectTable interface { } type selectTableImpl struct { - Statement SerializerHasProjections - alias string + Statement SerializerHasProjections + alias string + columnAliases []ColumnExpression } // NewSelectTable func -func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl { +func NewSelectTable(selectStmt SerializerHasProjections, alias string, columnAliases []ColumnExpression) selectTableImpl { selectTable := selectTableImpl{ - Statement: selectStmt, - alias: alias, + Statement: selectStmt, + alias: alias, + columnAliases: columnAliases, + } + + for _, column := range selectTable.columnAliases { + column.setSubQuery(selectTable) } return selectTable @@ -31,6 +37,10 @@ func (s selectTableImpl) Alias() string { } func (s selectTableImpl) AllColumns() ProjectionList { + if len(s.columnAliases) > 0 { + return ColumnListToProjectionList(s.columnAliases) + } + projectionList := s.projections().fromImpl(s) return projectionList.(ProjectionList) } @@ -40,6 +50,12 @@ func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, opt out.WriteString("AS") out.WriteIdentifier(s.alias) + + if len(s.columnAliases) > 0 { + out.WriteByte('(') + SerializeColumnExpressionNames(s.columnAliases, out) + out.WriteByte(')') + } } // -------------------------------------- @@ -50,7 +66,7 @@ type lateralImpl struct { // NewLateral creates new lateral expression from select statement with alias func NewLateral(selectStmt SerializerStatement, alias string) SelectTable { - return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias)} + return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias, nil)} } func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { diff --git a/internal/jet/values.go b/internal/jet/values.go new file mode 100644 index 00000000..9d0f6911 --- /dev/null +++ b/internal/jet/values.go @@ -0,0 +1,35 @@ +package jet + +// Values hold a set of one or more rows +type Values []RowExpression + +func (v Values) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteByte('(') + out.IncreaseIdent(5) + + out.NewLine() + out.WriteString("VALUES") + + for rowIndex, row := range v { + if rowIndex > 0 { + out.WriteString(",") + out.NewLine() + } else { + out.IncreaseIdent(7) + } + + row.serialize(statement, out, options...) + } + out.DecreaseIdent(7) + out.DecreaseIdent(5) + out.NewLine() + out.WriteByte(')') +} + +func (v Values) projections() ProjectionList { + if len(v) == 0 { + return nil + } + + return v[0].projections() +} diff --git a/internal/jet/with_statement.go b/internal/jet/with_statement.go index 783fa274..03330e61 100644 --- a/internal/jet/with_statement.go +++ b/internal/jet/with_statement.go @@ -64,7 +64,7 @@ type CommonTableExpression struct { // CTE creates new named CommonTableExpression func CTE(name string, columns ...ColumnExpression) CommonTableExpression { cte := CommonTableExpression{ - selectTableImpl: NewSelectTable(nil, name), + selectTableImpl: NewSelectTable(nil, name, columns), Columns: columns, } @@ -99,12 +99,3 @@ func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilde out.WriteIdentifier(c.alias) } } - -// AllColumns returns list of all projections in the CTE -func (c CommonTableExpression) AllColumns() ProjectionList { - if len(c.Columns) > 0 { - return ColumnListToProjectionList(c.Columns) - } - - return c.selectTableImpl.AllColumns() -} diff --git a/mysql/cast.go b/mysql/cast.go index 83a0578b..ca647a5c 100644 --- a/mysql/cast.go +++ b/mysql/cast.go @@ -6,23 +6,27 @@ import ( ) type cast interface { - // Cast expressions as castType type + // AS casts expressions as castType type AS(castType string) Expression - // Cast expression as char with optional length + // AS_CHAR casts expression as char with optional length AS_CHAR(length ...int) StringExpression - // Cast expression AS date type + // AS_DATE casts expression AS date type AS_DATE() DateExpression - // Cast expression AS numeric type, using precision and optionally scale + // AS_FLOAT casts expressions as float type + AS_FLOAT() FloatExpression + // AS_DOUBLE casts expressions as double type + AS_DOUBLE() FloatExpression + // AS_DECIMAL casts expression AS numeric type AS_DECIMAL() FloatExpression - // Cast expression AS time type + // AS_TIME casts expression AS time type AS_TIME() TimeExpression - // Cast expression as datetime type + // AS_DATETIME casts expression as datetime type AS_DATETIME() DateTimeExpression - // Cast expressions as signed integer type + // AS_SIGNED casts expressions as signed integer type AS_SIGNED() IntegerExpression - // Cast expression as unsigned integer type + // AS_UNSIGNED casts expression as unsigned integer type AS_UNSIGNED() IntegerExpression - // Cast expression as binary type + // AS_BINARY casts expression as binary type AS_BINARY() StringExpression } @@ -73,6 +77,14 @@ func (c *castImpl) AS_DATE() DateExpression { return DateExp(c.AS("DATE")) } +func (c *castImpl) AS_FLOAT() FloatExpression { + return FloatExp(c.AS("FLOAT")) +} + +func (c *castImpl) AS_DOUBLE() FloatExpression { + return FloatExp(c.AS("DOUBLE")) +} + // AS_DECIMAL casts expression AS DECIMAL type func (c *castImpl) AS_DECIMAL() FloatExpression { return FloatExp(c.AS("DECIMAL")) diff --git a/mysql/dialect.go b/mysql/dialect.go index 18d2eec7..9628bfbb 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -1,6 +1,7 @@ package mysql import ( + "fmt" "github.com/go-jet/jet/v2/internal/jet" ) @@ -28,6 +29,9 @@ func newDialect() jet.Dialect { }, ReservedWords: reservedWords, SerializeOrderBy: serializeOrderBy, + ValuesDefaultColumnName: func(index int) string { + return fmt.Sprintf("column_%d", index) + }, } return jet.NewDialect(mySQLDialectParams) diff --git a/mysql/functions.go b/mysql/functions.go index ca31d18d..ceec7ab2 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -12,7 +12,9 @@ var ( ) // ROW function is used to create a tuple value that consists of a set of expressions or column values. -var ROW = jet.ROW +func ROW(expressions ...Expression) RowExpression { + return jet.ROW(Dialect, expressions...) +} // ------------------ Mathematical functions ---------------// diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 6c5f3450..aaeff9a2 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -172,7 +172,7 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement { } func (s *selectStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } //----------------------------------------------------- diff --git a/mysql/select_table.go b/mysql/select_table.go index ad221934..8ca06d47 100644 --- a/mysql/select_table.go +++ b/mysql/select_table.go @@ -13,9 +13,9 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable { subQuery := &selectTableImpl{ - SelectTable: jet.NewSelectTable(selectStmt, alias), + SelectTable: jet.NewSelectTable(selectStmt, alias, columnAliases), } subQuery.readableTableInterfaceImpl.parent = subQuery diff --git a/mysql/set_statement.go b/mysql/set_statement.go index 2df75a08..7147f00c 100644 --- a/mysql/set_statement.go +++ b/mysql/set_statement.go @@ -85,7 +85,7 @@ func (s *setStatementImpl) OFFSET(offset int64) setStatement { } func (s *setStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } const ( diff --git a/mysql/statement.go b/mysql/statement.go index 073adce6..883b99d5 100644 --- a/mysql/statement.go +++ b/mysql/statement.go @@ -3,6 +3,6 @@ package mysql import "github.com/go-jet/jet/v2/internal/jet" // RawStatement creates new sql statements from raw query and optional map of named arguments -func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement { +func RawStatement(rawQuery string, namedArguments ...RawArgs) jet.SerializerStatement { return jet.RawStatement(Dialect, rawQuery, namedArguments...) } diff --git a/mysql/values.go b/mysql/values.go new file mode 100644 index 00000000..b55895be --- /dev/null +++ b/mysql/values.go @@ -0,0 +1,32 @@ +package mysql + +import "github.com/go-jet/jet/v2/internal/jet" + +type values struct { + jet.Values +} + +// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table. +// Each row is defined by the ROW constructor, which takes one or more expressions. +// +// Example usage: +// +// VALUES( +// ROW(Int32(204), Float32(1.21)), +// ROW(Int32(207), Float32(1.02)), +// ) +func VALUES(rows ...RowExpression) values { + return values{Values: jet.Values(rows)} +} + +// AS assigns an alias to the temporary VALUES table, allowing it to be referenced +// within SQL FROM clauses, just like a regular table. +// By default, VALUES columns are named `column1`, `column2`, etc... Default column aliasing can be +// overwritten by passing new list of columns. +// +// Example usage: +// +// VALUES(...).AS("film_values", IntegerColumn("length"), TimestampColumn("update_date")) +func (v values) AS(alias string, columns ...Column) SelectTable { + return newSelectTable(v, alias, columns) +} diff --git a/mysql/with_statement.go b/mysql/with_statement.go index ca608cbf..03b4d5b7 100644 --- a/mysql/with_statement.go +++ b/mysql/with_statement.go @@ -6,7 +6,7 @@ import "github.com/go-jet/jet/v2/internal/jet" type CommonTableExpression interface { SelectTable - AS(statement jet.SerializerStatement) CommonTableExpression + AS(statement jet.SerializerHasProjections) CommonTableExpression // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. ALIAS(alias string) SelectTable @@ -41,7 +41,7 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { } // AS is used to define a CTE query -func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { +func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression { c.CommonTableExpression.Statement = statement return c } @@ -52,7 +52,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. func (c *commonTableExpression) ALIAS(name string) SelectTable { - return newSelectTable(c, name) + return newSelectTable(c, name, nil) } func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { diff --git a/postgres/dialect.go b/postgres/dialect.go index 9484ab1a..14929ad1 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -1,6 +1,7 @@ package postgres import ( + "fmt" "github.com/go-jet/jet/v2/internal/jet" "strconv" ) @@ -25,6 +26,9 @@ func newDialect() jet.Dialect { return "$" + strconv.Itoa(ord) }, ReservedWords: reservedWords, + ValuesDefaultColumnName: func(index int) string { + return fmt.Sprintf("column%d", index+1) + }, } return jet.NewDialect(dialectParams) diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index 9aadbc92..6fb987c7 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -46,33 +46,33 @@ func TestExists(t *testing.T) { func TestIN(t *testing.T) { assertSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), - `($1 IN ( + `($1 IN (( SELECT table1.col1 AS "table1.col1" FROM db.table1 -))`, float64(1.11)) +)))`, float64(1.11)) assertSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), - `(ROW($1, table1.col1) IN ( + `(ROW($1, table1.col1) IN (( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2 -))`, int64(12)) +)))`, int64(12)) } func TestNOT_IN(t *testing.T) { assertSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), - `($1 NOT IN ( + `($1 NOT IN (( SELECT table1.col1 AS "table1.col1" FROM db.table1 -))`, float64(1.11)) +)))`, float64(1.11)) assertSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), - `(ROW($1, table1.col1) NOT IN ( + `(ROW($1, table1.col1) NOT IN (( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2 -))`, int64(12)) +)))`, int64(12)) } func TestReservedWordEscaped(t *testing.T) { diff --git a/postgres/functions.go b/postgres/functions.go index 7e71a960..31343254 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -14,7 +14,9 @@ var ( ) // ROW function is used to create a tuple value that consists of a set of expressions or column values. -var ROW = jet.ROW +func ROW(expressions ...Expression) RowExpression { + return jet.ROW(Dialect, expressions...) +} // ------------------ Mathematical functions ---------------// @@ -425,10 +427,12 @@ func castFloatLiteral(fraction FloatExpression) FloatExpression { // ), var GROUPING_SETS = jet.GROUPING_SETS -// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... ) -// The construct (a, b) is normally recognized in expressions as a row constructor. WRAP and ROW method behave exactly the same, -// except when used in GROUPING_SETS. For top level GROUPING SETS expression lists WRAP has to be used. -var WRAP = jet.WRAP +// WRAP surrounds a list of expressions or columns with parentheses, producing new row: (expression1, expression2, ...) +// The construct (a, b) is normally recognized in expressions as a row constructor. WRAP and ROW methods behave exactly the same, +// except when used in GROUPING_SETS and VALUES. In these contexts, WRAP must be used instead of ROW. +func WRAP(expressions ...Expression) RowExpression { + return jet.WRAP(Dialect, expressions...) +} // ROLLUP operator is used with the GROUP BY clause to generate all prefixes of a group of columns including the empty list. // It creates extra rows in the result set that represent the subtotal values for each combination of columns. diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index 25300c27..5ace301c 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -1,7 +1,6 @@ package postgres import ( - "github.com/go-jet/jet/v2/internal/jet" "github.com/stretchr/testify/require" "testing" "time" @@ -151,12 +150,13 @@ func TestInsert_ON_CONFLICT(t *testing.T) { VALUES("one", "two"). VALUES("1", "2"). VALUES("theta", "beta"). - ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE( - SET(table1ColBool.SET(Bool(true)), - table2ColInt.SET(Int(1)), - ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))), - ).WHERE(table1Col1.GT(Int(2))), - ). + ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()). + DO_UPDATE( + SET(table1ColBool.SET(Bool(true)), + table2ColInt.SET(Int(1)), + ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))), + ).WHERE(table1Col1.GT(Int(2))), + ). RETURNING(table1Col1, table1ColBool) assertDebugStatementSql(t, stmt, ` @@ -178,12 +178,12 @@ func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) { stmt := table1.INSERT(table1Col1, table1ColBool). VALUES("one", "two"). VALUES("1", "2"). - ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE( - SET(table1ColBool.SET(Bool(false)), - table2ColInt.SET(Int(1)), - ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))), - ).WHERE(table1Col1.GT(Int(2))), - ). + ON_CONFLICT().ON_CONSTRAINT("idk_primary_key"). + DO_UPDATE( + SET(table1ColBool.SET(Bool(false)), + table2ColInt.SET(Int(1)), + ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))), + ).WHERE(table1Col1.GT(Int(2)))). RETURNING(table1Col1, table1ColBool) assertDebugStatementSql(t, stmt, ` diff --git a/postgres/select_statement.go b/postgres/select_statement.go index 70a9a50b..2a48fc59 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -181,7 +181,7 @@ func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement { } func (s *selectStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } //----------------------------------------------------- diff --git a/postgres/select_table.go b/postgres/select_table.go index f3d680db..5f57c1de 100644 --- a/postgres/select_table.go +++ b/postgres/select_table.go @@ -2,7 +2,7 @@ package postgres import "github.com/go-jet/jet/v2/internal/jet" -// SelectTable is interface for postgres sub-queries +// SelectTable is interface for postgres temporary tables like sub-queries, VALUES, CTEs etc... type SelectTable interface { readableTable jet.SelectTable @@ -13,9 +13,9 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { +func newSelectTable(serializerWithProjections jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable { subQuery := &selectTableImpl{ - SelectTable: jet.NewSelectTable(selectStmt, alias), + SelectTable: jet.NewSelectTable(serializerWithProjections, alias, columnAliases), } subQuery.readableTableInterfaceImpl.parent = subQuery diff --git a/postgres/set_statement.go b/postgres/set_statement.go index 834560d4..0dee00d4 100644 --- a/postgres/set_statement.go +++ b/postgres/set_statement.go @@ -136,7 +136,7 @@ func (s *setStatementImpl) OFFSET_e(offset IntegerExpression) setStatement { } func (s *setStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } const ( diff --git a/postgres/values.go b/postgres/values.go new file mode 100644 index 00000000..28d17cd1 --- /dev/null +++ b/postgres/values.go @@ -0,0 +1,32 @@ +package postgres + +import "github.com/go-jet/jet/v2/internal/jet" + +type values struct { + jet.Values +} + +// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table. +// Each row is defined by the WRAP constructor, which takes one or more expressions. +// +// Example usage: +// +// VALUES( +// WRAP(Int32(204), Float32(1.21)), +// WRAP(Int32(207), Float32(1.02)), +// ) +func VALUES(rows ...RowExpression) values { + return values{Values: jet.Values(rows)} +} + +// AS assigns an alias to the temporary VALUES table, allowing it to be referenced +// within SQL FROM clauses, just like a regular table. +// By default, VALUES columns are named `column1`, `column2`, etc... Default column aliasing can be +// overwritten by passing new list of columns. +// +// Example usage: +// +// VALUES(...).AS("film_values", IntegerColumn("length"), TimestampColumn("update_date")) +func (v values) AS(alias string, columns ...Column) SelectTable { + return newSelectTable(v, alias, columns) +} diff --git a/postgres/with_statement.go b/postgres/with_statement.go index 698d6e3d..99ddc8fb 100644 --- a/postgres/with_statement.go +++ b/postgres/with_statement.go @@ -6,7 +6,7 @@ import "github.com/go-jet/jet/v2/internal/jet" type CommonTableExpression interface { SelectTable - AS(statement jet.SerializerStatement) CommonTableExpression + AS(statement jet.SerializerHasProjections) CommonTableExpression AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. ALIAS(alias string) SelectTable @@ -42,7 +42,7 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { } // AS is used to define a CTE query -func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { +func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression { c.CommonTableExpression.Statement = statement return c } @@ -60,7 +60,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. func (c *commonTableExpression) ALIAS(name string) SelectTable { - return newSelectTable(c, name) + return newSelectTable(c, name, nil) } func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { diff --git a/sqlite/dialect.go b/sqlite/dialect.go index 93e1d2f1..da03364b 100644 --- a/sqlite/dialect.go +++ b/sqlite/dialect.go @@ -1,6 +1,7 @@ package sqlite import ( + "fmt" "github.com/go-jet/jet/v2/internal/jet" ) @@ -23,6 +24,9 @@ func newDialect() jet.Dialect { return "?" }, ReservedWords: reservedWords2, + ValuesDefaultColumnName: func(index int) string { + return fmt.Sprintf("column%d", index+1) + }, } return jet.NewDialect(mySQLDialectParams) diff --git a/sqlite/functions.go b/sqlite/functions.go index a76f236d..ac6bd085 100644 --- a/sqlite/functions.go +++ b/sqlite/functions.go @@ -15,8 +15,10 @@ var ( OR = jet.OR ) -// ROW is construct one row from a list of expressions. -var ROW = jet.WRAP +// ROW function is used to create a tuple value that consists of a set of expressions or column values. +func ROW(expressions ...Expression) RowExpression { + return jet.WRAP(Dialect, expressions...) +} // ------------------ Mathematical functions ---------------// diff --git a/sqlite/select_statement.go b/sqlite/select_statement.go index e74cda64..531ae766 100644 --- a/sqlite/select_statement.go +++ b/sqlite/select_statement.go @@ -155,7 +155,7 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement { } func (s *selectStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } //----------------------------------------------------- diff --git a/sqlite/select_table.go b/sqlite/select_table.go index 9ac7f720..7421f4b0 100644 --- a/sqlite/select_table.go +++ b/sqlite/select_table.go @@ -13,9 +13,9 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable { subQuery := &selectTableImpl{ - SelectTable: jet.NewSelectTable(selectStmt, alias), + SelectTable: jet.NewSelectTable(selectStmt, alias, columnAliases), } subQuery.readableTableInterfaceImpl.parent = subQuery diff --git a/sqlite/set_statement.go b/sqlite/set_statement.go index 0a004bfc..47723a3a 100644 --- a/sqlite/set_statement.go +++ b/sqlite/set_statement.go @@ -86,7 +86,7 @@ func (s *setStatementImpl) OFFSET(offset int64) setStatement { } func (s *setStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } const ( diff --git a/sqlite/values.go b/sqlite/values.go new file mode 100644 index 00000000..cb683cb2 --- /dev/null +++ b/sqlite/values.go @@ -0,0 +1,26 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +type values struct { + jet.Values +} + +// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table. +// Each row is defined by the ROW constructor, which takes one or more expressions. +// +// Example usage: +// +// VALUES( +// ROW(Int32(204), Float32(1.21)), +// ROW(Int32(207), Float32(1.02)), +// ) +func VALUES(rows ...RowExpression) values { + return values{Values: jet.Values(rows)} +} + +// AS assigns an alias to the temporary VALUES table, allowing it to be referenced +// within SQL FROM clauses, just like a regular table. +func (v values) AS(alias string) SelectTable { + return newSelectTable(v, alias, nil) +} diff --git a/sqlite/with_statement.go b/sqlite/with_statement.go index 5375fffc..b05da7dd 100644 --- a/sqlite/with_statement.go +++ b/sqlite/with_statement.go @@ -6,8 +6,8 @@ import "github.com/go-jet/jet/v2/internal/jet" type CommonTableExpression interface { SelectTable - AS(statement jet.SerializerStatement) CommonTableExpression - AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression + AS(statement jet.SerializerHasProjections) CommonTableExpression + AS_NOT_MATERIALIZED(statement jet.SerializerHasProjections) CommonTableExpression // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. ALIAS(alias string) SelectTable @@ -42,13 +42,13 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { } // AS is used to define a CTE query -func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { +func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression { c.CommonTableExpression.Statement = statement return c } // AS_NOT_MATERIALIZED is used to define not materialized CTE query -func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression { +func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerHasProjections) CommonTableExpression { c.CommonTableExpression.NotMaterialized = true c.CommonTableExpression.Statement = statement return c @@ -60,7 +60,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. func (c *commonTableExpression) ALIAS(name string) SelectTable { - return newSelectTable(c, name) + return newSelectTable(c, name, nil) } func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index 4e3b5d51..b67b9014 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -92,3 +92,9 @@ func skipForMariaDB(t *testing.T) { t.SkipNow() } } + +func onlyMariaDB(t *testing.T) { + if !sourceIsMariaDB() { + t.SkipNow() + } +} diff --git a/tests/mysql/values_test.go b/tests/mysql/values_test.go new file mode 100644 index 00000000..09e9332d --- /dev/null +++ b/tests/mysql/values_test.go @@ -0,0 +1,347 @@ +package mysql + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/stretchr/testify/require" + "strings" + "testing" + "time" + + . "github.com/go-jet/jet/v2/mysql" + + "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table" +) + +func TestVALUES(t *testing.T) { + skipForMariaDB(t) + + valuesTable := VALUES( + ROW(Int32(1), Int32(2), Float(4.666), Bool(false), String("txt")), + ROW(Int32(11).ADD(Int32(2)), Int32(22), Float(33.222), Bool(true), String("png")), + ROW(Int32(11), Int32(22), Float(33.222), Bool(true), NULL), + ).AS("values_table") + + stmt := SELECT( + valuesTable.AllColumns(), + ).FROM( + valuesTable, + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT values_table.column_0 AS "column_0", + values_table.column_1 AS "column_1", + values_table.column_2 AS "column_2", + values_table.column_3 AS "column_3", + values_table.column_4 AS "column_4" +FROM ( + VALUES ROW(?, ?, ?, ?, ?), + ROW(? + ?, ?, ?, ?, ?), + ROW(?, ?, ?, ?, NULL) + ) AS values_table; +`) + + var dest []struct { + Column0 int + Column1 int + Column2 float32 + Column3 bool + Column4 *string + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "Column0": 1, + "Column1": 2, + "Column2": 4.666, + "Column3": false, + "Column4": "txt" + }, + { + "Column0": 13, + "Column1": 22, + "Column2": 33.222, + "Column3": true, + "Column4": "png" + }, + { + "Column0": 11, + "Column1": 22, + "Column2": 33.222, + "Column3": true, + "Column4": null + } +] +`) +} + +func TestVALUES_Join(t *testing.T) { + skipForMariaDB(t) + + title := StringColumn("title") + releaseYear := IntegerColumn("ReleaseYear") + rentalRate := FloatColumn("rental_rate") + + lastUpdate := Timestamp(2007, time.February, 11, 12, 0, 0) + + films := VALUES( + ROW(String("Chamber Italian"), Int64(117), Int32(2005), Float(5.82), lastUpdate), + ROW(String("Grosse Wonderful"), Int64(49), Int32(2004), Float(6.242), lastUpdate.ADD(INTERVAL(1, HOUR))), + ROW(String("Airport Pollock"), Int64(54), Int32(2001), Float(7.22), NULL), + ROW(String("Bright Encounters"), Int64(73), Int32(2002), Float(8.25), NULL), + ROW(String("Academy Dinosaur"), Int64(83), Int32(2010), Float(9.22), lastUpdate.SUB(INTERVAL(2, MINUTE))), + ).AS("film_values", + title, IntegerColumn("length"), releaseYear, rentalRate, TimestampColumn("last_update")) + + stmt := SELECT( + Film.AllColumns, + films.AllColumns(), + ).FROM( + Film. + INNER_JOIN(films, title.EQ(Film.Title)), + ).WHERE(AND( + Film.ReleaseYear.GT(releaseYear), + Film.RentalRate.LT(rentalRate), + )).ORDER_BY( + title, + ) + + testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.original_language_id AS "film.original_language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.special_features AS "film.special_features", + film.last_update AS "film.last_update", + film_values.title AS "title", + film_values.length AS "length", + film_values.''ReleaseYear'' AS "ReleaseYear", + film_values.rental_rate AS "rental_rate", + film_values.last_update AS "last_update" +FROM dvds.film + INNER JOIN ( + VALUES ROW('Chamber Italian', 117, 2005, 5.82, TIMESTAMP('2007-02-11 12:00:00')), + ROW('Grosse Wonderful', 49, 2004, 6.242, TIMESTAMP('2007-02-11 12:00:00') + INTERVAL 1 HOUR), + ROW('Airport Pollock', 54, 2001, 7.22, NULL), + ROW('Bright Encounters', 73, 2002, 8.25, NULL), + ROW('Academy Dinosaur', 83, 2010, 9.22, TIMESTAMP('2007-02-11 12:00:00') - INTERVAL 2 MINUTE) + ) AS film_values (title, length, ''ReleaseYear'', rental_rate, last_update) ON (film_values.title = film.title) +WHERE ( + (film.release_year > film_values.''ReleaseYear'') + AND (film.rental_rate < film_values.rental_rate) + ) +ORDER BY film_values.title; +`, "''", "`")) + + var dest []struct { + Film model.Film + + Title string + Length int + ReleaseYear int + RentalRate float32 + LastUpdate *time.Time + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + require.Len(t, dest, 4) + testutils.AssertJSON(t, dest[0:2], ` +[ + { + "Film": { + "FilmID": 8, + "Title": "AIRPORT POLLOCK", + "Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India", + "ReleaseYear": 2006, + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 6, + "RentalRate": 4.99, + "Length": 54, + "ReplacementCost": 15.99, + "Rating": "R", + "SpecialFeatures": "Trailers", + "LastUpdate": "2006-02-15T05:03:42Z" + }, + "Title": "Airport Pollock", + "Length": 54, + "ReleaseYear": 2001, + "RentalRate": 7.22, + "LastUpdate": null + }, + { + "Film": { + "FilmID": 98, + "Title": "BRIGHT ENCOUNTERS", + "Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat", + "ReleaseYear": 2006, + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 4, + "RentalRate": 4.99, + "Length": 73, + "ReplacementCost": 12.99, + "Rating": "PG-13", + "SpecialFeatures": "Trailers", + "LastUpdate": "2006-02-15T05:03:42Z" + }, + "Title": "Bright Encounters", + "Length": 73, + "ReleaseYear": 2002, + "RentalRate": 8.25, + "LastUpdate": null + } +] +`) +} + +func TestVALUES_CTE_Update(t *testing.T) { + skipForMariaDB(t) + + paymentID := IntegerColumn("payment_id") + increase := FloatColumn("increase") + paymentsToUpdate := CTE("values_cte", paymentID, increase) + + stmt := WITH( + paymentsToUpdate.AS( + VALUES( + ROW(Int32(204), Float(1.21)), + ROW(Int32(207), Float(1.02)), + ROW(Int32(200), Float(1.34)), + ROW(Int32(203), Float(1.72)), + ), + ), + )( + Payment.INNER_JOIN(paymentsToUpdate, paymentID.EQ(Payment.PaymentID)). + UPDATE(). + SET( + Payment.Amount.SET(Payment.Amount.MUL(increase)), + ).WHERE(Bool(true)), + ) + + testutils.AssertStatementSql(t, stmt, ` +WITH values_cte (payment_id, increase) AS ( + VALUES ROW(?, ?), + ROW(?, ?), + ROW(?, ?), + ROW(?, ?) +) +UPDATE dvds.payment +INNER JOIN values_cte ON (values_cte.payment_id = payment.payment_id) +SET amount = (payment.amount * values_cte.increase) +WHERE ?; +`) + + testutils.AssertExecAndRollback(t, stmt, db, 4) +} + +func TestVALUES_MariaDB(t *testing.T) { + onlyMariaDB(t) // mariadb won't accept values rows if all the elements are placeholders, so we have to use raw statement + + paymentID := IntegerColumn("payment_id") + increase := FloatColumn("increase") + paymentsToUpdate := CTE("values_cte", paymentID, increase) + + stmt := WITH( + paymentsToUpdate.AS( + RawStatement(` + VALUES (204, 1.21), + (207, 1.02), + (200, 1.34), + (203, 1.72) + `), + ), + )( + SELECT( + Payment.AllColumns, + paymentsToUpdate.AllColumns(), + ).FROM( + Payment. + INNER_JOIN(paymentsToUpdate, paymentID.EQ(Payment.PaymentID)), + ).WHERE( + increase.GT(Float(1.03)), + ).ORDER_BY( + increase, + ), + ) + + testutils.AssertStatementSql(t, stmt, ` +WITH values_cte (payment_id, increase) AS ( + VALUES (204, 1.21), + (207, 1.02), + (200, 1.34), + (203, 1.72) + +) +SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update", + values_cte.payment_id AS "payment_id", + values_cte.increase AS "increase" +FROM dvds.payment + INNER JOIN values_cte ON (values_cte.payment_id = payment.payment_id) +WHERE values_cte.increase > ? +ORDER BY values_cte.increase; +`) + + var dest []struct { + model.Payment + + Increase float64 + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "PaymentID": 204, + "CustomerID": 7, + "StaffID": 1, + "RentalID": 13476, + "Amount": 2.99, + "PaymentDate": "2005-08-20T01:06:04Z", + "LastUpdate": "2006-02-15T22:12:31Z", + "Increase": 1.21 + }, + { + "PaymentID": 200, + "CustomerID": 7, + "StaffID": 2, + "RentalID": 11542, + "Amount": 7.99, + "PaymentDate": "2005-08-17T00:51:32Z", + "LastUpdate": "2006-02-15T22:12:31Z", + "Increase": 1.34 + }, + { + "PaymentID": 203, + "CustomerID": 7, + "StaffID": 2, + "RentalID": 13373, + "Amount": 2.99, + "PaymentDate": "2005-08-19T21:23:31Z", + "LastUpdate": "2006-02-15T22:12:31Z", + "Increase": 1.72 + } +] +`) +} diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 07350b8b..b899ca2e 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -4,7 +4,7 @@ import ( "database/sql" "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/stretchr/testify/assert" - "log/slog" + "testing" "time" @@ -347,8 +347,6 @@ func TestExpressionOperators(t *testing.T) { AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), ).LIMIT(2) - // fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, ` SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", @@ -376,9 +374,6 @@ LIMIT $11; err := query.Query(db, &dest) require.NoError(t, err) - - //testutils.PrintJson(dest) - testutils.AssertJSON(t, dest, ` [ { @@ -936,7 +931,7 @@ func TestTimeExpression(t *testing.T) { func TestIntervalSetFunctionality(t *testing.T) { t.Run("updateQueryIntervalTest", func(t *testing.T) { - slog.Info("Running test", slog.Any("test", t.Name())) + expectedQuery := ` UPDATE test_sample.employee SET pto_accrual = INTERVAL '3 HOUR' diff --git a/tests/postgres/northwind_test.go b/tests/postgres/northwind_test.go index 2d6784d2..7aad0d4f 100644 --- a/tests/postgres/northwind_test.go +++ b/tests/postgres/northwind_test.go @@ -2,6 +2,7 @@ package postgres import ( "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/northwind/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/northwind/table" "github.com/stretchr/testify/require" @@ -10,19 +11,7 @@ import ( func TestNorthwindJoinEverything(t *testing.T) { - stmt := Customers. - LEFT_JOIN(CustomerCustomerDemo, Customers.CustomerID.EQ(CustomerCustomerDemo.CustomerID)). - LEFT_JOIN(CustomerDemographics, CustomerCustomerDemo.CustomerTypeID.EQ(CustomerDemographics.CustomerTypeID)). - LEFT_JOIN(Orders, Orders.CustomerID.EQ(Customers.CustomerID)). - LEFT_JOIN(Shippers, Orders.ShipVia.EQ(Shippers.ShipperID)). - LEFT_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID)). - LEFT_JOIN(Products, OrderDetails.ProductID.EQ(Products.ProductID)). - LEFT_JOIN(Categories, Products.CategoryID.EQ(Categories.CategoryID)). - LEFT_JOIN(Suppliers, Products.SupplierID.EQ(Suppliers.SupplierID)). - LEFT_JOIN(Employees, Orders.EmployeeID.EQ(Employees.EmployeeID)). - LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)). - LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)). - LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)). + stmt := SELECT( Customers.AllColumns, CustomerDemographics.AllColumns, @@ -32,8 +21,21 @@ func TestNorthwindJoinEverything(t *testing.T) { Products.AllColumns, Categories.AllColumns, Suppliers.AllColumns, - ). - ORDER_BY(Customers.CustomerID, Orders.OrderID, Products.ProductID) + ).FROM( + Customers. + LEFT_JOIN(CustomerCustomerDemo, Customers.CustomerID.EQ(CustomerCustomerDemo.CustomerID)). + LEFT_JOIN(CustomerDemographics, CustomerCustomerDemo.CustomerTypeID.EQ(CustomerDemographics.CustomerTypeID)). + LEFT_JOIN(Orders, Orders.CustomerID.EQ(Customers.CustomerID)). + LEFT_JOIN(Shippers, Orders.ShipVia.EQ(Shippers.ShipperID)). + LEFT_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID)). + LEFT_JOIN(Products, OrderDetails.ProductID.EQ(Products.ProductID)). + LEFT_JOIN(Categories, Products.CategoryID.EQ(Categories.CategoryID)). + LEFT_JOIN(Suppliers, Products.SupplierID.EQ(Suppliers.SupplierID)). + LEFT_JOIN(Employees, Orders.EmployeeID.EQ(Employees.EmployeeID)). + LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)). + LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)). + LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)), + ).ORDER_BY(Customers.CustomerID, Orders.OrderID, Products.ProductID) var dest []struct { model.Customers diff --git a/tests/postgres/values_test.go b/tests/postgres/values_test.go new file mode 100644 index 00000000..9e89f7ed --- /dev/null +++ b/tests/postgres/values_test.go @@ -0,0 +1,284 @@ +package postgres + +import ( + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestVALUES(t *testing.T) { + + values := VALUES( + WRAP(Int32(1), Int32(2), Float32(4.666), Bool(false), String("txt")), + WRAP(Int32(11).ADD(Int32(2)), Int32(22), Float32(33.222), Bool(true), String("png")), + WRAP(Int32(11), Int32(22), Float32(33.222), Bool(true), NULL), + ).AS("values_table") + + stmt := SELECT( + values.AllColumns(), + ).FROM( + values, + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT values_table.column1 AS "column1", + values_table.column2 AS "column2", + values_table.column3 AS "column3", + values_table.column4 AS "column4", + values_table.column5 AS "column5" +FROM ( + VALUES ($1::integer, $2::integer, $3::real, $4::boolean, $5::text), + ($6::integer + $7::integer, $8::integer, $9::real, $10::boolean, $11::text), + ($12::integer, $13::integer, $14::real, $15::boolean, NULL) + ) AS values_table; +`) + + var dest []struct { + Column1 int + Column2 int + Column3 float32 + Column4 bool + Column5 *string + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "Column1": 1, + "Column2": 2, + "Column3": 4.666, + "Column4": false, + "Column5": "txt" + }, + { + "Column1": 13, + "Column2": 22, + "Column3": 33.222, + "Column4": true, + "Column5": "png" + }, + { + "Column1": 11, + "Column2": 22, + "Column3": 33.222, + "Column4": true, + "Column5": null + } +] +`) +} + +func TestVALUES_Join(t *testing.T) { + + title := StringColumn("title") + releaseYear := IntegerColumn("ReleaseYear") + rentalRate := FloatColumn("rental_rate") + + lastUpdate := Timestamp(2007, time.February, 11, 12, 0, 0) + + filmValues := VALUES( + WRAP(String("Chamber Italian"), Int64(117), Int32(2005), Float32(5.82), lastUpdate), + WRAP(String("Grosse Wonderful"), Int64(49), Int32(2004), Float32(6.242), lastUpdate.ADD(INTERVAL(1, HOUR))), + WRAP(String("Airport Pollock"), Int64(54), Int32(2001), Float32(7.22), NULL), + WRAP(String("Bright Encounters"), Int64(73), Int32(2002), Float32(8.25), NULL), + WRAP(String("Academy Dinosaur"), Int64(83), Int32(2010), Float32(9.22), lastUpdate.SUB(INTERVAL(2, MINUTE))), + ).AS("film_values", + title, IntegerColumn("length"), releaseYear, rentalRate, TimestampColumn("update_date")) + + stmt := SELECT( + Film.AllColumns, + filmValues.AllColumns(), + ).FROM( + Film. + INNER_JOIN(filmValues, title.EQ(Film.Title)), + ).WHERE(AND( + Film.ReleaseYear.GT(releaseYear), + Film.RentalRate.LT(rentalRate), + )).ORDER_BY( + title, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.last_update AS "film.last_update", + film.special_features AS "film.special_features", + film.fulltext AS "film.fulltext", + film_values.title AS "title", + film_values.length AS "length", + film_values."ReleaseYear" AS "ReleaseYear", + film_values.rental_rate AS "rental_rate", + film_values.update_date AS "update_date" +FROM dvds.film + INNER JOIN ( + VALUES ('Chamber Italian'::text, 117::bigint, 2005::integer, 5.820000171661377::real, '2007-02-11 12:00:00'::timestamp without time zone), + ('Grosse Wonderful'::text, 49::bigint, 2004::integer, 6.242000102996826::real, '2007-02-11 12:00:00'::timestamp without time zone + INTERVAL '1 HOUR'), + ('Airport Pollock'::text, 54::bigint, 2001::integer, 7.21999979019165::real, NULL), + ('Bright Encounters'::text, 73::bigint, 2002::integer, 8.25::real, NULL), + ('Academy Dinosaur'::text, 83::bigint, 2010::integer, 9.220000267028809::real, '2007-02-11 12:00:00'::timestamp without time zone - INTERVAL '2 MINUTE') + ) AS film_values (title, length, "ReleaseYear", rental_rate, update_date) ON (film_values.title = film.title) +WHERE ( + (film.release_year > film_values."ReleaseYear") + AND (film.rental_rate < film_values.rental_rate) + ) +ORDER BY film_values.title; +`) + + //fmt.Println(stmt.DebugSql()) + + var dest []struct { + Film model.Film + + Title string + Length int + ReleaseYear int + RentalRate float32 + UpdateDate *time.Time + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + assert.Len(t, dest, 4) + testutils.AssertJSON(t, dest[0:2], ` +[ + { + "Film": { + "FilmID": 8, + "Title": "Airport Pollock", + "Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 6, + "RentalRate": 4.99, + "Length": 54, + "ReplacementCost": 15.99, + "Rating": "R", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{Trailers}", + "Fulltext": "'airport':1 'ancient':18 'confront':14 'epic':4 'girl':11 'india':19 'monkey':16 'moos':8 'must':13 'pollock':2 'tale':5" + }, + "Title": "Airport Pollock", + "Length": 54, + "ReleaseYear": 2001, + "RentalRate": 7.22, + "UpdateDate": null + }, + { + "Film": { + "FilmID": 98, + "Title": "Bright Encounters", + "Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 4, + "RentalRate": 4.99, + "Length": 73, + "ReplacementCost": 12.99, + "Rating": "PG-13", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{Trailers}", + "Fulltext": "'boat':20 'bright':1 'conquer':14 'encount':2 'fate':4 'feminist':11 'jet':19 'lumberjack':8 'must':13 'student':16 'yarn':5" + }, + "Title": "Bright Encounters", + "Length": 73, + "ReleaseYear": 2002, + "RentalRate": 8.25, + "UpdateDate": null + } +] +`) +} + +func TestVALUES_CTE_Update(t *testing.T) { + + paymentID := IntegerColumn("payment_ID") + increase := FloatColumn("increase") + paymentsToUpdate := CTE("values_cte", paymentID, increase) + + stmt := WITH( + paymentsToUpdate.AS( + VALUES( + WRAP(Int32(20564), Float32(1.21)), + WRAP(Int32(20567), Float32(1.02)), + WRAP(Int32(20570), Float32(1.34)), + WRAP(Int32(20573), Float32(1.72)), + ), + ), + )( + Payment.UPDATE(). + SET( + Payment.Amount.SET(Payment.Amount.MUL(CAST(increase).AS_DECIMAL())), + ). + FROM(paymentsToUpdate). + WHERE(Payment.PaymentID.EQ(paymentID)). + RETURNING(Payment.AllColumns), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +WITH values_cte ("payment_ID", increase) AS ( + VALUES (20564::integer, 1.2100000381469727::real), + (20567::integer, 1.0199999809265137::real), + (20570::integer, 1.340000033378601::real), + (20573::integer, 1.7200000286102295::real) +) +UPDATE dvds.payment +SET amount = (payment.amount * values_cte.increase::decimal) +FROM values_cte +WHERE payment.payment_id = values_cte."payment_ID" +RETURNING payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date"; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + + var payments []model.Payment + + err := stmt.Query(tx, &payments) + require.NoError(t, err) + + assert.Len(t, payments, 4) + testutils.AssertJSON(t, payments[0:2], ` +[ + { + "PaymentID": 20564, + "CustomerID": 379, + "StaffID": 2, + "RentalID": 11457, + "Amount": 4.83, + "PaymentDate": "2007-03-02T19:42:42.996577Z" + }, + { + "PaymentID": 20567, + "CustomerID": 379, + "StaffID": 2, + "RentalID": 13397, + "Amount": 8.15, + "PaymentDate": "2007-03-19T20:35:01.996577Z" + } +] +`) + }) + +} diff --git a/tests/sqlite/update_test.go b/tests/sqlite/update_test.go index 110c6590..e5c52455 100644 --- a/tests/sqlite/update_test.go +++ b/tests/sqlite/update_test.go @@ -283,7 +283,7 @@ func TestUpdateContextDeadlineExceeded(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} + var dest []model.Link err := updateStmt.QueryContext(ctx, tx, &dest) require.Error(t, err, "context deadline exceeded") diff --git a/tests/sqlite/values_test.go b/tests/sqlite/values_test.go new file mode 100644 index 00000000..0793397f --- /dev/null +++ b/tests/sqlite/values_test.go @@ -0,0 +1,344 @@ +package sqlite + +import ( + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/stretchr/testify/require" + "strings" + "testing" + "time" + + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table" +) + +func TestVALUES(t *testing.T) { + + values := VALUES( + ROW(Int32(1), Int32(2), Float(4.666), Bool(false), String("txt")), + ROW(Int32(11).ADD(Int32(2)), Int32(22), Float(33.222), Bool(true), String("png")), + ROW(Int32(11), Int32(22), Float(33.222), Bool(true), NULL), + ).AS("values_table") + + stmt := SELECT( + values.AllColumns(), + ).FROM( + values, + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT values_table.column1 AS "column1", + values_table.column2 AS "column2", + values_table.column3 AS "column3", + values_table.column4 AS "column4", + values_table.column5 AS "column5" +FROM ( + VALUES (?, ?, ?, ?, ?), + (? + ?, ?, ?, ?, ?), + (?, ?, ?, ?, NULL) + ) AS values_table; +`) + + var dest []struct { + Column1 int + Column2 int + Column3 float32 + Column4 bool + Column5 *string + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "Column1": 1, + "Column2": 2, + "Column3": 4.666, + "Column4": false, + "Column5": "txt" + }, + { + "Column1": 13, + "Column2": 22, + "Column3": 33.222, + "Column4": true, + "Column5": "png" + }, + { + "Column1": 11, + "Column2": 22, + "Column3": 33.222, + "Column4": true, + "Column5": null + } +] +`) +} + +func TestVALUES_Join(t *testing.T) { + + lastUpdate := DateTime(2007, time.February, 11, 12, 0, 0) + + films := VALUES( + ROW(String("Chamber Italian"), Int64(117), Int32(2005), Float(5.82), lastUpdate), + ROW(String("Grosse Wonderful"), Int64(49), Int32(2004), Float(6.242), lastUpdate), + ROW(String("Airport Pollock"), Int64(54), Int32(2001), Float(7.22), NULL), + ROW(String("Bright Encounters"), Int64(73), Int32(2002), Float(8.25), NULL), + ROW(String("Academy Dinosaur"), Int64(83), Int32(2010), Float(9.22), DATETIME(lastUpdate, YEARS(2))), + ).AS("film_values") + + title := StringColumn("column1").From(films) + releaseYear := IntegerColumn("column3").From(films) + rentalRate := FloatColumn("column4").From(films) + + stmt := SELECT( + Film.AllColumns, + films.AllColumns(), + ).FROM( + Film. + INNER_JOIN(films, LOWER(title).EQ(LOWER(Film.Title))), + ).WHERE(AND( + CAST(Film.ReleaseYear).AS_INTEGER().GT(releaseYear), + Film.RentalRate.LT(rentalRate), + )).ORDER_BY( + title, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.original_language_id AS "film.original_language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.special_features AS "film.special_features", + film.last_update AS "film.last_update", + film_values.column1 AS "column1", + film_values.column2 AS "column2", + film_values.column3 AS "column3", + film_values.column4 AS "column4", + film_values.column5 AS "column5" +FROM film + INNER JOIN ( + VALUES ('Chamber Italian', 117, 2005, 5.82, DATETIME('2007-02-11 12:00:00')), + ('Grosse Wonderful', 49, 2004, 6.242, DATETIME('2007-02-11 12:00:00')), + ('Airport Pollock', 54, 2001, 7.22, NULL), + ('Bright Encounters', 73, 2002, 8.25, NULL), + ('Academy Dinosaur', 83, 2010, 9.22, DATETIME(DATETIME('2007-02-11 12:00:00'), '2 YEARS')) + ) AS film_values ON (LOWER(film_values.column1) = LOWER(film.title)) +WHERE ( + (CAST(film.release_year AS INTEGER) > film_values.column3) + AND (film.rental_rate < film_values.column4) + ) +ORDER BY film_values.column1; +`) + + var dest []struct { + Film model.Film + + Column1 string + Column2 int + Column3 int + Column4 float32 + Column5 *time.Time + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "Film": { + "FilmID": 8, + "Title": "AIRPORT POLLOCK", + "Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India", + "ReleaseYear": "2006", + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 6, + "RentalRate": 4.99, + "Length": 54, + "ReplacementCost": 15.99, + "Rating": "R", + "SpecialFeatures": "Trailers", + "LastUpdate": "2019-04-11T18:11:48Z" + }, + "Column1": "Airport Pollock", + "Column2": 54, + "Column3": 2001, + "Column4": 7.22, + "Column5": null + }, + { + "Film": { + "FilmID": 98, + "Title": "BRIGHT ENCOUNTERS", + "Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat", + "ReleaseYear": "2006", + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 4, + "RentalRate": 4.99, + "Length": 73, + "ReplacementCost": 12.99, + "Rating": "PG-13", + "SpecialFeatures": "Trailers", + "LastUpdate": "2019-04-11T18:11:48Z" + }, + "Column1": "Bright Encounters", + "Column2": 73, + "Column3": 2002, + "Column4": 8.25, + "Column5": null + }, + { + "Film": { + "FilmID": 133, + "Title": "CHAMBER ITALIAN", + "Description": "A Fateful Reflection of a Moose And a Husband who must Overcome a Monkey in Nigeria", + "ReleaseYear": "2006", + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 7, + "RentalRate": 4.99, + "Length": 117, + "ReplacementCost": 14.99, + "Rating": "NC-17", + "SpecialFeatures": "Trailers", + "LastUpdate": "2019-04-11T18:11:48Z" + }, + "Column1": "Chamber Italian", + "Column2": 117, + "Column3": 2005, + "Column4": 5.82, + "Column5": "2007-02-11T12:00:00Z" + }, + { + "Film": { + "FilmID": 384, + "Title": "GROSSE WONDERFUL", + "Description": "A Epic Drama of a Cat And a Explorer who must Redeem a Moose in Australia", + "ReleaseYear": "2006", + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 5, + "RentalRate": 4.99, + "Length": 49, + "ReplacementCost": 19.99, + "Rating": "R", + "SpecialFeatures": "Behind the Scenes", + "LastUpdate": "2019-04-11T18:11:48Z" + }, + "Column1": "Grosse Wonderful", + "Column2": 49, + "Column3": 2004, + "Column4": 6.242, + "Column5": "2007-02-11T12:00:00Z" + } +] +`) +} + +func TestVALUES_CTE_Update(t *testing.T) { + + paymentID := IntegerColumn("payment_ID") + increase := FloatColumn("increase") + paymentsToUpdate := CTE("values_cte", paymentID, increase) + + stmt := WITH( + paymentsToUpdate.AS( + VALUES( + ROW(Int32(204), Float(1.21)), + ROW(Int32(207), Float(1.02)), + ROW(Int32(200), Float(1.34)), + ROW(Int32(203), Float(1.72)), + ), + ), + )( + Payment.UPDATE(). + SET( + Payment.Amount.SET(Payment.Amount.MUL(increase)), + ). + FROM(paymentsToUpdate). + WHERE(Payment.PaymentID.EQ(paymentID)). + RETURNING(Payment.AllColumns), + ) + + testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(` +WITH values_cte (''payment_ID'', increase) AS ( + VALUES (?, ?), + (?, ?), + (?, ?), + (?, ?) +) +UPDATE payment +SET amount = (payment.amount * values_cte.increase) +FROM values_cte +WHERE payment.payment_id = values_cte.''payment_ID'' +RETURNING payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update"; +`, "''", "`")) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var payments []model.Payment + + err := stmt.Query(tx, &payments) + + require.NoError(t, err) + testutils.AssertJSON(t, payments, ` +[ + { + "PaymentID": 200, + "CustomerID": 7, + "StaffID": 2, + "RentalID": 11542, + "Amount": 10.706600000000002, + "PaymentDate": "2005-08-17T00:51:32Z", + "LastUpdate": "2019-04-11T18:11:50Z" + }, + { + "PaymentID": 203, + "CustomerID": 7, + "StaffID": 2, + "RentalID": 13373, + "Amount": 5.1428, + "PaymentDate": "2005-08-19T21:23:31Z", + "LastUpdate": "2019-04-11T18:11:50Z" + }, + { + "PaymentID": 204, + "CustomerID": 7, + "StaffID": 1, + "RentalID": 13476, + "Amount": 3.6179, + "PaymentDate": "2005-08-20T01:06:04Z", + "LastUpdate": "2019-04-11T18:11:50Z" + }, + { + "PaymentID": 207, + "CustomerID": 8, + "StaffID": 2, + "RentalID": 866, + "Amount": 7.1298, + "PaymentDate": "2005-05-30T03:43:54Z", + "LastUpdate": "2019-04-11T18:11:50Z" + } +] +`) + }) + +}