diff --git a/cpp/include/milvus-storage/packed/reader.h b/cpp/include/milvus-storage/packed/reader.h index c808af0..3c3b400 100644 --- a/cpp/include/milvus-storage/packed/reader.h +++ b/cpp/include/milvus-storage/packed/reader.h @@ -17,6 +17,7 @@ #include #include #include "common/config.h" +#include "common/result.h" #include #include #include @@ -57,6 +58,8 @@ class PackedRecordBatchReader : public arrow::RecordBatchReader { arrow::Status ReadNext(std::shared_ptr* batch) override; + Result> ReadRowGroup(int file_index, int row_group_index); + arrow::Status Close() override; private: diff --git a/cpp/src/packed/reader.cpp b/cpp/src/packed/reader.cpp index c6aa969..d02103a 100644 --- a/cpp/src/packed/reader.cpp +++ b/cpp/src/packed/reader.cpp @@ -192,6 +192,21 @@ arrow::Status PackedRecordBatchReader::ReadNext(std::shared_ptr> PackedRecordBatchReader::ReadRowGroup(int file_index, int row_group_index) { + if (file_index < 0 || file_index >= file_readers_.size()) { + throw std::out_of_range("Invalid file index"); + } + + auto reader = file_readers_[file_index]->parquet_reader(); + if (row_group_index < 0 || row_group_index >= reader->metadata()->num_row_groups()) { + throw std::out_of_range("Invalid row group index"); + } + + std::shared_ptr table; + RETURN_ARROW_NOT_OK(file_readers_[file_index]->ReadRowGroup(row_group_index, &table)); + return std::move(table); +} + arrow::Status PackedRecordBatchReader::Close() { LOG_STORAGE_DEBUG_ << "PackedRecordBatchReader::Close(), total read " << read_count_ << " times"; for (int i = 0; i < column_group_states_.size(); ++i) { diff --git a/cpp/test/packed/packed_test_base.h b/cpp/test/packed/packed_test_base.h index 506e1e4..9ea7de2 100644 --- a/cpp/test/packed/packed_test_base.h +++ b/cpp/test/packed/packed_test_base.h @@ -100,6 +100,13 @@ class PackedTestBase : public ::testing::Test { PackedRecordBatchReader pr(*fs_, paths, new_schema, column_offsets, needed_columns, reader_memory_); ASSERT_AND_ARROW_ASSIGN(auto table, pr.ToTable()); + + auto res = pr.ReadRowGroup(0, 0); + if (!res.ok()) { + ASSERT_FALSE(res.ok()); + } + auto row_group = res.value(); + ASSERT_TRUE(row_group->num_rows() > 0); ASSERT_STATUS_OK(pr.Close()); ValidateTableData(table);