Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Message ID can be int, string, or null as per OpenRPC spec #48

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (e *ErrClient) Unwrap(err error) error {
type clientResponse struct {
Jsonrpc string `json:"jsonrpc"`
Result json.RawMessage `json:"result"`
ID int64 `json:"id"`
ID requestID `json:"id"`
Error *respError `json:"error,omitempty"`
}

Expand Down Expand Up @@ -170,7 +170,7 @@ func httpClient(ctx context.Context, addr string, namespace string, outs []inter
return clientResponse{}, xerrors.Errorf("unmarshaling response: %w", err)
}

if resp.ID != *cr.req.ID {
if cr.req.ID.actual != resp.ID.actual {
return clientResponse{}, xerrors.New("request and response id didn't match")
}

Expand Down Expand Up @@ -240,7 +240,7 @@ func websocketClient(ctx context.Context, addr string, namespace string, outs []
req: request{
Jsonrpc: "2.0",
Method: wsCancel,
Params: []param{{v: reflect.ValueOf(*cr.req.ID)}},
Params: []param{{v: reflect.ValueOf(cr.req.ID.actual)}},
},
}
select {
Expand Down Expand Up @@ -498,7 +498,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)

req := request{
Jsonrpc: "2.0",
ID: &id,
ID: requestID{id},
Method: fn.client.namespace + "." + fn.name,
Params: params,
}
Expand Down Expand Up @@ -528,7 +528,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
return fn.processError(fmt.Errorf("sendRequest failed: %w", err))
}

if resp.ID != *req.ID {
if req.ID.actual != resp.ID.actual {
return fn.processError(xerrors.New("request and response id didn't match"))
}

Expand Down
44 changes: 38 additions & 6 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,44 @@ type rpcHandler struct {

type request struct {
Jsonrpc string `json:"jsonrpc"`
ID *int64 `json:"id,omitempty"`
ID requestID `json:"id,omitempty"`
Method string `json:"method"`
Params []param `json:"params"`
Meta map[string]string `json:"meta,omitempty"`
}

type requestID struct {
actual interface{} // nil, int64, or string
}

func (r *requestID) UnmarshalJSON(data []byte) error {
switch data[0] {
case 'n': // null
case '"': // string
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
r.actual = s
default: // number
var n int64
if err := json.Unmarshal(data, &n); err != nil {
return err
}
r.actual = n
}
return nil
}

func (r requestID) MarshalJSON() ([]byte, error) {
rvagg marked this conversation as resolved.
Show resolved Hide resolved
switch r.actual.(type) {
case nil, int64, string:
return json.Marshal(r.actual)
default:
return nil, fmt.Errorf("unexpected ID type: %T", r.actual)
}
}

// Limit request size. Ideally this limit should be specific for each field
// in the JSON request but as a simple defensive measure we just limit the
// entire HTTP body.
Expand All @@ -64,7 +96,7 @@ func (e *respError) Error() string {
type response struct {
Jsonrpc string `json:"jsonrpc"`
Result interface{} `json:"result,omitempty"`
ID int64 `json:"id"`
ID requestID `json:"id"`
Error *respError `json:"error,omitempty"`
}

Expand Down Expand Up @@ -109,7 +141,7 @@ func (s *RPCServer) register(namespace string, r interface{}) {
// Handle

type rpcErrFunc func(w func(func(io.Writer)), req *request, code int, err error)
type chanOut func(reflect.Value, int64) error
type chanOut func(reflect.Value, requestID) error

func (s *RPCServer) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) {
wf := func(cb func(io.Writer)) {
Expand Down Expand Up @@ -262,15 +294,15 @@ func (s *RPCServer) handle(ctx context.Context, req request, w func(func(io.Writ
stats.Record(ctx, metrics.RPCRequestError.M(1))
return
}
if req.ID == nil {
if req.ID.actual == nil {
return // notification
}

///////////////////

resp := response{
Jsonrpc: "2.0",
ID: *req.ID,
ID: req.ID,
}

if handler.errOut != -1 {
Expand Down Expand Up @@ -302,7 +334,7 @@ func (s *RPCServer) handle(ctx context.Context, req request, w func(func(io.Writ
// sending channel messages before this rpc call returns

//noinspection GoNilness // already checked above
err = chOut(callResult[handler.valOut], *req.ID)
err = chOut(callResult[handler.valOut], req.ID)
if err == nil {
return // channel goroutine handles responding
}
Expand Down
77 changes: 77 additions & 0 deletions rpc_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package jsonrpc

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"reflect"
"strconv"
Expand Down Expand Up @@ -360,6 +362,81 @@ func TestRPCHttpClient(t *testing.T) {
closer()
}

func TestRPCCustomHttpClient(t *testing.T) {
// setup server
serverHandler := &SimpleServerHandler{}
rpcServer := NewServer()
rpcServer.Register("SimpleServerHandler", serverHandler)
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()

// setup custom client
addr := "http://" + testServ.Listener.Addr().String()
doReq := func(reqStr string) string {
hreq, err := http.NewRequest("POST", addr, bytes.NewReader([]byte(reqStr)))
require.NoError(t, err)

hreq.Header = http.Header{}
hreq.Header.Set("Content-Type", "application/json")

httpResp, err := testServ.Client().Do(hreq)
defer httpResp.Body.Close()

respBytes, err := ioutil.ReadAll(httpResp.Body)
require.NoError(t, err)

return string(respBytes)
}

// Add(2)
reqStr := `{"jsonrpc":"2.0","method":"SimpleServerHandler.Add","params":[2],"id":100}"`
respBytes := doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","id":100}`+"\n", string(respBytes))
require.Equal(t, 2, serverHandler.n)

// Add(-3546) error
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.Add","params":[-3546],"id":1010102}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","id":1010102,"error":{"code":1,"message":"test"}}`+"\n", string(respBytes))
require.Equal(t, 2, serverHandler.n)

// AddGet(3)
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.AddGet","params":[3],"id":0}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","result":5,"id":0}`+"\n", string(respBytes))
require.Equal(t, 5, serverHandler.n)

// StringMatch("0", 0, 0)
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.StringMatch","params":[{"S":"0","I":0},0],"id":1}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","result":{"S":"0","I":0,"Ok":true},"id":1}`+"\n", string(respBytes))
require.Equal(t, 5, serverHandler.n)

// StringMatch("5", 0, 5) error
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.StringMatch","params":[{"S":"5","I":0},5],"id":2}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","id":2,"error":{"code":1,"message":":("}}`+"\n", string(respBytes))
require.Equal(t, 5, serverHandler.n)

// StringMatch("8", 8, 8) error
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.StringMatch","params":[{"S":"8","I":8},8],"id":3}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","result":{"S":"8","I":8,"Ok":true},"id":3}`+"\n", string(respBytes))
require.Equal(t, 5, serverHandler.n)

// Add(int) string ID
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.Add","params":[2],"id":"100"}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","id":"100"}`+"\n", string(respBytes))
require.Equal(t, 7, serverHandler.n)

// Add(int) random string ID
reqStr = `{"jsonrpc":"2.0","method":"SimpleServerHandler.Add","params":[2],"id":"OpenRPC says this can be whatever you want"}"`
respBytes = doReq(reqStr)
require.Equal(t, `{"jsonrpc":"2.0","id":"OpenRPC says this can be whatever you want"}`+"\n", string(respBytes))
require.Equal(t, 9, serverHandler.n)
}

type CtxHandler struct {
lk sync.Mutex

Expand Down
4 changes: 2 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ func rpcError(wf func(func(io.Writer)), req *request, code int, err error) {

log.Warnf("rpc error: %s", err)

if req.ID == nil { // notification
if req.ID.actual == nil { // notification
return
}

resp := response{
Jsonrpc: "2.0",
ID: *req.ID,
ID: req.ID,
Error: &respError{
Code: code,
Message: err.Error(),
Expand Down
Loading