diff --git a/README.md b/README.md index cd4d6049..f07e6837 100644 --- a/README.md +++ b/README.md @@ -380,3 +380,101 @@ delivery.UseTestnet = true BinanceClient = delivery.NewClient(ApiKey, SecretKey) ``` +#### Websocket client +##### Order place +##### Async write/read +```go +func main() { + orderPlaceService, _ := futures.NewOrderPlaceWsService(apiKey, secretKey) + + ctx, cancel := context.WithCancel(context.Background()) + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + go func() { + select { + case <-c: + cancel() + } + }() + + request := futures.NewOrderPlaceWsRequest() + request. + Symbol("BTCUSDT"). + Side(futures.SideTypeSell). + Type(futures.OrderTypeLimit). + Price("68198.00"). + Quantity("0.002"). + TimeInForce(futures.TimeInForceTypeGTC) + + // sender + go func() { + for { + select { + case <-ctx.Done(): + return + default: + err := orderPlaceService.Do("id", request) + if err != nil { + return + } + } + } + }() + + wg := &sync.WaitGroup{} + wg.Add(1) + go listenOrderPlaceResponse(ctx, wg, orderPlaceService) + wg.Wait() + + log.Println("exit") +} + +func listenOrderPlaceResponse(ctx context.Context, wg *sync.WaitGroup, orderPlaceService *futures.OrderPlaceWsService) { + defer wg.Done() + + go func() { + for msg := range orderPlaceService.GetReadChannel() { + log.Println("order place response", string(msg)) + } + }() + + go func() { + for err := range orderPlaceService.GetReadErrorChannel() { + log.Println("order place error", err) + } + }() + + select { + case <-ctx.Done(): + orderPlaceService.ReceiveAllDataBeforeStop(10 * time.Second) + } +} +``` +##### Sync write/read +```go +func main() { + orderPlaceService, _ := futures.NewOrderPlaceWsService(apiKey, secretKey) + + id := "some-id" + request := futures.NewOrderPlaceWsRequest() + request. + Symbol("BTCUSDT"). + Side(futures.SideTypeSell). + Type(futures.OrderTypeLimit). + Price("68198.00"). + Quantity("0.002"). + TimeInForce(futures.TimeInForceTypeGTC) + + response, err := orderPlaceService.SyncDo(id, request) + if err != nil { + log.Fatal(err) + } + + // handle response +} +``` + + + + diff --git a/v2/futures/client_ws.go b/v2/futures/client_ws.go new file mode 100644 index 00000000..6e894eab --- /dev/null +++ b/v2/futures/client_ws.go @@ -0,0 +1,500 @@ +package futures + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + "github.com/jpillora/backoff" + + "github.com/adshao/go-binance/v2/common" +) + +//go:generate mockgen -source client_ws.go -destination mock/client_ws.go -package mock + +const ( + // reconnectMinInterval define reconnect min interval + reconnectMinInterval = 100 * time.Millisecond + + // reconnectMaxInterval define reconnect max interval + reconnectMaxInterval = 10 * time.Second +) + +var ( + // ErrorWsConnectionClosed defines that connection closed + ErrorWsConnectionClosed = errors.New("ws error: connection closed") + + // ErrorWsReadConnectionTimeout defines that connection read timeout expired + ErrorWsReadConnectionTimeout = errors.New("ws error: read connection timeout") + + // ErrorWsIdAlreadySent defines that request with the same id was already sent + ErrorWsIdAlreadySent = errors.New("ws error: request with same id already sent") +) + +// messageId define id field of request/response +type messageId struct { + Id string `json:"id"` +} + +// ClientWs define API websocket client +type ClientWs struct { + APIKey string + SecretKey string + Debug bool + KeyType string + TimeOffset int64 + logger *log.Logger + conn wsConnection + connMu sync.Mutex + reconnectSignal chan struct{} + connectionEstablishedSignal chan struct{} + requestsList RequestList + readC chan []byte + readErrChan chan error + reconnectCount atomic.Int64 +} + +func (c *ClientWs) debug(format string, v ...interface{}) { + if c.Debug { + c.logger.Println(fmt.Sprintf(format, v...)) + } +} + +// NewClientWs init ClientWs +func NewClientWs(conn wsConnection, apiKey, secretKey string) (*ClientWs, error) { + client := &ClientWs{ + APIKey: apiKey, + SecretKey: secretKey, + KeyType: common.KeyTypeHmac, + logger: log.New(os.Stderr, "Binance-golang ", log.LstdFlags), + conn: conn, + connMu: sync.Mutex{}, + reconnectSignal: make(chan struct{}, 1), + connectionEstablishedSignal: make(chan struct{}, 1), + requestsList: NewRequestList(), + readErrChan: make(chan error, 1), + readC: make(chan []byte), + } + + go client.handleReconnect() + go client.read() + + return client, nil +} + +type wsClient interface { + Write(id string, data []byte) error + WriteSync(id string, data []byte, timeout time.Duration) ([]byte, error) + GetReadChannel() <-chan []byte + GetReadErrorChannel() <-chan error + GetApiKey() string + GetSecretKey() string + GetTimeOffset() int64 + GetKeyType() string + GetReconnectCount() int64 + Wait(timeout time.Duration) +} + +// Write sends data into websocket connection +func (c *ClientWs) Write(id string, data []byte) error { + c.connMu.Lock() + defer c.connMu.Unlock() + + if c.requestsList.IsAlreadyInList(id) { + return ErrorWsIdAlreadySent + } + + if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { + c.debug("write: unable to write message into websocket conn '%v'", err) + return err + } + + c.requestsList.Add(id) + + return nil +} + +// WriteSync sends data to the websocket connection and waits for a response synchronously +// Should be used separately from the asynchronous Write method (do not send anything in parallel) +func (c *ClientWs) WriteSync(id string, data []byte, timeout time.Duration) ([]byte, error) { + c.connMu.Lock() + defer c.connMu.Unlock() + + if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { + c.debug("write sync: unable to write message into websocket conn '%v'", err) + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + select { + case <-ctx.Done(): + c.debug("write sync: timeout expired") + return nil, ErrorWsReadConnectionTimeout + case rawData := <-c.readC: + // check that the correct response from websocket has been read + msg := messageId{} + err := json.Unmarshal(rawData, &msg) + if err != nil { + return nil, err + } + if msg.Id != id { + c.debug("write sync: wrong response with id '%v' has been read", msg.Id) + continue + } + + return rawData, nil + case err := <-c.readErrChan: + c.debug("write sync: error read '%v'", err) + return nil, err + } + } +} + +func (c *ClientWs) GetReadChannel() <-chan []byte { + return c.readC +} + +func (c *ClientWs) GetReadErrorChannel() <-chan error { + return c.readErrChan +} + +func (c *ClientWs) GetApiKey() string { + return c.APIKey +} + +func (c *ClientWs) GetSecretKey() string { + return c.SecretKey +} + +func (c *ClientWs) GetTimeOffset() int64 { + return c.TimeOffset +} + +func (c *ClientWs) GetKeyType() string { + return c.KeyType +} + +func (c *ClientWs) Wait(timeout time.Duration) { + c.wait(timeout) +} + +// read data from connection +func (c *ClientWs) read() { + defer func() { + // reading from closed connection 1000 times caused panic + // prevent panic for any case + if r := recover(); r != nil { + } + }() + + for { + c.debug("read: waiting for message") + _, message, err := c.conn.ReadMessage() + if err != nil { + c.debug("read: error reading message '%v'", err) + c.reconnectSignal <- struct{}{} + c.readErrChan <- errors.Join(err, ErrorWsConnectionClosed) + + c.debug("read: wait to get connected") + <-c.connectionEstablishedSignal + + // refresh map after reconnect to avoid useless waiting after stop application + c.requestsList.RecreateList() + + c.debug("read: connection established") + continue + } + c.debug("read: got new message") + + msg := messageId{} + err = json.Unmarshal(message, &msg) + if err != nil { + c.debug("read: error unmarshalling message '%v'", err) + c.readErrChan <- err + continue + } + + c.debug("read: sending message into read channel '%v'", msg) + c.readC <- message + + c.debug("read: remove message from request list '%v'", msg) + c.requestsList.Remove(msg.Id) + } +} + +// wait until all responses received +// make sure that you are not sending requests +func (c *ClientWs) wait(timeout time.Duration) { + doneC := make(chan struct{}) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + for { + select { + case <-ctx.Done(): + return + default: + if c.requestsList.Len() == 0 { + doneC <- struct{}{} + return + } + } + } + }() + + t := time.After(timeout) + select { + case <-t: + case <-doneC: + } + + cancel() +} + +// handleReconnect waits for reconnect signal and starts reconnect +func (c *ClientWs) handleReconnect() { + for _ = range c.reconnectSignal { + c.debug("reconnect: received signal") + + b := &backoff.Backoff{ + Min: reconnectMinInterval, + Max: reconnectMaxInterval, + Factor: 1.8, + Jitter: false, + } + + conn := c.startReconnect(b) + + b.Reset() + + c.connMu.Lock() + c.conn = conn + c.connMu.Unlock() + + c.debug("reconnect: connected") + c.connectionEstablishedSignal <- struct{}{} + } +} + +// startReconnect starts reconnect loop with increasing delay +func (c *ClientWs) startReconnect(b *backoff.Backoff) *connection { + for { + c.reconnectCount.Add(1) + conn, err := newConnection() + if err != nil { + delay := b.Duration() + c.debug("reconnect: error while reconnecting. try in %s", delay.Round(time.Millisecond)) + time.Sleep(delay) + continue + } + + return conn + } +} + +// GetReconnectCount returns reconnect counter value +func (c *ClientWs) GetReconnectCount() int64 { + return c.reconnectCount.Load() +} + +// NewRequestList creates request list +func NewRequestList() RequestList { + return RequestList{ + mu: sync.Mutex{}, + requests: make(map[string]struct{}), // TODO preallocate buckets + } +} + +// RequestList state of requests that was sent/received +type RequestList struct { + mu sync.Mutex + requests map[string]struct{} +} + +// Add adds request into list +func (l *RequestList) Add(id string) { + l.mu.Lock() + defer l.mu.Unlock() + l.requests[id] = struct{}{} +} + +// RecreateList creates new request list +func (l *RequestList) RecreateList() { + l.mu.Lock() + defer l.mu.Unlock() + l.requests = make(map[string]struct{}) +} + +// Remove adds request from list +func (l *RequestList) Remove(id string) { + l.mu.Lock() + defer l.mu.Unlock() + delete(l.requests, id) +} + +// Len get list length +func (l *RequestList) Len() int { + l.mu.Lock() + defer l.mu.Unlock() + return len(l.requests) +} + +// IsAlreadyInList checks if id is presented in requests list +func (l *RequestList) IsAlreadyInList(id string) bool { + l.mu.Lock() + defer l.mu.Unlock() + if _, ok := l.requests[id]; ok { + return true + } + + return false +} + +// constructor for connection +func newConnection() (*connection, error) { + conn, err := WsApiInitReadWriteConn() + if err != nil { + return nil, err + } + + wsConn := &connection{ + conn: conn, + connectionMu: sync.Mutex{}, + lastResponseMu: sync.Mutex{}, + } + + if WebsocketKeepalive { + go wsConn.keepAlive(WebsocketTimeoutReadWriteConnection) + } + + return wsConn, nil +} + +// instance of single connection with keepalive handler +type connection struct { + conn *websocket.Conn + connectionMu sync.Mutex + lastResponse time.Time + lastResponseMu sync.Mutex +} + +type wsConnection interface { + WriteMessage(messageType int, data []byte) error + ReadMessage() (messageType int, p []byte, err error) +} + +// WriteMessage is a thread-safe method for conn.WriteMessage +func (c *connection) WriteMessage(messageType int, data []byte) error { + c.connectionMu.Lock() + defer c.connectionMu.Unlock() + return c.conn.WriteMessage(messageType, data) +} + +// ReadMessage wrapper for conn.ReadMessage +func (c *connection) ReadMessage() (int, []byte, error) { + return c.conn.ReadMessage() +} + +// keepAlive handles ping-pong for connection +func (c *connection) keepAlive(timeout time.Duration) { + ticker := time.NewTicker(timeout) + + c.updateLastResponse() + + c.conn.SetPongHandler(func(msg string) error { + c.updateLastResponse() + return nil + }) + + go func() { + defer ticker.Stop() + for { + err := c.ping() + if err != nil { + return + } + + <-ticker.C + if c.isLastResponseOutdated(timeout) { + c.close() + return + } + } + }() +} + +// updateLastResponse sets lastResponse now +func (c *connection) updateLastResponse() { + c.lastResponseMu.Lock() + defer c.lastResponseMu.Unlock() + c.lastResponse = time.Now() +} + +// isLastResponseOutdated checks that time since last pong message exceeded timeout +func (c *connection) isLastResponseOutdated(timeout time.Duration) bool { + c.lastResponseMu.Lock() + defer c.lastResponseMu.Unlock() + return time.Since(c.lastResponse) > timeout +} + +// close thread-safe method for closing connection +func (c *connection) close() error { + c.connectionMu.Lock() + defer c.connectionMu.Unlock() + return c.conn.Close() +} + +// ping thread-safe method sending ping message +func (c *connection) ping() error { + c.connectionMu.Lock() + defer c.connectionMu.Unlock() + + deadline := time.Now().Add(10 * time.Second) + err := c.conn.WriteControl(websocket.PingMessage, []byte{}, deadline) + if err != nil { + return err + } + + return nil +} + +// NewOrderPlaceWsService init OrderPlaceWsService +func NewOrderPlaceWsService(apiKey, secretKey string) (*OrderPlaceWsService, error) { + conn, err := newConnection() + if err != nil { + return nil, err + } + + client, err := NewClientWs(conn, apiKey, secretKey) + if err != nil { + return nil, err + } + + return &OrderPlaceWsService{c: client}, nil +} + +// NewOrderCancelWsService init OrderCancelWsService +func NewOrderCancelWsService(apiKey, secretKey string) (*OrderCancelWsService, error) { + conn, err := newConnection() + if err != nil { + return nil, err + } + + client, err := NewClientWs(conn, apiKey, secretKey) + if err != nil { + return nil, err + } + + return &OrderCancelWsService{c: client}, nil +} diff --git a/v2/futures/client_ws_test.go b/v2/futures/client_ws_test.go new file mode 100644 index 00000000..50d1d7dc --- /dev/null +++ b/v2/futures/client_ws_test.go @@ -0,0 +1,242 @@ +package futures + +import ( + "context" + "encoding/json" + "errors" + "log" + "net/http" + "testing" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/suite" +) + +func (s *clientWsTestSuite) SetupTest() { + s.apiKey = "dummyApiKey" + s.secretKey = "dummySecretKey" +} + +type clientWsTestSuite struct { + suite.Suite + apiKey string + secretKey string +} + +func TestClientWs(t *testing.T) { + suite.Run(t, new(clientWsTestSuite)) +} + +func (s *clientWsTestSuite) TestReadWriteSync() { + stopCh := make(chan struct{}) + go func() { + startWsTestServer(stopCh) + }() + defer func() { + stopCh <- struct{}{} + }() + + useLocalhost = true + WebsocketKeepalive = true + + conn, err := newConnection() + s.Require().NoError(err) + + client, err := NewClientWs(conn, s.apiKey, s.secretKey) + s.Require().NoError(err) + + tests := []struct { + name string + testCallback func() + }{ + { + name: "WriteSync success", + testCallback: func() { + id, err := uuid.NewRandom() + s.Require().NoError(err) + requestID := id.String() + + req := WsApiRequest{ + Id: requestID, + Method: "some-method", + Params: map[string]interface{}{}, + } + reqRaw, err := json.Marshal(req) + s.Require().NoError(err) + + responseRaw, err := client.WriteSync(requestID, reqRaw, WriteSyncWsTimeout) + s.Require().NoError(err) + s.Require().Equal(reqRaw, responseRaw) + }, + }, + { + name: "WriteSync success read message with parallel writing", + testCallback: func() { + id, err := uuid.NewRandom() + s.Require().NoError(err) + requestID := id.String() + + req := WsApiRequest{ + Id: "some-other-request-id", + Method: "some-method", + Params: map[string]interface{}{}, + } + reqRaw, err := json.Marshal(req) + s.Require().NoError(err) + + err = client.Write(requestID, reqRaw) + s.Require().NoError(err) + + req = WsApiRequest{ + Id: requestID, + Method: "some-method", + Params: map[string]interface{}{}, + } + reqRaw, err = json.Marshal(req) + s.Require().NoError(err) + + responseRaw, err := client.WriteSync(requestID, reqRaw, WriteSyncWsTimeout) + s.Require().NoError(err) + s.Require().Equal(reqRaw, responseRaw) + }, + }, + { + name: "WriteSync timeout expired", + testCallback: func() { + id, err := uuid.NewRandom() + s.Require().NoError(err) + requestID := id.String() + + req := WsApiRequest{ + Id: requestID, + Method: "some-method", + Params: map[string]interface{}{ + "timeout": "true", + }, + } + reqRaw, err := json.Marshal(req) + s.Require().NoError(err) + + responseRaw, err := client.WriteSync(requestID, reqRaw, 500*time.Millisecond) + s.Require().Nil(responseRaw) + s.Require().ErrorIs(err, ErrorWsReadConnectionTimeout) + }, + }, + { + name: "WriteAsync success", + testCallback: func() { + id, err := uuid.NewRandom() + s.Require().NoError(err) + requestID := id.String() + + req := WsApiRequest{ + Id: requestID, + Method: "some-method", + Params: map[string]interface{}{}, + } + reqRaw, err := json.Marshal(req) + s.Require().NoError(err) + + err = client.Write(requestID, reqRaw) + s.Require().NoError(err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + select { + case <-ctx.Done(): + s.T().Fatal("timeout waiting for write") + case responseRaw := <-client.GetReadChannel(): + s.Require().Equal(reqRaw, responseRaw) + case err := <-client.GetReadErrorChannel(): + s.T().Fatalf("unexpected error: '%v'", err) + } + }, + }, + } + + for _, tt := range tests { + tt := tt + s.T().Run(tt.name, func(t *testing.T) { + tt.testCallback() + }) + } +} + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + +func wsHandler(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println("Error upgrading to WebSocket:", err) + return + } + defer conn.Close() + + conn.SetPingHandler(func(appData string) error { + log.Println("Received ping:", appData) + err := conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(10*time.Second)) + if err != nil { + log.Println("Error sending pong:", err) + } + return nil + }) + + for { + messageType, message, err := conn.ReadMessage() + if err != nil { + log.Println("Error reading message:", err) + break + } + + log.Printf("Received message: %s\n", message) + + req := WsApiRequest{} + if err := json.Unmarshal(message, &req); err != nil { + log.Println("Error unmarshalling message:", err) + continue + } + + if req.Params["timeout"] == "true" { + continue + } + + err = conn.WriteMessage(messageType, message) + if err != nil { + log.Println("Error writing message:", err) + break + } + } +} + +func startWsTestServer(stopCh chan struct{}) { + server := &http.Server{ + Addr: ":8080", + } + + http.HandleFunc("/ws", wsHandler) + log.Println("WebSocket server started on :8080") + + go func() { + if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("WebSocket server error: %v", err) + } + log.Println("Stopped serving new connections.") + }() + + <-stopCh + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := server.Shutdown(ctx); err != nil { + log.Fatalf("WebSocket shutdown error: %v", err) + } + log.Println("Graceful shutdown complete.") +} diff --git a/v2/futures/mock/client_ws.go b/v2/futures/mock/client_ws.go new file mode 100644 index 00000000..e3f5c4f5 --- /dev/null +++ b/v2/futures/mock/client_ws.go @@ -0,0 +1,227 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: client_ws.go + +// Package mock is a generated GoMock package. +package mock + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" +) + +// MockwsClient is a mock of wsClient interface. +type MockwsClient struct { + ctrl *gomock.Controller + recorder *MockwsClientMockRecorder +} + +// MockwsClientMockRecorder is the mock recorder for MockwsClient. +type MockwsClientMockRecorder struct { + mock *MockwsClient +} + +// NewMockwsClient creates a new mock instance. +func NewMockwsClient(ctrl *gomock.Controller) *MockwsClient { + mock := &MockwsClient{ctrl: ctrl} + mock.recorder = &MockwsClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockwsClient) EXPECT() *MockwsClientMockRecorder { + return m.recorder +} + +// GetApiKey mocks base method. +func (m *MockwsClient) GetApiKey() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetApiKey") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetApiKey indicates an expected call of GetApiKey. +func (mr *MockwsClientMockRecorder) GetApiKey() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetApiKey", reflect.TypeOf((*MockwsClient)(nil).GetApiKey)) +} + +// GetKeyType mocks base method. +func (m *MockwsClient) GetKeyType() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKeyType") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetKeyType indicates an expected call of GetKeyType. +func (mr *MockwsClientMockRecorder) GetKeyType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeyType", reflect.TypeOf((*MockwsClient)(nil).GetKeyType)) +} + +// GetReadChannel mocks base method. +func (m *MockwsClient) GetReadChannel() <-chan []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetReadChannel") + ret0, _ := ret[0].(<-chan []byte) + return ret0 +} + +// GetReadChannel indicates an expected call of GetReadChannel. +func (mr *MockwsClientMockRecorder) GetReadChannel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReadChannel", reflect.TypeOf((*MockwsClient)(nil).GetReadChannel)) +} + +// GetReadErrorChannel mocks base method. +func (m *MockwsClient) GetReadErrorChannel() <-chan error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetReadErrorChannel") + ret0, _ := ret[0].(<-chan error) + return ret0 +} + +// GetReadErrorChannel indicates an expected call of GetReadErrorChannel. +func (mr *MockwsClientMockRecorder) GetReadErrorChannel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReadErrorChannel", reflect.TypeOf((*MockwsClient)(nil).GetReadErrorChannel)) +} + +// GetReconnectCount mocks base method. +func (m *MockwsClient) GetReconnectCount() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetReconnectCount") + ret0, _ := ret[0].(int64) + return ret0 +} + +// GetReconnectCount indicates an expected call of GetReconnectCount. +func (mr *MockwsClientMockRecorder) GetReconnectCount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReconnectCount", reflect.TypeOf((*MockwsClient)(nil).GetReconnectCount)) +} + +// GetSecretKey mocks base method. +func (m *MockwsClient) GetSecretKey() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSecretKey") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetSecretKey indicates an expected call of GetSecretKey. +func (mr *MockwsClientMockRecorder) GetSecretKey() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSecretKey", reflect.TypeOf((*MockwsClient)(nil).GetSecretKey)) +} + +// GetTimeOffset mocks base method. +func (m *MockwsClient) GetTimeOffset() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTimeOffset") + ret0, _ := ret[0].(int64) + return ret0 +} + +// GetTimeOffset indicates an expected call of GetTimeOffset. +func (mr *MockwsClientMockRecorder) GetTimeOffset() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTimeOffset", reflect.TypeOf((*MockwsClient)(nil).GetTimeOffset)) +} + +// Wait mocks base method. +func (m *MockwsClient) Wait(timeout time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Wait", timeout) +} + +// Wait indicates an expected call of Wait. +func (mr *MockwsClientMockRecorder) Wait(timeout interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Wait", reflect.TypeOf((*MockwsClient)(nil).Wait), timeout) +} + +// Write mocks base method. +func (m *MockwsClient) Write(id string, data []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", id, data) + ret0, _ := ret[0].(error) + return ret0 +} + +// Write indicates an expected call of Write. +func (mr *MockwsClientMockRecorder) Write(id, data interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockwsClient)(nil).Write), id, data) +} + +// WriteSync mocks base method. +func (m *MockwsClient) WriteSync(id string, data []byte, timeout time.Duration) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteSync", id, data, timeout) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WriteSync indicates an expected call of WriteSync. +func (mr *MockwsClientMockRecorder) WriteSync(id, data, timeout interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteSync", reflect.TypeOf((*MockwsClient)(nil).WriteSync), id, data, timeout) +} + +// MockwsConnection is a mock of wsConnection interface. +type MockwsConnection struct { + ctrl *gomock.Controller + recorder *MockwsConnectionMockRecorder +} + +// MockwsConnectionMockRecorder is the mock recorder for MockwsConnection. +type MockwsConnectionMockRecorder struct { + mock *MockwsConnection +} + +// NewMockwsConnection creates a new mock instance. +func NewMockwsConnection(ctrl *gomock.Controller) *MockwsConnection { + mock := &MockwsConnection{ctrl: ctrl} + mock.recorder = &MockwsConnectionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockwsConnection) EXPECT() *MockwsConnectionMockRecorder { + return m.recorder +} + +// ReadMessage mocks base method. +func (m *MockwsConnection) ReadMessage() (int, []byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadMessage") + ret0, _ := ret[0].(int) + ret1, _ := ret[1].([]byte) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ReadMessage indicates an expected call of ReadMessage. +func (mr *MockwsConnectionMockRecorder) ReadMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadMessage", reflect.TypeOf((*MockwsConnection)(nil).ReadMessage)) +} + +// WriteMessage mocks base method. +func (m *MockwsConnection) WriteMessage(messageType int, data []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteMessage", messageType, data) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteMessage indicates an expected call of WriteMessage. +func (mr *MockwsConnectionMockRecorder) WriteMessage(messageType, data interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteMessage", reflect.TypeOf((*MockwsConnection)(nil).WriteMessage), messageType, data) +} diff --git a/v2/futures/order_cancel_service_ws.go b/v2/futures/order_cancel_service_ws.go new file mode 100644 index 00000000..9d7d4a39 --- /dev/null +++ b/v2/futures/order_cancel_service_ws.go @@ -0,0 +1,129 @@ +package futures + +import ( + "encoding/json" + "time" + + "github.com/adshao/go-binance/v2/common" +) + +// NewOrderCancelRequest init OrderCancelRequest +func NewOrderCancelRequest() *OrderCancelRequest { + return &OrderCancelRequest{} +} + +// OrderCancelRequest parameters for 'order.cancel' websocket API +type OrderCancelRequest struct { + symbol string + orderID *int64 + origClientOrderID *string +} + +// Symbol set symbol +func (s *OrderCancelRequest) Symbol(symbol string) *OrderCancelRequest { + s.symbol = symbol + return s +} + +// OrderID set orderID +func (s *OrderCancelRequest) OrderID(orderID int64) *OrderCancelRequest { + s.orderID = &orderID + return s +} + +// OrigClientOrderID set origClientOrderID +func (s *OrderCancelRequest) OrigClientOrderID(origClientOrderID string) *OrderCancelRequest { + s.origClientOrderID = &origClientOrderID + return s +} + +// buildParams builds params +func (s *OrderCancelRequest) buildParams() params { + m := params{ + "symbol": s.symbol, + } + + if s.orderID != nil { + m["orderId"] = *s.orderID + } + + if s.origClientOrderID != nil { + m["origClientOrderId"] = *s.origClientOrderID + } + + return m +} + +// CancelOrderResult define order cancel result +type CancelOrderResult struct { + CancelOrderResponse +} + +// OrderCancelWsResponse define 'order.cancel' websocket API response +type OrderCancelWsResponse struct { + Id string `json:"id"` + Status int `json:"status"` + Result CancelOrderResult `json:"result"` + + // error response + Error *common.APIError `json:"error,omitempty"` +} + +// OrderCancelWsService cancel order +type OrderCancelWsService struct { + c wsClient +} + +// Do - sends 'order.cancel' request +func (s *OrderCancelWsService) Do(requestID string, request *OrderCancelRequest) error { + rawData, err := createWsRequest(requestID, s.c, CancelWsApiMethod, request.buildParams()) + if err != nil { + return err + } + + if err := s.c.Write(requestID, rawData); err != nil { + return err + } + + return nil +} + +// SyncDo - sends 'order.cancel' request and receives response +func (s *OrderCancelWsService) SyncDo(requestID string, request *OrderCancelRequest) (*OrderCancelWsResponse, error) { + rawData, err := createWsRequest(requestID, s.c, CancelWsApiMethod, request.buildParams()) + if err != nil { + return nil, err + } + + response, err := s.c.WriteSync(requestID, rawData, WriteSyncWsTimeout) + if err != nil { + return nil, err + } + + cancelOrderWsResponse := &OrderCancelWsResponse{} + if err := json.Unmarshal(response, cancelOrderWsResponse); err != nil { + return nil, err + } + + return cancelOrderWsResponse, nil +} + +// ReceiveAllDataBeforeStop waits until all responses will be received from websocket until timeout expired +func (s *OrderCancelWsService) ReceiveAllDataBeforeStop(timeout time.Duration) { + s.c.Wait(timeout) +} + +// GetReadChannel returns channel with API response data (including API errors) +func (s *OrderCancelWsService) GetReadChannel() <-chan []byte { + return s.c.GetReadChannel() +} + +// GetReadErrorChannel returns channel with errors which are occurred while reading websocket connection +func (s *OrderCancelWsService) GetReadErrorChannel() <-chan error { + return s.c.GetReadErrorChannel() +} + +// GetReconnectCount returns count of reconnect attempts by client +func (s *OrderCancelWsService) GetReconnectCount() int64 { + return s.c.GetReconnectCount() +} diff --git a/v2/futures/order_cancel_service_ws_test.go b/v2/futures/order_cancel_service_ws_test.go new file mode 100644 index 00000000..1019a6ea --- /dev/null +++ b/v2/futures/order_cancel_service_ws_test.go @@ -0,0 +1,174 @@ +package futures + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/adshao/go-binance/v2/futures/mock" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/suite" +) + +func (s *orderCancelServiceWsTestSuite) SetupTest() { + s.apiKey = "dummyApiKey" + s.secretKey = "dummySecretKey" + s.signedKey = "HMAC" + s.timeOffset = 0 + + s.requestID = "e2a85d9f-07a5-4f94-8d5f-789dc3deb098" + + s.ctrl = gomock.NewController(s.T()) + s.client = mock.NewMockwsClient(s.ctrl) + + s.orderCancel = &OrderCancelWsService{ + c: s.client, + } + + s.orderCancelRequest = NewOrderCancelRequest().OrigClientOrderID(s.requestID) +} + +func (s *orderCancelServiceWsTestSuite) TearDownTest() { + s.ctrl.Finish() +} + +type orderCancelServiceWsTestSuite struct { + suite.Suite + apiKey string + secretKey string + signedKey string + timeOffset int64 + + ctrl *gomock.Controller + client *mock.MockwsClient + + requestID string + + orderCancel *OrderCancelWsService + orderCancelRequest *OrderCancelRequest +} + +func TestOrderCancelServiceWs(t *testing.T) { + suite.Run(t, new(orderCancelServiceWsTestSuite)) +} + +func (s *orderCancelServiceWsTestSuite) TestOrderCancel() { + s.expectCalls(s.apiKey, s.secretKey, s.signedKey, s.timeOffset) + + s.client.EXPECT().Write(s.requestID, gomock.Any()).Return(nil).AnyTimes() + + err := s.orderCancel.Do(s.requestID, s.orderCancelRequest) + s.NoError(err) +} + +func (s *orderCancelServiceWsTestSuite) TestOrderCancel_EmptyRequestID() { + s.expectCalls(s.apiKey, s.secretKey, s.signedKey, s.timeOffset) + + s.client.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).Times(0) + + err := s.orderCancel.Do("", s.orderCancelRequest) + s.ErrorIs(err, ErrorRequestIDNotSet) +} + +func (s *orderCancelServiceWsTestSuite) TestOrderCancel_EmptyApiKey() { + s.expectCalls("", s.secretKey, s.signedKey, s.timeOffset) + + s.client.EXPECT().Write(s.requestID, gomock.Any()).Return(nil).Times(0) + + err := s.orderCancel.Do(s.requestID, s.orderCancelRequest) + s.ErrorIs(err, ErrorApiKeyIsNotSet) +} + +func (s *orderCancelServiceWsTestSuite) TestOrderCancel_EmptySecretKey() { + s.expectCalls(s.apiKey, "", s.signedKey, s.timeOffset) + + s.client.EXPECT().Write(s.requestID, gomock.Any()).Return(nil).Times(0) + + err := s.orderCancel.Do(s.requestID, s.orderCancelRequest) + s.ErrorIs(err, ErrorSecretKeyIsNotSet) +} + +func (s *orderCancelServiceWsTestSuite) TestOrderCancel_EmptySignKeyType() { + s.expectCalls(s.apiKey, s.secretKey, "", s.timeOffset) + + s.client.EXPECT().Write(s.requestID, gomock.Any()).Return(nil).Times(0) + + err := s.orderCancel.Do(s.requestID, s.orderCancelRequest) + s.Error(err) +} + +func (s *orderCancelServiceWsTestSuite) TestOrderCancelSync() { + s.expectCalls(s.apiKey, s.secretKey, s.signedKey, s.timeOffset) + + orderCancelResponse := OrderCancelWsResponse{ + Id: s.requestID, + Status: 200, + Result: CancelOrderResult{ + CancelOrderResponse{ + ClientOrderID: s.requestID, + }, + }, + } + + rawResponseData, err := json.Marshal(orderCancelResponse) + s.NoError(err) + + s.client.EXPECT().WriteSync(s.requestID, gomock.Any(), gomock.Any()).Return(rawResponseData, nil).Times(1) + + req := s.orderCancelRequest + response, err := s.orderCancel.SyncDo(s.requestID, req) + s.Require().NoError(err) + s.Equal(s.requestID, response.Result.ClientOrderID) +} + +func (s *orderCancelServiceWsTestSuite) TestOrderCancelSync_EmptyRequestID() { + s.expectCalls(s.apiKey, s.secretKey, s.signedKey, s.timeOffset) + + s.client.EXPECT(). + WriteSync(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("write sync: error")).Times(0) + + req := s.orderCancelRequest + response, err := s.orderCancel.SyncDo("", req) + s.Nil(response) + s.ErrorIs(err, ErrorRequestIDNotSet) +} + +func (s *orderCancelServiceWsTestSuite) TestOrderCancelSync_EmptyApiKey() { + s.expectCalls("", s.secretKey, s.signedKey, s.timeOffset) + + s.client.EXPECT(). + WriteSync(s.requestID, gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("write sync: error")).Times(0) + + response, err := s.orderCancel.SyncDo(s.requestID, s.orderCancelRequest) + s.Nil(response) + s.ErrorIs(err, ErrorApiKeyIsNotSet) +} + +func (s *orderCancelServiceWsTestSuite) TestOrderCancelSync_EmptySecretKey() { + s.expectCalls(s.apiKey, "", s.signedKey, s.timeOffset) + + s.client.EXPECT(). + WriteSync(s.requestID, gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("write sync: error")).Times(0) + + response, err := s.orderCancel.SyncDo(s.requestID, s.orderCancelRequest) + s.Nil(response) + s.ErrorIs(err, ErrorSecretKeyIsNotSet) +} + +func (s *orderCancelServiceWsTestSuite) TestOrderCancelSync_EmptySignKeyType() { + s.expectCalls(s.apiKey, s.secretKey, "", s.timeOffset) + + s.client.EXPECT(). + WriteSync(s.requestID, gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("write sync: error")).Times(0) + + response, err := s.orderCancel.SyncDo(s.requestID, s.orderCancelRequest) + s.Nil(response) + s.Error(err) +} + +func (s *orderCancelServiceWsTestSuite) expectCalls(apiKey, secretKey, signKeyType string, timeOffset int64) { + s.client.EXPECT().GetApiKey().Return(apiKey).AnyTimes() + s.client.EXPECT().GetSecretKey().Return(secretKey).AnyTimes() + s.client.EXPECT().GetKeyType().Return(signKeyType).AnyTimes() + s.client.EXPECT().GetTimeOffset().Return(timeOffset).AnyTimes() +} diff --git a/v2/futures/order_place_service_ws.go b/v2/futures/order_place_service_ws.go new file mode 100644 index 00000000..759bca51 --- /dev/null +++ b/v2/futures/order_place_service_ws.go @@ -0,0 +1,252 @@ +package futures + +import ( + "encoding/json" + "time" + + "github.com/adshao/go-binance/v2/common" +) + +// OrderPlaceWsService creates order +type OrderPlaceWsService struct { + c wsClient + signFn func(string, string) (*string, error) +} + +// OrderPlaceWsRequest parameters for 'order.place' websocket API +type OrderPlaceWsRequest struct { + symbol string + side SideType + positionSide *PositionSideType + orderType OrderType + timeInForce *TimeInForceType + quantity string + reduceOnly *bool + price *string + newClientOrderID *string + stopPrice *string + workingType *WorkingType + activationPrice *string + callbackRate *string + priceProtect *bool + newOrderRespType NewOrderRespType + closePosition *bool +} + +// NewOrderPlaceWsRequest init OrderPlaceWsRequest +func NewOrderPlaceWsRequest() *OrderPlaceWsRequest { + return &OrderPlaceWsRequest{} +} + +// Symbol set symbol +func (s *OrderPlaceWsRequest) Symbol(symbol string) *OrderPlaceWsRequest { + s.symbol = symbol + return s +} + +// Side set side +func (s *OrderPlaceWsRequest) Side(side SideType) *OrderPlaceWsRequest { + s.side = side + return s +} + +// PositionSide set side +func (s *OrderPlaceWsRequest) PositionSide(positionSide PositionSideType) *OrderPlaceWsRequest { + s.positionSide = &positionSide + return s +} + +// Type set type +func (s *OrderPlaceWsRequest) Type(orderType OrderType) *OrderPlaceWsRequest { + s.orderType = orderType + return s +} + +// TimeInForce set timeInForce +func (s *OrderPlaceWsRequest) TimeInForce(timeInForce TimeInForceType) *OrderPlaceWsRequest { + s.timeInForce = &timeInForce + return s +} + +// Quantity set quantity +func (s *OrderPlaceWsRequest) Quantity(quantity string) *OrderPlaceWsRequest { + s.quantity = quantity + return s +} + +// ReduceOnly set reduceOnly +func (s *OrderPlaceWsRequest) ReduceOnly(reduceOnly bool) *OrderPlaceWsRequest { + s.reduceOnly = &reduceOnly + return s +} + +// Price set price +func (s *OrderPlaceWsRequest) Price(price string) *OrderPlaceWsRequest { + s.price = &price + return s +} + +// NewClientOrderID set newClientOrderID +func (s *OrderPlaceWsRequest) NewClientOrderID(newClientOrderID string) *OrderPlaceWsRequest { + s.newClientOrderID = &newClientOrderID + return s +} + +// StopPrice set stopPrice +func (s *OrderPlaceWsRequest) StopPrice(stopPrice string) *OrderPlaceWsRequest { + s.stopPrice = &stopPrice + return s +} + +// WorkingType set workingType +func (s *OrderPlaceWsRequest) WorkingType(workingType WorkingType) *OrderPlaceWsRequest { + s.workingType = &workingType + return s +} + +// ActivationPrice set activationPrice +func (s *OrderPlaceWsRequest) ActivationPrice(activationPrice string) *OrderPlaceWsRequest { + s.activationPrice = &activationPrice + return s +} + +// CallbackRate set callbackRate +func (s *OrderPlaceWsRequest) CallbackRate(callbackRate string) *OrderPlaceWsRequest { + s.callbackRate = &callbackRate + return s +} + +// PriceProtect set priceProtect +func (s *OrderPlaceWsRequest) PriceProtect(priceProtect bool) *OrderPlaceWsRequest { + s.priceProtect = &priceProtect + return s +} + +// NewOrderResponseType set newOrderResponseType +func (s *OrderPlaceWsRequest) NewOrderResponseType(newOrderResponseType NewOrderRespType) *OrderPlaceWsRequest { + s.newOrderRespType = newOrderResponseType + return s +} + +// ClosePosition set closePosition +func (s *OrderPlaceWsRequest) ClosePosition(closePosition bool) *OrderPlaceWsRequest { + s.closePosition = &closePosition + return s +} + +// CreateOrderResult define order creation result +type CreateOrderResult struct { + CreateOrderResponse +} + +// CreateOrderWsResponse define 'order.place' websocket API response +type CreateOrderWsResponse struct { + Id string `json:"id"` + Status int `json:"status"` + Result CreateOrderResult `json:"result"` + + // error response + Error *common.APIError `json:"error,omitempty"` +} + +// buildParams builds params +func (s *OrderPlaceWsRequest) buildParams() params { + m := params{ + "symbol": s.symbol, + "side": s.side, + "type": s.orderType, + "newOrderRespType": s.newOrderRespType, + } + if s.quantity != "" { + m["quantity"] = s.quantity + } + if s.positionSide != nil { + m["positionSide"] = *s.positionSide + } + if s.timeInForce != nil { + m["timeInForce"] = *s.timeInForce + } + if s.reduceOnly != nil { + m["reduceOnly"] = *s.reduceOnly + } + if s.price != nil { + m["price"] = *s.price + } + if s.newClientOrderID != nil { + m["newClientOrderId"] = *s.newClientOrderID + } + if s.stopPrice != nil { + m["stopPrice"] = *s.stopPrice + } + if s.workingType != nil { + m["workingType"] = *s.workingType + } + if s.priceProtect != nil { + m["priceProtect"] = *s.priceProtect + } + if s.activationPrice != nil { + m["activationPrice"] = *s.activationPrice + } + if s.callbackRate != nil { + m["callbackRate"] = *s.callbackRate + } + if s.closePosition != nil { + m["closePosition"] = *s.closePosition + } + + return m +} + +// Do - sends 'order.place' request +func (s *OrderPlaceWsService) Do(requestID string, request *OrderPlaceWsRequest) error { + rawData, err := createWsRequest(requestID, s.c, OrderPlaceWsApiMethod, request.buildParams()) + if err != nil { + return err + } + + if err := s.c.Write(requestID, rawData); err != nil { + return err + } + + return nil +} + +// SyncDo - sends 'order.place' request and receives response +func (s *OrderPlaceWsService) SyncDo(requestID string, request *OrderPlaceWsRequest) (*CreateOrderWsResponse, error) { + rawData, err := createWsRequest(requestID, s.c, OrderPlaceWsApiMethod, request.buildParams()) + if err != nil { + return nil, err + } + + response, err := s.c.WriteSync(requestID, rawData, WriteSyncWsTimeout) + if err != nil { + return nil, err + } + + createOrderWsResponse := &CreateOrderWsResponse{} + if err := json.Unmarshal(response, createOrderWsResponse); err != nil { + return nil, err + } + + return createOrderWsResponse, nil +} + +// ReceiveAllDataBeforeStop waits until all responses will be received from websocket until timeout expired +func (s *OrderPlaceWsService) ReceiveAllDataBeforeStop(timeout time.Duration) { + s.c.Wait(timeout) +} + +// GetReadChannel returns channel with API response data (including API errors) +func (s *OrderPlaceWsService) GetReadChannel() <-chan []byte { + return s.c.GetReadChannel() +} + +// GetReadErrorChannel returns channel with errors which are occurred while reading websocket connection +func (s *OrderPlaceWsService) GetReadErrorChannel() <-chan error { + return s.c.GetReadErrorChannel() +} + +// GetReconnectCount returns count of reconnect attempts by client +func (s *OrderPlaceWsService) GetReconnectCount() int64 { + return s.c.GetReconnectCount() +} diff --git a/v2/futures/order_place_service_ws_test.go b/v2/futures/order_place_service_ws_test.go new file mode 100644 index 00000000..0e425568 --- /dev/null +++ b/v2/futures/order_place_service_ws_test.go @@ -0,0 +1,205 @@ +package futures + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/adshao/go-binance/v2/futures/mock" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/suite" +) + +func (s *orderPlaceServiceWsTestSuite) SetupTest() { + s.apiKey = "dummyApiKey" + s.secretKey = "dummySecretKey" + s.signedKey = "HMAC" + s.timeOffset = 0 + + s.requestID = "e2a85d9f-07a5-4f94-8d5f-789dc3deb098" + + s.symbol = "BTCUSDT" + s.side = SideTypeSell + s.orderType = OrderTypeLimit + s.timeInForce = TimeInForceTypeGTC + s.quantity = "0.1001" + s.price = "50000" + s.newClientOrderID = "testOrder" + + s.ctrl = gomock.NewController(s.T()) + s.client = mock.NewMockwsClient(s.ctrl) + + s.orderPlace = &OrderPlaceWsService{ + c: s.client, + } + + s.orderPlaceRequest = NewOrderPlaceWsRequest(). + Symbol(s.symbol). + Side(s.side). + Type(s.orderType). + TimeInForce(s.timeInForce). + Quantity(s.quantity). + Price(s.price). + NewClientOrderID(s.newClientOrderID) +} + +func (s *orderPlaceServiceWsTestSuite) TearDownTest() { + s.ctrl.Finish() +} + +type orderPlaceServiceWsTestSuite struct { + suite.Suite + apiKey string + secretKey string + signedKey string + timeOffset int64 + + ctrl *gomock.Controller + client *mock.MockwsClient + + requestID string + symbol string + side SideType + orderType OrderType + timeInForce TimeInForceType + quantity string + price string + newClientOrderID string + + orderPlace *OrderPlaceWsService + orderPlaceRequest *OrderPlaceWsRequest +} + +func TestOrderPlaceServiceWsPlace(t *testing.T) { + suite.Run(t, new(orderPlaceServiceWsTestSuite)) +} + +func (s *orderPlaceServiceWsTestSuite) TestOrderPlace() { + s.expectCalls(s.apiKey, s.secretKey, s.signedKey, s.timeOffset) + + s.client.EXPECT().Write(s.requestID, gomock.Any()).Return(nil).AnyTimes() + + err := s.orderPlace.Do(s.requestID, s.orderPlaceRequest) + s.NoError(err) +} + +func (s *orderPlaceServiceWsTestSuite) TestOrderPlace_EmptyRequestID() { + s.expectCalls(s.apiKey, s.secretKey, s.signedKey, s.timeOffset) + + s.client.EXPECT().Write(gomock.Any(), gomock.Any()).Return(nil).Times(0) + + err := s.orderPlace.Do("", s.orderPlaceRequest) + s.ErrorIs(err, ErrorRequestIDNotSet) +} + +func (s *orderPlaceServiceWsTestSuite) TestOrderPlace_EmptyApiKey() { + s.expectCalls("", s.secretKey, s.signedKey, s.timeOffset) + + s.client.EXPECT().Write(s.requestID, gomock.Any()).Return(nil).Times(0) + + err := s.orderPlace.Do(s.requestID, s.orderPlaceRequest) + s.ErrorIs(err, ErrorApiKeyIsNotSet) +} + +func (s *orderPlaceServiceWsTestSuite) TestOrderPlace_EmptySecretKey() { + s.expectCalls(s.apiKey, "", s.signedKey, s.timeOffset) + + s.client.EXPECT().Write(s.requestID, gomock.Any()).Return(nil).Times(0) + + err := s.orderPlace.Do(s.requestID, s.orderPlaceRequest) + s.ErrorIs(err, ErrorSecretKeyIsNotSet) +} + +func (s *orderPlaceServiceWsTestSuite) TestOrderPlace_EmptySignKeyType() { + s.expectCalls(s.apiKey, s.secretKey, "", s.timeOffset) + + s.client.EXPECT().Write(s.requestID, gomock.Any()).Return(nil).Times(0) + + err := s.orderPlace.Do(s.requestID, s.orderPlaceRequest) + s.Error(err) +} + +func (s *orderPlaceServiceWsTestSuite) TestOrderPlaceSync() { + s.expectCalls(s.apiKey, s.secretKey, s.signedKey, s.timeOffset) + + orderPlaceResponse := CreateOrderWsResponse{ + Id: s.requestID, + Status: 200, + Result: CreateOrderResult{ + CreateOrderResponse{ + Symbol: s.symbol, + OrderID: 0, + ClientOrderID: s.newClientOrderID, + Price: s.price, + TimeInForce: s.timeInForce, + Type: s.orderType, + Side: s.side, + }, + }, + } + + rawResponseData, err := json.Marshal(orderPlaceResponse) + s.NoError(err) + + s.client.EXPECT().WriteSync(s.requestID, gomock.Any(), gomock.Any()).Return(rawResponseData, nil).Times(1) + + req := s.orderPlaceRequest + response, err := s.orderPlace.SyncDo(s.requestID, req) + s.Require().NoError(err) + s.Equal(*req.newClientOrderID, response.Result.ClientOrderID) + s.Equal(req.symbol, response.Result.Symbol) + s.Equal(req.orderType, response.Result.Type) + s.Equal(*req.price, response.Result.Price) +} + +func (s *orderPlaceServiceWsTestSuite) TestOrderPlaceSync_EmptyRequestID() { + s.expectCalls(s.apiKey, s.secretKey, s.signedKey, s.timeOffset) + + s.client.EXPECT(). + WriteSync(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("write sync: error")).Times(0) + + req := s.orderPlaceRequest + response, err := s.orderPlace.SyncDo("", req) + s.Nil(response) + s.ErrorIs(err, ErrorRequestIDNotSet) +} + +func (s *orderPlaceServiceWsTestSuite) TestOrderPlaceSync_EmptyApiKey() { + s.expectCalls("", s.secretKey, s.signedKey, s.timeOffset) + + s.client.EXPECT(). + WriteSync(s.requestID, gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("write sync: error")).Times(0) + + response, err := s.orderPlace.SyncDo(s.requestID, s.orderPlaceRequest) + s.Nil(response) + s.ErrorIs(err, ErrorApiKeyIsNotSet) +} + +func (s *orderPlaceServiceWsTestSuite) TestOrderPlaceSync_EmptySecretKey() { + s.expectCalls(s.apiKey, "", s.signedKey, s.timeOffset) + + s.client.EXPECT(). + WriteSync(s.requestID, gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("write sync: error")).Times(0) + + response, err := s.orderPlace.SyncDo(s.requestID, s.orderPlaceRequest) + s.Nil(response) + s.ErrorIs(err, ErrorSecretKeyIsNotSet) +} + +func (s *orderPlaceServiceWsTestSuite) TestOrderPlaceSync_EmptySignKeyType() { + s.expectCalls(s.apiKey, s.secretKey, "", s.timeOffset) + + s.client.EXPECT(). + WriteSync(s.requestID, gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("write sync: error")).Times(0) + + response, err := s.orderPlace.SyncDo(s.requestID, s.orderPlaceRequest) + s.Nil(response) + s.Error(err) +} + +func (s *orderPlaceServiceWsTestSuite) expectCalls(apiKey, secretKey, signKeyType string, timeOffset int64) { + s.client.EXPECT().GetApiKey().Return(apiKey).AnyTimes() + s.client.EXPECT().GetSecretKey().Return(secretKey).AnyTimes() + s.client.EXPECT().GetKeyType().Return(signKeyType).AnyTimes() + s.client.EXPECT().GetTimeOffset().Return(timeOffset).AnyTimes() +} diff --git a/v2/futures/request.go b/v2/futures/request.go index 396dee6a..0d218c21 100644 --- a/v2/futures/request.go +++ b/v2/futures/request.go @@ -17,6 +17,15 @@ const ( type params map[string]interface{} +// encode encodes the parameters to a URL encoded string +func (p *params) encode() string { + queryValues := url.Values{} + for key, value := range *p { + queryValues.Add(key, fmt.Sprintf("%v", value)) + } + return queryValues.Encode() +} + // request define an API request type request struct { method string diff --git a/v2/futures/websocket.go b/v2/futures/websocket.go index 77519a8b..b379380f 100644 --- a/v2/futures/websocket.go +++ b/v2/futures/websocket.go @@ -108,3 +108,27 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) { } }() } + +var WsGetReadWriteConnection = func(cfg *WsConfig) (*websocket.Conn, error) { + proxy := http.ProxyFromEnvironment + if cfg.Proxy != nil { + u, err := url.Parse(*cfg.Proxy) + if err != nil { + return nil, err + } + proxy = http.ProxyURL(u) + } + + Dialer := websocket.Dialer{ + Proxy: proxy, + HandshakeTimeout: 45 * time.Second, + EnableCompression: false, + } + + c, _, err := Dialer.Dial(cfg.Endpoint, nil) + if err != nil { + return nil, err + } + + return c, nil +} diff --git a/v2/futures/websocket_api_types.go b/v2/futures/websocket_api_types.go new file mode 100644 index 00000000..64432405 --- /dev/null +++ b/v2/futures/websocket_api_types.go @@ -0,0 +1,85 @@ +package futures + +import ( + "encoding/json" + "errors" + "time" + + "github.com/adshao/go-binance/v2/common" +) + +// WsApiMethodType define method name for websocket API +type WsApiMethodType string + +// WsApiRequest define common websocket API request +type WsApiRequest struct { + Id string `json:"id"` + Method WsApiMethodType `json:"method"` + Params params `json:"params"` +} + +const ( + // apiKey define key for websocket API parameters + apiKey = "apiKey" + + // OrderPlaceWsApiMethod define method for creation order via websocket API + OrderPlaceWsApiMethod WsApiMethodType = "order.place" + + // CancelWsApiMethod define method for cancel order via websocket API + CancelWsApiMethod WsApiMethodType = "order.cancel" + + // WriteSyncWsTimeout defines timeout for WriteSync method of client_ws + WriteSyncWsTimeout = 5 * time.Second +) + +var ( + // ErrorRequestIDNotSet defines that request ID is not set + ErrorRequestIDNotSet = errors.New("ws service: request id is not set") + + // ErrorApiKeyIsNotSet defines that ApiKey is not set + ErrorApiKeyIsNotSet = errors.New("ws service: api key is not set") + + // ErrorSecretKeyIsNotSet defines that SecretKey is not set + ErrorSecretKeyIsNotSet = errors.New("ws service: secret key is not set") +) + +// createWsRequest creates signed ws request +func createWsRequest(requestID string, client wsClient, method WsApiMethodType, params params) ([]byte, error) { + if requestID == "" { + return nil, ErrorRequestIDNotSet + } + + if client.GetApiKey() == "" { + return nil, ErrorApiKeyIsNotSet + } + + if client.GetSecretKey() == "" { + return nil, ErrorSecretKeyIsNotSet + } + + params[apiKey] = client.GetApiKey() + params[timestampKey] = currentTimestamp() - client.GetTimeOffset() + + sf, err := common.SignFunc(client.GetKeyType()) + if err != nil { + return nil, err + } + signature, err := sf(client.GetSecretKey(), params.encode()) + if err != nil { + return nil, err + } + params[signatureKey] = signature + + req := WsApiRequest{ + Id: requestID, + Method: method, + Params: params, + } + + rawData, err := json.Marshal(req) + if err != nil { + return nil, err + } + + return rawData, nil +} diff --git a/v2/futures/websocket_service.go b/v2/futures/websocket_service.go index d488bed6..38a7fdf8 100644 --- a/v2/futures/websocket_service.go +++ b/v2/futures/websocket_service.go @@ -6,7 +6,7 @@ import ( "fmt" "strings" "time" - + "github.com/gorilla/websocket" "github.com/bitly/go-simplejson" ) @@ -16,6 +16,9 @@ const ( baseWsTestnetUrl = "wss://stream.binancefuture.com/ws" baseCombinedMainURL = "wss://fstream.binance.com/stream?streams=" baseCombinedTestnetURL = "wss://stream.binancefuture.com/stream?streams=" + BaseWsApiMainURL = "wss://ws-fapi.binance.com/ws-fapi/v1" + BaseWsApiTestnetURL = "wss://testnet.binancefuture.com/ws-fapi/v1" + localhostWsApiURL = "ws://localhost:8080/ws" ) var ( @@ -26,6 +29,11 @@ var ( // UseTestnet switch all the WS streams from production to the testnet UseTestnet = false ProxyUrl = "" + // WebsocketTimeoutReadWriteConnection is an interval for sending ping/pong messages if WebsocketKeepalive is enabled + // using for websocket API (read/write) + WebsocketTimeoutReadWriteConnection = time.Second * 10 + // useLocalhost switch all the WS streams from production to localhost testing + useLocalhost = false ) func getWsProxyUrl() *string { @@ -1189,3 +1197,27 @@ func WsUserDataServe(listenKey string, handler WsUserDataHandler, errHandler Err } return wsServe(cfg, wsHandler, errHandler) } + +// WsApiInitReadWriteConn create and serve connection +func WsApiInitReadWriteConn() (*websocket.Conn, error) { + cfg := newWsConfig(getWsApiEndpoint()) + conn, err := WsGetReadWriteConnection(cfg) + if err != nil { + return nil, err + } + + return conn, err +} + +// getWsApiEndpoint return the base endpoint of the API WS according the UseTestnet flag +func getWsApiEndpoint() string { + if UseTestnet { + return BaseWsApiTestnetURL + } + + if useLocalhost { + return localhostWsApiURL + } + + return BaseWsApiMainURL +} diff --git a/v2/go.mod b/v2/go.mod index 4c8da8fc..ca758f6d 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -4,7 +4,10 @@ go 1.18 require ( github.com/bitly/go-simplejson v0.5.0 + github.com/golang/mock v1.6.0 + github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.0 + github.com/jpillora/backoff v1.0.0 github.com/json-iterator/go v1.1.12 github.com/stretchr/testify v1.8.1 ) diff --git a/v2/go.sum b/v2/go.sum index ccb2f234..35e6e106 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -5,9 +5,15 @@ github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dR github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= @@ -30,6 +36,29 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=