Skip to content

Commit

Permalink
[GLUTEN-8481][VL] Clean up shuffle reader cpp code
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma authored Jan 10, 2025
1 parent 66e816f commit 7af129c
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 131 deletions.
1 change: 0 additions & 1 deletion cpp/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ set(SPARK_COLUMNAR_PLUGIN_SRCS
shuffle/RandomPartitioner.cc
shuffle/RoundRobinPartitioner.cc
shuffle/ShuffleMemoryPool.cc
shuffle/ShuffleReader.cc
shuffle/ShuffleWriter.cc
shuffle/SinglePartitioner.cc
shuffle/Spill.cc
Expand Down
1 change: 0 additions & 1 deletion cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,6 @@ JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper
jlong shuffleReaderHandle) {
JNI_METHOD_START
auto reader = ObjectStore::retrieve<ShuffleReader>(shuffleReaderHandle);
GLUTEN_THROW_NOT_OK(reader->close());
ObjectStore::release(shuffleReaderHandle);
JNI_METHOD_END()
}
Expand Down
55 changes: 0 additions & 55 deletions cpp/core/shuffle/ShuffleReader.cc

This file was deleted.

49 changes: 4 additions & 45 deletions cpp/core/shuffle/ShuffleReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,63 +17,22 @@

#pragma once

#include "memory/ColumnarBatch.h"

#include <arrow/ipc/message.h>
#include <arrow/ipc/options.h>

#include "Options.h"
#include "compute/ResultIterator.h"
#include "utils/Compression.h"

namespace gluten {

class DeserializerFactory {
public:
virtual ~DeserializerFactory() = default;

virtual std::unique_ptr<ColumnarBatchIterator> createDeserializer(std::shared_ptr<arrow::io::InputStream> in) = 0;

virtual arrow::MemoryPool* getPool() = 0;

virtual int64_t getDecompressTime() = 0;

virtual int64_t getDeserializeTime() = 0;

virtual ShuffleWriterType getShuffleWriterType() = 0;
};

class ShuffleReader {
public:
explicit ShuffleReader(std::unique_ptr<DeserializerFactory> factory);

virtual ~ShuffleReader() = default;

// FIXME iterator should be unique_ptr or un-copyable singleton
virtual std::shared_ptr<ResultIterator> readStream(std::shared_ptr<arrow::io::InputStream> in);

arrow::Status close();

int64_t getDecompressTime() const;

int64_t getIpcTime() const;

int64_t getDeserializeTime() const;

arrow::MemoryPool* getPool() const;

ShuffleWriterType getShuffleWriterType() const;
virtual std::shared_ptr<ResultIterator> readStream(std::shared_ptr<arrow::io::InputStream> in) = 0;

protected:
arrow::MemoryPool* pool_;
int64_t decompressTime_ = 0;
int64_t deserializeTime_ = 0;
virtual int64_t getDecompressTime() const = 0;

ShuffleWriterType shuffleWriterType_;
virtual int64_t getDeserializeTime() const = 0;

private:
std::shared_ptr<arrow::Schema> schema_;
std::unique_ptr<DeserializerFactory> factory_;
virtual arrow::MemoryPool* getPool() const = 0;
};

} // namespace gluten
2 changes: 1 addition & 1 deletion cpp/velox/compute/VeloxRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ std::shared_ptr<ShuffleReader> VeloxRuntime::createShuffleReader(
auto codec = gluten::createArrowIpcCodec(options.compressionType, options.codecBackend);
auto ctxVeloxPool = memoryManager()->getLeafMemoryPool();
auto veloxCompressionType = facebook::velox::common::stringToCompressionKind(options.compressionTypeStr);
auto deserializerFactory = std::make_unique<gluten::VeloxColumnarBatchDeserializerFactory>(
auto deserializerFactory = std::make_unique<gluten::VeloxShuffleReaderDeserializerFactory>(
schema,
std::move(codec),
veloxCompressionType,
Expand Down
40 changes: 26 additions & 14 deletions cpp/velox/shuffle/VeloxShuffleReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
* limitations under the License.
*/

#include "VeloxShuffleReader.h"
#include "GlutenByteStream.h"
#include "shuffle/VeloxShuffleReader.h"

#include <arrow/array/array_binary.h>
#include <arrow/io/buffered.h>

#include "memory/VeloxColumnarBatch.h"
#include "shuffle/GlutenByteStream.h"
#include "shuffle/Payload.h"
#include "shuffle/Utils.h"
#include "utils/Common.h"
Expand Down Expand Up @@ -576,7 +576,7 @@ std::shared_ptr<ColumnarBatch> VeloxRssSortShuffleReaderDeserializer::next() {
return std::make_shared<VeloxColumnarBatch>(std::move(rowVector));
}

VeloxColumnarBatchDeserializerFactory::VeloxColumnarBatchDeserializerFactory(
VeloxShuffleReaderDeserializerFactory::VeloxShuffleReaderDeserializerFactory(
const std::shared_ptr<arrow::Schema>& schema,
const std::shared_ptr<arrow::util::Codec>& codec,
const facebook::velox::common::CompressionKind veloxCompressionType,
Expand All @@ -598,7 +598,7 @@ VeloxColumnarBatchDeserializerFactory::VeloxColumnarBatchDeserializerFactory(
initFromSchema();
}

std::unique_ptr<ColumnarBatchIterator> VeloxColumnarBatchDeserializerFactory::createDeserializer(
std::unique_ptr<ColumnarBatchIterator> VeloxShuffleReaderDeserializerFactory::createDeserializer(
std::shared_ptr<arrow::io::InputStream> in) {
switch (shuffleWriterType_) {
case ShuffleWriterType::kHashShuffle:
Expand Down Expand Up @@ -635,23 +635,19 @@ std::unique_ptr<ColumnarBatchIterator> VeloxColumnarBatchDeserializerFactory::cr
}
}

arrow::MemoryPool* VeloxColumnarBatchDeserializerFactory::getPool() {
arrow::MemoryPool* VeloxShuffleReaderDeserializerFactory::getPool() {
return memoryPool_;
}

ShuffleWriterType VeloxColumnarBatchDeserializerFactory::getShuffleWriterType() {
return shuffleWriterType_;
}

int64_t VeloxColumnarBatchDeserializerFactory::getDecompressTime() {
int64_t VeloxShuffleReaderDeserializerFactory::getDecompressTime() {
return decompressTime_;
}

int64_t VeloxColumnarBatchDeserializerFactory::getDeserializeTime() {
int64_t VeloxShuffleReaderDeserializerFactory::getDeserializeTime() {
return deserializeTime_;
}

void VeloxColumnarBatchDeserializerFactory::initFromSchema() {
void VeloxShuffleReaderDeserializerFactory::initFromSchema() {
GLUTEN_ASSIGN_OR_THROW(auto arrowColumnTypes, toShuffleTypeId(schema_->fields()));
isValidityBuffer_.reserve(arrowColumnTypes.size());
for (size_t i = 0; i < arrowColumnTypes.size(); ++i) {
Expand Down Expand Up @@ -681,7 +677,23 @@ void VeloxColumnarBatchDeserializerFactory::initFromSchema() {
}
}

VeloxShuffleReader::VeloxShuffleReader(std::unique_ptr<DeserializerFactory> factory)
: ShuffleReader(std::move(factory)) {}
VeloxShuffleReader::VeloxShuffleReader(std::unique_ptr<VeloxShuffleReaderDeserializerFactory> factory)
: factory_(std::move(factory)) {}

std::shared_ptr<ResultIterator> VeloxShuffleReader::readStream(std::shared_ptr<arrow::io::InputStream> in) {
return std::make_shared<ResultIterator>(factory_->createDeserializer(in));
}

arrow::MemoryPool* VeloxShuffleReader::getPool() const {
return factory_->getPool();
}

int64_t VeloxShuffleReader::getDecompressTime() const {
return factory_->getDecompressTime();
}

int64_t VeloxShuffleReader::getDeserializeTime() const {
return factory_->getDeserializeTime();
}

} // namespace gluten
33 changes: 20 additions & 13 deletions cpp/velox/shuffle/VeloxShuffleReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@

#pragma once

#include "operators/serializer/VeloxColumnarBatchSerializer.h"
#include "shuffle/Payload.h"
#include "shuffle/ShuffleReader.h"
#include "shuffle/VeloxSortShuffleWriter.h"
#include "utils/Timer.h"

#include "velox/serializers/PrestoSerializer.h"
#include "velox/type/Type.h"
#include "velox/vector/ComplexVector.h"

#include <velox/serializers/PrestoSerializer.h>

namespace gluten {

class VeloxHashShuffleReaderDeserializer final : public ColumnarBatchIterator {
Expand Down Expand Up @@ -134,9 +132,9 @@ class VeloxRssSortShuffleReaderDeserializer : public ColumnarBatchIterator {
std::shared_ptr<VeloxInputStream> in_;
};

class VeloxColumnarBatchDeserializerFactory : public DeserializerFactory {
class VeloxShuffleReaderDeserializerFactory {
public:
VeloxColumnarBatchDeserializerFactory(
VeloxShuffleReaderDeserializerFactory(
const std::shared_ptr<arrow::Schema>& schema,
const std::shared_ptr<arrow::util::Codec>& codec,
const facebook::velox::common::CompressionKind veloxCompressionType,
Expand All @@ -147,15 +145,13 @@ class VeloxColumnarBatchDeserializerFactory : public DeserializerFactory {
std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool,
ShuffleWriterType shuffleWriterType);

std::unique_ptr<ColumnarBatchIterator> createDeserializer(std::shared_ptr<arrow::io::InputStream> in) override;

arrow::MemoryPool* getPool() override;
std::unique_ptr<ColumnarBatchIterator> createDeserializer(std::shared_ptr<arrow::io::InputStream> in);

int64_t getDecompressTime() override;
arrow::MemoryPool* getPool();

int64_t getDeserializeTime() override;
int64_t getDecompressTime();

ShuffleWriterType getShuffleWriterType() override;
int64_t getDeserializeTime();

private:
void initFromSchema();
Expand All @@ -180,6 +176,17 @@ class VeloxColumnarBatchDeserializerFactory : public DeserializerFactory {

class VeloxShuffleReader final : public ShuffleReader {
public:
VeloxShuffleReader(std::unique_ptr<DeserializerFactory> factory);
VeloxShuffleReader(std::unique_ptr<VeloxShuffleReaderDeserializerFactory> factory);

std::shared_ptr<ResultIterator> readStream(std::shared_ptr<arrow::io::InputStream> in) override;

int64_t getDecompressTime() const override;

int64_t getDeserializeTime() const override;

arrow::MemoryPool* getPool() const override;

private:
std::unique_ptr<VeloxShuffleReaderDeserializerFactory> factory_;
};
} // namespace gluten
2 changes: 1 addition & 1 deletion cpp/velox/utils/tests/VeloxShuffleWriterTestBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ class VeloxShuffleWriterTest : public ::testing::TestWithParam<ShuffleTestParams
facebook::velox::serializer::presto::PrestoVectorSerde::registerVectorSerde();
}
// Set batchSize to a large value to make all batches are merged by reader.
auto deserializerFactory = std::make_unique<gluten::VeloxColumnarBatchDeserializerFactory>(
auto deserializerFactory = std::make_unique<gluten::VeloxShuffleReaderDeserializerFactory>(
schema,
std::move(codec),
veloxCompressionType,
Expand Down

0 comments on commit 7af129c

Please sign in to comment.