diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 43f4d7ad6..2c110fb79 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -24,7 +24,7 @@ import org.apache.lucene.index.Sorter; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; -import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index ed0e8149a..886c6d93d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -24,7 +24,7 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.indices.Model; diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java index 8fec1af6d..bebe9e8b0 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java @@ -7,7 +7,7 @@ import lombok.experimental.UtilityClass; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; -import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; diff --git a/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java b/src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java similarity index 94% rename from src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java rename to src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java index 4cf68d16c..d880a4178 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java +++ b/src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.knn.index.quantizationService; +package org.opensearch.knn.index.quantizationservice; import lombok.extern.log4j.Log4j2; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; @@ -49,7 +49,7 @@ public T getVectorAtThePosition(int position) throws IOException { } knnVectorValues.nextDoc(); } - // Return the vector and the updated index + // Return the vector return knnVectorValues.getVector(); } } diff --git a/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java similarity index 99% rename from src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java rename to src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java index a9e3cc715..b1c94f993 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java +++ b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.knn.index.quantizationService; +package org.opensearch.knn.index.quantizationservice; import lombok.AccessLevel; import lombok.NoArgsConstructor; diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java index 79ce7b955..dbef3a72a 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java @@ -150,7 +150,14 @@ public int getDimensions() { if (thresholds == null || thresholds.length == 0 || thresholds[0] == null) { throw new IllegalStateException("Error in getting Dimension: The thresholds array is not initialized."); } - return thresholds.length * thresholds[0].length; + int originalDimensions = thresholds[0].length; + + // Align the original dimensions to the next multiple of 8 for each bit level + int alignedDimensions = (originalDimensions + 7) & ~7; + + // The final dimension count should consider the bit levels + return thresholds.length * alignedDimensions; + } /** diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java index 9c4ff7460..0a8c33771 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -123,7 +123,8 @@ public int getBytesPerVector() { @Override public int getDimensions() { // For one-bit quantization, the dimension for indexing is just the length of the thresholds array. - return meanThresholds.length; + // Align the original dimensions to the next multiple of 8 + return (meanThresholds.length + 7) & ~7; } /** diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java index 0b5a06dfc..1a8a832aa 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java @@ -17,7 +17,7 @@ import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java index 3bfec4104..81d490bb4 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java @@ -16,7 +16,7 @@ import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java index 30a2098dd..61d3d7589 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java @@ -9,7 +9,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; -import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; diff --git a/src/test/java/org/opensearch/knn/index/quantizationService/QuantizationServiceTests.java b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java similarity index 99% rename from src/test/java/org/opensearch/knn/index/quantizationService/QuantizationServiceTests.java rename to src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java index 886dbeabc..720b67fd5 100644 --- a/src/test/java/org/opensearch/knn/index/quantizationService/QuantizationServiceTests.java +++ b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.knn.index.quantizationService; +package org.opensearch.knn.index.quantizationservice; import org.opensearch.knn.KNNTestCase; import org.junit.Before; diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java index 298256127..4fd4f40a6 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java @@ -100,6 +100,78 @@ public void testOneBitScalarQuantizationStateRamBytesUsed() throws IOException { assertEquals(expectedRamBytesUsed, actualRamBytesUsed); } + public void testMultiBitScalarQuantizationStateGetDimensions_withDimensionNotMultipleOf8_thenSuccess() { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + + // Case 1: 3 thresholds, each with 2 dimensions + float[][] thresholds1 = { { 0.5f, 1.5f }, { 1.0f, 2.0f }, { 1.5f, 2.5f } }; + MultiBitScalarQuantizationState state1 = new MultiBitScalarQuantizationState(params, thresholds1); + int expectedDimensions1 = 24; // The next multiple of 8 considering all bits + assertEquals(expectedDimensions1, state1.getDimensions()); + + // Case 2: 1 threshold, with 5 dimensions (5 bits, should align to 8) + float[][] thresholds2 = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f } }; + MultiBitScalarQuantizationState state2 = new MultiBitScalarQuantizationState(params, thresholds2); + int expectedDimensions2 = 8; // The next multiple of 8 considering all bits + assertEquals(expectedDimensions2, state2.getDimensions()); + + // Case 3: 4 thresholds, each with 7 dimensions (28 bits, should align to 32) + float[][] thresholds3 = { + { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f }, + { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f }, + { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, + { 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } }; + MultiBitScalarQuantizationState state3 = new MultiBitScalarQuantizationState(params, thresholds3); + int expectedDimensions3 = 32; // The next multiple of 8 considering all bits + assertEquals(expectedDimensions3, state3.getDimensions()); + + // Case 4: 2 thresholds, each with 8 dimensions (16 bits, already aligned) + float[][] thresholds4 = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } }; + MultiBitScalarQuantizationState state4 = new MultiBitScalarQuantizationState(params, thresholds4); + int expectedDimensions4 = 16; // Already aligned to 8 + assertEquals(expectedDimensions4, state4.getDimensions()); + + // Case 5: 2 thresholds, each with 6 dimensions (12 bits, should align to 16) + float[][] thresholds5 = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f }, { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f } }; + MultiBitScalarQuantizationState state5 = new MultiBitScalarQuantizationState(params, thresholds5); + int expectedDimensions5 = 16; // The next multiple of 8 considering all bits + assertEquals(expectedDimensions5, state5.getDimensions()); + } + + public void testOneBitScalarQuantizationStateGetDimensions_withDimensionNotMultipleOf8_thenSuccess() { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + + // Case 1: 5 dimensions (should align to 8) + float[] thresholds1 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f }; + OneBitScalarQuantizationState state1 = new OneBitScalarQuantizationState(params, thresholds1); + int expectedDimensions1 = 8; // The next multiple of 8 + assertEquals(expectedDimensions1, state1.getDimensions()); + + // Case 2: 7 dimensions (should align to 8) + float[] thresholds2 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f }; + OneBitScalarQuantizationState state2 = new OneBitScalarQuantizationState(params, thresholds2); + int expectedDimensions2 = 8; // The next multiple of 8 + assertEquals(expectedDimensions2, state2.getDimensions()); + + // Case 3: 8 dimensions (already aligned to 8) + float[] thresholds3 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }; + OneBitScalarQuantizationState state3 = new OneBitScalarQuantizationState(params, thresholds3); + int expectedDimensions3 = 8; // Already aligned to 8 + assertEquals(expectedDimensions3, state3.getDimensions()); + + // Case 4: 10 dimensions (should align to 16) + float[] thresholds4 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f }; + OneBitScalarQuantizationState state4 = new OneBitScalarQuantizationState(params, thresholds4); + int expectedDimensions4 = 16; // The next multiple of 8 + assertEquals(expectedDimensions4, state4.getDimensions()); + + // Case 5: 16 dimensions (already aligned to 16) + float[] thresholds5 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f, 10.5f, 11.5f, 12.5f, 13.5f, 14.5f, 15.5f }; + OneBitScalarQuantizationState state5 = new OneBitScalarQuantizationState(params, thresholds5); + int expectedDimensions5 = 16; // Already aligned to 16 + assertEquals(expectedDimensions5, state5.getDimensions()); + } + public void testMultiBitScalarQuantizationStateRamBytesUsedManualCalculation() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); float[][] thresholds = { { 0.5f, 1.5f, 2.5f }, { 1.0f, 2.0f, 3.0f } };