use crate::frame::response::event::{Event, StatusChangeEvent};
use crate::prepared_statement::TokenCalculationError;
use crate::routing::Token;
use crate::transport::host_filter::HostFilter;
use crate::transport::{
connection::{Connection, VerifiedKeyspaceName},
connection_pool::PoolConfig,
errors::QueryError,
node::Node,
partitioner::PartitionerName,
topology::{Keyspace, Metadata, MetadataReader},
};
use arc_swap::ArcSwap;
use futures::future::join_all;
use futures::{future::RemoteHandle, FutureExt};
use itertools::Itertools;
use scylla_cql::errors::{BadQuery, NewSessionError};
use scylla_cql::types::serialize::row::SerializedValues;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tracing::instrument::WithSubscriber;
use tracing::{debug, warn};
use uuid::Uuid;
use super::node::{KnownNode, NodeAddr};
use super::locator::ReplicaLocator;
use super::partitioner::calculate_token_for_partition_key;
use super::topology::Strategy;
pub(crate) struct Cluster {
data: Arc<ArcSwap<ClusterData>>,
refresh_channel: tokio::sync::mpsc::Sender<RefreshRequest>,
use_keyspace_channel: tokio::sync::mpsc::Sender<UseKeyspaceRequest>,
_worker_handle: RemoteHandle<()>,
}
pub(crate) struct ClusterNeatDebug<'a>(pub(crate) &'a Cluster);
impl<'a> std::fmt::Debug for ClusterNeatDebug<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let cluster = self.0;
f.debug_struct("Cluster")
.field("data", &ClusterDataNeatDebug(&cluster.data.load()))
.finish_non_exhaustive()
}
}
#[derive(Clone, Debug)]
pub struct Datacenter {
pub nodes: Vec<Arc<Node>>,
pub rack_count: usize,
}
#[derive(Clone)]
pub struct ClusterData {
pub(crate) known_peers: HashMap<Uuid, Arc<Node>>, pub(crate) keyspaces: HashMap<String, Keyspace>,
pub(crate) locator: ReplicaLocator,
}
pub(crate) struct ClusterDataNeatDebug<'a>(pub(crate) &'a Arc<ClusterData>);
impl<'a> std::fmt::Debug for ClusterDataNeatDebug<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let cluster_data = &self.0;
f.debug_struct("ClusterData")
.field("known_peers", &cluster_data.known_peers)
.field("ring", {
struct RingSizePrinter(usize);
impl std::fmt::Debug for RingSizePrinter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<size={}>", self.0)
}
}
&RingSizePrinter(cluster_data.locator.ring().len())
})
.field("keyspaces", &cluster_data.keyspaces.keys())
.finish_non_exhaustive()
}
}
struct ClusterWorker {
cluster_data: Arc<ArcSwap<ClusterData>>,
metadata_reader: MetadataReader,
pool_config: PoolConfig,
refresh_channel: tokio::sync::mpsc::Receiver<RefreshRequest>,
use_keyspace_channel: tokio::sync::mpsc::Receiver<UseKeyspaceRequest>,
server_events_channel: tokio::sync::mpsc::Receiver<Event>,
control_connection_repair_channel: tokio::sync::broadcast::Receiver<()>,
used_keyspace: Option<VerifiedKeyspaceName>,
host_filter: Option<Arc<dyn HostFilter>>,
cluster_metadata_refresh_interval: Duration,
}
#[derive(Debug)]
struct RefreshRequest {
response_chan: tokio::sync::oneshot::Sender<Result<(), QueryError>>,
}
#[derive(Debug)]
struct UseKeyspaceRequest {
keyspace_name: VerifiedKeyspaceName,
response_chan: tokio::sync::oneshot::Sender<Result<(), QueryError>>,
}
impl Cluster {
pub(crate) async fn new(
known_nodes: Vec<KnownNode>,
pool_config: PoolConfig,
keyspaces_to_fetch: Vec<String>,
fetch_schema_metadata: bool,
host_filter: Option<Arc<dyn HostFilter>>,
cluster_metadata_refresh_interval: Duration,
) -> Result<Cluster, NewSessionError> {
let (refresh_sender, refresh_receiver) = tokio::sync::mpsc::channel(32);
let (use_keyspace_sender, use_keyspace_receiver) = tokio::sync::mpsc::channel(32);
let (server_events_sender, server_events_receiver) = tokio::sync::mpsc::channel(32);
let (control_connection_repair_sender, control_connection_repair_receiver) =
tokio::sync::broadcast::channel(32);
let mut metadata_reader = MetadataReader::new(
known_nodes,
control_connection_repair_sender,
pool_config.connection_config.clone(),
pool_config.keepalive_interval,
server_events_sender,
keyspaces_to_fetch,
fetch_schema_metadata,
&host_filter,
)
.await?;
let metadata = metadata_reader.read_metadata(true).await?;
let cluster_data = ClusterData::new(
metadata,
&pool_config,
&HashMap::new(),
&None,
host_filter.as_deref(),
)
.await;
cluster_data.wait_until_all_pools_are_initialized().await;
let cluster_data: Arc<ArcSwap<ClusterData>> =
Arc::new(ArcSwap::from(Arc::new(cluster_data)));
let worker = ClusterWorker {
cluster_data: cluster_data.clone(),
metadata_reader,
pool_config,
refresh_channel: refresh_receiver,
server_events_channel: server_events_receiver,
control_connection_repair_channel: control_connection_repair_receiver,
use_keyspace_channel: use_keyspace_receiver,
used_keyspace: None,
host_filter,
cluster_metadata_refresh_interval,
};
let (fut, worker_handle) = worker.work().remote_handle();
tokio::spawn(fut.with_current_subscriber());
let result = Cluster {
data: cluster_data,
refresh_channel: refresh_sender,
use_keyspace_channel: use_keyspace_sender,
_worker_handle: worker_handle,
};
Ok(result)
}
pub(crate) fn get_data(&self) -> Arc<ClusterData> {
self.data.load_full()
}
pub(crate) async fn refresh_metadata(&self) -> Result<(), QueryError> {
let (response_sender, response_receiver) = tokio::sync::oneshot::channel();
self.refresh_channel
.send(RefreshRequest {
response_chan: response_sender,
})
.await
.expect("Bug in Cluster::refresh_metadata sending");
response_receiver
.await
.expect("Bug in Cluster::refresh_metadata receiving")
}
pub(crate) async fn use_keyspace(
&self,
keyspace_name: VerifiedKeyspaceName,
) -> Result<(), QueryError> {
let (response_sender, response_receiver) = tokio::sync::oneshot::channel();
self.use_keyspace_channel
.send(UseKeyspaceRequest {
keyspace_name,
response_chan: response_sender,
})
.await
.expect("Bug in Cluster::use_keyspace sending");
response_receiver.await.unwrap() }
}
impl ClusterData {
fn update_rack_count(datacenters: &mut HashMap<String, Datacenter>) {
for datacenter in datacenters.values_mut() {
datacenter.rack_count = datacenter
.nodes
.iter()
.filter_map(|node| node.rack.as_ref())
.unique()
.count();
}
}
pub(crate) async fn wait_until_all_pools_are_initialized(&self) {
for node in self.locator.unique_nodes_in_global_ring().iter() {
node.wait_until_pool_initialized().await;
}
}
pub(crate) async fn new(
metadata: Metadata,
pool_config: &PoolConfig,
known_peers: &HashMap<Uuid, Arc<Node>>,
used_keyspace: &Option<VerifiedKeyspaceName>,
host_filter: Option<&dyn HostFilter>,
) -> Self {
let mut new_known_peers: HashMap<Uuid, Arc<Node>> =
HashMap::with_capacity(metadata.peers.len());
let mut ring: Vec<(Token, Arc<Node>)> = Vec::new();
let mut datacenters: HashMap<String, Datacenter> = HashMap::new();
let mut all_nodes: Vec<Arc<Node>> = Vec::with_capacity(metadata.peers.len());
for peer in metadata.peers {
let peer_host_id = peer.host_id;
let peer_address = peer.address;
let peer_tokens;
let node: Arc<Node> = match known_peers.get(&peer_host_id) {
Some(node) if node.datacenter == peer.datacenter && node.rack == peer.rack => {
let (peer_endpoint, tokens) = peer.into_peer_endpoint_and_tokens();
peer_tokens = tokens;
if node.address == peer_address {
node.clone()
} else {
Arc::new(Node::inherit_with_ip_changed(node, peer_endpoint))
}
}
_ => {
let is_enabled = host_filter.map_or(true, |f| f.accept(&peer));
let (peer_endpoint, tokens) = peer.into_peer_endpoint_and_tokens();
peer_tokens = tokens;
Arc::new(Node::new(
peer_endpoint,
pool_config.clone(),
used_keyspace.clone(),
is_enabled,
))
}
};
new_known_peers.insert(peer_host_id, node.clone());
if let Some(dc) = &node.datacenter {
match datacenters.get_mut(dc) {
Some(v) => v.nodes.push(node.clone()),
None => {
let v = Datacenter {
nodes: vec![node.clone()],
rack_count: 0,
};
datacenters.insert(dc.clone(), v);
}
}
}
for token in peer_tokens {
ring.push((token, node.clone()));
}
all_nodes.push(node);
}
Self::update_rack_count(&mut datacenters);
let keyspaces = metadata.keyspaces;
let (locator, keyspaces) = tokio::task::spawn_blocking(move || {
let keyspace_strategies = keyspaces.values().map(|ks| &ks.strategy);
let locator = ReplicaLocator::new(ring.into_iter(), keyspace_strategies);
(locator, keyspaces)
})
.await
.unwrap();
ClusterData {
known_peers: new_known_peers,
keyspaces,
locator,
}
}
pub fn get_keyspace_info(&self) -> &HashMap<String, Keyspace> {
&self.keyspaces
}
pub fn get_datacenters_info(&self) -> HashMap<String, Datacenter> {
self.locator
.datacenter_names()
.iter()
.map(|dc_name| {
let nodes = self
.locator
.unique_nodes_in_datacenter_ring(dc_name)
.unwrap()
.to_vec();
let rack_count = nodes.iter().map(|node| node.rack.as_ref()).unique().count();
(dc_name.clone(), Datacenter { nodes, rack_count })
})
.collect()
}
pub fn get_nodes_info(&self) -> &[Arc<Node>] {
self.locator.unique_nodes_in_global_ring()
}
pub fn compute_token(
&self,
keyspace: &str,
table: &str,
partition_key: &SerializedValues,
) -> Result<Token, BadQuery> {
let partitioner = self
.keyspaces
.get(keyspace)
.and_then(|k| k.tables.get(table))
.and_then(|t| t.partitioner.as_deref())
.and_then(PartitionerName::from_str)
.unwrap_or_default();
calculate_token_for_partition_key(partition_key, &partitioner).map_err(|err| match err {
TokenCalculationError::ValueTooLong(values_len) => {
BadQuery::ValuesTooLongForKey(values_len, u16::MAX.into())
}
})
}
pub fn get_token_endpoints(&self, keyspace: &str, token: Token) -> Vec<Arc<Node>> {
self.get_token_endpoints_iter(keyspace, token)
.cloned()
.collect()
}
pub(crate) fn get_token_endpoints_iter(
&self,
keyspace: &str,
token: Token,
) -> impl Iterator<Item = &Arc<Node>> {
let keyspace = self.keyspaces.get(keyspace);
let strategy = keyspace
.map(|k| &k.strategy)
.unwrap_or(&Strategy::LocalStrategy);
let replica_set = self
.replica_locator()
.replicas_for_token(token, strategy, None);
replica_set.into_iter()
}
pub fn get_endpoints(
&self,
keyspace: &str,
table: &str,
partition_key: &SerializedValues,
) -> Result<Vec<Arc<Node>>, BadQuery> {
Ok(self.get_token_endpoints(
keyspace,
self.compute_token(keyspace, table, partition_key)?,
))
}
pub fn replica_locator(&self) -> &ReplicaLocator {
&self.locator
}
pub(crate) fn iter_working_connections(
&self,
) -> Result<impl Iterator<Item = Arc<Connection>> + '_, QueryError> {
assert!(!self.known_peers.is_empty());
let mut peers_iter = self.known_peers.values();
let first_working_pool = peers_iter
.by_ref()
.map(|node| node.get_working_connections())
.find_or_first(Result::is_ok)
.expect("impossible: known_peers was asserted to be nonempty")?;
let remaining_pools_iter = peers_iter
.map(|node| node.get_working_connections())
.flatten_ok()
.flatten();
Ok(first_working_pool.into_iter().chain(remaining_pools_iter))
}
}
impl ClusterWorker {
pub(crate) async fn work(mut self) {
use tokio::time::Instant;
let control_connection_repair_duration = Duration::from_secs(1); let mut last_refresh_time = Instant::now();
let mut control_connection_works = true;
loop {
let mut cur_request: Option<RefreshRequest> = None;
let sleep_until: Instant = last_refresh_time
.checked_add(if control_connection_works {
self.cluster_metadata_refresh_interval
} else {
control_connection_repair_duration
})
.unwrap_or_else(Instant::now);
let sleep_future = tokio::time::sleep_until(sleep_until);
tokio::pin!(sleep_future);
tokio::select! {
_ = sleep_future => {},
recv_res = self.refresh_channel.recv() => {
match recv_res {
Some(request) => cur_request = Some(request),
None => return, }
}
recv_res = self.server_events_channel.recv() => {
if let Some(event) = recv_res {
debug!("Received server event: {:?}", event);
match event {
Event::TopologyChange(_) => (), Event::StatusChange(status) => {
match status {
StatusChangeEvent::Down(addr) => self.change_node_down_marker(addr, true),
StatusChangeEvent::Up(addr) => self.change_node_down_marker(addr, false),
}
continue;
},
_ => continue, }
} else {
return;
}
}
recv_res = self.use_keyspace_channel.recv() => {
match recv_res {
Some(request) => {
self.used_keyspace = Some(request.keyspace_name.clone());
let cluster_data = self.cluster_data.load_full();
let use_keyspace_future = Self::handle_use_keyspace_request(cluster_data, request);
tokio::spawn(use_keyspace_future.with_current_subscriber());
},
None => return, }
continue; }
recv_res = self.control_connection_repair_channel.recv() => {
match recv_res {
Ok(()) => {
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
return;
}
}
}
}
debug!("Requesting topology refresh");
last_refresh_time = Instant::now();
let refresh_res = self.perform_refresh().await;
control_connection_works = refresh_res.is_ok();
if let Some(request) = cur_request {
let _ = request.response_chan.send(refresh_res);
}
}
}
fn change_node_down_marker(&mut self, addr: SocketAddr, is_down: bool) {
let cluster_data = self.cluster_data.load_full();
let node = match cluster_data
.known_peers
.values()
.find(|&peer| peer.address == NodeAddr::Translatable(addr))
{
Some(node) => node,
None => {
warn!("Unknown node address {}", addr);
return;
}
};
node.change_down_marker(is_down);
}
async fn handle_use_keyspace_request(
cluster_data: Arc<ClusterData>,
request: UseKeyspaceRequest,
) {
let result = Self::send_use_keyspace(cluster_data, &request.keyspace_name).await;
let _ = request.response_chan.send(result);
}
async fn send_use_keyspace(
cluster_data: Arc<ClusterData>,
keyspace_name: &VerifiedKeyspaceName,
) -> Result<(), QueryError> {
let use_keyspace_futures = cluster_data
.known_peers
.values()
.map(|node| node.use_keyspace(keyspace_name.clone()));
let use_keyspace_results: Vec<Result<(), QueryError>> =
join_all(use_keyspace_futures).await;
let mut was_ok: bool = false;
let mut io_error: Option<Arc<std::io::Error>> = None;
for result in use_keyspace_results {
match result {
Ok(()) => was_ok = true,
Err(err) => match err {
QueryError::IoError(io_err) => io_error = Some(io_err),
_ => return Err(err),
},
}
}
if was_ok {
return Ok(());
}
Err(QueryError::IoError(io_error.unwrap()))
}
async fn perform_refresh(&mut self) -> Result<(), QueryError> {
let metadata = self.metadata_reader.read_metadata(false).await?;
let cluster_data: Arc<ClusterData> = self.cluster_data.load_full();
let new_cluster_data = Arc::new(
ClusterData::new(
metadata,
&self.pool_config,
&cluster_data.known_peers,
&self.used_keyspace,
self.host_filter.as_deref(),
)
.await,
);
new_cluster_data
.wait_until_all_pools_are_initialized()
.await;
self.update_cluster_data(new_cluster_data);
Ok(())
}
fn update_cluster_data(&mut self, new_cluster_data: Arc<ClusterData>) {
self.cluster_data.store(new_cluster_data);
}
}