diff --git a/examples/client/client.go b/examples/client/client.go index 72bf48b..c7b2662 100644 --- a/examples/client/client.go +++ b/examples/client/client.go @@ -17,7 +17,7 @@ func main() { } conn.OnClose(types.OnCloseHandler) conn.OnError(types.OnErrorHandler) - conn.OnMessage(func(c *wshelper.Connection, data wshelper.Payload) { + conn.OnMessage(func(c *wshelper.Connection, mtype websocket.MessageType, data wshelper.Payload) { var p types.Message err := data.Into(&p) if err != nil { diff --git a/examples/server/server.go b/examples/server/server.go index d9586f9..ab4d6ed 100644 --- a/examples/server/server.go +++ b/examples/server/server.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "net/http" + "nhooyr.io/websocket" "github.com/BOOMfinity-Developers/wshelper" "github.com/BOOMfinity-Developers/wshelper/examples/types" @@ -21,7 +22,7 @@ func main() { } conn.OnClose(types.OnCloseHandler) conn.OnError(types.OnErrorHandler) - conn.OnMessage(func(c *wshelper.Connection, data wshelper.Payload) { + conn.OnMessage(func(c *wshelper.Connection, mtype websocket.MessageType, data wshelper.Payload) { var p types.Message err := data.Into(&p) if err != nil { diff --git a/websocket.go b/websocket.go index 5c8df19..9680d9d 100644 --- a/websocket.go +++ b/websocket.go @@ -2,6 +2,7 @@ package wshelper import ( + "bytes" "context" "encoding/json" "errors" @@ -31,9 +32,15 @@ var ( } ) -type CloseHandler func(connection *Connection, code websocket.StatusCode, reason string) -type ErrorHandler func(connection *Connection, err error) -type MessageHandler func(connection *Connection, data Payload) +var ( + EmptyCloseHandler = CloseHandler(func(connection *Connection, code websocket.StatusCode, reason string) {}) + EmptyErrorHandler = ErrorHandler(func(connection *Connection, err error) {}) +) + +type CloseHandler func(conn *Connection, code websocket.StatusCode, reason string) +type ErrorHandler func(conn *Connection, err error) +type MessageHandler func(conn *Connection, mtype websocket.MessageType, data Payload) +type MessageBufferHandler func(conn *Connection, mtype websocket.MessageType, data *bytes.Buffer) type Payload []byte @@ -48,7 +55,7 @@ type Handler struct { conn *Connection id uint64 synchronous bool - run MessageHandler + run interface{} } func (h *Handler) SetAsync(enabled bool) { @@ -59,6 +66,29 @@ func (h Handler) Delete() { h.conn.removeHandler(h.id) } +func (h Handler) exec(mtype websocket.MessageType, data []byte) { + switch run := h.run.(type) { + case MessageHandler: + run(h.conn, mtype, data) + break + case MessageBufferHandler: + run(h.conn, mtype, bytes.NewBuffer(data)) + break + } +} + +func (h Handler) check() (ok bool) { + switch h.run.(type) { + case MessageHandler: + ok = true + break + case MessageBufferHandler: + ok = true + break + } + return +} + type Connection struct { onClose *CloseHandler onError *ErrorHandler @@ -67,12 +97,24 @@ type Connection struct { handlerID *atomic.Uint64 mutex sync.Mutex uuid string + closed bool } func (c *Connection) UUID() string { return c.uuid } +func (c *Connection) Close(status websocket.StatusCode, reason string) error { + if c.closed == true { + return nil + } + if err := c.WS().Close(status, reason); err != nil { + return err + } + c.closed = true + return nil +} + func (c *Connection) WS() *websocket.Conn { return c.ws } @@ -110,37 +152,52 @@ func (c *Connection) OnError(h ErrorHandler) { c.onError = &h } -func (c *Connection) OnMessage(h MessageHandler) *Handler { +func (c *Connection) newHandler() *Handler { p := new(Handler) p.id = c.handlerID.Inc() p.conn = c - p.run = h c.mutex.Lock() c.handlers[p.id] = p c.mutex.Unlock() return p } +func (c *Connection) OnMessage(h MessageHandler) *Handler { + p := c.newHandler() + p.run = h + return p +} + +func (c *Connection) OnMessageBuffer(h MessageBufferHandler) *Handler { + p := c.newHandler() + p.run = h + return p +} + func (c *Connection) loop() { for { - _, data, err := c.ws.Read(context.Background()) + if c.closed { + return + } + t, data, err := c.ws.Read(context.Background()) if err != nil { var closeErr websocket.CloseError if errors.As(err, &closeErr) && c.onClose != nil { (*c.onClose)(c, closeErr.Code, closeErr.Reason) + c.closed = true break } if c.onError != nil { (*c.onError)(c, err) } - break + continue } c.mutex.Lock() for _, h := range c.handlers { if h.synchronous { - h.run(c, data) + h.exec(t, data) } else { - go h.run(c, data) + go h.exec(t, data) } } c.mutex.Unlock()