diff --git a/conn_std.go b/conn_std.go index 539c59a9..f27b819e 100644 --- a/conn_std.go +++ b/conn_std.go @@ -29,7 +29,7 @@ type Conn struct { conn net.Conn connUDP *udpConn - rTimer *timer.Item + rTimer *time.Timer typ ConnType closed bool diff --git a/conn_unix.go b/conn_unix.go index 7695111a..8a84f7c8 100644 --- a/conn_unix.go +++ b/conn_unix.go @@ -17,7 +17,6 @@ import ( "time" "github.com/lesismal/nbio/mempool" - "github.com/lesismal/nbio/timer" ) // Conn implements net.Conn. @@ -30,8 +29,8 @@ type Conn struct { connUDP *udpConn - rTimer *timer.Item - wTimer *timer.Item + rTimer *time.Timer + wTimer *time.Timer writeBuffer []byte @@ -253,16 +252,15 @@ func (c *Conn) SetDeadline(t time.Time) error { if !c.closed { if !t.IsZero() { g := c.p.g - now := time.Now() if c.rTimer == nil { - c.rTimer = g.AfterFunc(t.Sub(now), func() { c.closeWithError(errReadTimeout) }) + c.rTimer = g.AfterFunc(time.Until(t), func() { c.closeWithError(errReadTimeout) }) } else { - c.rTimer.Reset(t.Sub(now)) + c.rTimer.Reset(time.Until(t)) } if c.wTimer == nil { - c.wTimer = g.AfterFunc(t.Sub(now), func() { c.closeWithError(errWriteTimeout) }) + c.wTimer = g.AfterFunc(time.Until(t), func() { c.closeWithError(errWriteTimeout) }) } else { - c.wTimer.Reset(t.Sub(now)) + c.wTimer.Reset(time.Until(t)) } } else { if c.rTimer != nil { @@ -279,7 +277,7 @@ func (c *Conn) SetDeadline(t time.Time) error { return nil } -func (c *Conn) setDeadline(timer **timer.Item, returnErr error, t time.Time) error { +func (c *Conn) setDeadline(timer **time.Timer, returnErr error, t time.Time) error { c.mux.Lock() defer c.mux.Unlock() if c.closed { @@ -287,9 +285,9 @@ func (c *Conn) setDeadline(timer **timer.Item, returnErr error, t time.Time) err } if !t.IsZero() { if *timer == nil { - *timer = c.p.g.UntilFunc(t, func() { c.closeWithError(returnErr) }) + *timer = c.p.g.AfterFunc(time.Until(t), func() { c.closeWithError(returnErr) }) } else { - (*timer).ResetUntil(t) + (*timer).Reset(time.Until(t)) } } else if *timer != nil { (*timer).Stop() @@ -387,7 +385,7 @@ func (c *Conn) write(b []byte) (int, error) { } if c.overflow(len(b)) { - return -1, syscall.EINVAL + return -1, errOverflow } if len(c.writeBuffer) == 0 { @@ -469,7 +467,7 @@ func (c *Conn) writev(in [][]byte) (int, error) { size += len(v) } if c.overflow(size) { - return -1, syscall.EINVAL + return -1, errOverflow } if len(c.writeBuffer) > 0 { for _, v := range in { diff --git a/engine.go b/engine.go index 4ff257d5..f3f4184a 100644 --- a/engine.go +++ b/engine.go @@ -52,7 +52,7 @@ type Config struct { // NPoller represents poller goroutine num, it's set to runtime.NumCPU() by default. NPoller int - // ReadBufferSize represents buffer size for reading, it's set to 16k by default. + // ReadBufferSize represents buffer size for reading, it's set to 64k by default. ReadBufferSize int // MaxWriteBufferSize represents max write buffer size for Conn, it's set to 1m by default. @@ -233,10 +233,10 @@ func (g *Engine) OnClose(h func(c *Conn, err error)) { panic("invalid nil handler") } g.onClose = func(c *Conn, err error) { - // g.Async(func() { - defer g.wgConn.Done() - h(c, err) - // }) + g.Async(func() { + defer g.wgConn.Done() + h(c, err) + }) } } diff --git a/engine_std.go b/engine_std.go index 7891d766..4fcac4a5 100644 --- a/engine_std.go +++ b/engine_std.go @@ -96,7 +96,7 @@ func (g *Engine) Start() error { } } - g.Timer.Start() + // g.Timer.Start() if len(g.addrs) == 0 { logging.Info("NBIO[%v] start", g.Name) diff --git a/error.go b/error.go index 07b99f8b..fa1b2642 100644 --- a/error.go +++ b/error.go @@ -11,4 +11,5 @@ import ( var ( errReadTimeout = errors.New("read timeout") errWriteTimeout = errors.New("write timeout") + errOverflow = errors.New("write overflow") ) diff --git a/logging/log.go b/logging/log.go index 40968266..cee66e65 100644 --- a/logging/log.go +++ b/logging/log.go @@ -6,6 +6,7 @@ package logging import ( "fmt" + "io" "os" "time" ) @@ -15,7 +16,7 @@ var ( TimeFormat = "2006/01/02 15:04:05.000" // Output is used to receive log output. - Output = os.Stdout + Output io.Writer = os.Stdout // DefaultLogger is the default logger and is used by arpc. DefaultLogger Logger = &logger{level: LevelInfo} diff --git a/nbhttp/engine.go b/nbhttp/engine.go index ea9222fc..3a38ff9a 100644 --- a/nbhttp/engine.go +++ b/nbhttp/engine.go @@ -7,6 +7,7 @@ package nbhttp import ( "context" "errors" + "io" "net" "net/http" "runtime" @@ -53,7 +54,7 @@ const ( // DefaultKeepaliveTime . DefaultKeepaliveTime = time.Second * 120 - // DefaultBlockingReadBufferSize sets to 4k(<= goroutine stack size). + // DefaultBlockingReadBufferSize sets to 4k. DefaultBlockingReadBufferSize = 1024 * 4 ) @@ -116,7 +117,7 @@ type Config struct { // ReadLimit represents the max size for parser reading, it's set to 64M by default. ReadLimit int - // ReadBufferSize represents buffer size for reading, it's set to 32k by default. + // ReadBufferSize represents buffer size for reading, it's set to 64k by default. ReadBufferSize int // MaxWriteBufferSize represents max write buffer size for Conn, it's set to 1m by default. @@ -200,15 +201,10 @@ type Config struct { ReadBufferPool mempool.Allocator // WebsocketCompressor . - WebsocketCompressor func() interface { - Compress([]byte) []byte - Close() - } + WebsocketCompressor func(w io.WriteCloser, level int) io.WriteCloser + // WebsocketDecompressor . - WebsocketDecompressor func() interface { - Decompress([]byte) ([]byte, error) - Close() - } + WebsocketDecompressor func(r io.Reader) io.ReadCloser } // Engine . @@ -546,6 +542,7 @@ func (engine *Engine) AddTransferredConn(nbc *nbio.Conn) error { key, err := conn2Array(nbc) if err != nil { nbc.Close() + logging.Error("AddTransferredConn failed: %v", err) return err } @@ -553,6 +550,7 @@ func (engine *Engine) AddTransferredConn(nbc *nbio.Conn) error { if len(engine.conns) >= engine.MaxLoad { engine.mux.Unlock() nbc.Close() + logging.Error("AddTransferredConn failed: overload, already has %v online", engine.MaxLoad) return ErrServiceOverload } engine.conns[key] = struct{}{} @@ -567,6 +565,7 @@ func (engine *Engine) AddConnNonTLSNonBlocking(c net.Conn, tlsConfig *tls.Config nbc, err := nbio.NBConn(c) if err != nil { c.Close() + logging.Error("AddConnNonTLSNonBlocking failed: %v", err) return } if nbc.Session() != nil { @@ -576,6 +575,7 @@ func (engine *Engine) AddConnNonTLSNonBlocking(c net.Conn, tlsConfig *tls.Config key, err := conn2Array(nbc) if err != nil { nbc.Close() + logging.Error("AddConnNonTLSNonBlocking failed: %v", err) return } @@ -583,6 +583,7 @@ func (engine *Engine) AddConnNonTLSNonBlocking(c net.Conn, tlsConfig *tls.Config if len(engine.conns) >= engine.MaxLoad { engine.mux.Unlock() nbc.Close() + logging.Error("AddConnNonTLSNonBlocking failed: overload, already has %v online", engine.MaxLoad) return } engine.conns[key] = struct{}{} @@ -606,6 +607,7 @@ func (engine *Engine) AddConnNonTLSBlocking(conn net.Conn, tlsConfig *tls.Config engine.mux.Lock() if len(engine.conns) >= engine.MaxLoad { engine.mux.Unlock() + logging.Error("AddConnNonTLSBlocking failed: overload, already has %v online", engine.MaxLoad) conn.Close() decrease() return @@ -617,6 +619,7 @@ func (engine *Engine) AddConnNonTLSBlocking(conn net.Conn, tlsConfig *tls.Config engine.mux.Unlock() conn.Close() decrease() + logging.Error("AddConnNonTLSBlocking failed: %v", err) return } engine.conns[key] = struct{}{} @@ -624,6 +627,7 @@ func (engine *Engine) AddConnNonTLSBlocking(conn net.Conn, tlsConfig *tls.Config engine.mux.Unlock() conn.Close() decrease() + logging.Error("AddConnNonTLSBlocking failed: unknown conn type: %v", vt) return } engine.mux.Unlock() @@ -641,15 +645,18 @@ func (engine *Engine) AddConnTLSNonBlocking(conn net.Conn, tlsConfig *tls.Config nbc, err := nbio.NBConn(conn) if err != nil { conn.Close() + logging.Error("AddConnTLSNonBlocking failed: %v", err) return } if nbc.Session() != nil { nbc.Close() + logging.Error("AddConnTLSNonBlocking failed: session should not be nil") return } key, err := conn2Array(nbc) if err != nil { nbc.Close() + logging.Error("AddConnTLSNonBlocking failed: %v", err) return } @@ -657,6 +664,7 @@ func (engine *Engine) AddConnTLSNonBlocking(conn net.Conn, tlsConfig *tls.Config if len(engine.conns) >= engine.MaxLoad { engine.mux.Unlock() nbc.Close() + logging.Error("AddConnTLSNonBlocking failed: overload, already has %v online", engine.MaxLoad) return } @@ -689,6 +697,7 @@ func (engine *Engine) AddConnTLSBlocking(conn net.Conn, tlsConfig *tls.Config, d engine.mux.Unlock() conn.Close() decrease() + logging.Error("AddConnTLSBlocking failed: overload, already has %v online", engine.MaxLoad) return } @@ -699,6 +708,7 @@ func (engine *Engine) AddConnTLSBlocking(conn net.Conn, tlsConfig *tls.Config, d engine.mux.Unlock() conn.Close() decrease() + logging.Error("AddConnTLSBlocking failed: %v", err) return } engine.conns[key] = struct{}{} @@ -706,6 +716,7 @@ func (engine *Engine) AddConnTLSBlocking(conn net.Conn, tlsConfig *tls.Config, d engine.mux.Unlock() conn.Close() decrease() + logging.Error("AddConnTLSBlocking unknown conn type: %v", vt) return } engine.mux.Unlock() diff --git a/nbhttp/websocket/conn.go b/nbhttp/websocket/conn.go index 7534535e..98b70377 100644 --- a/nbhttp/websocket/conn.go +++ b/nbhttp/websocket/conn.go @@ -74,6 +74,16 @@ type Conn struct { message []byte } +// IsClient . +func (c *Conn) IsClient() bool { + return c.isClient +} + +// SetClient . +func (c *Conn) SetClient(isClient bool) { + c.isClient = isClient +} + // IsBlockingMod . func (c *Conn) IsBlockingMod() bool { return c.isBlockingMod @@ -158,55 +168,80 @@ func (c *Conn) handleProtocolMessage(p *nbhttp.Parser, opcode MessageType, body } func (c *Conn) handleWsMessage(opcode MessageType, data []byte) { + const errInvalidUtf8Text = "invalid UTF-8 bytes" + if c.KeepaliveTime > 0 { defer c.SetReadDeadline(time.Now().Add(c.KeepaliveTime)) } + switch opcode { case BinaryMessage: c.messageHandler(c, opcode, data) + return case TextMessage: if !c.Engine.CheckUtf8(data) { - const errText = "Invalid UTF-8 bytes" - protoErrorData := make([]byte, 2+len(errText)) + protoErrorData := make([]byte, 2+len(errInvalidUtf8Text)) binary.BigEndian.PutUint16(protoErrorData, 1002) - copy(protoErrorData[2:], errText) + copy(protoErrorData[2:], errInvalidUtf8Text) + c.SetCloseError(ErrInvalidUtf8) c.WriteMessage(CloseMessage, protoErrorData) - return + goto ErrExit } c.messageHandler(c, opcode, data) + return + case PingMessage: + c.pingMessageHandler(c, string(data)) + return + case PongMessage: + c.pongMessageHandler(c, string(data)) + return case CloseMessage: if len(data) >= 2 { code := int(binary.BigEndian.Uint16(data[:2])) - if !validCloseCode(code) || !c.Engine.CheckUtf8(data[2:]) { + if !validCloseCode(code) { protoErrorCode := make([]byte, 2) binary.BigEndian.PutUint16(protoErrorCode, 1002) + c.SetCloseError(ErrInvalidCloseCode) c.WriteMessage(CloseMessage, protoErrorCode) - } else { - reson := string(data[2:]) + goto ErrExit + } + if !c.Engine.CheckUtf8(data[2:]) { + protoErrorData := make([]byte, 2+len(errInvalidUtf8Text)) + binary.BigEndian.PutUint16(protoErrorData, 1002) + copy(protoErrorData[2:], errInvalidUtf8Text) + c.SetCloseError(ErrInvalidUtf8) + c.WriteMessage(CloseMessage, protoErrorData) + goto ErrExit + } + + reson := string(data[2:]) + if code != 1000 { c.SetCloseError(&CloseError{ Code: code, Reason: reson, }) - c.closeMessageHandler(c, code, reson) } + c.closeMessageHandler(c, code, reson) } else { - c.WriteMessage(CloseMessage, nil) + c.SetCloseError(ErrInvalidControlFrame) } - // close immediately, no need to wait for data flushed on a blocked conn - c.Conn.Close() - case PingMessage: - c.pingMessageHandler(c, string(data)) - case PongMessage: - c.pongMessageHandler(c, string(data)) case FragmentMessage: logging.Debug("invalid fragment message") - c.Conn.Close() + c.SetCloseError(ErrInvalidFragmentMessage) default: + logging.Debug("invalid message type: %v", opcode) + c.SetCloseError(fmt.Errorf("websocket: invalid message type: %v", opcode)) + } + +ErrExit: + if c.IsAsyncWrite() { + c.Engine.AfterFunc(time.Second, func() { c.Conn.Close() }) + } else { c.Conn.Close() } } -func (c *Conn) nextFrame() (opcode MessageType, body []byte, ok, fin, res1, res2, res3 bool) { +func (c *Conn) nextFrame() (opcode MessageType, body []byte, ok, fin, res1, res2, res3 bool, err error) { l := int64(len(c.buffer)) headLen := int64(2) if l >= 2 { @@ -232,6 +267,13 @@ func (c *Conn) nextFrame() (opcode MessageType, body []byte, ok, fin, res1, res2 default: bodyLen = int64(payloadLen) } + + if (bodyLen > maxControlFramePayloadSize) && + ((opcode == PingMessage) || (opcode == PongMessage) || (opcode == CloseMessage)) { + err = ErrControlMessageTooBig + return + } + if bodyLen >= 0 { masked := (c.buffer[1] & 0x80) != 0 if masked { @@ -241,10 +283,7 @@ func (c *Conn) nextFrame() (opcode MessageType, body []byte, ok, fin, res1, res2 if l >= total { body = c.buffer[headLen:total] if masked { - maskKey := c.buffer[headLen-4 : headLen] - for i := 0; i < len(body); i++ { - body[i] ^= maskKey[i%4] - } + maskXOR(body, c.buffer[headLen-4:headLen]) } ok = true @@ -253,7 +292,7 @@ func (c *Conn) nextFrame() (opcode MessageType, body []byte, ok, fin, res1, res2 } } - return opcode, body, ok, fin, res1, res2, res3 + return opcode, body, ok, fin, res1, res2, res3, err } // Read . @@ -274,7 +313,11 @@ func (c *Conn) Read(p *nbhttp.Parser, data []byte) error { var err error for i := 0; true; i++ { - opcode, body, ok, fin, res1, res2, res3 := c.nextFrame() + opcode, body, ok, fin, res1, res2, res3, e := c.nextFrame() + if e != nil { + err = e + break + } if !ok { break } @@ -322,26 +365,19 @@ func (c *Conn) Read(p *nbhttp.Parser, data []byte) error { if fin { if c.messageHandler != nil { if c.compress { + var b []byte + var rc io.ReadCloser if c.Engine.WebsocketDecompressor != nil { - var b []byte - decompressor := c.Engine.WebsocketDecompressor() - defer decompressor.Close() - b, err = decompressor.Decompress(c.message) - if err != nil { - break - } - c.Engine.BodyAllocator.Free(c.message) - c.message = b + rc = c.Engine.WebsocketDecompressor(io.MultiReader(bytes.NewBuffer(c.message), strings.NewReader(flateReaderTail))) } else { - var b []byte - rc := decompressReader(io.MultiReader(bytes.NewBuffer(c.message), strings.NewReader(flateReaderTail))) - b, err = c.readAll(rc, len(c.message)*2) - c.Engine.BodyAllocator.Free(c.message) - c.message = b - rc.Close() - if err != nil { - break - } + rc = decompressReader(io.MultiReader(bytes.NewBuffer(c.message), strings.NewReader(flateReaderTail))) + } + b, err = c.readAll(rc, len(c.message)*2) + c.Engine.BodyAllocator.Free(c.message) + c.message = b + rc.Close() + if err != nil { + break } } c.handleMessage(p, c.opcode, c.message) @@ -468,7 +504,7 @@ func (c *Conn) WriteMessage(messageType MessageType, data []byte) error { case BinaryMessage: case PingMessage, PongMessage, CloseMessage: if len(data) > maxControlFramePayloadSize { - return ErrInvalidControlFrame + return ErrControlMessageTooBig } case FragmentMessage: default: @@ -479,24 +515,24 @@ func (c *Conn) WriteMessage(messageType MessageType, data []byte) error { // compress = true // if user customize mempool, they should promise it's safe to mempool.Free a buffer which is not from their mempool.Malloc // or we need to implement a writebuffer that use mempool.Realloc to grow or append the buffer + w := &writeBuffer{ + Buffer: bytes.NewBuffer(mempool.Malloc(len(data))), + } + defer w.Close() + w.Reset() + + var cw io.WriteCloser if c.Engine.WebsocketCompressor != nil { - compressor := c.Engine.WebsocketCompressor() - defer compressor.Close() - data = compressor.Compress(data) + cw = c.Engine.WebsocketCompressor(w, c.compressionLevel) } else { - w := &writeBuffer{ - Buffer: bytes.NewBuffer(mempool.Malloc(len(data))), - } - defer w.Close() - w.Reset() - cw := compressWriter(w, c.compressionLevel) - _, err := cw.Write(data) - if err != nil { - compress = false - } else { - cw.Close() - data = w.Bytes() - } + cw = compressWriter(w, c.compressionLevel) + } + _, err := cw.Write(data) + if err != nil { + compress = false + } else { + cw.Close() + data = w.Bytes() } } @@ -620,11 +656,9 @@ func (c *Conn) writeFrame(messageType MessageType, sendOpcode, fin bool, data [] if c.isClient { u32 := rand.Uint32() - maskKey := []byte{byte(u32), byte(u32 >> 8), byte(u32 >> 16), byte(u32 >> 24)} - copy(buf[headLen-4:headLen], maskKey) - for i := 0; i < len(data); i++ { - buf[headLen+i] = (data[i] ^ maskKey[i%4]) - } + binary.LittleEndian.PutUint32(buf[headLen-4:headLen], u32) + copy(buf[headLen:], data) + maskXOR(buf[headLen:], buf[headLen-4:headLen]) } else { copy(buf[headLen:], data) } @@ -865,3 +899,39 @@ func validCloseCode(code int) bool { } return false } + +func maskXOR(b, key []byte) { + key64 := uint64(binary.LittleEndian.Uint32(key)) + key64 |= (key64 << 32) + + for len(b) >= 64 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + b = b[64:] + } + + for len(b) >= 8 { + v := binary.LittleEndian.Uint64(b[:8]) + binary.LittleEndian.PutUint64(b[:8], v^key64) + b = b[8:] + } + + for i := 0; i < len(b); i++ { + idx := i & 3 + b[i] ^= key[idx] + } +} diff --git a/nbhttp/websocket/error.go b/nbhttp/websocket/error.go index ca31cd7d..e81f8b35 100644 --- a/nbhttp/websocket/error.go +++ b/nbhttp/websocket/error.go @@ -46,6 +46,9 @@ var ( // ErrControlMessageFragmented . ErrControlMessageFragmented = errors.New("websocket: control messages must not be fragmented") + // ErrControlMessageTooBig . + ErrControlMessageTooBig = errors.New("websocket: control frame length > 125") + // ErrFragmentsShouldNotHaveBinaryOrTextOpcode . ErrFragmentsShouldNotHaveBinaryOrTextOpcode = errors.New("websocket: fragments should not have opcode of text or binary") @@ -58,8 +61,14 @@ var ( // ErrInvalidCompression . ErrInvalidCompression = errors.New("websocket: invalid compression negotiation") + // ErrInvalidUtf8 . + ErrInvalidUtf8 = errors.New("websocket: invalid UTF-8 bytes") + + // ErrInvalidFragmentMessage . + ErrInvalidFragmentMessage = errors.New("invalid fragment message") + // ErrMalformedURL . - ErrMalformedURL = errors.New("malformed ws or wss URL") + ErrMalformedURL = errors.New("websocket: malformed ws or wss URL") // ErrMessageTooLarge. ErrMessageTooLarge = errors.New("message exceeds the configured limit") diff --git a/nbhttp/websocket/upgrader.go b/nbhttp/websocket/upgrader.go index 3c6ad2dd..b4f85e44 100644 --- a/nbhttp/websocket/upgrader.go +++ b/nbhttp/websocket/upgrader.go @@ -101,10 +101,6 @@ func NewUpgrader() *Upgrader { BlockingModSendQueueMaxSize: DefaultBlockingModSendQueueMaxSize, } u.pingMessageHandler = func(c *Conn, data string) { - if len(data) > 125 { - c.Close() - return - } err := c.WriteMessage(PongMessage, []byte(data)) if err != nil { logging.Debug("failed to send pong %v", err) @@ -114,9 +110,6 @@ func NewUpgrader() *Upgrader { } u.pongMessageHandler = func(*Conn, string) {} u.closeMessageHandler = func(c *Conn, code int, text string) { - if len(text)+2 > maxControlFramePayloadSize { - return //ErrInvalidControlFrame - } buf := mempool.Malloc(len(text) + 2) binary.BigEndian.PutUint16(buf[:2], uint16(code)) copy(buf[2:], text) diff --git a/timer/timer.go b/timer/timer.go index 57292998..cd3c00be 100644 --- a/timer/timer.go +++ b/timer/timer.go @@ -5,10 +5,8 @@ package timer import ( - "container/heap" "math" "runtime" - "sync" "time" "unsafe" @@ -20,290 +18,58 @@ const ( ) type Timer struct { - name string - - wg sync.WaitGroup - mux sync.Mutex - + name string executor func(f func()) - - chCalling chan struct{} - callings []func() - - trigger *time.Timer - items timerHeap - - chClose chan struct{} } func New(name string, executor func(f func())) *Timer { - t := &Timer{} - - t.mux.Lock() - t.name = name - t.executor = executor - t.callings = []func(){} - t.chCalling = make(chan struct{}, 1) - t.trigger = time.NewTimer(TimeForever) - t.chClose = make(chan struct{}) - t.mux.Unlock() + return &Timer{name: name, executor: executor} +} - return t +// IsTimerRunning . +func (t *Timer) IsTimerRunning() bool { + return true } // Start . -func (t *Timer) Start() { - t.wg.Add(1) - go t.loop() -} +func (t *Timer) Start() {} // Stop . -func (t *Timer) Stop() { - close(t.chClose) - t.wg.Wait() -} +func (t *Timer) Stop() {} // After used as time.After. -func (t *Timer) After(timeout time.Duration) <-chan time.Time { - c := make(chan time.Time, 1) - t.AfterFunc(timeout, func() { - c <- time.Now() - }) - return c +func (t *Timer) After(d time.Duration) <-chan time.Time { + return time.After(d) } // AfterFunc used as time.AfterFunc. -func (t *Timer) AfterFunc(timeout time.Duration, f func()) *Item { - t.mux.Lock() - - now := time.Now() - it := &Item{ - index: len(t.items), - expire: now.Add(timeout), - f: f, - parent: t, - } - - heap.Push(&t.items, it) - if t.items[0] == it { - t.trigger.Reset(timeout) - } - - t.mux.Unlock() - - return it -} - -// UntilFunc call f when time expire. -func (t *Timer) UntilFunc(expire time.Time, f func()) *Item { - t.mux.Lock() - - it := &Item{ - index: len(t.items), - expire: expire, - f: f, - parent: t, - } - - heap.Push(&t.items, it) - if t.items[0] == it { - t.trigger.Reset(time.Until(it.expire)) - } - - t.mux.Unlock() - - return it +func (t *Timer) AfterFunc(timeout time.Duration, f func()) *time.Timer { + return time.AfterFunc(timeout, func() { + defer func() { + err := recover() + if err != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + logging.Error("Timer[%v] exec call failed: %v\n%v\n", t.name, err, *(*string)(unsafe.Pointer(&buf))) + } + }() + f() + }) } // Async executes f in another goroutine. func (t *Timer) Async(f func()) { - if f != nil { - t.mux.Lock() - t.callings = append(t.callings, f) - t.mux.Unlock() - select { - case t.chCalling <- struct{}{}: - default: - } - } -} - -func (t *Timer) removeTimer(it *Item) { - t.mux.Lock() - defer t.mux.Unlock() - - index := it.index - if index < 0 || index >= len(t.items) { - return - } - - if t.items[index] == it { - heap.Remove(&t.items, index) - if len(t.items) > 0 { - if index == 0 { - t.trigger.Reset(time.Until(t.items[0].expire)) + go func() { + defer func() { + err := recover() + if err != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + logging.Error("Timer[%v] exec call failed: %v\n%v\n", t.name, err, *(*string)(unsafe.Pointer(&buf))) } - } else { - t.trigger.Reset(TimeForever) - } - } -} - -func (t *Timer) resetTimer(it *Item, d time.Duration) { - t.mux.Lock() - defer t.mux.Unlock() - - index := it.index - if index < 0 || index >= len(t.items) { - return - } - - if t.items[index] == it { - it.expire = time.Now().Add(d) - heap.Fix(&t.items, index) - if index == 0 || it.index == 0 { - t.trigger.Reset(time.Until(t.items[0].expire)) - } - } -} - -func (t *Timer) resetTimerUntil(it *Item, expire time.Time) { - t.mux.Lock() - defer t.mux.Unlock() - - index := it.index - if index < 0 || index >= len(t.items) { - return - } - - if t.items[index] == it { - it.expire = expire - heap.Fix(&t.items, index) - if index == 0 || it.index == 0 { - t.trigger.Reset(time.Until(t.items[0].expire)) - } - } -} - -func (t *Timer) loop() { - defer t.wg.Done() - logging.Debug("Timer[%v] timer start", t.name) - defer logging.Debug("Timer[%v] timer stopped", t.name) - for { - select { - case <-t.chCalling: - for { - t.mux.Lock() - if len(t.callings) == 0 { - t.callings = nil - t.mux.Unlock() - break - } - f := t.callings[0] - t.callings = t.callings[1:] - t.mux.Unlock() - if t.executor != nil { - t.executor(f) - } else { - func() { - defer func() { - err := recover() - if err != nil { - const size = 64 << 10 - buf := make([]byte, size) - buf = buf[:runtime.Stack(buf, false)] - logging.Error("Timer[%v] exec call failed: %v\n%v\n", t.name, err, *(*string)(unsafe.Pointer(&buf))) - } - }() - f() - }() - } - } - case <-t.trigger.C: - for { - t.mux.Lock() - if t.items.Len() == 0 { - t.trigger.Reset(TimeForever) - t.mux.Unlock() - break - } - now := time.Now() - it := t.items[0] - if now.After(it.expire) { - heap.Remove(&t.items, it.index) - t.mux.Unlock() - if t.executor != nil { - t.executor(it.f) - } else { - func() { - defer func() { - err := recover() - if err != nil { - const size = 64 << 10 - buf := make([]byte, size) - buf = buf[:runtime.Stack(buf, false)] - logging.Error("NBIO[%v] exec timer failed: %v\n%v\n", t.name, err, *(*string)(unsafe.Pointer(&buf))) - } - }() - it.f() - }() - } - } else { - t.trigger.Reset(it.expire.Sub(now)) - t.mux.Unlock() - break - } - } - case <-t.chClose: - return - } - } -} - -type timerHeap []*Item - -func (h timerHeap) Len() int { return len(h) } -func (h timerHeap) Less(i, j int) bool { return h[i].expire.Before(h[j].expire) } -func (h timerHeap) Swap(i, j int) { - h[i], h[j] = h[j], h[i] - h[i].index = i - h[j].index = j -} - -func (h *timerHeap) Push(x interface{}) { - *h = append(*h, x.(*Item)) - n := len(*h) - (*h)[n-1].index = n - 1 -} - -func (h *timerHeap) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - old[n-1] = nil // avoid memory leak - *h = old[0 : n-1] - return x -} - -// Item is a heap timer item. -type Item struct { - index int - expire time.Time - f func() - parent *Timer -} - -// Stop stops a timer. -func (it *Item) Stop() { - it.parent.removeTimer(it) -} - -// Reset resets timer. -func (it *Item) Reset(timeout time.Duration) { - it.parent.resetTimer(it, timeout) -} - -// ResetUntil resets timer. -func (it *Item) ResetUntil(t time.Time) { - it.parent.resetTimerUntil(it, t) + }() + f() + }() } diff --git a/timer/timer_group.go b/timer/timer_group.go index f70fb6e3..5f540a70 100644 --- a/timer/timer_group.go +++ b/timer/timer_group.go @@ -39,7 +39,7 @@ func (tg *TimerGroup) After(timeout time.Duration) <-chan time.Time { } // AfterFunc used as time.AfterFunc. -func (tg *TimerGroup) AfterFunc(timeout time.Duration, f func()) *Item { +func (tg *TimerGroup) AfterFunc(timeout time.Duration, f func()) *time.Timer { return tg.NextTimer().AfterFunc(timeout, f) } diff --git a/timer/timer_test.go b/timer/timer_test.go index 84d9fd86..ef04e990 100644 --- a/timer/timer_test.go +++ b/timer/timer_test.go @@ -1,7 +1,6 @@ package timer import ( - "container/heap" "log" "math/rand" "sync" @@ -102,7 +101,7 @@ func testTimerNormalExecMany(tg *TimerGroup, timeout time.Duration) { } func testTimerExecManyRandtime(tg *TimerGroup) { - its := make([]*Item, 100)[0:0] + its := make([]*time.Timer, 100)[0:0] ch5 := make(chan int, 100) for i := 0; i < 100; i++ { n := 500 + rand.Int()%200 @@ -132,32 +131,3 @@ LOOP_RECV: log.Panicf("invalid recved num: %v", recved) } } - -func TestTimerHeap(t *testing.T) { - now := time.Now() - th := make(timerHeap, 0, 10) - for i := 0; i < 100; i++ { - date := now.Add(time.Duration(rand.Int63n(10000)) * time.Second) - heap.Push(&th, &Item{ - expire: date, - }) - } - - last := now - for i := 0; i < 100; i++ { - if len(th) == 0 { - break - } - - item := heap.Pop(&th) - if item == nil { - break - } - cur := item.(*Item) - if cur.expire.After(last) || cur.expire.Equal(last) { - last = cur.expire - continue - } - t.Error("timer error") - } -}