From e743e9326f744cbfcb2af44d093dd25130308afd Mon Sep 17 00:00:00 2001 From: lixizan Date: Tue, 23 Apr 2024 14:24:53 +0800 Subject: [PATCH] Fix: ReadMaxPayloadSize limit may be exceeded when receiving content with very high compression rate. --- benchmark_test.go | 2 +- client.go | 2 +- compress.go | 30 ++++++++++++++++++++++++++---- task_test.go | 4 ++-- upgrader.go | 2 +- writer_test.go | 31 +++++++++++++++++++++++++++++++ 6 files changed, 62 insertions(+), 9 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index c291c582..d6039687 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -97,7 +97,7 @@ func BenchmarkConn_ReadMessage(b *testing.B) { config: config, deflater: new(deflater), } - conn1.deflater.initialize(false, conn1.pd) + conn1.deflater.initialize(false, conn1.pd, config.ReadMaxPayloadSize) var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), false) var reader = bytes.NewBuffer(buf.Bytes()) diff --git a/client.go b/client.go index fc4912cd..eca03361 100644 --- a/client.go +++ b/client.go @@ -175,7 +175,7 @@ func (c *connector) handshake() (*Conn, *http.Response, error) { readQueue: make(channel, c.option.ParallelGolimit), } if pd.Enabled { - socket.deflater.initialize(false, pd) + socket.deflater.initialize(false, pd, c.option.ReadMaxPayloadSize) if pd.ServerContextTakeover { socket.dpsWindow.initialize(nil, pd.ServerMaxWindowBits) } diff --git a/compress.go b/compress.go index 989f28cb..a167a47b 100644 --- a/compress.go +++ b/compress.go @@ -24,10 +24,10 @@ type deflaterPool struct { pool []*deflater } -func (c *deflaterPool) initialize(options PermessageDeflate) *deflaterPool { +func (c *deflaterPool) initialize(options PermessageDeflate, limit int) *deflaterPool { c.num = uint64(options.PoolSize) for i := uint64(0); i < c.num; i++ { - c.pool = append(c.pool, new(deflater).initialize(true, options)) + c.pool = append(c.pool, new(deflater).initialize(true, options, limit)) } return c } @@ -39,15 +39,19 @@ func (c *deflaterPool) Select() *deflater { type deflater struct { dpsLocker sync.Mutex + buf []byte + limit int dpsBuffer *bytes.Buffer dpsReader io.ReadCloser cpsLocker sync.Mutex cpsWriter *flate.Writer } -func (c *deflater) initialize(isServer bool, options PermessageDeflate) *deflater { +func (c *deflater) initialize(isServer bool, options PermessageDeflate, limit int) *deflater { c.dpsReader = flate.NewReader(nil) c.dpsBuffer = bytes.NewBuffer(nil) + c.buf = make([]byte, 32*1024) + c.limit = limit windowBits := internal.SelectValue(isServer, options.ServerMaxWindowBits, options.ClientMaxWindowBits) if windowBits == 15 { c.cpsWriter, _ = flate.NewWriter(nil, options.Level) @@ -73,7 +77,8 @@ func (c *deflater) Decompress(src *bytes.Buffer, dict []byte) (*bytes.Buffer, er _, _ = src.Write(flateTail) c.resetFR(src, dict) - if _, err := c.dpsReader.(io.WriterTo).WriteTo(c.dpsBuffer); err != nil { + reader := limitReader(c.dpsReader, c.limit) + if _, err := io.CopyBuffer(c.dpsBuffer, reader, c.buf); err != nil { return nil, err } var dst = binaryPool.Get(c.dpsBuffer.Len()) @@ -223,3 +228,20 @@ func permessageNegotiation(str string) PermessageDeflate { options.ServerMaxWindowBits = internal.SelectValue(options.ServerMaxWindowBits < 8, 8, options.ServerMaxWindowBits) return options } + +func limitReader(r io.Reader, limit int) io.Reader { return &limitedReader{R: r, M: limit} } + +type limitedReader struct { + R io.Reader + N int + M int +} + +func (c *limitedReader) Read(p []byte) (n int, err error) { + n, err = c.R.Read(p) + c.N += n + if c.N > c.M { + return n, internal.CloseMessageTooLarge + } + return +} diff --git a/task_test.go b/task_test.go index 436d3a31..78fc336c 100644 --- a/task_test.go +++ b/task_test.go @@ -40,7 +40,7 @@ func serveWebSocket( } if compressEnabled { if isServer { - socket.deflater = new(deflaterPool).initialize(pd).Select() + socket.deflater = new(deflaterPool).initialize(pd, config.ReadMaxPayloadSize).Select() if pd.ServerContextTakeover { socket.cpsWindow.initialize(config.cswPool, pd.ServerMaxWindowBits) } @@ -48,7 +48,7 @@ func serveWebSocket( socket.dpsWindow.initialize(config.dswPool, pd.ClientMaxWindowBits) } } else { - socket.deflater = new(deflater).initialize(false, pd) + socket.deflater = new(deflater).initialize(false, pd, config.ReadMaxPayloadSize) } } return socket diff --git a/upgrader.go b/upgrader.go index 34a6886c..da680c62 100644 --- a/upgrader.go +++ b/upgrader.go @@ -84,7 +84,7 @@ func NewUpgrader(eventHandler Event, option *ServerOption) *Upgrader { deflaterPool: new(deflaterPool), } if u.option.PermessageDeflate.Enabled { - u.deflaterPool.initialize(u.option.PermessageDeflate) + u.deflaterPool.initialize(u.option.PermessageDeflate, option.ReadMaxPayloadSize) } return u } diff --git a/writer_test.go b/writer_test.go index 47237987..25119bb0 100644 --- a/writer_test.go +++ b/writer_test.go @@ -3,6 +3,7 @@ package gws import ( "bufio" "bytes" + "errors" "io" "net" "net/http" @@ -79,6 +80,36 @@ func TestWriteBigMessage(t *testing.T) { var err = server.WriteMessage(OpcodeText, internal.AlphabetNumeric.Generate(128)) assert.Error(t, err) }) + + t.Run("", func(t *testing.T) { + var wg = &sync.WaitGroup{} + wg.Add(1) + var serverHandler = new(webSocketMocker) + var clientHandler = new(webSocketMocker) + serverHandler.onClose = func(socket *Conn, err error) { + assert.True(t, errors.Is(err, internal.CloseMessageTooLarge)) + wg.Done() + } + var serverOption = &ServerOption{ + ReadMaxPayloadSize: 128, + PermessageDeflate: PermessageDeflate{Enabled: true, Threshold: 1}, + } + var clientOption = &ClientOption{ + ReadMaxPayloadSize: 128 * 1024, + PermessageDeflate: PermessageDeflate{Enabled: true, Threshold: 1}, + } + server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption) + go server.ReadLoop() + go client.ReadLoop() + + var buf = bytes.NewBufferString("") + for i := 0; i < 64*1024; i++ { + buf.WriteString("a") + } + var err = client.WriteMessage(OpcodeText, buf.Bytes()) + assert.NoError(t, err) + wg.Wait() + }) } func TestWriteClose(t *testing.T) {