Skip to content

Commit

Permalink
Use StreamWriter instead of FileWriter (apache#943)
Browse files Browse the repository at this point in the history
Use StreamWriter instead of FileWriter (apache#943)
  • Loading branch information
avantgardnerio authored Dec 21, 2023
1 parent 3af8465 commit 05c0aac
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 31 deletions.
18 changes: 10 additions & 8 deletions ballista/core/src/execution_plans/shuffle_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
// under the License.

use async_trait::async_trait;
use datafusion::arrow::ipc::reader::StreamReader;
use datafusion::common::stats::Precision;
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::fs::File;
use std::io::BufReader;
use std::pin::Pin;
use std::result;
use std::sync::Arc;
Expand All @@ -31,7 +33,6 @@ use crate::serde::scheduler::{PartitionLocation, PartitionStats};

use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::error::ArrowError;
use datafusion::arrow::ipc::reader::FileReader;
use datafusion::arrow::record_batch::RecordBatch;

use datafusion::error::Result;
Expand Down Expand Up @@ -209,11 +210,11 @@ fn stats_for_partitions(
}

struct LocalShuffleStream {
reader: FileReader<File>,
reader: StreamReader<BufReader<File>>,
}

impl LocalShuffleStream {
pub fn new(reader: FileReader<File>) -> Self {
pub fn new(reader: StreamReader<BufReader<File>>) -> Self {
LocalShuffleStream { reader }
}
}
Expand Down Expand Up @@ -412,13 +413,14 @@ async fn fetch_partition_local(

fn fetch_partition_local_inner(
path: &str,
) -> result::Result<FileReader<File>, BallistaError> {
) -> result::Result<StreamReader<BufReader<File>>, BallistaError> {
let file = File::open(path).map_err(|e| {
BallistaError::General(format!("Failed to open partition file at {path}: {e:?}"))
})?;
FileReader::try_new(file, None).map_err(|e| {
let reader = StreamReader::try_new(file, None).map_err(|e| {
BallistaError::General(format!("Failed to new arrow FileReader at {path}: {e:?}"))
})
})?;
Ok(reader)
}

async fn fetch_partition_object_store(
Expand All @@ -437,7 +439,7 @@ mod tests {
use crate::utils;
use datafusion::arrow::array::{Int32Array, StringArray, UInt32Array};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::ipc::writer::FileWriter;
use datafusion::arrow::ipc::writer::StreamWriter;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::DataFusionError;
use datafusion::physical_expr::expressions::Column;
Expand Down Expand Up @@ -627,7 +629,7 @@ mod tests {
let tmp_dir = tempdir().unwrap();
let file_path = tmp_dir.path().join("shuffle_data");
let file = File::create(&file_path).unwrap();
let mut writer = FileWriter::try_new(file, &schema).unwrap();
let mut writer = StreamWriter::try_new(file, &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();

Expand Down
52 changes: 36 additions & 16 deletions ballista/core/src/execution_plans/shuffle_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use datafusion::arrow::ipc::writer::IpcWriteOptions;
use datafusion::arrow::ipc::CompressionType;
use datafusion::physical_plan::expressions::PhysicalSortExpr;

use datafusion::arrow::ipc::writer::StreamWriter;
use std::any::Any;
use std::fs;
use std::fs::File;
use std::future::Future;
use std::iter::Iterator;
use std::path::PathBuf;
Expand All @@ -42,7 +45,6 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};

use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::{DataFusionError, Result};
use datafusion::physical_plan::common::IPCWriter;
use datafusion::physical_plan::memory::MemoryStream;
use datafusion::physical_plan::metrics::{
self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet,
Expand Down Expand Up @@ -81,6 +83,13 @@ pub struct ShuffleWriterExec {
metrics: ExecutionPlanMetricsSet,
}

pub struct WriteTracker {
pub num_batches: usize,
pub num_rows: usize,
pub writer: StreamWriter<File>,
pub path: PathBuf,
}

#[derive(Debug, Clone)]
struct ShuffleWriteMetrics {
/// Time spend writing batches to shuffle files
Expand Down Expand Up @@ -210,7 +219,7 @@ impl ShuffleWriterExec {
Some(Partitioning::Hash(exprs, num_output_partitions)) => {
// we won't necessary produce output for every possible partition, so we
// create writers on demand
let mut writers: Vec<Option<IPCWriter>> = vec![];
let mut writers: Vec<Option<WriteTracker>> = vec![];
for _ in 0..num_output_partitions {
writers.push(None);
}
Expand All @@ -232,7 +241,9 @@ impl ShuffleWriterExec {
let timer = write_metrics.write_time.timer();
match &mut writers[output_partition] {
Some(w) => {
w.write(&output_batch)?;
w.num_batches += 1;
w.num_rows += output_batch.num_rows();
w.writer.write(&output_batch)?;
}
None => {
let mut path = path.clone();
Expand All @@ -248,14 +259,22 @@ impl ShuffleWriterExec {
.try_with_compression(Some(
CompressionType::LZ4_FRAME,
))?;
let mut writer = IPCWriter::new_with_options(
&path,
stream.schema().as_ref(),
options,
)?;

let file = File::create(path.clone())?;
let mut writer =
StreamWriter::try_new_with_options(
file,
stream.schema().as_ref(),
options,
)?;

writer.write(&output_batch)?;
writers[output_partition] = Some(writer);
writers[output_partition] = Some(WriteTracker {
num_batches: 1,
num_rows: output_batch.num_rows(),
writer,
path,
});
}
}
write_metrics.output_rows.add(output_batch.num_rows());
Expand All @@ -270,22 +289,23 @@ impl ShuffleWriterExec {
for (i, w) in writers.iter_mut().enumerate() {
match w {
Some(w) => {
w.finish()?;
let num_bytes = fs::metadata(&w.path)?.len();
w.writer.finish()?;
debug!(
"Finished writing shuffle partition {} at {:?}. Batches: {}. Rows: {}. Bytes: {}.",
i,
w.path(),
w.path,
w.num_batches,
w.num_rows,
w.num_bytes
num_bytes
);

part_locs.push(ShuffleWritePartition {
partition_id: i as u64,
path: w.path().to_string_lossy().to_string(),
num_batches: w.num_batches,
num_rows: w.num_rows,
num_bytes: w.num_bytes,
path: w.path.to_string_lossy().to_string(),
num_batches: w.num_batches as u64,
num_rows: w.num_rows as u64,
num_bytes,
});
}
None => {}
Expand Down
5 changes: 3 additions & 2 deletions ballista/core/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ use crate::serde::scheduler::PartitionStats;
use async_trait::async_trait;
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::ipc::writer::IpcWriteOptions;
use datafusion::arrow::ipc::writer::StreamWriter;
use datafusion::arrow::ipc::CompressionType;
use datafusion::arrow::{ipc::writer::FileWriter, record_batch::RecordBatch};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::physical_plan::{CsvExec, ParquetExec};
use datafusion::error::DataFusionError;
use datafusion::execution::context::{
Expand Down Expand Up @@ -89,7 +90,7 @@ pub async fn write_stream_to_disk(
.try_with_compression(Some(CompressionType::LZ4_FRAME))?;

let mut writer =
FileWriter::try_new_with_options(file, stream.schema().as_ref(), options)?;
StreamWriter::try_new_with_options(file, stream.schema().as_ref(), options)?;

while let Some(result) = stream.next().await {
let batch = result?;
Expand Down
9 changes: 4 additions & 5 deletions ballista/executor/src/flight_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Implementation of the Apache Arrow Flight protocol that wraps an executor.
use arrow::ipc::reader::StreamReader;
use std::convert::TryFrom;
use std::fs::File;
use std::pin::Pin;
Expand All @@ -34,9 +35,7 @@ use arrow_flight::{
FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
PutResult, SchemaResult, Ticket,
};
use datafusion::arrow::{
error::ArrowError, ipc::reader::FileReader, record_batch::RecordBatch,
};
use datafusion::arrow::{error::ArrowError, record_batch::RecordBatch};
use futures::{Stream, StreamExt, TryStreamExt};
use log::{debug, info};
use std::io::{Read, Seek};
Expand Down Expand Up @@ -97,7 +96,7 @@ impl FlightService for BallistaFlightService {
})
.map_err(|e| from_ballista_err(&e))?;
let reader =
FileReader::try_new(file, None).map_err(|e| from_arrow_err(&e))?;
StreamReader::try_new(file, None).map_err(|e| from_arrow_err(&e))?;

let (tx, rx) = channel(2);
let schema = reader.schema();
Expand Down Expand Up @@ -207,7 +206,7 @@ impl FlightService for BallistaFlightService {
}

fn read_partition<T>(
reader: FileReader<T>,
reader: StreamReader<std::io::BufReader<T>>,
tx: Sender<Result<RecordBatch, FlightError>>,
) -> Result<(), FlightError>
where
Expand Down

0 comments on commit 05c0aac

Please sign in to comment.