diff --git a/datachannel.go b/datachannel.go index 5fef7900174..52de821b7e5 100644 --- a/datachannel.go +++ b/datachannel.go @@ -320,8 +320,13 @@ func (d *DataChannel) onMessage(msg DataChannelMessage) { func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlreadyNegotiated bool) { d.mu.Lock() - if d.isGracefulClosed { + if d.isGracefulClosed { // The channel was closed during the connecting state d.mu.Unlock() + if err := dc.Close(); err != nil { + d.log.Errorf("Failed to close DataChannel that was closed during connecting state %v", err.Error()) + } + d.onClose() + return } d.dataChannel = dc diff --git a/datachannel_go_test.go b/datachannel_go_test.go index ccce0ea0a48..88d9fb36986 100644 --- a/datachannel_go_test.go +++ b/datachannel_go_test.go @@ -745,3 +745,80 @@ func TestDetachRemovesDatachannelReference(t *testing.T) { } } } + +func TestDataChannelClose(t *testing.T) { + // Test if onClose is fired for self and remote after Close is called + t.Run("close open channels", func(t *testing.T) { + options := &DataChannelInit{} + + offerPC, answerPC, dc, done := setUpDataChannelParametersTest(t, options) + + answerPC.OnDataChannel(func(dataChannel *DataChannel) { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if dataChannel.Label() != expectedLabel { + return + } + + dataChannel.OnOpen(func() { + assert.NoError(t, dataChannel.Close()) + }) + + dataChannel.OnClose(func() { + done <- true + }) + }) + + dc.OnClose(func() { + done <- true + }) + + assert.NoError(t, signalPair(offerPC, answerPC)) + + // Offer and Answer OnClose + <-done + <-done + + assert.NoError(t, offerPC.Close()) + assert.NoError(t, answerPC.Close()) + }) + + // Test if OnClose is fired for self and remote after Close is called on non-established channel + // https://github.com/pion/webrtc/issues/2659 + t.Run("Close connecting channels", func(t *testing.T) { + options := &DataChannelInit{} + + offerPC, answerPC, dc, done := setUpDataChannelParametersTest(t, options) + + answerPC.OnDataChannel(func(dataChannel *DataChannel) { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if dataChannel.Label() != expectedLabel { + return + } + + dataChannel.OnOpen(func() { + t.Fatal("OnOpen must not be fired after we call Close") + }) + + dataChannel.OnClose(func() { + done <- true + }) + + assert.NoError(t, dataChannel.Close()) + }) + + dc.OnClose(func() { + done <- true + }) + + assert.NoError(t, signalPair(offerPC, answerPC)) + + // Offer and Answer OnClose + <-done + <-done + + assert.NoError(t, offerPC.Close()) + assert.NoError(t, answerPC.Close()) + }) +}