Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Switch classic search to Backend interface. #2109

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
58 changes: 24 additions & 34 deletions src/engine_classic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,16 @@ void EngineClassic::UpdateFromUciOptions() {
const auto network_configuration =
NetworkFactory::BackendConfiguration(options_);
if (network_configuration_ != network_configuration) {
network_ = NetworkFactory::LoadNetwork(options_);
backend_ =
CreateMemCache(BackendManager::Get()->CreateFromParams(options_),
options_.Get<int>(SharedBackendParams::kNNCacheSizeId));
network_configuration_ = network_configuration;
} else {
// Still update the cache size.
backend_->SetCacheCapacity(
options_.Get<int>(SharedBackendParams::kNNCacheSizeId));
}

// Cache size.
cache_.SetCapacity(options_.Get<int>(SharedBackendParams::kNNCacheSizeId));

// Check whether we can update the move timer in "Go".
strict_uci_timing_ = options_.Get<bool>(kStrictUciTiming);
}
Expand All @@ -178,7 +181,7 @@ void EngineClassic::NewGame() {
// newgame and goes straight into go.
ResetMoveTimer();
SharedLock lock(busy_mutex_);
cache_.Clear();
backend_->ClearCache();
search_.reset();
tree_.reset();
CreateFreshTimeManager();
Expand Down Expand Up @@ -265,52 +268,39 @@ class PonderResponseTransformer : public TransformingUciResponder {
std::string ponder_move_;
};

void ValueOnlyGo(classic::NodeTree* tree, Network* network,
const OptionsDict& options,
void ValueOnlyGo(classic::NodeTree* tree, Backend* backend,
std::unique_ptr<UciResponder> responder) {
auto input_format = network->GetCapabilities().input_format;

const auto& board = tree->GetPositionHistory().Last().GetBoard();
auto legal_moves = board.GenerateLegalMoves();
tree->GetCurrentHead()->CreateEdges(legal_moves);
PositionHistory history = tree->GetPositionHistory();
std::vector<InputPlanes> planes;
std::vector<float> comp_q;
comp_q.reserve(legal_moves.size());
auto comp = backend->CreateComputation();
for (auto edge : tree->GetCurrentHead()->Edges()) {
history.Append(edge.GetMove());
if (history.ComputeGameResult() == GameResult::UNDECIDED) {
planes.emplace_back(EncodePositionForNN(
input_format, history, 8, FillEmptyHistory::FEN_ONLY, nullptr));
comp_q.emplace_back();
comp->AddInput(
EvalPosition{
.pos = history.GetPositions(),
.legal_moves = {},
},
EvalResultPtr{.q = &comp_q.back()});
}
history.Pop();
}

std::vector<float> comp_q;
int batch_size = options.Get<int>(classic::SearchParams::kMiniBatchSizeId);
if (batch_size == 0) batch_size = network->GetMiniBatchSize();

for (size_t i = 0; i < planes.size(); i += batch_size) {
auto comp = network->NewComputation();
for (int j = 0; j < batch_size; j++) {
comp->AddInput(std::move(planes[i + j]));
if (i + j + 1 == planes.size()) break;
}
comp->ComputeBlocking();

for (int j = 0; j < batch_size; j++) comp_q.push_back(comp->GetQVal(j));
}

Move best;
int comp_idx = 0;
float max_q = std::numeric_limits<float>::lowest();
for (auto edge : tree->GetCurrentHead()->Edges()) {
for (size_t comp_idx = 0; auto edge : tree->GetCurrentHead()->Edges()) {
history.Append(edge.GetMove());
auto result = history.ComputeGameResult();
float q = -1;
if (result == GameResult::UNDECIDED) {
// NN eval is for side to move perspective - so if its good, its bad for
// us.
q = -comp_q[comp_idx];
comp_idx++;
q = -comp_q[comp_idx++];
} else if (result == GameResult::DRAW) {
q = 0;
} else {
Expand Down Expand Up @@ -375,7 +365,7 @@ void EngineClassic::Go(const GoParams& params) {
responder = std::make_unique<MovesLeftResponseFilter>(std::move(responder));
}
if (options_.Get<bool>(kValueOnly)) {
ValueOnlyGo(tree_.get(), network_.get(), options_, std::move(responder));
ValueOnlyGo(tree_.get(), backend_.get(), std::move(responder));
return;
}

Expand All @@ -385,10 +375,10 @@ void EngineClassic::Go(const GoParams& params) {

auto stopper = time_manager_->GetStopper(params, *tree_.get());
search_ = std::make_unique<classic::Search>(
*tree_, network_.get(), std::move(responder),
*tree_, backend_.get(), std::move(responder),
StringsToMovelist(params.searchmoves, tree_->HeadPosition().GetBoard()),
*move_start_time_, std::move(stopper), params.infinite, params.ponder,
options_, &cache_, syzygy_tb_.get());
options_, syzygy_tb_.get());

LOGFILE << "Timer started at "
<< FormatTime(SteadyClockToSystemClock(*move_start_time_));
Expand Down
5 changes: 2 additions & 3 deletions src/engine_classic.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include "engine_loop.h"
#include "neural/cache.h"
#include "neural/factory.h"
#include "neural/network.h"
#include "neural/memcache.h"
#include "search/classic/search.h"
#include "syzygy/syzygy.h"
#include "utils/mutex.h"
Expand Down Expand Up @@ -94,8 +94,7 @@ class EngineClassic : public EngineControllerBase {
std::unique_ptr<classic::Search> search_;
std::unique_ptr<classic::NodeTree> tree_;
std::unique_ptr<SyzygyTablebase> syzygy_tb_;
std::unique_ptr<Network> network_;
NNCache cache_;
std::unique_ptr<CachingBackend> backend_;

// Store current TB and network settings to track when they change so that
// they are reloaded.
Expand Down
16 changes: 10 additions & 6 deletions src/neural/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,22 @@ struct BackendAttributes {
int maximum_batch_size;
};

struct EvalResultPtr {
float* q = nullptr;
float* d = nullptr;
float* m = nullptr;
std::span<float> p = {};
};

struct EvalResult {
float q;
float d;
float m;
std::vector<float> p;
};

struct EvalResultPtr {
float* q = nullptr;
float* d = nullptr;
float* m = nullptr;
std::span<float> p;
EvalResultPtr AsPtr() {
return EvalResultPtr{.q = &q, .d = &d, .m = &m, .p = p};
}
};

struct EvalPosition {
Expand Down
42 changes: 29 additions & 13 deletions src/neural/memcache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void CachedValueToEvalResult(const CachedValue& cv, const EvalResultPtr& ptr) {
std::copy(cv.p.get(), cv.p.get() + ptr.p.size(), ptr.p.begin());
}

class MemCache : public Backend {
class MemCache : public CachingBackend {
public:
MemCache(std::unique_ptr<Backend> wrapped, size_t capacity)
: wrapped_backend_(std::move(wrapped)),
Expand All @@ -67,6 +67,11 @@ class MemCache : public Backend {
std::unique_ptr<BackendComputation> CreateComputation() override;
std::optional<EvalResult> GetCachedEvaluation(const EvalPosition&) override;

void ClearCache() override { cache_.Clear(); }
void SetCacheCapacity(size_t capacity) override {
cache_.SetCapacity(capacity);
}

private:
std::unique_ptr<Backend> wrapped_backend_;
HashKeyedCache<CachedValue> cache_;
Expand All @@ -91,23 +96,29 @@ class MemCacheComputation : public BackendComputation {
}
virtual AddInputResult AddInput(const EvalPosition& pos,
EvalResultPtr result) override {
assert(pos.legal_moves.size() == result.p.size() || result.p.empty());
const uint64_t hash = ComputeEvalPositionHash(pos);
{
HashKeyedCacheLock<CachedValue> lock(&memcache_->cache_, hash);
if (lock.holds_value()) {
// Sometimes search queries NN without passing the legal moves. It is
// still cached in this case, but in subsequent queries we only return it
// legal moves are not passed again.
if (lock.holds_value() && (pos.legal_moves.empty() || lock->p)) {
CachedValueToEvalResult(**lock, result);
return AddInputResult::FETCHED_IMMEDIATELY;
}
}
assert(keys_.size() < memcache_->max_batch_size_);
keys_.push_back(hash);
auto value = std::make_unique<CachedValue>();
value->p.reset(new float[result.p.size()]);
auto& value = values_.emplace_back(std::make_unique<CachedValue>());
value->p.reset(pos.legal_moves.empty() ? nullptr
: new float[pos.legal_moves.size()]);
result_ptrs_.push_back(result);
return wrapped_computation_->AddInput(
pos, EvalResultPtr{&value->q,
&value->d,
&value->m,
{value->p.get(), pos.legal_moves.size()}});
pos, EvalResultPtr{&value->q, &value->d, &value->m,
value->p ? std::span<float>{value->p.get(),
pos.legal_moves.size()}
: std::span<float>{}});
}

virtual void ComputeBlocking() override {
Expand All @@ -133,20 +144,25 @@ std::optional<EvalResult> MemCache::GetCachedEvaluation(
const EvalPosition& pos) {
const uint64_t hash = ComputeEvalPositionHash(pos);
HashKeyedCacheLock<CachedValue> lock(&cache_, hash);
if (!lock.holds_value()) return std::nullopt;
if (!lock.holds_value() || (!pos.legal_moves.empty() && !lock->p)) {
return std::nullopt;
}
EvalResult result;
result.d = lock->d;
result.q = lock->q;
result.m = lock->m;
std::copy(lock->p.get(), lock->p.get() + pos.legal_moves.size(),
result.p.begin());
if (lock->p) {
result.p.reserve(pos.legal_moves.size());
std::copy(lock->p.get(), lock->p.get() + pos.legal_moves.size(),
std::back_inserter(result.p));
}
return result;
}

} // namespace

std::unique_ptr<Backend> CreateMemCache(std::unique_ptr<Backend> wrapped,
size_t capacity) {
std::unique_ptr<CachingBackend> CreateMemCache(std::unique_ptr<Backend> wrapped,
size_t capacity) {
return std::make_unique<MemCache>(std::move(wrapped), capacity);
}

Expand Down
12 changes: 10 additions & 2 deletions src/neural/memcache.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,18 @@

namespace lczero {

class CachingBackend : public Backend {
public:
// Clears the cache.
virtual void ClearCache() = 0;
// Sets the maximum number of items in the cache.
virtual void SetCacheCapacity(size_t capacity) = 0;
};

// Creates a caching backend wrapper, which returns values immediately if they
// are found, and forwards the request to the wrapped backend otherwise (and
// caches the result).
std::unique_ptr<Backend> CreateMemCache(std::unique_ptr<Backend> parent,
size_t capacity);
std::unique_ptr<CachingBackend> CreateMemCache(std::unique_ptr<Backend> parent,
size_t capacity);

} // namespace lczero
6 changes: 0 additions & 6 deletions src/search/classic/params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,6 @@ const OptionId SearchParams::kMovesLeftQuadraticFactorId{
"moves-left-quadratic-factor", "MovesLeftQuadraticFactor",
"A factor which is multiplied by the square of Q of parent node and the "
"base moves left effect."};
const OptionId SearchParams::kDisplayCacheUsageId{
"display-cache-usage", "DisplayCacheUsage",
"Display cache fullness through UCI info `hash` section."};
const OptionId SearchParams::kMaxConcurrentSearchersId{
"max-concurrent-searchers", "MaxConcurrentSearchers",
"If not 0, at most this many search workers can be gathering minibatches "
Expand Down Expand Up @@ -543,7 +540,6 @@ void SearchParams::Populate(OptionsParser* options) {
options->Add<FloatOption>(kMovesLeftScaledFactorId, -2.0f, 2.0f) = 1.6521f;
options->Add<FloatOption>(kMovesLeftQuadraticFactorId, -1.0f, 1.0f) =
-0.6521f;
options->Add<BoolOption>(kDisplayCacheUsageId) = false;
options->Add<IntOption>(kMaxConcurrentSearchersId, 0, 128) = 1;
options->Add<FloatOption>(kDrawScoreId, -1.0f, 1.0f) = 0.0f;
std::vector<std::string> mode = {"play", "white_side_analysis",
Expand Down Expand Up @@ -578,7 +574,6 @@ void SearchParams::Populate(OptionsParser* options) {
options->HideOption(kNoiseEpsilonId);
options->HideOption(kNoiseAlphaId);
options->HideOption(kLogLiveStatsId);
options->HideOption(kDisplayCacheUsageId);
options->HideOption(kRootHasOwnCpuctParamsId);
options->HideOption(kCpuctAtRootId);
options->HideOption(kCpuctBaseAtRootId);
Expand Down Expand Up @@ -643,7 +638,6 @@ SearchParams::SearchParams(const OptionsDict& options)
kMovesLeftScaledFactor(options.Get<float>(kMovesLeftScaledFactorId)),
kMovesLeftQuadraticFactor(
options.Get<float>(kMovesLeftQuadraticFactorId)),
kDisplayCacheUsage(options.Get<bool>(kDisplayCacheUsageId)),
kMaxConcurrentSearchers(options.Get<int>(kMaxConcurrentSearchersId)),
kDrawScore(options.Get<float>(kDrawScoreId)),
kContempt(GetContempt(options.Get<std::string>(kUCIOpponentId),
Expand Down
3 changes: 0 additions & 3 deletions src/search/classic/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ class SearchParams {
float GetMovesLeftQuadraticFactor() const {
return kMovesLeftQuadraticFactor;
}
bool GetDisplayCacheUsage() const { return kDisplayCacheUsage; }
int GetMaxConcurrentSearchers() const { return kMaxConcurrentSearchers; }
float GetDrawScore() const { return kDrawScore; }
ContemptMode GetContemptMode() const {
Expand Down Expand Up @@ -205,7 +204,6 @@ class SearchParams {
static const OptionId kMovesLeftScaledFactorId;
static const OptionId kMovesLeftQuadraticFactorId;
static const OptionId kMovesLeftSlopeId;
static const OptionId kDisplayCacheUsageId;
static const OptionId kMaxConcurrentSearchersId;
static const OptionId kDrawScoreId;
static const OptionId kContemptModeId;
Expand Down Expand Up @@ -271,7 +269,6 @@ class SearchParams {
const float kMovesLeftConstantFactor;
const float kMovesLeftScaledFactor;
const float kMovesLeftQuadraticFactor;
const bool kDisplayCacheUsage;
const int kMaxConcurrentSearchers;
const float kDrawScore;
const float kContempt;
Expand Down
Loading