From ff6eb7f9962ec13925445104c3e9d11388aa0a5a Mon Sep 17 00:00:00 2001 From: Yunxiang Huang Date: Wed, 31 Jan 2018 03:21:50 +0800 Subject: [PATCH] Fix re-auth hang. (#181) * Add test for recurring re-authing hang * Fix re-authing hang --- zk/conn.go | 45 +++++++++++++++++++++++++++++++++++----- zk/conn_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 5 deletions(-) create mode 100644 zk/conn_test.go diff --git a/zk/conn.go b/zk/conn.go index 13fb0f0c..f79a51b3 100644 --- a/zk/conn.go +++ b/zk/conn.go @@ -101,6 +101,9 @@ type Conn struct { reconnectLatch chan struct{} setWatchLimit int setWatchCallback func([]*setWatchesRequest) + // Debug (for recurring re-auth hang) + debugCloseRecvLoop bool + debugReauthDone chan struct{} logger Logger logInfo bool // true if information messages are logged; false if only errors are logged @@ -301,9 +304,9 @@ func WithMaxBufferSize(maxBufferSize int) connOption { // to a limit of 1mb. This option should be used for non-standard server setup // where znode is bigger than default 1mb. func WithMaxConnBufferSize(maxBufferSize int) connOption { - return func(c *Conn) { - c.buf = make([]byte, maxBufferSize) - } + return func(c *Conn) { + c.buf = make([]byte, maxBufferSize) + } } func (c *Conn) Close() { @@ -389,6 +392,17 @@ func (c *Conn) connect() error { } func (c *Conn) resendZkAuth(reauthReadyChan chan struct{}) { + shouldCancel := func() bool { + select { + case <-c.shouldQuit: + return true + case <-c.closeChan: + return true + default: + return false + } + } + c.credsMu.Lock() defer c.credsMu.Unlock() @@ -400,6 +414,10 @@ func (c *Conn) resendZkAuth(reauthReadyChan chan struct{}) { } for _, cred := range c.creds { + if shouldCancel() { + c.logger.Printf("Cancel rer-submitting credentials") + return + } resChan, err := c.sendRequest( opSetAuth, &setAuthRequest{Type: 0, @@ -415,7 +433,16 @@ func (c *Conn) resendZkAuth(reauthReadyChan chan struct{}) { continue } - res := <-resChan + var res response + select { + case res = <-resChan: + case <-c.closeChan: + c.logger.Printf("Recv closed, cancel re-submitting credentials") + return + case <-c.shouldQuit: + c.logger.Printf("Should quit, cancel re-submitting credentials") + return + } if res.err != nil { c.logger.Printf("Credential re-submit failed: %s", res.err) // FIXME(prozlach): lets ignore errors for now @@ -476,6 +503,9 @@ func (c *Conn) loop() { wg.Add(1) go func() { <-reauthChan + if c.debugCloseRecvLoop { + close(c.debugReauthDone) + } err := c.sendLoop() if err != nil || c.logInfo { c.logger.Printf("Send loop terminated: err=%v", err) @@ -486,7 +516,12 @@ func (c *Conn) loop() { wg.Add(1) go func() { - err := c.recvLoop(c.conn) + var err error + if c.debugCloseRecvLoop { + err = errors.New("DEBUG: close recv loop") + } else { + err = c.recvLoop(c.conn) + } if err != io.EOF || c.logInfo { c.logger.Printf("Recv loop terminated: err=%v", err) } diff --git a/zk/conn_test.go b/zk/conn_test.go new file mode 100644 index 00000000..ed4a7706 --- /dev/null +++ b/zk/conn_test.go @@ -0,0 +1,55 @@ +package zk + +import ( + "io/ioutil" + "testing" + "time" +) + +func TestRecurringReAuthHang(t *testing.T) { + sessionTimeout := 2 * time.Second + + finish := make(chan struct{}) + defer close(finish) + go func() { + select { + case <-finish: + return + case <-time.After(5 * sessionTimeout): + panic("expected not hang") + } + }() + + zkC, err := StartTestCluster(2, ioutil.Discard, ioutil.Discard) + if err != nil { + panic(err) + } + defer zkC.Stop() + + conn, evtC, err := zkC.ConnectAll() + if err != nil { + panic(err) + } + for conn.state != StateHasSession { + time.Sleep(50 * time.Millisecond) + } + + go func() { + for range evtC { + } + }() + + // Add auth. + conn.creds = append(conn.creds, authCreds{"digest", []byte("test:test")}) + + currentServer := conn.server + conn.debugCloseRecvLoop = true + conn.debugReauthDone = make(chan struct{}) + zkC.StopServer(currentServer) + // wait connect to new zookeeper. + for conn.server == currentServer && conn.state != StateHasSession { + time.Sleep(100 * time.Millisecond) + } + + <-conn.debugReauthDone +}