diff --git a/bool_test.go b/bool_test.go index 5c479e6..fc35062 100644 --- a/bool_test.go +++ b/bool_test.go @@ -11,7 +11,7 @@ func TestBool(t *testing.T) { d, err := sql.Open("sqlite3", "file::memory:") require.NoError(t, err) - _, err = d.Exec("CREATE TABLE `users` (`id` id NOT NULL,`status` BIT(1), PRIMARY KEY (`id`))") + _, err = d.Exec("CREATE TABLE `users` (`id` int NOT NULL,`status` BIT(1), PRIMARY KEY (`id`))") require.NoError(t, err) result, err := d.Exec("INSERT INTO `users`(`id`, `status`) VALUES(?, ?)", 10, Bool(true)) diff --git a/string.go b/string.go new file mode 100644 index 0000000..4579bd8 --- /dev/null +++ b/string.go @@ -0,0 +1,57 @@ +package sqle + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" +) + +type String struct { + sql.NullString +} + +func NewString(s string) String { + return String{NullString: sql.NullString{String: s, Valid: true}} +} + +// Scan implements the [sql.Scanner] interface. +func (t *String) Scan(value any) error { // skipcq: GO-W1029 + return t.NullString.Scan(value) +} + +// Value implements the [driver.Valuer] interface. +func (t String) Value() (driver.Value, error) { // skipcq: GO-W1029 + return t.NullString.Value() +} + +// Time returns the underlying time.Time value of the Time struct. +func (t *String) String() string { // skipcq: GO-W1029 + return t.NullString.String +} + +// MarshalJSON implements the json.Marshaler interface +func (t String) MarshalJSON() ([]byte, error) { // skipcq: GO-W1029 + if t.Valid { + return json.Marshal(t.NullString.String) + } + return nullJsonBytes, nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (t *String) UnmarshalJSON(data []byte) error { // skipcq: GO-W1029 + if len(data) == 0 || string(data) == nullJson { + t.NullString.Valid = false + return nil + } + + var v string + err := json.Unmarshal(data, &v) + if err != nil { + return err + } + + t.NullString.String = v + t.NullString.Valid = true + + return nil +} diff --git a/string_test.go b/string_test.go new file mode 100644 index 0000000..4e1950d --- /dev/null +++ b/string_test.go @@ -0,0 +1,112 @@ +package sqle + +import ( + "database/sql" + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStringInSQL(t *testing.T) { + + v := "has value" + d, err := sql.Open("sqlite3", "file::memory:") + require.NoError(t, err) + + _, err = d.Exec("CREATE TABLE `strings` (`id` int NOT NULL,`name` varchar(125), PRIMARY KEY (`id`))") + require.NoError(t, err) + + result, err := d.Exec("INSERT INTO `strings`(`id`) VALUES(?)", 10) + require.NoError(t, err) + + rows, err := result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + result, err = d.Exec("INSERT INTO `strings`(`id`, `name`) VALUES(?, ?)", 20, v) + require.NoError(t, err) + + rows, err = result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + var v10 String + err = d.QueryRow("SELECT `name` FROM `strings` WHERE id=?", 10).Scan(&v10) + require.NoError(t, err) + + require.EqualValues(t, false, v10.Valid) + + var v20 String + err = d.QueryRow("SELECT `name` FROM `strings` WHERE id=?", 20).Scan(&v20) + require.NoError(t, err) + + require.EqualValues(t, true, v20.Valid) + require.EqualValues(t, v, v20.String()) + + result, err = d.Exec("INSERT INTO `strings`(`id`,`name`) VALUES(?, ?)", 11, v10) + require.NoError(t, err) + + rows, err = result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + result, err = d.Exec("INSERT INTO `strings`(`id`, `name`) VALUES(?, ?)", 21, v20) + require.NoError(t, err) + + rows, err = result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + var v11 String + err = d.QueryRow("SELECT `name` FROM `strings` WHERE id=?", 11).Scan(&v11) + require.NoError(t, err) + + require.EqualValues(t, false, v11.Valid) + + var v21 String + err = d.QueryRow("SELECT `name` FROM `strings` WHERE id=?", 21).Scan(&v21) + require.NoError(t, err) + + require.EqualValues(t, true, v21.Valid) + require.EqualValues(t, v, v21.String()) + +} + +func TestStringInJSON(t *testing.T) { + + sysString := "has value" + + bufSysString, err := json.Marshal(sysString) + require.NoError(t, err) + + sqleString := NewString(sysString) + + bufSqleString, err := json.Marshal(sqleString) + require.NoError(t, err) + + require.Equal(t, bufSysString, bufSqleString) + + var jsSqleString String + // Unmarshal sqle.Time from time.Time json bytes + err = json.Unmarshal(bufSysString, &jsSqleString) + require.NoError(t, err) + + require.Equal(t, sysString, jsSqleString.String()) + require.Equal(t, true, jsSqleString.Valid) + + var jsSysString string + // Unmarshal time.Time from sqle.Time json bytes + err = json.Unmarshal(bufSqleString, &jsSysString) + require.NoError(t, err) + require.Equal(t, sysString, jsSysString) + + var nullString String + err = json.Unmarshal([]byte("null"), &nullString) + require.NoError(t, err) + require.Equal(t, false, nullString.Valid) + + bufNull, err := json.Marshal(nullString) + require.NoError(t, err) + require.Equal(t, []byte("null"), bufNull) +} diff --git a/time.go b/time.go index d28b8b7..2c43489 100644 --- a/time.go +++ b/time.go @@ -7,10 +7,6 @@ import ( "time" ) -var nullTimeJsonBytes = []byte("null") - -const nullTimeJson = "null" - // Time represents a nullable time value. type Time struct { sql.NullTime @@ -41,12 +37,12 @@ func (t Time) MarshalJSON() ([]byte, error) { // skipcq: GO-W1029 if t.Valid { return json.Marshal(t.NullTime.Time) } - return nullTimeJsonBytes, nil + return nullJsonBytes, nil } // UnmarshalJSON implements the json.Unmarshaler interface func (t *Time) UnmarshalJSON(data []byte) error { // skipcq: GO-W1029 - if len(data) == 0 || string(data) == nullTimeJson { + if len(data) == 0 || string(data) == nullJson { t.NullTime.Time = time.Time{} t.NullTime.Valid = false return nil