Skip to content

Commit

Permalink
fix(orderby): improved OrderByBuilder for api input (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi authored Apr 12, 2024
1 parent 4a878bb commit c2510fe
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 15 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.4.3] - 2014-04-12
### Fixes
- fixed close issue when it fails to build prepareStmt (#33)
- improved `OrderByBuilder` for api input (#34)

## [1.4.2] - 2014-04-10
### Added
- added `OrderByBuilder` to prevent sql injection (#32)
- added `OrderByBuilder` to prevent sql injection (#32, #33)

## [1.4.1] - 2014-04-09
### Added
Expand Down
39 changes: 33 additions & 6 deletions sqlbuilder_orderby.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ type OrderByBuilder struct {
allowedColumns []string
}

// OrderBy create an OrderByBuilder with allowed columns to prevent sql injection. NB: any input is allowed if it is not provided
func (b *Builder) OrderBy(allowedColumns ...string) *OrderByBuilder {
// Order create an OrderByBuilder with allowed columns to prevent sql injection. NB: any input is allowed if it is not provided
func (b *Builder) Order(allowedColumns ...string) *OrderByBuilder {
ob := &OrderByBuilder{
Builder: b,
allowedColumns: allowedColumns,
Expand All @@ -34,8 +34,35 @@ func (ob *OrderByBuilder) isAllowed(col string) bool {
})
}

// Asc order by ASC with columns
func (ob *OrderByBuilder) Asc(columns ...string) *OrderByBuilder {
// By order by raw sql. eg By("a asc, b desc")
func (ob *OrderByBuilder) By(raw string) *OrderByBuilder {
cols := strings.Split(raw, ",")

var n int
var items []string
var by string
for _, col := range cols {
items = strings.Split(strings.TrimSpace(col), " ")
n = len(items)
switch n {
case 1:
ob.ByAsc(strings.TrimSpace(col))
case 2:
by = strings.TrimSpace(items[1])
if strings.EqualFold(by, "ASC") {
ob.ByAsc(strings.TrimSpace(items[0]))
} else if strings.EqualFold(by, "DESC") {
ob.ByDesc(strings.TrimSpace(items[0]))
}
}
}

return ob

}

// ByAsc order by ascending with columns
func (ob *OrderByBuilder) ByAsc(columns ...string) *OrderByBuilder {
for _, c := range columns {
if ob.isAllowed(c) {
if ob.isWritten {
Expand All @@ -49,8 +76,8 @@ func (ob *OrderByBuilder) Asc(columns ...string) *OrderByBuilder {
return ob
}

// Desc order by desc with columns
func (ob *OrderByBuilder) Desc(columns ...string) *OrderByBuilder {
// ByDesc order by descending with columns
func (ob *OrderByBuilder) ByDesc(columns ...string) *OrderByBuilder {
for _, c := range columns {
if ob.isAllowed(c) {
if ob.isWritten {
Expand Down
27 changes: 19 additions & 8 deletions sqlbuilder_orderby_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ func TestOrderByBuilder(t *testing.T) {
name: "no_safe_columns_should_work",
build: func() *Builder {
b := New("SELECT * FROM users")
b.OrderBy().
Desc("created_at").
Asc("id", "name").
Asc("updated_at")
b.Order().
ByDesc("created_at").
ByAsc("id", "name").
ByAsc("updated_at")

return b
},
Expand All @@ -29,15 +29,26 @@ func TestOrderByBuilder(t *testing.T) {
name: "safe_columns_should_work",
build: func() *Builder {
b := New("SELECT * FROM users")
b.OrderBy("id", "created_at", "updated_at").
Asc("id", "name").
Desc("created_at", "unsafe_input").
Asc("updated_at")
b.Order("id", "created_at", "updated_at").
ByAsc("id", "name").
ByDesc("created_at", "unsafe_input").
ByAsc("updated_at")

return b
},
wanted: "SELECT * FROM users ORDER BY id ASC, created_at DESC, updated_at ASC",
},
{
name: "order_by_raw_sql_should_work",
build: func() *Builder {
b := New("SELECT * FROM users")
b.Order("id", "created_at", "updated_at", "age").
By("created_at desc, id, name asc, updated_at asc, age invalid_by, unsafe_asc, unsafe_desc desc")

return b
},
wanted: "SELECT * FROM users ORDER BY created_at DESC, id ASC, updated_at ASC",
},
}

for _, test := range tests {
Expand Down

0 comments on commit c2510fe

Please sign in to comment.