diff --git a/CHANGELOG.md b/CHANGELOG.md index a09f40bbc..e3b5aa721 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305] - Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320] - Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357] +- Added more detailed error messages for KNN model training (#2378)[https://github.com/opensearch-project/k-NN/pull/2378] ### Bug Fixes * Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282] * Fix for NPE while merging segments after all the vector fields docs are deleted (#2365)[https://github.com/opensearch-project/k-NN/pull/2365] diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java index 9208661af..a493b065f 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java @@ -11,6 +11,7 @@ import org.opensearch.knn.index.mapper.VectorValidator; import java.util.Map; +import java.util.function.BiFunction; /** * Context a library gives to build one of its indices @@ -47,4 +48,10 @@ public interface KNNLibraryIndexingContext { * @return Get the per dimension processor */ PerDimensionProcessor getPerDimensionProcessor(); + + /** + * + * @return Get function that validates training model parameters + */ + BiFunction getTrainingConfigValidationSetup(); } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java index f5329fc31..e0286de37 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java @@ -13,6 +13,11 @@ import java.util.Collections; import java.util.Map; +import java.util.function.BiFunction; + +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; /** * Simple implementation of {@link KNNLibraryIndexingContext} @@ -52,4 +57,33 @@ public PerDimensionValidator getPerDimensionValidator() { public PerDimensionProcessor getPerDimensionProcessor() { return perDimensionProcessor; } + + @Override + public BiFunction getTrainingConfigValidationSetup() { + return (trainingVectors, knnMethodContext) -> { + + long minTrainingVectorCount = 1000; + TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder(); + + MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext() + .getParameters() + .get(METHOD_ENCODER_PARAMETER); + + if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST) + && encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) { + + int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST)); + int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE)); + minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size)); + } + + if (trainingVectors < minTrainingVectorCount) { + builder.valid(false).minTrainingVectorCount(minTrainingVectorCount); + return builder.build(); + } + + builder.valid(true); + return builder.build(); + }; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java b/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java new file mode 100644 index 000000000..0cbe6cad5 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * This object provides the output of the validation checks for training model inputs. + * The values in this object need to be dynamically set and calling code needs to handle + * the possibility that the values have not been set. + */ +@Setter +@Getter +@Builder +@AllArgsConstructor +public class TrainingConfigValidationOutput { + private boolean valid; + private long minTrainingVectorCount; +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index 78f3769c5..99e0d1940 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -23,12 +23,17 @@ import org.opensearch.common.ValidationException; import org.opensearch.common.inject.Inject; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.TrainingConfigValidationOutput; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportService; import java.util.Map; +import java.util.function.BiFunction; import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; import static org.opensearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER; @@ -134,6 +139,25 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques trainingVectors = trainingModelRequest.getMaximumVectorCount(); } + KNNMethodContext knnMethodContext = trainingModelRequest.getKnnMethodContext(); + KNNMethodConfigContext knnMethodConfigContext = trainingModelRequest.getKnnMethodConfigContext(); + + KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + + BiFunction validateTrainingConfig = knnLibraryIndexingContext + .getTrainingConfigValidationSetup(); + + TrainingConfigValidationOutput validation = validateTrainingConfig.apply(trainingVectors, knnMethodContext); + if (!validation.isValid()) { + ValidationException exception = new ValidationException(); + exception.addValidationError( + String.format("Number of training points should be greater than %d", validation.getMinTrainingVectorCount()) + ); + listener.onFailure(exception); + return; + } + listener.onResponse( estimateVectorSetSizeInKB(trainingVectors, trainingModelRequest.getDimension(), trainingModelRequest.getVectorDataType()) ); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 9906ab490..8d43d3dd9 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -35,6 +35,8 @@ import java.io.IOException; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; + /** * Request to train and serialize a model */ @@ -283,6 +285,15 @@ public ActionRequestValidationException validate() { exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + " characters"); } + // Check if ENCODER_PARAMETER_PQ_M is divisible by vector dimension + if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M) + && knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext() + .getParameters() + .get(ENCODER_PARAMETER_PQ_M) != 0) { + exception = exception == null ? new ActionRequestValidationException() : exception; + exception.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions"); + } + // Validate training index exists IndexMetadata indexMetadata = clusterService.state().metadata().index(trainingIndex); if (indexMetadata == null) { diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index b479192e8..275aa2f47 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -203,9 +203,7 @@ public void run() { } catch (Exception e) { logger.error("Failed to run training job for model \"" + modelId + "\": ", e); modelMetadata.setState(ModelState.FAILED); - modelMetadata.setError( - "Failed to execute training. May be caused by an invalid method definition or " + "not enough memory to perform training." - ); + modelMetadata.setError("Failed to execute training. " + e.getMessage()); KNNCounter.TRAINING_ERRORS.increment(); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java index 30c5d33a1..aee45e2cc 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java @@ -36,6 +36,7 @@ import java.util.List; import java.util.Map; +import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -344,6 +345,55 @@ public void testTrainingIndexSize() { transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener); } + public void testTrainingIndexSizeFailure() { + + String trainingIndexName = "training-index"; + int dimension = 133; + int vectorCount = 100; + + // Setup the request + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( + null, + getDefaultKNNMethodContextForModel(), + dimension, + trainingIndexName, + "training-field", + null, + "description", + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ); + + // Mock client to return the right number of docs + TotalHits totalHits = new TotalHits(vectorCount, TotalHits.Relation.EQUAL_TO); + SearchHits searchHits = new SearchHits(new SearchHit[2], totalHits, 1.0f); + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(searchHits); + Client client = mock(Client.class); + doAnswer(invocationOnMock -> { + ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + // Setup the action + ClusterService clusterService = mock(ClusterService.class); + TransportService transportService = mock(TransportService.class); + TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction( + transportService, + new ActionFilters(Collections.emptySet()), + clusterService, + client + ); + + ActionListener listener = ActionListener.wrap( + size -> size.intValue(), + e -> assertThat(e.getMessage(), containsString("Number of training points should be greater than")) + ); + + transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener); + } + public void testTrainIndexSize_whenDataTypeIsBinary() { String trainingIndexName = "training-index"; int dimension = 8; diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 6fd399434..fdffc91d0 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -621,11 +621,61 @@ public void testValidation_invalid_descriptionToLong() { ActionRequestValidationException exception = trainingModelRequest.validate(); assertNotNull(exception); List validationErrors = exception.validationErrors(); - logger.error("Validation errorsa " + validationErrors); + logger.error("Validation errors " + validationErrors); assertEquals(1, validationErrors.size()); assertTrue(validationErrors.get(0).contains("Description exceeds limit")); } + public void testValidation_invalid_mNotDivisibleByDimension() { + + // Setup the training request + String modelId = "test-model-id"; + int dimension = 10; + String trainingIndex = "test-training-index"; + String trainingField = "test-training-field"; + String trainingFieldModeId = "training-field-model-id"; + + Map parameters = Map.of("m", 3); + + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, parameters); + final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.DEFAULT, methodComponentContext); + + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null, + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ); + + // Mock the model dao to return metadata for modelId to recognize it is a duplicate + ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class); + when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); + + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.getMetadata(modelId)).thenReturn(null); + when(modelDao.getMetadata(trainingFieldModeId)).thenReturn(trainingFieldModelMetadata); + + // Cluster service that wont produce validation exception + ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension); + + // Initialize static components with the mocks + TrainingModelRequest.initialize(modelDao, clusterService); + + // Test that validation produces m not divisible by vector dimension error message + ActionRequestValidationException exception = trainingModelRequest.validate(); + assertNotNull(exception); + List validationErrors = exception.validationErrors(); + logger.error("Validation errors " + validationErrors); + assertEquals(2, validationErrors.size()); + assertTrue(validationErrors.get(1).contains("Training request ENCODER_PARAMETER_PQ_M")); + } + public void testValidation_valid_trainingIndexBuiltFromMethod() { // This cluster service will result in no validation exceptions diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index 4706bd000..8db9d67bc 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -41,6 +41,7 @@ import java.util.UUID; import java.util.concurrent.ExecutionException; +import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -217,7 +218,6 @@ public void testRun_success() throws IOException, ExecutionException { Model model = trainingJob.getModel(); assertNotNull(model); - assertEquals(ModelState.CREATED, model.getModelMetadata().getState()); // Simple test that creates the index from template and doesnt fail @@ -308,6 +308,10 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept Model model = trainingJob.getModel(); assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState()); + assertThat( + "Failed to load training data into memory. " + "Check if there is enough memory to perform the request.", + containsString(trainingJob.getModel().getModelMetadata().getError()) + ); assertNotNull(model); assertFalse(model.getModelMetadata().getError().isEmpty()); } @@ -382,6 +386,10 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce Model model = trainingJob.getModel(); assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState()); + assertThat( + "Failed to allocate space in native memory for the model. " + "Check if there is enough memory to perform the request.", + containsString(trainingJob.getModel().getModelMetadata().getError()) + ); assertNotNull(model); assertFalse(model.getModelMetadata().getError().isEmpty()); } @@ -435,7 +443,7 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep when(nativeMemoryAllocation.isClosed()).thenReturn(true); when(nativeMemoryAllocation.getMemoryAddress()).thenReturn((long) 0); - // Throw error on getting data + // Throw error on allocation is closed when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); TrainingJob trainingJob = new TrainingJob( @@ -443,7 +451,83 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep knnMethodContext, nativeMemoryCacheManager, trainingDataEntryContext, - mock(NativeMemoryEntryContext.AnonymousEntryContext.class), + modelContext, + knnMethodConfigContext, + "", + "test-node", + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ); + + trainingJob.run(); + + Model model = trainingJob.getModel(); + assertThat( + "Failed to execute training. Unable to load training data into memory: allocation is already closed", + containsString(trainingJob.getModel().getModelMetadata().getError()) + ); + assertNotNull(model); + assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState()); + } + + public void testRun_failure_closedModelAnonymousAllocation() throws ExecutionException { + // In this test, the model anonymous allocation should be closed. Then, run should fail and update the error of + // the model + String modelId = "test-model-id"; + + // Define the method setup for method that requires training + int nlists = 5; + int dimension = 16; + KNNEngine knnEngine = KNNEngine.FAISS; + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .dimension(dimension) + .versionCreated(Version.CURRENT) + .build(); + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists)) + ); + + String tdataKey = "t-data-key"; + NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock( + NativeMemoryEntryContext.TrainingDataEntryContext.class + ); + when(trainingDataEntryContext.getKey()).thenReturn(tdataKey); + + // Setup model manager + NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class); + + // Setup mock allocation for model that's closed + NativeMemoryAllocation modelAllocation = mock(NativeMemoryAllocation.class); + doAnswer(invocationOnMock -> null).when(modelAllocation).readLock(); + doAnswer(invocationOnMock -> null).when(modelAllocation).readUnlock(); + when(modelAllocation.isClosed()).thenReturn(true); + + String modelKey = "model-test-key"; + NativeMemoryEntryContext.AnonymousEntryContext modelContext = mock(NativeMemoryEntryContext.AnonymousEntryContext.class); + when(modelContext.getKey()).thenReturn(modelKey); + + // Throw error on allocation is closed + when(nativeMemoryCacheManager.get(modelContext, false)).thenReturn(modelAllocation); + doAnswer(invocationOnMock -> null).when(nativeMemoryCacheManager).invalidate(modelKey); + + // Setup mock allocation thats not closed + NativeMemoryAllocation nativeMemoryAllocation = mock(NativeMemoryAllocation.class); + doAnswer(invocationOnMock -> null).when(nativeMemoryAllocation).readLock(); + doAnswer(invocationOnMock -> null).when(nativeMemoryAllocation).readUnlock(); + when(nativeMemoryAllocation.isClosed()).thenReturn(false); + when(nativeMemoryAllocation.getMemoryAddress()).thenReturn((long) 0); + + when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); + + TrainingJob trainingJob = new TrainingJob( + modelId, + knnMethodContext, + nativeMemoryCacheManager, + trainingDataEntryContext, + modelContext, knnMethodConfigContext, "", "test-node", @@ -454,6 +538,10 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep trainingJob.run(); Model model = trainingJob.getModel(); + assertThat( + "Failed to execute training. Unable to reserve memory for model: allocation is already closed", + containsString(trainingJob.getModel().getModelMetadata().getError()) + ); assertNotNull(model); assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState()); }