diff --git a/README.md b/README.md index f7ec87f5..dd61d27d 100755 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ to be processed in a non-blocking way. - Reliability and Stability - **Robust Error Handling**: Advanced mechanisms to manage and mitigate errors, ensuring continuous operation. - - **Well-Developed Test Cases**: Passed all `Autobahn` test cases, fully compliant with `RFC 6455`. 99% unit test coverage, covering almost all conditional branches. + - **Well-Developed Test Cases**: Passed all `Autobahn` test cases, fully compliant with `RFC 7692`. 99% unit test coverage, covering almost all conditional branches. ### Benchmark @@ -95,7 +95,7 @@ PASS - [x] Dial via Proxy - [x] Context-Takeover - [x] Zero Allocs Read / Write -- [x] Passed `Autobahn` Test Cases [Server](https://lxzan.github.io/gws/reports/servers/) / [Client](https://lxzan.github.io/gws/reports/clients/) +- [x] Passed Autobahn Test Cases [Server](https://lxzan.github.io/gws/reports/servers/) / [Client](https://lxzan.github.io/gws/reports/clients/) - [x] Concurrent & Asynchronous Non-Blocking Write ### Attention @@ -152,7 +152,7 @@ const ( func main() { upgrader := gws.NewUpgrader(&Handler{}, &gws.ServerOption{ - ReadAsyncEnabled: true, // Parallel message processing + ParallelEnabled: true, // Parallel message processing Recovery: gws.Recovery, // Exception recovery PermessageDeflate: gws.PermessageDeflate{Enabled: true}, // Enable compression }) diff --git a/README_CN.md b/README_CN.md index 29ed5d2d..2c5c3713 100755 --- a/README_CN.md +++ b/README_CN.md @@ -30,7 +30,7 @@ GWS(Go WebSocket)是一个用 Go 编写的非常简单、快速、可靠且 - 稳定可靠 - **健壮的错误处理**: 管理和减少错误的先进机制,确保持续运行. - - **完善的测试用例**: 通过了所有 `Autobahn` 测试用例, 完全符合 `RFC 6455` 标准. 单元测试覆盖率达到 99%, 几乎覆盖所有条件分支. + - **完善的测试用例**: 通过了所有 `Autobahn` 测试用例, 符合 `RFC 7692` 标准. 单元测试覆盖率达到 99%, 几乎覆盖所有条件分支. ### 基准测试 @@ -88,7 +88,7 @@ PASS - [x] 上下文接管 - [x] 读写过程零动态内存分配 - [x] 支持并发和异步非阻塞写入 -- [x] 通过所有 `Autobahn` 测试用例 [Server](https://lxzan.github.io/gws/reports/servers/) / [Client](https://lxzan.github.io/gws/reports/clients/) +- [x] 通过所有 Autobahn 测试用例 [Server](https://lxzan.github.io/gws/reports/servers/) / [Client](https://lxzan.github.io/gws/reports/clients/) ### 注意 @@ -145,7 +145,7 @@ const ( func main() { upgrader := gws.NewUpgrader(&Handler{}, &gws.ServerOption{ - ReadAsyncEnabled: true, // 开启并行消息处理 + ParallelEnabled: true, // 开启并行消息处理 Recovery: gws.Recovery, // 开启异常恢复 PermessageDeflate: gws.PermessageDeflate{Enabled: true}, // 开启压缩 }) diff --git a/client.go b/client.go index 55bfcd01..7dbb26d9 100644 --- a/client.go +++ b/client.go @@ -172,15 +172,15 @@ func (c *connector) handshake() (*Conn, *http.Response, error) { closed: 0, deflater: new(deflater), writeQueue: workerQueue{maxConcurrency: 1}, - readQueue: make(channel, c.option.ReadAsyncGoLimit), + readQueue: make(channel, c.option.ParallelGolimit), } if pd.Enabled { socket.deflater.initialize(false, pd) if pd.ServerContextTakeover { - socket.dpsWindow.initialize(pd.ServerMaxWindowBits) + socket.dpsWindow.initialize(nil, pd.ServerMaxWindowBits) } if pd.ClientContextTakeover { - socket.cpsWindow.initialize(pd.ClientMaxWindowBits) + socket.cpsWindow.initialize(nil, pd.ClientMaxWindowBits) } } return socket, resp, c.conn.SetDeadline(time.Time{}) diff --git a/compress.go b/compress.go index f5ed7f0b..fbc8c153 100644 --- a/compress.go +++ b/compress.go @@ -87,10 +87,7 @@ func (c *deflater) Compress(src []byte, dst *bytes.Buffer, dict []byte) error { defer c.cpsLocker.Unlock() c.cpsWriter.ResetDict(dst, dict) - if err := internal.WriteN(c.cpsWriter, src); err != nil { - return err - } - if err := c.cpsWriter.Flush(); err != nil { + if err := internal.CheckErrors(internal.WriteN(c.cpsWriter, src), c.cpsWriter.Flush()); err != nil { return err } if n := dst.Len(); n >= 4 { @@ -108,10 +105,14 @@ type slideWindow struct { size int } -func (c *slideWindow) initialize(windowBits int) *slideWindow { +func (c *slideWindow) initialize(pool *internal.Pool[[]byte], windowBits int) *slideWindow { c.enabled = true c.size = internal.BinaryPow(windowBits) - c.dict = make([]byte, 0, c.size) + if pool != nil { + c.dict = pool.Get()[:0] + } else { + c.dict = make([]byte, 0, c.size) + } return c } @@ -127,10 +128,11 @@ func (c *slideWindow) Write(p []byte) { return } - var m = c.size - length - c.dict = append(c.dict, p[:m]...) - p = p[m:] - n = len(p) + if m := c.size - length; m > 0 { + c.dict = append(c.dict, p[:m]...) + p = p[m:] + n = len(p) + } if n >= c.size { copy(c.dict, p[n-c.size:]) diff --git a/compress_test.go b/compress_test.go index c9499d5a..fd458ce6 100644 --- a/compress_test.go +++ b/compress_test.go @@ -11,7 +11,7 @@ import ( func TestSlideWindow(t *testing.T) { t.Run("", func(t *testing.T) { - var sw = new(slideWindow).initialize(3) + var sw = new(slideWindow).initialize(nil, 3) sw.Write([]byte("abc")) assert.Equal(t, string(sw.dict), "abc") @@ -23,13 +23,35 @@ func TestSlideWindow(t *testing.T) { }) t.Run("", func(t *testing.T) { - var sw = new(slideWindow).initialize(3) + var sw = new(slideWindow).initialize(nil, 3) sw.Write([]byte("abc")) assert.Equal(t, string(sw.dict), "abc") sw.Write([]byte("defgh123456789")) assert.Equal(t, string(sw.dict), "23456789") }) + + t.Run("", func(t *testing.T) { + const size = 4 * 1024 + var sw = slideWindow{enabled: true, size: size} + for i := 0; i < 1000; i++ { + var n = internal.AlphabetNumeric.Intn(100) + sw.Write(internal.AlphabetNumeric.Generate(n)) + } + assert.Equal(t, len(sw.dict), size) + }) + + t.Run("", func(t *testing.T) { + const size = 4 * 1024 + for i := 0; i < 10; i++ { + var sw = slideWindow{enabled: true, size: size, dict: make([]byte, internal.AlphabetNumeric.Intn(size))} + for j := 0; j < 1000; j++ { + var n = internal.AlphabetNumeric.Intn(100) + sw.Write(internal.AlphabetNumeric.Generate(n)) + } + assert.Equal(t, len(sw.dict), size) + } + }) } func TestNegotiation(t *testing.T) { diff --git a/conn.go b/conn.go index 3f91a21c..4cd5c109 100644 --- a/conn.go +++ b/conn.go @@ -52,8 +52,17 @@ func (c *Conn) ReadLoop() { // 回收资源 if c.isServer { c.br.Reset(nil) - c.config.readerPool.Put(c.br) + c.config.brPool.Put(c.br) c.br = nil + + if c.cpsWindow.enabled { + c.config.cswPool.Put(c.cpsWindow.dict) + c.cpsWindow.dict = nil + } + if c.dpsWindow.enabled { + c.config.dswPool.Put(c.dpsWindow.dict) + c.dpsWindow.dict = nil + } } } diff --git a/examples/autobahn/server/main.go b/examples/autobahn/server/main.go index 5381675b..b819d263 100644 --- a/examples/autobahn/server/main.go +++ b/examples/autobahn/server/main.go @@ -18,7 +18,7 @@ func main() { }) s2 := gws.NewServer(&Handler{Sync: false}, &gws.ServerOption{ - ReadAsyncEnabled: true, + ParallelEnabled: true, PermessageDeflate: gws.PermessageDeflate{ Enabled: true, ServerContextTakeover: true, @@ -39,7 +39,7 @@ func main() { }) s4 := gws.NewServer(&Handler{Sync: false}, &gws.ServerOption{ - ReadAsyncEnabled: true, + ParallelEnabled: true, PermessageDeflate: gws.PermessageDeflate{ Enabled: true, ServerContextTakeover: false, diff --git a/init.go b/init.go index 4e07d55e..48f127f7 100644 --- a/init.go +++ b/init.go @@ -6,5 +6,4 @@ var ( framePadding = frameHeader{} // 帧头填充物 binaryPool = internal.NewBufferPool() // 缓冲池 defaultLogger = new(stdLogger) // 默认日志工具 - callbackFunc = func(err error) {} // 回调函数 ) diff --git a/internal/deque_test.go b/internal/deque_test.go index 66f49f4f..57492341 100644 --- a/internal/deque_test.go +++ b/internal/deque_test.go @@ -2,10 +2,11 @@ package internal import ( "container/list" - "github.com/stretchr/testify/assert" "math/rand" "testing" "unsafe" + + "github.com/stretchr/testify/assert" ) type Ordered interface { @@ -463,3 +464,11 @@ func TestDeque_Clone(t *testing.T) { assert.NotEqual(t, addr, addr1) assert.Equal(t, addr, addr2) } + +func TestDeque_PushFront(t *testing.T) { + var q Deque[int] + q.PushFront(1) + q.PushFront(3) + q.PushFront(5) + assert.Equal(t, q.PopFront(), 5) +} diff --git a/internal/utils.go b/internal/utils.go index b67cbeb4..18d10b93 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -232,6 +232,13 @@ func Min[T int | int64](a, b T) T { return b } +func Max[T int | int64](a, b T) T { + if a > b { + return a + } + return b +} + func IsSameSlice[T comparable](a, b []T) bool { if len(a) != len(b) { return false @@ -243,3 +250,12 @@ func IsSameSlice[T comparable](a, b []T) bool { } return true } + +func CheckErrors(errs ...error) error { + for _, item := range errs { + if item != nil { + return item + } + } + return nil +} diff --git a/internal/utils_test.go b/internal/utils_test.go index 505f4f67..4e3dac7f 100644 --- a/internal/utils_test.go +++ b/internal/utils_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "errors" "hash/fnv" "io" "net/http" @@ -249,6 +250,11 @@ func TestMin(t *testing.T) { assert.Equal(t, Min(4, 3), 3) } +func TestMax(t *testing.T) { + assert.Equal(t, Max(1, 2), 2) + assert.Equal(t, Max(4, 3), 4) +} + func TestIsSameSlice(t *testing.T) { assert.True(t, IsSameSlice( []int{1, 2, 3}, @@ -265,3 +271,12 @@ func TestIsSameSlice(t *testing.T) { []int{1, 2, 4}, )) } + +func TestCheckErrors(t *testing.T) { + var err0 error + var err1 error + var err2 = errors.New("1") + assert.NoError(t, CheckErrors(err0, err1)) + assert.Error(t, CheckErrors(err0, err1, err2)) + assert.True(t, errors.Is(CheckErrors(err0, err1, err2), err2)) +} diff --git a/option.go b/option.go index 1486654c..1a2eb2ea 100644 --- a/option.go +++ b/option.go @@ -12,7 +12,7 @@ import ( ) const ( - defaultReadAsyncGoLimit = 8 + defaultParallelGolimit = 8 defaultCompressLevel = flate.BestSpeed defaultReadMaxPayloadSize = 16 * 1024 * 1024 defaultWriteMaxPayloadSize = 16 * 1024 * 1024 @@ -71,15 +71,22 @@ type ( } Config struct { - readerPool *internal.Pool[*bufio.Reader] + // bufio.Reader内存池 + brPool *internal.Pool[*bufio.Reader] - // 是否开启异步读, 开启的话会并行调用OnMessage - // Whether to enable asynchronous reading, if enabled OnMessage will be called in parallel - ReadAsyncEnabled bool + // 压缩器滑动窗口内存池 + cswPool *internal.Pool[[]byte] - // 异步读的最大并行协程数量 - // Maximum number of parallel concurrent processes for asynchronous reads - ReadAsyncGoLimit int + // 解压器滑动窗口内存池 + dswPool *internal.Pool[[]byte] + + // 是否开启并行消息处理 + // Whether to enable parallel message processing + ParallelEnabled bool + + // (单个连接)用于并行消息处理的协程数量限制 + // Limit on the number of concurrent goroutine used for parallel message processing (single connection) + ParallelGolimit int // 最大读取的消息内容长度 // Maximum read message content length @@ -118,8 +125,8 @@ type ( WriteBufferSize int PermessageDeflate PermessageDeflate - ReadAsyncEnabled bool - ReadAsyncGoLimit int + ParallelEnabled bool + ParallelGolimit int ReadMaxPayloadSize int ReadBufferSize int WriteMaxPayloadSize int @@ -168,8 +175,8 @@ func initServerOption(c *ServerOption) *ServerOption { if c.ReadMaxPayloadSize <= 0 { c.ReadMaxPayloadSize = defaultReadMaxPayloadSize } - if c.ReadAsyncGoLimit <= 0 { - c.ReadAsyncGoLimit = defaultReadAsyncGoLimit + if c.ParallelGolimit <= 0 { + c.ParallelGolimit = defaultParallelGolimit } if c.ReadBufferSize <= 0 { c.ReadBufferSize = defaultReadBufferSize @@ -221,9 +228,8 @@ func initServerOption(c *ServerOption) *ServerOption { c.deleteProtectedHeaders() c.config = &Config{ - readerPool: internal.NewPool(func() *bufio.Reader { return bufio.NewReaderSize(nil, c.ReadBufferSize) }), - ReadAsyncEnabled: c.ReadAsyncEnabled, - ReadAsyncGoLimit: c.ReadAsyncGoLimit, + ParallelEnabled: c.ParallelEnabled, + ParallelGolimit: c.ParallelGolimit, ReadMaxPayloadSize: c.ReadMaxPayloadSize, ReadBufferSize: c.ReadBufferSize, WriteMaxPayloadSize: c.WriteMaxPayloadSize, @@ -231,6 +237,24 @@ func initServerOption(c *ServerOption) *ServerOption { CheckUtf8Enabled: c.CheckUtf8Enabled, Recovery: c.Recovery, Logger: c.Logger, + brPool: internal.NewPool(func() *bufio.Reader { + return bufio.NewReaderSize(nil, c.ReadBufferSize) + }), + } + + if c.PermessageDeflate.Enabled { + if c.PermessageDeflate.ServerContextTakeover { + windowSize := internal.BinaryPow(c.PermessageDeflate.ServerMaxWindowBits) + c.config.cswPool = internal.NewPool[[]byte](func() []byte { + return make([]byte, 0, windowSize) + }) + } + if c.PermessageDeflate.ClientContextTakeover { + windowSize := internal.BinaryPow(c.PermessageDeflate.ClientMaxWindowBits) + c.config.dswPool = internal.NewPool[[]byte](func() []byte { + return make([]byte, 0, windowSize) + }) + } } return c @@ -245,8 +269,8 @@ type ClientOption struct { WriteBufferSize int PermessageDeflate PermessageDeflate - ReadAsyncEnabled bool - ReadAsyncGoLimit int + ParallelEnabled bool + ParallelGolimit int ReadMaxPayloadSize int ReadBufferSize int WriteMaxPayloadSize int @@ -290,8 +314,8 @@ func initClientOption(c *ClientOption) *ClientOption { if c.ReadMaxPayloadSize <= 0 { c.ReadMaxPayloadSize = defaultReadMaxPayloadSize } - if c.ReadAsyncGoLimit <= 0 { - c.ReadAsyncGoLimit = defaultReadAsyncGoLimit + if c.ParallelGolimit <= 0 { + c.ParallelGolimit = defaultParallelGolimit } if c.ReadBufferSize <= 0 { c.ReadBufferSize = defaultReadBufferSize @@ -340,8 +364,8 @@ func initClientOption(c *ClientOption) *ClientOption { func (c *ClientOption) getConfig() *Config { config := &Config{ - ReadAsyncEnabled: c.ReadAsyncEnabled, - ReadAsyncGoLimit: c.ReadAsyncGoLimit, + ParallelEnabled: c.ParallelEnabled, + ParallelGolimit: c.ParallelGolimit, ReadMaxPayloadSize: c.ReadMaxPayloadSize, ReadBufferSize: c.ReadBufferSize, WriteMaxPayloadSize: c.WriteMaxPayloadSize, diff --git a/option_test.go b/option_test.go index bdc7a219..53fe630e 100644 --- a/option_test.go +++ b/option_test.go @@ -12,14 +12,14 @@ import ( func validateServerOption(as *assert.Assertions, u *Upgrader) { var option = u.option var config = u.option.getConfig() - as.Equal(config.ReadAsyncEnabled, option.ReadAsyncEnabled) - as.Equal(config.ReadAsyncGoLimit, option.ReadAsyncGoLimit) + as.Equal(config.ParallelEnabled, option.ParallelEnabled) + as.Equal(config.ParallelGolimit, option.ParallelGolimit) as.Equal(config.ReadMaxPayloadSize, option.ReadMaxPayloadSize) as.Equal(config.WriteMaxPayloadSize, option.WriteMaxPayloadSize) as.Equal(config.CheckUtf8Enabled, option.CheckUtf8Enabled) as.Equal(config.ReadBufferSize, option.ReadBufferSize) as.Equal(config.WriteBufferSize, option.WriteBufferSize) - as.NotNil(config.readerPool) + as.NotNil(config.brPool) as.NotNil(config.Recovery) as.Equal(config.Logger, defaultLogger) @@ -29,14 +29,14 @@ func validateServerOption(as *assert.Assertions, u *Upgrader) { func validateClientOption(as *assert.Assertions, option *ClientOption) { var config = option.getConfig() - as.Equal(config.ReadAsyncEnabled, option.ReadAsyncEnabled) - as.Equal(config.ReadAsyncGoLimit, option.ReadAsyncGoLimit) + as.Equal(config.ParallelEnabled, option.ParallelEnabled) + as.Equal(config.ParallelGolimit, option.ParallelGolimit) as.Equal(config.ReadMaxPayloadSize, option.ReadMaxPayloadSize) as.Equal(config.WriteMaxPayloadSize, option.WriteMaxPayloadSize) as.Equal(config.CheckUtf8Enabled, option.CheckUtf8Enabled) as.Equal(config.ReadBufferSize, option.ReadBufferSize) as.Equal(config.WriteBufferSize, option.WriteBufferSize) - as.Nil(config.readerPool) + as.Nil(config.brPool) as.NotNil(config.Recovery) as.Equal(config.Logger, defaultLogger) @@ -52,16 +52,13 @@ func TestDefaultUpgrader(t *testing.T) { "Sec-Websocket-Extensions": []string{"chat"}, "X-Server": []string{"gws"}, }, - PermessageDeflate: PermessageDeflate{ - Enabled: true, - ServerContextTakeover: true, - ClientContextTakeover: true, - }, }) var config = updrader.option.getConfig() - as.Equal(false, config.ReadAsyncEnabled) + as.Nil(config.cswPool) + as.Nil(config.dswPool) + as.Equal(false, config.ParallelEnabled) as.Equal(false, config.CheckUtf8Enabled) - as.Equal(defaultReadAsyncGoLimit, config.ReadAsyncGoLimit) + as.Equal(defaultParallelGolimit, config.ParallelGolimit) as.Equal(defaultReadMaxPayloadSize, config.ReadMaxPayloadSize) as.Equal(defaultWriteMaxPayloadSize, config.WriteMaxPayloadSize) as.Equal(defaultHandshakeTimeout, updrader.option.HandshakeTimeout) @@ -74,8 +71,8 @@ func TestDefaultUpgrader(t *testing.T) { as.Nil(updrader.option.SubProtocols) as.Equal("", updrader.option.ResponseHeader.Get("Sec-Websocket-Extensions")) as.Equal("gws", updrader.option.ResponseHeader.Get("X-Server")) - as.Equal(updrader.option.PermessageDeflate.ServerMaxWindowBits, 12) - as.Equal(updrader.option.PermessageDeflate.ClientMaxWindowBits, 12) + as.Equal(updrader.option.PermessageDeflate.ServerMaxWindowBits, 0) + as.Equal(updrader.option.PermessageDeflate.ClientMaxWindowBits, 0) validateServerOption(as, updrader) } @@ -103,30 +100,41 @@ func TestCompressServerOption(t *testing.T) { t.Run("", func(t *testing.T) { var updrader = NewUpgrader(new(BuiltinEventHandler), &ServerOption{ PermessageDeflate: PermessageDeflate{ - Enabled: true, - Level: flate.BestCompression, - Threshold: 1024, + Enabled: true, + ServerContextTakeover: true, + ClientContextTakeover: true, + ServerMaxWindowBits: 10, + ClientMaxWindowBits: 12, + Level: flate.BestCompression, + Threshold: 1024, }, }) + as.Equal(updrader.option.PermessageDeflate.ServerMaxWindowBits, 10) + as.Equal(updrader.option.PermessageDeflate.ClientMaxWindowBits, 12) as.Equal(true, updrader.option.PermessageDeflate.Enabled) as.Equal(flate.BestCompression, updrader.option.PermessageDeflate.Level) as.Equal(1024, updrader.option.PermessageDeflate.Threshold) as.Equal(defaultCompressorPoolSize, updrader.option.PermessageDeflate.PoolSize) validateServerOption(as, updrader) + + as.Equal(cap(updrader.option.config.cswPool.Get()), 1024) + as.Equal(cap(updrader.option.config.dswPool.Get()), 4*1024) + as.Equal(len(updrader.option.config.cswPool.Get()), 0) + as.Equal(len(updrader.option.config.dswPool.Get()), 0) }) } func TestReadServerOption(t *testing.T) { var as = assert.New(t) var updrader = NewUpgrader(new(BuiltinEventHandler), &ServerOption{ - ReadAsyncEnabled: true, - ReadAsyncGoLimit: 4, + ParallelEnabled: true, + ParallelGolimit: 4, ReadMaxPayloadSize: 1024, HandshakeTimeout: 10 * time.Second, }) var config = updrader.option.getConfig() - as.Equal(true, config.ReadAsyncEnabled) - as.Equal(4, config.ReadAsyncGoLimit) + as.Equal(true, config.ParallelEnabled) + as.Equal(4, config.ParallelGolimit) as.Equal(1024, config.ReadMaxPayloadSize) as.Equal(10*time.Second, updrader.option.HandshakeTimeout) validateServerOption(as, updrader) @@ -138,9 +146,12 @@ func TestDefaultClientOption(t *testing.T) { NewClient(new(BuiltinEventHandler), option) var config = option.getConfig() - as.Equal(false, config.ReadAsyncEnabled) + as.Nil(config.brPool) + as.Nil(config.cswPool) + as.Nil(config.dswPool) + as.Equal(false, config.ParallelEnabled) as.Equal(false, config.CheckUtf8Enabled) - as.Equal(defaultReadAsyncGoLimit, config.ReadAsyncGoLimit) + as.Equal(defaultParallelGolimit, config.ParallelGolimit) as.Equal(defaultReadMaxPayloadSize, config.ReadMaxPayloadSize) as.Equal(defaultWriteMaxPayloadSize, config.WriteMaxPayloadSize) as.NotNil(config) diff --git a/reader.go b/reader.go index d29299be..db6259a8 100644 --- a/reader.go +++ b/reader.go @@ -155,7 +155,7 @@ func (c *Conn) emitMessage(msg *Message) (err error) { if !c.isTextValid(msg.Opcode, msg.Bytes()) { return internal.NewError(internal.CloseUnsupportedData, ErrTextEncoding) } - if c.config.ReadAsyncEnabled { + if c.config.ParallelEnabled { return c.readQueue.Go(msg, c.dispatch) } return c.dispatch(msg) diff --git a/reader_test.go b/reader_test.go index c130d442..aa3011cb 100644 --- a/reader_test.go +++ b/reader_test.go @@ -106,14 +106,14 @@ func TestRead(t *testing.T) { var serverHandler = new(webSocketMocker) var clientHandler = new(webSocketMocker) var serverOption = &ServerOption{ - ReadAsyncEnabled: true, + ParallelEnabled: true, CheckUtf8Enabled: false, ReadMaxPayloadSize: 1024 * 1024, WriteMaxPayloadSize: 16 * 1024 * 1024, PermessageDeflate: PermessageDeflate{Enabled: true}, } var clientOption = &ClientOption{ - ReadAsyncEnabled: true, + ParallelEnabled: true, PermessageDeflate: PermessageDeflate{Enabled: true, ServerContextTakeover: true, ClientContextTakeover: true}, CheckUtf8Enabled: true, ReadMaxPayloadSize: 1024 * 1024, @@ -359,3 +359,53 @@ func TestFrameHeader_Parse(t *testing.T) { assert.Error(t, err) }) } + +func TestConn_ReadMessage(t *testing.T) { + t.Run("", func(t *testing.T) { + var addr = ":" + nextPort() + var serverHandler = &webSocketMocker{} + serverHandler.onOpen = func(socket *Conn) { + frame, _ := socket.genFrame(OpcodePing, []byte("123"), false) + socket.conn.Write(frame.Bytes()[:2]) + socket.conn.Close() + } + var server = NewServer(serverHandler, nil) + go server.Run(addr) + + time.Sleep(100 * time.Millisecond) + client, _, err := NewClient(new(BuiltinEventHandler), &ClientOption{ + Addr: "ws://localhost" + addr, + PermessageDeflate: PermessageDeflate{ + Enabled: true, + ServerContextTakeover: true, + ClientContextTakeover: true, + }, + }) + assert.NoError(t, err) + client.ReadLoop() + }) + + t.Run("", func(t *testing.T) { + var addr = ":" + nextPort() + var serverHandler = &webSocketMocker{} + serverHandler.onOpen = func(socket *Conn) { + frame, _ := socket.genFrame(OpcodeText, []byte("123"), false) + socket.conn.Write(frame.Bytes()[:2]) + socket.conn.Close() + } + var server = NewServer(serverHandler, nil) + go server.Run(addr) + + time.Sleep(100 * time.Millisecond) + client, _, err := NewClient(new(BuiltinEventHandler), &ClientOption{ + Addr: "ws://localhost" + addr, + PermessageDeflate: PermessageDeflate{ + Enabled: true, + ServerContextTakeover: true, + ClientContextTakeover: true, + }, + }) + assert.NoError(t, err) + client.ReadLoop() + }) +} diff --git a/task.go b/task.go index d5bc8426..5463312f 100644 --- a/task.go +++ b/task.go @@ -1,8 +1,9 @@ package gws import ( - "github.com/lxzan/gws/internal" "sync" + + "github.com/lxzan/gws/internal" ) type ( diff --git a/task_test.go b/task_test.go index e2bc55cb..436d3a31 100644 --- a/task_test.go +++ b/task_test.go @@ -42,10 +42,10 @@ func serveWebSocket( if isServer { socket.deflater = new(deflaterPool).initialize(pd).Select() if pd.ServerContextTakeover { - socket.cpsWindow.initialize(pd.ServerMaxWindowBits) + socket.cpsWindow.initialize(config.cswPool, pd.ServerMaxWindowBits) } if pd.ClientContextTakeover { - socket.dpsWindow.initialize(pd.ClientMaxWindowBits) + socket.dpsWindow.initialize(config.dswPool, pd.ClientMaxWindowBits) } } else { socket.deflater = new(deflater).initialize(false, pd) @@ -216,11 +216,11 @@ func TestReadAsync(t *testing.T) { var clientHandler = new(webSocketMocker) var serverOption = &ServerOption{ PermessageDeflate: PermessageDeflate{Enabled: true, Threshold: 512}, - ReadAsyncEnabled: true, + ParallelEnabled: true, } var clientOption = &ClientOption{ PermessageDeflate: PermessageDeflate{Enabled: true, Threshold: 512}, - ReadAsyncEnabled: true, + ParallelEnabled: true, } server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption) @@ -340,7 +340,7 @@ func TestWriteAsyncBlocking(t *testing.T) { // 第一个msg被异步协程从chan取出了,取出后阻塞在writePublic、没有后续的取出,再入defaultAsyncIOGoLimit个msg到chan里, // 则defaultAsyncIOGoLimit+2个消息会导致入chan阻塞。 // 1s后client 0开始读取,广播才会继续,这一轮对应的时间约为1s - for i := 0; i <= defaultReadAsyncGoLimit+2; i++ { + for i := 0; i <= defaultParallelGolimit+2; i++ { t0 := time.Now() for wsConn := range allConns { wsConn.WriteAsync(OpcodeBinary, []byte{0}, nil) diff --git a/types.go b/types.go index 1dde35e3..9d76d346 100644 --- a/types.go +++ b/types.go @@ -93,8 +93,8 @@ type Event interface { OnPong(socket *Conn, payload []byte) // OnMessage 消息事件 - // 如果开启了ReadAsyncEnabled, 会并行地调用OnMessage; 没有做recover处理. - // If ReadAsyncEnabled is enabled, OnMessage is called in parallel. No recover is done. + // 如果开启了ParallelEnabled, 会并行地调用OnMessage; 没有做recover处理. + // If ParallelEnabled is enabled, OnMessage is called in parallel. No recover is done. OnMessage(socket *Conn, message *Message) } diff --git a/upgrader.go b/upgrader.go index 973ae8c3..34a6886c 100644 --- a/upgrader.go +++ b/upgrader.go @@ -99,7 +99,7 @@ func (c *Upgrader) hijack(w http.ResponseWriter) (net.Conn, *bufio.Reader, error if err != nil { return nil, nil, err } - br := c.option.config.readerPool.Get() + br := c.option.config.brPool.Get() br.Reset(netConn) return netConn, br, nil } @@ -194,28 +194,29 @@ func (c *Upgrader) doUpgradeFromConn(netConn net.Conn, br *bufio.Reader, r *http return nil, err } + config := c.option.getConfig() socket := &Conn{ ss: session, isServer: true, subprotocol: rw.subprotocol, pd: pd, conn: netConn, - config: c.option.getConfig(), + config: config, br: br, continuationFrame: continuationFrame{}, fh: frameHeader{}, handler: c.eventHandler, closed: 0, writeQueue: workerQueue{maxConcurrency: 1}, - readQueue: make(channel, c.option.ReadAsyncGoLimit), + readQueue: make(channel, c.option.ParallelGolimit), } if pd.Enabled { socket.deflater = c.deflaterPool.Select() if c.option.PermessageDeflate.ServerContextTakeover { - socket.cpsWindow.initialize(c.option.PermessageDeflate.ServerMaxWindowBits) + socket.cpsWindow.initialize(config.cswPool, c.option.PermessageDeflate.ServerMaxWindowBits) } if c.option.PermessageDeflate.ClientContextTakeover { - socket.dpsWindow.initialize(c.option.PermessageDeflate.ClientMaxWindowBits) + socket.dpsWindow.initialize(config.dswPool, c.option.PermessageDeflate.ClientMaxWindowBits) } } return socket, nil @@ -296,7 +297,7 @@ func (c *Server) RunListener(listener net.Listener) error { } go func(conn net.Conn) { - br := c.option.config.readerPool.Get() + br := c.option.config.brPool.Get() br.Reset(conn) if r, err := http.ReadRequest(br); err != nil { c.OnError(conn, err) diff --git a/writer_test.go b/writer_test.go index 8d8a3c68..cb5247d1 100644 --- a/writer_test.go +++ b/writer_test.go @@ -294,7 +294,7 @@ func TestNewBroadcaster(t *testing.T) { go server.ReadLoop() go client.ReadLoop() - _ = server.conn.Close() + server.WriteClose(0, nil) var broadcaster = NewBroadcaster(OpcodeText, internal.AlphabetNumeric.Generate(16)) _ = broadcaster.Broadcast(server) wg.Wait()