Skip to content

Commit

Permalink
Remove generics on BallistaCodec
Browse files Browse the repository at this point in the history
  • Loading branch information
lewiszlw committed Aug 27, 2024
1 parent ea94f21 commit 4766439
Show file tree
Hide file tree
Showing 19 changed files with 128 additions and 197 deletions.
19 changes: 4 additions & 15 deletions ballista/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,14 @@ use datafusion::execution::FunctionRegistry;
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
use datafusion_proto::common::proto_error;
use datafusion_proto::physical_plan::from_proto::parse_protobuf_hash_partitioning;
use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
use datafusion_proto::{
convert_required,
logical_plan::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec},
physical_plan::{AsExecutionPlan, PhysicalExtensionCodec},
logical_plan::{DefaultLogicalExtensionCodec, LogicalExtensionCodec},
physical_plan::PhysicalExtensionCodec,
};

use prost::Message;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::Arc;
use std::{convert::TryInto, io::Cursor};

Expand Down Expand Up @@ -69,37 +67,28 @@ pub fn decode_protobuf(bytes: &[u8]) -> Result<BallistaAction, BallistaError> {
}

#[derive(Clone, Debug)]
pub struct BallistaCodec<
T: 'static + AsLogicalPlan = LogicalPlanNode,
U: 'static + AsExecutionPlan = PhysicalPlanNode,
> {
pub struct BallistaCodec {
logical_extension_codec: Arc<dyn LogicalExtensionCodec>,
physical_extension_codec: Arc<dyn PhysicalExtensionCodec>,
logical_plan_repr: PhantomData<T>,
physical_plan_repr: PhantomData<U>,
}

impl Default for BallistaCodec {
fn default() -> Self {
Self {
logical_extension_codec: Arc::new(DefaultLogicalExtensionCodec {}),
physical_extension_codec: Arc::new(BallistaPhysicalExtensionCodec {}),
logical_plan_repr: PhantomData,
physical_plan_repr: PhantomData,
}
}
}

impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> BallistaCodec<T, U> {
impl BallistaCodec {
pub fn new(
logical_extension_codec: Arc<dyn LogicalExtensionCodec>,
physical_extension_codec: Arc<dyn PhysicalExtensionCodec>,
) -> Self {
Self {
logical_extension_codec,
physical_extension_codec,
logical_plan_repr: PhantomData,
physical_plan_repr: PhantomData,
}
}

Expand Down
21 changes: 11 additions & 10 deletions ballista/core/src/serde/scheduler/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::physical_plan::metrics::{Count, Gauge, MetricValue, MetricsSet, Time, Timestamp};
use datafusion::physical_plan::{ExecutionPlan, Metric};
use datafusion::prelude::SessionContext;
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use datafusion_proto::protobuf::PhysicalPlanNode;
use std::collections::HashMap;
use std::convert::TryInto;
use std::sync::Arc;
Expand Down Expand Up @@ -242,10 +242,10 @@ impl Into<ExecutorSpecification> for protobuf::ExecutorSpecification {
}
}

pub fn get_task_definition_vec<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
pub fn get_task_definition_vec(
multi_task: protobuf::MultiTaskDefinition,
runtime: Arc<RuntimeEnv>,
codec: BallistaCodec<T, U>,
codec: BallistaCodec,
) -> Result<Vec<TaskDefinition>, BallistaError> {
let mut props = HashMap::new();
for kv_pair in multi_task.props {
Expand All @@ -254,13 +254,14 @@ pub fn get_task_definition_vec<T: 'static + AsLogicalPlan, U: 'static + AsExecut
let props = Arc::new(props);

let encoded_plan = multi_task.plan.as_slice();
let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
proto.try_into_physical_plan(
&SessionContext::new(),
runtime.as_ref(),
codec.physical_extension_codec(),
)
})?;
let plan: Arc<dyn ExecutionPlan> =
PhysicalPlanNode::try_decode(encoded_plan).and_then(|proto| {
proto.try_into_physical_plan(
&SessionContext::new(),
runtime.as_ref(),
codec.physical_extension_codec(),
)
})?;

let job_id = multi_task.job_id;
let stage_id = multi_task.stage_id as usize;
Expand Down
3 changes: 1 addition & 2 deletions ballista/executor/src/executor_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ use tracing_subscriber::EnvFilter;
use uuid::Uuid;

use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};

use ballista_core::error::BallistaError;
use ballista_core::serde::protobuf::executor_resource::Resource;
Expand Down Expand Up @@ -143,7 +142,7 @@ pub async fn start_executor_process(opt: Arc<ExecutorProcessConfig>) -> Result<(
.max_encoding_message_size(16 * 1024 * 1024)
.max_decoding_message_size(16 * 1024 * 1024);

let default_codec: BallistaCodec<LogicalPlanNode, PhysicalPlanNode> = BallistaCodec::default();
let default_codec: BallistaCodec = BallistaCodec::default();

let mut service_handlers: FuturesUnordered<JoinHandle<Result<(), BallistaError>>> =
FuturesUnordered::new();
Expand Down
33 changes: 15 additions & 18 deletions ballista/executor/src/executor_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ use datafusion::common::DataFusionError;
use datafusion::config::ConfigOptions;
use datafusion::execution::TaskContext;
use datafusion::prelude::SessionConfig;
use datafusion_proto::{logical_plan::AsLogicalPlan, physical_plan::AsExecutionPlan};
use tokio::sync::mpsc::error::TryRecvError;
use tokio::task::JoinHandle;

Expand All @@ -74,11 +73,11 @@ struct CuratorTaskStatus {
task_status: TaskStatus,
}

pub async fn startup<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
pub async fn startup(
scheduler: SchedulerGrpcClient<Channel>,
config: Arc<ExecutorProcessConfig>,
executor: Arc<Executor>,
codec: BallistaCodec<T, U>,
codec: BallistaCodec,
) -> Result<ServerHandle, BallistaError> {
let channel_buf_size = executor.concurrent_tasks * 50;
let (tx_task, rx_task) = mpsc::channel::<CuratorTaskDefinition>(channel_buf_size);
Expand Down Expand Up @@ -134,11 +133,11 @@ pub async fn startup<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
}

#[derive(Clone)]
pub struct ExecutorServer<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
pub struct ExecutorServer {
_start_time: u128,
executor: Arc<Executor>,
executor_env: ExecutorEnv,
codec: BallistaCodec<T, U>,
codec: BallistaCodec,
scheduler_to_register: SchedulerGrpcClient<Channel>,
schedulers: SchedulerClients,
}
Expand All @@ -157,12 +156,12 @@ unsafe impl Sync for ExecutorEnv {}
/// set to `true` when the executor receives a shutdown signal
pub static TERMINATING: AtomicBool = AtomicBool::new(false);

impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T, U> {
impl ExecutorServer {
fn new(
scheduler_to_register: SchedulerGrpcClient<Channel>,
executor: Arc<Executor>,
executor_env: ExecutorEnv,
codec: BallistaCodec<T, U>,
codec: BallistaCodec,
) -> Self {
Self {
_start_time: SystemTime::now()
Expand Down Expand Up @@ -370,12 +369,12 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
}

/// Heartbeater will run forever.
struct Heartbeater<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
executor_server: Arc<ExecutorServer<T, U>>,
struct Heartbeater {
executor_server: Arc<ExecutorServer>,
}

impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> Heartbeater<T, U> {
fn new(executor_server: Arc<ExecutorServer<T, U>>) -> Self {
impl Heartbeater {
fn new(executor_server: Arc<ExecutorServer>) -> Self {
Self { executor_server }
}

Expand All @@ -395,12 +394,12 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> Heartbeater<T, U>
/// First is for sending back task status to scheduler
/// Second is for receiving task from scheduler and run.
/// The two loops will run forever.
struct TaskRunnerPool<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
executor_server: Arc<ExecutorServer<T, U>>,
struct TaskRunnerPool {
executor_server: Arc<ExecutorServer>,
}

impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskRunnerPool<T, U> {
fn new(executor_server: Arc<ExecutorServer<T, U>>) -> Self {
impl TaskRunnerPool {
fn new(executor_server: Arc<ExecutorServer>) -> Self {
Self { executor_server }
}

Expand Down Expand Up @@ -511,9 +510,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskRunnerPool<T,
}

#[tonic::async_trait]
impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorGrpc
for ExecutorServer<T, U>
{
impl ExecutorGrpc for ExecutorServer {
/// by this interface, it can reduce the deserialization cost for multiple tasks
/// belong to the same job stage running on the same one executor
async fn launch_multi_task(
Expand Down
22 changes: 9 additions & 13 deletions ballista/scheduler/src/api/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ use crate::state::execution_graph::ExecutionStage;
use ballista_core::serde::protobuf::job_status::Status;
use ballista_core::BALLISTA_VERSION;
use datafusion::physical_plan::metrics::{MetricValue, MetricsSet, Time};
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;

use std::time::Duration;
use warp::Rejection;
Expand Down Expand Up @@ -60,8 +58,8 @@ pub struct QueryStageSummary {
}

/// Return current scheduler state
pub(crate) async fn get_scheduler_state<T: AsLogicalPlan, U: AsExecutionPlan>(
data_server: SchedulerServer<T, U>,
pub(crate) async fn get_scheduler_state(
data_server: SchedulerServer,
) -> Result<impl warp::Reply, Rejection> {
let response = SchedulerStateResponse {
started: data_server.start_time,
Expand All @@ -71,8 +69,8 @@ pub(crate) async fn get_scheduler_state<T: AsLogicalPlan, U: AsExecutionPlan>(
}

/// Return list of executors
pub(crate) async fn get_executors<T: AsLogicalPlan, U: AsExecutionPlan>(
data_server: SchedulerServer<T, U>,
pub(crate) async fn get_executors(
data_server: SchedulerServer,
) -> Result<impl warp::Reply, Rejection> {
let state = data_server.state;
let executors: Vec<ExecutorMetaResponse> = state
Expand All @@ -93,9 +91,7 @@ pub(crate) async fn get_executors<T: AsLogicalPlan, U: AsExecutionPlan>(
}

/// Return list of jobs
pub(crate) async fn get_jobs<T: AsLogicalPlan, U: AsExecutionPlan>(
data_server: SchedulerServer<T, U>,
) -> Result<impl warp::Reply, Rejection> {
pub(crate) async fn get_jobs(data_server: SchedulerServer) -> Result<impl warp::Reply, Rejection> {
// TODO: Display last seen information in UI
let state = data_server.state;

Expand Down Expand Up @@ -155,8 +151,8 @@ pub(crate) async fn get_jobs<T: AsLogicalPlan, U: AsExecutionPlan>(
Ok(warp::reply::json(&jobs))
}

pub(crate) async fn cancel_job<T: AsLogicalPlan, U: AsExecutionPlan>(
data_server: SchedulerServer<T, U>,
pub(crate) async fn cancel_job(
data_server: SchedulerServer,
job_id: String,
) -> Result<impl warp::Reply, Rejection> {
// 404 if job doesn't exist
Expand Down Expand Up @@ -185,8 +181,8 @@ pub struct QueryStagesResponse {
}

/// Get the execution graph for the specified job id
pub(crate) async fn get_query_stages<T: AsLogicalPlan, U: AsExecutionPlan>(
data_server: SchedulerServer<T, U>,
pub(crate) async fn get_query_stages(
data_server: SchedulerServer,
job_id: String,
) -> Result<impl warp::Reply, Rejection> {
if let Some(graph) = data_server
Expand Down
12 changes: 4 additions & 8 deletions ballista/scheduler/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ mod handlers;

use crate::scheduler_server::SchedulerServer;
use anyhow::Result;
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use std::{
pin::Pin,
task::{Context as TaskContext, Poll},
Expand Down Expand Up @@ -73,15 +71,13 @@ fn map_option_err<T, U: Into<Error>>(err: Option<Result<T, U>>) -> Option<Result
err.map(|e| e.map_err(Into::into))
}

fn with_data_server<T: AsLogicalPlan + Clone, U: 'static + AsExecutionPlan>(
db: SchedulerServer<T, U>,
) -> impl Filter<Extract = (SchedulerServer<T, U>,), Error = std::convert::Infallible> + Clone {
fn with_data_server(
db: SchedulerServer,
) -> impl Filter<Extract = (SchedulerServer,), Error = std::convert::Infallible> + Clone {
warp::any().map(move || db.clone())
}

pub fn get_routes<T: AsLogicalPlan + Clone, U: 'static + AsExecutionPlan>(
scheduler_server: SchedulerServer<T, U>,
) -> BoxedFilter<(impl Reply,)> {
pub fn get_routes(scheduler_server: SchedulerServer) -> BoxedFilter<(impl Reply,)> {
let route_scheduler_state = warp::path!("api" / "state")
.and(with_data_server(scheduler_server.clone()))
.and_then(handlers::get_scheduler_state);
Expand Down
25 changes: 6 additions & 19 deletions ballista/scheduler/src/cluster/kv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
use ballista_core::serde::BallistaCodec;
use dashmap::DashMap;
use datafusion::prelude::SessionContext;
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
use futures::StreamExt;
use itertools::Itertools;
use log::info;
Expand All @@ -48,19 +45,15 @@ use std::future::Future;
use std::sync::Arc;

/// State implementation based on underlying `KeyValueStore`
pub struct KeyValueState<
S: KeyValueStore,
T: 'static + AsLogicalPlan = LogicalPlanNode,
U: 'static + AsExecutionPlan = PhysicalPlanNode,
> {
pub struct KeyValueState<S: KeyValueStore> {
/// Underlying `KeyValueStore`
store: S,
/// ExecutorMetadata cache, executor_id -> ExecutorMetadata
executors: Arc<DashMap<String, ExecutorMetadata>>,
/// ExecutorHeartbeat cache, executor_id -> ExecutorHeartbeat
executor_heartbeats: Arc<DashMap<String, ExecutorHeartbeat>>,
/// Codec used to serialize/deserialize execution plan
codec: BallistaCodec<T, U>,
codec: BallistaCodec,
/// Name of current scheduler. Should be `{host}:{port}`
#[allow(dead_code)]
scheduler: String,
Expand All @@ -70,13 +63,11 @@ pub struct KeyValueState<
session_builder: SessionBuilder,
}

impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
KeyValueState<S, T, U>
{
impl<S: KeyValueStore> KeyValueState<S> {
pub fn new(
scheduler: impl Into<String>,
store: S,
codec: BallistaCodec<T, U>,
codec: BallistaCodec,
session_builder: SessionBuilder,
) -> Self {
Self {
Expand Down Expand Up @@ -134,9 +125,7 @@ impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
}

#[async_trait]
impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ClusterState
for KeyValueState<S, T, U>
{
impl<S: KeyValueStore> ClusterState for KeyValueState<S> {
/// Initialize a background process that will listen for executor heartbeats and update the in-memory cache
/// of executor heartbeats
async fn init(&self) -> Result<()> {
Expand Down Expand Up @@ -393,9 +382,7 @@ impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
}

#[async_trait]
impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> JobState
for KeyValueState<S, T, U>
{
impl<S: KeyValueStore> JobState for KeyValueState<S> {
fn accept_job(&self, job_id: &str, queued_at: u64) -> Result<()> {
self.queued_jobs.insert(job_id.to_string(), queued_at);

Expand Down
Loading

0 comments on commit 4766439

Please sign in to comment.