Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added more detailed error messages for KNN model training #2378

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
- Added more detailed error messages for KNN model training (#2378)[https://github.com/opensearch-project/k-NN/pull/2378]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.common.ValidationException;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportRequestOptions;
Expand All @@ -31,6 +32,9 @@
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER;

/**
Expand Down Expand Up @@ -134,6 +138,30 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques
trainingVectors = trainingModelRequest.getMaximumVectorCount();
}

long minTrainingVectorCount = 1000;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make the 1000 value a class level constant?

MethodComponentContext encoderContext = (MethodComponentContext) trainingModelRequest.getKnnMethodContext()
.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (trainingModelRequest.getKnnMethodContext().getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) trainingModelRequest.getKnnMethodContext()
.getMethodComponentContext()
.getParameters()
.get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}

if (trainingVectors < minTrainingVectorCount) {
ValidationException exception = new ValidationException();
exception.addValidationError("Number of training points should be greater than " + minTrainingVectorCount);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use String.format for concatenation

listener.onFailure(exception);
return;
}

listener.onResponse(
estimateVectorSetSizeInKB(trainingVectors, trainingModelRequest.getDimension(), trainingModelRequest.getVectorDataType())
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

import java.io.IOException;

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;

/**
* Request to train and serialize a model
*/
Expand Down Expand Up @@ -283,6 +285,15 @@ public ActionRequestValidationException validate() {
exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + " characters");
}

// Check if ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove these checks here now, correct?

&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if that parameter is always present or not, but if it's optional then this line can generate the runtime exception in case parameter is not present

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe Java has short-circuit evaluation, so if containsKey(ENCODER_PARAMETER_PQ_M) returns false then the second expression will not be evaluated. So a runtime exception shouldn't be thrown.

exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions");
}

// Validate training index exists
IndexMetadata indexMetadata = clusterService.state().metadata().index(trainingIndex);
if (indexMetadata == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ public void run() {
} catch (Exception e) {
logger.error("Failed to run training job for model \"" + modelId + "\": ", e);
modelMetadata.setState(ModelState.FAILED);
modelMetadata.setError(
"Failed to execute training. May be caused by an invalid method definition or " + "not enough memory to perform training."
);
modelMetadata.setError("Failed to execute training. " + e.getMessage());

KNNCounter.TRAINING_ERRORS.increment();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.List;
import java.util.Map;

import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -344,6 +345,55 @@ public void testTrainingIndexSize() {
transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener);
}

public void testTrainingIndexSizeFailure() {

String trainingIndexName = "training-index";
int dimension = 133;
int vectorCount = 100;

// Setup the request
TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
null,
getDefaultKNNMethodContextForModel(),
dimension,
trainingIndexName,
"training-field",
null,
"description",
VectorDataType.DEFAULT,
Mode.NOT_CONFIGURED,
CompressionLevel.NOT_CONFIGURED
);

// Mock client to return the right number of docs
TotalHits totalHits = new TotalHits(vectorCount, TotalHits.Relation.EQUAL_TO);
SearchHits searchHits = new SearchHits(new SearchHit[2], totalHits, 1.0f);
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(searchHits);
Client client = mock(Client.class);
doAnswer(invocationOnMock -> {
((ActionListener<SearchResponse>) invocationOnMock.getArguments()[1]).onResponse(searchResponse);
return null;
}).when(client).search(any(), any());

// Setup the action
ClusterService clusterService = mock(ClusterService.class);
TransportService transportService = mock(TransportService.class);
TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction(
transportService,
new ActionFilters(Collections.emptySet()),
clusterService,
client
);

ActionListener<Integer> listener = ActionListener.wrap(
size -> size.intValue(),
e -> assertThat(e.getMessage(), containsString("Number of training points should be greater than"))
);

transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener);
}

public void testTrainIndexSize_whenDataTypeIsBinary() {
String trainingIndexName = "training-index";
int dimension = 8;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,11 +621,61 @@ public void testValidation_invalid_descriptionToLong() {
ActionRequestValidationException exception = trainingModelRequest.validate();
assertNotNull(exception);
List<String> validationErrors = exception.validationErrors();
logger.error("Validation errorsa " + validationErrors);
logger.error("Validation errors " + validationErrors);
assertEquals(1, validationErrors.size());
assertTrue(validationErrors.get(0).contains("Description exceeds limit"));
}

public void testValidation_invalid_mNotDivisibleByDimension() {

// Setup the training request
String modelId = "test-model-id";
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
String trainingFieldModeId = "training-field-model-id";

Map<String, Object> parameters = Map.of("m", 3);

MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, parameters);
final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.DEFAULT, methodComponentContext);

TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
modelId,
knnMethodContext,
dimension,
trainingIndex,
trainingField,
null,
null,
VectorDataType.DEFAULT,
Mode.NOT_CONFIGURED,
CompressionLevel.NOT_CONFIGURED
);

// Mock the model dao to return metadata for modelId to recognize it is a duplicate
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);

ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(null);
when(modelDao.getMetadata(trainingFieldModeId)).thenReturn(trainingFieldModelMetadata);

// Cluster service that wont produce validation exception
ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension);

// Initialize static components with the mocks
TrainingModelRequest.initialize(modelDao, clusterService);

// Test that validation produces m not divisible by vector dimension error message
ActionRequestValidationException exception = trainingModelRequest.validate();
assertNotNull(exception);
List<String> validationErrors = exception.validationErrors();
logger.error("Validation errors " + validationErrors);
assertEquals(2, validationErrors.size());
assertTrue(validationErrors.get(1).contains("Training request ENCODER_PARAMETER_PQ_M"));
}

public void testValidation_valid_trainingIndexBuiltFromMethod() {
// This cluster service will result in no validation exceptions

Expand Down
94 changes: 91 additions & 3 deletions src/test/java/org/opensearch/knn/training/TrainingJobTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import java.util.UUID;
import java.util.concurrent.ExecutionException;

import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -217,7 +218,6 @@ public void testRun_success() throws IOException, ExecutionException {

Model model = trainingJob.getModel();
assertNotNull(model);

assertEquals(ModelState.CREATED, model.getModelMetadata().getState());

// Simple test that creates the index from template and doesnt fail
Expand Down Expand Up @@ -308,6 +308,10 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept

Model model = trainingJob.getModel();
assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState());
assertThat(
"Failed to load training data into memory. " + "Check if there is enough memory to perform the request.",
containsString(trainingJob.getModel().getModelMetadata().getError())
);
assertNotNull(model);
assertFalse(model.getModelMetadata().getError().isEmpty());
}
Expand Down Expand Up @@ -382,6 +386,10 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce

Model model = trainingJob.getModel();
assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState());
assertThat(
"Failed to allocate space in native memory for the model. " + "Check if there is enough memory to perform the request.",
containsString(trainingJob.getModel().getModelMetadata().getError())
);
assertNotNull(model);
assertFalse(model.getModelMetadata().getError().isEmpty());
}
Expand Down Expand Up @@ -435,15 +443,91 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep
when(nativeMemoryAllocation.isClosed()).thenReturn(true);
when(nativeMemoryAllocation.getMemoryAddress()).thenReturn((long) 0);

// Throw error on getting data
// Throw error on allocation is closed
when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation);

TrainingJob trainingJob = new TrainingJob(
modelId,
knnMethodContext,
nativeMemoryCacheManager,
trainingDataEntryContext,
mock(NativeMemoryEntryContext.AnonymousEntryContext.class),
modelContext,
knnMethodConfigContext,
"",
"test-node",
Mode.NOT_CONFIGURED,
CompressionLevel.NOT_CONFIGURED
);

trainingJob.run();

Model model = trainingJob.getModel();
assertThat(
"Failed to execute training. Unable to load training data into memory: allocation is already closed",
containsString(trainingJob.getModel().getModelMetadata().getError())
);
assertNotNull(model);
assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState());
}

public void testRun_failure_closedModelAnonymousAllocation() throws ExecutionException {
// In this test, the model anonymous allocation should be closed. Then, run should fail and update the error of
// the model
String modelId = "test-model-id";

// Define the method setup for method that requires training
int nlists = 5;
int dimension = 16;
KNNEngine knnEngine = KNNEngine.FAISS;
KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder()
.vectorDataType(VectorDataType.FLOAT)
.dimension(dimension)
.versionCreated(Version.CURRENT)
.build();
KNNMethodContext knnMethodContext = new KNNMethodContext(
knnEngine,
SpaceType.INNER_PRODUCT,
new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists))
);

String tdataKey = "t-data-key";
NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock(
NativeMemoryEntryContext.TrainingDataEntryContext.class
);
when(trainingDataEntryContext.getKey()).thenReturn(tdataKey);

// Setup model manager
NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class);

// Setup mock allocation for model that's closed
NativeMemoryAllocation modelAllocation = mock(NativeMemoryAllocation.class);
doAnswer(invocationOnMock -> null).when(modelAllocation).readLock();
doAnswer(invocationOnMock -> null).when(modelAllocation).readUnlock();
when(modelAllocation.isClosed()).thenReturn(true);

String modelKey = "model-test-key";
NativeMemoryEntryContext.AnonymousEntryContext modelContext = mock(NativeMemoryEntryContext.AnonymousEntryContext.class);
when(modelContext.getKey()).thenReturn(modelKey);

// Throw error on allocation is closed
when(nativeMemoryCacheManager.get(modelContext, false)).thenReturn(modelAllocation);
doAnswer(invocationOnMock -> null).when(nativeMemoryCacheManager).invalidate(modelKey);

// Setup mock allocation thats not closed
NativeMemoryAllocation nativeMemoryAllocation = mock(NativeMemoryAllocation.class);
doAnswer(invocationOnMock -> null).when(nativeMemoryAllocation).readLock();
doAnswer(invocationOnMock -> null).when(nativeMemoryAllocation).readUnlock();
when(nativeMemoryAllocation.isClosed()).thenReturn(false);
when(nativeMemoryAllocation.getMemoryAddress()).thenReturn((long) 0);

when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation);

TrainingJob trainingJob = new TrainingJob(
modelId,
knnMethodContext,
nativeMemoryCacheManager,
trainingDataEntryContext,
modelContext,
knnMethodConfigContext,
"",
"test-node",
Expand All @@ -454,6 +538,10 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep
trainingJob.run();

Model model = trainingJob.getModel();
assertThat(
"Failed to execute training. Unable to reserve memory for model: allocation is already closed",
containsString(trainingJob.getModel().getModelMetadata().getError())
);
assertNotNull(model);
assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState());
}
Expand Down