Skip to content

Commit

Permalink
Merge pull request #404 from safaci2000/feature/testTools
Browse files Browse the repository at this point in the history
Replacing several test util function with a generic version
  • Loading branch information
go-jet authored Oct 7, 2024
2 parents d17ab3d + a77ecc3 commit 31a6b95
Show file tree
Hide file tree
Showing 18 changed files with 161 additions and 214 deletions.
4 changes: 2 additions & 2 deletions examples/quick-start/quick-start.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
"io/ioutil"
"os"

_ "github.com/lib/pq"

Expand Down Expand Up @@ -90,7 +90,7 @@ func main() {
func jsonSave(path string, v interface{}) {
jsonText, _ := json.MarshalIndent(v, "", "\t")

err := ioutil.WriteFile(path, jsonText, 0644)
err := os.WriteFile(path, jsonText, 0644)

panicOnError(err)
}
Expand Down
79 changes: 4 additions & 75 deletions internal/testutils/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io/ioutil"
"os"
"path/filepath"
"runtime"
Expand Down Expand Up @@ -109,7 +108,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) {
jsonText, _ := json.MarshalIndent(v, "", "\t")

filePath := getFullPath(testRelativePath)
err := ioutil.WriteFile(filePath, jsonText, 0644)
err := os.WriteFile(filePath, jsonText, 0644)

throw.OnError(err)
}
Expand All @@ -118,7 +117,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) {
func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {

filePath := getFullPath(testRelativePath)
fileJSONData, err := ioutil.ReadFile(filePath)
fileJSONData, err := os.ReadFile(filePath)
require.NoError(t, err)

if runtime.GOOS == "windows" {
Expand Down Expand Up @@ -245,7 +244,7 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest inter

// AssertFileContent check if file content at filePath contains expectedContent text.
func AssertFileContent(t *testing.T, filePath string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath)
enumFileData, err := os.ReadFile(filePath)

require.NoError(t, err)

Expand All @@ -254,7 +253,7 @@ func AssertFileContent(t *testing.T, filePath string, expectedContent string) {

// AssertFileNamesEqual check if all filesInfos are contained in fileNames
func AssertFileNamesEqual(t *testing.T, dirPath string, fileNames ...string) {
files, err := ioutil.ReadDir(dirPath)
files, err := os.ReadDir(dirPath)
require.NoError(t, err)

require.Equal(t, len(files), len(fileNames))
Expand Down Expand Up @@ -293,76 +292,6 @@ func printDiff(actual, expected interface{}, options ...cmp.Option) {
fmt.Println(expected)
}

// BoolPtr returns address of bool parameter
func BoolPtr(b bool) *bool {
return &b
}

// Int8Ptr returns address of int8 parameter
func Int8Ptr(i int8) *int8 {
return &i
}

// UInt8Ptr returns address of uint8 parameter
func UInt8Ptr(i uint8) *uint8 {
return &i
}

// Int16Ptr returns address of int16 parameter
func Int16Ptr(i int16) *int16 {
return &i
}

// UInt16Ptr returns address of uint16 parameter
func UInt16Ptr(i uint16) *uint16 {
return &i
}

// Int32Ptr returns address of int32 parameter
func Int32Ptr(i int32) *int32 {
return &i
}

// UInt32Ptr returns address of uint32 parameter
func UInt32Ptr(i uint32) *uint32 {
return &i
}

// Int64Ptr returns address of int64 parameter
func Int64Ptr(i int64) *int64 {
return &i
}

// UInt64Ptr returns address of uint64 parameter
func UInt64Ptr(i uint64) *uint64 {
return &i
}

// StringPtr returns address of string parameter
func StringPtr(s string) *string {
return &s
}

// TimePtr returns address of time.Time parameter
func TimePtr(t time.Time) *time.Time {
return &t
}

// ByteArrayPtr returns address of []byte parameter
func ByteArrayPtr(arr []byte) *[]byte {
return &arr
}

// Float32Ptr returns address of float32 parameter
func Float32Ptr(f float32) *float32 {
return &f
}

// Float64Ptr returns address of float64 parameter
func Float64Ptr(f float64) *float64 {
return &f
}

// UUIDPtr returns address of uuid.UUID
func UUIDPtr(u string) *uuid.UUID {
newUUID := uuid.MustParse(u)
Expand Down
6 changes: 6 additions & 0 deletions internal/utils/ptr/ptr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package ptr

// Of returns the address of any given parameter
func Of[T any](value T) *T {
return &value
}
2 changes: 2 additions & 0 deletions tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ setup: checkout-testdata docker-compose-up
checkout-testdata:
git submodule init
git submodule update
#
checkout-latest-testdata: checkout-testdata
cd ./testdata && git fetch && git checkout master && git pull

# docker-compose-up will download docker image for each of the databases listed in docker-compose.yaml file, and then it will initialize
Expand Down
2 changes: 1 addition & 1 deletion tests/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ services:
- ./testdata/init/postgres:/docker-entrypoint-initdb.d

mysql:
image: mysql:8.0.27
image: mysql:8.0
command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1']
restart: always
environment:
Expand Down
3 changes: 1 addition & 2 deletions tests/init/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/go-jet/jet/v2/generator/sqlite"
"github.com/go-jet/jet/v2/internal/utils/errfmt"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
"io/ioutil"
"os"
"os/exec"
"strings"
Expand Down Expand Up @@ -184,7 +183,7 @@ func initPostgresDB(dbType string, connectionString string) error {
}

func execFile(db *sql.DB, sqlFilePath string) error {
testSampleSql, err := ioutil.ReadFile(sqlFilePath)
testSampleSql, err := os.ReadFile(sqlFilePath)
if err != nil {
return fmt.Errorf("failed to read sql file - %s: %w", sqlFilePath, err)
}
Expand Down
5 changes: 2 additions & 3 deletions tests/internal/utils/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package file

import (
"github.com/stretchr/testify/require"
"io/ioutil"
"os"
"path"
"testing"
Expand All @@ -11,7 +10,7 @@ import (
// Exists expects file to exist on path constructed from pathElems and returns content of the file
func Exists(t *testing.T, pathElems ...string) (fileContent string) {
modelFilePath := path.Join(pathElems...)
file, err := ioutil.ReadFile(modelFilePath)
file, err := os.ReadFile(modelFilePath)
require.Nil(t, err)
require.NotEmpty(t, file)
return string(file)
Expand All @@ -20,6 +19,6 @@ func Exists(t *testing.T, pathElems ...string) (fileContent string) {
// NotExists expects file not to exist on path constructed from pathElems
func NotExists(t *testing.T, pathElems ...string) {
modelFilePath := path.Join(pathElems...)
_, err := ioutil.ReadFile(modelFilePath)
_, err := os.ReadFile(modelFilePath)
require.True(t, os.IsNotExist(err))
}
67 changes: 34 additions & 33 deletions tests/mysql/alltypes_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mysql

import (
"github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"strings"
Expand Down Expand Up @@ -1067,7 +1068,7 @@ func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) {

var toInsert = model.AllTypes{
Boolean: false,
BooleanPtr: testutils.BoolPtr(true),
BooleanPtr: ptr.Of(true),
TinyInt: 1,
UTinyInt: 2,
SmallInt: 3,
Expand All @@ -1078,53 +1079,53 @@ var toInsert = model.AllTypes{
UInteger: 8,
BigInt: 9,
UBigInt: 1122334455,
TinyIntPtr: testutils.Int8Ptr(11),
UTinyIntPtr: testutils.UInt8Ptr(22),
SmallIntPtr: testutils.Int16Ptr(33),
USmallIntPtr: testutils.UInt16Ptr(44),
MediumIntPtr: testutils.Int32Ptr(55),
UMediumIntPtr: testutils.UInt32Ptr(66),
IntegerPtr: testutils.Int32Ptr(77),
UIntegerPtr: testutils.UInt32Ptr(88),
BigIntPtr: testutils.Int64Ptr(99),
UBigIntPtr: testutils.UInt64Ptr(111),
TinyIntPtr: ptr.Of(int8(11)),
UTinyIntPtr: ptr.Of(uint8(22)),
SmallIntPtr: ptr.Of(int16(33)),
USmallIntPtr: ptr.Of(uint16(44)),
MediumIntPtr: ptr.Of(int32(55)),
UMediumIntPtr: ptr.Of(uint32(66)),
IntegerPtr: ptr.Of(int32(77)),
UIntegerPtr: ptr.Of(uint32(88)),
BigIntPtr: ptr.Of(int64(99)),
UBigIntPtr: ptr.Of(uint64(111)),
Decimal: 11.22,
DecimalPtr: testutils.Float64Ptr(33.44),
DecimalPtr: ptr.Of(33.44),
Numeric: 55.66,
NumericPtr: testutils.Float64Ptr(77.88),
NumericPtr: ptr.Of(77.88),
Float: 99.00,
FloatPtr: testutils.Float64Ptr(11.22),
FloatPtr: ptr.Of(11.22),
Double: 33.44,
DoublePtr: testutils.Float64Ptr(55.66),
DoublePtr: ptr.Of(55.66),
Real: 77.88,
RealPtr: testutils.Float64Ptr(99.00),
RealPtr: ptr.Of(99.00),
Bit: "1",
BitPtr: testutils.StringPtr("0"),
BitPtr: ptr.Of("0"),
Time: time.Date(1, 1, 1, 10, 11, 12, 100, &time.Location{}),
TimePtr: testutils.TimePtr(time.Date(1, 1, 1, 10, 11, 12, 100, time.UTC)),
TimePtr: ptr.Of(time.Date(1, 1, 1, 10, 11, 12, 100, time.UTC)),
Date: time.Now(),
DatePtr: testutils.TimePtr(time.Now()),
DatePtr: ptr.Of(time.Now()),
DateTime: time.Now(),
DateTimePtr: testutils.TimePtr(time.Now()),
DateTimePtr: ptr.Of(time.Now()),
Timestamp: time.Now(),
//TimestampPtr: testutils.TimePtr(time.Now()), // TODO: build fails for MariaDB
Year: 2000,
YearPtr: testutils.Int16Ptr(2001),
YearPtr: ptr.Of(int16(2001)),
Char: "abcd",
CharPtr: testutils.StringPtr("absd"),
CharPtr: ptr.Of("absd"),
VarChar: "abcd",
VarCharPtr: testutils.StringPtr("absd"),
VarCharPtr: ptr.Of("absd"),
Binary: []byte("1010"),
BinaryPtr: testutils.ByteArrayPtr([]byte("100001")),
BinaryPtr: ptr.Of([]byte("100001")),
VarBinary: []byte("1010"),
VarBinaryPtr: testutils.ByteArrayPtr([]byte("100001")),
VarBinaryPtr: ptr.Of([]byte("100001")),
Blob: []byte("large file"),
BlobPtr: testutils.ByteArrayPtr([]byte("very large file")),
BlobPtr: ptr.Of([]byte("very large file")),
Text: "some text",
TextPtr: testutils.StringPtr("text"),
TextPtr: ptr.Of("text"),
Enum: model.AllTypesEnum_Value1,
JSON: "{}",
JSONPtr: testutils.StringPtr(`{"a": 1}`),
JSONPtr: ptr.Of(`{"a": 1}`),
}

var allTypesJson = `
Expand Down Expand Up @@ -1358,17 +1359,17 @@ func TestExactDecimals(t *testing.T) {
Floats: model.Floats{
// overwritten by wrapped(floats) scope
Numeric: 0.1,
NumericPtr: testutils.Float64Ptr(0.1),
NumericPtr: ptr.Of(0.1),
Decimal: 0.1,
DecimalPtr: testutils.Float64Ptr(0.1),
DecimalPtr: ptr.Of(0.1),

// not overwritten
Float: 0.2,
FloatPtr: testutils.Float64Ptr(0.22),
FloatPtr: ptr.Of(0.22),
Double: 0.3,
DoublePtr: testutils.Float64Ptr(0.33),
DoublePtr: ptr.Of(0.33),
Real: 0.4,
RealPtr: testutils.Float64Ptr(0.44),
RealPtr: ptr.Of(0.44),
},
Numeric: decimal.RequireFromString("12.35"),
NumericPtr: decimal.RequireFromString("56.79"),
Expand Down
6 changes: 4 additions & 2 deletions tests/mysql/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"context"
"database/sql"
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/go-jet/jet/v2/internal/utils/ptr"
. "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/table"

"github.com/stretchr/testify/require"
"math/rand"
"testing"
Expand Down Expand Up @@ -300,7 +302,7 @@ func TestInsertOnDuplicateKeyUpdateNEW(t *testing.T) {
ID: randId,
URL: "https://www.yahoo.com",
Name: "Yahoo",
Description: testutils.StringPtr("web portal and search engine"),
Description: ptr.Of("web portal and search engine"),
},
}).AS_NEW().
ON_DUPLICATE_KEY_UPDATE(
Expand Down Expand Up @@ -337,7 +339,7 @@ ON DUPLICATE KEY UPDATE id = (link.id + ?),
ID: randId + 11,
URL: "https://www.yahoo.com",
Name: "Yahoo",
Description: testutils.StringPtr("web portal and search engine"),
Description: ptr.Of("web portal and search engine"),
})
})
}
Expand Down
Loading

0 comments on commit 31a6b95

Please sign in to comment.