scylla/transport/load_balancing/
plan.rsuse tracing::error;
use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo};
use crate::transport::ClusterData;
enum PlanState<'a> {
Created,
PickedNone, Picked(NodeRef<'a>),
Fallback {
iter: FallbackPlan<'a>,
node_to_filter_out: NodeRef<'a>,
},
}
pub struct Plan<'a> {
policy: &'a dyn LoadBalancingPolicy,
routing_info: &'a RoutingInfo<'a>,
cluster: &'a ClusterData,
state: PlanState<'a>,
}
impl<'a> Plan<'a> {
pub fn new(
policy: &'a dyn LoadBalancingPolicy,
routing_info: &'a RoutingInfo<'a>,
cluster: &'a ClusterData,
) -> Self {
Self {
policy,
routing_info,
cluster,
state: PlanState::Created,
}
}
}
impl<'a> Iterator for Plan<'a> {
type Item = NodeRef<'a>;
fn next(&mut self) -> Option<Self::Item> {
match &mut self.state {
PlanState::Created => {
let picked = self.policy.pick(self.routing_info, self.cluster);
if let Some(picked) = picked {
self.state = PlanState::Picked(picked);
Some(picked)
} else {
let mut iter = self.policy.fallback(self.routing_info, self.cluster);
let first_fallback_node = iter.next();
if let Some(node) = first_fallback_node {
self.state = PlanState::Fallback {
iter,
node_to_filter_out: node,
};
Some(node)
} else {
error!("Load balancing policy returned an empty plan! The query cannot be executed. Routing info: {:?}", self.routing_info);
self.state = PlanState::PickedNone;
None
}
}
}
PlanState::Picked(node) => {
self.state = PlanState::Fallback {
iter: self.policy.fallback(self.routing_info, self.cluster),
node_to_filter_out: node,
};
self.next()
}
PlanState::Fallback {
iter,
node_to_filter_out,
} => {
for node in iter {
if node == *node_to_filter_out {
continue;
} else {
return Some(node);
}
}
None
}
PlanState::PickedNone => None,
}
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, str::FromStr, sync::Arc};
use crate::transport::{
locator::test::{create_locator, mock_metadata_for_token_aware_tests},
Node, NodeAddr,
};
use super::*;
fn expected_nodes() -> Vec<Arc<Node>> {
vec![Arc::new(Node::new_for_test(
NodeAddr::Translatable(SocketAddr::from_str("127.0.0.1:9042").unwrap()),
None,
None,
))]
}
#[derive(Debug)]
struct PickingNonePolicy {
expected_nodes: Vec<Arc<Node>>,
}
impl LoadBalancingPolicy for PickingNonePolicy {
fn pick<'a>(
&'a self,
_query: &'a RoutingInfo,
_cluster: &'a ClusterData,
) -> Option<NodeRef<'a>> {
None
}
fn fallback<'a>(
&'a self,
_query: &'a RoutingInfo,
_cluster: &'a ClusterData,
) -> FallbackPlan<'a> {
Box::new(self.expected_nodes.iter())
}
fn name(&self) -> String {
"PickingNone".into()
}
}
#[tokio::test]
async fn plan_calls_fallback_even_if_pick_returned_none() {
let policy = PickingNonePolicy {
expected_nodes: expected_nodes(),
};
let locator = create_locator(&mock_metadata_for_token_aware_tests());
let cluster_data = ClusterData {
known_peers: Default::default(),
keyspaces: Default::default(),
locator,
};
let routing_info = RoutingInfo::default();
let plan = Plan::new(&policy, &routing_info, &cluster_data);
assert_eq!(Vec::from_iter(plan.cloned()), policy.expected_nodes);
}
}