From f429d4a801465b1862b7da724624d16be31533ef Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Fri, 8 Jun 2018 16:18:07 +0100 Subject: [PATCH] max-length per sentence in batch --- src/amun/common/history.h | 3 +++ src/amun/common/search.cpp | 12 +++++++++--- src/amun/common/search.h | 3 ++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/amun/common/history.h b/src/amun/common/history.h index 62d44ae03..2878cd424 100644 --- a/src/amun/common/history.h +++ b/src/amun/common/history.h @@ -46,6 +46,9 @@ class History { unsigned GetLineNum() const { return lineNo_; } + unsigned GetMaxLength() const + { return maxLength_; } + void SetActive(bool active); bool GetActive() const; diff --git a/src/amun/common/search.cpp b/src/amun/common/search.cpp index 606b79118..e7583c35d 100644 --- a/src/amun/common/search.cpp +++ b/src/amun/common/search.cpp @@ -100,7 +100,7 @@ std::shared_ptr Search::Translate(const Sentences& sentences) { } //cerr << "beamSizes=" << Debug(beamSizes, 1) << endl; - bool hasSurvivors = CalcBeam(histories, beamSizes, prevHyps, states, nextStates); + bool hasSurvivors = CalcBeam(histories, beamSizes, prevHyps, states, nextStates, decoderStep); if (!hasSurvivors) { break; } @@ -134,18 +134,24 @@ bool Search::CalcBeam( std::vector& beamSizes, Beam& prevHyps, States& states, - States& nextStates) + States& nextStates, + unsigned decoderStep) { unsigned batchSize = beamSizes.size(); Beams beams(batchSize); bestHyps_->CalcBeam(prevHyps, scorers_, filterIndices_, beams, beamSizes); histories->Add(beams); + //cerr << "batchSize=" << batchSize << endl; histories->SetActive(false); Beam survivors; for (unsigned batchId = 0; batchId < batchSize; ++batchId) { + const History &hist = *histories->at(batchId); + unsigned maxLength = hist.GetMaxLength(); + + //cerr << "beamSizes[batchId]=" << batchId << " " << beamSizes[batchId] << " " << maxLength << endl; for (auto& h : beams[batchId]) { - if (h->GetWord() != EOS_ID) { + if (decoderStep < maxLength && h->GetWord() != EOS_ID) { survivors.push_back(h); histories->SetActive(batchId, true); diff --git a/src/amun/common/search.h b/src/amun/common/search.h index 159ede63c..81c3d807d 100644 --- a/src/amun/common/search.h +++ b/src/amun/common/search.h @@ -30,7 +30,8 @@ class Search { std::vector& beamSizes, Beam& prevHyps, States& states, - States& nextStates); + States& nextStates, + unsigned decoderStep); Search(const Search&) = delete;