use bytes::{Bytes, BytesMut};
use scylla_cql::errors::{BadQuery, QueryError};
use scylla_cql::frame::types::RawValue;
use scylla_cql::types::serialize::row::{RowSerializationContext, SerializeRow, SerializedValues};
use scylla_cql::types::serialize::SerializationError;
use smallvec::{smallvec, SmallVec};
use std::convert::TryInto;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use uuid::Uuid;
use scylla_cql::frame::response::result::ColumnSpec;
use super::StatementConfig;
use crate::frame::response::result::PreparedMetadata;
use crate::frame::types::{Consistency, SerialConsistency};
use crate::history::HistoryListener;
use crate::retry_policy::RetryPolicy;
use crate::routing::Token;
use crate::transport::execution_profile::ExecutionProfileHandle;
use crate::transport::partitioner::{Partitioner, PartitionerHasher, PartitionerName};
#[derive(Debug)]
pub struct PreparedStatement {
pub(crate) config: StatementConfig,
pub prepare_tracing_ids: Vec<Uuid>,
id: Bytes,
shared: Arc<PreparedStatementSharedData>,
page_size: Option<i32>,
partitioner_name: PartitionerName,
is_confirmed_lwt: bool,
}
#[derive(Debug)]
struct PreparedStatementSharedData {
metadata: PreparedMetadata,
statement: String,
}
impl Clone for PreparedStatement {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
prepare_tracing_ids: Vec::new(),
id: self.id.clone(),
shared: self.shared.clone(),
page_size: self.page_size,
partitioner_name: self.partitioner_name.clone(),
is_confirmed_lwt: self.is_confirmed_lwt,
}
}
}
impl PreparedStatement {
pub(crate) fn new(
id: Bytes,
is_lwt: bool,
metadata: PreparedMetadata,
statement: String,
page_size: Option<i32>,
config: StatementConfig,
) -> Self {
Self {
id,
shared: Arc::new(PreparedStatementSharedData {
metadata,
statement,
}),
prepare_tracing_ids: Vec::new(),
page_size,
config,
partitioner_name: Default::default(),
is_confirmed_lwt: is_lwt,
}
}
pub fn get_id(&self) -> &Bytes {
&self.id
}
pub fn get_statement(&self) -> &str {
&self.shared.statement
}
pub fn set_page_size(&mut self, page_size: i32) {
assert!(page_size > 0, "page size must be larger than 0");
self.page_size = Some(page_size);
}
pub fn disable_paging(&mut self) {
self.page_size = None;
}
pub fn get_page_size(&self) -> Option<i32> {
self.page_size
}
pub fn get_prepare_tracing_ids(&self) -> &[Uuid] {
&self.prepare_tracing_ids
}
pub fn is_token_aware(&self) -> bool {
!self.get_prepared_metadata().pk_indexes.is_empty()
}
pub fn is_confirmed_lwt(&self) -> bool {
self.is_confirmed_lwt
}
pub fn compute_partition_key(
&self,
bound_values: &impl SerializeRow,
) -> Result<Bytes, PartitionKeyError> {
let serialized = self.serialize_values(bound_values)?;
let partition_key = self.extract_partition_key(&serialized)?;
let mut buf = BytesMut::new();
let mut writer = |chunk: &[u8]| buf.extend_from_slice(chunk);
partition_key.write_encoded_partition_key(&mut writer)?;
Ok(buf.freeze())
}
pub(crate) fn extract_partition_key<'ps>(
&'ps self,
bound_values: &'ps SerializedValues,
) -> Result<PartitionKey, PartitionKeyExtractionError> {
PartitionKey::new(self.get_prepared_metadata(), bound_values)
}
pub(crate) fn extract_partition_key_and_calculate_token<'ps>(
&'ps self,
partitioner_name: &'ps PartitionerName,
serialized_values: &'ps SerializedValues,
) -> Result<Option<(PartitionKey<'ps>, Token)>, QueryError> {
if !self.is_token_aware() {
return Ok(None);
}
let partition_key =
self.extract_partition_key(serialized_values)
.map_err(|err| match err {
PartitionKeyExtractionError::NoPkIndexValue(_, _) => {
QueryError::ProtocolError("No pk indexes - can't calculate token")
}
})?;
let token = partition_key
.calculate_token(partitioner_name)
.map_err(|err| match err {
TokenCalculationError::ValueTooLong(values_len) => {
QueryError::BadQuery(BadQuery::ValuesTooLongForKey(values_len, u16::MAX.into()))
}
})?;
Ok(Some((partition_key, token)))
}
pub fn calculate_token(&self, values: &impl SerializeRow) -> Result<Option<Token>, QueryError> {
self.calculate_token_untyped(&self.serialize_values(values)?)
}
pub(crate) fn calculate_token_untyped(
&self,
values: &SerializedValues,
) -> Result<Option<Token>, QueryError> {
self.extract_partition_key_and_calculate_token(&self.partitioner_name, values)
.map(|opt| opt.map(|(_pk, token)| token))
}
pub fn get_keyspace_name(&self) -> Option<&str> {
self.get_prepared_metadata()
.col_specs
.first()
.map(|col_spec| col_spec.table_spec.ks_name.as_str())
}
pub fn get_table_name(&self) -> Option<&str> {
self.get_prepared_metadata()
.col_specs
.first()
.map(|col_spec| col_spec.table_spec.table_name.as_str())
}
pub fn set_consistency(&mut self, c: Consistency) {
self.config.consistency = Some(c);
}
pub fn get_consistency(&self) -> Option<Consistency> {
self.config.consistency
}
pub fn set_serial_consistency(&mut self, sc: Option<SerialConsistency>) {
self.config.serial_consistency = Some(sc);
}
pub fn get_serial_consistency(&self) -> Option<SerialConsistency> {
self.config.serial_consistency.flatten()
}
pub fn set_is_idempotent(&mut self, is_idempotent: bool) {
self.config.is_idempotent = is_idempotent;
}
pub fn get_is_idempotent(&self) -> bool {
self.config.is_idempotent
}
pub fn set_tracing(&mut self, should_trace: bool) {
self.config.tracing = should_trace;
}
pub fn get_tracing(&self) -> bool {
self.config.tracing
}
pub fn set_timestamp(&mut self, timestamp: Option<i64>) {
self.config.timestamp = timestamp
}
pub fn get_timestamp(&self) -> Option<i64> {
self.config.timestamp
}
pub fn set_request_timeout(&mut self, timeout: Option<Duration>) {
self.config.request_timeout = timeout
}
pub fn get_request_timeout(&self) -> Option<Duration> {
self.config.request_timeout
}
pub(crate) fn set_partitioner_name(&mut self, partitioner_name: PartitionerName) {
self.partitioner_name = partitioner_name;
}
pub fn get_prepared_metadata(&self) -> &PreparedMetadata {
&self.shared.metadata
}
pub(crate) fn get_partitioner_name(&self) -> &PartitionerName {
&self.partitioner_name
}
#[inline]
pub fn set_retry_policy(&mut self, retry_policy: Option<Arc<dyn RetryPolicy>>) {
self.config.retry_policy = retry_policy;
}
#[inline]
pub fn get_retry_policy(&self) -> Option<&Arc<dyn RetryPolicy>> {
self.config.retry_policy.as_ref()
}
pub fn set_history_listener(&mut self, history_listener: Arc<dyn HistoryListener>) {
self.config.history_listener = Some(history_listener);
}
pub fn remove_history_listener(&mut self) -> Option<Arc<dyn HistoryListener>> {
self.config.history_listener.take()
}
pub fn set_execution_profile_handle(&mut self, profile_handle: Option<ExecutionProfileHandle>) {
self.config.execution_profile_handle = profile_handle;
}
pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> {
self.config.execution_profile_handle.as_ref()
}
pub(crate) fn serialize_values(
&self,
values: &impl SerializeRow,
) -> Result<SerializedValues, SerializationError> {
let ctx = RowSerializationContext::from_prepared(self.get_prepared_metadata());
SerializedValues::from_serializable(&ctx, values)
}
}
#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
pub enum PartitionKeyExtractionError {
#[error("No value with given pk_index! pk_index: {0}, values.len(): {1}")]
NoPkIndexValue(u16, u16),
}
#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
pub enum TokenCalculationError {
#[error("Value bytes too long to create partition key, max 65 535 allowed! value.len(): {0}")]
ValueTooLong(usize),
}
#[derive(Clone, Debug, Error)]
pub enum PartitionKeyError {
#[error(transparent)]
PartitionKeyExtraction(PartitionKeyExtractionError),
#[error(transparent)]
TokenCalculation(TokenCalculationError),
#[error(transparent)]
Serialization(SerializationError),
}
impl From<PartitionKeyExtractionError> for PartitionKeyError {
fn from(err: PartitionKeyExtractionError) -> Self {
Self::PartitionKeyExtraction(err)
}
}
impl From<TokenCalculationError> for PartitionKeyError {
fn from(err: TokenCalculationError) -> Self {
Self::TokenCalculation(err)
}
}
impl From<SerializationError> for PartitionKeyError {
fn from(err: SerializationError) -> Self {
Self::Serialization(err)
}
}
pub(crate) type PartitionKeyValue<'ps> = (&'ps [u8], &'ps ColumnSpec);
pub(crate) struct PartitionKey<'ps> {
pk_values: SmallVec<[Option<PartitionKeyValue<'ps>>; PartitionKey::SMALLVEC_ON_STACK_SIZE]>,
}
impl<'ps> PartitionKey<'ps> {
const SMALLVEC_ON_STACK_SIZE: usize = 8;
fn new(
prepared_metadata: &'ps PreparedMetadata,
bound_values: &'ps SerializedValues,
) -> Result<Self, PartitionKeyExtractionError> {
let mut pk_values: SmallVec<[_; PartitionKey::SMALLVEC_ON_STACK_SIZE]> =
smallvec![None; prepared_metadata.pk_indexes.len()];
let mut values_iter = bound_values.iter();
let mut values_iter_offset = 0;
for pk_index in prepared_metadata.pk_indexes.iter().copied() {
let next_val = values_iter
.nth((pk_index.index - values_iter_offset) as usize)
.ok_or_else(|| {
PartitionKeyExtractionError::NoPkIndexValue(
pk_index.index,
bound_values.element_count(),
)
})?;
if let RawValue::Value(v) = next_val {
let spec = &prepared_metadata.col_specs[pk_index.index as usize];
pk_values[pk_index.sequence as usize] = Some((v, spec));
}
values_iter_offset = pk_index.index + 1;
}
Ok(Self { pk_values })
}
pub(crate) fn iter(&self) -> impl Iterator<Item = PartitionKeyValue<'ps>> + Clone + '_ {
self.pk_values.iter().flatten().copied()
}
fn write_encoded_partition_key(
&self,
writer: &mut impl FnMut(&[u8]),
) -> Result<(), TokenCalculationError> {
let mut pk_val_iter = self.iter().map(|(val, _spec)| val);
if let Some(first_value) = pk_val_iter.next() {
if let Some(second_value) = pk_val_iter.next() {
for value in std::iter::once(first_value)
.chain(std::iter::once(second_value))
.chain(pk_val_iter)
{
let v_len_u16: u16 = value
.len()
.try_into()
.map_err(|_| TokenCalculationError::ValueTooLong(value.len()))?;
writer(&v_len_u16.to_be_bytes());
writer(value);
writer(&[0u8]);
}
} else {
writer(first_value);
}
}
Ok(())
}
pub(crate) fn calculate_token(
&self,
partitioner_name: &PartitionerName,
) -> Result<Token, TokenCalculationError> {
let mut partitioner_hasher = partitioner_name.build_hasher();
let mut writer = |chunk: &[u8]| partitioner_hasher.write(chunk);
self.write_encoded_partition_key(&mut writer)?;
Ok(partitioner_hasher.finish())
}
}
#[cfg(test)]
mod tests {
use scylla_cql::{
frame::response::result::{
ColumnSpec, ColumnType, PartitionKeyIndex, PreparedMetadata, TableSpec,
},
types::serialize::row::SerializedValues,
};
use crate::prepared_statement::PartitionKey;
fn make_meta(
cols: impl IntoIterator<Item = ColumnType>,
idx: impl IntoIterator<Item = usize>,
) -> PreparedMetadata {
let table_spec = TableSpec {
ks_name: "ks".to_owned(),
table_name: "t".to_owned(),
};
let col_specs: Vec<_> = cols
.into_iter()
.enumerate()
.map(|(i, typ)| ColumnSpec {
name: format!("col_{}", i),
table_spec: table_spec.clone(),
typ,
})
.collect();
let mut pk_indexes = idx
.into_iter()
.enumerate()
.map(|(sequence, index)| PartitionKeyIndex {
index: index as u16,
sequence: sequence as u16,
})
.collect::<Vec<_>>();
pk_indexes.sort_unstable_by_key(|pki| pki.index);
PreparedMetadata {
flags: 0,
col_count: col_specs.len(),
col_specs,
pk_indexes,
}
}
#[test]
fn test_partition_key_multiple_columns_shuffled() {
let meta = make_meta(
[
ColumnType::TinyInt,
ColumnType::SmallInt,
ColumnType::Int,
ColumnType::BigInt,
ColumnType::Blob,
],
[4, 0, 3],
);
let mut values = SerializedValues::new();
values.add_value(&67i8, &ColumnType::TinyInt).unwrap();
values.add_value(&42i16, &ColumnType::SmallInt).unwrap();
values.add_value(&23i32, &ColumnType::Int).unwrap();
values.add_value(&89i64, &ColumnType::BigInt).unwrap();
values
.add_value(&[1u8, 2, 3, 4, 5], &ColumnType::Blob)
.unwrap();
let pk = PartitionKey::new(&meta, &values).unwrap();
let pk_cols = Vec::from_iter(pk.iter());
assert_eq!(
pk_cols,
vec![
([1u8, 2, 3, 4, 5].as_slice(), &meta.col_specs[4]),
(67i8.to_be_bytes().as_ref(), &meta.col_specs[0]),
(89i64.to_be_bytes().as_ref(), &meta.col_specs[3]),
]
);
}
}