From b1bd801d44c39eb4db1e2828a84797dca0bd258d Mon Sep 17 00:00:00 2001 From: Volodymyr Shkolka Date: Mon, 22 Apr 2024 15:38:20 +0300 Subject: [PATCH] Array fields support (#81) Fixes #77 --- Milvus.Client.Tests/FieldTests.cs | 137 ++++++++++++++++++ Milvus.Client.sln.DotSettings | 3 + Milvus.Client/ArrayFieldData.cs | 145 +++++++++++++++++++ Milvus.Client/Constants.cs | 5 + Milvus.Client/FieldData.cs | 96 ++++++++---- Milvus.Client/FieldSchema.cs | 48 ++++++ Milvus.Client/MilvusClient.Collection.cs | 10 ++ Milvus.Client/MilvusCollection.Collection.cs | 17 ++- 8 files changed, 428 insertions(+), 33 deletions(-) create mode 100644 Milvus.Client/ArrayFieldData.cs diff --git a/Milvus.Client.Tests/FieldTests.cs b/Milvus.Client.Tests/FieldTests.cs index 4afa717..4bae3f8 100644 --- a/Milvus.Client.Tests/FieldTests.cs +++ b/Milvus.Client.Tests/FieldTests.cs @@ -67,4 +67,141 @@ public void CreateFloatVectorTest() Assert.Equal(2, field.RowCount); Assert.Equal(2, field.Data.Count); } + + [Fact] + public void CreateInt8ArrayTest() + { + var field = FieldData.CreateArray( + "vector", + new sbyte[][] + { + [1, 2], + [3, 4], + }); + + Assert.Equal(MilvusDataType.Array, field.DataType); + Assert.Equal(MilvusDataType.Int8, field.ElementType); + Assert.Equal(2, field.RowCount); + Assert.Equal(2, field.Data.Count); + } + + [Fact] + public void CreateInt16ArrayTest() + { + var field = FieldData.CreateArray( + "vector", + new short[][] + { + [1, 2], + [3, 4], + }); + + Assert.Equal(MilvusDataType.Array, field.DataType); + Assert.Equal(MilvusDataType.Int16, field.ElementType); + Assert.Equal(2, field.RowCount); + Assert.Equal(2, field.Data.Count); + } + + [Fact] + public void CreateInt32ArrayTest() + { + var field = FieldData.CreateArray( + "vector", + new int[][] + { + [1, 2], + [3, 4], + }); + + Assert.Equal(MilvusDataType.Array, field.DataType); + Assert.Equal(MilvusDataType.Int32, field.ElementType); + Assert.Equal(2, field.RowCount); + Assert.Equal(2, field.Data.Count); + } + + [Fact] + public void CreateInt64ArrayTest() + { + var field = FieldData.CreateArray( + "vector", + new long[][] + { + [1, 2], + [3, 4], + }); + + Assert.Equal(MilvusDataType.Array, field.DataType); + Assert.Equal(MilvusDataType.Int64, field.ElementType); + Assert.Equal(2, field.RowCount); + Assert.Equal(2, field.Data.Count); + } + + [Fact] + public void CreateBoolArrayTest() + { + var field = FieldData.CreateArray( + "vector", + new bool[][] + { + [true, false], + [false, false], + }); + + Assert.Equal(MilvusDataType.Array, field.DataType); + Assert.Equal(MilvusDataType.Bool, field.ElementType); + Assert.Equal(2, field.RowCount); + Assert.Equal(2, field.Data.Count); + } + + [Fact] + public void CreateFloatArrayTest() + { + var field = FieldData.CreateArray( + "vector", + new float[][] + { + [1, 2], + [3, 4], + }); + + Assert.Equal(MilvusDataType.Array, field.DataType); + Assert.Equal(MilvusDataType.Float, field.ElementType); + Assert.Equal(2, field.RowCount); + Assert.Equal(2, field.Data.Count); + } + + [Fact] + public void CreateDoubleArrayTest() + { + var field = FieldData.CreateArray( + "vector", + new double[][] + { + [1, 2], + [3, 4], + }); + + Assert.Equal(MilvusDataType.Array, field.DataType); + Assert.Equal(MilvusDataType.Double, field.ElementType); + Assert.Equal(2, field.RowCount); + Assert.Equal(2, field.Data.Count); + } + + //TODO: differentiate VarChar and String somehow + [Fact] + public void CreateVarCharArrayTest() + { + var field = FieldData.CreateArray( + "vector", + new string[][] + { + ["3d4d387208e04a9abe77be65e2b7c7b3", "a5502ddb557047968a70ff69720d2dd2"], + ["4c246789a91f4b15aa3b26799df61457", "00a23e95823b4f14854ceed5f7059953"], + }); + + Assert.Equal(MilvusDataType.Array, field.DataType); + Assert.Equal(MilvusDataType.VarChar, field.ElementType); + Assert.Equal(2, field.RowCount); + Assert.Equal(2, field.Data.Count); + } } diff --git a/Milvus.Client.sln.DotSettings b/Milvus.Client.sln.DotSettings index b7ffb9a..d186ee4 100644 --- a/Milvus.Client.sln.DotSettings +++ b/Milvus.Client.sln.DotSettings @@ -1,9 +1,12 @@  + 1 + True True True True True True + True True True True diff --git a/Milvus.Client/ArrayFieldData.cs b/Milvus.Client/ArrayFieldData.cs new file mode 100644 index 0000000..3874588 --- /dev/null +++ b/Milvus.Client/ArrayFieldData.cs @@ -0,0 +1,145 @@ +namespace Milvus.Client; + +/// +/// Binary Field +/// +public sealed class ArrayFieldData : FieldData> +{ + /// + /// Construct an array field + /// + /// + /// + /// + public ArrayFieldData(string fieldName, IReadOnlyList> data, bool isDynamic) + : base(fieldName, data, MilvusDataType.Array, isDynamic) + { + ElementType = EnsureDataType(); + } + + /// + /// Array element type + /// + public MilvusDataType ElementType { get; } + + /// + internal override Grpc.FieldData ToGrpcFieldData() + { + Check(); + + Grpc.FieldData fieldData = new() + { + Type = Grpc.DataType.Array, + IsDynamic = IsDynamic + }; + + if (FieldName is not null) + { + fieldData.FieldName = FieldName; + } + + var arrayArray = new ArrayArray + { + ElementType = (DataType) ElementType, + }; + + fieldData.Scalars = new ScalarField + { + ArrayData = arrayArray + }; + + foreach (var array in Data) + { + switch (ElementType) + { + case MilvusDataType.Bool: + BoolArray boolData = new(); + boolData.Data.AddRange(array as IEnumerable); + arrayArray.Data.Add(new ScalarField { BoolData = boolData, }); + break; + + case MilvusDataType.Int8: + IntArray int8Data = new(); + var sbytes = array as IEnumerable ?? Enumerable.Empty(); + int8Data.Data.AddRange(sbytes.Select(x => (int) x)); + arrayArray.Data.Add(new ScalarField { IntData = int8Data, }); + break; + + case MilvusDataType.Int16: + IntArray int16Data = new(); + var shorts = array as IEnumerable ?? Enumerable.Empty(); + int16Data.Data.AddRange(shorts.Select(x => (int) x)); + arrayArray.Data.Add(new ScalarField { IntData = int16Data, }); + break; + + case MilvusDataType.Int32: + IntArray int32Data = new(); + int32Data.Data.AddRange(array as IEnumerable); + arrayArray.Data.Add(new ScalarField { IntData = int32Data, }); + break; + + case MilvusDataType.Int64: + LongArray int64Data = new(); + int64Data.Data.AddRange(array as IEnumerable); + arrayArray.Data.Add(new ScalarField { LongData = int64Data, }); + break; + + case MilvusDataType.Float: + FloatArray floatData = new(); + floatData.Data.AddRange(array as IEnumerable); + arrayArray.Data.Add(new ScalarField { FloatData = floatData, }); + break; + + case MilvusDataType.Double: + DoubleArray doubleData = new(); + doubleData.Data.AddRange(array as IEnumerable); + arrayArray.Data.Add(new ScalarField { DoubleData = doubleData, }); + break; + + case MilvusDataType.String: + StringArray stringData = new(); + stringData.Data.AddRange(array as IEnumerable); + arrayArray.Data.Add(new ScalarField { StringData = stringData, }); + break; + + case MilvusDataType.VarChar: + StringArray varcharData = new(); + varcharData.Data.AddRange(array as IEnumerable); + arrayArray.Data.Add(new ScalarField { StringData = varcharData, }); + break; + + case MilvusDataType.Json: + JSONArray jsonData = new(); + var enumerable = array as IEnumerable ?? Enumerable.Empty(); + jsonData.Data.AddRange(enumerable.Select(ByteString.CopyFromUtf8)); + arrayArray.Data.Add(new ScalarField { JsonData = jsonData, }); + break; + + case MilvusDataType.None: + throw new MilvusException($"ElementType Error:{DataType}"); + + default: + throw new MilvusException($"ElementType Error:{DataType}, not supported"); + } + } + + return fieldData; + } + + internal override object GetValueAsObject(int index) + => ElementType switch + { + MilvusDataType.Bool => ((IReadOnlyList>) Data)[index], + MilvusDataType.Int8 => ((IReadOnlyList>) Data)[index], + MilvusDataType.Int16 => ((IReadOnlyList>) Data)[index], + MilvusDataType.Int32 => ((IReadOnlyList>) Data)[index], + MilvusDataType.Int64 => ((IReadOnlyList>) Data)[index], + MilvusDataType.Float => ((IReadOnlyList>) Data)[index], + MilvusDataType.Double => ((IReadOnlyList>) Data)[index], + MilvusDataType.String => ((IReadOnlyList>) Data)[index], + MilvusDataType.VarChar => ((IReadOnlyList>) Data)[index], + + MilvusDataType.None => throw new MilvusException($"DataType Error:{DataType}"), + _ => throw new MilvusException($"DataType Error:{DataType}, not supported") + }; +} diff --git a/Milvus.Client/Constants.cs b/Milvus.Client/Constants.cs index cce7c09..c59455f 100644 --- a/Milvus.Client/Constants.cs +++ b/Milvus.Client/Constants.cs @@ -70,6 +70,11 @@ internal static class Constants /// internal const string FailedReason = "failed_reason"; + /// + /// Key name. + /// + internal const string MaxCapacity = "max_capacity"; + /// /// Files. /// diff --git a/Milvus.Client/FieldData.cs b/Milvus.Client/FieldData.cs index 97c67d9..a4ea17e 100644 --- a/Milvus.Client/FieldData.cs +++ b/Milvus.Client/FieldData.cs @@ -13,7 +13,7 @@ public abstract class FieldData /// /// Field name. /// Field data type. - /// + /// Whether the field is dynamic. protected FieldData(string fieldName, MilvusDataType dataType, bool isDynamic = false) { FieldName = fieldName; @@ -120,23 +120,41 @@ internal static FieldData FromGrpcFieldData(Grpc.FieldData fieldData) } case Grpc.FieldData.FieldOneofCase.Scalars: - return fieldData.Scalars.DataCase switch + return fieldData.Scalars switch { - Grpc.ScalarField.DataOneofCase.BoolData + { DataCase: ScalarField.DataOneofCase.BoolData } => Create(fieldData.FieldName, fieldData.Scalars.BoolData.Data, fieldData.IsDynamic), - Grpc.ScalarField.DataOneofCase.FloatData + { DataCase: ScalarField.DataOneofCase.FloatData } => Create(fieldData.FieldName, fieldData.Scalars.FloatData.Data, fieldData.IsDynamic), - Grpc.ScalarField.DataOneofCase.IntData + { DataCase: ScalarField.DataOneofCase.IntData } => Create(fieldData.FieldName, fieldData.Scalars.IntData.Data, fieldData.IsDynamic), - Grpc.ScalarField.DataOneofCase.LongData + { DataCase: ScalarField.DataOneofCase.LongData } => Create(fieldData.FieldName, fieldData.Scalars.LongData.Data, fieldData.IsDynamic), - Grpc.ScalarField.DataOneofCase.StringData + { DataCase: ScalarField.DataOneofCase.StringData } => CreateVarChar(fieldData.FieldName, fieldData.Scalars.StringData.Data, fieldData.IsDynamic), - Grpc.ScalarField.DataOneofCase.JsonData - => CreateJson(fieldData.FieldName, fieldData.Scalars.JsonData.Data - .Select(p => p.ToStringUtf8()).ToList(), fieldData.IsDynamic), - - _ => throw new NotSupportedException($"{fieldData.Scalars.DataCase} not support"), + { DataCase: ScalarField.DataOneofCase.JsonData } + => CreateJson(fieldData.FieldName, fieldData.Scalars.JsonData.Data.Select(p => p.ToStringUtf8()).ToList(), fieldData.IsDynamic), + { DataCase: ScalarField.DataOneofCase.ArrayData, ArrayData.ElementType: Grpc.DataType.Bool } + => CreateArray(fieldData.FieldName, fieldData.Scalars.ArrayData.Data.Select(x => x.BoolData.Data).ToArray(), fieldData.IsDynamic), + { DataCase: ScalarField.DataOneofCase.ArrayData, ArrayData.ElementType: Grpc.DataType.Int8 } + => CreateArray(fieldData.FieldName, fieldData.Scalars.ArrayData.Data.Select(x => x.IntData.Data).ToArray(), fieldData.IsDynamic), + { DataCase: ScalarField.DataOneofCase.ArrayData, ArrayData.ElementType: Grpc.DataType.Int16 } + => CreateArray(fieldData.FieldName, fieldData.Scalars.ArrayData.Data.Select(x => x.IntData.Data).ToArray(), fieldData.IsDynamic), + { DataCase: ScalarField.DataOneofCase.ArrayData, ArrayData.ElementType: Grpc.DataType.Int32 } + => CreateArray(fieldData.FieldName, fieldData.Scalars.ArrayData.Data.Select(x => x.IntData.Data).ToArray(), fieldData.IsDynamic), + { DataCase: ScalarField.DataOneofCase.ArrayData, ArrayData.ElementType: Grpc.DataType.Int64 } + => CreateArray(fieldData.FieldName, fieldData.Scalars.ArrayData.Data.Select(x => x.LongData.Data).ToArray(), fieldData.IsDynamic), + { DataCase: ScalarField.DataOneofCase.ArrayData, ArrayData.ElementType: Grpc.DataType.Float } + => CreateArray(fieldData.FieldName, fieldData.Scalars.ArrayData.Data.Select(x => x.FloatData.Data).ToArray(), fieldData.IsDynamic), + { DataCase: ScalarField.DataOneofCase.ArrayData, ArrayData.ElementType: Grpc.DataType.Double } + => CreateArray(fieldData.FieldName, fieldData.Scalars.ArrayData.Data.Select(x => x.DoubleData.Data).ToArray(), fieldData.IsDynamic), + { DataCase: ScalarField.DataOneofCase.ArrayData, ArrayData.ElementType: Grpc.DataType.String } + => CreateArray(fieldData.FieldName, fieldData.Scalars.ArrayData.Data.Select(x => x.StringData.Data).ToArray(), fieldData.IsDynamic), + { DataCase: ScalarField.DataOneofCase.ArrayData, ArrayData.ElementType: Grpc.DataType.VarChar } + => CreateArray(fieldData.FieldName, fieldData.Scalars.ArrayData.Data.Select(x => x.StringData.Data).ToArray(), fieldData.IsDynamic), + { DataCase: ScalarField.DataOneofCase.ArrayData, ArrayData.ElementType: Grpc.DataType.Json } + => CreateArray(fieldData.FieldName, fieldData.Scalars.ArrayData.Data.Select(x => x.JsonData.Data).ToArray(), fieldData.IsDynamic), + _ => throw new NotSupportedException($"{ fieldData.Scalars.DataCase } not support"), }; default: @@ -147,7 +165,7 @@ internal static FieldData FromGrpcFieldData(Grpc.FieldData fieldData) internal static MilvusDataType EnsureDataType() { Type type = typeof(TDataType); - MilvusDataType dataType = MilvusDataType.Double; + MilvusDataType dataType; if (type == typeof(bool)) { @@ -191,7 +209,7 @@ internal static MilvusDataType EnsureDataType() } else { - throw new NotSupportedException($"Not Support DataType:{dataType}"); + throw new NotSupportedException($"Type {type.Name} cannot be mapped to DataType"); } return dataType; @@ -227,8 +245,8 @@ public static FieldData Create( /// Create a varchar field. /// /// Field name. - /// Data. - /// + /// Data in this field + /// Whether the field is dynamic. /// public static FieldData CreateVarChar( string fieldName, @@ -236,6 +254,21 @@ public static FieldData CreateVarChar( bool isDynamic = false) => new(fieldName, data, MilvusDataType.VarChar, isDynamic); + /// + /// Create array of elements. + /// + /// Field name. + /// Data in this field + /// Whether the field is dynamic. + /// + public static ArrayFieldData CreateArray( + string fieldName, + IReadOnlyList> data, + bool isDynamic = false) + { + return new(fieldName, data, isDynamic); + } + /// /// Create a field from array. /// @@ -267,8 +300,8 @@ public static BinaryVectorFieldData CreateFromBytes(string fieldName, ReadOnlySp /// /// Create a binary vectors /// - /// - /// + /// Field name. + /// Data in this field /// public static BinaryVectorFieldData CreateBinaryVectors(string fieldName, IReadOnlyList> data) { @@ -281,7 +314,7 @@ public static BinaryVectorFieldData CreateBinaryVectors(string fieldName, IReadO /// Create a float vector. /// /// Field name. - /// Data + /// Data in this field /// public static FloatVectorFieldData CreateFloatVector(string fieldName, IReadOnlyList> data) => new(fieldName, data); @@ -290,7 +323,7 @@ public static FloatVectorFieldData CreateFloatVector(string fieldName, IReadOnly /// Create a field from stream /// /// Field name - /// + /// Stream data /// Dimension of data /// New created field public static FieldData CreateFromStream(string fieldName, Stream stream, long dimension) @@ -305,8 +338,8 @@ public static FieldData CreateFromStream(string fieldName, Stream stream, long d /// Create json field. /// /// Field name. - /// json field. - /// + /// Json field. + /// Whether the field is dynamic. /// public static FieldData CreateJson(string fieldName, IReadOnlyList json, bool isDynamic = false) { @@ -324,9 +357,9 @@ public class FieldData : FieldData /// /// Construct a field /// - /// - /// - /// + /// Field name. + /// Data in this field + /// Whether the field is dynamic. public FieldData(string fieldName, IReadOnlyList data, bool isDynamic = false) : base(fieldName, EnsureDataType(), isDynamic) => Data = data; @@ -334,10 +367,10 @@ public FieldData(string fieldName, IReadOnlyList data, bool isDynamic = f /// /// Construct a field /// - /// - /// + /// Field name. + /// Data in this field /// Milvus data type. - /// + /// Whether the field is dynamic. public FieldData(string fieldName, IReadOnlyList data, MilvusDataType milvusDataType, bool isDynamic) : base(fieldName, milvusDataType, isDynamic) => Data = data; @@ -446,7 +479,6 @@ internal override Grpc.FieldData ToGrpcFieldData() } fieldData.Scalars = new Grpc.ScalarField { JsonData = jsonData, }; break; - case MilvusDataType.None: throw new MilvusException($"DataType Error:{DataType}"); default: @@ -481,7 +513,11 @@ internal override object GetValueAsObject(int index) public override string ToString() => $"Field: {{{nameof(FieldName)}: {FieldName}, {nameof(DataType)}: {DataType}, {nameof(Data)}: {Data?.Count}, {nameof(RowCount)}: {RowCount}}}"; - private void Check() + /// + /// Checks data + /// + /// + protected void Check() { if (Data.Any() != true) { diff --git a/Milvus.Client/FieldSchema.cs b/Milvus.Client/FieldSchema.cs index 0c9720b..46b093f 100644 --- a/Milvus.Client/FieldSchema.cs +++ b/Milvus.Client/FieldSchema.cs @@ -95,6 +95,42 @@ public static FieldSchema CreateBinaryVector(string name, int dimension, string public static FieldSchema CreateJson(string name) => new(name, MilvusDataType.Json); + /// + /// Create a field schema for a array of TData field. + /// + /// Determines the element data type of array stored in the field based. + /// The field name. + /// Maximum number of elements that an array field can contain. + /// An optional description for the field. + public static FieldSchema CreateArray( + string name, + int maxCapacity, + string description = "") + => new(name, MilvusDataType.Array, description: description) + { + ElementDataType = FieldData.EnsureDataType(), + MaxCapacity = maxCapacity, + }; + + /// + /// Create a field schema for a array of varchar field. + /// + /// The field name. + /// Maximum number of elements that an array field can contain. + /// Maximum length of strings for each varchar element in an array field. + /// An optional description for the field. + public static FieldSchema CreateVarcharArray( + string name, + int maxCapacity, + int maxLength, + string description = "") + => new(name, MilvusDataType.Array, description: description) + { + ElementDataType = MilvusDataType.VarChar, + MaxCapacity = maxCapacity, + MaxLength = maxLength, + }; + // Construct used when the user constructs a schema to be provided to CreateSchema private FieldSchema( string name, @@ -118,6 +154,7 @@ internal FieldSchema( long id, string name, MilvusDataType dataType, + MilvusDataType elementType, FieldState state, bool isPrimaryKey, bool autoId, @@ -128,6 +165,7 @@ internal FieldSchema( FieldId = id; Name = name; DataType = dataType; + ElementDataType = elementType; State = state; IsPrimaryKey = isPrimaryKey; AutoId = autoId; @@ -146,6 +184,11 @@ internal FieldSchema( /// public MilvusDataType DataType { get; } + /// + /// The element data type for array stored in the field. + /// + public MilvusDataType ElementDataType { get; set; } + /// /// Whether the field is a primary key. /// @@ -187,6 +230,11 @@ internal FieldSchema( /// public int? MaxLength { get; set; } + /// + /// Maximum number of elements that an array field can contain. Mandatory for fields, and must be in a range [1, 4096] + /// + public int? MaxCapacity { get; set; } + /// /// The dimension of the vector. Mandatory for /// and fields, and must be greater than zero. diff --git a/Milvus.Client/MilvusClient.Collection.cs b/Milvus.Client/MilvusClient.Collection.cs index d37a354..a12f54c 100644 --- a/Milvus.Client/MilvusClient.Collection.cs +++ b/Milvus.Client/MilvusClient.Collection.cs @@ -87,6 +87,7 @@ public async Task CreateCollectionAsync( { Name = field.Name, DataType = (DataType)(int)field.DataType, + ElementType = (DataType)(int)field.ElementDataType, IsPrimaryKey = field.IsPrimaryKey, IsPartitionKey = field.IsPartitionKey, AutoID = field.AutoId, @@ -111,6 +112,15 @@ public async Task CreateCollectionAsync( }); } + if (field.MaxCapacity is not null) + { + grpcField.TypeParams.Add(new Grpc.KeyValuePair + { + Key = Constants.MaxCapacity, + Value = field.MaxCapacity.Value.ToString(CultureInfo.InvariantCulture) + }); + } + grpcCollectionSchema.Fields.Add(grpcField); } diff --git a/Milvus.Client/MilvusCollection.Collection.cs b/Milvus.Client/MilvusCollection.Collection.cs index 51e656b..bdd802d 100644 --- a/Milvus.Client/MilvusCollection.Collection.cs +++ b/Milvus.Client/MilvusCollection.Collection.cs @@ -34,9 +34,16 @@ await _client.InvokeAsync(_client.GrpcClient.DescribeCollectionAsync, request, r foreach (Grpc.FieldSchema grpcField in response.Schema.Fields) { FieldSchema milvusField = new( - grpcField.FieldID, grpcField.Name, (MilvusDataType)grpcField.DataType, - (FieldState)grpcField.State, grpcField.IsPrimaryKey, grpcField.AutoID, grpcField.IsPartitionKey, - grpcField.IsDynamic, grpcField.Description); + grpcField.FieldID, + grpcField.Name, + (MilvusDataType) grpcField.DataType, + (MilvusDataType) grpcField.ElementType, + (FieldState) grpcField.State, + grpcField.IsPrimaryKey, + grpcField.AutoID, + grpcField.IsPartitionKey, + grpcField.IsDynamic, + grpcField.Description); foreach (Grpc.KeyValuePair parameter in grpcField.TypeParams) { @@ -46,6 +53,10 @@ await _client.InvokeAsync(_client.GrpcClient.DescribeCollectionAsync, request, r milvusField.MaxLength = int.Parse(parameter.Value, CultureInfo.InvariantCulture); break; + case Constants.MaxCapacity: + milvusField.MaxCapacity = int.Parse(parameter.Value, CultureInfo.InvariantCulture); + break; + case Constants.VectorDim: milvusField.Dimension = int.Parse(parameter.Value, CultureInfo.InvariantCulture); break;