use bytes::Bytes;
use futures::{future::RemoteHandle, FutureExt};
use scylla_cql::errors::TranslationError;
use scylla_cql::frame::request::options::Options;
use scylla_cql::frame::response::Error;
use scylla_cql::frame::types::SerialConsistency;
use scylla_cql::types::serialize::batch::{BatchValues, BatchValuesIterator};
use scylla_cql::types::serialize::raw_batch::RawBatchValuesAdapter;
use scylla_cql::types::serialize::row::{RowSerializationContext, SerializedValues};
use socket2::{SockRef, TcpKeepalive};
use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpSocket, TcpStream};
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::instrument::WithSubscriber;
use tracing::{debug, error, trace, warn};
use uuid::Uuid;
use std::borrow::Cow;
#[cfg(feature = "ssl")]
use std::pin::Pin;
use std::sync::atomic::AtomicU64;
use std::time::Duration;
#[cfg(feature = "ssl")]
use tokio_openssl::SslStream;
#[cfg(feature = "ssl")]
pub(crate) use ssl_config::SslConfig;
use crate::authentication::AuthenticatorProvider;
use scylla_cql::frame::response::authenticate::Authenticate;
use std::collections::{BTreeSet, HashMap, HashSet};
use std::convert::TryFrom;
use std::io::ErrorKind;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::{
cmp::Ordering,
net::{Ipv4Addr, Ipv6Addr},
};
use super::errors::{BadKeyspaceName, DbError, QueryError};
use super::iterator::RowIterator;
use super::session::AddressTranslator;
use super::topology::{PeerEndpoint, UntranslatedEndpoint, UntranslatedPeer};
use super::NodeAddr;
#[cfg(feature = "cloud")]
use crate::cloud::CloudConfig;
use crate::batch::{Batch, BatchStatement};
use crate::frame::protocol_features::ProtocolFeatures;
use crate::frame::{
self,
request::{self, batch, execute, query, register, SerializableRequest},
response::{event::Event, result, NonErrorResponse, Response, ResponseOpcode},
server_event_type::EventType,
FrameParams, SerializedRequest,
};
use crate::query::Query;
use crate::routing::ShardInfo;
use crate::statement::prepared_statement::PreparedStatement;
use crate::statement::Consistency;
use crate::transport::session::IntoTypedRows;
use crate::transport::Compression;
use crate::QueryResult;
const LOCAL_VERSION: &str = "SELECT schema_version FROM system.local WHERE key='local'";
const OLD_ORPHAN_COUNT_THRESHOLD: usize = 1024;
const OLD_AGE_ORPHAN_THRESHOLD: std::time::Duration = std::time::Duration::from_secs(1);
pub(crate) struct Connection {
_worker_handle: RemoteHandle<()>,
connect_address: SocketAddr,
config: ConnectionConfig,
features: ConnectionFeatures,
router_handle: Arc<RouterHandle>,
}
struct RouterHandle {
submit_channel: mpsc::Sender<Task>,
request_id_generator: AtomicU64,
orphan_notification_sender: mpsc::UnboundedSender<RequestId>,
}
impl RouterHandle {
fn allocate_request_id(&self) -> RequestId {
self.request_id_generator
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
async fn send_request(
&self,
request: &impl SerializableRequest,
compression: Option<Compression>,
tracing: bool,
) -> Result<TaskResponse, QueryError> {
let serialized_request = SerializedRequest::make(request, compression, tracing)?;
let request_id = self.allocate_request_id();
let (response_sender, receiver) = oneshot::channel();
let response_handler = ResponseHandler {
response_sender,
request_id,
};
let notifier = OrphanhoodNotifier::new(request_id, &self.orphan_notification_sender);
self.submit_channel
.send(Task {
serialized_request,
response_handler,
})
.await
.map_err(|_| {
QueryError::IoError(Arc::new(std::io::Error::new(
ErrorKind::Other,
"Connection broken",
)))
})?;
let task_response = receiver.await.map_err(|_| {
QueryError::IoError(Arc::new(std::io::Error::new(
ErrorKind::Other,
"Connection broken",
)))
})?;
notifier.disable();
task_response
}
}
#[derive(Default)]
pub(crate) struct ConnectionFeatures {
shard_info: Option<ShardInfo>,
shard_aware_port: Option<u16>,
protocol_features: ProtocolFeatures,
}
type RequestId = u64;
struct ResponseHandler {
response_sender: oneshot::Sender<Result<TaskResponse, QueryError>>,
request_id: RequestId,
}
struct OrphanhoodNotifier<'a> {
enabled: bool,
request_id: RequestId,
notification_sender: &'a mpsc::UnboundedSender<RequestId>,
}
impl<'a> OrphanhoodNotifier<'a> {
fn new(
request_id: RequestId,
notification_sender: &'a mpsc::UnboundedSender<RequestId>,
) -> Self {
Self {
enabled: true,
request_id,
notification_sender,
}
}
fn disable(mut self) {
self.enabled = false;
}
}
impl<'a> Drop for OrphanhoodNotifier<'a> {
fn drop(&mut self) {
if self.enabled {
let _ = self.notification_sender.send(self.request_id);
}
}
}
struct Task {
serialized_request: SerializedRequest,
response_handler: ResponseHandler,
}
struct TaskResponse {
params: FrameParams,
opcode: ResponseOpcode,
body: Bytes,
}
pub(crate) struct QueryResponse {
pub(crate) response: Response,
pub(crate) tracing_id: Option<Uuid>,
pub(crate) warnings: Vec<String>,
}
pub(crate) struct NonErrorQueryResponse {
pub(crate) response: NonErrorResponse,
pub(crate) tracing_id: Option<Uuid>,
pub(crate) warnings: Vec<String>,
}
impl QueryResponse {
pub(crate) fn into_non_error_query_response(self) -> Result<NonErrorQueryResponse, QueryError> {
Ok(NonErrorQueryResponse {
response: self.response.into_non_error_response()?,
tracing_id: self.tracing_id,
warnings: self.warnings,
})
}
pub(crate) fn into_query_result(self) -> Result<QueryResult, QueryError> {
self.into_non_error_query_response()?.into_query_result()
}
}
impl NonErrorQueryResponse {
pub(crate) fn as_set_keyspace(&self) -> Option<&result::SetKeyspace> {
match &self.response {
NonErrorResponse::Result(result::Result::SetKeyspace(sk)) => Some(sk),
_ => None,
}
}
pub(crate) fn as_schema_change(&self) -> Option<&result::SchemaChange> {
match &self.response {
NonErrorResponse::Result(result::Result::SchemaChange(sc)) => Some(sc),
_ => None,
}
}
pub(crate) fn into_query_result(self) -> Result<QueryResult, QueryError> {
let (rows, paging_state, col_specs, serialized_size) = match self.response {
NonErrorResponse::Result(result::Result::Rows(rs)) => (
Some(rs.rows),
rs.metadata.paging_state,
rs.metadata.col_specs,
rs.serialized_size,
),
NonErrorResponse::Result(_) => (None, None, vec![], 0),
_ => {
return Err(QueryError::ProtocolError(
"Unexpected server response, expected Result or Error",
))
}
};
Ok(QueryResult {
rows,
warnings: self.warnings,
tracing_id: self.tracing_id,
paging_state,
col_specs,
serialized_size,
})
}
}
#[cfg(feature = "ssl")]
mod ssl_config {
use openssl::{
error::ErrorStack,
ssl::{Ssl, SslContext},
};
#[cfg(feature = "cloud")]
use uuid::Uuid;
#[derive(Clone)]
pub struct SslConfig {
context: SslContext,
#[cfg(feature = "cloud")]
sni: Option<String>,
}
impl SslConfig {
pub fn new_with_global_context(context: SslContext) -> Self {
Self {
context,
#[cfg(feature = "cloud")]
sni: None,
}
}
#[cfg(feature = "cloud")]
pub(crate) fn new_for_sni(
context: SslContext,
domain_name: &str,
host_id: Option<Uuid>,
) -> Self {
Self {
context,
#[cfg(feature = "cloud")]
sni: Some(if let Some(host_id) = host_id {
format!("{}.{}", host_id, domain_name)
} else {
domain_name.into()
}),
}
}
pub(crate) fn new_ssl(&self) -> Result<Ssl, ErrorStack> {
#[allow(unused_mut)]
let mut ssl = Ssl::new(&self.context)?;
#[cfg(feature = "cloud")]
if let Some(sni) = self.sni.as_ref() {
ssl.set_hostname(sni)?;
}
Ok(ssl)
}
}
}
#[derive(Clone)]
pub struct ConnectionConfig {
pub compression: Option<Compression>,
pub tcp_nodelay: bool,
pub tcp_keepalive_interval: Option<Duration>,
#[cfg(feature = "ssl")]
pub ssl_config: Option<SslConfig>,
pub connect_timeout: std::time::Duration,
pub event_sender: Option<mpsc::Sender<Event>>,
pub default_consistency: Consistency,
#[cfg(feature = "cloud")]
pub(crate) cloud_config: Option<Arc<CloudConfig>>,
pub authenticator: Option<Arc<dyn AuthenticatorProvider>>,
pub address_translator: Option<Arc<dyn AddressTranslator>>,
pub enable_write_coalescing: bool,
pub keepalive_interval: Option<Duration>,
pub keepalive_timeout: Option<Duration>,
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
compression: None,
tcp_nodelay: true,
tcp_keepalive_interval: None,
event_sender: None,
#[cfg(feature = "ssl")]
ssl_config: None,
connect_timeout: std::time::Duration::from_secs(5),
default_consistency: Default::default(),
authenticator: None,
address_translator: None,
#[cfg(feature = "cloud")]
cloud_config: None,
enable_write_coalescing: true,
keepalive_interval: None,
keepalive_timeout: None,
}
}
}
impl ConnectionConfig {
#[cfg(feature = "ssl")]
pub fn is_ssl(&self) -> bool {
#[cfg(feature = "cloud")]
if self.cloud_config.is_some() {
return true;
}
self.ssl_config.is_some()
}
#[cfg(not(feature = "ssl"))]
pub fn is_ssl(&self) -> bool {
false
}
}
pub(crate) type ErrorReceiver = tokio::sync::oneshot::Receiver<QueryError>;
impl Connection {
pub(crate) async fn new(
addr: SocketAddr,
source_port: Option<u16>,
config: ConnectionConfig,
) -> Result<(Self, ErrorReceiver), QueryError> {
let stream_connector = match source_port {
Some(p) => {
tokio::time::timeout(config.connect_timeout, connect_with_source_port(addr, p))
.await
}
None => tokio::time::timeout(config.connect_timeout, TcpStream::connect(addr)).await,
};
let stream = match stream_connector {
Ok(stream) => stream?,
Err(_) => {
return Err(QueryError::TimeoutError);
}
};
stream.set_nodelay(config.tcp_nodelay)?;
if let Some(tcp_keepalive_interval) = config.tcp_keepalive_interval {
let mut tcp_keepalive = TcpKeepalive::new().with_time(tcp_keepalive_interval);
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "windows",
))]
{
tcp_keepalive = tcp_keepalive.with_interval(Duration::from_secs(1));
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
))]
{
tcp_keepalive = tcp_keepalive.with_retries(10);
}
let sf = SockRef::from(&stream);
sf.set_tcp_keepalive(&tcp_keepalive)?;
}
let (sender, receiver) = mpsc::channel(1024);
let (error_sender, error_receiver) = tokio::sync::oneshot::channel();
let (orphan_notification_sender, orphan_notification_receiver) = mpsc::unbounded_channel();
let router_handle = Arc::new(RouterHandle {
submit_channel: sender,
request_id_generator: AtomicU64::new(0),
orphan_notification_sender,
});
let _worker_handle = Self::run_router(
config.clone(),
stream,
receiver,
error_sender,
orphan_notification_receiver,
router_handle.clone(),
addr.ip(),
)
.await?;
let connection = Connection {
_worker_handle,
config,
features: Default::default(),
connect_address: addr,
router_handle,
};
Ok((connection, error_receiver))
}
pub(crate) async fn startup(
&self,
options: HashMap<String, String>,
) -> Result<Response, QueryError> {
Ok(self
.send_request(&request::Startup { options }, false, false)
.await?
.response)
}
pub(crate) async fn get_options(&self) -> Result<Response, QueryError> {
Ok(self
.send_request(&request::Options {}, false, false)
.await?
.response)
}
pub(crate) async fn prepare(&self, query: &Query) -> Result<PreparedStatement, QueryError> {
let query_response = self
.send_request(
&request::Prepare {
query: &query.contents,
},
true,
query.config.tracing,
)
.await?;
let mut prepared_statement = match query_response.response {
Response::Error(err) => return Err(err.into()),
Response::Result(result::Result::Prepared(p)) => PreparedStatement::new(
p.id,
self.features
.protocol_features
.prepared_flags_contain_lwt_mark(p.prepared_metadata.flags as u32),
p.prepared_metadata,
query.contents.clone(),
query.get_page_size(),
query.config.clone(),
),
_ => {
return Err(QueryError::ProtocolError(
"PREPARE: Unexpected server response",
))
}
};
if let Some(tracing_id) = query_response.tracing_id {
prepared_statement.prepare_tracing_ids.push(tracing_id);
}
Ok(prepared_statement)
}
pub(crate) async fn reprepare(
&self,
query: impl Into<Query>,
previous_prepared: &PreparedStatement,
) -> Result<(), QueryError> {
let reprepare_query: Query = query.into();
let reprepared = self.prepare(&reprepare_query).await?;
if reprepared.get_id() != previous_prepared.get_id() {
Err(QueryError::ProtocolError(
"Prepared statement Id changed, md5 sum should stay the same",
))
} else {
Ok(())
}
}
pub(crate) async fn authenticate_response(
&self,
response: Option<Vec<u8>>,
) -> Result<QueryResponse, QueryError> {
self.send_request(&request::AuthResponse { response }, false, false)
.await
}
pub(crate) async fn query_single_page(
&self,
query: impl Into<Query>,
) -> Result<QueryResult, QueryError> {
let query: Query = query.into();
let consistency = query
.config
.determine_consistency(self.config.default_consistency);
let serial_consistency = query.config.serial_consistency;
self.query_single_page_with_consistency(query, consistency, serial_consistency.flatten())
.await
}
pub(crate) async fn query_single_page_with_consistency(
&self,
query: impl Into<Query>,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
) -> Result<QueryResult, QueryError> {
let query: Query = query.into();
self.query_with_consistency(&query, consistency, serial_consistency, None)
.await?
.into_query_result()
}
pub(crate) async fn query(
&self,
query: &Query,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
self.query_with_consistency(
query,
query
.config
.determine_consistency(self.config.default_consistency),
query.config.serial_consistency.flatten(),
paging_state,
)
.await
}
pub(crate) async fn query_with_consistency(
&self,
query: &Query,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
let query_frame = query::Query {
contents: Cow::Borrowed(&query.contents),
parameters: query::QueryParameters {
consistency,
serial_consistency,
values: Cow::Borrowed(SerializedValues::EMPTY),
page_size: query.get_page_size(),
paging_state,
timestamp: query.get_timestamp(),
},
};
self.send_request(&query_frame, true, query.config.tracing)
.await
}
#[allow(dead_code)]
pub(crate) async fn execute(
&self,
prepared: PreparedStatement,
values: SerializedValues,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
self.execute_with_consistency(
&prepared,
&values,
prepared
.config
.determine_consistency(self.config.default_consistency),
prepared.config.serial_consistency.flatten(),
paging_state,
)
.await
}
pub(crate) async fn execute_with_consistency(
&self,
prepared_statement: &PreparedStatement,
values: &SerializedValues,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
let execute_frame = execute::Execute {
id: prepared_statement.get_id().to_owned(),
parameters: query::QueryParameters {
consistency,
serial_consistency,
values: Cow::Borrowed(values),
page_size: prepared_statement.get_page_size(),
timestamp: prepared_statement.get_timestamp(),
paging_state,
},
};
let query_response = self
.send_request(&execute_frame, true, prepared_statement.config.tracing)
.await?;
match &query_response.response {
Response::Error(frame::response::Error {
error: DbError::Unprepared { statement_id },
..
}) => {
debug!("Connection::execute: Got DbError::Unprepared - repreparing statement with id {:?}", statement_id);
self.reprepare(prepared_statement.get_statement(), prepared_statement)
.await?;
self.send_request(&execute_frame, true, prepared_statement.config.tracing)
.await
}
_ => Ok(query_response),
}
}
pub(crate) async fn query_iter(
self: Arc<Self>,
query: Query,
) -> Result<RowIterator, QueryError> {
let consistency = query
.config
.determine_consistency(self.config.default_consistency);
let serial_consistency = query.config.serial_consistency.flatten();
RowIterator::new_for_connection_query_iter(query, self, consistency, serial_consistency)
.await
}
pub(crate) async fn execute_iter(
self: Arc<Self>,
prepared_statement: PreparedStatement,
values: SerializedValues,
) -> Result<RowIterator, QueryError> {
let consistency = prepared_statement
.config
.determine_consistency(self.config.default_consistency);
let serial_consistency = prepared_statement.config.serial_consistency.flatten();
RowIterator::new_for_connection_execute_iter(
prepared_statement,
values,
self,
consistency,
serial_consistency,
)
.await
}
#[allow(dead_code)]
pub(crate) async fn batch(
&self,
batch: &Batch,
values: impl BatchValues,
) -> Result<QueryResult, QueryError> {
self.batch_with_consistency(
batch,
values,
batch
.config
.determine_consistency(self.config.default_consistency),
batch.config.serial_consistency.flatten(),
)
.await
}
pub(crate) async fn batch_with_consistency(
&self,
init_batch: &Batch,
values: impl BatchValues,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
) -> Result<QueryResult, QueryError> {
let batch = self.prepare_batch(init_batch, &values).await?;
let contexts = batch.statements.iter().map(|bs| match bs {
BatchStatement::Query(_) => RowSerializationContext::empty(),
BatchStatement::PreparedStatement(ps) => {
RowSerializationContext::from_prepared(ps.get_prepared_metadata())
}
});
let values = RawBatchValuesAdapter::new(values, contexts);
let batch_frame = batch::Batch {
statements: Cow::Borrowed(&batch.statements),
values,
batch_type: batch.get_type(),
consistency,
serial_consistency,
timestamp: batch.get_timestamp(),
};
loop {
let query_response = self
.send_request(&batch_frame, true, batch.config.tracing)
.await?;
return match query_response.response {
Response::Error(err) => match err.error {
DbError::Unprepared { statement_id } => {
debug!("Connection::batch: got DbError::Unprepared - repreparing statement with id {:?}", statement_id);
let prepared_statement = batch.statements.iter().find_map(|s| match s {
BatchStatement::PreparedStatement(s) if *s.get_id() == statement_id => {
Some(s)
}
_ => None,
});
if let Some(p) = prepared_statement {
self.reprepare(p.get_statement(), p).await?;
continue;
} else {
return Err(QueryError::ProtocolError(
"The server returned a prepared statement Id that did not exist in the batch",
));
}
}
_ => Err(err.into()),
},
Response::Result(_) => Ok(query_response.into_query_result()?),
_ => Err(QueryError::ProtocolError(
"BATCH: Unexpected server response",
)),
};
}
}
async fn prepare_batch<'b>(
&self,
init_batch: &'b Batch,
values: impl BatchValues,
) -> Result<Cow<'b, Batch>, QueryError> {
let mut to_prepare = HashSet::<&str>::new();
{
let mut values_iter = values.batch_values_iter();
for stmt in &init_batch.statements {
if let BatchStatement::Query(query) = stmt {
if let Some(false) = values_iter.is_empty_next() {
to_prepare.insert(&query.contents);
}
} else {
values_iter.skip_next();
}
}
}
if to_prepare.is_empty() {
return Ok(Cow::Borrowed(init_batch));
}
let mut prepared_queries = HashMap::<&str, PreparedStatement>::new();
for query in &to_prepare {
let prepared = self.prepare(&Query::new(query.to_string())).await?;
prepared_queries.insert(query, prepared);
}
let mut batch: Cow<Batch> = Cow::Owned(Default::default());
batch.to_mut().config = init_batch.config.clone();
for stmt in &init_batch.statements {
match stmt {
BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str())
{
Some(prepared) => batch.to_mut().append_statement(prepared.clone()),
None => batch.to_mut().append_statement(query.clone()),
},
BatchStatement::PreparedStatement(prepared) => {
batch.to_mut().append_statement(prepared.clone());
}
}
}
Ok(batch)
}
pub(crate) async fn use_keyspace(
&self,
keyspace_name: &VerifiedKeyspaceName,
) -> Result<(), QueryError> {
let query: Query = match keyspace_name.is_case_sensitive {
true => format!("USE \"{}\"", keyspace_name.as_str()).into(),
false => format!("USE {}", keyspace_name.as_str()).into(),
};
let query_response = self.query(&query, None).await?;
match query_response.response {
Response::Result(result::Result::SetKeyspace(set_keyspace)) => {
if set_keyspace.keyspace_name.to_lowercase()
!= keyspace_name.as_str().to_lowercase()
{
return Err(QueryError::ProtocolError(
"USE <keyspace_name> returned response with different keyspace name",
));
}
Ok(())
}
Response::Error(err) => Err(err.into()),
_ => Err(QueryError::ProtocolError(
"USE <keyspace_name> returned unexpected response",
)),
}
}
async fn register(
&self,
event_types_to_register_for: Vec<EventType>,
) -> Result<(), QueryError> {
let register_frame = register::Register {
event_types_to_register_for,
};
match self
.send_request(®ister_frame, true, false)
.await?
.response
{
Response::Ready => Ok(()),
Response::Error(err) => Err(err.into()),
_ => Err(QueryError::ProtocolError(
"Unexpected response to REGISTER message",
)),
}
}
pub(crate) async fn fetch_schema_version(&self) -> Result<Uuid, QueryError> {
let (version_id,): (Uuid,) = self
.query_single_page(LOCAL_VERSION)
.await?
.rows
.ok_or(QueryError::ProtocolError("Version query returned not rows"))?
.into_typed::<(Uuid,)>()
.next()
.ok_or(QueryError::ProtocolError("Admin table returned empty rows"))?
.map_err(|_| QueryError::ProtocolError("Row is not uuid type as it should be"))?;
Ok(version_id)
}
async fn send_request(
&self,
request: &impl SerializableRequest,
compress: bool,
tracing: bool,
) -> Result<QueryResponse, QueryError> {
let compression = if compress {
self.config.compression
} else {
None
};
let task_response = self
.router_handle
.send_request(request, compression, tracing)
.await?;
Self::parse_response(
task_response,
self.config.compression,
&self.features.protocol_features,
)
}
fn parse_response(
task_response: TaskResponse,
compression: Option<Compression>,
features: &ProtocolFeatures,
) -> Result<QueryResponse, QueryError> {
let body_with_ext = frame::parse_response_body_extensions(
task_response.params.flags,
compression,
task_response.body,
)?;
for warn_description in &body_with_ext.warnings {
warn!(
warning = warn_description.as_str(),
"Response from the database contains a warning",
);
}
let response =
Response::deserialize(features, task_response.opcode, &mut &*body_with_ext.body)?;
Ok(QueryResponse {
response,
warnings: body_with_ext.warnings,
tracing_id: body_with_ext.trace_id,
})
}
async fn run_router(
config: ConnectionConfig,
stream: TcpStream,
receiver: mpsc::Receiver<Task>,
error_sender: tokio::sync::oneshot::Sender<QueryError>,
orphan_notification_receiver: mpsc::UnboundedReceiver<RequestId>,
router_handle: Arc<RouterHandle>,
node_address: IpAddr,
) -> Result<RemoteHandle<()>, std::io::Error> {
#[cfg(feature = "ssl")]
if let Some(ssl_config) = &config.ssl_config {
let ssl = ssl_config.new_ssl()?;
let mut stream = SslStream::new(ssl, stream)?;
let _pin = Pin::new(&mut stream).connect().await;
let (task, handle) = Self::router(
config,
stream,
receiver,
error_sender,
orphan_notification_receiver,
router_handle,
node_address,
)
.remote_handle();
tokio::task::spawn(task.with_current_subscriber());
return Ok(handle);
}
let (task, handle) = Self::router(
config,
stream,
receiver,
error_sender,
orphan_notification_receiver,
router_handle,
node_address,
)
.remote_handle();
tokio::task::spawn(task.with_current_subscriber());
Ok(handle)
}
async fn router(
config: ConnectionConfig,
stream: (impl AsyncRead + AsyncWrite),
receiver: mpsc::Receiver<Task>,
error_sender: tokio::sync::oneshot::Sender<QueryError>,
orphan_notification_receiver: mpsc::UnboundedReceiver<RequestId>,
router_handle: Arc<RouterHandle>,
node_address: IpAddr,
) {
let (read_half, write_half) = split(stream);
let handler_map = StdMutex::new(ResponseHandlerMap::new());
let enable_write_coalescing = config.enable_write_coalescing;
let k = Self::keepaliver(
router_handle,
config.keepalive_interval,
config.keepalive_timeout,
node_address,
);
let r = Self::reader(
BufReader::with_capacity(8192, read_half),
&handler_map,
config,
);
let w = Self::writer(
BufWriter::with_capacity(8192, write_half),
&handler_map,
receiver,
enable_write_coalescing,
);
let o = Self::orphaner(&handler_map, orphan_notification_receiver);
let result = futures::try_join!(r, w, o, k);
let error: QueryError = match result {
Ok(_) => return, Err(err) => err,
};
let response_handlers: HashMap<i16, ResponseHandler> =
handler_map.into_inner().unwrap().into_handlers();
for (_, handler) in response_handlers {
let _ = handler.response_sender.send(Err(error.clone()));
}
let _ = error_sender.send(error);
}
async fn reader(
mut read_half: (impl AsyncRead + Unpin),
handler_map: &StdMutex<ResponseHandlerMap>,
config: ConnectionConfig,
) -> Result<(), QueryError> {
loop {
let (params, opcode, body) = frame::read_response_frame(&mut read_half).await?;
let response = TaskResponse {
params,
opcode,
body,
};
match params.stream.cmp(&-1) {
Ordering::Less => {
continue;
}
Ordering::Equal => {
if let Some(event_sender) = config.event_sender.as_ref() {
Self::handle_event(response, config.compression, event_sender).await?;
}
continue;
}
_ => {}
}
let handler_lookup_res = {
let mut handler_map_guard = handler_map.try_lock().unwrap();
handler_map_guard.lookup(params.stream)
};
use HandlerLookupResult::*;
match handler_lookup_res {
Handler(handler) => {
let _ = handler.response_sender.send(Ok(response));
}
Missing => {
debug!(
"Received response with unexpected StreamId {}",
params.stream
);
return Err(QueryError::ProtocolError(
"Received response with unexpected StreamId",
));
}
Orphaned => {
}
}
}
}
fn alloc_stream_id(
handler_map: &StdMutex<ResponseHandlerMap>,
response_handler: ResponseHandler,
) -> Option<i16> {
let mut handler_map_guard = handler_map.try_lock().unwrap();
match handler_map_guard.allocate(response_handler) {
Ok(stream_id) => Some(stream_id),
Err(response_handler) => {
error!("Could not allocate stream id");
let _ = response_handler
.response_sender
.send(Err(QueryError::UnableToAllocStreamId));
None
}
}
}
async fn writer(
mut write_half: (impl AsyncWrite + Unpin),
handler_map: &StdMutex<ResponseHandlerMap>,
mut task_receiver: mpsc::Receiver<Task>,
enable_write_coalescing: bool,
) -> Result<(), QueryError> {
while let Some(mut task) = task_receiver.recv().await {
let mut num_requests = 0;
let mut total_sent = 0;
while let Some(stream_id) = Self::alloc_stream_id(handler_map, task.response_handler) {
let mut req = task.serialized_request;
req.set_stream(stream_id);
let req_data: &[u8] = req.get_data();
total_sent += req_data.len();
num_requests += 1;
write_half.write_all(req_data).await?;
task = match task_receiver.try_recv() {
Ok(t) => t,
Err(_) if enable_write_coalescing => {
tokio::task::yield_now().await;
match task_receiver.try_recv() {
Ok(t) => t,
Err(_) => break,
}
}
Err(_) => break,
}
}
trace!("Sending {} requests; {} bytes", num_requests, total_sent);
write_half.flush().await?;
}
Ok(())
}
async fn orphaner(
handler_map: &StdMutex<ResponseHandlerMap>,
mut orphan_receiver: mpsc::UnboundedReceiver<RequestId>,
) -> Result<(), QueryError> {
let mut interval = tokio::time::interval(OLD_AGE_ORPHAN_THRESHOLD);
loop {
tokio::select! {
_ = interval.tick() => {
let handler_map_guard = handler_map.try_lock().unwrap();
let old_orphan_count = handler_map_guard.old_orphans_count();
if old_orphan_count > OLD_ORPHAN_COUNT_THRESHOLD {
warn!(
"Too many old orphaned stream ids: {}",
old_orphan_count,
);
return Err(QueryError::TooManyOrphanedStreamIds(old_orphan_count as u16))
}
}
Some(request_id) = orphan_receiver.recv() => {
trace!(
"Trying to orphan stream id associated with request_id = {}",
request_id,
);
let mut handler_map_guard = handler_map.try_lock().unwrap(); handler_map_guard.orphan(request_id);
}
else => { break }
}
}
Ok(())
}
async fn keepaliver(
router_handle: Arc<RouterHandle>,
keepalive_interval: Option<Duration>,
keepalive_timeout: Option<Duration>,
node_address: IpAddr, ) -> Result<(), QueryError> {
async fn issue_keepalive_query(router_handle: &RouterHandle) -> Result<(), QueryError> {
router_handle
.send_request(&Options, None, false)
.await
.map(|_| ())
}
if let Some(keepalive_interval) = keepalive_interval {
let mut interval = tokio::time::interval(keepalive_interval);
interval.tick().await; interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
loop {
interval.tick().await;
let keepalive_query = issue_keepalive_query(&router_handle);
let query_result = if let Some(timeout) = keepalive_timeout {
match tokio::time::timeout(timeout, keepalive_query).await {
Ok(res) => res,
Err(_) => {
warn!(
"Timed out while waiting for response to keepalive request on connection to node {}",
node_address
);
return Err(QueryError::IoError(Arc::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!(
"Timed out while waiting for response to keepalive request on connection to node {}",
node_address
)
))));
}
}
} else {
keepalive_query.await
};
if let Err(err) = query_result {
warn!(
"Failed to execute keepalive request on connection to node {} - {}",
node_address, err
);
return Err(err);
}
}
} else {
Ok(())
}
}
async fn handle_event(
task_response: TaskResponse,
compression: Option<Compression>,
event_sender: &mpsc::Sender<Event>,
) -> Result<(), QueryError> {
let features = ProtocolFeatures::default(); let response = Self::parse_response(task_response, compression, &features)?.response;
let event = match response {
Response::Event(e) => e,
_ => {
warn!("Expected to receive Event response, got {:?}", response);
return Ok(());
}
};
event_sender.send(event).await.map_err(|_| {
QueryError::IoError(Arc::new(std::io::Error::new(
ErrorKind::Other,
"Connection broken",
)))
})
}
pub(crate) fn get_shard_info(&self) -> &Option<ShardInfo> {
&self.features.shard_info
}
pub(crate) fn get_shard_aware_port(&self) -> Option<u16> {
self.features.shard_aware_port
}
fn set_features(&mut self, features: ConnectionFeatures) {
self.features = features;
}
pub(crate) fn get_connect_address(&self) -> SocketAddr {
self.connect_address
}
}
async fn maybe_translated_addr(
endpoint: UntranslatedEndpoint,
address_translator: Option<&dyn AddressTranslator>,
) -> Result<SocketAddr, TranslationError> {
match endpoint {
UntranslatedEndpoint::ContactPoint(addr) => Ok(addr.address),
UntranslatedEndpoint::Peer(PeerEndpoint {
host_id,
address,
datacenter,
rack,
}) => match address {
NodeAddr::Translatable(addr) => {
if let Some(translator) = address_translator {
let res = translator
.translate_address(&UntranslatedPeer {
host_id,
untranslated_address: addr,
datacenter,
rack,
})
.await;
if let Err(ref err) = res {
error!("Address translation failed for addr {}: {}", addr, err);
}
res
} else {
Ok(addr)
}
}
NodeAddr::Untranslatable(addr) => {
Ok(addr)
}
},
}
}
pub(crate) async fn open_connection(
endpoint: UntranslatedEndpoint,
source_port: Option<u16>,
config: ConnectionConfig,
) -> Result<(Connection, ErrorReceiver), QueryError> {
let addr = maybe_translated_addr(endpoint, config.address_translator.as_deref()).await?;
open_named_connection(
addr,
source_port,
config,
Some("scylla-rust-driver".to_string()),
option_env!("CARGO_PKG_VERSION").map(|v| v.to_string()),
)
.await
}
pub(crate) async fn open_named_connection(
addr: SocketAddr,
source_port: Option<u16>,
config: ConnectionConfig,
driver_name: Option<String>,
driver_version: Option<String>,
) -> Result<(Connection, ErrorReceiver), QueryError> {
let (mut connection, error_receiver) =
Connection::new(addr, source_port, config.clone()).await?;
let options_result = connection.get_options().await?;
let shard_aware_port_key = match config.is_ssl() {
true => "SCYLLA_SHARD_AWARE_PORT_SSL",
false => "SCYLLA_SHARD_AWARE_PORT",
};
let mut supported = match options_result {
Response::Supported(supported) => supported,
Response::Error(Error { error, reason }) => return Err(QueryError::DbError(error, reason)),
_ => {
return Err(QueryError::ProtocolError(
"Wrong response to OPTIONS message was received",
));
}
};
let shard_info = ShardInfo::try_from(&supported.options).ok();
let supported_compression = supported.options.remove("COMPRESSION").unwrap_or_default();
let shard_aware_port = supported
.options
.remove(shard_aware_port_key)
.unwrap_or_default()
.into_iter()
.next()
.and_then(|p| p.parse::<u16>().ok());
let protocol_features = ProtocolFeatures::parse_from_supported(&supported.options);
let mut options = HashMap::new();
protocol_features.add_startup_options(&mut options);
let features = ConnectionFeatures {
shard_info,
shard_aware_port,
protocol_features,
};
connection.set_features(features);
options.insert("CQL_VERSION".to_string(), "4.0.0".to_string()); if let Some(name) = driver_name {
options.insert("DRIVER_NAME".to_string(), name);
}
if let Some(version) = driver_version {
options.insert("DRIVER_VERSION".to_string(), version);
}
if let Some(compression) = &config.compression {
let compression_str = compression.to_string();
if supported_compression.iter().any(|c| c == &compression_str) {
options.insert("COMPRESSION".to_string(), compression.to_string());
} else {
connection.config.compression = None;
}
}
let result = connection.startup(options).await?;
match result {
Response::Ready => {}
Response::Authenticate(authenticate) => {
perform_authenticate(&mut connection, &authenticate).await?;
}
Response::Error(Error { error, reason }) => return Err(QueryError::DbError(error, reason)),
_ => {
return Err(QueryError::ProtocolError(
"Unexpected response to STARTUP message",
))
}
}
if connection.config.event_sender.is_some() {
let all_event_types = vec![
EventType::TopologyChange,
EventType::StatusChange,
EventType::SchemaChange,
];
connection.register(all_event_types).await?;
}
Ok((connection, error_receiver))
}
async fn perform_authenticate(
connection: &mut Connection,
authenticate: &Authenticate,
) -> Result<(), QueryError> {
let authenticator = &authenticate.authenticator_name as &str;
match connection.config.authenticator {
Some(ref authenticator_provider) => {
let (mut response, mut auth_session) = authenticator_provider
.start_authentication_session(authenticator)
.await
.map_err(QueryError::InvalidMessage)?;
loop {
match connection
.authenticate_response(response)
.await?.response
{
Response::AuthChallenge(challenge) => {
response = auth_session
.evaluate_challenge(
challenge.authenticate_message.as_deref(),
)
.await
.map_err(QueryError::InvalidMessage)?;
}
Response::AuthSuccess(success) => {
auth_session
.success(success.success_message.as_deref())
.await
.map_err(QueryError::InvalidMessage)?;
break;
}
Response::Error(err) => {
return Err(err.into());
}
_ => {
return Err(QueryError::ProtocolError(
"Unexpected response to Authenticate Response message",
))
}
}
}
},
None => return Err(QueryError::InvalidMessage(
"Authentication is required. You can use SessionBuilder::user(\"user\", \"pass\") to provide credentials \
or SessionBuilder::authenticator_provider to provide custom authenticator".to_string(),
)),
}
Ok(())
}
async fn connect_with_source_port(
addr: SocketAddr,
source_port: u16,
) -> Result<TcpStream, std::io::Error> {
match addr {
SocketAddr::V4(_) => {
let socket = TcpSocket::new_v4()?;
socket.bind(SocketAddr::new(
Ipv4Addr::new(0, 0, 0, 0).into(),
source_port,
))?;
Ok(socket.connect(addr).await?)
}
SocketAddr::V6(_) => {
let socket = TcpSocket::new_v6()?;
socket.bind(SocketAddr::new(
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(),
source_port,
))?;
Ok(socket.connect(addr).await?)
}
}
}
struct OrphanageTracker {
orphans: HashMap<i16, Instant>,
by_orphaning_times: BTreeSet<(Instant, i16)>,
}
impl OrphanageTracker {
fn new() -> Self {
Self {
orphans: HashMap::new(),
by_orphaning_times: BTreeSet::new(),
}
}
fn insert(&mut self, stream_id: i16) {
let now = Instant::now();
self.orphans.insert(stream_id, now);
self.by_orphaning_times.insert((now, stream_id));
}
fn remove(&mut self, stream_id: i16) {
if let Some(time) = self.orphans.remove(&stream_id) {
self.by_orphaning_times.remove(&(time, stream_id));
}
}
fn contains(&self, stream_id: i16) -> bool {
self.orphans.contains_key(&stream_id)
}
fn orphans_older_than(&self, age: std::time::Duration) -> usize {
let minimal_age = Instant::now() - age;
self.by_orphaning_times
.range(..(minimal_age, i16::MAX))
.count() }
}
struct ResponseHandlerMap {
stream_set: StreamIdSet,
handlers: HashMap<i16, ResponseHandler>,
request_to_stream: HashMap<RequestId, i16>,
orphanage_tracker: OrphanageTracker,
}
enum HandlerLookupResult {
Orphaned,
Handler(ResponseHandler),
Missing,
}
impl ResponseHandlerMap {
fn new() -> Self {
Self {
stream_set: StreamIdSet::new(),
handlers: HashMap::new(),
request_to_stream: HashMap::new(),
orphanage_tracker: OrphanageTracker::new(),
}
}
fn allocate(&mut self, response_handler: ResponseHandler) -> Result<i16, ResponseHandler> {
if let Some(stream_id) = self.stream_set.allocate() {
self.request_to_stream
.insert(response_handler.request_id, stream_id);
let prev_handler = self.handlers.insert(stream_id, response_handler);
assert!(prev_handler.is_none());
Ok(stream_id)
} else {
Err(response_handler)
}
}
fn orphan(&mut self, request_id: RequestId) {
if let Some(stream_id) = self.request_to_stream.get(&request_id) {
debug!(
"Orphaning stream_id = {} associated with request_id = {}",
stream_id, request_id
);
self.orphanage_tracker.insert(*stream_id);
self.handlers.remove(stream_id);
self.request_to_stream.remove(&request_id);
}
}
fn old_orphans_count(&self) -> usize {
self.orphanage_tracker
.orphans_older_than(OLD_AGE_ORPHAN_THRESHOLD)
}
fn lookup(&mut self, stream_id: i16) -> HandlerLookupResult {
self.stream_set.free(stream_id);
if self.orphanage_tracker.contains(stream_id) {
self.orphanage_tracker.remove(stream_id);
return HandlerLookupResult::Orphaned;
}
if let Some(handler) = self.handlers.remove(&stream_id) {
self.request_to_stream.remove(&handler.request_id);
HandlerLookupResult::Handler(handler)
} else {
HandlerLookupResult::Missing
}
}
fn into_handlers(self) -> HashMap<i16, ResponseHandler> {
self.handlers
}
}
struct StreamIdSet {
used_bitmap: Box<[u64]>,
}
impl StreamIdSet {
fn new() -> Self {
const BITMAP_SIZE: usize = (std::i16::MAX as usize + 1) / 64;
Self {
used_bitmap: vec![0; BITMAP_SIZE].into_boxed_slice(),
}
}
fn allocate(&mut self) -> Option<i16> {
for (block_id, block) in self.used_bitmap.iter_mut().enumerate() {
if *block != !0 {
let off = block.trailing_ones();
*block |= 1u64 << off;
let stream_id = off as i16 + block_id as i16 * 64;
return Some(stream_id);
}
}
None
}
fn free(&mut self, stream_id: i16) {
let block_id = stream_id as usize / 64;
let off = stream_id as usize % 64;
self.used_bitmap[block_id] &= !(1 << off);
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct VerifiedKeyspaceName {
name: Arc<String>,
pub(crate) is_case_sensitive: bool,
}
impl VerifiedKeyspaceName {
pub(crate) fn new(
keyspace_name: String,
case_sensitive: bool,
) -> Result<Self, BadKeyspaceName> {
Self::verify_keyspace_name_is_valid(&keyspace_name)?;
Ok(VerifiedKeyspaceName {
name: Arc::new(keyspace_name),
is_case_sensitive: case_sensitive,
})
}
pub(crate) fn as_str(&self) -> &str {
self.name.as_str()
}
fn verify_keyspace_name_is_valid(keyspace_name: &str) -> Result<(), BadKeyspaceName> {
if keyspace_name.is_empty() {
return Err(BadKeyspaceName::Empty);
}
let keyspace_name_len: usize = keyspace_name.chars().count(); if keyspace_name_len > 48 {
return Err(BadKeyspaceName::TooLong(
keyspace_name.to_string(),
keyspace_name_len,
));
}
for character in keyspace_name.chars() {
match character {
'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => {}
_ => {
return Err(BadKeyspaceName::IllegalCharacter(
keyspace_name.to_string(),
character,
))
}
};
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use scylla_cql::errors::QueryError;
use scylla_cql::frame::protocol_features::{
LWT_OPTIMIZATION_META_BIT_MASK_KEY, SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION,
};
use scylla_cql::frame::types;
use scylla_proxy::{
Condition, Node, Proxy, Reaction, RequestFrame, RequestOpcode, RequestReaction,
RequestRule, ResponseFrame, ShardAwareness,
};
use tokio::select;
use tokio::sync::mpsc;
use super::ConnectionConfig;
use crate::query::Query;
use crate::transport::connection::open_connection;
use crate::transport::node::ResolvedContactPoint;
use crate::transport::topology::UntranslatedEndpoint;
use crate::utils::test_utils::unique_keyspace_name;
use crate::{IntoTypedRows, SessionBuilder};
use futures::{StreamExt, TryStreamExt};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
async fn resolve_hostname(hostname: &str) -> SocketAddr {
match tokio::net::lookup_host(hostname).await {
Ok(mut addrs) => addrs.next().unwrap(),
Err(_) => {
tokio::net::lookup_host((hostname, 9042)) .await
.unwrap()
.next()
.unwrap()
}
}
}
#[tokio::test]
#[cfg(not(scylla_cloud_tests))]
async fn connection_query_iter_test() {
let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
let addr: SocketAddr = resolve_hostname(&uri).await;
let (connection, _) = super::open_connection(
UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
address: addr,
datacenter: None,
}),
None,
ConnectionConfig::default(),
)
.await
.unwrap();
let connection = Arc::new(connection);
let ks = unique_keyspace_name();
{
let session = SessionBuilder::new()
.known_node_addr(addr)
.build()
.await
.unwrap();
session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks.clone()), &[]).await.unwrap();
session.use_keyspace(ks.clone(), false).await.unwrap();
session
.query("DROP TABLE IF EXISTS connection_query_iter_tab", &[])
.await
.unwrap();
session
.query(
"CREATE TABLE IF NOT EXISTS connection_query_iter_tab (p int primary key)",
&[],
)
.await
.unwrap();
}
connection
.use_keyspace(&super::VerifiedKeyspaceName::new(ks, false).unwrap())
.await
.unwrap();
let select_query = Query::new("SELECT p FROM connection_query_iter_tab").with_page_size(7);
let empty_res = connection
.clone()
.query_iter(select_query.clone())
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert!(empty_res.is_empty());
let values: Vec<i32> = (0..100).collect();
let mut insert_futures = Vec::new();
let insert_query =
Query::new("INSERT INTO connection_query_iter_tab (p) VALUES (?)").with_page_size(7);
let prepared = connection.prepare(&insert_query).await.unwrap();
for v in &values {
let prepared_clone = prepared.clone();
let values = prepared_clone.serialize_values(&(*v,)).unwrap();
let fut = async { connection.execute(prepared_clone, values, None).await };
insert_futures.push(fut);
}
futures::future::try_join_all(insert_futures).await.unwrap();
let mut results: Vec<i32> = connection
.clone()
.query_iter(select_query.clone())
.await
.unwrap()
.into_typed::<(i32,)>()
.map(|ret| ret.unwrap().0)
.collect::<Vec<_>>()
.await;
results.sort_unstable(); assert_eq!(results, values);
let insert_res1 = connection
.query_iter(Query::new(
"INSERT INTO connection_query_iter_tab (p) VALUES (0)",
))
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert!(insert_res1.is_empty());
}
#[tokio::test]
#[cfg(not(scylla_cloud_tests))]
async fn test_coalescing() {
let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
let addr: SocketAddr = resolve_hostname(&uri).await;
let ks = unique_keyspace_name();
{
let session = SessionBuilder::new()
.known_node_addr(addr)
.build()
.await
.unwrap();
session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks.clone()), &[]).await.unwrap();
session.use_keyspace(ks.clone(), false).await.unwrap();
session
.query(
"CREATE TABLE IF NOT EXISTS t (p int primary key, v blob)",
&[],
)
.await
.unwrap();
}
let subtest = |enable_coalescing: bool, ks: String| async move {
let (connection, _) = super::open_connection(
UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
address: addr,
datacenter: None,
}),
None,
ConnectionConfig {
enable_write_coalescing: enable_coalescing,
..ConnectionConfig::default()
},
)
.await
.unwrap();
let connection = Arc::new(connection);
connection
.use_keyspace(&super::VerifiedKeyspaceName::new(ks, false).unwrap())
.await
.unwrap();
connection.query(&"TRUNCATE t".into(), None).await.unwrap();
let mut futs = Vec::new();
const NUM_BATCHES: i32 = 10;
for batch_size in 0..NUM_BATCHES {
let base = arithmetic_sequence_sum(batch_size);
let conn = connection.clone();
futs.push(tokio::task::spawn(async move {
let futs = (base..base + batch_size).map(|j| {
let q = Query::new("INSERT INTO t (p, v) VALUES (?, ?)");
let conn = conn.clone();
async move {
let prepared = conn.prepare(&q).await.unwrap();
let values = prepared
.serialize_values(&(j, vec![j as u8; j as usize]))
.unwrap();
let response =
conn.execute(prepared.clone(), values, None).await.unwrap();
let _nonerror_response =
response.into_non_error_query_response().unwrap();
}
});
let _joined: Vec<()> = futures::future::join_all(futs).await;
}));
tokio::task::yield_now().await;
}
let _joined: Vec<()> = futures::future::try_join_all(futs).await.unwrap();
let range_end = arithmetic_sequence_sum(NUM_BATCHES);
let mut results = connection
.query(&"SELECT p, v FROM t".into(), None)
.await
.unwrap()
.into_query_result()
.unwrap()
.rows()
.unwrap()
.into_typed::<(i32, Vec<u8>)>()
.collect::<Result<Vec<_>, _>>()
.unwrap();
results.sort();
let expected = (0..range_end)
.map(|i| (i, vec![i as u8; i as usize]))
.collect::<Vec<_>>();
assert_eq!(results, expected);
};
subtest(true, ks.clone()).await;
subtest(false, ks.clone()).await;
}
fn arithmetic_sequence_sum(n: i32) -> i32 {
n * (n - 1) / 2
}
#[tokio::test]
async fn test_lwt_optimisation_mark_negotiation() {
const MASK: &str = "2137";
let lwt_optimisation_entry = format!("{}={}", LWT_OPTIMIZATION_META_BIT_MASK_KEY, MASK);
let proxy_addr = SocketAddr::new(scylla_proxy::get_exclusive_local_address(), 9042);
let config = ConnectionConfig::default();
let (startup_tx, mut startup_rx) = mpsc::unbounded_channel();
let options_without_lwt_optimisation_support = HashMap::<String, Vec<String>>::new();
let options_with_lwt_optimisation_support = [(
SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION.into(),
vec![lwt_optimisation_entry.clone()],
)]
.into_iter()
.collect::<HashMap<String, Vec<String>>>();
let make_rules = |options| {
vec![
RequestRule(
Condition::RequestOpcode(RequestOpcode::Options),
RequestReaction::forge_response(Arc::new(move |frame: RequestFrame| {
ResponseFrame::forged_supported(frame.params, &options).unwrap()
})),
),
RequestRule(
Condition::RequestOpcode(RequestOpcode::Startup),
RequestReaction::drop_frame().with_feedback_when_performed(startup_tx.clone()),
),
]
};
let mut proxy = Proxy::builder()
.with_node(
Node::builder()
.proxy_address(proxy_addr)
.request_rules(make_rules(options_without_lwt_optimisation_support))
.build_dry_mode(),
)
.build()
.run()
.await
.unwrap();
let (startup_without_lwt_optimisation, _shard) = select! {
_ = open_connection(UntranslatedEndpoint::ContactPoint(ResolvedContactPoint{address: proxy_addr, datacenter: None}), None, config.clone()) => unreachable!(),
startup = startup_rx.recv() => startup.unwrap(),
};
proxy.running_nodes[0]
.change_request_rules(Some(make_rules(options_with_lwt_optimisation_support)));
let (startup_with_lwt_optimisation, _shard) = select! {
_ = open_connection(UntranslatedEndpoint::ContactPoint(ResolvedContactPoint{address: proxy_addr, datacenter: None}), None, config.clone()) => unreachable!(),
startup = startup_rx.recv() => startup.unwrap(),
};
let _ = proxy.finish().await;
let chosen_options =
types::read_string_map(&mut &*startup_without_lwt_optimisation.body).unwrap();
assert!(!chosen_options.contains_key(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION));
let chosen_options =
types::read_string_map(&mut &startup_with_lwt_optimisation.body[..]).unwrap();
assert!(chosen_options.contains_key(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION));
assert_eq!(
chosen_options
.get(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION)
.unwrap(),
&lwt_optimisation_entry
)
}
#[tokio::test]
#[ntest::timeout(20000)]
#[cfg(not(scylla_cloud_tests))]
async fn connection_is_closed_on_no_response_to_keepalives() {
let proxy_addr = SocketAddr::new(scylla_proxy::get_exclusive_local_address(), 9042);
let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
let node_addr: SocketAddr = resolve_hostname(&uri).await;
let drop_options_rule = RequestRule(
Condition::RequestOpcode(RequestOpcode::Options),
RequestReaction::drop_frame(),
);
let config = ConnectionConfig {
keepalive_interval: Some(Duration::from_millis(500)),
keepalive_timeout: Some(Duration::from_secs(1)),
..Default::default()
};
let mut proxy = Proxy::builder()
.with_node(
Node::builder()
.proxy_address(proxy_addr)
.real_address(node_addr)
.shard_awareness(ShardAwareness::QueryNode)
.build(),
)
.build()
.run()
.await
.unwrap();
let (conn, mut error_receiver) = open_connection(
UntranslatedEndpoint::ContactPoint(ResolvedContactPoint {
address: proxy_addr,
datacenter: None,
}),
None,
config,
)
.await
.unwrap();
for _ in 0..3 {
tokio::time::sleep(Duration::from_millis(500)).await;
conn.query_single_page("SELECT host_id FROM system.local")
.await
.unwrap();
}
assert_matches!(
error_receiver.try_recv(),
Err(tokio::sync::oneshot::error::TryRecvError::Empty)
);
proxy.running_nodes[0].change_request_rules(Some(vec![drop_options_rule]));
let err = error_receiver.await.unwrap();
assert_matches!(err, QueryError::IoError(_));
conn.query_single_page("SELECT host_id FROM system.local")
.await
.unwrap_err();
let _ = proxy.finish().await;
}
}