From 37aaa6cc506051ca567241c47891f5548362f004 Mon Sep 17 00:00:00 2001 From: lxzan Date: Thu, 9 Nov 2023 15:40:18 +0800 Subject: [PATCH 1/2] Simplified Connection Pool Management --- .github/workflows/go.yml | 2 + benchmark_test.go | 4 +- compress.go | 8 ++-- compress_test.go | 4 +- internal/pool.go | 78 ++++++++++++++++++++++++++++---------- internal/pool_test.go | 81 ++++++++++++++++++++++++++++++++-------- reader.go | 10 ++--- reader_test.go | 2 +- recovery.go | 27 -------------- protocol.go => types.go | 28 ++++++++++++-- upgrader.go | 5 +-- writer.go | 33 ++++++++-------- 12 files changed, 181 insertions(+), 101 deletions(-) delete mode 100644 recovery.go rename protocol.go => types.go (93%) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index ee3755c2..3940232e 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -22,3 +22,5 @@ jobs: go-version: 1.18 - name: Test run: go test -v ./... + - name: Bench + run: make bench diff --git a/benchmark_test.go b/benchmark_test.go index 456a28b4..f1b08b20 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -66,7 +66,7 @@ func BenchmarkConn_ReadMessage(b *testing.B) { conn: &benchConn{}, config: upgrader.option.getConfig(), } - var buf, _, _ = conn1.genFrame(OpcodeText, githubData) + var buf, _ = conn1.genFrame(OpcodeText, githubData) var reader = bytes.NewBuffer(buf.Bytes()) var conn2 = &Conn{ @@ -94,7 +94,7 @@ func BenchmarkConn_ReadMessage(b *testing.B) { compressor: config.compressors.Select(), decompressor: config.decompressors.Select(), } - var buf, _, _ = conn1.genFrame(OpcodeText, githubData) + var buf, _ = conn1.genFrame(OpcodeText, githubData) var reader = bytes.NewBuffer(buf.Bytes()) var conn2 = &Conn{ diff --git a/compress.go b/compress.go index 4e633733..d6043dfa 100644 --- a/compress.go +++ b/compress.go @@ -109,16 +109,16 @@ func (c *decompressor) reset(r io.Reader) { } // Decompress 解压 -func (c *decompressor) Decompress(src *bytes.Buffer) (*bytes.Buffer, int, error) { +func (c *decompressor) Decompress(src *bytes.Buffer) (*bytes.Buffer, error) { c.Lock() defer c.Unlock() _, _ = src.Write(flateTail) c.reset(src) if _, err := c.fr.(io.WriterTo).WriteTo(c.b); err != nil { - return nil, 0, err + return nil, err } - var dst, idx = binaryPool.Get(c.b.Len()) + var dst = binaryPool.Get(c.b.Len()) _, _ = c.b.WriteTo(dst) - return dst, idx, nil + return dst, nil } diff --git a/compress_test.go b/compress_test.go index 7adef6b6..947f7517 100644 --- a/compress_test.go +++ b/compress_test.go @@ -26,7 +26,7 @@ func TestFlate(t *testing.T) { var buf = bytes.NewBufferString("") buf.Write(compressedBuf.Bytes()) - plainText, _, err := dps.Decompress(buf) + plainText, err := dps.Decompress(buf) if err != nil { as.NoError(err) return @@ -49,7 +49,7 @@ func TestFlate(t *testing.T) { var buf = bytes.NewBufferString("") buf.Write(compressedBuf.Bytes()) buf.WriteString("1234") - _, _, err := dps.Decompress(buf) + _, err := dps.Decompress(buf) as.Error(err) }) } diff --git a/internal/pool.go b/internal/pool.go index 49b6ae9f..2d8b24f8 100644 --- a/internal/pool.go +++ b/internal/pool.go @@ -20,43 +20,81 @@ const ( ) type BufferPool struct { - pools [poolSize]*sync.Pool - limits [poolSize]int + pools []*sync.Pool + limits []int } func NewBufferPool() *BufferPool { - var p BufferPool - p.limits = [poolSize]int{0, Lv1, Lv2, Lv3, Lv4, Lv5, Lv6, Lv7, Lv8, Lv9} + var p = &BufferPool{ + pools: make([]*sync.Pool, poolSize), + limits: []int{0, Lv1, Lv2, Lv3, Lv4, Lv5, Lv6, Lv7, Lv8, Lv9}, + } for i := 1; i < poolSize; i++ { var capacity = p.limits[i] p.pools[i] = &sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, 0, capacity)) }} } - return &p + return p } -func (p *BufferPool) Put(b *bytes.Buffer, index int) { - if index == 0 || b == nil { +func (p *BufferPool) Put(b *bytes.Buffer) { + if b == nil || b.Cap() == 0 { return } - if b.Cap() <= p.limits[index] { - p.pools[index].Put(b) + if i := p.getIndex(uint32(b.Cap())); i > 0 { + p.pools[i].Put(b) } } -func (p *BufferPool) Get(n int) (*bytes.Buffer, int) { - for i := 1; i < poolSize; i++ { - if n <= p.limits[i] { - b := p.pools[i].Get().(*bytes.Buffer) - if b.Cap() < n { - b.Grow(p.limits[i]) - } - b.Reset() - return b, i - } +func (p *BufferPool) Get(n int) *bytes.Buffer { + var index = p.getIndex(uint32(n)) + if index == 0 { + return bytes.NewBuffer(make([]byte, 0, n)) + } + + buf := p.pools[index].Get().(*bytes.Buffer) + if buf.Cap() < n { + buf.Grow(p.limits[index]) + } + buf.Reset() + return buf +} + +func (p *BufferPool) getIndex(v uint32) int { + if v > Lv9 { + return 0 + } + if v <= 128 { + return 1 + } + + v-- + v |= v >> 1 + v |= v >> 2 + v |= v >> 4 + v |= v >> 8 + v |= v >> 16 + v++ + + switch v { + case Lv3: + return 3 + case Lv4: + return 4 + case Lv5: + return 5 + case Lv6: + return 6 + case Lv7: + return 7 + case Lv8: + return 8 + case Lv9: + return 9 + default: + return 2 } - return bytes.NewBuffer(make([]byte, 0, n)), 0 } func NewPool[T any](f func() T) *Pool[T] { diff --git a/internal/pool_test.go b/internal/pool_test.go index 57341f09..a17a8733 100644 --- a/internal/pool_test.go +++ b/internal/pool_test.go @@ -12,45 +12,39 @@ func TestBufferPool(t *testing.T) { for i := 0; i < 10; i++ { var n = AlphabetNumeric.Intn(126) - var buf, index = pool.Get(n) + var buf = pool.Get(n) as.Equal(128, buf.Cap()) as.Equal(0, buf.Len()) - as.Equal(index, 1) } for i := 0; i < 10; i++ { - var buf, index = pool.Get(500) + var buf = pool.Get(500) as.Equal(Lv2, buf.Cap()) as.Equal(0, buf.Len()) - as.Equal(index, 2) } for i := 0; i < 10; i++ { - var buf, index = pool.Get(2000) + var buf = pool.Get(2000) as.Equal(Lv3, buf.Cap()) as.Equal(0, buf.Len()) - as.Equal(index, 3) } for i := 0; i < 10; i++ { - var buf, index = pool.Get(5000) + var buf = pool.Get(5000) as.Equal(Lv5, buf.Cap()) as.Equal(0, buf.Len()) - as.Equal(index, 5) } { - pool.Put(bytes.NewBuffer(make([]byte, 2)), 2) - b, index := pool.Get(120) + pool.Put(bytes.NewBuffer(make([]byte, 2))) + b := pool.Get(120) as.GreaterOrEqual(b.Cap(), 120) - as.Equal(index, 1) } { - pool.Put(bytes.NewBuffer(make([]byte, 2000)), 4) - b, index := pool.Get(3000) + pool.Put(bytes.NewBuffer(make([]byte, 2000))) + b := pool.Get(3000) as.GreaterOrEqual(b.Cap(), 3000) - as.Equal(index, 4) } - pool.Put(nil, 0) - buffer, _ := pool.Get(256 * 1024) + pool.Put(nil) + buffer := pool.Get(256 * 1024) as.GreaterOrEqual(buffer.Cap(), 256*1024) } @@ -61,3 +55,58 @@ func TestPool(t *testing.T) { assert.Equal(t, 0, p.Get()) p.Put(1) } + +func TestBufferPool_GetIndex(t *testing.T) { + var p = NewBufferPool() + assert.Equal(t, p.getIndex(200*1024), 0) + + assert.Equal(t, p.getIndex(0), 1) + assert.Equal(t, p.getIndex(1), 1) + assert.Equal(t, p.getIndex(10), 1) + assert.Equal(t, p.getIndex(100), 1) + assert.Equal(t, p.getIndex(128), 1) + + assert.Equal(t, p.getIndex(200), 2) + assert.Equal(t, p.getIndex(1000), 2) + assert.Equal(t, p.getIndex(500), 2) + assert.Equal(t, p.getIndex(1024), 2) + + assert.Equal(t, p.getIndex(2*1024), 3) + assert.Equal(t, p.getIndex(2000), 3) + assert.Equal(t, p.getIndex(1025), 3) + + assert.Equal(t, p.getIndex(4*1024), 4) + assert.Equal(t, p.getIndex(3000), 4) + assert.Equal(t, p.getIndex(2*1024+1), 4) + + assert.Equal(t, p.getIndex(8*1024), 5) + assert.Equal(t, p.getIndex(5000), 5) + assert.Equal(t, p.getIndex(4*1024+1), 5) + + assert.Equal(t, p.getIndex(16*1024), 6) + assert.Equal(t, p.getIndex(10000), 6) + assert.Equal(t, p.getIndex(8*1024+1), 6) + + assert.Equal(t, p.getIndex(32*1024), 7) + assert.Equal(t, p.getIndex(20000), 7) + assert.Equal(t, p.getIndex(16*1024+1), 7) + + assert.Equal(t, p.getIndex(64*1024), 8) + assert.Equal(t, p.getIndex(40000), 8) + assert.Equal(t, p.getIndex(32*1024+1), 8) + + assert.Equal(t, p.getIndex(128*1024), 9) + assert.Equal(t, p.getIndex(100000), 9) + assert.Equal(t, p.getIndex(64*1024+1), 9) +} + +func BenchmarkPool_GetIndex(b *testing.B) { + var p = NewBufferPool() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < 1000000; j++ { + p.getIndex(uint32(j)) + } + } +} diff --git a/reader.go b/reader.go index 87530d15..22ca65fe 100644 --- a/reader.go +++ b/reader.go @@ -93,9 +93,9 @@ func (c *Conn) readMessage() error { } var fin = c.fh.GetFIN() - var buf, index = binaryPool.Get(contentLength + len(flateTail)) + var buf = binaryPool.Get(contentLength + len(flateTail)) var p = buf.Bytes()[:contentLength] - var closer = Message{Data: buf, index: index} + var closer = Message{Data: buf} defer closer.Close() if err := internal.ReadN(c.br, p); err != nil { @@ -112,9 +112,9 @@ func (c *Conn) readMessage() error { if fin && opcode != OpcodeContinuation { *(*[]byte)(unsafe.Pointer(buf)) = p if !compressed { - closer.Data, closer.index = nil, 0 + closer.Data = nil } - return c.emitMessage(&Message{index: index, Opcode: opcode, Data: buf, compressed: compressed}) + return c.emitMessage(&Message{Opcode: opcode, Data: buf, compressed: compressed}) } if !fin && opcode != OpcodeContinuation { @@ -149,7 +149,7 @@ func (c *Conn) dispatch(msg *Message) error { func (c *Conn) emitMessage(msg *Message) (err error) { if msg.compressed { - msg.Data, msg.index, err = c.decompressor.Decompress(msg.Data) + msg.Data, err = c.decompressor.Decompress(msg.Data) if err != nil { return internal.NewError(internal.CloseInternalServerErr, err) } diff --git a/reader_test.go b/reader_test.go index bb45a36b..b58db3c2 100644 --- a/reader_test.go +++ b/reader_test.go @@ -279,7 +279,7 @@ func TestSegments(t *testing.T) { go client.ReadLoop() go func() { - frame, _, _ := client.genFrame(OpcodeText, testdata) + frame, _ := client.genFrame(OpcodeText, testdata) data := frame.Bytes() data[20] = 'x' client.conn.Write(data) diff --git a/recovery.go b/recovery.go deleted file mode 100644 index 05ce762c..00000000 --- a/recovery.go +++ /dev/null @@ -1,27 +0,0 @@ -package gws - -import ( - "log" - "runtime" - "unsafe" -) - -type Logger interface { - Error(v ...any) -} - -type stdLogger struct{} - -func (c *stdLogger) Error(v ...any) { - log.Println(v...) -} - -func Recovery(logger Logger) { - if e := recover(); e != nil { - const size = 64 << 10 - buf := make([]byte, size) - buf = buf[:runtime.Stack(buf, false)] - msg := *(*string)(unsafe.Pointer(&buf)) - logger.Error("fatal error:", e, msg) - } -} diff --git a/protocol.go b/types.go similarity index 93% rename from protocol.go rename to types.go index cff826bc..720e99e2 100644 --- a/protocol.go +++ b/types.go @@ -7,7 +7,10 @@ import ( "fmt" "github.com/lxzan/gws/internal" "io" + "log" "net" + "runtime" + "unsafe" ) const frameHeaderSize = 14 @@ -226,9 +229,6 @@ type Message struct { // 是否压缩 compressed bool - // 内存池下标索引 - index int - // 操作码 Opcode Opcode @@ -246,7 +246,7 @@ func (c *Message) Bytes() []byte { // Close recycle buffer func (c *Message) Close() error { - binaryPool.Put(c.Data, c.index) + binaryPool.Put(c.Data) c.Data = nil return nil } @@ -264,3 +264,23 @@ func (c *continuationFrame) reset() { c.opcode = 0 c.buffer = nil } + +type Logger interface { + Error(v ...any) +} + +type stdLogger struct{} + +func (c *stdLogger) Error(v ...any) { + log.Println(v...) +} + +func Recovery(logger Logger) { + if e := recover(); e != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + msg := *(*string)(unsafe.Pointer(&buf)) + logger.Error("fatal error:", e, msg) + } +} diff --git a/upgrader.go b/upgrader.go index c1177135..055e34d4 100644 --- a/upgrader.go +++ b/upgrader.go @@ -16,12 +16,11 @@ import ( type responseWriter struct { err error b *bytes.Buffer - idx int subprotocol string } func (c *responseWriter) Init() *responseWriter { - c.b, c.idx = binaryPool.Get(512) + c.b = binaryPool.Get(512) c.b.WriteString("HTTP/1.1 101 Switching Protocols\r\n") c.b.WriteString("Upgrade: websocket\r\n") c.b.WriteString("Connection: Upgrade\r\n") @@ -29,7 +28,7 @@ func (c *responseWriter) Init() *responseWriter { } func (c *responseWriter) Close() { - binaryPool.Put(c.b, c.idx) + binaryPool.Put(c.b) c.b = nil } diff --git a/writer.go b/writer.go index 46e2802a..c449b529 100644 --- a/writer.go +++ b/writer.go @@ -43,7 +43,7 @@ func (c *Conn) WriteString(s string) error { // WriteAsync 异步非阻塞地写入消息 // Write messages asynchronously and non-blocking func (c *Conn) WriteAsync(opcode Opcode, payload []byte) error { - frame, index, err := c.genFrame(opcode, payload) + frame, err := c.genFrame(opcode, payload) if err != nil { c.emitError(err) return err @@ -54,7 +54,7 @@ func (c *Conn) WriteAsync(opcode Opcode, payload []byte) error { return } err = internal.WriteN(c.conn, frame.Bytes()) - binaryPool.Put(frame, index) + binaryPool.Put(frame) c.emitError(err) }) return nil @@ -74,21 +74,21 @@ func (c *Conn) WriteMessage(opcode Opcode, payload []byte) error { // 执行写入逻辑, 关闭状态置为1后还能写, 以便发送关闭帧 // Execute the write logic, and write after the close state is set to 1, so that the close frame can be sent func (c *Conn) doWrite(opcode Opcode, payload []byte) error { - frame, index, err := c.genFrame(opcode, payload) + frame, err := c.genFrame(opcode, payload) if err != nil { return err } err = internal.WriteN(c.conn, frame.Bytes()) - binaryPool.Put(frame, index) + binaryPool.Put(frame) return err } // 帧生成 -func (c *Conn) genFrame(opcode Opcode, payload []byte) (*bytes.Buffer, int, error) { +func (c *Conn) genFrame(opcode Opcode, payload []byte) (*bytes.Buffer, error) { // 不要删除 opcode == OpcodeText if opcode == OpcodeText && !c.isTextValid(opcode, payload) { - return nil, 0, internal.NewError(internal.CloseUnsupportedData, ErrTextEncoding) + return nil, internal.NewError(internal.CloseUnsupportedData, ErrTextEncoding) } if c.compressEnabled && opcode.isDataFrame() && len(payload) >= c.config.CompressThreshold { @@ -97,32 +97,32 @@ func (c *Conn) genFrame(opcode Opcode, payload []byte) (*bytes.Buffer, int, erro var n = len(payload) if n > c.config.WriteMaxPayloadSize { - return nil, 0, internal.CloseMessageTooLarge + return nil, internal.CloseMessageTooLarge } var header = frameHeader{} headerLength, maskBytes := header.GenerateHeader(c.isServer, true, false, opcode, n) - var buf, index = binaryPool.Get(n + headerLength) + var buf = binaryPool.Get(n + headerLength) buf.Write(header[:headerLength]) buf.Write(payload) var contents = buf.Bytes() if !c.isServer { internal.MaskXOR(contents[headerLength:], maskBytes) } - return buf, index, nil + return buf, nil } -func (c *Conn) compressData(opcode Opcode, payload []byte) (*bytes.Buffer, int, error) { - var buf, index = binaryPool.Get(len(payload) + frameHeaderSize) +func (c *Conn) compressData(opcode Opcode, payload []byte) (*bytes.Buffer, error) { + var buf = binaryPool.Get(len(payload) + frameHeaderSize) buf.Write(myPadding[0:]) err := c.compressor.Compress(payload, buf) if err != nil { - return nil, 0, err + return nil, err } var contents = buf.Bytes() var payloadSize = buf.Len() - frameHeaderSize if payloadSize > c.config.WriteMaxPayloadSize { - return nil, 0, internal.CloseMessageTooLarge + return nil, internal.CloseMessageTooLarge } var header = frameHeader{} headerLength, maskBytes := header.GenerateHeader(c.isServer, true, true, opcode, payloadSize) @@ -131,7 +131,7 @@ func (c *Conn) compressData(opcode Opcode, payload []byte) (*bytes.Buffer, int, } copy(contents[frameHeaderSize-headerLength:], header[:headerLength]) buf.Next(frameHeaderSize - headerLength) - return buf, index, nil + return buf, nil } type ( @@ -145,7 +145,6 @@ type ( broadcastMessageWrapper struct { once sync.Once err error - index int frame *bytes.Buffer } ) @@ -169,7 +168,7 @@ func NewBroadcaster(opcode Opcode, payload []byte) *Broadcaster { func (c *Broadcaster) Broadcast(socket *Conn) error { var idx = internal.SelectValue(socket.compressEnabled, 1, 0) var msg = c.msgs[idx] - msg.once.Do(func() { msg.frame, msg.index, msg.err = socket.genFrame(c.opcode, c.payload) }) + msg.once.Do(func() { msg.frame, msg.err = socket.genFrame(c.opcode, c.payload) }) if msg.err != nil { return msg.err } @@ -189,7 +188,7 @@ func (c *Broadcaster) Broadcast(socket *Conn) error { func (c *Broadcaster) doClose() { for _, item := range c.msgs { if item != nil { - binaryPool.Put(item.frame, item.index) + binaryPool.Put(item.frame) } } } From e5db458c82df91dc4b7258054d6feec566fd6dae Mon Sep 17 00:00:00 2001 From: lxzan Date: Thu, 9 Nov 2023 17:41:08 +0800 Subject: [PATCH 2/2] Curing anonymous functions --- internal/pool.go | 14 +++---- task.go | 105 +++++++++++++++++++++++++++++++++++++++-------- task_test.go | 44 ++++++++++++++++---- writer.go | 39 ++++++++++-------- 4 files changed, 152 insertions(+), 50 deletions(-) diff --git a/internal/pool.go b/internal/pool.go index 2d8b24f8..7734c3b8 100644 --- a/internal/pool.go +++ b/internal/pool.go @@ -42,8 +42,8 @@ func (p *BufferPool) Put(b *bytes.Buffer) { if b == nil || b.Cap() == 0 { return } - if i := p.getIndex(uint32(b.Cap())); i > 0 { - p.pools[i].Put(b) + if index := p.getIndex(uint32(b.Cap())); index > 0 { + p.pools[index].Put(b) } } @@ -53,12 +53,12 @@ func (p *BufferPool) Get(n int) *bytes.Buffer { return bytes.NewBuffer(make([]byte, 0, n)) } - buf := p.pools[index].Get().(*bytes.Buffer) - if buf.Cap() < n { - buf.Grow(p.limits[index]) + b := p.pools[index].Get().(*bytes.Buffer) + if b.Cap() < n { + b.Grow(p.limits[index]) } - buf.Reset() - return buf + b.Reset() + return b } func (p *BufferPool) getIndex(v uint32) int { diff --git a/task.go b/task.go index 1cd289fd..30c7008f 100644 --- a/task.go +++ b/task.go @@ -1,18 +1,24 @@ package gws import ( + "bytes" "sync" ) type ( workerQueue struct { mu sync.Mutex // 锁 - q []asyncJob // 任务队列 + q heap // 任务队列 maxConcurrency int32 // 最大并发 curConcurrency int32 // 当前并发 } - asyncJob func() + asyncJob struct { + serial int + socket *Conn + frame *bytes.Buffer + execute func(conn *Conn, buffer *bytes.Buffer) + } ) // newWorkerQueue 创建一个任务队列 @@ -25,28 +31,19 @@ func newWorkerQueue(maxConcurrency int32) *workerQueue { return c } -func (c *workerQueue) pop() asyncJob { - if len(c.q) == 0 { - return nil - } - var job = c.q[0] - c.q = c.q[1:] - return job -} - // 获取一个任务 -func (c *workerQueue) getJob(newJob asyncJob, delta int32) asyncJob { +func (c *workerQueue) getJob(newJob *asyncJob, delta int32) *asyncJob { c.mu.Lock() defer c.mu.Unlock() if newJob != nil { - c.q = append(c.q, newJob) + c.q.Push(newJob) } c.curConcurrency += delta if c.curConcurrency >= c.maxConcurrency { return nil } - var job = c.pop() + var job = c.q.Pop() if job == nil { return nil } @@ -55,15 +52,15 @@ func (c *workerQueue) getJob(newJob asyncJob, delta int32) asyncJob { } // 循环执行任务 -func (c *workerQueue) do(job asyncJob) { +func (c *workerQueue) do(job *asyncJob) { for job != nil { - job() + job.execute(job.socket, job.frame) job = c.getJob(nil, -1) } } // Push 追加任务, 有资源空闲的话会立即执行 -func (c *workerQueue) Push(job asyncJob) { +func (c *workerQueue) Push(job *asyncJob) { if nextJob := c.getJob(job, 0); nextJob != nil { go c.do(nextJob) } @@ -78,8 +75,80 @@ func (c channel) done() { <-c } func (c channel) Go(m *Message, f func(*Message) error) error { c.add() go func() { - f(m) + _ = f(m) c.done() }() return nil } + +type heap struct { + data []*asyncJob + serial int +} + +func (c *heap) next() int { + c.serial++ + return c.serial +} + +func (c *heap) less(i, j int) bool { + return c.data[i].serial < c.data[j].serial +} + +func (c *heap) Len() int { + return len(c.data) +} + +func (c *heap) swap(i, j int) { + c.data[i], c.data[j] = c.data[j], c.data[i] +} + +func (c *heap) Push(v *asyncJob) { + if v.serial == 0 { + v.serial = c.next() + } + c.data = append(c.data, v) + c.up(c.Len() - 1) +} + +func (c *heap) up(i int) { + var j = (i - 1) / 2 + if i >= 1 && c.less(i, j) { + c.swap(i, j) + c.up(j) + } +} + +func (c *heap) Pop() *asyncJob { + n := c.Len() + switch n { + case 0: + return nil + case 1: + v := c.data[0] + c.data = c.data[:0] + return v + default: + v := c.data[0] + c.data[0] = c.data[n-1] + c.data = c.data[:n-1] + c.down(0, n-1) + return v + } +} + +func (c *heap) down(i, n int) { + var j = 2*i + 1 + var k = 2*i + 2 + var x = -1 + if j < n { + x = j + } + if k < n && c.less(k, j) { + x = k + } + if x != -1 && c.less(x, i) { + c.swap(i, x) + c.down(x, n) + } +} diff --git a/task_test.go b/task_test.go index 538d7a62..04a216db 100644 --- a/task_test.go +++ b/task_test.go @@ -2,8 +2,10 @@ package gws import ( "bufio" + "bytes" "fmt" "net" + "sort" "sync" "sync/atomic" "testing" @@ -233,14 +235,14 @@ func TestTaskQueue(t *testing.T) { listA = append(listA, i) v := i - q.Push(func() { + q.Push(&asyncJob{execute: func(conn *Conn, buffer *bytes.Buffer) { defer wg.Done() var latency = time.Duration(internal.AlphabetNumeric.Intn(100)) * time.Microsecond time.Sleep(latency) mu.Lock() listB = append(listB, v) mu.Unlock() - }) + }}) } wg.Wait() as.ElementsMatch(listA, listB) @@ -253,11 +255,11 @@ func TestTaskQueue(t *testing.T) { wg.Add(1000) for i := int64(1); i <= 1000; i++ { var tmp = i - w.Push(func() { + w.Push(&asyncJob{execute: func(conn *Conn, buffer *bytes.Buffer) { time.Sleep(time.Millisecond) atomic.AddInt64(&sum, tmp) wg.Done() - }) + }}) } wg.Wait() as.Equal(sum, int64(500500)) @@ -270,11 +272,11 @@ func TestTaskQueue(t *testing.T) { wg.Add(1000) for i := int64(1); i <= 1000; i++ { var tmp = i - w.Push(func() { + w.Push(&asyncJob{execute: func(conn *Conn, buffer *bytes.Buffer) { time.Sleep(time.Millisecond) atomic.AddInt64(&sum, tmp) wg.Done() - }) + }}) } wg.Wait() as.Equal(sum, int64(500500)) @@ -348,7 +350,7 @@ func TestRQueue(t *testing.T) { var serial = int64(0) var done = make(chan struct{}) for i := 0; i < total; i++ { - q.Push(func() { + q.Push(&asyncJob{execute: func(conn *Conn, buffer *bytes.Buffer) { x := atomic.AddInt64(&concurrency, 1) assert.LessOrEqual(t, x, int64(limit)) time.Sleep(10 * time.Millisecond) @@ -356,8 +358,34 @@ func TestRQueue(t *testing.T) { if atomic.AddInt64(&serial, 1) == total { done <- struct{}{} } - }) + }}) } <-done }) } + +func TestHeap_Sort(t *testing.T) { + var count = 1000 + var list0 []int + var list1 []int + var h heap + for i := 0; i < count; i++ { + var v = internal.Numeric.Intn(count) + 1 + list0 = append(list0, v) + h.Push(&asyncJob{serial: v}) + } + + sort.Ints(list0) + for h.Len() > 0 { + list1 = append(list1, h.Pop().serial) + } + for i := 0; i < count; i++ { + assert.Equal(t, list0[i], list1[i]) + } + assert.Zero(t, h.Len()) +} + +func TestHeap_Pop(t *testing.T) { + var h = heap{} + assert.Nil(t, h.Pop()) +} diff --git a/writer.go b/writer.go index c449b529..381d5a6b 100644 --- a/writer.go +++ b/writer.go @@ -40,6 +40,15 @@ func (c *Conn) WriteString(s string) error { return c.WriteMessage(OpcodeText, internal.StringToBytes(s)) } +func writeAsync(socket *Conn, buffer *bytes.Buffer) { + if socket.isClosed() { + return + } + err := internal.WriteN(socket.conn, buffer.Bytes()) + binaryPool.Put(buffer) + socket.emitError(err) +} + // WriteAsync 异步非阻塞地写入消息 // Write messages asynchronously and non-blocking func (c *Conn) WriteAsync(opcode Opcode, payload []byte) error { @@ -48,15 +57,8 @@ func (c *Conn) WriteAsync(opcode Opcode, payload []byte) error { c.emitError(err) return err } - - c.writeQueue.Push(func() { - if c.isClosed() { - return - } - err = internal.WriteN(c.conn, frame.Bytes()) - binaryPool.Put(frame) - c.emitError(err) - }) + job := &asyncJob{socket: c, frame: frame, execute: writeAsync} + c.writeQueue.Push(job) return nil } @@ -162,6 +164,15 @@ func NewBroadcaster(opcode Opcode, payload []byte) *Broadcaster { return c } +func (c *Broadcaster) writeAsync(socket *Conn, buffer *bytes.Buffer) { + if !socket.isClosed() { + socket.emitError(internal.WriteN(socket.conn, buffer.Bytes())) + } + if atomic.AddInt64(&c.state, -1) == 0 { + c.doClose() + } +} + // Broadcast 广播 // 向客户端发送广播消息 // Send a broadcast message to a client. @@ -174,14 +185,8 @@ func (c *Broadcaster) Broadcast(socket *Conn) error { } atomic.AddInt64(&c.state, 1) - socket.writeQueue.Push(func() { - if !socket.isClosed() { - socket.emitError(internal.WriteN(socket.conn, msg.frame.Bytes())) - } - if atomic.AddInt64(&c.state, -1) == 0 { - c.doClose() - } - }) + var job = &asyncJob{socket: socket, frame: msg.frame, execute: c.writeAsync} + socket.writeQueue.Push(job) return nil }