diff --git a/shardid/id.go b/shardid/id.go index f960b50..c547b4b 100644 --- a/shardid/id.go +++ b/shardid/id.go @@ -2,6 +2,7 @@ package shardid import ( "database/sql/driver" + "encoding/json" "errors" "time" ) @@ -98,3 +99,18 @@ func (b *ID) Scan(src interface{}) error { // skipcq: GO-W1029 b.WorkerID = id.WorkerID return nil } + +// MarshalJSON implements the json.Marshaler interface +func (id ID) MarshalJSON() ([]byte, error) { + return json.Marshal(id.Int64) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (id *ID) UnmarshalJSON(data []byte) error { + var value int64 + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *id = Parse(value) + return nil +} diff --git a/shardid/id_test.go b/shardid/id_test.go index 7cf7edd..1ddbd3d 100644 --- a/shardid/id_test.go +++ b/shardid/id_test.go @@ -2,6 +2,7 @@ package shardid import ( "database/sql" + "encoding/json" "fmt" "math/rand" "testing" @@ -97,7 +98,7 @@ func TestID(t *testing.T) { } } -func TestSQLDriver(t *testing.T) { +func TestIDInSQL(t *testing.T) { d, err := sql.Open("sqlite3", "file::memory:") require.NoError(t, err) @@ -127,3 +128,29 @@ func TestSQLDriver(t *testing.T) { require.Equal(t, id.WorkerID, i.WorkerID) } + +func TestIdInJSON(t *testing.T) { + + now := time.Now() + id := Build(now.UnixMilli(), 1, 2, MonthlyRotate, 3) + + idInt64 := id.Int64 + + bufID, err := json.Marshal(id) + require.NoError(t, err) + + bufIdInt64, err := json.Marshal(idInt64) + require.NoError(t, err) + + require.Equal(t, bufIdInt64, bufID) + + var jsID ID + err = json.Unmarshal(bufIdInt64, &jsID) + require.NoError(t, err) + require.Equal(t, id, jsID) + + var jsIdInt64 int64 + err = json.Unmarshal(bufID, &jsIdInt64) + require.NoError(t, err) + require.Equal(t, idInt64, jsIdInt64) +}