use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::error;
use std::fmt::{Display, Formatter, Result as FResult};
use std::panic::catch_unwind;
use brane_ast::spec::BuiltinFunctions;
use brane_ast::{MergeStrategy, ast};
use brane_exe::pc::{ProgramCounter, ResolvedProgramCounter};
use enum_debug::EnumDebug as _;
use log::{Level, debug, trace};
use specifications::data::{AvailabilityKind, DataName, PreprocessKind};
use super::preprocess;
use super::spec::{Dataset, Elem, ElemBranch, ElemCommit, ElemLoop, ElemParallel, ElemTask, User, Workflow};
use crate::{Metadata, utils};
#[derive(Debug)]
pub enum Error {
MissingUser,
Preprocess { err: super::preprocess::Error },
PcOutOfBounds { pc: ResolvedProgramCounter, max: usize },
ParallelMergeOutOfBounds { pc: ResolvedProgramCounter, merge: ResolvedProgramCounter },
ParallelWithNonJoin { pc: ResolvedProgramCounter, merge: ResolvedProgramCounter, got: String },
StrayJoin { pc: ResolvedProgramCounter },
IllegalCall { pc: ResolvedProgramCounter, name: String },
CommitTooMuchOutput { pc: ResolvedProgramCounter, got: usize },
CommitNoOutput { pc: ResolvedProgramCounter },
CommitReturnsResult { pc: ResolvedProgramCounter },
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> FResult {
use Error::*;
match self {
MissingUser => write!(f, "User not specified in given workflow"),
Preprocess { .. } => write!(f, "Failed to preprocess input WIR workflow"),
PcOutOfBounds { pc, max } => write!(
f,
"Program counter {} is out-of-bounds (function {} has {} edges)",
pc,
if let Some(func_name) = pc.func_name() { func_name.clone() } else { pc.func_id().to_string() },
max
),
ParallelMergeOutOfBounds { pc, merge } => {
write!(f, "Parallel edge at {pc}'s merge pointer {merge} is out-of-bounds")
},
ParallelWithNonJoin { pc, merge, got } => {
write!(f, "Parallel edge at {pc}'s merge edge (at {merge}) was not an Edge::Join, but an Edge::{got}")
},
StrayJoin { pc } => write!(f, "Found Join-edge without preceding Parallel-edge at {pc}"),
IllegalCall { pc, name } => {
write!(f, "Encountered illegal call to function '{name}' at {pc} (calls to non-task, non-builtin functions are not supported)")
},
CommitTooMuchOutput { pc, got } => {
write!(f, "Call to `commit_result()` as {pc} returns more than 1 outputs (got {got})")
},
CommitNoOutput { pc } => write!(f, "Call to `commit_result()` at {pc} does not return a dataset"),
CommitReturnsResult { pc } => {
write!(f, "Call to `commit_result()` at {pc} returns an IntermediateResult instead of a Data")
},
}
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
use Error::*;
match self {
MissingUser => None,
Preprocess { err, .. } => Some(err),
PcOutOfBounds { .. }
| ParallelMergeOutOfBounds { .. }
| ParallelWithNonJoin { .. }
| StrayJoin { .. }
| IllegalCall { .. }
| CommitTooMuchOutput { .. }
| CommitNoOutput { .. }
| CommitReturnsResult { .. } => None,
}
}
}
fn analyse_data_lkls(lkls: &mut HashMap<DataName, HashSet<String>>, wir: &ast::Workflow, pc: ProgramCounter, breakpoint: Option<ProgramCounter>) {
if let Some(breakpoint) = breakpoint {
if pc == breakpoint {
return;
}
}
let edge: &ast::Edge = match utils::get_edge(wir, pc) {
Some(edge) => edge,
None => return,
};
trace!("Analysing data LKLs in {:?}", edge.variant());
match edge {
ast::Edge::Linear { instrs: _, next } => {
analyse_data_lkls(lkls, wir, pc.jump(*next), breakpoint)
},
ast::Edge::Node { task: _, locs: _, at, input, result, metadata: _, next } => {
for (i, access) in input {
match access {
Some(AvailabilityKind::Available { .. }) => {
if let Some(at) = at {
*lkls.entry(i.clone()).or_default() = HashSet::from([at.clone()]);
}
},
Some(AvailabilityKind::Unavailable { how: PreprocessKind::TransferRegistryTar { location, dataname: _ } }) => {
*lkls.entry(i.clone()).or_default() = HashSet::from([location.clone()]);
},
None => continue,
}
}
if let (Some(result), Some(at)) = (result, at) {
*lkls.entry(DataName::IntermediateResult(result.clone())).or_default() = HashSet::from([at.clone()]);
}
analyse_data_lkls(lkls, wir, pc.jump(*next), breakpoint)
},
ast::Edge::Stop {} => (),
ast::Edge::Branch { true_next, false_next, merge } => {
analyse_data_lkls(lkls, wir, pc.jump(*true_next), merge.map(|m| pc.jump(m)));
if let Some(false_next) = false_next {
analyse_data_lkls(lkls, wir, pc.jump(*false_next), merge.map(|m| pc.jump(m)));
}
if let Some(merge) = merge {
analyse_data_lkls(lkls, wir, pc.jump(*merge), breakpoint)
}
},
ast::Edge::Parallel { branches, merge } => {
for branch in branches {
analyse_data_lkls(lkls, wir, pc.jump(*branch), Some(pc.jump(*merge)));
}
analyse_data_lkls(lkls, wir, pc.jump(*merge), breakpoint)
},
ast::Edge::Join { merge: _, next } => analyse_data_lkls(lkls, wir, pc.jump(*next), breakpoint),
ast::Edge::Loop { cond, body, next } => {
analyse_data_lkls(lkls, wir, pc.jump(*body), Some(pc.jump(*cond)));
analyse_data_lkls(lkls, wir, pc.jump(*cond), Some(pc.jump(*body - 1)));
if let Some(next) = next {
analyse_data_lkls(lkls, wir, pc.jump(*next), breakpoint);
}
},
ast::Edge::Call { input: _, result: _, next } => {
analyse_data_lkls(lkls, wir, pc.jump(*next), breakpoint)
},
ast::Edge::Return { result } => {
for res in result {
lkls.entry(res.clone()).or_default().insert("Danny Data Scientist".into());
}
},
}
}
fn reconstruct_graph(
wir: &ast::Workflow,
wf_id: &str,
calls: &HashMap<ProgramCounter, usize>,
lkls: &mut HashMap<DataName, HashSet<String>>,
pc: ProgramCounter,
plug: Elem,
breakpoint: Option<ProgramCounter>,
) -> Result<Elem, Error> {
if let Some(breakpoint) = breakpoint {
if pc == breakpoint {
return Ok(plug);
}
}
let edge: &ast::Edge = match utils::get_edge(wir, pc) {
Some(edge) => edge,
None => return Ok(plug),
};
trace!("Compiling {:?}", edge.variant());
match edge {
ast::Edge::Linear { next, .. } => {
reconstruct_graph(wir, wf_id, calls, lkls, pc.jump(*next), plug, breakpoint)
},
ast::Edge::Node { task, locs: _, at, input, result, metadata, next } => {
let def: &ast::ComputeTaskDef = match catch_unwind(|| wir.table.task(*task)) {
Ok(def) => {
if let ast::TaskDef::Compute(c) = def {
c
} else {
unimplemented!();
}
},
Err(_) => panic!("Encountered unknown task '{task}' after preprocessing"),
};
Ok(Elem::Task(ElemTask {
id: format!("{}-{}-task", wf_id, pc.resolved(&wir.table)),
name: def.function.name.clone(),
package: def.package.clone(),
version: def.version,
input: input
.iter()
.map(|(name, avail)| Dataset {
name: name.name().into(),
from: avail.as_ref().and_then(|avail| match avail {
AvailabilityKind::Available { how: _ } => None,
AvailabilityKind::Unavailable { how: PreprocessKind::TransferRegistryTar { location, dataname: _ } } => {
Some(location.clone())
},
}),
})
.collect(),
output: result.as_ref().map(|name| Dataset { name: name.clone(), from: None }),
location: at.clone(),
metadata: metadata
.iter()
.map(|md| Metadata { owner: md.owner.clone(), tag: md.tag.clone(), signature: md.signature.clone() })
.collect(),
next: Box::new(reconstruct_graph(wir, wf_id, calls, lkls, pc.jump(*next), plug, breakpoint)?),
}))
},
ast::Edge::Stop {} => Ok(Elem::Stop(HashSet::new())),
ast::Edge::Branch { true_next, false_next, merge } => {
let mut branches: Vec<Elem> =
vec![reconstruct_graph(wir, wf_id, calls, lkls, pc.jump(*true_next), Elem::Next, merge.map(|merge| pc.jump(merge)))?];
if let Some(false_next) = false_next {
branches.push(reconstruct_graph(wir, wf_id, calls, lkls, pc.jump(*false_next), Elem::Next, merge.map(|merge| pc.jump(merge)))?)
}
let next: Elem = merge
.map(|merge| reconstruct_graph(wir, wf_id, calls, lkls, pc.jump(merge), plug, breakpoint))
.transpose()?
.unwrap_or(Elem::Stop(HashSet::new()));
Ok(Elem::Branch(ElemBranch { branches, next: Box::new(next) }))
},
ast::Edge::Parallel { branches, merge } => {
let mut elem_branches: Vec<Elem> = Vec::with_capacity(branches.len());
for branch in branches {
elem_branches.push(reconstruct_graph(wir, wf_id, calls, lkls, pc.jump(*branch), Elem::Next, Some(pc.jump(*merge)))?);
}
let merge_edge: &ast::Edge = match utils::get_edge(wir, pc.jump(*merge)) {
Some(edge) => edge,
None => return Err(Error::ParallelMergeOutOfBounds { pc: pc.resolved(&wir.table), merge: pc.jump(*merge).resolved(&wir.table) }),
};
let (strategy, next): (MergeStrategy, usize) = if let ast::Edge::Join { merge, next } = merge_edge {
(*merge, *next)
} else {
return Err(Error::ParallelWithNonJoin {
pc: pc.resolved(&wir.table),
merge: pc.jump(*merge).resolved(&wir.table),
got: merge_edge.variant().to_string(),
});
};
let next: Elem = reconstruct_graph(wir, wf_id, calls, lkls, pc.jump(next), plug, breakpoint)?;
Ok(Elem::Parallel(ElemParallel { branches: elem_branches, merge: strategy, next: Box::new(next) }))
},
ast::Edge::Join { .. } => Err(Error::StrayJoin { pc: pc.resolved(&wir.table) }),
ast::Edge::Loop { cond, body, next } => {
let body_elems: Elem = reconstruct_graph(wir, wf_id, calls, lkls, pc.jump(*body), Elem::Next, Some(pc.jump(*cond)))?;
let cond: Elem = reconstruct_graph(wir, wf_id, calls, lkls, pc.jump(*cond), body_elems, Some(pc.jump(*body - 1)))?;
let next: Elem = next
.map(|next| reconstruct_graph(wir, wf_id, calls, lkls, pc.jump(next), plug, breakpoint))
.transpose()?
.unwrap_or(Elem::Stop(HashSet::new()));
Ok(Elem::Loop(ElemLoop { body: Box::new(cond), next: Box::new(next) }))
},
ast::Edge::Call { input, result, next } => {
let func_def: &ast::FunctionDef = match calls.get(&pc) {
Some(id) => match wir.table.funcs.get(*id) {
Some(def) => def,
None => panic!("Encountered unknown function '{id}' after preprocessing"),
},
None => panic!("Encountered unresolved call after preprocessing"),
};
if func_def.name == BuiltinFunctions::CommitResult.name() {
let mut locs: HashSet<String> = HashSet::with_capacity(input.len());
let mut new_input: Vec<Dataset> = Vec::with_capacity(input.len());
for i in input {
let location: Option<String> = lkls.get(i).and_then(|locs| locs.iter().next().cloned());
if let Some(location) = &location {
locs.insert(location.clone());
}
new_input.push(Dataset { name: i.name().into(), from: location });
}
if result.len() > 1 {
return Err(Error::CommitTooMuchOutput { pc: pc.resolved(&wir.table), got: result.len() });
}
let data_name: String = if let Some(name) = result.iter().next() {
if let DataName::Data(name) = name {
name.clone()
} else {
return Err(Error::CommitReturnsResult { pc: pc.resolved(&wir.table) });
}
} else {
return Err(Error::CommitNoOutput { pc: pc.resolved(&wir.table) });
};
let next: Elem = reconstruct_graph(wir, wf_id, calls, lkls, pc.jump(*next), plug, breakpoint)?;
Ok(Elem::Commit(ElemCommit {
id: format!("{}-{}-commit", wf_id, pc.resolved(&wir.table)),
data_name,
location: locs.into_iter().next(),
input: new_input,
next: Box::new(next),
}))
} else if func_def.name == BuiltinFunctions::Print.name()
|| func_def.name == BuiltinFunctions::PrintLn.name()
|| func_def.name == BuiltinFunctions::Len.name()
{
reconstruct_graph(wir, wf_id, calls, lkls, pc.jump(*next), plug, breakpoint)
} else {
Err(Error::IllegalCall { pc: pc.resolved(&wir.table), name: func_def.name.clone() })
}
},
ast::Edge::Return { result } => Ok(Elem::Stop(result.iter().map(|data| Dataset { name: data.name().into(), from: None }).collect())),
}
}
impl TryFrom<ast::Workflow> for Workflow {
type Error = Error;
#[inline]
fn try_from(value: ast::Workflow) -> Result<Self, Self::Error> {
let mut buf: Vec<u8> = Vec::new();
brane_ast::traversals::print::ast::do_traversal(&value, &mut buf).unwrap();
debug!("Compiling workflow:\n\n{}\n", String::from_utf8(buf).unwrap());
let user: String = if let Some(user) = (*value.user).clone() {
user
} else {
return Err(Error::MissingUser);
};
let wf_id: String = value.id.clone();
let (wir, calls): (ast::Workflow, HashMap<ProgramCounter, usize>) = match preprocess::simplify(value) {
Ok(res) => res,
Err(err) => return Err(Error::Preprocess { err }),
};
if log::max_level() >= Level::Debug {
let mut buf: Vec<u8> = vec![];
brane_ast::traversals::print::ast::do_traversal(&wir, &mut buf).unwrap();
debug!("Preprocessed workflow:\n\n{}\n", String::from_utf8_lossy(&buf));
}
let mut lkls: HashMap<DataName, HashSet<String>> = HashMap::new();
analyse_data_lkls(&mut lkls, &wir, ProgramCounter::start(), None);
let graph: Elem = reconstruct_graph(&wir, &wf_id, &calls, &mut lkls, ProgramCounter::start(), Elem::Stop(HashSet::new()), None)?;
Ok(Self {
id: wf_id,
start: graph,
user: User { name: user },
metadata: wir
.metadata
.iter()
.map(|md| Metadata { owner: md.owner.clone(), tag: md.tag.clone(), signature: md.signature.clone() })
.collect(),
signature: "its_signed_i_swear_mom".into(),
})
}
}