diff --git a/ballista/core/proto/ballista.proto b/ballista/core/proto/ballista.proto index cb3c148b4..70e8bba53 100644 --- a/ballista/core/proto/ballista.proto +++ b/ballista/core/proto/ballista.proto @@ -50,7 +50,7 @@ message ShuffleWriterExecNode { message UnresolvedShuffleExecNode { uint32 stage_id = 1; datafusion_common.Schema schema = 2; - uint32 output_partition_count = 4; + datafusion.Partitioning partitioning = 5; } message ShuffleReaderExecNode { @@ -58,6 +58,7 @@ message ShuffleReaderExecNode { datafusion_common.Schema schema = 2; // The stage to read from uint32 stage_id = 3; + datafusion.Partitioning partitioning = 4; } message ShuffleReaderPartition { diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index 4a2c25b87..f50d6a291 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -74,12 +74,11 @@ impl ShuffleReaderExec { stage_id: usize, partition: Vec>, schema: SchemaRef, + partitioning: Partitioning, ) -> Result { let properties = PlanProperties::new( datafusion::physical_expr::EquivalenceProperties::new(schema.clone()), - // TODO partitioning may be known and could be populated here - // see https://github.com/apache/arrow-datafusion/issues/758 - Partitioning::UnknownPartitioning(partition.len()), + partitioning, datafusion::physical_plan::ExecutionMode::Bounded, ); Ok(Self { @@ -134,6 +133,7 @@ impl ExecutionPlan for ShuffleReaderExec { self.stage_id, self.partition.clone(), self.schema.clone(), + self.properties().output_partitioning().clone(), )?)) } @@ -553,6 +553,7 @@ mod tests { input_stage_id, vec![partitions], Arc::new(schema), + Partitioning::UnknownPartitioning(4), )?; let mut stream = shuffle_reader_exec.execute(0, task_ctx)?; let batches = utils::collect_stream(&mut stream).await; diff --git a/ballista/core/src/execution_plans/unresolved_shuffle.rs b/ballista/core/src/execution_plans/unresolved_shuffle.rs index e227e2ac3..9d4d3077d 100644 --- a/ballista/core/src/execution_plans/unresolved_shuffle.rs +++ b/ballista/core/src/execution_plans/unresolved_shuffle.rs @@ -46,22 +46,16 @@ pub struct UnresolvedShuffleExec { impl UnresolvedShuffleExec { /// Create a new UnresolvedShuffleExec - pub fn new( - stage_id: usize, - schema: SchemaRef, - output_partition_count: usize, - ) -> Self { + pub fn new(stage_id: usize, schema: SchemaRef, partitioning: Partitioning) -> Self { let properties = PlanProperties::new( datafusion::physical_expr::EquivalenceProperties::new(schema.clone()), - // TODO the output partition is known and should be populated here! - // see https://github.com/apache/arrow-datafusion/issues/758 - Partitioning::UnknownPartitioning(output_partition_count), + partitioning, datafusion::physical_plan::ExecutionMode::Bounded, ); Self { stage_id, schema, - output_partition_count, + output_partition_count: properties.partitioning.partition_count(), properties, } } @@ -75,7 +69,11 @@ impl DisplayAs for UnresolvedShuffleExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "UnresolvedShuffleExec") + write!( + f, + "UnresolvedShuffleExec: {:?}", + self.properties().output_partitioning() + ) } } } diff --git a/ballista/core/src/serde/generated/ballista.rs b/ballista/core/src/serde/generated/ballista.rs index d4faef825..ed73ab404 100644 --- a/ballista/core/src/serde/generated/ballista.rs +++ b/ballista/core/src/serde/generated/ballista.rs @@ -42,8 +42,8 @@ pub struct UnresolvedShuffleExecNode { pub stage_id: u32, #[prost(message, optional, tag = "2")] pub schema: ::core::option::Option<::datafusion_proto_common::Schema>, - #[prost(uint32, tag = "4")] - pub output_partition_count: u32, + #[prost(message, optional, tag = "5")] + pub partitioning: ::core::option::Option<::datafusion_proto::protobuf::Partitioning>, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct ShuffleReaderExecNode { @@ -54,6 +54,8 @@ pub struct ShuffleReaderExecNode { /// The stage to read from #[prost(uint32, tag = "3")] pub stage_id: u32, + #[prost(message, optional, tag = "4")] + pub partitioning: ::core::option::Option<::datafusion_proto::protobuf::Partitioning>, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct ShuffleReaderPartition { diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index d7d6474f7..84cf80684 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -21,6 +21,7 @@ use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction}; use arrow_flight::sql::ProstMessageExt; +use datafusion::arrow::datatypes::SchemaRef; use datafusion::common::{DataFusionError, Result}; use datafusion::execution::FunctionRegistry; use datafusion::physical_plan::{ExecutionPlan, Partitioning}; @@ -29,6 +30,9 @@ use datafusion_proto::logical_plan::file_formats::{ JsonLogicalExtensionCodec, ParquetLogicalExtensionCodec, }; use datafusion_proto::physical_plan::from_proto::parse_protobuf_hash_partitioning; +use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning; +use datafusion_proto::physical_plan::to_proto::serialize_partitioning; +use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; use datafusion_proto::{ @@ -244,8 +248,18 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { } } -#[derive(Debug, Default)] -pub struct BallistaPhysicalExtensionCodec {} +#[derive(Debug)] +pub struct BallistaPhysicalExtensionCodec { + default_codec: Arc, +} + +impl Default for BallistaPhysicalExtensionCodec { + fn default() -> Self { + Self { + default_codec: Arc::new(DefaultPhysicalExtensionCodec {}), + } + } +} impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec { fn try_decode( @@ -272,14 +286,11 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec { PhysicalPlanType::ShuffleWriter(shuffle_writer) => { let input = inputs[0].clone(); - let default_codec = - datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec {}; - let shuffle_output_partitioning = parse_protobuf_hash_partitioning( shuffle_writer.output_partitioning.as_ref(), registry, input.schema().as_ref(), - &default_codec, + self.default_codec.as_ref(), )?; Ok(Arc::new(ShuffleWriterExec::try_new( @@ -292,7 +303,8 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec { } PhysicalPlanType::ShuffleReader(shuffle_reader) => { let stage_id = shuffle_reader.stage_id as usize; - let schema = Arc::new(convert_required!(shuffle_reader.schema)?); + let schema: SchemaRef = + Arc::new(convert_required!(shuffle_reader.schema)?); let partition_location: Vec> = shuffle_reader .partition .iter() @@ -309,16 +321,37 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec { .collect::, _>>() }) .collect::, DataFusionError>>()?; - let shuffle_reader = - ShuffleReaderExec::try_new(stage_id, partition_location, schema)?; + let partitioning = parse_protobuf_partitioning( + shuffle_reader.partitioning.as_ref(), + registry, + schema.as_ref(), + self.default_codec.as_ref(), + )?; + let partitioning = partitioning + .ok_or_else(|| proto_error("missing required partitioning field"))?; + let shuffle_reader = ShuffleReaderExec::try_new( + stage_id, + partition_location, + schema, + partitioning, + )?; Ok(Arc::new(shuffle_reader)) } PhysicalPlanType::UnresolvedShuffle(unresolved_shuffle) => { - let schema = Arc::new(convert_required!(unresolved_shuffle.schema)?); + let schema: SchemaRef = + Arc::new(convert_required!(unresolved_shuffle.schema)?); + let partitioning = parse_protobuf_partitioning( + unresolved_shuffle.partitioning.as_ref(), + registry, + schema.as_ref(), + self.default_codec.as_ref(), + )?; + let partitioning = partitioning + .ok_or_else(|| proto_error("missing required partitioning field"))?; Ok(Arc::new(UnresolvedShuffleExec::new( unresolved_shuffle.stage_id as usize, schema, - unresolved_shuffle.output_partition_count as usize, + partitioning, ))) } } @@ -334,12 +367,10 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec { // to get the true output partitioning let output_partitioning = match exec.shuffle_output_partitioning() { Some(Partitioning::Hash(exprs, partition_count)) => { - let default_codec = - datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec {}; Some(datafusion_proto::protobuf::PhysicalHashRepartition { hash_expr: exprs .iter() - .map(|expr|datafusion_proto::physical_plan::to_proto::serialize_physical_expr(&expr.clone(), &default_codec)) + .map(|expr|datafusion_proto::physical_plan::to_proto::serialize_physical_expr(&expr.clone(), self.default_codec.as_ref())) .collect::, DataFusionError>>()?, partition_count: *partition_count as u64, }) @@ -387,12 +418,17 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec { .collect::, _>>()?, }); } + let partitioning = serialize_partitioning( + &exec.properties().partitioning, + self.default_codec.as_ref(), + )?; let proto = protobuf::BallistaPhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ShuffleReader( protobuf::ShuffleReaderExecNode { stage_id, partition, schema: Some(exec.schema().as_ref().try_into()?), + partitioning: Some(partitioning), }, )), }; @@ -404,12 +440,16 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec { Ok(()) } else if let Some(exec) = node.as_any().downcast_ref::() { + let partitioning = serialize_partitioning( + &exec.properties().partitioning, + self.default_codec.as_ref(), + )?; let proto = protobuf::BallistaPhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::UnresolvedShuffle( protobuf::UnresolvedShuffleExecNode { stage_id: exec.stage_id as u32, schema: Some(exec.schema().as_ref().try_into()?), - output_partition_count: exec.output_partition_count as u32, + partitioning: Some(partitioning), }, )), }; @@ -449,6 +489,11 @@ struct FileFormatProto { #[cfg(test)] mod test { + use super::*; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::execution::registry::MemoryFunctionRegistry; + use datafusion::physical_plan::expressions::col; + use datafusion::physical_plan::Partitioning; use datafusion::{ common::DFSchema, datasource::file_format::{parquet::ParquetFormatFactory, DefaultFileType}, @@ -493,4 +538,75 @@ mod test { assert_eq!(o.to_string(), d.to_string()) //logical_plan. } + + fn create_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])) + } + + #[tokio::test] + async fn test_unresolved_shuffle_exec_roundtrip() { + let schema = create_test_schema(); + let partitioning = + Partitioning::Hash(vec![col("id", schema.as_ref()).unwrap()], 4); + + let original_exec = UnresolvedShuffleExec::new( + 1, // stage_id + schema.clone(), + partitioning.clone(), + ); + + let codec = BallistaPhysicalExtensionCodec::default(); + let mut buf: Vec = vec![]; + codec + .try_encode(Arc::new(original_exec.clone()), &mut buf) + .unwrap(); + + let registry = MemoryFunctionRegistry::new(); + let decoded_plan = codec.try_decode(&buf, &[], ®istry).unwrap(); + + let decoded_exec = decoded_plan + .as_any() + .downcast_ref::() + .expect("Expected UnresolvedShuffleExec"); + + assert_eq!(decoded_exec.stage_id, 1); + assert_eq!(decoded_exec.schema().as_ref(), schema.as_ref()); + assert_eq!(&decoded_exec.properties().partitioning, &partitioning); + } + + #[tokio::test] + async fn test_shuffle_reader_exec_roundtrip() { + let schema = create_test_schema(); + let partitioning = + Partitioning::Hash(vec![col("id", schema.as_ref()).unwrap()], 4); + + let original_exec = ShuffleReaderExec::try_new( + 1, // stage_id + Vec::new(), + schema.clone(), + partitioning.clone(), + ) + .unwrap(); + + let codec = BallistaPhysicalExtensionCodec::default(); + let mut buf: Vec = vec![]; + codec + .try_encode(Arc::new(original_exec.clone()), &mut buf) + .unwrap(); + + let registry = MemoryFunctionRegistry::new(); + let decoded_plan = codec.try_decode(&buf, &[], ®istry).unwrap(); + + let decoded_exec = decoded_plan + .as_any() + .downcast_ref::() + .expect("Expected ShuffleReaderExec"); + + assert_eq!(decoded_exec.stage_id, 1); + assert_eq!(decoded_exec.schema().as_ref(), schema.as_ref()); + assert_eq!(&decoded_exec.properties().partitioning, &partitioning); + } } diff --git a/ballista/scheduler/src/planner.rs b/ballista/scheduler/src/planner.rs index 47500ac11..fc32262e0 100644 --- a/ballista/scheduler/src/planner.rs +++ b/ballista/scheduler/src/planner.rs @@ -168,10 +168,7 @@ fn create_unresolved_shuffle( Arc::new(UnresolvedShuffleExec::new( shuffle_writer.stage_id(), shuffle_writer.schema(), - shuffle_writer - .properties() - .output_partitioning() - .partition_count(), + shuffle_writer.properties().output_partitioning().clone(), )) } @@ -239,6 +236,10 @@ pub fn remove_unresolved_shuffles( unresolved_shuffle.stage_id, relevant_locations, unresolved_shuffle.schema().clone(), + unresolved_shuffle + .properties() + .output_partitioning() + .clone(), )?)) } else { new_children.push(remove_unresolved_shuffles( @@ -259,16 +260,12 @@ pub fn rollback_resolved_shuffles( let mut new_children: Vec> = vec![]; for child in stage.children() { if let Some(shuffle_reader) = child.as_any().downcast_ref::() { - let output_partition_count = shuffle_reader - .properties() - .output_partitioning() - .partition_count(); let stage_id = shuffle_reader.stage_id; let unresolved_shuffle = Arc::new(UnresolvedShuffleExec::new( stage_id, shuffle_reader.schema(), - output_partition_count, + shuffle_reader.properties().partitioning.clone(), )); new_children.push(unresolved_shuffle); } else { @@ -396,6 +393,10 @@ mod test { downcast_exec!(unresolved_shuffle, UnresolvedShuffleExec); assert_eq!(unresolved_shuffle.stage_id, 1); assert_eq!(unresolved_shuffle.output_partition_count, 2); + assert_eq!( + unresolved_shuffle.properties().partitioning, + Partitioning::Hash(vec![Arc::new(Column::new("l_returnflag", 0))], 2) + ); // verify stage 2 let stage2 = stages[2].children()[0].clone(); @@ -405,6 +406,10 @@ mod test { downcast_exec!(unresolved_shuffle, UnresolvedShuffleExec); assert_eq!(unresolved_shuffle.stage_id, 2); assert_eq!(unresolved_shuffle.output_partition_count, 2); + assert_eq!( + unresolved_shuffle.properties().partitioning, + Partitioning::Hash(vec![Arc::new(Column::new("l_returnflag", 0))], 2) + ); Ok(()) } @@ -559,6 +564,10 @@ order by let unresolved_shuffle_reader_1 = downcast_exec!(join_input_1, UnresolvedShuffleExec); assert_eq!(unresolved_shuffle_reader_1.output_partition_count, 2); + assert_eq!( + unresolved_shuffle_reader_1.properties().partitioning, + Partitioning::Hash(vec![Arc::new(Column::new("l_orderkey", 0))], 2) + ); let join_input_2 = join.children()[1].clone(); // skip CoalesceBatches @@ -566,6 +575,10 @@ order by let unresolved_shuffle_reader_2 = downcast_exec!(join_input_2, UnresolvedShuffleExec); assert_eq!(unresolved_shuffle_reader_2.output_partition_count, 2); + assert_eq!( + unresolved_shuffle_reader_2.properties().partitioning, + Partitioning::Hash(vec![Arc::new(Column::new("o_orderkey", 0))], 2) + ); // final partitioned hash aggregate assert_eq!(