From 476643974552daa1e1ee58403f885305514430bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Tue, 27 Aug 2024 12:39:26 +0800 Subject: [PATCH] Remove generics on BallistaCodec --- ballista/core/src/serde/mod.rs | 19 ++-------- .../core/src/serde/scheduler/from_proto.rs | 21 +++++----- ballista/executor/src/executor_process.rs | 3 +- ballista/executor/src/executor_server.rs | 33 ++++++++-------- ballista/scheduler/src/api/handlers.rs | 22 +++++------ ballista/scheduler/src/api/mod.rs | 12 ++---- ballista/scheduler/src/cluster/kv.rs | 25 +++--------- ballista/scheduler/src/cluster/mod.rs | 6 +-- ballista/scheduler/src/flight_sql.rs | 5 +-- ballista/scheduler/src/planner.rs | 4 +- ballista/scheduler/src/scheduler_process.rs | 17 ++++----- .../scheduler/src/scheduler_server/grpc.rs | 38 ++++++++----------- .../scheduler/src/scheduler_server/mod.rs | 10 ++--- .../scheduler_server/query_stage_scheduler.rs | 14 +++---- .../scheduler/src/state/execution_graph.rs | 16 ++------ .../scheduler/src/state/execution_stage.rs | 38 +++++++++---------- ballista/scheduler/src/state/mod.rs | 12 +++--- ballista/scheduler/src/state/task_manager.rs | 12 +++--- ballista/scheduler/src/test_utils.rs | 18 ++++----- 19 files changed, 128 insertions(+), 197 deletions(-) diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index 81705d83..72a9d86f 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -26,16 +26,14 @@ use datafusion::execution::FunctionRegistry; use datafusion::physical_plan::{ExecutionPlan, Partitioning}; use datafusion_proto::common::proto_error; use datafusion_proto::physical_plan::from_proto::parse_protobuf_hash_partitioning; -use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; use datafusion_proto::{ convert_required, - logical_plan::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec}, - physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}, + logical_plan::{DefaultLogicalExtensionCodec, LogicalExtensionCodec}, + physical_plan::PhysicalExtensionCodec, }; use prost::Message; use std::fmt::Debug; -use std::marker::PhantomData; use std::sync::Arc; use std::{convert::TryInto, io::Cursor}; @@ -69,14 +67,9 @@ pub fn decode_protobuf(bytes: &[u8]) -> Result { } #[derive(Clone, Debug)] -pub struct BallistaCodec< - T: 'static + AsLogicalPlan = LogicalPlanNode, - U: 'static + AsExecutionPlan = PhysicalPlanNode, -> { +pub struct BallistaCodec { logical_extension_codec: Arc, physical_extension_codec: Arc, - logical_plan_repr: PhantomData, - physical_plan_repr: PhantomData, } impl Default for BallistaCodec { @@ -84,13 +77,11 @@ impl Default for BallistaCodec { Self { logical_extension_codec: Arc::new(DefaultLogicalExtensionCodec {}), physical_extension_codec: Arc::new(BallistaPhysicalExtensionCodec {}), - logical_plan_repr: PhantomData, - physical_plan_repr: PhantomData, } } } -impl BallistaCodec { +impl BallistaCodec { pub fn new( logical_extension_codec: Arc, physical_extension_codec: Arc, @@ -98,8 +89,6 @@ impl BallistaCodec for protobuf::ExecutorSpecification { } } -pub fn get_task_definition_vec( +pub fn get_task_definition_vec( multi_task: protobuf::MultiTaskDefinition, runtime: Arc, - codec: BallistaCodec, + codec: BallistaCodec, ) -> Result, BallistaError> { let mut props = HashMap::new(); for kv_pair in multi_task.props { @@ -254,13 +254,14 @@ pub fn get_task_definition_vec = U::try_decode(encoded_plan).and_then(|proto| { - proto.try_into_physical_plan( - &SessionContext::new(), - runtime.as_ref(), - codec.physical_extension_codec(), - ) - })?; + let plan: Arc = + PhysicalPlanNode::try_decode(encoded_plan).and_then(|proto| { + proto.try_into_physical_plan( + &SessionContext::new(), + runtime.as_ref(), + codec.physical_extension_codec(), + ) + })?; let job_id = multi_task.job_id; let stage_id = multi_task.stage_id as usize; diff --git a/ballista/executor/src/executor_process.rs b/ballista/executor/src/executor_process.rs index e1f51616..e78cff3a 100644 --- a/ballista/executor/src/executor_process.rs +++ b/ballista/executor/src/executor_process.rs @@ -34,7 +34,6 @@ use tracing_subscriber::EnvFilter; use uuid::Uuid; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; use ballista_core::error::BallistaError; use ballista_core::serde::protobuf::executor_resource::Resource; @@ -143,7 +142,7 @@ pub async fn start_executor_process(opt: Arc) -> Result<( .max_encoding_message_size(16 * 1024 * 1024) .max_decoding_message_size(16 * 1024 * 1024); - let default_codec: BallistaCodec = BallistaCodec::default(); + let default_codec: BallistaCodec = BallistaCodec::default(); let mut service_handlers: FuturesUnordered>> = FuturesUnordered::new(); diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs index 9f1a6ea7..38830c72 100644 --- a/ballista/executor/src/executor_server.rs +++ b/ballista/executor/src/executor_server.rs @@ -48,7 +48,6 @@ use datafusion::common::DataFusionError; use datafusion::config::ConfigOptions; use datafusion::execution::TaskContext; use datafusion::prelude::SessionConfig; -use datafusion_proto::{logical_plan::AsLogicalPlan, physical_plan::AsExecutionPlan}; use tokio::sync::mpsc::error::TryRecvError; use tokio::task::JoinHandle; @@ -74,11 +73,11 @@ struct CuratorTaskStatus { task_status: TaskStatus, } -pub async fn startup( +pub async fn startup( scheduler: SchedulerGrpcClient, config: Arc, executor: Arc, - codec: BallistaCodec, + codec: BallistaCodec, ) -> Result { let channel_buf_size = executor.concurrent_tasks * 50; let (tx_task, rx_task) = mpsc::channel::(channel_buf_size); @@ -134,11 +133,11 @@ pub async fn startup( } #[derive(Clone)] -pub struct ExecutorServer { +pub struct ExecutorServer { _start_time: u128, executor: Arc, executor_env: ExecutorEnv, - codec: BallistaCodec, + codec: BallistaCodec, scheduler_to_register: SchedulerGrpcClient, schedulers: SchedulerClients, } @@ -157,12 +156,12 @@ unsafe impl Sync for ExecutorEnv {} /// set to `true` when the executor receives a shutdown signal pub static TERMINATING: AtomicBool = AtomicBool::new(false); -impl ExecutorServer { +impl ExecutorServer { fn new( scheduler_to_register: SchedulerGrpcClient, executor: Arc, executor_env: ExecutorEnv, - codec: BallistaCodec, + codec: BallistaCodec, ) -> Self { Self { _start_time: SystemTime::now() @@ -370,12 +369,12 @@ impl ExecutorServer { - executor_server: Arc>, +struct Heartbeater { + executor_server: Arc, } -impl Heartbeater { - fn new(executor_server: Arc>) -> Self { +impl Heartbeater { + fn new(executor_server: Arc) -> Self { Self { executor_server } } @@ -395,12 +394,12 @@ impl Heartbeater /// First is for sending back task status to scheduler /// Second is for receiving task from scheduler and run. /// The two loops will run forever. -struct TaskRunnerPool { - executor_server: Arc>, +struct TaskRunnerPool { + executor_server: Arc, } -impl TaskRunnerPool { - fn new(executor_server: Arc>) -> Self { +impl TaskRunnerPool { + fn new(executor_server: Arc) -> Self { Self { executor_server } } @@ -511,9 +510,7 @@ impl TaskRunnerPool ExecutorGrpc - for ExecutorServer -{ +impl ExecutorGrpc for ExecutorServer { /// by this interface, it can reduce the deserialization cost for multiple tasks /// belong to the same job stage running on the same one executor async fn launch_multi_task( diff --git a/ballista/scheduler/src/api/handlers.rs b/ballista/scheduler/src/api/handlers.rs index 19fa0a31..87a08dc6 100644 --- a/ballista/scheduler/src/api/handlers.rs +++ b/ballista/scheduler/src/api/handlers.rs @@ -16,8 +16,6 @@ use crate::state::execution_graph::ExecutionStage; use ballista_core::serde::protobuf::job_status::Status; use ballista_core::BALLISTA_VERSION; use datafusion::physical_plan::metrics::{MetricValue, MetricsSet, Time}; -use datafusion_proto::logical_plan::AsLogicalPlan; -use datafusion_proto::physical_plan::AsExecutionPlan; use std::time::Duration; use warp::Rejection; @@ -60,8 +58,8 @@ pub struct QueryStageSummary { } /// Return current scheduler state -pub(crate) async fn get_scheduler_state( - data_server: SchedulerServer, +pub(crate) async fn get_scheduler_state( + data_server: SchedulerServer, ) -> Result { let response = SchedulerStateResponse { started: data_server.start_time, @@ -71,8 +69,8 @@ pub(crate) async fn get_scheduler_state( } /// Return list of executors -pub(crate) async fn get_executors( - data_server: SchedulerServer, +pub(crate) async fn get_executors( + data_server: SchedulerServer, ) -> Result { let state = data_server.state; let executors: Vec = state @@ -93,9 +91,7 @@ pub(crate) async fn get_executors( } /// Return list of jobs -pub(crate) async fn get_jobs( - data_server: SchedulerServer, -) -> Result { +pub(crate) async fn get_jobs(data_server: SchedulerServer) -> Result { // TODO: Display last seen information in UI let state = data_server.state; @@ -155,8 +151,8 @@ pub(crate) async fn get_jobs( Ok(warp::reply::json(&jobs)) } -pub(crate) async fn cancel_job( - data_server: SchedulerServer, +pub(crate) async fn cancel_job( + data_server: SchedulerServer, job_id: String, ) -> Result { // 404 if job doesn't exist @@ -185,8 +181,8 @@ pub struct QueryStagesResponse { } /// Get the execution graph for the specified job id -pub(crate) async fn get_query_stages( - data_server: SchedulerServer, +pub(crate) async fn get_query_stages( + data_server: SchedulerServer, job_id: String, ) -> Result { if let Some(graph) = data_server diff --git a/ballista/scheduler/src/api/mod.rs b/ballista/scheduler/src/api/mod.rs index 46c812c6..94b0422f 100644 --- a/ballista/scheduler/src/api/mod.rs +++ b/ballista/scheduler/src/api/mod.rs @@ -14,8 +14,6 @@ mod handlers; use crate::scheduler_server::SchedulerServer; use anyhow::Result; -use datafusion_proto::logical_plan::AsLogicalPlan; -use datafusion_proto::physical_plan::AsExecutionPlan; use std::{ pin::Pin, task::{Context as TaskContext, Poll}, @@ -73,15 +71,13 @@ fn map_option_err>(err: Option>) -> Option( - db: SchedulerServer, -) -> impl Filter,), Error = std::convert::Infallible> + Clone { +fn with_data_server( + db: SchedulerServer, +) -> impl Filter + Clone { warp::any().map(move || db.clone()) } -pub fn get_routes( - scheduler_server: SchedulerServer, -) -> BoxedFilter<(impl Reply,)> { +pub fn get_routes(scheduler_server: SchedulerServer) -> BoxedFilter<(impl Reply,)> { let route_scheduler_state = warp::path!("api" / "state") .and(with_data_server(scheduler_server.clone())) .and_then(handlers::get_scheduler_state); diff --git a/ballista/scheduler/src/cluster/kv.rs b/ballista/scheduler/src/cluster/kv.rs index e26dd047..bd9a9327 100644 --- a/ballista/scheduler/src/cluster/kv.rs +++ b/ballista/scheduler/src/cluster/kv.rs @@ -36,9 +36,6 @@ use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; use ballista_core::serde::BallistaCodec; use dashmap::DashMap; use datafusion::prelude::SessionContext; -use datafusion_proto::logical_plan::AsLogicalPlan; -use datafusion_proto::physical_plan::AsExecutionPlan; -use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; use futures::StreamExt; use itertools::Itertools; use log::info; @@ -48,11 +45,7 @@ use std::future::Future; use std::sync::Arc; /// State implementation based on underlying `KeyValueStore` -pub struct KeyValueState< - S: KeyValueStore, - T: 'static + AsLogicalPlan = LogicalPlanNode, - U: 'static + AsExecutionPlan = PhysicalPlanNode, -> { +pub struct KeyValueState { /// Underlying `KeyValueStore` store: S, /// ExecutorMetadata cache, executor_id -> ExecutorMetadata @@ -60,7 +53,7 @@ pub struct KeyValueState< /// ExecutorHeartbeat cache, executor_id -> ExecutorHeartbeat executor_heartbeats: Arc>, /// Codec used to serialize/deserialize execution plan - codec: BallistaCodec, + codec: BallistaCodec, /// Name of current scheduler. Should be `{host}:{port}` #[allow(dead_code)] scheduler: String, @@ -70,13 +63,11 @@ pub struct KeyValueState< session_builder: SessionBuilder, } -impl - KeyValueState -{ +impl KeyValueState { pub fn new( scheduler: impl Into, store: S, - codec: BallistaCodec, + codec: BallistaCodec, session_builder: SessionBuilder, ) -> Self { Self { @@ -134,9 +125,7 @@ impl } #[async_trait] -impl ClusterState - for KeyValueState -{ +impl ClusterState for KeyValueState { /// Initialize a background process that will listen for executor heartbeats and update the in-memory cache /// of executor heartbeats async fn init(&self) -> Result<()> { @@ -393,9 +382,7 @@ impl } #[async_trait] -impl JobState - for KeyValueState -{ +impl JobState for KeyValueState { fn accept_job(&self, job_id: &str, queued_at: u64) -> Result<()> { self.queued_jobs.insert(job_id.to_string(), queued_at); diff --git a/ballista/scheduler/src/cluster/mod.rs b/ballista/scheduler/src/cluster/mod.rs index 2b71b7c4..ad4aa0de 100644 --- a/ballista/scheduler/src/cluster/mod.rs +++ b/ballista/scheduler/src/cluster/mod.rs @@ -20,8 +20,6 @@ use std::pin::Pin; use std::sync::Arc; use datafusion::prelude::SessionContext; -use datafusion_proto::logical_plan::AsLogicalPlan; -use datafusion_proto::physical_plan::AsExecutionPlan; use futures::Stream; use log::{debug, info, warn}; @@ -63,11 +61,11 @@ impl BallistaCluster { } } - pub fn new_kv( + pub fn new_kv( store: S, scheduler: impl Into, session_builder: SessionBuilder, - codec: BallistaCodec, + codec: BallistaCodec, ) -> Self { let kv_state = Arc::new(KeyValueState::new(scheduler, store, codec, session_builder)); Self { diff --git a/ballista/scheduler/src/flight_sql.rs b/ballista/scheduler/src/flight_sql.rs index ff6b59cb..aef4c237 100644 --- a/ballista/scheduler/src/flight_sql.rs +++ b/ballista/scheduler/src/flight_sql.rs @@ -60,7 +60,6 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::DFSchemaRef; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::SessionContext; -use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; use prost::Message; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::time::sleep; @@ -69,7 +68,7 @@ use tonic::metadata::MetadataValue; use uuid::Uuid; pub struct FlightSqlServiceImpl { - server: SchedulerServer, + server: SchedulerServer, statements: Arc>, contexts: Arc>>, } @@ -77,7 +76,7 @@ pub struct FlightSqlServiceImpl { const TABLE_TYPES: [&str; 2] = ["TABLE", "VIEW"]; impl FlightSqlServiceImpl { - pub fn new(server: SchedulerServer) -> Self { + pub fn new(server: SchedulerServer) -> Self { Self { server, statements: Default::default(), diff --git a/ballista/scheduler/src/planner.rs b/ballista/scheduler/src/planner.rs index d927379d..fe2018a8 100644 --- a/ballista/scheduler/src/planner.rs +++ b/ballista/scheduler/src/planner.rs @@ -266,8 +266,6 @@ mod test { use datafusion::physical_plan::{displayable, ExecutionPlan}; use datafusion::prelude::SessionContext; use datafusion_proto::physical_plan::AsExecutionPlan; - use datafusion_proto::protobuf::LogicalPlanNode; - use datafusion_proto::protobuf::PhysicalPlanNode; use std::ops::Deref; use std::sync::Arc; use uuid::Uuid; @@ -571,7 +569,7 @@ order by ctx: &SessionContext, plan: Arc, ) -> Result, BallistaError> { - let codec: BallistaCodec = BallistaCodec::default(); + let codec: BallistaCodec = BallistaCodec::default(); let proto: datafusion_proto::protobuf::PhysicalPlanNode = datafusion_proto::protobuf::PhysicalPlanNode::try_from_physical_plan( plan.clone(), diff --git a/ballista/scheduler/src/scheduler_process.rs b/ballista/scheduler/src/scheduler_process.rs index 842e70b0..eca1207d 100644 --- a/ballista/scheduler/src/scheduler_process.rs +++ b/ballista/scheduler/src/scheduler_process.rs @@ -26,8 +26,6 @@ use std::sync::Arc; use tonic::transport::server::Connected; use tower::Service; -use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; - use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer; use ballista_core::serde::BallistaCodec; use ballista_core::utils::create_grpc_server; @@ -52,14 +50,13 @@ pub async fn start_server( // Should only call SchedulerServer::new() once in the process info!("Starting Scheduler grpc server with push task scheduling policy",); - let mut scheduler_server: SchedulerServer = - SchedulerServer::new( - config.scheduler_name(), - cluster, - BallistaCodec::default(), - config.clone(), - Arc::new(DefaultTaskLauncher::new(config.scheduler_name())), - ); + let mut scheduler_server: SchedulerServer = SchedulerServer::new( + config.scheduler_name(), + cluster, + BallistaCodec::default(), + config.clone(), + Arc::new(DefaultTaskLauncher::new(config.scheduler_name())), + ); scheduler_server.init().await?; diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index 2cacaef5..d5d02d23 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -22,8 +22,6 @@ use ballista_core::serde::protobuf::{ }; use ballista_core::serde::scheduler::ExecutorMetadata; -use datafusion_proto::logical_plan::AsLogicalPlan; -use datafusion_proto::physical_plan::AsExecutionPlan; use log::{debug, error, warn}; use tonic::{Request, Response, Status}; @@ -31,9 +29,7 @@ use tonic::{Request, Response, Status}; use crate::scheduler_server::{timestamp_secs, SchedulerServer}; #[tonic::async_trait] -impl SchedulerGrpc - for SchedulerServer -{ +impl SchedulerGrpc for SchedulerServer { async fn heart_beat_from_executor( &self, request: Request, @@ -127,8 +123,6 @@ mod test { use std::sync::Arc; use std::time::Duration; - use datafusion_proto::protobuf::LogicalPlanNode; - use datafusion_proto::protobuf::PhysicalPlanNode; use tonic::Request; use crate::config::SchedulerConfig; @@ -150,14 +144,13 @@ mod test { let config = SchedulerConfig::default(); let scheduler_name = "localhost:50050".to_owned(); - let mut scheduler: SchedulerServer = - SchedulerServer::new( - scheduler_name.clone(), - cluster, - BallistaCodec::default(), - Arc::new(config), - Arc::new(DefaultTaskLauncher::new(scheduler_name)), - ); + let mut scheduler: SchedulerServer = SchedulerServer::new( + scheduler_name.clone(), + cluster, + BallistaCodec::default(), + Arc::new(config), + Arc::new(DefaultTaskLauncher::new(scheduler_name)), + ); scheduler.init().await?; let exec_meta = ExecutorRegistration { @@ -202,14 +195,13 @@ mod test { let config = SchedulerConfig::default(); let scheduler_name = "localhost:50050".to_owned(); - let mut scheduler: SchedulerServer = - SchedulerServer::new( - scheduler_name.clone(), - cluster.clone(), - BallistaCodec::default(), - Arc::new(config), - Arc::new(DefaultTaskLauncher::new(scheduler_name)), - ); + let mut scheduler: SchedulerServer = SchedulerServer::new( + scheduler_name.clone(), + cluster.clone(), + BallistaCodec::default(), + Arc::new(config), + Arc::new(DefaultTaskLauncher::new(scheduler_name)), + ); scheduler.init().await?; let exec_meta = ExecutorRegistration { diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index 980b773b..43d3477c 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -26,8 +26,6 @@ use ballista_core::serde::BallistaCodec; use datafusion::execution::context::SessionState; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_proto::logical_plan::AsLogicalPlan; -use datafusion_proto::physical_plan::AsExecutionPlan; use crate::cluster::BallistaCluster; use crate::config::SchedulerConfig; @@ -49,20 +47,20 @@ pub(crate) mod query_stage_scheduler; pub(crate) type SessionBuilder = fn(SessionConfig) -> SessionState; #[derive(Clone)] -pub struct SchedulerServer { +pub struct SchedulerServer { pub scheduler_name: String, pub start_time: u128, - pub state: Arc>, + pub state: Arc, pub(crate) query_stage_event_loop: EventLoop, #[allow(dead_code)] config: Arc, } -impl SchedulerServer { +impl SchedulerServer { pub fn new( scheduler_name: String, cluster: BallistaCluster, - codec: BallistaCodec, + codec: BallistaCodec, config: Arc, task_launcher: Arc, ) -> Self { diff --git a/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs b/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs index 4899fde8..d2e8e5b1 100644 --- a/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs +++ b/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs @@ -25,30 +25,26 @@ use ballista_core::event_loop::{EventAction, EventSender}; use crate::config::SchedulerConfig; use crate::scheduler_server::timestamp_millis; -use datafusion_proto::logical_plan::AsLogicalPlan; -use datafusion_proto::physical_plan::AsExecutionPlan; use tokio::sync::mpsc; use crate::scheduler_server::event::QueryStageSchedulerEvent; use crate::state::SchedulerState; -pub(crate) struct QueryStageScheduler { - state: Arc>, +pub(crate) struct QueryStageScheduler { + state: Arc, #[allow(dead_code)] config: Arc, } -impl QueryStageScheduler { - pub(crate) fn new(state: Arc>, config: Arc) -> Self { +impl QueryStageScheduler { + pub(crate) fn new(state: Arc, config: Arc) -> Self { Self { state, config } } } #[async_trait] -impl EventAction - for QueryStageScheduler -{ +impl EventAction for QueryStageScheduler { fn on_start(&self) { info!("Starting QueryStageScheduler"); } diff --git a/ballista/scheduler/src/state/execution_graph.rs b/ballista/scheduler/src/state/execution_graph.rs index 5e46b977..47ecb5be 100644 --- a/ballista/scheduler/src/state/execution_graph.rs +++ b/ballista/scheduler/src/state/execution_graph.rs @@ -24,7 +24,6 @@ use std::time::{SystemTime, UNIX_EPOCH}; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{accept, ExecutionPlan, ExecutionPlanVisitor}; use datafusion::prelude::SessionContext; -use datafusion_proto::logical_plan::AsLogicalPlan; use log::{debug, info, warn}; use ballista_core::error::{BallistaError, Result}; @@ -39,7 +38,6 @@ use ballista_core::serde::scheduler::{ ExecutorMetadata, PartitionId, PartitionLocation, PartitionStats, }; use ballista_core::serde::BallistaCodec; -use datafusion_proto::physical_plan::AsExecutionPlan; use crate::display::print_stage_metrics; use crate::planner::DistributedPlanner; @@ -636,12 +634,9 @@ impl ExecutionGraph { Ok(()) } - pub(crate) async fn decode_execution_graph< - T: 'static + AsLogicalPlan, - U: 'static + AsExecutionPlan, - >( + pub(crate) async fn decode_execution_graph( proto: protobuf::ExecutionGraph, - codec: &BallistaCodec, + codec: &BallistaCodec, session_ctx: &SessionContext, ) -> Result { let mut stages: HashMap = HashMap::new(); @@ -693,12 +688,9 @@ impl ExecutionGraph { /// Running stages will not be persisted so that will not be encoded. /// Running stages will be convert back to the resolved stages to be encoded and persisted - pub(crate) fn encode_execution_graph< - T: 'static + AsLogicalPlan, - U: 'static + AsExecutionPlan, - >( + pub(crate) fn encode_execution_graph( graph: ExecutionGraph, - codec: &BallistaCodec, + codec: &BallistaCodec, ) -> Result { let job_id = graph.job_id().to_owned(); diff --git a/ballista/scheduler/src/state/execution_stage.rs b/ballista/scheduler/src/state/execution_stage.rs index b927aee3..f6d8afd9 100644 --- a/ballista/scheduler/src/state/execution_stage.rs +++ b/ballista/scheduler/src/state/execution_stage.rs @@ -28,7 +28,6 @@ use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::metrics::{MetricValue, MetricsSet}; use datafusion::physical_plan::{ExecutionPlan, Metric}; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_proto::logical_plan::AsLogicalPlan; use log::{debug, warn}; use ballista_core::error::{BallistaError, Result}; @@ -40,6 +39,7 @@ use ballista_core::serde::protobuf::{task_status, RunningTask}; use ballista_core::serde::scheduler::PartitionLocation; use ballista_core::serde::BallistaCodec; use datafusion_proto::physical_plan::AsExecutionPlan; +use datafusion_proto::protobuf::PhysicalPlanNode; use crate::display::DisplayableBallistaExecutionPlan; @@ -260,12 +260,12 @@ impl UnresolvedStage { )) } - pub(super) fn decode( + pub(super) fn decode( stage: protobuf::UnResolvedStage, - codec: &BallistaCodec, + codec: &BallistaCodec, session_ctx: &SessionContext, ) -> Result { - let plan_proto = U::try_decode(&stage.plan)?; + let plan_proto = PhysicalPlanNode::try_decode(&stage.plan)?; let plan = plan_proto.try_into_physical_plan( session_ctx, session_ctx.runtime_env().as_ref(), @@ -282,12 +282,12 @@ impl UnresolvedStage { }) } - pub(super) fn encode( + pub(super) fn encode( stage: UnresolvedStage, - codec: &BallistaCodec, + codec: &BallistaCodec, ) -> Result { let mut plan: Vec = vec![]; - U::try_from_physical_plan(stage.plan, codec.physical_extension_codec()) + PhysicalPlanNode::try_from_physical_plan(stage.plan, codec.physical_extension_codec()) .and_then(|proto| proto.try_encode(&mut plan))?; let inputs = encode_inputs(stage.inputs)?; @@ -345,12 +345,12 @@ impl ResolvedStage { ) } - pub(super) fn decode( + pub(super) fn decode( stage: protobuf::ResolvedStage, - codec: &BallistaCodec, + codec: &BallistaCodec, session_ctx: &SessionContext, ) -> Result { - let plan_proto = U::try_decode(&stage.plan)?; + let plan_proto = PhysicalPlanNode::try_decode(&stage.plan)?; let plan = plan_proto.try_into_physical_plan( session_ctx, session_ctx.runtime_env().as_ref(), @@ -368,12 +368,12 @@ impl ResolvedStage { }) } - pub(super) fn encode( + pub(super) fn encode( stage: ResolvedStage, - codec: &BallistaCodec, + codec: &BallistaCodec, ) -> Result { let mut plan: Vec = vec![]; - U::try_from_physical_plan(stage.plan, codec.physical_extension_codec()) + PhysicalPlanNode::try_from_physical_plan(stage.plan, codec.physical_extension_codec()) .and_then(|proto| proto.try_encode(&mut plan))?; let inputs = encode_inputs(stage.inputs)?; @@ -618,12 +618,12 @@ impl Debug for RunningStage { } impl SuccessfulStage { - pub(super) fn decode( + pub(super) fn decode( stage: protobuf::SuccessfulStage, - codec: &BallistaCodec, + codec: &BallistaCodec, session_ctx: &SessionContext, ) -> Result { - let plan_proto = U::try_decode(&stage.plan)?; + let plan_proto = PhysicalPlanNode::try_decode(&stage.plan)?; let plan = plan_proto.try_into_physical_plan( session_ctx, session_ctx.runtime_env().as_ref(), @@ -654,15 +654,15 @@ impl SuccessfulStage { }) } - pub(super) fn encode( + pub(super) fn encode( _job_id: String, stage: SuccessfulStage, - codec: &BallistaCodec, + codec: &BallistaCodec, ) -> Result { let stage_id = stage.stage_id; let mut plan: Vec = vec![]; - U::try_from_physical_plan(stage.plan, codec.physical_extension_codec()) + PhysicalPlanNode::try_from_physical_plan(stage.plan, codec.physical_extension_codec()) .and_then(|proto| proto.try_encode(&mut plan))?; let inputs = encode_inputs(stage.inputs)?; diff --git a/ballista/scheduler/src/state/mod.rs b/ballista/scheduler/src/state/mod.rs index 24d92dd9..098dcf87 100644 --- a/ballista/scheduler/src/state/mod.rs +++ b/ballista/scheduler/src/state/mod.rs @@ -40,8 +40,6 @@ use ballista_core::serde::BallistaCodec; use datafusion::logical_expr::LogicalPlan; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::prelude::SessionContext; -use datafusion_proto::logical_plan::AsLogicalPlan; -use datafusion_proto::physical_plan::AsExecutionPlan; use log::{debug, error, info, warn}; use prost::Message; @@ -74,18 +72,18 @@ pub fn encode_protobuf(msg: &T) -> Result> { } #[derive(Clone)] -pub struct SchedulerState { +pub struct SchedulerState { pub executor_manager: ExecutorManager, - pub task_manager: TaskManager, + pub task_manager: TaskManager, pub session_manager: SessionManager, - pub codec: BallistaCodec, + pub codec: BallistaCodec, pub config: Arc, } -impl SchedulerState { +impl SchedulerState { pub fn new( cluster: BallistaCluster, - codec: BallistaCodec, + codec: BallistaCodec, scheduler_name: String, config: Arc, launcher: Arc, diff --git a/ballista/scheduler/src/state/task_manager.rs b/ballista/scheduler/src/state/task_manager.rs index e990e2c1..4f2c95f7 100644 --- a/ballista/scheduler/src/state/task_manager.rs +++ b/ballista/scheduler/src/state/task_manager.rs @@ -34,8 +34,8 @@ use ballista_core::serde::BallistaCodec; use dashmap::DashMap; use datafusion::physical_plan::ExecutionPlan; -use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; +use datafusion_proto::protobuf::PhysicalPlanNode; use log::{debug, error, info, warn}; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; @@ -101,9 +101,9 @@ impl TaskLauncher for DefaultTaskLauncher { } #[derive(Clone)] -pub struct TaskManager { +pub struct TaskManager { state: Arc, - codec: BallistaCodec, + codec: BallistaCodec, scheduler_id: String, // Cache for active jobs curated by this scheduler active_job_cache: ActiveJobCache, @@ -138,10 +138,10 @@ pub struct UpdatedStages { pub failed_stages: HashMap, } -impl TaskManager { +impl TaskManager { pub fn new( state: Arc, - codec: BallistaCodec, + codec: BallistaCodec, scheduler_id: String, launcher: Arc, ) -> Self { @@ -455,7 +455,7 @@ impl TaskManager plan.clone() } else { let mut plan_buf: Vec = vec![]; - let plan_proto = U::try_from_physical_plan( + let plan_proto = PhysicalPlanNode::try_from_physical_plan( task.plan.clone(), self.codec.physical_extension_codec(), )?; diff --git a/ballista/scheduler/src/test_utils.rs b/ballista/scheduler/src/test_utils.rs index bba586bc..db5b7490 100644 --- a/ballista/scheduler/src/test_utils.rs +++ b/ballista/scheduler/src/test_utils.rs @@ -54,7 +54,6 @@ use crate::scheduler_server::event::QueryStageSchedulerEvent; use crate::cluster::storage::sled::SledClient; use crate::state::execution_graph::{ExecutionGraph, TaskDescription}; use ballista_core::utils::default_session_builder; -use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; use tokio::sync::mpsc::{channel, Receiver, Sender}; pub const TPCH_TABLES: &[&str] = &[ @@ -366,7 +365,7 @@ impl TaskLauncher for VirtualTaskLauncher { } pub struct SchedulerTest { - scheduler: SchedulerServer, + scheduler: SchedulerServer, ballista_config: HashMap, status_receiver: Option)>>, } @@ -412,14 +411,13 @@ impl SchedulerTest { executors: executors.clone(), }; - let mut scheduler: SchedulerServer = - SchedulerServer::new( - "localhost:50050".to_owned(), - cluster, - BallistaCodec::default(), - Arc::new(config), - Arc::new(launcher), - ); + let mut scheduler: SchedulerServer = SchedulerServer::new( + "localhost:50050".to_owned(), + cluster, + BallistaCodec::default(), + Arc::new(config), + Arc::new(launcher), + ); scheduler.init().await?; for (executor_id, VirtualExecutor { task_slots, .. }) in executors {