From a690d6ee623b85493e6272c727e66359954d9c52 Mon Sep 17 00:00:00 2001 From: moriya Date: Sat, 29 Apr 2023 22:39:33 +0900 Subject: [PATCH] change_for_knative_serving --- websocket/connection.go | 48 +++++++++++++++++++++++---------- websocket/connection_test.go | 51 +++++++++++++++++++++--------------- 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/websocket/connection.go b/websocket/connection.go index a22d79df53..4616e8380d 100644 --- a/websocket/connection.go +++ b/websocket/connection.go @@ -55,36 +55,56 @@ type rawConnection interface { SetReadDeadline(deadline time.Time) error Read(p []byte) (n int, err error) Write(p []byte) (n int, err error) - WriteClientMessage(w io.Writer, op ws.OpCode, p []byte) error - NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) + WriteMessage(op ws.OpCode, p []byte) error + NextReader() (ws.Header, io.Reader, error) + ReadMessage() (messageType ws.OpCode, p []byte, err error) } -type netConnExtension struct { +type NetConnExtension struct { conn net.Conn } -func (nc *netConnExtension) Read(p []byte) (n int, err error) { +func NewNetConnExtension(conn net.Conn) *NetConnExtension { + nc := &NetConnExtension{ + conn: conn, + } + return nc +} + +func (nc *NetConnExtension) Read(p []byte) (n int, err error) { return nc.conn.Read(p) } -func (nc *netConnExtension) Write(p []byte) (n int, err error) { +func (nc *NetConnExtension) Write(p []byte) (n int, err error) { return nc.conn.Write(p) } -func (nc *netConnExtension) Close() error { +func (nc *NetConnExtension) Close() error { return nc.conn.Close() } -func (nc *netConnExtension) SetReadDeadline(deadline time.Time) error { +func (nc *NetConnExtension) SetReadDeadline(deadline time.Time) error { return nc.conn.SetReadDeadline(deadline) } -func (nc *netConnExtension) WriteClientMessage(w io.Writer, op ws.OpCode, p []byte) error { - return wsutil.WriteClientMessage(w, op, p) +func (nc *NetConnExtension) WriteMessage(op ws.OpCode, p []byte) error { + return wsutil.WriteClientMessage(nc, op, p) } -func (nc *netConnExtension) NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) { - return wsutil.NextReader(r, s) +func (nc *NetConnExtension) NextReader() (ws.Header, io.Reader, error) { + return wsutil.NextReader(nc, ws.StateServerSide) +} + +func (nc *NetConnExtension) ReadMessage() (messageType ws.OpCode, p []byte, err error) { + var r io.Reader + var header ws.Header + header, r, err = nc.NextReader() + messageType = header.OpCode + if err != nil { + return messageType, nil, err + } + p, err = io.ReadAll(r) + return messageType, p, err } // ManagedConnection represents a websocket connection. @@ -172,7 +192,7 @@ func NewDurableConnection(target string, messageChan chan []byte, logger *zap.Su logger.Errorw("Websocket connection could not be established", zap.Error(err)) } } - nc := &netConnExtension{ + nc := &NetConnExtension{ conn: conn, } return nc, err @@ -321,7 +341,7 @@ func (c *ManagedConnection) read() error { c.connection.SetReadDeadline(time.Now().Add(pongTimeout)) - header, reader, err := c.connection.NextReader(c.connection, ws.StateClientSide) + header, reader, err := c.connection.NextReader() messageType := header.OpCode if err != nil { return err @@ -349,7 +369,7 @@ func (c *ManagedConnection) write(messageType ws.OpCode, body []byte) error { c.writerLock.Lock() defer c.writerLock.Unlock() - return c.connection.WriteClientMessage(c.connection, messageType, body) + return c.connection.WriteMessage(messageType, body) } // Status checks the connection status of the webhook. diff --git a/websocket/connection_test.go b/websocket/connection_test.go index 6ae27173b3..feef620c2e 100644 --- a/websocket/connection_test.go +++ b/websocket/connection_test.go @@ -35,14 +35,16 @@ import ( const propagationTimeout = 5 * time.Second type inspectableConnection struct { - closeCalls chan struct{} - setReadDeadlineCalls chan struct{} - writeClientMessageCalls chan struct{} - nextReaderCalls chan struct{} - - readFunc func() (int, error) - writeFunc func() (int, error) - nextReaderFunc func() (ws.Header, io.Reader, error) + closeCalls chan struct{} + setReadDeadlineCalls chan struct{} + writeMessageCalls chan struct{} + nextReaderCalls chan struct{} + readMessageCalls chan struct{} + + readFunc func() (int, error) + writeFunc func() (int, error) + nextReaderFunc func() (ws.Header, io.Reader, error) + readMessageFunc func() (messageType ws.OpCode, p []byte, err error) } func (c *inspectableConnection) Close() error { @@ -67,20 +69,27 @@ func (c *inspectableConnection) Write(_ []byte) (n int, err error) { return c.writeFunc() } -func (c *inspectableConnection) WriteClientMessage(_ io.Writer, _ ws.OpCode, _ []byte) error { - if c.writeClientMessageCalls != nil { - c.writeClientMessageCalls <- struct{}{} +func (c *inspectableConnection) WriteMessage(_ ws.OpCode, _ []byte) error { + if c.writeMessageCalls != nil { + c.writeMessageCalls <- struct{}{} } return nil } -func (c *inspectableConnection) NextReader(_ io.Reader, _ ws.State) (ws.Header, io.Reader, error) { +func (c *inspectableConnection) NextReader() (ws.Header, io.Reader, error) { if c.nextReaderCalls != nil { c.nextReaderCalls <- struct{}{} } return c.nextReaderFunc() } +func (c *inspectableConnection) ReadMessage() (messageType ws.OpCode, p []byte, err error) { + if c.readMessageCalls != nil { + c.readMessageCalls <- struct{}{} + } + return c.readMessageFunc() +} + // staticConnFactory returns a static connection, for example // an inspectable connection. func staticConnFactory(conn rawConnection) func() (rawConnection, error) { @@ -155,7 +164,7 @@ func TestStatusOnNoConnection(t *testing.T) { func TestSendErrorOnEncode(t *testing.T) { spy := &inspectableConnection{ - writeClientMessageCalls: make(chan struct{}, 1), + writeMessageCalls: make(chan struct{}, 1), } conn := newConnection(staticConnFactory(spy), nil) conn.connect() @@ -165,14 +174,14 @@ func TestSendErrorOnEncode(t *testing.T) { if got == nil { t.Fatal("Expected an error but got none") } - if len(spy.writeClientMessageCalls) != 0 { - t.Fatalf("Expected 'WriteClientMessage' not to be called, but was called %v times", spy.writeClientMessageCalls) + if len(spy.writeMessageCalls) != 0 { + t.Fatalf("Expected 'WriteClientMessage' not to be called, but was called %v times", spy.writeMessageCalls) } } func TestSendMessage(t *testing.T) { spy := &inspectableConnection{ - writeClientMessageCalls: make(chan struct{}, 1), + writeMessageCalls: make(chan struct{}, 1), } conn := newConnection(staticConnFactory(spy), nil) conn.connect() @@ -184,14 +193,14 @@ func TestSendMessage(t *testing.T) { if got := conn.Send("test"); got != nil { t.Fatalf("Expected no error but got: %+v", got) } - if len(spy.writeClientMessageCalls) != 1 { - t.Fatalf("Expected 'WriteClientMessage' to be called once, but was called %v times", spy.writeClientMessageCalls) + if len(spy.writeMessageCalls) != 1 { + t.Fatalf("Expected 'WriteClientMessage' to be called once, but was called %v times", spy.writeMessageCalls) } } func TestSendRawMessage(t *testing.T) { spy := &inspectableConnection{ - writeClientMessageCalls: make(chan struct{}, 1), + writeMessageCalls: make(chan struct{}, 1), } conn := newConnection(staticConnFactory(spy), nil) conn.connect() @@ -203,8 +212,8 @@ func TestSendRawMessage(t *testing.T) { if got := conn.SendRaw(ws.OpBinary, []byte("test")); got != nil { t.Fatalf("Expected no error but got: %+v", got) } - if len(spy.writeClientMessageCalls) != 1 { - t.Fatalf("Expected 'WriteClientMessage' to be called once, but was called %v times", spy.writeClientMessageCalls) + if len(spy.writeMessageCalls) != 1 { + t.Fatalf("Expected 'WriteClientMessage' to be called once, but was called %v times", spy.writeMessageCalls) } }