Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ROW expressions and VALUES statement #410

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/jet/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Dialect interface {
ArgumentPlaceholder() QueryPlaceholderFunc
IsReservedWord(name string) bool
SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName(index int) string
}

// SerializerFunc func
Expand All @@ -35,6 +36,7 @@ type DialectParams struct {
ArgumentPlaceholder QueryPlaceholderFunc
ReservedWords []string
SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName func(index int) string
}

// NewDialect creates new dialect with params
Expand All @@ -49,6 +51,7 @@ func NewDialect(params DialectParams) Dialect {
argumentPlaceholder: params.ArgumentPlaceholder,
reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords),
serializeOrderBy: params.SerializeOrderBy,
valuesDefaultColumnName: params.ValuesDefaultColumnName,
}
}

Expand All @@ -62,6 +65,7 @@ type dialectImpl struct {
argumentPlaceholder QueryPlaceholderFunc
reservedWords map[string]bool
serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
valuesDefaultColumnName func(index int) string
}

func (d *dialectImpl) Name() string {
Expand Down Expand Up @@ -107,6 +111,10 @@ func (d *dialectImpl) SerializeOrderBy() func(expression Expression, ascending,
return d.serializeOrderBy
}

func (d *dialectImpl) ValuesDefaultColumnName(index int) string {
return d.valuesDefaultColumnName(index)
}

func arrayOfStringsToMapOfStrings(arr []string) map[string]bool {
ret := map[string]bool{}
for _, elem := range arr {
Expand Down
17 changes: 4 additions & 13 deletions internal/jet/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression {

// IN checks if this expressions matches any in expressions list
func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN")
return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "IN")
}

// NOT_IN checks if this expressions is different of all expressions in expressions list
func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN")
return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "NOT IN")
}

// AS the temporary alias name to assign to the expression
Expand Down Expand Up @@ -316,15 +316,6 @@ func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder,
}
}

type skipParenthesisWrap struct {
Expression
}

func skipWrap(expression Expression) Expression {
return &skipParenthesisWrap{expression}
}

// since the expression is a function parameter, there is no need to wrap it in parentheses
func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.Expression.serialize(statement, out, append(options, NoWrap)...)
func wrap(expressions ...Expression) Expression {
return NewFunc("", expressions, nil)
}
7 changes: 1 addition & 6 deletions internal/jet/func_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@ func OR(expressions ...BoolExpression) BoolExpression {
return newBoolExpressionListOperator("OR", expressions...)
}

// ROW function is used to create a tuple value that consists of a set of expressions or column values.
func ROW(expressions ...Expression) Expression {
return NewFunc("ROW", expressions, nil)
}

// ------------------ Mathematical functions ---------------//

// ABSf calculates absolute value from float expression
Expand Down Expand Up @@ -711,7 +706,7 @@ func (p parametersSerializer) serialize(statement StatementType, out *SQLBuilder
if _, isStatement := expression.(Statement); isStatement {
expression.serialize(statement, out, options...)
} else {
skipWrap(expression).serialize(statement, out, options...)
expression.serialize(statement, out, append(options, NoWrap, Ident)...)
}
}
}
Expand Down
26 changes: 0 additions & 26 deletions internal/jet/literal_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,32 +374,6 @@ func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, option

//---------------------------------------------------//

type wrap struct {
ExpressionInterfaceImpl
expressions []Expression
}

func (n *wrap) serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")

if len(n.expressions) == 1 {
options = append(options, NoWrap, Ident)
}
serializeExpressionList(statementType, n.expressions, ", ", out, options...)

out.WriteString(")")
}

// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... )
func WRAP(expression ...Expression) Expression {
wrap := &wrap{expressions: expression}
wrap.ExpressionInterfaceImpl.Parent = wrap

return wrap
}

//---------------------------------------------------//

type rawExpression struct {
ExpressionInterfaceImpl

Expand Down
7 changes: 6 additions & 1 deletion internal/jet/order_set_aggregate_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ type orderSetAggregateFuncExpression struct {

func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString(p.name)
WRAP(p.fraction).serialize(statement, out, FallTrough(options)...)

if p.fraction != nil {
wrap(p.fraction).serialize(statement, out, FallTrough(options)...)
} else {
wrap().serialize(statement, out, FallTrough(options)...)
}
out.WriteString("WITHIN GROUP")
p.orderBy.serialize(statement, out)
}
2 changes: 1 addition & 1 deletion internal/jet/projection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ AVG(table1.col_int) AS "avg",
table2.col3 AS "col3",
table2.col4 AS "col4"`)

subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery"))
subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery", nil))

assertProjectionSerialize(t, subQueryProjections,
`"subQuery"."table1.col3" AS "table1.col3",
Expand Down
2 changes: 1 addition & 1 deletion internal/jet/raw_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type rawStatementImpl struct {
}

// RawStatement creates new sql statements from raw query and optional map of named arguments
func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) Statement {
func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) SerializerStatement {
newRawStatement := rawStatementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
dialect: dialect,
Expand Down
102 changes: 102 additions & 0 deletions internal/jet/row_expression.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package jet

// RowExpression interface
type RowExpression interface {
Expression
HasProjections

EQ(rhs RowExpression) BoolExpression
NOT_EQ(rhs RowExpression) BoolExpression
IS_DISTINCT_FROM(rhs RowExpression) BoolExpression
IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression

LT(rhs RowExpression) BoolExpression
LT_EQ(rhs RowExpression) BoolExpression
GT(rhs RowExpression) BoolExpression
GT_EQ(rhs RowExpression) BoolExpression
}

type rowInterfaceImpl struct {
parent Expression
dialect Dialect
elemCount int
}

func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression {
return Eq(n.parent, rhs)
}

func (n *rowInterfaceImpl) NOT_EQ(rhs RowExpression) BoolExpression {
return NotEq(n.parent, rhs)
}

func (n *rowInterfaceImpl) IS_DISTINCT_FROM(rhs RowExpression) BoolExpression {
return IsDistinctFrom(n.parent, rhs)
}

func (n *rowInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression {
return IsNotDistinctFrom(n.parent, rhs)
}

func (n *rowInterfaceImpl) GT(rhs RowExpression) BoolExpression {
return Gt(n.parent, rhs)
}

func (n *rowInterfaceImpl) GT_EQ(rhs RowExpression) BoolExpression {
return GtEq(n.parent, rhs)
}

func (n *rowInterfaceImpl) LT(rhs RowExpression) BoolExpression {
return Lt(n.parent, rhs)
}

func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression {
return LtEq(n.parent, rhs)
}

func (n *rowInterfaceImpl) projections() ProjectionList {
var ret ProjectionList

for i := 0; i < n.elemCount; i++ {
rowColumn := NewColumnImpl(n.dialect.ValuesDefaultColumnName(i), "", nil)
ret = append(ret, &rowColumn)
}

return ret
}

// ---------------------------------------------------//
type rowExpressionWrapper struct {
rowInterfaceImpl
Expression
}

func newRowExpression(name string, dialect Dialect, expressions ...Expression) RowExpression {
ret := &rowExpressionWrapper{}
ret.rowInterfaceImpl.parent = ret

ret.Expression = NewFunc(name, expressions, ret)
ret.dialect = dialect
ret.elemCount = len(expressions)

return ret
}

// ROW function is used to create a tuple value that consists of a set of expressions or column values.
func ROW(dialect Dialect, expressions ...Expression) RowExpression {
return newRowExpression("ROW", dialect, expressions...)
}

// WRAP creates row expressions without ROW keyword `( expression1, expression2, ... )`.
func WRAP(dialect Dialect, expressions ...Expression) RowExpression {
return newRowExpression("", dialect, expressions...)
}

// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
// This enables the Go compiler to interpret any expression as a row expression
// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
func RowExp(expression Expression) RowExpression {
rowExpressionWrap := rowExpressionWrapper{Expression: expression}
rowExpressionWrap.rowInterfaceImpl.parent = &rowExpressionWrap
return &rowExpressionWrap
}
28 changes: 22 additions & 6 deletions internal/jet/select_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,21 @@ type SelectTable interface {
}

type selectTableImpl struct {
Statement SerializerHasProjections
alias string
Statement SerializerHasProjections
alias string
columnAliases []ColumnExpression
}

// NewSelectTable func
func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl {
func NewSelectTable(selectStmt SerializerHasProjections, alias string, columnAliases []ColumnExpression) selectTableImpl {
selectTable := selectTableImpl{
Statement: selectStmt,
alias: alias,
Statement: selectStmt,
alias: alias,
columnAliases: columnAliases,
}

for _, column := range selectTable.columnAliases {
column.setSubQuery(selectTable)
}

return selectTable
Expand All @@ -31,6 +37,10 @@ func (s selectTableImpl) Alias() string {
}

func (s selectTableImpl) AllColumns() ProjectionList {
if len(s.columnAliases) > 0 {
return ColumnListToProjectionList(s.columnAliases)
}

projectionList := s.projections().fromImpl(s)
return projectionList.(ProjectionList)
}
Expand All @@ -40,6 +50,12 @@ func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, opt

out.WriteString("AS")
out.WriteIdentifier(s.alias)

if len(s.columnAliases) > 0 {
out.WriteByte('(')
SerializeColumnExpressionNames(s.columnAliases, out)
out.WriteByte(')')
}
}

// --------------------------------------
Expand All @@ -50,7 +66,7 @@ type lateralImpl struct {

// NewLateral creates new lateral expression from select statement with alias
func NewLateral(selectStmt SerializerStatement, alias string) SelectTable {
return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias)}
return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias, nil)}
}

func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
Expand Down
35 changes: 35 additions & 0 deletions internal/jet/values.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package jet

// Values hold a set of one or more rows
type Values []RowExpression

func (v Values) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteByte('(')
out.IncreaseIdent(5)

out.NewLine()
out.WriteString("VALUES")

for rowIndex, row := range v {
if rowIndex > 0 {
out.WriteString(",")
out.NewLine()
} else {
out.IncreaseIdent(7)
}

row.serialize(statement, out, options...)
}
out.DecreaseIdent(7)
out.DecreaseIdent(5)
out.NewLine()
out.WriteByte(')')
}

func (v Values) projections() ProjectionList {
if len(v) == 0 {
return nil
}

return v[0].projections()
}
11 changes: 1 addition & 10 deletions internal/jet/with_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type CommonTableExpression struct {
// CTE creates new named CommonTableExpression
func CTE(name string, columns ...ColumnExpression) CommonTableExpression {
cte := CommonTableExpression{
selectTableImpl: NewSelectTable(nil, name),
selectTableImpl: NewSelectTable(nil, name, columns),
Columns: columns,
}

Expand Down Expand Up @@ -99,12 +99,3 @@ func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilde
out.WriteIdentifier(c.alias)
}
}

// AllColumns returns list of all projections in the CTE
func (c CommonTableExpression) AllColumns() ProjectionList {
if len(c.Columns) > 0 {
return ColumnListToProjectionList(c.Columns)
}

return c.selectTableImpl.AllColumns()
}
Loading
Loading