Skip to content

Commit

Permalink
change_for_knative_serving
Browse files Browse the repository at this point in the history
  • Loading branch information
Gekko0114 committed Apr 29, 2023
1 parent 68306cb commit a690d6e
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 35 deletions.
48 changes: 34 additions & 14 deletions websocket/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,36 +55,56 @@ type rawConnection interface {
SetReadDeadline(deadline time.Time) error
Read(p []byte) (n int, err error)
Write(p []byte) (n int, err error)
WriteClientMessage(w io.Writer, op ws.OpCode, p []byte) error
NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error)
WriteMessage(op ws.OpCode, p []byte) error
NextReader() (ws.Header, io.Reader, error)
ReadMessage() (messageType ws.OpCode, p []byte, err error)
}

type netConnExtension struct {
type NetConnExtension struct {
conn net.Conn
}

func (nc *netConnExtension) Read(p []byte) (n int, err error) {
func NewNetConnExtension(conn net.Conn) *NetConnExtension {
nc := &NetConnExtension{
conn: conn,
}
return nc
}

func (nc *NetConnExtension) Read(p []byte) (n int, err error) {
return nc.conn.Read(p)
}

func (nc *netConnExtension) Write(p []byte) (n int, err error) {
func (nc *NetConnExtension) Write(p []byte) (n int, err error) {
return nc.conn.Write(p)
}

func (nc *netConnExtension) Close() error {
func (nc *NetConnExtension) Close() error {
return nc.conn.Close()
}

func (nc *netConnExtension) SetReadDeadline(deadline time.Time) error {
func (nc *NetConnExtension) SetReadDeadline(deadline time.Time) error {
return nc.conn.SetReadDeadline(deadline)
}

func (nc *netConnExtension) WriteClientMessage(w io.Writer, op ws.OpCode, p []byte) error {
return wsutil.WriteClientMessage(w, op, p)
func (nc *NetConnExtension) WriteMessage(op ws.OpCode, p []byte) error {
return wsutil.WriteClientMessage(nc, op, p)
}

func (nc *netConnExtension) NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
return wsutil.NextReader(r, s)
func (nc *NetConnExtension) NextReader() (ws.Header, io.Reader, error) {
return wsutil.NextReader(nc, ws.StateServerSide)
}

func (nc *NetConnExtension) ReadMessage() (messageType ws.OpCode, p []byte, err error) {
var r io.Reader
var header ws.Header
header, r, err = nc.NextReader()
messageType = header.OpCode
if err != nil {
return messageType, nil, err
}
p, err = io.ReadAll(r)
return messageType, p, err
}

// ManagedConnection represents a websocket connection.
Expand Down Expand Up @@ -172,7 +192,7 @@ func NewDurableConnection(target string, messageChan chan []byte, logger *zap.Su
logger.Errorw("Websocket connection could not be established", zap.Error(err))
}
}
nc := &netConnExtension{
nc := &NetConnExtension{
conn: conn,
}
return nc, err
Expand Down Expand Up @@ -321,7 +341,7 @@ func (c *ManagedConnection) read() error {

c.connection.SetReadDeadline(time.Now().Add(pongTimeout))

header, reader, err := c.connection.NextReader(c.connection, ws.StateClientSide)
header, reader, err := c.connection.NextReader()
messageType := header.OpCode
if err != nil {
return err
Expand Down Expand Up @@ -349,7 +369,7 @@ func (c *ManagedConnection) write(messageType ws.OpCode, body []byte) error {

c.writerLock.Lock()
defer c.writerLock.Unlock()
return c.connection.WriteClientMessage(c.connection, messageType, body)
return c.connection.WriteMessage(messageType, body)
}

// Status checks the connection status of the webhook.
Expand Down
51 changes: 30 additions & 21 deletions websocket/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@ import (
const propagationTimeout = 5 * time.Second

type inspectableConnection struct {
closeCalls chan struct{}
setReadDeadlineCalls chan struct{}
writeClientMessageCalls chan struct{}
nextReaderCalls chan struct{}

readFunc func() (int, error)
writeFunc func() (int, error)
nextReaderFunc func() (ws.Header, io.Reader, error)
closeCalls chan struct{}
setReadDeadlineCalls chan struct{}
writeMessageCalls chan struct{}
nextReaderCalls chan struct{}
readMessageCalls chan struct{}

readFunc func() (int, error)
writeFunc func() (int, error)
nextReaderFunc func() (ws.Header, io.Reader, error)
readMessageFunc func() (messageType ws.OpCode, p []byte, err error)
}

func (c *inspectableConnection) Close() error {
Expand All @@ -67,20 +69,27 @@ func (c *inspectableConnection) Write(_ []byte) (n int, err error) {
return c.writeFunc()
}

func (c *inspectableConnection) WriteClientMessage(_ io.Writer, _ ws.OpCode, _ []byte) error {
if c.writeClientMessageCalls != nil {
c.writeClientMessageCalls <- struct{}{}
func (c *inspectableConnection) WriteMessage(_ ws.OpCode, _ []byte) error {
if c.writeMessageCalls != nil {
c.writeMessageCalls <- struct{}{}
}
return nil
}

func (c *inspectableConnection) NextReader(_ io.Reader, _ ws.State) (ws.Header, io.Reader, error) {
func (c *inspectableConnection) NextReader() (ws.Header, io.Reader, error) {
if c.nextReaderCalls != nil {
c.nextReaderCalls <- struct{}{}
}
return c.nextReaderFunc()
}

func (c *inspectableConnection) ReadMessage() (messageType ws.OpCode, p []byte, err error) {
if c.readMessageCalls != nil {
c.readMessageCalls <- struct{}{}
}
return c.readMessageFunc()
}

// staticConnFactory returns a static connection, for example
// an inspectable connection.
func staticConnFactory(conn rawConnection) func() (rawConnection, error) {
Expand Down Expand Up @@ -155,7 +164,7 @@ func TestStatusOnNoConnection(t *testing.T) {

func TestSendErrorOnEncode(t *testing.T) {
spy := &inspectableConnection{
writeClientMessageCalls: make(chan struct{}, 1),
writeMessageCalls: make(chan struct{}, 1),
}
conn := newConnection(staticConnFactory(spy), nil)
conn.connect()
Expand All @@ -165,14 +174,14 @@ func TestSendErrorOnEncode(t *testing.T) {
if got == nil {
t.Fatal("Expected an error but got none")
}
if len(spy.writeClientMessageCalls) != 0 {
t.Fatalf("Expected 'WriteClientMessage' not to be called, but was called %v times", spy.writeClientMessageCalls)
if len(spy.writeMessageCalls) != 0 {
t.Fatalf("Expected 'WriteClientMessage' not to be called, but was called %v times", spy.writeMessageCalls)
}
}

func TestSendMessage(t *testing.T) {
spy := &inspectableConnection{
writeClientMessageCalls: make(chan struct{}, 1),
writeMessageCalls: make(chan struct{}, 1),
}
conn := newConnection(staticConnFactory(spy), nil)
conn.connect()
Expand All @@ -184,14 +193,14 @@ func TestSendMessage(t *testing.T) {
if got := conn.Send("test"); got != nil {
t.Fatalf("Expected no error but got: %+v", got)
}
if len(spy.writeClientMessageCalls) != 1 {
t.Fatalf("Expected 'WriteClientMessage' to be called once, but was called %v times", spy.writeClientMessageCalls)
if len(spy.writeMessageCalls) != 1 {
t.Fatalf("Expected 'WriteClientMessage' to be called once, but was called %v times", spy.writeMessageCalls)
}
}

func TestSendRawMessage(t *testing.T) {
spy := &inspectableConnection{
writeClientMessageCalls: make(chan struct{}, 1),
writeMessageCalls: make(chan struct{}, 1),
}
conn := newConnection(staticConnFactory(spy), nil)
conn.connect()
Expand All @@ -203,8 +212,8 @@ func TestSendRawMessage(t *testing.T) {
if got := conn.SendRaw(ws.OpBinary, []byte("test")); got != nil {
t.Fatalf("Expected no error but got: %+v", got)
}
if len(spy.writeClientMessageCalls) != 1 {
t.Fatalf("Expected 'WriteClientMessage' to be called once, but was called %v times", spy.writeClientMessageCalls)
if len(spy.writeMessageCalls) != 1 {
t.Fatalf("Expected 'WriteClientMessage' to be called once, but was called %v times", spy.writeMessageCalls)
}
}

Expand Down

0 comments on commit a690d6e

Please sign in to comment.