diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 3eb2f5a76b..e39a6a4d93 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -360,6 +360,9 @@ void search_neighbors_to_add( float d_entry_point, int level, VisitedTable& vt) { + // selects a version + const bool reference_version = false; + // top is nearest candidate std::priority_queue candidates; @@ -381,26 +384,89 @@ void search_neighbors_to_add( // loop over neighbors size_t begin, end; hnsw.neighbor_range(currNode, level, &begin, &end); - for (size_t i = begin; i < end; i++) { - storage_idx_t nodeId = hnsw.neighbors[i]; - if (nodeId < 0) - break; - if (vt.get(nodeId)) - continue; - vt.set(nodeId); - float dis = qdis(nodeId); - NodeDistFarther evE1(dis, nodeId); + // select a version, based on a flag + if (reference_version) { + // a reference version + for (size_t i = begin; i < end; i++) { + storage_idx_t nodeId = hnsw.neighbors[i]; + if (nodeId < 0) + break; + if (vt.get(nodeId)) + continue; + vt.set(nodeId); + + float dis = qdis(nodeId); + NodeDistFarther evE1(dis, nodeId); + + if (results.size() < hnsw.efConstruction || + results.top().d > dis) { + results.emplace(dis, nodeId); + candidates.emplace(dis, nodeId); + if (results.size() > hnsw.efConstruction) { + results.pop(); + } + } + } + } else { + // a faster version + + // the following version processes 4 neighbors at a time + auto update_with_candidate = [&](const storage_idx_t idx, + const float dis) { + if (results.size() < hnsw.efConstruction || + results.top().d > dis) { + results.emplace(dis, idx); + candidates.emplace(dis, idx); + if (results.size() > hnsw.efConstruction) { + results.pop(); + } + } + }; + + int n_buffered = 0; + storage_idx_t buffered_ids[4]; - if (results.size() < hnsw.efConstruction || results.top().d > dis) { - results.emplace(dis, nodeId); - candidates.emplace(dis, nodeId); - if (results.size() > hnsw.efConstruction) { - results.pop(); + for (size_t j = begin; j < end; j++) { + storage_idx_t nodeId = hnsw.neighbors[j]; + if (nodeId < 0) + break; + if (vt.get(nodeId)) { + continue; + } + vt.set(nodeId); + + buffered_ids[n_buffered] = nodeId; + n_buffered += 1; + + if (n_buffered == 4) { + float dis[4]; + qdis.distances_batch_4( + buffered_ids[0], + buffered_ids[1], + buffered_ids[2], + buffered_ids[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + update_with_candidate(buffered_ids[id4], dis[id4]); + } + + n_buffered = 0; } } + + // process leftovers + for (size_t icnt = 0; icnt < n_buffered; icnt++) { + float dis = qdis(buffered_ids[icnt]); + update_with_candidate(buffered_ids[icnt], dis); + } } } + vt.advance(); } @@ -415,6 +481,9 @@ HNSWStats greedy_update_nearest( int level, storage_idx_t& nearest, float& d_nearest) { + // selects a version + const bool reference_version = false; + HNSWStats stats; for (;;) { @@ -424,14 +493,69 @@ HNSWStats greedy_update_nearest( hnsw.neighbor_range(nearest, level, &begin, &end); size_t ndis = 0; - for (size_t i = begin; i < end; i++, ndis++) { - storage_idx_t v = hnsw.neighbors[i]; - if (v < 0) - break; - float dis = qdis(v); - if (dis < d_nearest) { - nearest = v; - d_nearest = dis; + + // select a version, based on a flag + if (reference_version) { + // a reference version + for (size_t i = begin; i < end; i++) { + storage_idx_t v = hnsw.neighbors[i]; + if (v < 0) + break; + ndis += 1; + float dis = qdis(v); + if (dis < d_nearest) { + nearest = v; + d_nearest = dis; + } + } + } else { + // a faster version + + // the following version processes 4 neighbors at a time + auto update_with_candidate = [&](const storage_idx_t idx, + const float dis) { + if (dis < d_nearest) { + nearest = idx; + d_nearest = dis; + } + }; + + int n_buffered = 0; + storage_idx_t buffered_ids[4]; + + for (size_t j = begin; j < end; j++) { + storage_idx_t v = hnsw.neighbors[j]; + if (v < 0) + break; + ndis += 1; + + buffered_ids[n_buffered] = v; + n_buffered += 1; + + if (n_buffered == 4) { + float dis[4]; + qdis.distances_batch_4( + buffered_ids[0], + buffered_ids[1], + buffered_ids[2], + buffered_ids[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + update_with_candidate(buffered_ids[id4], dis[id4]); + } + + n_buffered = 0; + } + } + + // process leftovers + for (size_t icnt = 0; icnt < n_buffered; icnt++) { + float dis = qdis(buffered_ids[icnt]); + update_with_candidate(buffered_ids[icnt], dis); } } @@ -563,6 +687,9 @@ int search_from_candidates( int level, int nres_in = 0, const SearchParametersHNSW* params = nullptr) { + // selects a version + const bool reference_version = false; + int nres = nres_in; int ndis = 0; @@ -607,87 +734,94 @@ int search_from_candidates( size_t begin, end; hnsw.neighbor_range(v0, level, &begin, &end); - // // baseline version - // for (size_t j = begin; j < end; j++) { - // int v1 = hnsw.neighbors[j]; - // if (v1 < 0) - // break; - // if (vt.get(v1)) { - // continue; - // } - // vt.set(v1); - // ndis++; - // float d = qdis(v1); - // if (!sel || sel->is_member(v1)) { - // if (nres < k) { - // faiss::maxheap_push(++nres, D, I, d, v1); - // } else if (d < D[0]) { - // faiss::maxheap_replace_top(nres, D, I, d, v1); - // } - // } - // candidates.push(v1, d); - // } - - // the following version processes 4 neighbors at a time - size_t jmax = begin; - for (size_t j = begin; j < end; j++) { - int v1 = hnsw.neighbors[j]; - if (v1 < 0) - break; + // select a version, based on a flag + if (reference_version) { + // a reference version + for (size_t j = begin; j < end; j++) { + int v1 = hnsw.neighbors[j]; + if (v1 < 0) + break; + if (vt.get(v1)) { + continue; + } + vt.set(v1); + ndis++; + float d = qdis(v1); + if (!sel || sel->is_member(v1)) { + if (d < threshold) { + if (res.add_result(d, v1)) { + threshold = res.threshold; + nres += 1; + } + } + } - prefetch_L2(vt.visited.data() + v1); - jmax += 1; - } + candidates.push(v1, d); + } + } else { + // a faster version + + // the following version processes 4 neighbors at a time + size_t jmax = begin; + for (size_t j = begin; j < end; j++) { + int v1 = hnsw.neighbors[j]; + if (v1 < 0) + break; + + prefetch_L2(vt.visited.data() + v1); + jmax += 1; + } - int counter = 0; - size_t saved_j[4]; + int counter = 0; + size_t saved_j[4]; - ndis += jmax - begin; - threshold = res.threshold; + ndis += jmax - begin; + threshold = res.threshold; - auto add_to_heap = [&](const size_t idx, const float dis) { - if (!sel || sel->is_member(idx)) { - if (dis < threshold) { - if (res.add_result(dis, idx)) { - threshold = res.threshold; - nres += 1; + auto add_to_heap = [&](const size_t idx, const float dis) { + if (!sel || sel->is_member(idx)) { + if (dis < threshold) { + if (res.add_result(dis, idx)) { + threshold = res.threshold; + nres += 1; + } } } - } - candidates.push(idx, dis); - }; - - for (size_t j = begin; j < jmax; j++) { - int v1 = hnsw.neighbors[j]; - - bool vget = vt.get(v1); - vt.set(v1); - saved_j[counter] = v1; - counter += vget ? 0 : 1; - - if (counter == 4) { - float dis[4]; - qdis.distances_batch_4( - saved_j[0], - saved_j[1], - saved_j[2], - saved_j[3], - dis[0], - dis[1], - dis[2], - dis[3]); - - for (size_t id4 = 0; id4 < 4; id4++) { - add_to_heap(saved_j[id4], dis[id4]); - } + candidates.push(idx, dis); + }; + + for (size_t j = begin; j < jmax; j++) { + int v1 = hnsw.neighbors[j]; + + bool vget = vt.get(v1); + vt.set(v1); + saved_j[counter] = v1; + counter += vget ? 0 : 1; + + if (counter == 4) { + float dis[4]; + qdis.distances_batch_4( + saved_j[0], + saved_j[1], + saved_j[2], + saved_j[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + add_to_heap(saved_j[id4], dis[id4]); + } - counter = 0; + counter = 0; + } } - } - for (size_t icnt = 0; icnt < counter; icnt++) { - float dis = qdis(saved_j[icnt]); - add_to_heap(saved_j[icnt], dis); + for (size_t icnt = 0; icnt < counter; icnt++) { + float dis = qdis(saved_j[icnt]); + add_to_heap(saved_j[icnt], dis); + } } nstep++; @@ -715,6 +849,9 @@ std::priority_queue search_from_candidate_unbounded( int ef, VisitedTable* vt, HNSWStats& stats) { + // selects a version + const bool reference_version = false; + int ndis = 0; std::priority_queue top_candidates; std::priority_queue, std::greater> candidates; @@ -738,92 +875,96 @@ std::priority_queue search_from_candidate_unbounded( size_t begin, end; hnsw.neighbor_range(v0, 0, &begin, &end); - // // baseline version - // for (size_t j = begin; j < end; ++j) { - // int v1 = hnsw.neighbors[j]; - // - // if (v1 < 0) { - // break; - // } - // if (vt->get(v1)) { - // continue; - // } - // - // vt->set(v1); - // - // float d1 = qdis(v1); - // ++ndis; - // - // if (top_candidates.top().first > d1 || - // top_candidates.size() < ef) { - // candidates.emplace(d1, v1); - // top_candidates.emplace(d1, v1); - // - // if (top_candidates.size() > ef) { - // top_candidates.pop(); - // } - // } - // } - - // the following version processes 4 neighbors at a time - size_t jmax = begin; - for (size_t j = begin; j < end; j++) { - int v1 = hnsw.neighbors[j]; - if (v1 < 0) - break; + if (reference_version) { + // reference version + for (size_t j = begin; j < end; ++j) { + int v1 = hnsw.neighbors[j]; - prefetch_L2(vt->visited.data() + v1); - jmax += 1; - } + if (v1 < 0) { + break; + } + if (vt->get(v1)) { + continue; + } - int counter = 0; - size_t saved_j[4]; + vt->set(v1); - ndis += jmax - begin; + float d1 = qdis(v1); + ++ndis; - auto add_to_heap = [&](const size_t idx, const float dis) { - if (top_candidates.top().first > dis || - top_candidates.size() < ef) { - candidates.emplace(dis, idx); - top_candidates.emplace(dis, idx); + if (top_candidates.top().first > d1 || + top_candidates.size() < ef) { + candidates.emplace(d1, v1); + top_candidates.emplace(d1, v1); - if (top_candidates.size() > ef) { - top_candidates.pop(); + if (top_candidates.size() > ef) { + top_candidates.pop(); + } } } - }; - - for (size_t j = begin; j < jmax; j++) { - int v1 = hnsw.neighbors[j]; - - bool vget = vt->get(v1); - vt->set(v1); - saved_j[counter] = v1; - counter += vget ? 0 : 1; - - if (counter == 4) { - float dis[4]; - qdis.distances_batch_4( - saved_j[0], - saved_j[1], - saved_j[2], - saved_j[3], - dis[0], - dis[1], - dis[2], - dis[3]); - - for (size_t id4 = 0; id4 < 4; id4++) { - add_to_heap(saved_j[id4], dis[id4]); + } else { + // a faster version + + // the following version processes 4 neighbors at a time + size_t jmax = begin; + for (size_t j = begin; j < end; j++) { + int v1 = hnsw.neighbors[j]; + if (v1 < 0) + break; + + prefetch_L2(vt->visited.data() + v1); + jmax += 1; + } + + int counter = 0; + size_t saved_j[4]; + + ndis += jmax - begin; + + auto add_to_heap = [&](const size_t idx, const float dis) { + if (top_candidates.top().first > dis || + top_candidates.size() < ef) { + candidates.emplace(dis, idx); + top_candidates.emplace(dis, idx); + + if (top_candidates.size() > ef) { + top_candidates.pop(); + } } + }; + + for (size_t j = begin; j < jmax; j++) { + int v1 = hnsw.neighbors[j]; + + bool vget = vt->get(v1); + vt->set(v1); + saved_j[counter] = v1; + counter += vget ? 0 : 1; + + if (counter == 4) { + float dis[4]; + qdis.distances_batch_4( + saved_j[0], + saved_j[1], + saved_j[2], + saved_j[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + add_to_heap(saved_j[id4], dis[id4]); + } - counter = 0; + counter = 0; + } } - } - for (size_t icnt = 0; icnt < counter; icnt++) { - float dis = qdis(saved_j[icnt]); - add_to_heap(saved_j[icnt], dis); + for (size_t icnt = 0; icnt < counter; icnt++) { + float dis = qdis(saved_j[icnt]); + add_to_heap(saved_j[icnt], dis); + } } stats.nhops += 1;