From fe3d21be041b55c64c4a21c3af0d0779cde57654 Mon Sep 17 00:00:00 2001 From: lxzan Date: Mon, 22 Jan 2024 21:04:39 +0800 Subject: [PATCH 1/2] add writev method --- README.md | 23 ++++++++------ README_CN.md | 2 +- internal/utils.go | 7 +++++ internal/utils_test.go | 30 ++++++++++++++++++ writer.go | 17 ++++++++-- writer_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 136 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index dd61d27d..f57ccee1 100755 --- a/README.md +++ b/README.md @@ -27,18 +27,20 @@ to be processed in a non-blocking way. - Simplicity and Ease of Use - - **User-Friendly**: Simple and clear `WebSocket` Event API design makes server-client interaction easy. - - **Code Efficiency**: Minimizes the amount of code needed to implement complex WebSocket solutions. + - **User-Friendly**: Simple and clear `WebSocket` Event API design makes server-client interaction easy. + - **Code Efficiency**: Minimizes the amount of code needed to implement complex WebSocket solutions. - High-Performance - - **High IOPS Low Latency**: Designed for rapid data transmission and reception, ideal for time-sensitive - applications. - - **Low Memory Usage**: Highly optimized memory multiplexing system to minimize memory usage and reduce your cost of ownership. + - **High IOPS Low Latency**: Designed for rapid data transmission and reception, ideal for time-sensitive + applications. + - **Low Memory Usage**: Highly optimized memory multiplexing system to minimize memory usage and reduce your cost of + ownership. - Reliability and Stability - - **Robust Error Handling**: Advanced mechanisms to manage and mitigate errors, ensuring continuous operation. - - **Well-Developed Test Cases**: Passed all `Autobahn` test cases, fully compliant with `RFC 7692`. 99% unit test coverage, covering almost all conditional branches. + - **Robust Error Handling**: Advanced mechanisms to manage and mitigate errors, ensuring continuous operation. + - **Well-Developed Test Cases**: Passed all `Autobahn` test cases, fully compliant with `RFC 7692`. Unit test + coverage is almost 100%, covering all conditional branches. ### Benchmark @@ -317,9 +319,12 @@ func WriteWithTimeout(socket *gws.Conn, p []byte, timeout time.Duration) error { #### Pub / Sub -Use the event_emitter package to implement the publish-subscribe model. Wrap `gws.Conn` in a structure and implement the GetSubscriberID method to get the subscription ID, which must be unique. The subscription ID is used to identify the subscriber, who can only receive messages on the subject of his subscription. +Use the event_emitter package to implement the publish-subscribe model. Wrap `gws.Conn` in a structure and implement the +GetSubscriberID method to get the subscription ID, which must be unique. The subscription ID is used to identify the +subscriber, who can only receive messages on the subject of his subscription. -This example is useful for building chat rooms or push messages using gws. This means that a user can subscribe to one or more topics via websocket, and when a message is posted to that topic, all subscribers will receive the message. +This example is useful for building chat rooms or push messages using gws. This means that a user can subscribe to one +or more topics via websocket, and when a message is posted to that topic, all subscribers will receive the message. ```go package main diff --git a/README_CN.md b/README_CN.md index 2c5c3713..34241af4 100755 --- a/README_CN.md +++ b/README_CN.md @@ -30,7 +30,7 @@ GWS(Go WebSocket)是一个用 Go 编写的非常简单、快速、可靠且 - 稳定可靠 - **健壮的错误处理**: 管理和减少错误的先进机制,确保持续运行. - - **完善的测试用例**: 通过了所有 `Autobahn` 测试用例, 符合 `RFC 7692` 标准. 单元测试覆盖率达到 99%, 几乎覆盖所有条件分支. + - **完善的测试用例**: 通过了所有 `Autobahn` 测试用例, 符合 `RFC 7692` 标准. 单元测试覆盖率几乎达到 100%, 覆盖所有条件分支. ### 基准测试 diff --git a/internal/utils.go b/internal/utils.go index 18d10b93..4ac09f7e 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -259,3 +259,10 @@ func CheckErrors(errs ...error) error { } return nil } + +func Reduce[T any, S any](arr []T, initialValue S, reducer func(s S, i int, v T) S) S { + for index, value := range arr { + initialValue = reducer(initialValue, index, value) + } + return initialValue +} diff --git a/internal/utils_test.go b/internal/utils_test.go index 4e3dac7f..d531fbb7 100644 --- a/internal/utils_test.go +++ b/internal/utils_test.go @@ -280,3 +280,33 @@ func TestCheckErrors(t *testing.T) { assert.Error(t, CheckErrors(err0, err1, err2)) assert.True(t, errors.Is(CheckErrors(err0, err1, err2), err2)) } + +func TestReduce(t *testing.T) { + t.Run("", func(t *testing.T) { + var arr = []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + var sum = Reduce(arr, 0, func(summarize int, i int, item int) int { + return summarize + item + }) + assert.Equal(t, sum, 55) + }) + + t.Run("", func(t *testing.T) { + var arr []int + var sum = Reduce(arr, 0, func(summarize int, i int, item int) int { + return summarize + item + }) + assert.Equal(t, sum, 0) + }) + + t.Run("", func(t *testing.T) { + var payloads = [][]byte{ + AlphabetNumeric.Generate(10), + AlphabetNumeric.Generate(20), + AlphabetNumeric.Generate(30), + } + var n = Reduce(payloads, 0, func(s int, i int, v []byte) int { + return s + len(v) + }) + assert.Equal(t, n, 60) + }) +} diff --git a/writer.go b/writer.go index 17b6abc2..5ff0cf49 100644 --- a/writer.go +++ b/writer.go @@ -49,14 +49,25 @@ func (c *Conn) WriteMessage(opcode Opcode, payload []byte) error { return err } +// WriteV 批量写入文本/二进制消息, 文本消息应该使用UTF8编码 +// writes batch text/binary messages, text messages should be encoded in UTF8. +func (c *Conn) WriteV(opcode Opcode, payloads ...[]byte) error { + var n = internal.Reduce(payloads, 0, func(s int, i int, v []byte) int { return s + len(v) }) + var buf = binaryPool.Get(n) + for _, item := range payloads { + buf.Write(item) + } + var err = c.WriteMessage(opcode, buf.Bytes()) + binaryPool.Put(buf) + return err +} + // WriteAsync 异步写 // 异步非阻塞地将消息写入到任务队列, 收到回调后才允许回收payload内存 // Asynchronously and non-blockingly write the message to the task queue, allowing the payload memory to be reclaimed only after a callback is received. func (c *Conn) WriteAsync(opcode Opcode, payload []byte, callback func(error)) { c.writeQueue.Push(func() { - var err = c.doWrite(opcode, payload) - c.emitError(err) - if callback != nil { + if err := c.WriteMessage(opcode, payload); callback != nil { callback(err) } }) diff --git a/writer_test.go b/writer_test.go index cb5247d1..16b4b05a 100644 --- a/writer_test.go +++ b/writer_test.go @@ -328,3 +328,73 @@ func TestRecovery(t *testing.T) { as.NoError(client.WriteString("hi")) time.Sleep(100 * time.Millisecond) } + +func TestConn_WriteV(t *testing.T) { + t.Run("", func(t *testing.T) { + var serverHandler = new(webSocketMocker) + var clientHandler = new(webSocketMocker) + var serverOption = &ServerOption{} + var clientOption = &ClientOption{} + var wg = &sync.WaitGroup{} + wg.Add(1) + + serverHandler.onMessage = func(socket *Conn, message *Message) { + if bytes.Equal(message.Bytes(), []byte("hello, world!")) { + wg.Done() + } + } + + server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption) + go server.ReadLoop() + go client.ReadLoop() + + var err = client.WriteV(OpcodeText, [][]byte{ + []byte("he"), + []byte("llo"), + []byte(", world!"), + }...) + assert.NoError(t, err) + wg.Wait() + }) + + t.Run("", func(t *testing.T) { + var serverHandler = new(webSocketMocker) + var clientHandler = new(webSocketMocker) + var serverOption = &ServerOption{ + PermessageDeflate: PermessageDeflate{ + Enabled: true, + ServerContextTakeover: true, + ClientContextTakeover: true, + Threshold: 1, + }, + } + var clientOption = &ClientOption{ + PermessageDeflate: PermessageDeflate{ + Enabled: true, + ServerContextTakeover: true, + ClientContextTakeover: true, + Threshold: 1, + }, + } + var wg = &sync.WaitGroup{} + wg.Add(1) + + serverHandler.onMessage = func(socket *Conn, message *Message) { + if bytes.Equal(message.Bytes(), []byte("hello, world!")) { + wg.Done() + } + } + + server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption) + go server.ReadLoop() + go client.ReadLoop() + + var err = client.WriteV(OpcodeText, [][]byte{ + []byte("he"), + []byte("llo"), + []byte(", world!"), + }...) + assert.NoError(t, err) + wg.Wait() + }) +} From 8efe3c8e1d3f1f0a0f7e0d72566b1c76a0a31e66 Mon Sep 17 00:00:00 2001 From: lxzan Date: Tue, 23 Jan 2024 20:22:50 +0800 Subject: [PATCH 2/2] optimize writev method --- README.md | 1 - README_CN.md | 1 - benchmark_test.go | 4 +- compress.go | 19 +++++---- compress_test.go | 77 ++++++++++++++++++++++++++++++++++++ conn.go | 17 +++----- internal/io.go | 85 +++++++++++++++++++++++++++++++++++++++ internal/io_test.go | 90 ++++++++++++++++++++++++++++++++++++++++++ internal/utils.go | 20 ---------- internal/utils_test.go | 54 ------------------------- reader_test.go | 8 ++-- writer.go | 57 +++++++++++++------------- writer_test.go | 34 +++++++++++++++- 13 files changed, 336 insertions(+), 131 deletions(-) create mode 100644 internal/io.go create mode 100644 internal/io_test.go diff --git a/README.md b/README.md index f57ccee1..010a7ad5 100755 --- a/README.md +++ b/README.md @@ -96,7 +96,6 @@ PASS - [x] Broadcast - [x] Dial via Proxy - [x] Context-Takeover -- [x] Zero Allocs Read / Write - [x] Passed Autobahn Test Cases [Server](https://lxzan.github.io/gws/reports/servers/) / [Client](https://lxzan.github.io/gws/reports/clients/) - [x] Concurrent & Asynchronous Non-Blocking Write diff --git a/README_CN.md b/README_CN.md index 34241af4..1b1c45de 100755 --- a/README_CN.md +++ b/README_CN.md @@ -86,7 +86,6 @@ PASS - [x] 广播 - [x] 代理拨号 - [x] 上下文接管 -- [x] 读写过程零动态内存分配 - [x] 支持并发和异步非阻塞写入 - [x] 通过所有 Autobahn 测试用例 [Server](https://lxzan.github.io/gws/reports/servers/) / [Client](https://lxzan.github.io/gws/reports/clients/) diff --git a/benchmark_test.go b/benchmark_test.go index 5c7fcc27..c291c582 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -68,7 +68,7 @@ func BenchmarkConn_ReadMessage(b *testing.B) { conn: &benchConn{}, config: upgrader.option.getConfig(), } - var buf, _ = conn1.genFrame(OpcodeText, githubData, false) + var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), false) var reader = bytes.NewBuffer(buf.Bytes()) var conn2 = &Conn{ @@ -98,7 +98,7 @@ func BenchmarkConn_ReadMessage(b *testing.B) { deflater: new(deflater), } conn1.deflater.initialize(false, conn1.pd) - var buf, _ = conn1.genFrame(OpcodeText, githubData, false) + var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), false) var reader = bytes.NewBuffer(buf.Bytes()) var conn2 = &Conn{ diff --git a/compress.go b/compress.go index fbc8c153..989f28cb 100644 --- a/compress.go +++ b/compress.go @@ -82,12 +82,15 @@ func (c *deflater) Decompress(src *bytes.Buffer, dict []byte) (*bytes.Buffer, er } // Compress 压缩 -func (c *deflater) Compress(src []byte, dst *bytes.Buffer, dict []byte) error { +func (c *deflater) Compress(src internal.Payload, dst *bytes.Buffer, dict []byte) error { c.cpsLocker.Lock() defer c.cpsLocker.Unlock() c.cpsWriter.ResetDict(dst, dict) - if err := internal.CheckErrors(internal.WriteN(c.cpsWriter, src), c.cpsWriter.Flush()); err != nil { + if _, err := src.WriteTo(c.cpsWriter); err != nil { + return err + } + if err := c.cpsWriter.Flush(); err != nil { return err } if n := dst.Len(); n >= 4 { @@ -116,16 +119,17 @@ func (c *slideWindow) initialize(pool *internal.Pool[[]byte], windowBits int) *s return c } -func (c *slideWindow) Write(p []byte) { +func (c *slideWindow) Write(p []byte) (int, error) { if !c.enabled { - return + return 0, nil } - var n = len(p) + var total = len(p) + var n = total var length = len(c.dict) if n+length <= c.size { c.dict = append(c.dict, p...) - return + return total, nil } if m := c.size - length; m > 0 { @@ -136,11 +140,12 @@ func (c *slideWindow) Write(p []byte) { if n >= c.size { copy(c.dict, p[n-c.size:]) - return + return total, nil } copy(c.dict, c.dict[n:]) copy(c.dict[c.size-n:], p) + return total, nil } func (c *PermessageDeflate) genRequestHeader() string { diff --git a/compress_test.go b/compress_test.go index fd458ce6..8135d962 100644 --- a/compress_test.go +++ b/compress_test.go @@ -1,6 +1,8 @@ package gws import ( + "errors" + "io" "testing" "time" @@ -181,4 +183,79 @@ func TestPermessageNegotiation(t *testing.T) { assert.NoError(t, err) client.WriteMessage(OpcodeText, internal.AlphabetNumeric.Generate(1024)) }) + + t.Run("ok 5", func(t *testing.T) { + var addr = ":" + nextPort() + var serverHandler = &webSocketMocker{} + serverHandler.onMessage = func(socket *Conn, message *Message) { + println(message.Data.String()) + } + var server = NewServer(serverHandler, &ServerOption{PermessageDeflate: PermessageDeflate{ + Enabled: true, + ServerContextTakeover: true, + ClientContextTakeover: true, + ServerMaxWindowBits: 10, + ClientMaxWindowBits: 10, + }}) + go server.Run(addr) + + time.Sleep(100 * time.Millisecond) + client, _, err := NewClient(new(BuiltinEventHandler), &ClientOption{ + Addr: "ws://localhost" + addr, + PermessageDeflate: PermessageDeflate{ + Enabled: true, + ServerContextTakeover: true, + ClientContextTakeover: true, + Threshold: 1, + }, + }) + assert.NoError(t, err) + _ = client.WriteString("he") + assert.Equal(t, string(client.cpsWindow.dict), "he") + _ = client.WriteString("llo") + assert.Equal(t, string(client.cpsWindow.dict), "hello") + _ = client.WriteV(OpcodeText, []byte(", "), []byte("world!")) + assert.Equal(t, string(client.cpsWindow.dict), "hello, world!") + }) + + t.Run("fail", func(t *testing.T) { + var addr = ":" + nextPort() + var serverHandler = &webSocketMocker{} + var server = NewServer(serverHandler, &ServerOption{PermessageDeflate: PermessageDeflate{ + Enabled: true, + ServerContextTakeover: true, + ClientContextTakeover: true, + ServerMaxWindowBits: 10, + ClientMaxWindowBits: 10, + }}) + go server.Run(addr) + + time.Sleep(100 * time.Millisecond) + client, _, err := NewClient(new(BuiltinEventHandler), &ClientOption{ + Addr: "ws://localhost" + addr, + PermessageDeflate: PermessageDeflate{ + Enabled: true, + ServerContextTakeover: true, + ClientContextTakeover: true, + Threshold: 1, + }, + }) + assert.NoError(t, err) + err = client.doWrite(OpcodeText, new(writerTo)) + assert.Equal(t, err.Error(), "1") + }) +} + +type writerTo struct{} + +func (c *writerTo) CheckEncoding(enabled bool, opcode uint8) bool { + return true +} + +func (c *writerTo) Len() int { + return 10 +} + +func (c *writerTo) WriteTo(w io.Writer) (n int64, err error) { + return 0, errors.New("1") } diff --git a/conn.go b/conn.go index 4cd5c109..2833c18b 100644 --- a/conn.go +++ b/conn.go @@ -5,13 +5,11 @@ import ( "bytes" "crypto/tls" "encoding/binary" + "github.com/lxzan/gws/internal" "net" "sync" "sync/atomic" "time" - "unicode/utf8" - - "github.com/lxzan/gws/internal" ) type Conn struct { @@ -91,22 +89,17 @@ func (c *Conn) getDpsDict() []byte { } func (c *Conn) isTextValid(opcode Opcode, payload []byte) bool { - if !c.config.CheckUtf8Enabled { - return true - } - switch opcode { - case OpcodeText, OpcodeCloseConnection: - return utf8.Valid(payload) - default: - return true + if c.config.CheckUtf8Enabled { + return internal.CheckEncoding(uint8(opcode), payload) } + return true } func (c *Conn) isClosed() bool { return atomic.LoadUint32(&c.closed) == 1 } func (c *Conn) close(reason []byte, err error) { c.err.Store(err) - _ = c.doWrite(OpcodeCloseConnection, reason) + _ = c.doWrite(OpcodeCloseConnection, internal.Bytes(reason)) _ = c.conn.Close() } diff --git a/internal/io.go b/internal/io.go new file mode 100644 index 00000000..9f19fd82 --- /dev/null +++ b/internal/io.go @@ -0,0 +1,85 @@ +package internal + +import ( + "io" + "unicode/utf8" +) + +// ReadN 精准地读取len(data)个字节, 否则返回错误 +func ReadN(reader io.Reader, data []byte) error { + _, err := io.ReadFull(reader, data) + return err +} + +func WriteN(writer io.Writer, content []byte) error { + _, err := writer.Write(content) + return err +} + +func CheckEncoding(opcode uint8, payload []byte) bool { + switch opcode { + case 1, 8: + return utf8.Valid(payload) + default: + return true + } +} + +type Payload interface { + io.WriterTo + Len() int + CheckEncoding(enabled bool, opcode uint8) bool +} + +type Buffers [][]byte + +func (b Buffers) CheckEncoding(enabled bool, opcode uint8) bool { + if enabled { + for i, _ := range b { + if !CheckEncoding(opcode, b[i]) { + return false + } + } + } + return true +} + +func (b Buffers) Len() int { + var sum = 0 + for i, _ := range b { + sum += len(b[i]) + } + return sum +} + +// WriteTo 可重复写 +func (b Buffers) WriteTo(w io.Writer) (int64, error) { + var n = 0 + for i, _ := range b { + x, err := w.Write(b[i]) + n += x + if err != nil { + return int64(n), err + } + } + return int64(n), nil +} + +type Bytes []byte + +func (b Bytes) CheckEncoding(enabled bool, opcode uint8) bool { + if enabled { + return CheckEncoding(opcode, b) + } + return true +} + +func (b Bytes) Len() int { + return len(b) +} + +// WriteTo 可重复写 +func (b Bytes) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(b) + return int64(n), err +} diff --git a/internal/io_test.go b/internal/io_test.go new file mode 100644 index 00000000..597b0bf7 --- /dev/null +++ b/internal/io_test.go @@ -0,0 +1,90 @@ +package internal + +import ( + "bytes" + "net" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIOUtil(t *testing.T) { + var as = assert.New(t) + + t.Run("", func(t *testing.T) { + var reader = strings.NewReader("hello") + var p = make([]byte, 5) + var err = ReadN(reader, p) + as.Nil(err) + }) + + t.Run("", func(t *testing.T) { + var writer = bytes.NewBufferString("") + var err = WriteN(writer, nil) + as.NoError(err) + }) + + t.Run("", func(t *testing.T) { + var writer = bytes.NewBufferString("") + var p = []byte("hello") + var err = WriteN(writer, p) + as.NoError(err) + }) +} + +func TestBuffers_WriteTo(t *testing.T) { + t.Run("", func(t *testing.T) { + var b = Buffers{ + []byte("he"), + []byte("llo"), + } + var w = bytes.NewBufferString("") + b.WriteTo(w) + n, _ := b.WriteTo(w) + assert.Equal(t, w.String(), "hellohello") + assert.Equal(t, n, int64(5)) + assert.Equal(t, b.Len(), 5) + assert.True(t, b.CheckEncoding(true, 1)) + }) + + t.Run("", func(t *testing.T) { + var conn, _ = net.Pipe() + _ = conn.Close() + var b = Buffers{ + []byte("he"), + []byte("llo"), + } + _, err := b.WriteTo(conn) + assert.Error(t, err) + }) + + t.Run("", func(t *testing.T) { + var str = "你好" + var b = Buffers{ + []byte("he"), + []byte(str[2:]), + } + assert.False(t, b.CheckEncoding(true, 1)) + }) +} + +func TestBytes_WriteTo(t *testing.T) { + t.Run("", func(t *testing.T) { + var b = Bytes("hello") + var w = bytes.NewBufferString("") + b.WriteTo(w) + n, _ := b.WriteTo(w) + assert.Equal(t, w.String(), "hellohello") + assert.Equal(t, n, int64(5)) + assert.Equal(t, b.Len(), 5) + }) + + t.Run("", func(t *testing.T) { + var str = "你好" + var b = Bytes(str[2:]) + assert.False(t, b.CheckEncoding(true, 1)) + assert.True(t, b.CheckEncoding(false, 1)) + assert.True(t, b.CheckEncoding(true, 2)) + }) +} diff --git a/internal/utils.go b/internal/utils.go index 4ac09f7e..7b0650fd 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -5,7 +5,6 @@ import ( "crypto/sha1" "encoding/base64" "encoding/binary" - "io" "reflect" "strings" "unsafe" @@ -88,18 +87,6 @@ func FnvNumber[T Integer](x T) uint64 { return h } -// ReadN 精准地读取len(data)个字节, 否则返回错误 -func ReadN(reader io.Reader, data []byte) error { - _, err := io.ReadFull(reader, data) - return err -} - -// WriteN 精准地写入len(data)个字节, 否则返回错误 -func WriteN(writer io.Writer, content []byte) error { - _, err := writer.Write(content) - return err -} - func MaskXOR(b []byte, key []byte) { var maskKey = binary.LittleEndian.Uint32(key) var key64 = uint64(maskKey)<<32 + uint64(maskKey) @@ -259,10 +246,3 @@ func CheckErrors(errs ...error) error { } return nil } - -func Reduce[T any, S any](arr []T, initialValue S, reducer func(s S, i int, v T) S) S { - for index, value := range arr { - initialValue = reducer(initialValue, index, value) - } - return initialValue -} diff --git a/internal/utils_test.go b/internal/utils_test.go index d531fbb7..484e4a83 100644 --- a/internal/utils_test.go +++ b/internal/utils_test.go @@ -81,30 +81,6 @@ func TestFNV64(t *testing.T) { _ = FnvNumber(1234) } -func TestIOUtil(t *testing.T) { - var as = assert.New(t) - - t.Run("", func(t *testing.T) { - var reader = strings.NewReader("hello") - var p = make([]byte, 5) - var err = ReadN(reader, p) - as.Nil(err) - }) - - t.Run("", func(t *testing.T) { - var writer = bytes.NewBufferString("") - var err = WriteN(writer, nil) - as.NoError(err) - }) - - t.Run("", func(t *testing.T) { - var writer = bytes.NewBufferString("") - var p = []byte("hello") - var err = WriteN(writer, p) - as.NoError(err) - }) -} - func TestNewMaskKey(t *testing.T) { var key = NewMaskKey() assert.Equal(t, 4, len(key)) @@ -280,33 +256,3 @@ func TestCheckErrors(t *testing.T) { assert.Error(t, CheckErrors(err0, err1, err2)) assert.True(t, errors.Is(CheckErrors(err0, err1, err2), err2)) } - -func TestReduce(t *testing.T) { - t.Run("", func(t *testing.T) { - var arr = []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - var sum = Reduce(arr, 0, func(summarize int, i int, item int) int { - return summarize + item - }) - assert.Equal(t, sum, 55) - }) - - t.Run("", func(t *testing.T) { - var arr []int - var sum = Reduce(arr, 0, func(summarize int, i int, item int) int { - return summarize + item - }) - assert.Equal(t, sum, 0) - }) - - t.Run("", func(t *testing.T) { - var payloads = [][]byte{ - AlphabetNumeric.Generate(10), - AlphabetNumeric.Generate(20), - AlphabetNumeric.Generate(30), - } - var n = Reduce(payloads, 0, func(s int, i int, v []byte) int { - return s + len(v) - }) - assert.Equal(t, n, 60) - }) -} diff --git a/reader_test.go b/reader_test.go index aa3011cb..5807e894 100644 --- a/reader_test.go +++ b/reader_test.go @@ -287,7 +287,7 @@ func TestSegments(t *testing.T) { go client.ReadLoop() go func() { - frame, _ := client.genFrame(OpcodeText, testdata, false) + frame, _ := client.genFrame(OpcodeText, internal.Bytes(testdata), false) data := frame.Bytes() data[20] = 'x' client.conn.Write(data) @@ -365,7 +365,8 @@ func TestConn_ReadMessage(t *testing.T) { var addr = ":" + nextPort() var serverHandler = &webSocketMocker{} serverHandler.onOpen = func(socket *Conn) { - frame, _ := socket.genFrame(OpcodePing, []byte("123"), false) + var p = []byte("123") + frame, _ := socket.genFrame(OpcodePing, internal.Bytes(p), false) socket.conn.Write(frame.Bytes()[:2]) socket.conn.Close() } @@ -389,7 +390,8 @@ func TestConn_ReadMessage(t *testing.T) { var addr = ":" + nextPort() var serverHandler = &webSocketMocker{} serverHandler.onOpen = func(socket *Conn) { - frame, _ := socket.genFrame(OpcodeText, []byte("123"), false) + var p = []byte("123") + frame, _ := socket.genFrame(OpcodeText, internal.Bytes(p), false) socket.conn.Write(frame.Bytes()[:2]) socket.conn.Close() } diff --git a/writer.go b/writer.go index 5ff0cf49..6f3fc019 100644 --- a/writer.go +++ b/writer.go @@ -44,7 +44,7 @@ func (c *Conn) WriteString(s string) error { // WriteMessage 写入文本/二进制消息, 文本消息应该使用UTF8编码 // Write text/binary messages, text messages should be encoded in UTF8. func (c *Conn) WriteMessage(opcode Opcode, payload []byte) error { - err := c.doWrite(opcode, payload) + err := c.doWrite(opcode, internal.Bytes(payload)) c.emitError(err) return err } @@ -52,13 +52,8 @@ func (c *Conn) WriteMessage(opcode Opcode, payload []byte) error { // WriteV 批量写入文本/二进制消息, 文本消息应该使用UTF8编码 // writes batch text/binary messages, text messages should be encoded in UTF8. func (c *Conn) WriteV(opcode Opcode, payloads ...[]byte) error { - var n = internal.Reduce(payloads, 0, func(s int, i int, v []byte) int { return s + len(v) }) - var buf = binaryPool.Get(n) - for _, item := range payloads { - buf.Write(item) - } - var err = c.WriteMessage(opcode, buf.Bytes()) - binaryPool.Put(buf) + var err = c.doWrite(opcode, internal.Buffers(payloads)) + c.emitError(err) return err } @@ -74,7 +69,7 @@ func (c *Conn) WriteAsync(opcode Opcode, payload []byte, callback func(error)) { } // 执行写入逻辑, 注意妥善维护压缩字典 -func (c *Conn) doWrite(opcode Opcode, payload []byte) error { +func (c *Conn) doWrite(opcode Opcode, payload internal.Payload) error { c.mu.Lock() defer c.mu.Unlock() @@ -88,58 +83,58 @@ func (c *Conn) doWrite(opcode Opcode, payload []byte) error { } err = internal.WriteN(c.conn, frame.Bytes()) - c.cpsWindow.Write(payload) + _, _ = payload.WriteTo(&c.cpsWindow) binaryPool.Put(frame) return err } // 帧生成 -func (c *Conn) genFrame(opcode Opcode, payload []byte, isBroadcast bool) (*bytes.Buffer, error) { - // 不要删除 opcode == OpcodeText - if opcode == OpcodeText && !c.isTextValid(opcode, payload) { +func (c *Conn) genFrame(opcode Opcode, payload internal.Payload, isBroadcast bool) (*bytes.Buffer, error) { + if opcode == OpcodeText && !payload.CheckEncoding(c.config.CheckUtf8Enabled, uint8(opcode)) { return nil, internal.NewError(internal.CloseUnsupportedData, ErrTextEncoding) } - if c.pd.Enabled && opcode.isDataFrame() && len(payload) >= c.pd.Threshold { - return c.compressData(opcode, payload, isBroadcast) - } + var n = payload.Len() - var n = len(payload) if n > c.config.WriteMaxPayloadSize { return nil, internal.CloseMessageTooLarge } + var buf = binaryPool.Get(n + frameHeaderSize) + buf.Write(framePadding[0:]) + + if c.pd.Enabled && opcode.isDataFrame() && n >= c.pd.Threshold { + return c.compressData(buf, opcode, payload, isBroadcast) + } + var header = frameHeader{} headerLength, maskBytes := header.GenerateHeader(c.isServer, true, false, opcode, n) - var buf = binaryPool.Get(n + headerLength) - buf.Write(header[:headerLength]) - buf.Write(payload) + _, _ = payload.WriteTo(buf) var contents = buf.Bytes() if !c.isServer { - internal.MaskXOR(contents[headerLength:], maskBytes) + internal.MaskXOR(contents[frameHeaderSize:], maskBytes) } + var m = frameHeaderSize - headerLength + copy(contents[m:], header[:headerLength]) + buf.Next(m) return buf, nil } -func (c *Conn) compressData(opcode Opcode, payload []byte, isBroadcast bool) (*bytes.Buffer, error) { - var buf = binaryPool.Get(len(payload) + frameHeaderSize) - buf.Write(framePadding[0:]) +func (c *Conn) compressData(buf *bytes.Buffer, opcode Opcode, payload internal.Payload, isBroadcast bool) (*bytes.Buffer, error) { err := c.deflater.Compress(payload, buf, c.getCpsDict(isBroadcast)) if err != nil { return nil, err } var contents = buf.Bytes() var payloadSize = buf.Len() - frameHeaderSize - if payloadSize > c.config.WriteMaxPayloadSize { - return nil, internal.CloseMessageTooLarge - } var header = frameHeader{} headerLength, maskBytes := header.GenerateHeader(c.isServer, true, true, opcode, payloadSize) if !c.isServer { internal.MaskXOR(contents[frameHeaderSize:], maskBytes) } - copy(contents[frameHeaderSize-headerLength:], header[:headerLength]) - buf.Next(frameHeaderSize - headerLength) + var m = frameHeaderSize - headerLength + copy(contents[m:], header[:headerLength]) + buf.Next(m) return buf, nil } @@ -189,7 +184,9 @@ func (c *Broadcaster) Broadcast(socket *Conn) error { var idx = internal.SelectValue(socket.pd.Enabled, 1, 0) var msg = c.msgs[idx] - msg.once.Do(func() { msg.frame, msg.err = socket.genFrame(c.opcode, c.payload, true) }) + msg.once.Do(func() { + msg.frame, msg.err = socket.genFrame(c.opcode, internal.Bytes(c.payload), true) + }) if msg.err != nil { return msg.err } diff --git a/writer_test.go b/writer_test.go index 16b4b05a..31ecf40a 100644 --- a/writer_test.go +++ b/writer_test.go @@ -18,7 +18,7 @@ func testWrite(c *Conn, fin bool, opcode Opcode, payload []byte) error { var useCompress = c.pd.Enabled && opcode.isDataFrame() && len(payload) >= c.pd.Threshold if useCompress { var buf = bytes.NewBufferString("") - err := c.deflater.Compress(payload, buf, c.cpsWindow.dict) + err := c.deflater.Compress(internal.Bytes(payload), buf, c.cpsWindow.dict) if err != nil { return internal.NewError(internal.CloseInternalServerErr, err) } @@ -397,4 +397,36 @@ func TestConn_WriteV(t *testing.T) { assert.NoError(t, err) wg.Wait() }) + + t.Run("", func(t *testing.T) { + var serverHandler = new(webSocketMocker) + var clientHandler = new(webSocketMocker) + var serverOption = &ServerOption{ + PermessageDeflate: PermessageDeflate{ + Enabled: true, + ServerContextTakeover: true, + ClientContextTakeover: true, + Threshold: 1, + }, + } + var clientOption = &ClientOption{ + CheckUtf8Enabled: true, + PermessageDeflate: PermessageDeflate{ + Enabled: true, + ServerContextTakeover: true, + ClientContextTakeover: true, + Threshold: 1, + }, + } + + server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption) + go server.ReadLoop() + go client.ReadLoop() + + var err = client.WriteV(OpcodeText, [][]byte{ + []byte("山高月小"), + []byte("水落石出")[2:], + }...) + assert.Error(t, err) + }) }