use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Instant;
use brane_ast::Workflow;
use brane_cfg::info::Info;
use brane_cfg::infra::InfraFile;
use brane_cfg::node::{CentralConfig, NodeConfig, NodeSpecificConfig};
use brane_exe::FullValue;
use brane_prx::client::ProxyClient;
use brane_tsk::errors::PlanError;
use brane_tsk::spec::AppId;
use dashmap::DashMap;
use enum_debug::EnumDebug as _;
use error_trace::{ErrorTrace as _, trace};
use log::{debug, error, info};
use specifications::driving::{CheckReply, CheckRequest, CreateSessionReply, CreateSessionRequest, DriverService, ExecuteReply, ExecuteRequest};
use specifications::profiling::ProfileReport;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_stream::wrappers::ReceiverStream;
use tonic::{Request, Response, Status};
use crate::check::RequestOutput;
use crate::errors::RemoteVmError;
use crate::planner::InstancePlanner;
use crate::vm::InstanceVm;
use crate::{check, gc};
macro_rules! fatal_err {
($tx:ident,Status:: $status:ident, $err:expr) => {{
log::error!("{}", $err.trace());
let serr: String = $err.to_string();
if let Err(err) = $tx.send(Err(Status::$status(serr))).await {
log::error!("{}", trace!(("Failed to notify client of error"), err));
}
return;
}};
($tx:ident, $status:expr) => {{
log::error!("Aborting incoming request: {}", $status);
if let Err(err) = $tx.send(Err($status)).await {
log::error!("{}", trace!(("Failed to notify client of error"), err));
}
return;
}};
($tx:ident, $rx:ident,Status:: $status:ident, $err:expr) => {{
log::error!("{}", $err.trace());
if let Err(err) = $tx.send(Err(Status::$status($err.to_string()))).await {
log::error!("{}", trace!(("Failed to notify client of error"), err));
}
return Ok(Response::new(ReceiverStream::new($rx)));
}};
($tx:ident, $rx:ident, $status:expr) => {{
log::error!("Aborting incoming request: {}", $status);
if let Err(err) = $tx.send(Err($status)).await {
log::error!("{}", trace!(("Failed to notify client of error"), err));
}
return Ok(Response::new(ReceiverStream::new($rx)));
}};
}
#[derive(Clone)]
pub struct DriverHandler {
node_config_path: PathBuf,
proxy: Arc<ProxyClient>,
sessions: Arc<DashMap<AppId, (InstanceVm, Instant)>>,
}
impl DriverHandler {
#[inline]
pub fn new(node_config_path: impl Into<PathBuf>, proxy: Arc<ProxyClient>) -> Self {
let sessions: Arc<DashMap<AppId, (InstanceVm, Instant)>> = Arc::new(DashMap::new());
tokio::spawn(gc::sessions(Arc::downgrade(&sessions)));
Self { node_config_path: node_config_path.into(), proxy, sessions }
}
}
#[tonic::async_trait]
impl DriverService for DriverHandler {
type ExecuteStream = ReceiverStream<Result<ExecuteReply, Status>>;
async fn create_session(&self, _request: Request<CreateSessionRequest>) -> Result<Response<CreateSessionReply>, Status> {
let report = ProfileReport::auto_reporting_file("brane-drv DriverHandler::create_session", "brane-drv_create-session");
let _guard = report.time("Total");
let app_id: AppId = AppId::generate();
self.sessions.insert(app_id.clone(), (InstanceVm::new(&self.node_config_path, app_id.clone(), self.proxy.clone()), Instant::now()));
debug!("Created new session '{}'", app_id);
let reply = CreateSessionReply { uuid: app_id.into() };
Ok(Response::new(reply))
}
async fn check(&self, request: Request<CheckRequest>) -> Result<Response<CheckReply>, Status> {
let report = ProfileReport::auto_reporting_file("brane-drv DriverHandler::check", "brane-drv_check");
let overhead = report.time("Handle overhead");
let CheckRequest { workflow } = request.into_inner();
debug!("Receiving check request");
debug!("Deserializing input workflow...");
let workflow: Workflow = match serde_json::from_str(&workflow) {
Ok(workflow) => workflow,
Err(err) => {
debug!("{}", trace!(("Incoming request has invalid workflow"), err));
return Err(Status::invalid_argument("Invalid workflow"));
},
};
debug!("Loading node config file '{}'...", self.node_config_path.display());
let central_cfg: CentralConfig = match NodeConfig::from_path_async(&self.node_config_path).await {
Ok(cfg) => match cfg.node {
NodeSpecificConfig::Central(central) => central,
NodeSpecificConfig::Worker(_) | NodeSpecificConfig::Proxy(_) => {
error!("Given node config file '{}' is for a {}, but expected a Central", self.node_config_path.display(), cfg.node.variant());
return Err(Status::internal("An internal error has occurred"));
},
},
Err(err) => {
error!("{}", trace!(("Failed to read node config file '{}'", self.node_config_path.display()), err));
return Err(Status::internal("An internal error has occurred"));
},
};
debug!("Loading infra file '{}'...", central_cfg.paths.infra.display());
let infra: InfraFile = match InfraFile::from_path_async(¢ral_cfg.paths.infra).await {
Ok(infra) => infra,
Err(err) => {
error!("{}", trace!(("Failed to read infra file '{}'", central_cfg.paths.infra.display()), err));
return Err(Status::internal("An internal error has occurred"));
},
};
overhead.stop();
debug!("Planning workflow on instance `brane-plr`...");
let wf_id: String = workflow.id.clone();
let workflow: Workflow =
match InstancePlanner::plan(¢ral_cfg.services.plr.address, AppId::generate(), workflow, report.nest("Planning")).await {
Ok(wf) => wf,
Err(PlanError::CheckerDenied { domain, reasons }) => {
debug!("Checker denied workflow during planning already");
return Ok(Response::new(CheckReply {
verdict: false,
who: Some(domain),
reasons,
profile: serde_json::to_string(report.scope()).ok(),
}));
},
Err(err) => {
error!("{}", trace!(("Failed to plan workflow '{wf_id}'"), err));
return Err(Status::internal("An internal error has occurred"));
},
};
debug!("Generating requests for workflow '{}'...", workflow.id);
let req_gen = report.time("Spawning requests");
let handles: Vec<(String, JoinHandle<RequestOutput>)> = match check::spawn_requests(&infra, &workflow) {
Ok(reqs) => reqs,
Err(err) => {
error!("{}", trace!(("Failed to spawn requests for workflow '{}'", workflow.id), err));
return Err(Status::internal("An internal error has occurred"));
},
};
req_gen.stop();
debug!("Waiting for requests for workflow '{}' to complete...", workflow.id);
let req_join = report.time("Joining requests");
let mut result: Option<(String, Vec<String>)> = None;
for (checker, handle) in handles {
let res: RequestOutput = match handle.await {
Ok(res) => res,
Err(err) => {
error!("{}", trace!(("Failed to await JoinHandle for workflow '{}'", workflow.id), err));
return Err(Status::internal("An internal error has occurred"));
},
};
match res {
Ok(None) => continue,
Ok(Some(who)) => {
result = Some(who);
break;
},
Err(err) => {
error!("{}", trace!(("Failed to ask checker '{checker}' for permission for workflow '{}'", workflow.id), err));
return Err(Status::internal("An internal error has occurred"));
},
}
}
req_join.stop();
info!("Checkers verdict for workflow '{}' is {}", workflow.id, if result.is_none() { "ALLOW" } else { "DENY" });
if let Some((who, reasons)) = result {
Ok(Response::new(CheckReply { verdict: false, who: Some(who), reasons, profile: serde_json::to_string(report.scope()).ok() }))
} else {
Ok(Response::new(CheckReply { verdict: true, who: None, reasons: vec![], profile: serde_json::to_string(report.scope()).ok() }))
}
}
async fn execute(&self, request: Request<ExecuteRequest>) -> Result<Response<Self::ExecuteStream>, Status> {
let report = ProfileReport::auto_reporting_file("brane-drv DriverHandler::execute", "brane-drv_execute");
let overhead = report.time("Handle overhead");
let request = request.into_inner();
debug!("Receiving execute request for session '{}'", request.uuid);
let (tx, rx) = mpsc::channel::<Result<ExecuteReply, Status>>(10);
let app_id: AppId = match AppId::from_str(&request.uuid) {
Ok(app_id) => app_id,
Err(err) => {
fatal_err!(tx, rx, Status::invalid_argument, err);
},
};
let sessions: Arc<DashMap<AppId, (InstanceVm, Instant)>> = self.sessions.clone();
let vm: InstanceVm = match sessions.get(&app_id) {
Some(vm) => vm.0.clone(),
None => {
fatal_err!(tx, rx, Status::internal(format!("No session with ID '{app_id}' found")));
},
};
overhead.stop();
tokio::spawn(async move {
debug!("Executing workflow for session '{}'", app_id);
let par = report.time("Workflow parsing");
debug!("Parsing workflow of {} characters", request.input.len());
let workflow: Workflow = match serde_json::from_str(&request.input) {
Ok(workflow) => workflow,
Err(err) => {
debug!(
"Workflow:\n{}\n{}\n{}\n\n",
(0..80).map(|_| '-').collect::<String>(),
request.input,
(0..80).map(|_| '-').collect::<String>()
);
fatal_err!(tx, Status::invalid_argument, err);
},
};
par.stop();
debug!("Executing workflow of {} edges", workflow.graph.len());
let (vm, res): (InstanceVm, Result<FullValue, RemoteVmError>) =
report.nest_fut("VM execution", |scope| vm.exec(tx.clone(), app_id.clone(), workflow, scope)).await;
debug!("Saving state session state");
sessions.insert(app_id, (vm, Instant::now()));
match res {
Ok(res) => {
debug!("Completed execution.");
let _ret = report.time("Returning value");
let sres: String = match serde_json::to_string(&res) {
Ok(sres) => sres,
Err(err) => {
fatal_err!(tx, Status::internal, err);
},
};
let msg = String::from("Driver completed execution.");
let reply = ExecuteReply { close: true, debug: Some(msg.clone()), stderr: None, stdout: None, value: Some(sres) };
if let Err(err) = tx.send(Ok(reply)).await {
error!("{}", trace!(("Failed to send workflow result back to client"), err));
}
},
Err(RemoteVmError::PlanError { err: PlanError::CheckerDenied { domain, reasons } }) => {
fatal_err!(
tx,
Status::permission_denied(format!(
"Checker of domain '{domain}' denied execution{}",
if !reasons.is_empty() {
format!("\n\nReasons:\n{}\n", reasons.iter().map(|r| format!(" - {r}")).collect::<Vec<String>>().join("\n"))
} else {
String::new()
}
))
);
},
Err(err) => {
fatal_err!(tx, Status::internal, err);
},
};
});
Ok(Response::new(ReceiverStream::new(rx)))
}
}