From 8c280710fb701f8e4d64307defff3fa754f9dd46 Mon Sep 17 00:00:00 2001 From: OrlandoCo Date: Wed, 28 Oct 2020 13:58:31 -0600 Subject: [PATCH] Fix sub transceiver direction and rate limit pli (#264) * Fix sub transceiver direction and rate limit pli * Fix tests --- pkg/router.go | 23 ++++++++++------------- pkg/simplesender.go | 7 +++++-- pkg/simplesender_test.go | 11 +++-------- pkg/simulcastsender.go | 19 ++++++++++++------- pkg/simulcastsender_test.go | 7 ------- 5 files changed, 30 insertions(+), 37 deletions(-) diff --git a/pkg/router.go b/pkg/router.go index 654cba975..04606cec8 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -188,29 +188,26 @@ func (r *router) addSender(p *WebRTCTransport, rr *receiverRouter) error { return err } // Create webrtc sender for the peer we are sending track to - s, err := p.pc.AddTrack(outTrack) + t, err := p.pc.AddTransceiverFromTrack(outTrack, webrtc.RtpTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionSendonly, + }) if err != nil { return err } if rr.kind == SimulcastReceiver { - sender = NewSimulcastSender(p.ctx, p.id, rr, s, recv.SpatialLayer(), r.config.Simulcast) + sender = NewSimulcastSender(p.ctx, p.id, rr, t.Sender(), recv.SpatialLayer(), r.config.Simulcast) } else { - sender = NewSimpleSender(p.ctx, p.id, rr, s) + sender = NewSimpleSender(p.ctx, p.id, rr, t.Sender()) } sender.OnCloseHandler(func() { - if err := p.pc.RemoveTrack(s); err != nil { + if err := p.pc.RemoveTrack(t.Sender()); err != nil { log.Errorf("Error closing sender: %s", err) } }) - for _, t := range p.pc.GetTransceivers() { - if t.Sender() != nil && t.Sender().Track().SSRC() == ssrc { - p.pendingSenders.PushBack(&pendingSender{ - transceiver: t, - sender: sender, - }) - break - } - } + p.pendingSenders.PushBack(&pendingSender{ + transceiver: t, + sender: sender, + }) p.AddSender(rr.stream, sender) recv.AddSender(sender) return nil diff --git a/pkg/simplesender.go b/pkg/simplesender.go index f6cb1bfd1..7922fb7b4 100644 --- a/pkg/simplesender.go +++ b/pkg/simplesender.go @@ -64,6 +64,7 @@ func (s *SimpleSender) ID() string { func (s *SimpleSender) Start() { s.start.Do(func() { log.Debugf("starting sender %s with ssrc %d", s.id, s.track.SSRC()) + s.reSync.set(true) s.enabled.set(true) }) } @@ -216,8 +217,10 @@ func (s *SimpleSender) receiveRTCP() { for _, pkt := range pkts { switch pkt := pkt.(type) { case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest: - fwdPkts = append(fwdPkts, pkt) - s.lastPli = time.Now() + if !s.reSync.get() && s.enabled.get() && time.Now().Sub(s.lastPli) > time.Second { + fwdPkts = append(fwdPkts, pkt) + s.lastPli = time.Now() + } case *rtcp.TransportLayerNack: log.Tracef("sender got nack: %+v", pkt) for _, pair := range pkt.Nacks { diff --git a/pkg/simplesender_test.go b/pkg/simplesender_test.go index 1ebfb116d..692160a3e 100644 --- a/pkg/simplesender_test.go +++ b/pkg/simplesender_test.go @@ -189,13 +189,6 @@ forLoop: MediaSSRC: 1234, }, }, - { - name: "Sender must forward FIR messages", - want: &rtcp.FullIntraRequest{ - SenderSSRC: 1234, - MediaSSRC: 1234, - }, - }, } for _, tt := range tests { tt := tt @@ -212,12 +205,14 @@ forLoop: sender: s, track: senderTrack, } + wss.enabled.set(true) + wss.lastPli = time.Now().Add(-5 * time.Second) go wss.receiveRTCP() tmr := time.NewTimer(5000 * time.Millisecond) testLoop: for { select { - case <-time.After(20 * time.Millisecond): + case <-time.After(10 * time.Millisecond): err := remote.WriteRTCP([]rtcp.Packet{tt.want, tt.want, tt.want, tt.want}) assert.NoError(t, err) case <-tmr.C: diff --git a/pkg/simulcastsender.go b/pkg/simulcastsender.go index 286196306..b5ff4d989 100644 --- a/pkg/simulcastsender.go +++ b/pkg/simulcastsender.go @@ -294,14 +294,19 @@ func (s *SimulcastSender) receiveRTCP() { for _, pkt := range pkts { switch pkt := pkt.(type) { case *rtcp.PictureLossIndication: - pkt.MediaSSRC = s.lSSRC - pkt.SenderSSRC = s.lSSRC - fwdPkts = append(fwdPkts, pkt) - s.lastPli = time.Now() + if s.enabled.get() && time.Now().Sub(s.lastPli) > time.Second { + pkt.MediaSSRC = s.lSSRC + pkt.SenderSSRC = s.lSSRC + fwdPkts = append(fwdPkts, pkt) + s.lastPli = time.Now() + } case *rtcp.FullIntraRequest: - pkt.MediaSSRC = s.lSSRC - pkt.SenderSSRC = s.lSSRC - fwdPkts = append(fwdPkts, pkt) + if s.enabled.get() && time.Now().Sub(s.lastPli) > time.Second { + pkt.MediaSSRC = s.lSSRC + pkt.SenderSSRC = s.lSSRC + fwdPkts = append(fwdPkts, pkt) + s.lastPli = time.Now() + } case *rtcp.TransportLayerNack: log.Tracef("sender got nack: %+v", pkt) for _, pair := range pkt.Nacks { diff --git a/pkg/simulcastsender_test.go b/pkg/simulcastsender_test.go index a8e7572d4..688ee03e8 100644 --- a/pkg/simulcastsender_test.go +++ b/pkg/simulcastsender_test.go @@ -259,13 +259,6 @@ forLoop: MediaSSRC: simulcastSSRC, }, }, - { - name: "Sender must forward FIR messages, with correct SSRC", - want: &rtcp.FullIntraRequest{ - SenderSSRC: simulcastSSRC, - MediaSSRC: simulcastSSRC, - }, - }, } for _, tt := range tests { tt := tt