diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c17958..2f3b8dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.4.2] - 2014-04-10 +### Added +- added `OrderByBuilder` to prevent sql injection (#32) + ## [1.4.1] - 2014-04-09 ### Added - added multi-dht support on `DB` (#31) diff --git a/sqlbuilder_orderby.go b/sqlbuilder_orderby.go new file mode 100644 index 0000000..5691a56 --- /dev/null +++ b/sqlbuilder_orderby.go @@ -0,0 +1,65 @@ +package sqle + +import ( + "slices" + "strings" +) + +type OrderByBuilder struct { + *Builder + isWritten bool + allowedColumns []string +} + +// OrderBy create an OrderByBuilder with allowed columns to prevent sql injection. NB: any input is allowed if it is not provided +func (b *Builder) OrderBy(allowedColumns ...string) *OrderByBuilder { + ob := &OrderByBuilder{ + Builder: b, + allowedColumns: allowedColumns, + } + + b.SQL(" ORDER BY ") + + return ob +} + +// isAllowed check if column is included in allowed columns. It will remove any untrust input from client +func (ob *OrderByBuilder) isAllowed(col string) bool { + if ob.allowedColumns == nil { + return true + } + + return slices.ContainsFunc(ob.allowedColumns, func(c string) bool { + return strings.EqualFold(c, col) + }) +} + +// Asc order by ASC with columns +func (ob *OrderByBuilder) Asc(columns ...string) *OrderByBuilder { + for _, c := range columns { + if ob.isAllowed(c) { + if ob.isWritten { + ob.Builder.SQL(", ").SQL(c).SQL(" ASC") + } else { + ob.Builder.SQL(c).SQL(" ASC") + ob.isWritten = true + } + } + } + return ob +} + +// Desc order by desc with columns +func (ob *OrderByBuilder) Desc(columns ...string) *OrderByBuilder { + for _, c := range columns { + if ob.isAllowed(c) { + if ob.isWritten { + ob.Builder.SQL(", ").SQL(c).SQL(" DESC") + } else { + ob.Builder.SQL(c).SQL(" DESC") + ob.isWritten = true + } + } + } + return ob +} diff --git a/sqlbuilder_orderby_test.go b/sqlbuilder_orderby_test.go new file mode 100644 index 0000000..2857d57 --- /dev/null +++ b/sqlbuilder_orderby_test.go @@ -0,0 +1,50 @@ +package sqle + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOrderByBuilder(t *testing.T) { + tests := []struct { + name string + build func() *Builder + wanted string + }{ + { + name: "no_safe_columns_should_work", + build: func() *Builder { + b := New("SELECT * FROM users") + b.OrderBy(). + Desc("created_at"). + Asc("id", "name"). + Asc("updated_at") + + return b + }, + wanted: "SELECT * FROM users ORDER BY created_at DESC, id ASC, name ASC, updated_at ASC", + }, + { + name: "safe_columns_should_work", + build: func() *Builder { + b := New("SELECT * FROM users") + b.OrderBy("id", "created_at", "updated_at"). + Asc("id", "name"). + Desc("created_at", "unsafe_input"). + Asc("updated_at") + + return b + }, + wanted: "SELECT * FROM users ORDER BY id ASC, created_at DESC, updated_at ASC", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual := test.build().String() + + require.Equal(t, test.wanted, actual) + }) + } +}