Skip to content

Commit

Permalink
add partitioning scheme for unresolved shuffle and shuffle reader exec (
Browse files Browse the repository at this point in the history
#1144)

* add partitioning scheme for unresolved shuffle and shuffle reader exec

* make default_codec a property of BallistaPhysicalExtensionCodec

* tests
  • Loading branch information
onursatici authored Dec 19, 2024
1 parent 4ea309f commit f840585
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 40 deletions.
3 changes: 2 additions & 1 deletion ballista/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ message ShuffleWriterExecNode {
message UnresolvedShuffleExecNode {
uint32 stage_id = 1;
datafusion_common.Schema schema = 2;
uint32 output_partition_count = 4;
datafusion.Partitioning partitioning = 5;
}

message ShuffleReaderExecNode {
repeated ShuffleReaderPartition partition = 1;
datafusion_common.Schema schema = 2;
// The stage to read from
uint32 stage_id = 3;
datafusion.Partitioning partitioning = 4;
}

message ShuffleReaderPartition {
Expand Down
7 changes: 4 additions & 3 deletions ballista/core/src/execution_plans/shuffle_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,11 @@ impl ShuffleReaderExec {
stage_id: usize,
partition: Vec<Vec<PartitionLocation>>,
schema: SchemaRef,
partitioning: Partitioning,
) -> Result<Self> {
let properties = PlanProperties::new(
datafusion::physical_expr::EquivalenceProperties::new(schema.clone()),
// TODO partitioning may be known and could be populated here
// see https://github.com/apache/arrow-datafusion/issues/758
Partitioning::UnknownPartitioning(partition.len()),
partitioning,
datafusion::physical_plan::ExecutionMode::Bounded,
);
Ok(Self {
Expand Down Expand Up @@ -134,6 +133,7 @@ impl ExecutionPlan for ShuffleReaderExec {
self.stage_id,
self.partition.clone(),
self.schema.clone(),
self.properties().output_partitioning().clone(),
)?))
}

Expand Down Expand Up @@ -553,6 +553,7 @@ mod tests {
input_stage_id,
vec![partitions],
Arc::new(schema),
Partitioning::UnknownPartitioning(4),
)?;
let mut stream = shuffle_reader_exec.execute(0, task_ctx)?;
let batches = utils::collect_stream(&mut stream).await;
Expand Down
18 changes: 8 additions & 10 deletions ballista/core/src/execution_plans/unresolved_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,16 @@ pub struct UnresolvedShuffleExec {

impl UnresolvedShuffleExec {
/// Create a new UnresolvedShuffleExec
pub fn new(
stage_id: usize,
schema: SchemaRef,
output_partition_count: usize,
) -> Self {
pub fn new(stage_id: usize, schema: SchemaRef, partitioning: Partitioning) -> Self {
let properties = PlanProperties::new(
datafusion::physical_expr::EquivalenceProperties::new(schema.clone()),
// TODO the output partition is known and should be populated here!
// see https://github.com/apache/arrow-datafusion/issues/758
Partitioning::UnknownPartitioning(output_partition_count),
partitioning,
datafusion::physical_plan::ExecutionMode::Bounded,
);
Self {
stage_id,
schema,
output_partition_count,
output_partition_count: properties.partitioning.partition_count(),
properties,
}
}
Expand All @@ -75,7 +69,11 @@ impl DisplayAs for UnresolvedShuffleExec {
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "UnresolvedShuffleExec")
write!(
f,
"UnresolvedShuffleExec: {:?}",
self.properties().output_partitioning()
)
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions ballista/core/src/serde/generated/ballista.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ pub struct UnresolvedShuffleExecNode {
pub stage_id: u32,
#[prost(message, optional, tag = "2")]
pub schema: ::core::option::Option<::datafusion_proto_common::Schema>,
#[prost(uint32, tag = "4")]
pub output_partition_count: u32,
#[prost(message, optional, tag = "5")]
pub partitioning: ::core::option::Option<::datafusion_proto::protobuf::Partitioning>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ShuffleReaderExecNode {
Expand All @@ -54,6 +54,8 @@ pub struct ShuffleReaderExecNode {
/// The stage to read from
#[prost(uint32, tag = "3")]
pub stage_id: u32,
#[prost(message, optional, tag = "4")]
pub partitioning: ::core::option::Option<::datafusion_proto::protobuf::Partitioning>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ShuffleReaderPartition {
Expand Down
146 changes: 131 additions & 15 deletions ballista/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction};

use arrow_flight::sql::ProstMessageExt;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::{DataFusionError, Result};
use datafusion::execution::FunctionRegistry;
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
Expand All @@ -29,6 +30,9 @@ use datafusion_proto::logical_plan::file_formats::{
JsonLogicalExtensionCodec, ParquetLogicalExtensionCodec,
};
use datafusion_proto::physical_plan::from_proto::parse_protobuf_hash_partitioning;
use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning;
use datafusion_proto::physical_plan::to_proto::serialize_partitioning;
use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
use datafusion_proto::protobuf::proto_error;
use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
use datafusion_proto::{
Expand Down Expand Up @@ -244,8 +248,18 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec {
}
}

#[derive(Debug, Default)]
pub struct BallistaPhysicalExtensionCodec {}
#[derive(Debug)]
pub struct BallistaPhysicalExtensionCodec {
default_codec: Arc<dyn PhysicalExtensionCodec>,
}

impl Default for BallistaPhysicalExtensionCodec {
fn default() -> Self {
Self {
default_codec: Arc::new(DefaultPhysicalExtensionCodec {}),
}
}
}

impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
fn try_decode(
Expand All @@ -272,14 +286,11 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
PhysicalPlanType::ShuffleWriter(shuffle_writer) => {
let input = inputs[0].clone();

let default_codec =
datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec {};

let shuffle_output_partitioning = parse_protobuf_hash_partitioning(
shuffle_writer.output_partitioning.as_ref(),
registry,
input.schema().as_ref(),
&default_codec,
self.default_codec.as_ref(),
)?;

Ok(Arc::new(ShuffleWriterExec::try_new(
Expand All @@ -292,7 +303,8 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
}
PhysicalPlanType::ShuffleReader(shuffle_reader) => {
let stage_id = shuffle_reader.stage_id as usize;
let schema = Arc::new(convert_required!(shuffle_reader.schema)?);
let schema: SchemaRef =
Arc::new(convert_required!(shuffle_reader.schema)?);
let partition_location: Vec<Vec<PartitionLocation>> = shuffle_reader
.partition
.iter()
Expand All @@ -309,16 +321,37 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
.collect::<Result<Vec<_>, _>>()
})
.collect::<Result<Vec<_>, DataFusionError>>()?;
let shuffle_reader =
ShuffleReaderExec::try_new(stage_id, partition_location, schema)?;
let partitioning = parse_protobuf_partitioning(
shuffle_reader.partitioning.as_ref(),
registry,
schema.as_ref(),
self.default_codec.as_ref(),
)?;
let partitioning = partitioning
.ok_or_else(|| proto_error("missing required partitioning field"))?;
let shuffle_reader = ShuffleReaderExec::try_new(
stage_id,
partition_location,
schema,
partitioning,
)?;
Ok(Arc::new(shuffle_reader))
}
PhysicalPlanType::UnresolvedShuffle(unresolved_shuffle) => {
let schema = Arc::new(convert_required!(unresolved_shuffle.schema)?);
let schema: SchemaRef =
Arc::new(convert_required!(unresolved_shuffle.schema)?);
let partitioning = parse_protobuf_partitioning(
unresolved_shuffle.partitioning.as_ref(),
registry,
schema.as_ref(),
self.default_codec.as_ref(),
)?;
let partitioning = partitioning
.ok_or_else(|| proto_error("missing required partitioning field"))?;
Ok(Arc::new(UnresolvedShuffleExec::new(
unresolved_shuffle.stage_id as usize,
schema,
unresolved_shuffle.output_partition_count as usize,
partitioning,
)))
}
}
Expand All @@ -334,12 +367,10 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
// to get the true output partitioning
let output_partitioning = match exec.shuffle_output_partitioning() {
Some(Partitioning::Hash(exprs, partition_count)) => {
let default_codec =
datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec {};
Some(datafusion_proto::protobuf::PhysicalHashRepartition {
hash_expr: exprs
.iter()
.map(|expr|datafusion_proto::physical_plan::to_proto::serialize_physical_expr(&expr.clone(), &default_codec))
.map(|expr|datafusion_proto::physical_plan::to_proto::serialize_physical_expr(&expr.clone(), self.default_codec.as_ref()))
.collect::<Result<Vec<_>, DataFusionError>>()?,
partition_count: *partition_count as u64,
})
Expand Down Expand Up @@ -387,12 +418,17 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
.collect::<Result<Vec<_>, _>>()?,
});
}
let partitioning = serialize_partitioning(
&exec.properties().partitioning,
self.default_codec.as_ref(),
)?;
let proto = protobuf::BallistaPhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::ShuffleReader(
protobuf::ShuffleReaderExecNode {
stage_id,
partition,
schema: Some(exec.schema().as_ref().try_into()?),
partitioning: Some(partitioning),
},
)),
};
Expand All @@ -404,12 +440,16 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {

Ok(())
} else if let Some(exec) = node.as_any().downcast_ref::<UnresolvedShuffleExec>() {
let partitioning = serialize_partitioning(
&exec.properties().partitioning,
self.default_codec.as_ref(),
)?;
let proto = protobuf::BallistaPhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::UnresolvedShuffle(
protobuf::UnresolvedShuffleExecNode {
stage_id: exec.stage_id as u32,
schema: Some(exec.schema().as_ref().try_into()?),
output_partition_count: exec.output_partition_count as u32,
partitioning: Some(partitioning),
},
)),
};
Expand Down Expand Up @@ -449,6 +489,11 @@ struct FileFormatProto {

#[cfg(test)]
mod test {
use super::*;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::execution::registry::MemoryFunctionRegistry;
use datafusion::physical_plan::expressions::col;
use datafusion::physical_plan::Partitioning;
use datafusion::{
common::DFSchema,
datasource::file_format::{parquet::ParquetFormatFactory, DefaultFileType},
Expand Down Expand Up @@ -493,4 +538,75 @@ mod test {
assert_eq!(o.to_string(), d.to_string())
//logical_plan.
}

fn create_test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]))
}

#[tokio::test]
async fn test_unresolved_shuffle_exec_roundtrip() {
let schema = create_test_schema();
let partitioning =
Partitioning::Hash(vec![col("id", schema.as_ref()).unwrap()], 4);

let original_exec = UnresolvedShuffleExec::new(
1, // stage_id
schema.clone(),
partitioning.clone(),
);

let codec = BallistaPhysicalExtensionCodec::default();
let mut buf: Vec<u8> = vec![];
codec
.try_encode(Arc::new(original_exec.clone()), &mut buf)
.unwrap();

let registry = MemoryFunctionRegistry::new();
let decoded_plan = codec.try_decode(&buf, &[], &registry).unwrap();

let decoded_exec = decoded_plan
.as_any()
.downcast_ref::<UnresolvedShuffleExec>()
.expect("Expected UnresolvedShuffleExec");

assert_eq!(decoded_exec.stage_id, 1);
assert_eq!(decoded_exec.schema().as_ref(), schema.as_ref());
assert_eq!(&decoded_exec.properties().partitioning, &partitioning);
}

#[tokio::test]
async fn test_shuffle_reader_exec_roundtrip() {
let schema = create_test_schema();
let partitioning =
Partitioning::Hash(vec![col("id", schema.as_ref()).unwrap()], 4);

let original_exec = ShuffleReaderExec::try_new(
1, // stage_id
Vec::new(),
schema.clone(),
partitioning.clone(),
)
.unwrap();

let codec = BallistaPhysicalExtensionCodec::default();
let mut buf: Vec<u8> = vec![];
codec
.try_encode(Arc::new(original_exec.clone()), &mut buf)
.unwrap();

let registry = MemoryFunctionRegistry::new();
let decoded_plan = codec.try_decode(&buf, &[], &registry).unwrap();

let decoded_exec = decoded_plan
.as_any()
.downcast_ref::<ShuffleReaderExec>()
.expect("Expected ShuffleReaderExec");

assert_eq!(decoded_exec.stage_id, 1);
assert_eq!(decoded_exec.schema().as_ref(), schema.as_ref());
assert_eq!(&decoded_exec.properties().partitioning, &partitioning);
}
}
Loading

0 comments on commit f840585

Please sign in to comment.