diff --git a/crypto-macros/src/entity_derive/derive_impl.rs b/crypto-macros/src/entity_derive/derive_impl.rs index 9027794802..111a2c232a 100644 --- a/crypto-macros/src/entity_derive/derive_impl.rs +++ b/crypto-macros/src/entity_derive/derive_impl.rs @@ -38,7 +38,7 @@ impl KeyStoreEntityFlattened { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::#struct_name(self) + crate::transaction::dynamic_dispatch::Entity::#struct_name(self.into()) } } } diff --git a/crypto-macros/src/entity_derive_new/column.rs b/crypto-macros/src/entity_derive_new/column.rs index ecc28c5944..9eddf18700 100644 --- a/crypto-macros/src/entity_derive_new/column.rs +++ b/crypto-macros/src/entity_derive_new/column.rs @@ -57,7 +57,11 @@ where pub(super) fn load_expression(&self) -> TokenStream { let column_name = self.sql_name(); - let expr = quote!(row.get::<_, Vec>(#column_name)?); + let sql_data_type = match self.transformation { + None => self.column_type.get_as_type(), + Some(FieldTransformation::Hex) => quote!(String), + }; + let expr = quote!(row.get::<_, #sql_data_type>(#column_name)?); let expr = match self.transformation { None => expr, diff --git a/crypto-macros/src/entity_derive_new/column_type.rs b/crypto-macros/src/entity_derive_new/column_type.rs index 5cb3f06fee..cde35bf274 100644 --- a/crypto-macros/src/entity_derive_new/column_type.rs +++ b/crypto-macros/src/entity_derive_new/column_type.rs @@ -133,6 +133,9 @@ impl TryFrom for ColumnType { pub(super) trait EmitGetExpression { /// Emit an expression which wraps the input expression, appropriately parsing according to this column type. fn emit_get_expression(&self, input: TokenStream) -> TokenStream; + + /// Emit an expression with the rust type which should be used in the rusqlite `get` expression + fn get_as_type(&self) -> TokenStream; } impl EmitGetExpression for IdColumnType { @@ -142,17 +145,24 @@ impl EmitGetExpression for IdColumnType { Self::String => quote!(String::from_utf8(#input).map_err(|err| err.utf8_error())?), } } + + fn get_as_type(&self) -> TokenStream { + quote!(Vec) + } } impl EmitGetExpression for ColumnType { fn emit_get_expression(&self, input: TokenStream) -> TokenStream { match self { - ColumnType::Bytes => input, + ColumnType::Bytes | ColumnType::OptionalBytes => input, ColumnType::String => quote!(String::from_utf8(#input).map_err(|err| err.utf8_error())?), - ColumnType::OptionalBytes => quote! {{ - let data = #input; - (!data.is_empty()).then_some(data) - }}, + } + } + + fn get_as_type(&self) -> TokenStream { + match self { + ColumnType::Bytes | ColumnType::String => quote!(Vec), + ColumnType::OptionalBytes => quote!(Option>), } } } diff --git a/crypto-macros/src/entity_derive_new/derive_impl.rs b/crypto-macros/src/entity_derive_new/derive_impl.rs index 3596f111d8..d731ad66d4 100644 --- a/crypto-macros/src/entity_derive_new/derive_impl.rs +++ b/crypto-macros/src/entity_derive_new/derive_impl.rs @@ -3,14 +3,15 @@ use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{Ident, Lifetime}; -use crate::entity_derive_new::{Entity, column_type::ColumnType}; +use crate::entity_derive_new::{Entity, FieldTransformation, column_type::ColumnType}; impl quote::ToTokens for Entity { fn to_tokens(&self, tokens: &mut TokenStream) { tokens.extend(self.impl_entity_base()); + tokens.extend(self.impl_primary_key()); tokens.extend(self.impl_entity_generic()); tokens.extend(self.impl_entity_wasm()); - tokens.extend(self.impl_borrow_primary_key()); + tokens.extend(self.impl_entity_get_borrowed()); tokens.extend(self.impl_entity_database_mutation()); tokens.extend(self.impl_entity_delete_borrowed()); tokens.extend(self.impl_decrypting()); @@ -37,7 +38,36 @@ impl Entity { const COLLECTION_NAME: &'static str = #collection_name; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::#struct_name(self) + crate::transaction::dynamic_dispatch::Entity::#struct_name(self.into()) + } + } + } + } + + /// `impl PrimaryKey for MyEntity` and `impl BorrowPrimaryKey for MyEntity` + fn impl_primary_key(&self) -> TokenStream { + let Self { + struct_name, id_column, .. + } = self; + + let primary_key = id_column.column_type.owned(); + let borrowed_primary_key = id_column.column_type.borrowed(); + let pk_field_name = &id_column.field_name; + + quote! { + impl crate::traits::PrimaryKey for #struct_name { + type PrimaryKey = #primary_key; + + fn primary_key(&self) -> Self::PrimaryKey { + self.#pk_field_name.clone() + } + } + + impl crate::traits::BorrowPrimaryKey for #struct_name { + type BorrowedPrimaryKey = #borrowed_primary_key; + + fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey { + &self.#pk_field_name } } } @@ -52,9 +82,6 @@ impl Entity { .. } = self; - let primary_key = id_column.column_type.owned(); - let id_field_name = &id_column.field_name; - let field_assignments = std::iter::once(id_column.field_assignment()) .chain(other_columns.iter().map(|column| column.field_assignment())); @@ -62,14 +89,8 @@ impl Entity { #[cfg(not(target_family = "wasm"))] #[::async_trait::async_trait] impl crate::traits::Entity for #struct_name { - type PrimaryKey = #primary_key; - - fn primary_key(&self) -> #primary_key { - self.#id_field_name.clone() - } - async fn get(conn: &mut Self::ConnectionType, key: &Self::PrimaryKey) -> crate::CryptoKeystoreResult> { - ::get_borrowed(conn, key).await + ::get_borrowed(conn, key).await } async fn count(conn: &mut Self::ConnectionType) -> crate::CryptoKeystoreResult { @@ -89,25 +110,14 @@ impl Entity { /// `#[cfg(target_family = "wasm")] impl Entity for MyEntity` fn impl_entity_wasm(&self) -> TokenStream { - let Self { - struct_name, id_column, .. - } = self; - - let primary_key = id_column.column_type.owned(); - let id_field_name = &id_column.field_name; + let Self { struct_name, .. } = self; quote! { #[cfg(target_family = "wasm")] #[::async_trait::async_trait(?Send)] impl crate::traits::Entity for #struct_name { - type PrimaryKey = #primary_key; - - fn primary_key(&self) -> #primary_key { - self.#id_field_name.clone() - } - async fn get(conn: &mut Self::ConnectionType, key: &Self::PrimaryKey) -> crate::CryptoKeystoreResult> { - ::get_borrowed(conn, key).await + ::get_borrowed(conn, key).await } async fn count(conn: &mut Self::ConnectionType) -> crate::CryptoKeystoreResult { @@ -121,8 +131,8 @@ impl Entity { } } - /// `impl BorrowPrimaryKey for MyEntity` - fn impl_borrow_primary_key(&self) -> TokenStream { + /// `impl EntityGetBorrowed for MyEntity` + fn impl_entity_get_borrowed(&self) -> TokenStream { let Self { struct_name, id_column, @@ -130,8 +140,6 @@ impl Entity { .. } = self; - let borrowed_primary_key = id_column.column_type.borrowed(); - let pk_field_name = &id_column.field_name; let pk_column_name = id_column .column_name .clone() @@ -140,29 +148,31 @@ impl Entity { let field_assignments = std::iter::once(id_column.field_assignment()) .chain(other_columns.iter().map(|column| column.field_assignment())); + // if we ever add a second field transformation, we'll want this match pattern + #[allow(clippy::manual_map)] + let key_transform = match id_column.transformation { + None => None, + Some(FieldTransformation::Hex) => Some(quote! {let key = hex::encode(key);}), + }; + quote! { #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)] - impl crate::traits::BorrowPrimaryKey for #struct_name { - type BorrowedPrimaryKey = #borrowed_primary_key; - - fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey { - &self.#pk_field_name - } - - async fn get_borrowed(conn: &mut Self::ConnectionType, key: &Q) -> crate::CryptoKeystoreResult> - where - Self::PrimaryKey: std::borrow::Borrow, - Q: crate::traits::KeyType, + impl crate::traits::EntityGetBorrowed for #struct_name { + async fn get_borrowed(conn: &mut Self::ConnectionType, key: &Self::BorrowedPrimaryKey) + -> crate::CryptoKeystoreResult> { + let key = <&Self::BorrowedPrimaryKey as crate::traits::KeyType>::bytes(&key); + let key = key.as_ref(); + #key_transform #[cfg(target_family = "wasm")] { - conn.storage().new_get(key.bytes().as_ref()).await + conn.storage().new_get(key).await } #[cfg(not(target_family = "wasm"))] { - crate::entities::platform::get_helper::(conn, #pk_column_name, key.bytes().as_ref(), |row| { + crate::entities::platform::get_helper::(conn, #pk_column_name, key, |row| { Ok(Self { #( #field_assignments, )* }) @@ -196,6 +206,10 @@ impl Entity { .map(|tokens| quote!(#tokens,)) .collect::(); + let sql_map_err = (!upsert).then_some(quote! { + .map_err(|_| CryptoKeystoreError::AlreadyExists(Self::COLLECTION_NAME)) + }); + quote! { #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] @@ -211,7 +225,7 @@ impl Entity { #[cfg(not(target_family = "wasm"))] { let mut stmt = tx.prepare_cached(#sql_statement)?; - stmt.execute(rusqlite::params![#fields])?; + stmt.execute(rusqlite::params![#fields])#sql_map_err?; Ok(()) } } @@ -241,27 +255,35 @@ impl Entity { } = self; let id_column_name = id_column.sql_name(); + // if we ever add a second field transformation, we'll want this match pattern + #[allow(clippy::manual_map)] + let key_transform = match id_column.transformation { + None => None, + Some(FieldTransformation::Hex) => Some(quote! {let key = hex::encode(key);}), + }; quote! { #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] impl<'a> crate::traits::EntityDeleteBorrowed<'a> for #struct_name { - async fn delete_borrowed( + async fn delete_borrowed( tx: &>::Transaction, - id: &Q, + id: &::BorrowedPrimaryKey, ) -> crate::CryptoKeystoreResult where - Self::PrimaryKey: std::borrow::Borrow, - Q: crate::traits::KeyType + for<'pk> &'pk ::BorrowedPrimaryKey: crate::traits::KeyType, { + let key = <&::BorrowedPrimaryKey as crate::traits::KeyType>::bytes(&id); + let key = key.as_ref(); + #key_transform #[cfg(target_family = "wasm")] { - tx.new_delete::(id.bytes().as_ref()).await + tx.new_delete::(key).await } #[cfg(not(target_family = "wasm"))] { - crate::entities::platform::delete_helper::(tx, #id_column_name, id.bytes().as_ref()).await + crate::entities::platform::delete_helper::(tx, #id_column_name, key).await } } } diff --git a/crypto/src/e2e_identity/pki_env.rs b/crypto/src/e2e_identity/pki_env.rs index 4e4ea3cfea..670246fc23 100644 --- a/crypto/src/e2e_identity/pki_env.rs +++ b/crypto/src/e2e_identity/pki_env.rs @@ -1,8 +1,8 @@ use std::collections::HashSet; use core_crypto_keystore::{ - connection::FetchFromDatabase, entities::{E2eiAcmeCA, E2eiCrl, E2eiIntermediateCert}, + traits::FetchFromDatabase, }; use wire_e2e_identity::prelude::x509::revocation::{PkiEnvironment, PkiEnvironmentParams}; use x509_cert::der::Decode; @@ -33,7 +33,7 @@ impl IntoIterator for NewCrlDistributionPoints { pub(crate) async fn restore_pki_env(data_provider: &impl FetchFromDatabase) -> Result> { let mut trust_roots = vec![]; - let Ok(ta_raw) = data_provider.find_unique::().await else { + let Ok(Some(ta_raw)) = data_provider.get_unique::().await else { return Ok(None); }; @@ -42,7 +42,7 @@ pub(crate) async fn restore_pki_env(data_provider: &impl FetchFromDatabase) -> R ); let intermediates = data_provider - .find_all::(Default::default()) + .load_all::() .await .map_err(KeystoreError::wrap("finding intermediate certificates"))? .into_iter() @@ -50,7 +50,7 @@ pub(crate) async fn restore_pki_env(data_provider: &impl FetchFromDatabase) -> R .collect::, _>>()?; let crls = data_provider - .find_all::(Default::default()) + .load_all::() .await .map_err(KeystoreError::wrap("finding crls"))? .into_iter() diff --git a/crypto/src/group_store.rs b/crypto/src/group_store.rs index 9869980faf..6dadf2669c 100644 --- a/crypto/src/group_store.rs +++ b/crypto/src/group_store.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use core_crypto_keystore::connection::FetchFromDatabase; +use core_crypto_keystore::traits::FetchFromDatabase; use crate::{ConversationId, KeystoreError, MlsConversation, RecursiveError, Result}; @@ -31,7 +31,7 @@ impl GroupStoreEntity for MlsConversation { keystore: &impl FetchFromDatabase, ) -> crate::Result> { let result = keystore - .find::(id) + .get_borrowed::(id.as_ref()) .await .map_err(KeystoreError::wrap("finding mls conversation from keystore by id"))?; let Some(store_value) = result else { diff --git a/crypto/src/mls/conversation/conversation_guard/decrypt/buffer_commit.rs b/crypto/src/mls/conversation/conversation_guard/decrypt/buffer_commit.rs index c0b0e360b1..36dc02561a 100644 --- a/crypto/src/mls/conversation/conversation_guard/decrypt/buffer_commit.rs +++ b/crypto/src/mls/conversation/conversation_guard/decrypt/buffer_commit.rs @@ -1,4 +1,4 @@ -use core_crypto_keystore::{connection::FetchFromDatabase as _, entities::StoredBufferedCommit}; +use core_crypto_keystore::{entities::StoredBufferedCommit, traits::FetchFromDatabase as _}; use log::info; use openmls::framing::MlsMessageIn; use openmls_traits::OpenMlsCryptoProvider as _; @@ -34,7 +34,7 @@ impl ConversationGuard { self.crypto_provider() .await? .keystore() - .find::(conversation.id()) + .get_borrowed::(conversation.id().as_ref()) .await .map(|option| option.map(StoredBufferedCommit::into_commit_data)) .map_err(KeystoreError::wrap("attempting to retrieve buffered commit")) @@ -69,7 +69,7 @@ impl ConversationGuard { self.crypto_provider() .await? .keystore() - .remove::(conversation.id()) + .remove_borrowed::(conversation.id().as_ref()) .await .map_err(KeystoreError::wrap("attempting to clear buffered commit")) .map_err(Into::into) diff --git a/crypto/src/mls/conversation/merge.rs b/crypto/src/mls/conversation/merge.rs index e912c3b349..22eee153ab 100644 --- a/crypto/src/mls/conversation/merge.rs +++ b/crypto/src/mls/conversation/merge.rs @@ -33,7 +33,7 @@ impl MlsConversation { // ..so if there's any, we clear them after the commit is merged for oln in &previous_own_leaf_nodes { let ek = oln.encryption_key().as_slice(); - let _ = backend.key_store().remove::(ek).await; + let _ = backend.key_store().remove_borrowed::(ek).await; } client diff --git a/crypto/src/mls/conversation/persistence.rs b/crypto/src/mls/conversation/persistence.rs index 9f7c53aac7..859e14db99 100644 --- a/crypto/src/mls/conversation/persistence.rs +++ b/crypto/src/mls/conversation/persistence.rs @@ -1,10 +1,6 @@ use std::collections::HashMap; -use core_crypto_keystore::{ - CryptoKeystoreMls as _, - connection::FetchFromDatabase as _, - entities::{EntityFindParams, PersistedMlsGroup}, -}; +use core_crypto_keystore::{CryptoKeystoreMls as _, entities::PersistedMlsGroup, traits::FetchFromDatabase as _}; use mls_crypto_provider::Database; use openmls::group::{InnerState, MlsGroup}; @@ -50,7 +46,7 @@ impl MlsConversation { /// Effectively [`Database::mls_groups_restore`] but with better types pub(crate) async fn load_all(keystore: &Database) -> Result> { let groups = keystore - .find_all::(EntityFindParams::default()) + .load_all::() .await .map_err(KeystoreError::wrap("finding all persisted mls groups"))?; groups diff --git a/crypto/src/mls/conversation/renew.rs b/crypto/src/mls/conversation/renew.rs index 7106a36c37..483a647f27 100644 --- a/crypto/src/mls/conversation/renew.rs +++ b/crypto/src/mls/conversation/renew.rs @@ -131,7 +131,7 @@ impl MlsConversation { // encryption key from the keystore otherwise we would have a leak backend .key_store() - .remove::(leaf_node.encryption_key().as_slice()) + .remove_borrowed::(leaf_node.encryption_key().as_slice()) .await .map_err(KeystoreError::wrap("removing mls encryption keypair"))?; } diff --git a/crypto/src/mls/conversation/welcome.rs b/crypto/src/mls/conversation/welcome.rs index 72742b90e4..93df24d222 100644 --- a/crypto/src/mls/conversation/welcome.rs +++ b/crypto/src/mls/conversation/welcome.rs @@ -1,4 +1,4 @@ -use core_crypto_keystore::{connection::FetchFromDatabase, entities::PersistedMlsPendingGroup}; +use core_crypto_keystore::{entities::PersistedMlsPendingGroup, traits::FetchFromDatabase}; use mls_crypto_provider::MlsCryptoProvider; use openmls::prelude::{MlsGroup, Welcome}; use openmls_traits::OpenMlsCryptoProvider; @@ -38,19 +38,24 @@ impl MlsConversation { ) -> Result { let mls_group_config = configuration.as_openmls_default_configuration()?; - let group = MlsGroup::new_from_welcome(backend, &mls_group_config, welcome, None).await; - - let group = match group { - Err(openmls::prelude::WelcomeError::NoMatchingKeyPackage) - | Err(openmls::prelude::WelcomeError::NoMatchingEncryptionKey) => return Err(Error::OrphanWelcome), - _ => group.map_err(MlsError::wrap("group could not be created from welcome"))?, - }; + let group = MlsGroup::new_from_welcome(backend, &mls_group_config, welcome, None) + .await + .map_err(|err| { + use openmls::prelude::WelcomeError; + match err { + WelcomeError::NoMatchingKeyPackage | WelcomeError::NoMatchingEncryptionKey => Error::OrphanWelcome, + _ => MlsError::wrap("group could not be created from welcome")(err).into(), + } + })?; let id = ConversationId::from(group.group_id().as_slice()); let existing_conversation = mls_groups.get_fetch(&id, &backend.keystore(), None).await; let conversation_exists = existing_conversation.ok().flatten().is_some(); - let pending_group = backend.key_store().find::(id.as_ref()).await; + let pending_group = backend + .key_store() + .get_borrowed::(id.as_ref()) + .await; let pending_group_exists = pending_group.ok().flatten().is_some(); if conversation_exists || pending_group_exists { diff --git a/crypto/src/mls/credential/credential_ref/find.rs b/crypto/src/mls/credential/credential_ref/find.rs index 073e53fec3..9c82d32e49 100644 --- a/crypto/src/mls/credential/credential_ref/find.rs +++ b/crypto/src/mls/credential/credential_ref/find.rs @@ -1,7 +1,4 @@ -use core_crypto_keystore::{ - connection::FetchFromDatabase as _, - entities::{EntityFindParams, StoredCredential}, -}; +use core_crypto_keystore::{entities::StoredCredential, traits::FetchFromDatabase as _}; use mls_crypto_provider::Database; use openmls::prelude::Credential as MlsCredential; use tls_codec::Deserialize as _; @@ -67,7 +64,7 @@ impl CredentialRef { } = find_filters; let partial_credentials = database - .find_all::(EntityFindParams::default()) + .load_all::() .await .map_err(KeystoreError::wrap("finding all credentials"))? .into_iter() diff --git a/crypto/src/mls/credential/credential_ref/persistence.rs b/crypto/src/mls/credential/credential_ref/persistence.rs index ce06299a5d..bc4a7bfeec 100644 --- a/crypto/src/mls/credential/credential_ref/persistence.rs +++ b/crypto/src/mls/credential/credential_ref/persistence.rs @@ -4,10 +4,7 @@ //! useful to end users. Clients building on the CC API can't do anything useful with a full [`Credential`], //! and it's wasteful to transfer one across the FFI boundary. -use core_crypto_keystore::{ - connection::FetchFromDatabase as _, - entities::{EntityFindParams, StoredCredential}, -}; +use core_crypto_keystore::{Sha256Hash, entities::StoredCredential, traits::FetchFromDatabase as _}; use mls_crypto_provider::Database; use super::{Error, Result}; @@ -20,7 +17,7 @@ impl CredentialRef { /// For loading a single credential, prefer [`Self::load`]. pub(crate) async fn load_stored_credentials(database: &Database) -> Result> { let credentials = database - .find_all::(EntityFindParams::default()) + .load_all::() .await .map_err(KeystoreError::wrap("finding all mls credentials"))?; Ok(credentials) @@ -31,7 +28,7 @@ impl CredentialRef { /// Note that this does not attach the credential to any Session; it just does the data manipulation. pub(crate) async fn load(&self, database: &Database) -> Result { database - .find::(self.public_key()) + .get::(&Sha256Hash::hash_from(self.public_key())) .await .map_err(KeystoreError::wrap("finding credential"))? .ok_or(Error::CredentialNotFound) diff --git a/crypto/src/mls/credential/crl.rs b/crypto/src/mls/credential/crl.rs index a1d07c73ac..93e13a0ecc 100644 --- a/crypto/src/mls/credential/crl.rs +++ b/crypto/src/mls/credential/crl.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use core_crypto_keystore::{connection::FetchFromDatabase, entities::E2eiCrl}; +use core_crypto_keystore::{entities::E2eiCrl, traits::FetchFromDatabase}; use mls_crypto_provider::MlsCryptoProvider; use openmls::{ group::MlsGroup, @@ -75,7 +75,7 @@ pub(crate) async fn get_new_crl_distribution_points( let stored_crls = backend .key_store() - .find_all::(Default::default()) + .load_all::() .await .map_err(KeystoreError::wrap("finding all e2e crl"))?; let stored_crl_dps: HashSet<&str> = stored_crls.iter().map(|crl| crl.distribution_point.as_str()).collect(); diff --git a/crypto/src/mls/credential/persistence.rs b/crypto/src/mls/credential/persistence.rs index 590f6a4cf0..0b7742b25b 100644 --- a/crypto/src/mls/credential/persistence.rs +++ b/crypto/src/mls/credential/persistence.rs @@ -1,4 +1,4 @@ -use core_crypto_keystore::entities::StoredCredential; +use core_crypto_keystore::{Sha256Hash, entities::StoredCredential}; use mls_crypto_provider::Database; use tls_codec::Serialize as _; @@ -6,14 +6,6 @@ use super::{Error, Result}; use crate::{Credential, CredentialRef, KeystoreError}; impl Credential { - /// Update all the fields that were updated by the DB during the save. - /// - /// [`::pre_save`][core_crypto_keystore::entities::EntityTransactionExt::pre_save]. - fn update_from(&mut self, stored: StoredCredential) { - self.earliest_validity = stored.created_at; - } - /// Persist this credential into the database. /// /// Returns a reference which is stable over time and across the FFI boundary. @@ -26,7 +18,7 @@ impl Credential { .tls_serialize_detached() .map_err(Error::tls_serialize("credential"))?; - let stored_credential = database + self.earliest_validity = database .save(StoredCredential { session_id: self.client_id().to_owned().into_inner(), credential: credential_data, @@ -38,15 +30,13 @@ impl Credential { .await .map_err(KeystoreError::wrap("saving credential"))?; - self.update_from(stored_credential); - Ok(CredentialRef::from_credential(self)) } /// Delete this credential from the database pub(crate) async fn delete(self, database: &Database) -> Result<()> { database - .remove::(self.signature_key_pair.public()) + .remove::(&Sha256Hash::hash_from(self.signature_key_pair.public())) .await .map_err(KeystoreError::wrap("deleting credential"))?; diff --git a/crypto/src/mls/session/key_package.rs b/crypto/src/mls/session/key_package.rs index 1f7b0376a3..38ccf0f901 100644 --- a/crypto/src/mls/session/key_package.rs +++ b/crypto/src/mls/session/key_package.rs @@ -1,8 +1,8 @@ use std::{sync::Arc, time::Duration}; use core_crypto_keystore::{ - connection::FetchFromDatabase, - entities::{EntityFindParams, StoredEncryptionKeyPair, StoredHpkePrivateKey, StoredKeypackage}, + entities::{StoredEncryptionKeyPair, StoredHpkePrivateKey, StoredKeypackage}, + traits::FetchFromDatabase, }; use openmls::prelude::{CryptoConfig, Lifetime}; @@ -14,10 +14,10 @@ use crate::{ /// Default number of Keypackages a client generates the first time it's created #[cfg(not(test))] -pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100; +pub const INITIAL_KEYING_MATERIAL_COUNT: u32 = 100; /// Default number of Keypackages a client generates the first time it's created #[cfg(test)] -pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10; +pub const INITIAL_KEYING_MATERIAL_COUNT: u32 = 10; /// Default lifetime of all generated Keypackages. Matches the limit defined in openmls pub const KEYPACKAGE_DEFAULT_LIFETIME: Duration = Duration::from_secs(60 * 60 * 24 * 28 * 3); // ~3 months @@ -88,7 +88,7 @@ impl Session { let stored_keypackages: Vec = self .crypto_provider .keystore() - .find_all(EntityFindParams::default()) + .load_all() .await .map_err(KeystoreError::wrap("finding all keypackages"))?; @@ -116,7 +116,7 @@ impl Session { pub(crate) async fn load_keypackage(&self, kp_ref: &KeypackageRef) -> Result> { self.crypto_provider .keystore() - .find::(kp_ref.hash_ref()) + .get_borrowed::(kp_ref.hash_ref()) .await .map_err(KeystoreError::wrap("loading keypackage from database"))? .map(|stored_keypackage| from_stored(&stored_keypackage)) @@ -135,13 +135,13 @@ impl Session { }; let db = self.crypto_provider.keystore(); - db.remove::(kp_ref.hash_ref()) + db.remove_borrowed::(kp_ref.hash_ref()) .await .map_err(KeystoreError::wrap("removing key package from keystore"))?; - db.remove::(kp.hpke_init_key().as_slice()) + db.remove_borrowed::(kp.hpke_init_key().as_slice()) .await .map_err(KeystoreError::wrap("removing private key from keystore"))?; - db.remove::(kp.leaf_node().encryption_key().as_slice()) + db.remove_borrowed::(kp.leaf_node().encryption_key().as_slice()) .await .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?; @@ -349,4 +349,19 @@ mod tests { }) .await } + + #[apply(all_cred_cipher)] + async fn can_store_and_load_key_packages(case: TestContext) { + let [cc] = case.sessions().await; + + // generate a keypackage; automatically saves it + let kp = cc.new_keypackage(&case).await; + + let all_keypackages = cc.session.get_keypackages().await.unwrap(); + assert_eq!(all_keypackages[0], kp); + + let kp_ref = kp.make_ref().unwrap(); + let by_ref = cc.session.load_keypackage(&kp_ref).await.unwrap().unwrap(); + assert_eq!(kp, by_ref); + } } diff --git a/crypto/src/mls/session/mod.rs b/crypto/src/mls/session/mod.rs index 84724655d8..5184202553 100644 --- a/crypto/src/mls/session/mod.rs +++ b/crypto/src/mls/session/mod.rs @@ -324,7 +324,7 @@ impl Session { #[cfg(test)] mod tests { - use core_crypto_keystore::{connection::FetchFromDatabase as _, entities::*}; + use core_crypto_keystore::{entities::*, traits::FetchFromDatabase as _}; use mls_crypto_provider::MlsCryptoProvider; use super::*; diff --git a/crypto/src/proteus.rs b/crypto/src/proteus.rs index 3851d5a1f1..f4e9164ed6 100644 --- a/crypto/src/proteus.rs +++ b/crypto/src/proteus.rs @@ -2,8 +2,8 @@ use std::{collections::HashMap, sync::Arc}; use core_crypto_keystore::{ Database as CryptoKeystore, - connection::FetchFromDatabase, entities::{ProteusIdentity, ProteusSession}, + traits::FetchFromDatabase, }; use proteus_wasm::{ keys::{IdentityKeyPair, PreKeyBundle}, @@ -73,8 +73,11 @@ impl GroupStoreEntity for ProteusConversationSession { identity: Option, keystore: &impl FetchFromDatabase, ) -> crate::Result> { + let id = str::from_utf8(id.as_ref()).map_err(KeystoreError::wrap( + "converting id to string to fetch ProteusConversationSession", + ))?; let result = keystore - .find::(id) + .get_borrowed::(id) .await .map_err(KeystoreError::wrap("finding raw group store entity by id"))?; let Some(store_value) = result else { @@ -198,7 +201,7 @@ impl ProteusCentral { /// errors) async fn load_or_create_identity(keystore: &CryptoKeystore) -> Result { let Some(identity) = keystore - .find::(ProteusIdentity::ID) + .get_unique::() .await .map_err(KeystoreError::wrap("finding proteus identity"))? else { @@ -238,10 +241,9 @@ impl ProteusCentral { ) -> Result> { let mut proteus_sessions = GroupStore::new_with_limit(crate::group_store::ITEM_LIMIT * 2); for session in keystore - .find_all::(Default::default()) + .load_all::() .await .map_err(KeystoreError::wrap("finding all proteus sessions"))? - .into_iter() { let proteus_session = Session::deserialise(identity.clone(), &session.session) .map_err(ProteusError::wrap("deserializing session"))?; @@ -381,7 +383,7 @@ impl ProteusCentral { /// Deletes a session in the store pub(crate) async fn session_delete(&mut self, keystore: &CryptoKeystore, session_id: &str) -> Result<()> { - if keystore.remove::(session_id).await.is_ok() { + if keystore.remove_borrowed::(session_id).await.is_ok() { let _ = self.proteus_sessions.remove(session_id.as_bytes()); } Ok(()) @@ -505,9 +507,7 @@ impl ProteusCentral { /// If it cannot be found, one will be created. pub(crate) async fn last_resort_prekey(&self, keystore: &CryptoKeystore) -> Result> { let last_resort = if let Some(last_resort) = keystore - .find::( - Self::last_resort_prekey_id().to_le_bytes().as_slice(), - ) + .get::(&Self::last_resort_prekey_id()) .await .map_err(KeystoreError::wrap("finding proteus prekey"))? { @@ -807,7 +807,7 @@ mod tests { gap_ids.dedup(); } for gap_id in gap_ids.iter() { - keystore.remove::(gap_id.to_le_bytes()).await.unwrap(); + keystore.remove::(gap_id).await.unwrap(); } gap_ids.sort(); @@ -828,7 +828,7 @@ mod tests { gap_ids.dedup(); } for gap_id in gap_ids.iter() { - keystore.remove::(gap_id.to_le_bytes()).await.unwrap(); + keystore.remove::(gap_id).await.unwrap(); } let potential_range = *ID_TEST_RANGE.end()..=(*ID_TEST_RANGE.end() * 2); diff --git a/crypto/src/test_utils/context.rs b/crypto/src/test_utils/context.rs index cc7d29b003..40a25cd6d4 100644 --- a/crypto/src/test_utils/context.rs +++ b/crypto/src/test_utils/context.rs @@ -1,8 +1,8 @@ use std::sync::Arc; use core_crypto_keystore::{ - connection::FetchFromDatabase, - entities::{EntityFindParams, StoredCredential, StoredEncryptionKeyPair, StoredHpkePrivateKey, StoredKeypackage}, + entities::{StoredCredential, StoredEncryptionKeyPair, StoredHpkePrivateKey, StoredKeypackage}, + traits::FetchFromDatabase, }; use openmls::prelude::{Credential as MlsCredential, ExternalSender, HpkePublicKey, KeyPackage, SignaturePublicKey}; use openmls_traits::{OpenMlsCryptoProvider, crypto::OpenMlsCrypto, types::SignatureScheme}; @@ -59,7 +59,7 @@ impl SessionContext { .await .unwrap() .key_store() - .find_all::(EntityFindParams::default()) + .load_all::() .await .unwrap() .into_iter() @@ -147,7 +147,7 @@ impl SessionContext { .keystore() .await .unwrap() - .find::(&skp.tls_serialize_detached().unwrap()) + .get::(&skp.tls_serialize_detached().unwrap()) .await .unwrap() } @@ -158,14 +158,14 @@ impl SessionContext { .keystore() .await .unwrap() - .find_all::(EntityFindParams::default()) + .load_all::() .await .unwrap() .into_iter() .find(|c| c.credential[..] == credential) } - pub async fn count_hpke_private_key(&self) -> usize { + pub async fn count_hpke_private_key(&self) -> u32 { self.transaction .keystore() .await @@ -175,7 +175,7 @@ impl SessionContext { .unwrap() } - pub async fn count_encryption_keypairs(&self) -> usize { + pub async fn count_encryption_keypairs(&self) -> u32 { self.transaction .keystore() .await @@ -185,7 +185,7 @@ impl SessionContext { .unwrap() } - pub async fn count_credentials_in_keystore(&self) -> usize { + pub async fn count_credentials_in_keystore(&self) -> u32 { self.transaction .keystore() .await diff --git a/crypto/src/transaction_context/conversation/mod.rs b/crypto/src/transaction_context/conversation/mod.rs index b34355a8e5..862b9ae445 100644 --- a/crypto/src/transaction_context/conversation/mod.rs +++ b/crypto/src/transaction_context/conversation/mod.rs @@ -6,7 +6,7 @@ pub mod external_sender; pub(crate) mod proposal; pub mod welcome; -use core_crypto_keystore::{connection::FetchFromDatabase as _, entities::PersistedMlsPendingGroup}; +use core_crypto_keystore::{entities::PersistedMlsPendingGroup, traits::FetchFromDatabase as _}; use super::{Error, Result, TransactionContext}; use crate::{ @@ -39,7 +39,7 @@ impl TransactionContext { pub(crate) async fn pending_conversation(&self, id: &ConversationIdRef) -> Result { let keystore = self.keystore().await?; let Some(pending_group) = keystore - .find::(id) + .get_borrowed::(id.as_ref()) .await .map_err(KeystoreError::wrap("finding persisted mls pending group"))? else { diff --git a/crypto/src/transaction_context/e2e_identity/error.rs b/crypto/src/transaction_context/e2e_identity/error.rs index fa744ebeb8..894ba44699 100644 --- a/crypto/src/transaction_context/e2e_identity/error.rs +++ b/crypto/src/transaction_context/e2e_identity/error.rs @@ -21,6 +21,8 @@ pub enum Error { PkiEnvironmentUnset, #[error("The certificate chain is invalid or not complete")] InvalidCertificateChain, + #[error("{0} not found")] + NotFound(&'static str), #[error(transparent)] X509Error(#[from] wire_e2e_identity::prelude::x509::RustyX509CheckError), #[error(transparent)] diff --git a/crypto/src/transaction_context/e2e_identity/init_certificates.rs b/crypto/src/transaction_context/e2e_identity/init_certificates.rs index ea62a9a31a..74ae66a5aa 100644 --- a/crypto/src/transaction_context/e2e_identity/init_certificates.rs +++ b/crypto/src/transaction_context/e2e_identity/init_certificates.rs @@ -1,6 +1,6 @@ use core_crypto_keystore::{ - connection::FetchFromDatabase, entities::{E2eiAcmeCA, E2eiCrl, E2eiIntermediateCert}, + traits::FetchFromDatabase, }; use openmls_traits::OpenMlsCryptoProvider; use wire_e2e_identity::prelude::x509::{ @@ -36,18 +36,14 @@ impl TransactionContext { /// # Parameters /// * `trust_anchor_pem` - PEM certificate to anchor as a Trust Root pub async fn e2ei_register_acme_ca(&self, trust_anchor_pem: String) -> Result<()> { - { - if self - .mls_provider() - .await - .map_err(RecursiveError::transaction("getting mls provider"))? - .keystore() - .find_unique::() - .await - .is_ok() - { - return Err(Error::TrustAnchorAlreadyRegistered); - } + let database = self + .mls_provider() + .await + .map_err(RecursiveError::transaction("getting mls provider"))? + .keystore(); + + if matches!(database.get_unique::().await, Ok(Some(_))) { + return Err(Error::TrustAnchorAlreadyRegistered); } let pki_env = PkiEnvironment::init(PkiEnvironmentParams { @@ -66,10 +62,7 @@ impl TransactionContext { // Save DER repr in keystore let cert_der = PkiEnvironment::encode_cert_to_der(&root_cert)?; let acme_ca = E2eiAcmeCA { content: cert_der }; - self.mls_provider() - .await - .map_err(RecursiveError::transaction("getting mls provider"))? - .keystore() + database .save(acme_ca) .await .map_err(KeystoreError::wrap("saving acme ca"))?; @@ -133,9 +126,10 @@ impl TransactionContext { .await .map_err(RecursiveError::transaction("getting keystore"))?; let trust_anchor = keystore - .find_unique::() + .get_unique::() .await - .map_err(KeystoreError::wrap("finding acme ca"))?; + .map_err(KeystoreError::wrap("finding acme ca"))? + .ok_or(Error::NotFound("E2eiAcmeCA"))?; let trust_anchor = x509_cert::Certificate::from_der(&trust_anchor.content)?; // the `/federation` endpoint from smallstep repeats the root CA @@ -212,7 +206,7 @@ impl TransactionContext { .map_err(RecursiveError::transaction("getting keystore"))?; let dirty = ks - .find::(crl_dp.as_bytes()) + .get::(&crl_dp) .await .ok() .flatten() diff --git a/crypto/src/transaction_context/mod.rs b/crypto/src/transaction_context/mod.rs index 4b03df5098..46ba099eeb 100644 --- a/crypto/src/transaction_context/mod.rs +++ b/crypto/src/transaction_context/mod.rs @@ -6,7 +6,7 @@ use std::sync::Arc; #[cfg(feature = "proteus")] use async_lock::Mutex; use async_lock::{RwLock, RwLockReadGuardArc, RwLockWriteGuardArc}; -use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData}; +use core_crypto_keystore::{CryptoKeystoreError, entities::ConsumerData, traits::FetchFromDatabase as _}; pub use error::{Error, Result}; use mls_crypto_provider::{Database, MlsCryptoProvider}; use openmls_traits::OpenMlsCryptoProvider as _; @@ -292,8 +292,8 @@ impl TransactionContext { /// Get the data that has previously been set by [TransactionContext::set_data]. /// This is meant to be used as a check point at the end of a transaction. pub async fn get_data(&self) -> Result>> { - match self.keystore().await?.find_unique::().await { - Ok(data) => Ok(Some(data.into())), + match self.keystore().await?.get_unique::().await { + Ok(maybe_data) => Ok(maybe_data.map(Into::into)), Err(CryptoKeystoreError::NotFound(..)) => Ok(None), Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()), } diff --git a/crypto/src/transaction_context/test_utils.rs b/crypto/src/transaction_context/test_utils.rs index b22b35fd93..74a88a276c 100644 --- a/crypto/src/transaction_context/test_utils.rs +++ b/crypto/src/transaction_context/test_utils.rs @@ -1,25 +1,25 @@ use core_crypto_keystore::{ - connection::FetchFromDatabase as _, entities::{ MlsPendingMessage, PersistedMlsGroup, PersistedMlsPendingGroup, StoredCredential, StoredE2eiEnrollment, StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle, }, + traits::FetchFromDatabase as _, }; use super::TransactionContext; #[derive(Debug, Clone, Eq, PartialEq)] pub struct EntitiesCount { - pub credential: usize, - pub encryption_keypair: usize, - pub epoch_encryption_keypair: usize, - pub enrollment: usize, - pub group: usize, - pub hpke_private_key: usize, - pub key_package: usize, - pub pending_group: usize, - pub pending_messages: usize, - pub psk_bundle: usize, + pub credential: u32, + pub encryption_keypair: u32, + pub epoch_encryption_keypair: u32, + pub enrollment: u32, + pub group: u32, + pub hpke_private_key: u32, + pub key_package: u32, + pub pending_group: u32, + pub pending_messages: u32, + pub psk_bundle: u32, } impl TransactionContext { diff --git a/keystore-dump/src/main.rs b/keystore-dump/src/main.rs index bbb005a62f..38401100b6 100644 --- a/keystore-dump/src/main.rs +++ b/keystore-dump/src/main.rs @@ -23,7 +23,7 @@ async fn main() -> anyhow::Result<()> { use chrono::TimeZone; use clap::Parser as _; use core_crypto_keystore::{ - ConnectionType, Database as Keystore, DatabaseKey, connection::FetchFromDatabase, entities::*, + ConnectionType, Database as Keystore, DatabaseKey, entities::*, traits::FetchFromDatabase, }; use openmls::prelude::TlsDeserializeTrait; use serde::ser::{SerializeMap, Serializer}; @@ -43,11 +43,7 @@ async fn main() -> anyhow::Result<()> { let mut json_map = json_serializer.serialize_map(None)?; let mut credentials: Vec = vec![]; - for cred in keystore - .find_all::(Default::default()) - .await? - .into_iter() - { + for cred in keystore.load_all::().await?.into_iter() { let mls_credential = openmls::prelude::Credential::tls_deserialize(&mut cred.credential.as_slice())?; let mls_keypair = openmls_basic_credential::SignatureKeyPair::from_raw( core_crypto::Ciphersuite::try_from(cred.ciphersuite) @@ -73,7 +69,7 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("mls_credentials", &credentials)?; let hpke_sks: Vec = keystore - .find_all::(Default::default()) + .load_all::() .await? .into_iter() .map(|hpke_sk| postcard::from_bytes::(&hpke_sk.sk)) @@ -81,7 +77,7 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("mls_hpke_private_keys", &hpke_sks)?; let hpke_keypairs: Vec = keystore - .find_all::(Default::default()) + .load_all::() .await? .into_iter() .map(|hpke_kp| postcard::from_bytes::(&hpke_kp.sk)) @@ -89,11 +85,7 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("mls_hpke_keypairs", &hpke_keypairs)?; let mut external_psks: std::collections::HashMap = Default::default(); - for psk in keystore - .find_all::(Default::default()) - .await? - .into_iter() - { + for psk in keystore.load_all::().await?.into_iter() { let mls_psk = postcard::from_bytes::(&psk.psk)?; external_psks.insert(hex::encode(&psk.psk_id), mls_psk); } @@ -101,7 +93,7 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("external_psks", &external_psks)?; let keypackages: Vec = keystore - .find_all::(Default::default()) + .load_all::() .await? .into_iter() .map(|kp| postcard::from_bytes::(&kp.keypackage)) @@ -109,7 +101,7 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("mls_keypackages", &keypackages)?; let e2ei_enrollments: Vec = keystore - .find_all::(Default::default()) + .load_all::() .await? .into_iter() .map(|enrollment| serde_json::from_slice::(&enrollment.content)) @@ -117,7 +109,7 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("e2ei_enrollments", &e2ei_enrollments)?; let pgroups: Vec = keystore - .find_all::(Default::default()) + .load_all::() .await? .into_iter() .map(|pgroup| core_crypto_keystore::deser::(&pgroup.state)) @@ -125,14 +117,14 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("mls_groups", &pgroups)?; let pegroups: Vec = keystore - .find_all::(Default::default()) + .load_all::() .await? .into_iter() .map(|pgroup| core_crypto_keystore::deser::(&pgroup.state)) .collect::>()?; json_map.serialize_entry("mls_pending_groups", &pegroups)?; - if let Some(proteus_identity) = keystore.find::(ProteusIdentity::ID).await? { + if let Some(proteus_identity) = keystore.get_unique::().await? { let identity = { let sk = proteus_identity.sk_raw(); let pk = proteus_identity.pk_raw(); @@ -141,7 +133,7 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("proteus_identity", &identity)?; let prekeys: Vec = keystore - .find_all::(Default::default()) + .load_all::() .await? .into_iter() .map(|pk| proteus_wasm::keys::PreKey::deserialise(&pk.prekey)) @@ -149,7 +141,7 @@ async fn main() -> anyhow::Result<()> { json_map.serialize_entry("proteus_prekeys", &prekeys)?; let proteus_sessions: Vec> = keystore - .find_all::(Default::default()) + .load_all::() .await? .into_iter() .map(|session| proteus_wasm::session::Session::deserialise(identity.clone(), &session.session)) diff --git a/keystore/src/connection/mod.rs b/keystore/src/connection/mod.rs index 24d51d2756..a31fdd6a0f 100644 --- a/keystore/src/connection/mod.rs +++ b/keystore/src/connection/mod.rs @@ -1,5 +1,6 @@ -use std::{fmt, ops::Deref}; +use std::{borrow::Borrow, fmt, ops::Deref}; +use async_trait::async_trait; use sha2::{Digest as _, Sha256}; use zeroize::{Zeroize, ZeroizeOnDrop}; @@ -29,7 +30,11 @@ use async_lock::{Mutex, MutexGuard, Semaphore}; pub use self::platform::*; use crate::{ CryptoKeystoreError, CryptoKeystoreResult, - entities::{Entity, EntityFindParams, EntityTransactionExt, MlsPendingMessage, StringEntityId, UniqueEntity}, + entities::{MlsPendingMessage, PersistedMlsGroupExt}, + traits::{ + BorrowPrimaryKey, Entity, EntityDatabaseMutation, EntityDeleteBorrowed, EntityGetBorrowed, FetchFromDatabase, + KeyType, + }, transaction::KeystoreTransaction, }; @@ -143,19 +148,19 @@ const ALLOWED_CONCURRENT_TRANSACTIONS_COUNT: usize = 1; /// transaaction #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] -pub trait FetchFromDatabase: Send + Sync { +pub trait OldFetchFromDatabase: Send + Sync { async fn find>( &self, id: impl AsRef<[u8]> + Send, ) -> CryptoKeystoreResult>; - async fn find_unique>( + async fn find_unique>( &self, ) -> CryptoKeystoreResult; async fn find_all>( &self, - params: EntityFindParams, + params: crate::entities::EntityFindParams, ) -> CryptoKeystoreResult>; async fn find_many>( @@ -275,9 +280,7 @@ impl Database { } Ok(()) } -} -impl Database { /// Close this database and delete its contents. pub async fn wipe(&self) -> CryptoKeystoreResult<()> { self.take().await?.wipe().await @@ -322,12 +325,11 @@ impl Database { Ok(()) } - pub async fn child_groups< - E: Entity + crate::entities::PersistedMlsGroupExt + Sync, - >( - &self, - entity: E, - ) -> CryptoKeystoreResult> { + pub async fn child_groups<'a, E>(&self, entity: E) -> CryptoKeystoreResult> + where + E: Clone + Entity + EntityDatabaseMutation<'a> + BorrowPrimaryKey + PersistedMlsGroupExt + Send + Sync, + for<'pk> &'pk ::BorrowedPrimaryKey: KeyType, + { let mut conn = self.conn().await?; let persisted_records = entity.child_groups(conn.deref_mut()).await?; @@ -338,29 +340,37 @@ impl Database { transaction.child_groups(entity, persisted_records).await } - pub async fn save + Sync + EntityTransactionExt>( - &self, - entity: E, - ) -> CryptoKeystoreResult { + pub async fn save<'a, E>(&self, entity: E) -> CryptoKeystoreResult + where + E: Entity + EntityDatabaseMutation<'a> + Send + Sync, + { let transaction_guard = self.transaction.lock().await; let Some(transaction) = transaction_guard.as_ref() else { return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction); }; - transaction.save_mut(entity).await + transaction.save(entity).await } - pub async fn remove< - E: Entity + EntityTransactionExt, - S: AsRef<[u8]>, - >( - &self, - id: S, - ) -> CryptoKeystoreResult<()> { + pub async fn remove<'a, E>(&self, id: &E::PrimaryKey) -> CryptoKeystoreResult<()> + where + E: Entity + EntityDatabaseMutation<'a>, + { + let transaction_guard = self.transaction.lock().await; + let Some(transaction) = transaction_guard.as_ref() else { + return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction); + }; + transaction.remove::(id).await + } + + pub async fn remove_borrowed<'a, E>(&self, id: &E::BorrowedPrimaryKey) -> CryptoKeystoreResult<()> + where + E: Entity + EntityDatabaseMutation<'a> + BorrowPrimaryKey + EntityDeleteBorrowed<'a>, + { let transaction_guard = self.transaction.lock().await; let Some(transaction) = transaction_guard.as_ref() else { return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction); }; - transaction.remove::(id).await + transaction.remove_borrowed::(id).await } pub async fn find_pending_messages_by_conversation_id( @@ -390,81 +400,76 @@ impl Database { }; transaction .remove_pending_messages_by_conversation_id(conversation_id) - .await + .await; + Ok(()) } } -#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] -#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] +#[cfg_attr(target_family = "wasm", async_trait(?Send))] +#[cfg_attr(not(target_family = "wasm"), async_trait)] impl FetchFromDatabase for Database { - async fn find>( - &self, - id: impl AsRef<[u8]> + Send, - ) -> CryptoKeystoreResult> { + async fn get(&self, id: &E::PrimaryKey) -> CryptoKeystoreResult> + where + E: Entity + Clone + Send + Sync, + { // If a transaction is in progress... if let Some(transaction) = self.transaction.lock().await.as_ref() //... and it has information about this entity, ... - && let Some(cached_record) = transaction.find::(id.as_ref()).await? + && let Some(cached_record) = transaction.get(id).await { - // ... return that result - return Ok(cached_record); + return Ok(cached_record.map(Arc::unwrap_or_clone)); } // Otherwise get it from the database let mut conn = self.conn().await?; - E::find_one(&mut conn, &id.as_ref().into()).await + E::get(&mut conn, id).await } - async fn find_unique(&self) -> CryptoKeystoreResult { + async fn get_borrowed(&self, id: &::BorrowedPrimaryKey) -> CryptoKeystoreResult> + where + E: EntityGetBorrowed + Clone + Send + Sync, + E::PrimaryKey: Borrow, + for<'a> &'a E::BorrowedPrimaryKey: KeyType, + { // If a transaction is in progress... if let Some(transaction) = self.transaction.lock().await.as_ref() //... and it has information about this entity, ... - && let Some(cached_record) = transaction.find_unique::().await? + && let Some(cached_record) = transaction.get_borrowed(id).await { - // ... return that result - return Ok(cached_record); + return Ok(cached_record.map(Arc::unwrap_or_clone)); } + // Otherwise get it from the database let mut conn = self.conn().await?; - U::find_unique(&mut conn).await + E::get_borrowed(&mut conn, id).await } - async fn find_all>( - &self, - params: EntityFindParams, - ) -> CryptoKeystoreResult> { - let mut conn = self.conn().await?; - let persisted_records = E::find_all(&mut conn, params.clone()).await?; - - let transaction_guard = self.transaction.lock().await; - let Some(transaction) = transaction_guard.as_ref() else { - return Ok(persisted_records); - }; - transaction.find_all(persisted_records, params).await + async fn count(&self) -> CryptoKeystoreResult + where + E: Entity + Clone + Send + Sync, + { + if self.transaction.lock().await.is_some() { + // Unfortunately, we have to do this because of possible record id overlap + // between cache and db. + let count = self.load_all::().await?.len(); + Ok(count as _) + } else { + let mut conn = self.conn().await?; + E::count(&mut conn).await + } } - async fn find_many>( - &self, - ids: &[Vec], - ) -> CryptoKeystoreResult> { - let entity_ids: Vec = ids.iter().map(|id| id.as_slice().into()).collect(); + async fn load_all(&self) -> CryptoKeystoreResult> + where + E: Entity + Clone + Send + Sync, + { let mut conn = self.conn().await?; - let persisted_records = E::find_many(&mut conn, &entity_ids).await?; + let persisted_records = E::load_all(&mut conn).await?; let transaction_guard = self.transaction.lock().await; let Some(transaction) = transaction_guard.as_ref() else { return Ok(persisted_records); }; - transaction.find_many(persisted_records, ids).await - } - - async fn count>(&self) -> CryptoKeystoreResult { - if self.transaction.lock().await.is_some() { - // Unfortunately, we have to do this because of possible record id overlap - // between cache and db. - return Ok(self.find_all::(Default::default()).await?.len()); - }; - let mut conn = self.conn().await?; - E::count(&mut conn).await + transaction.find_all(persisted_records).await } } diff --git a/keystore/src/connection/platform/generic/mod.rs b/keystore/src/connection/platform/generic/mod.rs index dfd1ac1713..8389a82c92 100644 --- a/keystore/src/connection/platform/generic/mod.rs +++ b/keystore/src/connection/platform/generic/mod.rs @@ -324,7 +324,7 @@ mod migration_test { use crate::{ ConnectionType, Database, DatabaseKey, connection::{FetchFromDatabase, MigrationTarget}, - entities::{EntityFindParams, StoredCredential}, + entities::StoredCredential, }; const DB: &[u8] = include_bytes!("../../../../../crypto-ffi/bindings/jvm/src/test/resources/db-v10002003.sqlite"); @@ -438,7 +438,7 @@ mod migration_test { .await .unwrap(); let deduplicated_credentials = db - .find_all::(EntityFindParams::default()) + .load_all::() .await .expect("deduplicated credentials"); diff --git a/keystore/src/connection/platform/wasm/storage/storage.rs b/keystore/src/connection/platform/wasm/storage/storage.rs index b3b3b13bfa..24ad862599 100644 --- a/keystore/src/connection/platform/wasm/storage/storage.rs +++ b/keystore/src/connection/platform/wasm/storage/storage.rs @@ -280,11 +280,12 @@ impl WasmEncryptedStorage { // After putting some thought into it, I'd prefer not to redesign the `Decrypting` trait, though. // There's always the chance that `serde_wasm_bindgen` will relax that restriction, at which point // we can just relax this bound and all will be good. - pub async fn new_get<'a, E>(&self, key: &[u8]) -> CryptoKeystoreResult> + pub async fn new_get<'a, E>(&self, key: impl AsRef<[u8]>) -> CryptoKeystoreResult> where E: NewEntity + Decryptable<'a>, >::DecryptableFrom: DeserializeOwned, { + let key = key.as_ref(); let js_value = match &self.storage { WasmStorageWrapper::Persistent(idb) => { let transaction = idb.transaction(&[E::COLLECTION_NAME], TransactionMode::ReadOnly)?; diff --git a/keystore/src/connection/platform/wasm/storage/transaction.rs b/keystore/src/connection/platform/wasm/storage/transaction.rs index ab748ee380..c57f7931f8 100644 --- a/keystore/src/connection/platform/wasm/storage/transaction.rs +++ b/keystore/src/connection/platform/wasm/storage/transaction.rs @@ -150,10 +150,11 @@ impl WasmStorageTransaction<'_> { /// `BorrowPrimaryKey` or not, and without specialization, we can't just do the right thing /// and accept the more general form. But we do know the primary key and its borrowed form /// both implement `KeyType`, so it's always safe to accept a byte reference. - pub(crate) async fn new_delete(&self, key: &[u8]) -> CryptoKeystoreResult + pub(crate) async fn new_delete(&self, key: impl AsRef<[u8]>) -> CryptoKeystoreResult where E: NewEntity, { + let key = key.as_ref(); match self { WasmStorageTransaction::Persistent { tx, .. } => { let query = JsValue::from(Uint8Array::from(key)); diff --git a/keystore/src/entities/dummy_entity.rs b/keystore/src/entities/dummy_entity.rs index 725f62b292..fa6733f530 100644 --- a/keystore/src/entities/dummy_entity.rs +++ b/keystore/src/entities/dummy_entity.rs @@ -1,4 +1,4 @@ -use std::{borrow::Borrow, collections::HashSet, sync::LazyLock}; +use std::{collections::HashSet, sync::LazyLock}; use async_lock::RwLock; use sha2::{Digest as _, Sha256}; @@ -8,7 +8,7 @@ use crate::{ entities::{Entity, EntityBase, EntityFindParams, StringEntityId}, traits::{ BorrowPrimaryKey, DecryptData as _, Decryptable, Decrypting, EncryptData as _, Encrypting, Entity as NewEntity, - EntityBase as NewEntityBase, KeyType, UniqueEntity, + EntityBase as NewEntityBase, EntityGetBorrowed, KeyType, PrimaryKey, UniqueEntity, }, }; #[cfg(not(target_family = "wasm"))] @@ -90,15 +90,23 @@ impl Entity for DummyStoreValue { } } -#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] -#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] -impl NewEntity for DummyStoreValue { +impl PrimaryKey for DummyStoreValue { type PrimaryKey = Vec; - - fn primary_key(&self) -> Vec { + fn primary_key(&self) -> Self::PrimaryKey { Vec::new() } +} + +impl BorrowPrimaryKey for DummyStoreValue { + type BorrowedPrimaryKey = [u8]; + fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey { + &[] + } +} +#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] +impl NewEntity for DummyStoreValue { async fn get(_conn: &mut Self::ConnectionType, _key: &Self::PrimaryKey) -> CryptoKeystoreResult> { Ok(None) } @@ -116,24 +124,17 @@ impl NewEntity for DummyStoreValue { #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] -impl BorrowPrimaryKey for DummyStoreValue { - type BorrowedPrimaryKey = [u8]; - - fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey { - &[] - } - - async fn get_borrowed(_conn: &mut Self::ConnectionType, _key: &Q) -> CryptoKeystoreResult> - where - Self::PrimaryKey: Borrow, - Q: KeyType, - { +impl EntityGetBorrowed for DummyStoreValue { + async fn get_borrowed( + _conn: &mut Self::ConnectionType, + _key: &Self::BorrowedPrimaryKey, + ) -> CryptoKeystoreResult> { Ok(None) } } impl UniqueEntity for DummyStoreValue { - const KEY: ::PrimaryKey = Vec::new(); + const KEY: Self::PrimaryKey = Vec::new(); } #[derive(Debug, Clone, PartialEq, Eq)] @@ -183,15 +184,23 @@ impl NewEntityBase for NewDummyStoreValue { } } -#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] -#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] -impl NewEntity for NewDummyStoreValue { +impl PrimaryKey for NewDummyStoreValue { type PrimaryKey = Vec; - - fn primary_key(&self) -> Vec { + fn primary_key(&self) -> Self::PrimaryKey { self.id.clone() } +} + +impl BorrowPrimaryKey for NewDummyStoreValue { + type BorrowedPrimaryKey = [u8]; + fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey { + &self.id + } +} +#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] +impl NewEntity for NewDummyStoreValue { async fn get(conn: &mut Self::ConnectionType, key: &Self::PrimaryKey) -> CryptoKeystoreResult> { Self::get_borrowed(conn, key).await } @@ -212,18 +221,11 @@ impl NewEntity for NewDummyStoreValue { #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] -impl BorrowPrimaryKey for NewDummyStoreValue { - type BorrowedPrimaryKey = [u8]; - - fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey { - &self.id - } - - async fn get_borrowed(_conn: &mut Self::ConnectionType, key: &Q) -> CryptoKeystoreResult> - where - Self::PrimaryKey: Borrow, - Q: KeyType, - { +impl EntityGetBorrowed for NewDummyStoreValue { + async fn get_borrowed( + _conn: &mut Self::ConnectionType, + key: &Self::BorrowedPrimaryKey, + ) -> CryptoKeystoreResult> { let guard = NEW_DUMMY_STORE_IDS.read().await; let key = key.bytes(); let key = key.as_ref(); @@ -271,11 +273,7 @@ impl<'a> EntityDatabaseMutation<'a> for NewDummyStoreValue { #[async_trait::async_trait] impl EntityDeleteBorrowed<'_> for NewDummyStoreValue { /// Delete an entity by a borrowed form of its primary key. - async fn delete_borrowed(_tx: &Self::Transaction, id: &Q) -> CryptoKeystoreResult - where - Self::PrimaryKey: Borrow, - Q: KeyType, - { + async fn delete_borrowed(_tx: &Self::Transaction, id: &Self::BorrowedPrimaryKey) -> CryptoKeystoreResult { let mut guard = NEW_DUMMY_STORE_IDS.write().await; let removed = guard.remove::<[u8]>(id.bytes().as_ref()); Ok(removed) diff --git a/keystore/src/entities/mls.rs b/keystore/src/entities/mls.rs index 857118686d..a57d7d553b 100644 --- a/keystore/src/entities/mls.rs +++ b/keystore/src/entities/mls.rs @@ -1,7 +1,10 @@ -use zeroize::Zeroize; +use zeroize::{Zeroize, ZeroizeOnDrop}; -use super::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, StringEntityId}; -use crate::{CryptoKeystoreError, CryptoKeystoreResult, connection::TransactionWrapper}; +use crate::{ + CryptoKeystoreError, CryptoKeystoreResult, + connection::TransactionWrapper, + traits::{BorrowPrimaryKey, Entity, EntityBase, KeyType, OwnedKeyType, PrimaryKey}, +}; /// Entity representing a persisted `MlsGroup` #[derive( @@ -28,27 +31,33 @@ pub struct PersistedMlsGroup { #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] -pub trait PersistedMlsGroupExt: Entity { +pub trait PersistedMlsGroupExt: Entity + BorrowPrimaryKey +where + for<'a> &'a ::BorrowedPrimaryKey: KeyType, +{ fn parent_id(&self) -> Option<&[u8]>; - async fn parent_group( - &self, - conn: &mut ::ConnectionType, - ) -> CryptoKeystoreResult> { + async fn parent_group(&self, conn: &mut Self::ConnectionType) -> CryptoKeystoreResult> { let Some(parent_id) = self.parent_id() else { return Ok(None); }; - ::find_one(conn, &parent_id.into()).await + let parent_id = OwnedKeyType::from_bytes(parent_id) + .ok_or(CryptoKeystoreError::InvalidPrimaryKeyBytes(Self::COLLECTION_NAME))?; + Self::get(conn, &parent_id).await } - async fn child_groups( - &self, - conn: &mut ::ConnectionType, - ) -> CryptoKeystoreResult> { - let entities = ::find_all(conn, super::EntityFindParams::default()).await?; + async fn child_groups(&self, conn: &mut ::ConnectionType) -> CryptoKeystoreResult> { + // A perfect opportunity for refactoring in WPB-20844 + // when we do that, we no longer need varying implementations according to wasm or not, + // so both `parent_group` and this method should just be implemented directly on `PersistedMlsGroup`. + let entities = Self::load_all(conn).await?; - let id = self.id_raw(); + // for whatever reason rustc needs each of these distinct bindings to prove to itself that the lifetimes work + // out + let id = self.borrow_primary_key(); + let id = id.bytes(); + let id = id.as_ref(); Ok(entities .into_iter() @@ -70,6 +79,77 @@ pub struct PersistedMlsPendingGroup { pub custom_configuration: Vec, } +/// [`MlsPendingMessage`]s have no distinct primary key; +/// they must always be accessed via [`MlsPendingMessage::find_all_by_conversation_id`] and +/// cleaned up with [`MlsPendingMessage::delete_by_conversation_id`] +/// +/// However, we have to fake a primary key type in order to support +/// `KeystoreTransaction::remove_pending_messages_by_conversation_id`. Additionally we need the same one in WASM, where +/// it's necessary for item-level encryption. +/// +/// This implementation is fairly inefficient and hopefully temporary. But it at least implements the correct semantics. +#[derive(ZeroizeOnDrop)] +pub struct MlsPendingMessagePrimaryKey { + pub(crate) foreign_id: Vec, + message: Vec, +} + +impl MlsPendingMessagePrimaryKey { + /// Construct a partial mls pending message primary key from only the conversation id. + /// + /// This does not in fact uniquely identify a single pending message--it should always uniquely + /// identify exactly 0 pending messages--but we have to have it so that we can search and delete + /// by conversation id within transactions. + pub(crate) fn from_conversation_id(conversation_id: impl AsRef<[u8]>) -> Self { + Self { + foreign_id: conversation_id.as_ref().to_owned(), + message: Vec::new(), + } + } +} + +impl From<&MlsPendingMessage> for MlsPendingMessagePrimaryKey { + fn from(value: &MlsPendingMessage) -> Self { + Self { + foreign_id: value.foreign_id.clone(), + message: value.message.clone(), + } + } +} + +impl KeyType for MlsPendingMessagePrimaryKey { + fn bytes(&self) -> std::borrow::Cow<'_, [u8]> { + // run-length encoding: 32 bits of size for each field, followed by the field + let fields = [&self.foreign_id, &self.message]; + let mut key = Vec::with_capacity( + ((u32::BITS / u8::BITS) as usize * fields.len()) + self.foreign_id.len() + self.message.len(), + ); + for field in fields { + key.extend((field.len() as u32).to_le_bytes()); + key.extend(field.as_slice()); + } + key.into() + } +} + +impl OwnedKeyType for MlsPendingMessagePrimaryKey { + fn from_bytes(bytes: &[u8]) -> Option { + // run-length decoding: 32 bits of size for each field, followed by the field + let (len, bytes) = bytes.split_at_checked(4)?; + let len = u32::from_le_bytes(len.try_into().ok()?); + let (foreign_id, bytes) = bytes.split_at_checked(len as _)?; + + let (len, bytes) = bytes.split_at_checked(4)?; + let len = u32::from_le_bytes(len.try_into().ok()?); + let (message, bytes) = bytes.split_at_checked(len as _)?; + + bytes.is_empty().then(|| Self { + foreign_id: foreign_id.to_owned(), + message: message.to_owned(), + }) + } +} + /// Entity representing a buffered message #[derive(core_crypto_macros::Debug, Clone, PartialEq, Eq, Zeroize, serde::Serialize, serde::Deserialize)] #[zeroize(drop)] @@ -79,6 +159,13 @@ pub struct MlsPendingMessage { pub message: Vec, } +impl PrimaryKey for MlsPendingMessage { + type PrimaryKey = MlsPendingMessagePrimaryKey; + fn primary_key(&self) -> Self::PrimaryKey { + self.into() + } +} + /// Entity representing a buffered commit. /// /// There should always exist either 0 or 1 of these in the store per conversation. @@ -239,7 +326,7 @@ pub struct StoredE2eiEnrollment { #[cfg(target_family = "wasm")] #[async_trait::async_trait(?Send)] pub trait UniqueEntity: - EntityBase + crate::entities::EntityBase + serde::Serialize + serde::de::DeserializeOwned where @@ -259,7 +346,7 @@ where .ok_or(CryptoKeystoreError::NotFound(Self::COLLECTION_NAME, "".to_string()))?) } - async fn find_all(conn: &mut Self::ConnectionType, _params: EntityFindParams) -> CryptoKeystoreResult> { + async fn find_all(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult> { match Self::find_unique(conn).await { Ok(record) => Ok(vec![record]), Err(CryptoKeystoreError::NotFound(..)) => Ok(vec![]), @@ -312,7 +399,7 @@ pub trait UniqueEntity: EntityBase CryptoKeystoreResult> { + async fn find_all(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult> { match Self::find_unique(conn).await { Ok(record) => Ok(vec![record]), Err(CryptoKeystoreError::NotFound(..)) => Ok(vec![]), @@ -366,7 +453,13 @@ pub trait UniqueEntity: EntityBase EntityTransactionExt for T { +impl crate::entities::EntityTransactionExt for T +where + T: crate::entities::Entity + + UniqueEntity + + Send + + Sync, +{ #[cfg(not(target_family = "wasm"))] async fn save(&self, tx: &TransactionWrapper<'_>) -> CryptoKeystoreResult<()> { self.replace(tx).await @@ -380,7 +473,7 @@ impl EntityTransactionExt for T { #[cfg(not(target_family = "wasm"))] async fn delete_fail_on_missing_id( _: &TransactionWrapper<'_>, - _id: StringEntityId<'_>, + _id: crate::entities::StringEntityId<'_>, ) -> CryptoKeystoreResult<()> { Err(CryptoKeystoreError::NotImplemented) } @@ -388,7 +481,7 @@ impl EntityTransactionExt for T { #[cfg(target_family = "wasm")] async fn delete_fail_on_missing_id<'a>( _: &TransactionWrapper<'a>, - _id: StringEntityId<'a>, + _id: crate::entities::StringEntityId<'a>, ) -> CryptoKeystoreResult<()> { Err(CryptoKeystoreError::NotImplemented) } diff --git a/keystore/src/entities/mod.rs b/keystore/src/entities/mod.rs index fe82483be4..39d0795576 100644 --- a/keystore/src/entities/mod.rs +++ b/keystore/src/entities/mod.rs @@ -245,13 +245,14 @@ cfg_if::cfg_if! { } #[async_trait::async_trait(?Send)] - impl Entity for T { + impl Entity for T + where T : UniqueEntity + crate::entities::EntityBase { fn id_raw(&self) -> &[u8] { &Self::ID } - async fn find_all(conn: &mut Self::ConnectionType, params: EntityFindParams) -> CryptoKeystoreResult> { - ::find_all(conn, params).await + async fn find_all(conn: &mut Self::ConnectionType, _params: EntityFindParams) -> CryptoKeystoreResult> { + ::find_all(conn).await } async fn find_one(conn: &mut Self::ConnectionType, _id: &StringEntityId) -> CryptoKeystoreResult> { @@ -324,13 +325,15 @@ cfg_if::cfg_if! { } #[async_trait::async_trait] - impl Entity for T { + impl Entity for T + where T : UniqueEntity + crate::entities::EntityBase + { fn id_raw(&self) -> &[u8] { &[Self::ID as u8] } - async fn find_all(conn: &mut Self::ConnectionType, params: EntityFindParams) -> CryptoKeystoreResult> { - ::find_all(conn, params).await + async fn find_all(conn: &mut Self::ConnectionType, _params: EntityFindParams) -> CryptoKeystoreResult> { + ::find_all(conn).await } async fn find_one(conn: &mut Self::ConnectionType, _id: &StringEntityId) -> CryptoKeystoreResult> { diff --git a/keystore/src/entities/platform/generic/general.rs b/keystore/src/entities/platform/generic/general.rs index 0e578ac28b..3f477a37f0 100644 --- a/keystore/src/entities/platform/generic/general.rs +++ b/keystore/src/entities/platform/generic/general.rs @@ -36,7 +36,7 @@ impl EntityBase for ConsumerData { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::ConsumerData(self) + crate::transaction::dynamic_dispatch::Entity::ConsumerData(self.into()) } } @@ -46,7 +46,7 @@ impl NewEntityBase for ConsumerData { const COLLECTION_NAME: &'static str = "consumer_data"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::ConsumerData(self) + crate::transaction::dynamic_dispatch::Entity::ConsumerData(self.into()) } } diff --git a/keystore/src/entities/platform/generic/mls/credential.rs b/keystore/src/entities/platform/generic/mls/credential.rs index df2fac4422..84a2a79bf5 100644 --- a/keystore/src/entities/platform/generic/mls/credential.rs +++ b/keystore/src/entities/platform/generic/mls/credential.rs @@ -16,7 +16,7 @@ use crate::{ }, traits::{ BorrowPrimaryKey, Entity as NewEntity, EntityBase as NewEntityBase, EntityDatabaseMutation, - EntityDeleteBorrowed, KeyType, + EntityDeleteBorrowed, KeyType, PrimaryKey, }, }; @@ -169,7 +169,7 @@ impl EntityBase for StoredCredential { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::StoredCredential(self) + crate::transaction::dynamic_dispatch::Entity::StoredCredential(self.into()) } } @@ -292,18 +292,20 @@ impl NewEntityBase for StoredCredential { const COLLECTION_NAME: &'static str = "mls_credentials"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::StoredCredential(self) + crate::transaction::dynamic_dispatch::Entity::StoredCredential(self.into()) } } -#[async_trait] -impl NewEntity for StoredCredential { +impl PrimaryKey for StoredCredential { type PrimaryKey = Sha256Hash; fn primary_key(&self) -> Self::PrimaryKey { Sha256Hash::hash_from(&self.public_key) } +} +#[async_trait] +impl NewEntity for StoredCredential { async fn get(conn: &mut Self::ConnectionType, key: &Self::PrimaryKey) -> CryptoKeystoreResult> { let conn = conn.conn().await; let mut stmt = conn.prepare_cached( @@ -387,7 +389,7 @@ impl<'a> EntityDatabaseMutation<'a> for StoredCredential { count_helper_tx::(tx).await } - async fn delete(tx: &Self::Transaction, id: &::PrimaryKey) -> CryptoKeystoreResult { + async fn delete(tx: &Self::Transaction, id: &Self::PrimaryKey) -> CryptoKeystoreResult { delete_helper::(tx, "public_key_sha256", id).await } } diff --git a/keystore/src/entities/platform/generic/mls/e2ei_acme_ca.rs b/keystore/src/entities/platform/generic/mls/e2ei_acme_ca.rs index 7bbad6ba7a..f8be2a53cb 100644 --- a/keystore/src/entities/platform/generic/mls/e2ei_acme_ca.rs +++ b/keystore/src/entities/platform/generic/mls/e2ei_acme_ca.rs @@ -34,7 +34,7 @@ impl EntityBase for E2eiAcmeCA { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::E2eiAcmeCA(self) + crate::transaction::dynamic_dispatch::Entity::E2eiAcmeCA(self.into()) } } @@ -44,7 +44,7 @@ impl NewEntityBase for E2eiAcmeCA { const COLLECTION_NAME: &'static str = "e2ei_acme_ca"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::E2eiAcmeCA(self) + crate::transaction::dynamic_dispatch::Entity::E2eiAcmeCA(self.into()) } } diff --git a/keystore/src/entities/platform/generic/mls/encryption_keypair.rs b/keystore/src/entities/platform/generic/mls/encryption_keypair.rs index 5ae436c978..bf5fb7f004 100644 --- a/keystore/src/entities/platform/generic/mls/encryption_keypair.rs +++ b/keystore/src/entities/platform/generic/mls/encryption_keypair.rs @@ -15,7 +15,7 @@ use crate::{ }, traits::{ BorrowPrimaryKey, Entity as NewEntity, EntityBase as NewEntityBase, EntityDatabaseMutation, - EntityDeleteBorrowed, KeyType, + EntityDeleteBorrowed, EntityGetBorrowed, KeyType, PrimaryKey, }, }; @@ -111,7 +111,7 @@ impl EntityBase for StoredEncryptionKeyPair { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::EncryptionKeyPair(self) + crate::transaction::dynamic_dispatch::Entity::EncryptionKeyPair(self.into()) } } @@ -180,20 +180,30 @@ impl NewEntityBase for StoredEncryptionKeyPair { const COLLECTION_NAME: &'static str = "mls_encryption_keypairs"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::EncryptionKeyPair(self) + crate::transaction::dynamic_dispatch::Entity::EncryptionKeyPair(self.into()) } } -#[async_trait] -impl NewEntity for StoredEncryptionKeyPair { - type PrimaryKey = Sha256Hash; +impl PrimaryKey for StoredEncryptionKeyPair { + type PrimaryKey = Vec; - fn primary_key(&self) -> Sha256Hash { - Sha256Hash::hash_from(&self.pk) + fn primary_key(&self) -> Vec { + self.pk.clone() } +} - async fn get(conn: &mut Self::ConnectionType, id: &Sha256Hash) -> CryptoKeystoreResult> { - get_helper::(conn, "pk_sha256", id, Self::from_row).await +impl BorrowPrimaryKey for StoredEncryptionKeyPair { + type BorrowedPrimaryKey = [u8]; + + fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey { + &self.pk + } +} + +#[async_trait] +impl NewEntity for StoredEncryptionKeyPair { + async fn get(conn: &mut Self::ConnectionType, id: &Vec) -> CryptoKeystoreResult> { + Self::get_borrowed(conn, id.as_slice()).await } async fn load_all(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult> { @@ -205,6 +215,13 @@ impl NewEntity for StoredEncryptionKeyPair { } } +#[async_trait] +impl EntityGetBorrowed for StoredEncryptionKeyPair { + async fn get_borrowed(conn: &mut Self::ConnectionType, id: &[u8]) -> CryptoKeystoreResult> { + get_helper::(conn, "pk", id, Self::from_row).await + } +} + #[async_trait] impl<'a> EntityDatabaseMutation<'a> for StoredEncryptionKeyPair { type Transaction = TransactionWrapper<'a>; @@ -220,7 +237,14 @@ impl<'a> EntityDatabaseMutation<'a> for StoredEncryptionKeyPair { count_helper_tx::(tx).await } - async fn delete(tx: &Self::Transaction, id: &Sha256Hash) -> CryptoKeystoreResult { - delete_helper::(tx, "pk_sha256", id).await + async fn delete(tx: &Self::Transaction, id: &Vec) -> CryptoKeystoreResult { + Self::delete_borrowed(tx, id.as_slice()).await + } +} + +#[async_trait] +impl<'a> EntityDeleteBorrowed<'a> for StoredEncryptionKeyPair { + async fn delete_borrowed(tx: &Self::Transaction, id: &[u8]) -> CryptoKeystoreResult { + delete_helper::(tx, "pk", id).await } } diff --git a/keystore/src/entities/platform/generic/mls/hpke_private_key.rs b/keystore/src/entities/platform/generic/mls/hpke_private_key.rs index 2c6fd3dd9d..7c6f06e7f8 100644 --- a/keystore/src/entities/platform/generic/mls/hpke_private_key.rs +++ b/keystore/src/entities/platform/generic/mls/hpke_private_key.rs @@ -15,7 +15,7 @@ use crate::{ }, traits::{ BorrowPrimaryKey, Entity as NewEntity, EntityBase as NewEntityBase, EntityDatabaseMutation, - EntityDeleteBorrowed, KeyType, + EntityDeleteBorrowed, EntityGetBorrowed, KeyType, PrimaryKey, }, }; @@ -111,7 +111,7 @@ impl EntityBase for StoredHpkePrivateKey { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::HpkePrivateKey(self) + crate::transaction::dynamic_dispatch::Entity::HpkePrivateKey(self.into()) } } @@ -178,20 +178,30 @@ impl NewEntityBase for StoredHpkePrivateKey { const COLLECTION_NAME: &'static str = "mls_hpke_private_keys"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::HpkePrivateKey(self) + crate::transaction::dynamic_dispatch::Entity::HpkePrivateKey(self.into()) } } -#[async_trait] -impl NewEntity for StoredHpkePrivateKey { - type PrimaryKey = Sha256Hash; +impl PrimaryKey for StoredHpkePrivateKey { + type PrimaryKey = Vec; - fn primary_key(&self) -> Sha256Hash { - Sha256Hash::hash_from(&self.pk) + fn primary_key(&self) -> Vec { + self.pk.clone() } +} - async fn get(conn: &mut Self::ConnectionType, id: &Sha256Hash) -> CryptoKeystoreResult> { - get_helper::(conn, "pk_sha256", id, Self::from_row).await +impl BorrowPrimaryKey for StoredHpkePrivateKey { + type BorrowedPrimaryKey = [u8]; + + fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey { + &self.pk + } +} + +#[async_trait] +impl NewEntity for StoredHpkePrivateKey { + async fn get(conn: &mut Self::ConnectionType, id: &Vec) -> CryptoKeystoreResult> { + Self::get_borrowed(conn, id.as_slice()).await } async fn load_all(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult> { @@ -203,6 +213,13 @@ impl NewEntity for StoredHpkePrivateKey { } } +#[async_trait] +impl EntityGetBorrowed for StoredHpkePrivateKey { + async fn get_borrowed(conn: &mut Self::ConnectionType, id: &[u8]) -> CryptoKeystoreResult> { + get_helper::(conn, "pk", id, Self::from_row).await + } +} + #[async_trait] impl<'a> EntityDatabaseMutation<'a> for StoredHpkePrivateKey { type Transaction = TransactionWrapper<'a>; @@ -218,7 +235,14 @@ impl<'a> EntityDatabaseMutation<'a> for StoredHpkePrivateKey { count_helper_tx::(tx).await } - async fn delete(tx: &Self::Transaction, id: &Sha256Hash) -> CryptoKeystoreResult { - delete_helper::(tx, "pk_sha256", id).await + async fn delete(tx: &Self::Transaction, id: &Vec) -> CryptoKeystoreResult { + Self::delete_borrowed(tx, id.as_slice()).await + } +} + +#[async_trait] +impl<'a> EntityDeleteBorrowed<'a> for StoredHpkePrivateKey { + async fn delete_borrowed(tx: &Self::Transaction, id: &[u8]) -> CryptoKeystoreResult { + delete_helper::(tx, "pk", id).await } } diff --git a/keystore/src/entities/platform/generic/mls/pending_group.rs b/keystore/src/entities/platform/generic/mls/pending_group.rs index a1ddce80c5..5b13c87a66 100644 --- a/keystore/src/entities/platform/generic/mls/pending_group.rs +++ b/keystore/src/entities/platform/generic/mls/pending_group.rs @@ -12,7 +12,7 @@ use crate::{ }, traits::{ BorrowPrimaryKey, Entity as NewEntity, EntityBase as NewEntityBase, EntityDatabaseMutation, - EntityDeleteBorrowed, KeyType, + EntityDeleteBorrowed, EntityGetBorrowed, KeyType, PrimaryKey, }, }; @@ -206,7 +206,7 @@ impl EntityBase for PersistedMlsPendingGroup { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::PersistedMlsPendingGroup(self) + crate::transaction::dynamic_dispatch::Entity::PersistedMlsPendingGroup(self.into()) } } @@ -315,18 +315,28 @@ impl NewEntityBase for PersistedMlsPendingGroup { const COLLECTION_NAME: &'static str = "mls_pending_groups"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::PersistedMlsPendingGroup(self) + crate::transaction::dynamic_dispatch::Entity::PersistedMlsPendingGroup(self.into()) } } -#[async_trait] -impl NewEntity for PersistedMlsPendingGroup { +impl PrimaryKey for PersistedMlsPendingGroup { type PrimaryKey = Vec; fn primary_key(&self) -> Self::PrimaryKey { self.id.clone() } +} + +impl BorrowPrimaryKey for PersistedMlsPendingGroup { + type BorrowedPrimaryKey = [u8]; + + fn borrow_primary_key(&self) -> &[u8] { + &self.id + } +} +#[async_trait] +impl NewEntity for PersistedMlsPendingGroup { async fn get(conn: &mut Self::ConnectionType, key: &Self::PrimaryKey) -> CryptoKeystoreResult> { Self::get_borrowed(conn, key).await } @@ -341,18 +351,11 @@ impl NewEntity for PersistedMlsPendingGroup { } #[async_trait] -impl BorrowPrimaryKey for PersistedMlsPendingGroup { - type BorrowedPrimaryKey = [u8]; - - fn borrow_primary_key(&self) -> &[u8] { - &self.id - } - - async fn get_borrowed(conn: &mut Self::ConnectionType, key: &Q) -> CryptoKeystoreResult> - where - Self::PrimaryKey: Borrow, - Q: KeyType, - { +impl EntityGetBorrowed for PersistedMlsPendingGroup { + async fn get_borrowed( + conn: &mut Self::ConnectionType, + key: &Self::BorrowedPrimaryKey, + ) -> CryptoKeystoreResult> { get_helper::(conn, "id", key.bytes().as_ref(), Self::from_row).await } } @@ -380,11 +383,7 @@ impl<'a> EntityDatabaseMutation<'a> for PersistedMlsPendingGroup { #[async_trait] impl<'a> EntityDeleteBorrowed<'a> for PersistedMlsPendingGroup { - async fn delete_borrowed(tx: &Self::Transaction, id: &Q) -> CryptoKeystoreResult - where - Self::PrimaryKey: Borrow, - Q: KeyType, - { + async fn delete_borrowed(tx: &Self::Transaction, id: &Self::BorrowedPrimaryKey) -> CryptoKeystoreResult { delete_helper::(tx, "id", id.bytes().as_ref()).await } } diff --git a/keystore/src/entities/platform/generic/mls/pending_message.rs b/keystore/src/entities/platform/generic/mls/pending_message.rs index ed832f27e4..900b939e47 100644 --- a/keystore/src/entities/platform/generic/mls/pending_message.rs +++ b/keystore/src/entities/platform/generic/mls/pending_message.rs @@ -12,7 +12,7 @@ use crate::{ }, traits::{ BorrowPrimaryKey, Entity as NewEntity, EntityBase as NewEntityBase, EntityDatabaseMutation, - EntityDeleteBorrowed, KeyType, + EntityDeleteBorrowed, EntityGetBorrowed, KeyType, OwnedKeyType, PrimaryKey, }, }; @@ -158,7 +158,7 @@ impl EntityBase for MlsPendingMessage { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::MlsPendingMessage(self) + crate::transaction::dynamic_dispatch::Entity::MlsPendingMessage(self.into()) } } @@ -212,18 +212,12 @@ impl NewEntityBase for MlsPendingMessage { const COLLECTION_NAME: &'static str = "mls_pending_messages"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::MlsPendingMessage(self) + crate::transaction::dynamic_dispatch::Entity::MlsPendingMessage(self.into()) } } #[async_trait] -/// Pending messages have no distinct primary key; -/// they must always be accessed via [`MlsPendingMessage::find_all_by_conversation_id`] and -/// cleaned up with [`MlsPendingMessage::delete_by_conversation_id`] impl NewEntity for MlsPendingMessage { - type PrimaryKey = (); - fn primary_key(&self) -> Self::PrimaryKey {} - async fn get(conn: &mut Self::ConnectionType, key: &Self::PrimaryKey) -> CryptoKeystoreResult> { panic!("cannot get `MlsPendingMessage` by primary key as it has no distinct primary key") } diff --git a/keystore/src/entities/platform/generic/mls/psk_bundle.rs b/keystore/src/entities/platform/generic/mls/psk_bundle.rs index 035ef8698a..23615394bc 100644 --- a/keystore/src/entities/platform/generic/mls/psk_bundle.rs +++ b/keystore/src/entities/platform/generic/mls/psk_bundle.rs @@ -10,7 +10,10 @@ use crate::{ Entity, EntityBase, EntityFindParams, EntityIdStringExt, EntityTransactionExt, StoredPskBundle, StringEntityId, count_helper, count_helper_tx, delete_helper, get_helper, load_all_helper, }, - traits::{Entity as NewEntity, EntityBase as NewEntityBase, EntityDatabaseMutation}, + traits::{ + BorrowPrimaryKey, Entity as NewEntity, EntityBase as NewEntityBase, EntityDatabaseMutation, + EntityDeleteBorrowed, EntityGetBorrowed, PrimaryKey, + }, }; #[async_trait::async_trait] @@ -101,7 +104,7 @@ impl EntityBase for StoredPskBundle { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::PskBundle(self) + crate::transaction::dynamic_dispatch::Entity::PskBundle(self.into()) } } @@ -167,20 +170,30 @@ impl NewEntityBase for StoredPskBundle { const COLLECTION_NAME: &'static str = "mls_psk_bundles"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::PskBundle(self) + crate::transaction::dynamic_dispatch::Entity::PskBundle(self.into()) } } -#[async_trait] -impl NewEntity for StoredPskBundle { - type PrimaryKey = Sha256Hash; +impl PrimaryKey for StoredPskBundle { + type PrimaryKey = Vec; fn primary_key(&self) -> Self::PrimaryKey { - Sha256Hash::hash_from(&self.psk_id) + self.psk_id.clone() + } +} + +impl BorrowPrimaryKey for StoredPskBundle { + type BorrowedPrimaryKey = [u8]; + + fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey { + &self.psk_id } +} +#[async_trait] +impl NewEntity for StoredPskBundle { async fn get(conn: &mut Self::ConnectionType, key: &Self::PrimaryKey) -> CryptoKeystoreResult> { - get_helper::(conn, "id_sha256", key, Self::from_row).await + Self::get_borrowed(conn, key.as_slice()).await } async fn count(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult { @@ -192,6 +205,16 @@ impl NewEntity for StoredPskBundle { } } +#[async_trait] +impl EntityGetBorrowed for StoredPskBundle { + async fn get_borrowed( + conn: &mut Self::ConnectionType, + key: &Self::BorrowedPrimaryKey, + ) -> CryptoKeystoreResult> { + get_helper::(conn, "id", key, Self::from_row).await + } +} + #[async_trait] impl<'a> EntityDatabaseMutation<'a> for StoredPskBundle { type Transaction = TransactionWrapper<'a>; @@ -208,6 +231,13 @@ impl<'a> EntityDatabaseMutation<'a> for StoredPskBundle { } async fn delete(tx: &Self::Transaction, id: &Self::PrimaryKey) -> CryptoKeystoreResult { - delete_helper::(tx, "id_sha256", id).await + Self::delete_borrowed(tx, id.as_slice()).await + } +} + +#[async_trait] +impl<'a> EntityDeleteBorrowed<'a> for StoredPskBundle { + async fn delete_borrowed(tx: &Self::Transaction, id: &Self::BorrowedPrimaryKey) -> CryptoKeystoreResult { + delete_helper::(tx, "id", id).await } } diff --git a/keystore/src/entities/platform/generic/proteus/identity.rs b/keystore/src/entities/platform/generic/proteus/identity.rs index fc25bd0737..82b236186d 100644 --- a/keystore/src/entities/platform/generic/proteus/identity.rs +++ b/keystore/src/entities/platform/generic/proteus/identity.rs @@ -8,7 +8,7 @@ use crate::{ Entity, EntityBase, EntityFindParams, EntityTransactionExt, ProteusIdentity, StringEntityId, count_helper, count_helper_tx, load_all_helper, }, - traits::{Entity as NewEntity, EntityBase as NewEntityBase, EntityDatabaseMutation, UniqueEntity}, + traits::{Entity as NewEntity, EntityBase as NewEntityBase, EntityDatabaseMutation, PrimaryKey, UniqueEntity}, }; #[async_trait::async_trait] @@ -98,7 +98,7 @@ impl EntityBase for ProteusIdentity { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::ProteusIdentity(self) + crate::transaction::dynamic_dispatch::Entity::ProteusIdentity(self.into()) } } @@ -158,7 +158,7 @@ impl NewEntityBase for ProteusIdentity { const COLLECTION_NAME: &'static str = "proteus_identities"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::ProteusIdentity(self) + crate::transaction::dynamic_dispatch::Entity::ProteusIdentity(self.into()) } } @@ -166,14 +166,16 @@ impl UniqueEntity for ProteusIdentity { const KEY: () = (); } -#[async_trait] -impl NewEntity for ProteusIdentity { +impl PrimaryKey for ProteusIdentity { type PrimaryKey = (); fn primary_key(&self) -> Self::PrimaryKey {} +} +#[async_trait] +impl NewEntity for ProteusIdentity { async fn get(conn: &mut Self::ConnectionType, _key: &()) -> CryptoKeystoreResult> { let conn = conn.conn().await; - let mut stmt = conn.prepare_cached("SELECT rowid FROM proteus_identities ORDER BY rowid ASC LIMIT 1")?; + let mut stmt = conn.prepare_cached("SELECT sk, pk FROM proteus_identities ORDER BY rowid ASC LIMIT 1")?; stmt.query_one([], Self::from_row).optional().map_err(Into::into) } diff --git a/keystore/src/entities/platform/generic/proteus/prekey.rs b/keystore/src/entities/platform/generic/proteus/prekey.rs index 7cfacc5848..7d2f0ccc5e 100644 --- a/keystore/src/entities/platform/generic/proteus/prekey.rs +++ b/keystore/src/entities/platform/generic/proteus/prekey.rs @@ -8,7 +8,7 @@ use crate::{ Entity, EntityBase, EntityFindParams, EntityTransactionExt, ProteusPrekey, StringEntityId, count_helper, count_helper_tx, delete_helper, get_helper, load_all_helper, }, - traits::{Entity as NewEntity, EntityBase as NewEntityBase, EntityDatabaseMutation}, + traits::{Entity as NewEntity, EntityBase as NewEntityBase, EntityDatabaseMutation, PrimaryKey}, }; #[async_trait::async_trait] @@ -91,7 +91,7 @@ impl EntityBase for ProteusPrekey { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::ProteusPrekey(self) + crate::transaction::dynamic_dispatch::Entity::ProteusPrekey(self.into()) } } @@ -151,18 +151,20 @@ impl NewEntityBase for ProteusPrekey { const COLLECTION_NAME: &'static str = "proteus_prekeys"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::ProteusPrekey(self) + crate::transaction::dynamic_dispatch::Entity::ProteusPrekey(self.into()) } } -#[async_trait] -impl NewEntity for ProteusPrekey { +impl PrimaryKey for ProteusPrekey { type PrimaryKey = u16; fn primary_key(&self) -> u16 { self.id } +} +#[async_trait] +impl NewEntity for ProteusPrekey { async fn get(conn: &mut Self::ConnectionType, key: &u16) -> CryptoKeystoreResult> { get_helper::(conn, "id", *key, Self::from_row).await } diff --git a/keystore/src/entities/platform/wasm/general.rs b/keystore/src/entities/platform/wasm/general.rs index 23a6331ed2..250c290348 100644 --- a/keystore/src/entities/platform/wasm/general.rs +++ b/keystore/src/entities/platform/wasm/general.rs @@ -1,8 +1,11 @@ use crate::{ - MissingKeyErrorKind, + CryptoKeystoreResult, MissingKeyErrorKind, connection::KeystoreDatabaseConnection, entities::{ConsumerData, EntityBase, UniqueEntity}, - traits::{EntityBase as NewEntityBase, UniqueEntityImplementationHelper}, + traits::{ + DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, EntityBase as NewEntityBase, UniqueEntity as _, + UniqueEntityImplementationHelper, + }, }; #[async_trait::async_trait(?Send)] @@ -16,7 +19,7 @@ impl EntityBase for ConsumerData { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::ConsumerData(self) + crate::transaction::dynamic_dispatch::Entity::ConsumerData(self.into()) } } @@ -37,7 +40,7 @@ impl NewEntityBase for ConsumerData { const COLLECTION_NAME: &'static str = "consumer_data"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::ConsumerData(self) + crate::transaction::dynamic_dispatch::Entity::ConsumerData(self.into()) } } @@ -50,3 +53,30 @@ impl UniqueEntityImplementationHelper for ConsumerData { &self.content } } + +#[derive(serde::Serialize, serde::Deserialize)] +pub struct ConsumerDataEncrypted { + content: Vec, +} + +impl<'a> Encrypting<'a> for ConsumerData { + type EncryptedForm = ConsumerDataEncrypted; + + fn encrypt(&'a self, cipher: &aes_gcm::Aes256Gcm) -> CryptoKeystoreResult { + let content = ::encrypt_data(self, cipher, &self.content)?; + Ok(ConsumerDataEncrypted { content }) + } +} + +impl Decrypting<'static> for ConsumerDataEncrypted { + type DecryptedForm = ConsumerData; + + fn decrypt(self, cipher: &aes_gcm::Aes256Gcm) -> CryptoKeystoreResult { + let content = ::decrypt_data(cipher, &ConsumerData::KEY, &self.content)?; + Ok(ConsumerData { content }) + } +} + +impl Decryptable<'static> for ConsumerData { + type DecryptableFrom = ConsumerDataEncrypted; +} diff --git a/keystore/src/entities/platform/wasm/mls/credential.rs b/keystore/src/entities/platform/wasm/mls/credential.rs index 361d24f9ae..92f16afdf8 100644 --- a/keystore/src/entities/platform/wasm/mls/credential.rs +++ b/keystore/src/entities/platform/wasm/mls/credential.rs @@ -7,7 +7,7 @@ use crate::{ entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, StoredCredential, StringEntityId}, traits::{ DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, Entity as NewEntity, - EntityBase as NewEntityBase, EntityDatabaseMutation, KeyType as _, + EntityBase as NewEntityBase, EntityDatabaseMutation, KeyType as _, PrimaryKey, }, }; @@ -22,7 +22,7 @@ impl EntityBase for StoredCredential { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::StoredCredential(self) + crate::transaction::dynamic_dispatch::Entity::StoredCredential(self.into()) } } @@ -91,18 +91,20 @@ impl NewEntityBase for StoredCredential { const COLLECTION_NAME: &'static str = "mls_credentials"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::StoredCredential(self) + crate::transaction::dynamic_dispatch::Entity::StoredCredential(self.into()) } } -#[async_trait(?Send)] -impl NewEntity for StoredCredential { +impl PrimaryKey for StoredCredential { type PrimaryKey = Sha256Hash; fn primary_key(&self) -> Self::PrimaryKey { Sha256Hash::hash_from(&self.public_key) } +} +#[async_trait(?Send)] +impl NewEntity for StoredCredential { async fn get(conn: &mut Self::ConnectionType, key: &Self::PrimaryKey) -> CryptoKeystoreResult> { conn.storage().new_get(key.bytes().as_ref()).await } @@ -138,7 +140,7 @@ impl<'a> EntityDatabaseMutation<'a> for StoredCredential { tx.new_count::().await } - async fn delete(tx: &Self::Transaction, id: &::PrimaryKey) -> CryptoKeystoreResult { + async fn delete(tx: &Self::Transaction, id: &Self::PrimaryKey) -> CryptoKeystoreResult { tx.new_delete::(id.bytes().as_ref()).await } } diff --git a/keystore/src/entities/platform/wasm/mls/e2ei_acme_ca.rs b/keystore/src/entities/platform/wasm/mls/e2ei_acme_ca.rs index de95080983..fd34e2d4b3 100644 --- a/keystore/src/entities/platform/wasm/mls/e2ei_acme_ca.rs +++ b/keystore/src/entities/platform/wasm/mls/e2ei_acme_ca.rs @@ -1,8 +1,11 @@ use crate::{ - MissingKeyErrorKind, + CryptoKeystoreResult, MissingKeyErrorKind, connection::KeystoreDatabaseConnection, entities::{E2eiAcmeCA, EntityBase, UniqueEntity}, - traits::{EntityBase as NewEntityBase, UniqueEntityImplementationHelper}, + traits::{ + DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, EntityBase as NewEntityBase, UniqueEntity as _, + UniqueEntityImplementationHelper, + }, }; impl EntityBase for E2eiAcmeCA { @@ -15,7 +18,7 @@ impl EntityBase for E2eiAcmeCA { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::E2eiAcmeCA(self) + crate::transaction::dynamic_dispatch::Entity::E2eiAcmeCA(self.into()) } } @@ -35,7 +38,7 @@ impl NewEntityBase for E2eiAcmeCA { const COLLECTION_NAME: &'static str = "e2ei_acme_ca"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::E2eiAcmeCA(self) + crate::transaction::dynamic_dispatch::Entity::E2eiAcmeCA(self.into()) } } @@ -47,3 +50,30 @@ impl UniqueEntityImplementationHelper for E2eiAcmeCA { &self.content } } + +#[derive(serde::Serialize, serde::Deserialize)] +pub struct E2eiAcmeCAEncrypted { + content: Vec, +} + +impl<'a> Encrypting<'a> for E2eiAcmeCA { + type EncryptedForm = E2eiAcmeCAEncrypted; + + fn encrypt(&'a self, cipher: &aes_gcm::Aes256Gcm) -> CryptoKeystoreResult { + let content = ::encrypt_data(self, cipher, &self.content)?; + Ok(E2eiAcmeCAEncrypted { content }) + } +} + +impl Decrypting<'static> for E2eiAcmeCAEncrypted { + type DecryptedForm = E2eiAcmeCA; + + fn decrypt(self, cipher: &aes_gcm::Aes256Gcm) -> CryptoKeystoreResult { + let content = ::decrypt_data(cipher, &E2eiAcmeCA::KEY, &self.content)?; + Ok(E2eiAcmeCA { content }) + } +} + +impl Decryptable<'static> for E2eiAcmeCA { + type DecryptableFrom = E2eiAcmeCAEncrypted; +} diff --git a/keystore/src/entities/platform/wasm/mls/encryption_keypair.rs b/keystore/src/entities/platform/wasm/mls/encryption_keypair.rs index a2d59af5cb..b38aa0682a 100644 --- a/keystore/src/entities/platform/wasm/mls/encryption_keypair.rs +++ b/keystore/src/entities/platform/wasm/mls/encryption_keypair.rs @@ -1,12 +1,12 @@ use async_trait::async_trait; use crate::{ - CryptoKeystoreResult, MissingKeyErrorKind, Sha256Hash, + CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection, TransactionWrapper}, entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, StoredEncryptionKeyPair, StringEntityId}, traits::{ - DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, Entity as NewEntity, - EntityBase as NewEntityBase, EntityDatabaseMutation, KeyType, + BorrowPrimaryKey, DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, Entity as NewEntity, + EntityBase as NewEntityBase, EntityDatabaseMutation, EntityDeleteBorrowed, EntityGetBorrowed, PrimaryKey, }, }; @@ -21,7 +21,7 @@ impl EntityBase for StoredEncryptionKeyPair { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::EncryptionKeyPair(self) + crate::transaction::dynamic_dispatch::Entity::EncryptionKeyPair(self.into()) } } @@ -74,20 +74,30 @@ impl NewEntityBase for StoredEncryptionKeyPair { const COLLECTION_NAME: &'static str = "mls_encryption_keypairs"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::EncryptionKeyPair(self) + crate::transaction::dynamic_dispatch::Entity::EncryptionKeyPair(self.into()) } } -#[async_trait(?Send)] -impl NewEntity for StoredEncryptionKeyPair { - type PrimaryKey = Sha256Hash; +impl PrimaryKey for StoredEncryptionKeyPair { + type PrimaryKey = Vec; - fn primary_key(&self) -> Sha256Hash { - Sha256Hash::hash_from(&self.pk) + fn primary_key(&self) -> Self::PrimaryKey { + self.pk.clone() } +} - async fn get(conn: &mut Self::ConnectionType, id: &Sha256Hash) -> CryptoKeystoreResult> { - conn.storage().new_get(id.bytes().as_ref()).await +impl BorrowPrimaryKey for StoredEncryptionKeyPair { + type BorrowedPrimaryKey = [u8]; + + fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey { + &self.pk + } +} + +#[async_trait(?Send)] +impl NewEntity for StoredEncryptionKeyPair { + async fn get(conn: &mut Self::ConnectionType, id: &Vec) -> CryptoKeystoreResult> { + Self::get_borrowed(conn, id.as_slice()).await } async fn load_all(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult> { @@ -99,6 +109,16 @@ impl NewEntity for StoredEncryptionKeyPair { } } +#[async_trait(?Send)] +impl EntityGetBorrowed for StoredEncryptionKeyPair { + async fn get_borrowed( + conn: &mut ::ConnectionType, + id: &[u8], + ) -> CryptoKeystoreResult> { + conn.storage().new_get(id).await + } +} + #[async_trait(?Send)] impl<'a> EntityDatabaseMutation<'a> for StoredEncryptionKeyPair { type Transaction = TransactionWrapper<'a>; @@ -111,8 +131,18 @@ impl<'a> EntityDatabaseMutation<'a> for StoredEncryptionKeyPair { tx.new_count::().await } - async fn delete(tx: &Self::Transaction, id: &Sha256Hash) -> CryptoKeystoreResult { - tx.new_delete::(id.bytes().as_ref()).await + async fn delete(tx: &Self::Transaction, id: &Vec) -> CryptoKeystoreResult { + Self::delete_borrowed(tx, id.as_slice()).await + } +} + +#[async_trait(?Send)] +impl<'a> EntityDeleteBorrowed<'a> for StoredEncryptionKeyPair { + async fn delete_borrowed( + tx: &>::Transaction, + id: &[u8], + ) -> CryptoKeystoreResult { + tx.new_delete::(id).await } } @@ -147,8 +177,7 @@ impl Decrypting<'static> for StoredEncryptionKeyPairDecrypt { fn decrypt(self, cipher: &aes_gcm::Aes256Gcm) -> CryptoKeystoreResult { let Self { pk, sk } = self; - let primary_key = Sha256Hash::hash_from(&pk); - let sk = ::decrypt_data(cipher, &primary_key, &sk)?; + let sk = ::decrypt_data(cipher, &pk, &sk)?; Ok(StoredEncryptionKeyPair { pk, sk }) } diff --git a/keystore/src/entities/platform/wasm/mls/hpke_private_key.rs b/keystore/src/entities/platform/wasm/mls/hpke_private_key.rs index f2d1bd737f..fedcc6245f 100644 --- a/keystore/src/entities/platform/wasm/mls/hpke_private_key.rs +++ b/keystore/src/entities/platform/wasm/mls/hpke_private_key.rs @@ -2,12 +2,12 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use crate::{ - CryptoKeystoreResult, MissingKeyErrorKind, Sha256Hash, + CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection, TransactionWrapper}, entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, StoredHpkePrivateKey, StringEntityId}, traits::{ - DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, Entity as NewEntity, - EntityBase as NewEntityBase, EntityDatabaseMutation, KeyType, + BorrowPrimaryKey, DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, Entity as NewEntity, + EntityBase as NewEntityBase, EntityDatabaseMutation, EntityDeleteBorrowed, EntityGetBorrowed, PrimaryKey, }, }; @@ -22,7 +22,7 @@ impl EntityBase for StoredHpkePrivateKey { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::HpkePrivateKey(self) + crate::transaction::dynamic_dispatch::Entity::HpkePrivateKey(self.into()) } } @@ -75,20 +75,30 @@ impl NewEntityBase for StoredHpkePrivateKey { const COLLECTION_NAME: &'static str = "mls_hpke_private_keys"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::HpkePrivateKey(self) + crate::transaction::dynamic_dispatch::Entity::HpkePrivateKey(self.into()) } } -#[async_trait(?Send)] -impl NewEntity for StoredHpkePrivateKey { - type PrimaryKey = Sha256Hash; +impl PrimaryKey for StoredHpkePrivateKey { + type PrimaryKey = Vec; - fn primary_key(&self) -> Sha256Hash { - Sha256Hash::hash_from(&self.pk) + fn primary_key(&self) -> Vec { + self.pk.clone() } +} - async fn get(conn: &mut Self::ConnectionType, id: &Sha256Hash) -> CryptoKeystoreResult> { - conn.storage().new_get(id.bytes().as_ref()).await +impl BorrowPrimaryKey for StoredHpkePrivateKey { + type BorrowedPrimaryKey = [u8]; + + fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey { + &self.pk + } +} + +#[async_trait(?Send)] +impl NewEntity for StoredHpkePrivateKey { + async fn get(conn: &mut Self::ConnectionType, id: &Vec) -> CryptoKeystoreResult> { + Self::get_borrowed(conn, id.as_slice()).await } async fn load_all(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult> { @@ -100,6 +110,16 @@ impl NewEntity for StoredHpkePrivateKey { } } +#[async_trait(?Send)] +impl EntityGetBorrowed for StoredHpkePrivateKey { + async fn get_borrowed( + conn: &mut ::ConnectionType, + id: &[u8], + ) -> CryptoKeystoreResult> { + conn.storage().new_get(id).await + } +} + #[async_trait(?Send)] impl<'a> EntityDatabaseMutation<'a> for StoredHpkePrivateKey { type Transaction = TransactionWrapper<'a>; @@ -112,8 +132,18 @@ impl<'a> EntityDatabaseMutation<'a> for StoredHpkePrivateKey { tx.new_count::().await } - async fn delete(tx: &Self::Transaction, id: &Sha256Hash) -> CryptoKeystoreResult { - tx.new_delete::(id.bytes().as_ref()).await + async fn delete(tx: &Self::Transaction, id: &Vec) -> CryptoKeystoreResult { + Self::delete_borrowed(tx, id.as_slice()).await + } +} + +#[async_trait(?Send)] +impl<'a> EntityDeleteBorrowed<'a> for StoredHpkePrivateKey { + async fn delete_borrowed( + tx: &>::Transaction, + id: &[u8], + ) -> CryptoKeystoreResult { + tx.new_delete::(id).await } } @@ -142,8 +172,7 @@ impl Decrypting<'static> for StoredHpkePrivateKeyDecrypt { type DecryptedForm = StoredHpkePrivateKey; fn decrypt(self, cipher: &aes_gcm::Aes256Gcm) -> CryptoKeystoreResult { - let primary_key = Sha256Hash::hash_from(&self.pk); - let sk = ::decrypt_data(cipher, &primary_key, &self.sk)?; + let sk = ::decrypt_data(cipher, &self.pk, &self.sk)?; Ok(StoredHpkePrivateKey { sk, pk: self.pk }) } } diff --git a/keystore/src/entities/platform/wasm/mls/pending_group.rs b/keystore/src/entities/platform/wasm/mls/pending_group.rs index 1379355b10..3ccdf1d0dc 100644 --- a/keystore/src/entities/platform/wasm/mls/pending_group.rs +++ b/keystore/src/entities/platform/wasm/mls/pending_group.rs @@ -1,5 +1,3 @@ -use std::borrow::Borrow; - use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -9,7 +7,8 @@ use crate::{ entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, PersistedMlsPendingGroup, StringEntityId}, traits::{ BorrowPrimaryKey, DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, Entity as NewEntity, - EntityBase as NewEntityBase, EntityDatabaseMutation, EntityDeleteBorrowed, KeyType, + EntityBase as NewEntityBase, EntityDatabaseMutation, EntityDeleteBorrowed, EntityGetBorrowed, KeyType, + PrimaryKey, }, }; @@ -23,7 +22,7 @@ impl EntityBase for PersistedMlsPendingGroup { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::PersistedMlsPendingGroup(self) + crate::transaction::dynamic_dispatch::Entity::PersistedMlsPendingGroup(self.into()) } } @@ -90,18 +89,28 @@ impl NewEntityBase for PersistedMlsPendingGroup { const COLLECTION_NAME: &'static str = "mls_pending_groups"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::PersistedMlsPendingGroup(self) + crate::transaction::dynamic_dispatch::Entity::PersistedMlsPendingGroup(self.into()) } } -#[async_trait(?Send)] -impl NewEntity for PersistedMlsPendingGroup { +impl PrimaryKey for PersistedMlsPendingGroup { type PrimaryKey = Vec; fn primary_key(&self) -> Self::PrimaryKey { self.id.clone() } +} + +impl BorrowPrimaryKey for PersistedMlsPendingGroup { + type BorrowedPrimaryKey = [u8]; + + fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey { + &self.id + } +} +#[async_trait(?Send)] +impl NewEntity for PersistedMlsPendingGroup { async fn get(conn: &mut Self::ConnectionType, key: &Self::PrimaryKey) -> CryptoKeystoreResult> { Self::get_borrowed(conn, key).await } @@ -116,18 +125,11 @@ impl NewEntity for PersistedMlsPendingGroup { } #[async_trait(?Send)] -impl BorrowPrimaryKey for PersistedMlsPendingGroup { - type BorrowedPrimaryKey = [u8]; - - fn borrow_primary_key(&self) -> &[u8] { - &self.id - } - - async fn get_borrowed(conn: &mut Self::ConnectionType, key: &Q) -> CryptoKeystoreResult> - where - Self::PrimaryKey: Borrow, - Q: KeyType, - { +impl EntityGetBorrowed for PersistedMlsPendingGroup { + async fn get_borrowed( + conn: &mut Self::ConnectionType, + key: &Self::BorrowedPrimaryKey, + ) -> CryptoKeystoreResult> { conn.storage().new_get(key.bytes().as_ref()).await } } @@ -151,11 +153,7 @@ impl<'a> EntityDatabaseMutation<'a> for PersistedMlsPendingGroup { #[async_trait(?Send)] impl<'a> EntityDeleteBorrowed<'a> for PersistedMlsPendingGroup { - async fn delete_borrowed(tx: &Self::Transaction, id: &Q) -> CryptoKeystoreResult - where - Self::PrimaryKey: Borrow, - Q: KeyType, - { + async fn delete_borrowed(tx: &Self::Transaction, id: &Self::BorrowedPrimaryKey) -> CryptoKeystoreResult { tx.new_delete::(id.bytes().as_ref()).await } } diff --git a/keystore/src/entities/platform/wasm/mls/pending_message.rs b/keystore/src/entities/platform/wasm/mls/pending_message.rs index ef1ee2f5a7..dae5f18bd8 100644 --- a/keystore/src/entities/platform/wasm/mls/pending_message.rs +++ b/keystore/src/entities/platform/wasm/mls/pending_message.rs @@ -8,8 +8,8 @@ use crate::{ connection::{KeystoreDatabaseConnection, TransactionWrapper}, entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, MlsPendingMessage, StringEntityId}, traits::{ - DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, Entity as NewEntity, - EntityBase as NewEntityBase, EntityDatabaseMutation, + DecryptWithExplicitEncryptionKey as _, Decryptable, Decrypting, EncryptWithExplicitEncryptionKey as _, + Encrypting, EncryptionKey, Entity as NewEntity, EntityBase as NewEntityBase, EntityDatabaseMutation, }, }; @@ -24,7 +24,7 @@ impl EntityBase for MlsPendingMessage { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::MlsPendingMessage(self) + crate::transaction::dynamic_dispatch::Entity::MlsPendingMessage(self.into()) } } @@ -112,24 +112,12 @@ impl NewEntityBase for MlsPendingMessage { const COLLECTION_NAME: &'static str = "mls_pending_messages"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::MlsPendingMessage(self) + crate::transaction::dynamic_dispatch::Entity::MlsPendingMessage(self.into()) } } #[async_trait(?Send)] -/// Pending messages have no distinct primary key; -/// they must always be accessed via [`MlsPendingMessage::find_all_by_conversation_id`] and -/// cleaned up with [`MlsPendingMessage::delete_by_conversation_id`] -/// -/// However, we have to fake it as a byte vector in this impl in order for encryption and decryption -/// to work. impl NewEntity for MlsPendingMessage { - type PrimaryKey = Vec; - - fn primary_key(&self) -> Vec { - self.foreign_id.clone() - } - async fn get(_conn: &mut Self::ConnectionType, _key: &Self::PrimaryKey) -> CryptoKeystoreResult> { panic!("cannot get `MlsPendingMessage` by primary key as it has no distinct primary key") } @@ -160,6 +148,12 @@ impl<'a> EntityDatabaseMutation<'a> for MlsPendingMessage { } } +impl EncryptionKey for MlsPendingMessage { + fn encryption_key(&self) -> &[u8] { + &self.foreign_id + } +} + #[derive(Serialize)] pub struct MlsPendingMessageEncrypt<'a> { foreign_id: &'a [u8], @@ -170,7 +164,7 @@ impl<'a> Encrypting<'a> for MlsPendingMessage { type EncryptedForm = MlsPendingMessageEncrypt<'a>; fn encrypt(&'a self, cipher: &aes_gcm::Aes256Gcm) -> CryptoKeystoreResult { - let message = ::encrypt_data(self, cipher, &self.message)?; + let message = self.encrypt_data_with_encryption_key(cipher, &self.message)?; Ok(MlsPendingMessageEncrypt { foreign_id: &self.foreign_id, message, @@ -188,7 +182,7 @@ impl Decrypting<'static> for MlsPendingMessageDecrypt { type DecryptedForm = MlsPendingMessage; fn decrypt(self, cipher: &aes_gcm::Aes256Gcm) -> CryptoKeystoreResult { - let message = ::decrypt_data(cipher, &self.foreign_id, &self.message)?; + let message = MlsPendingMessage::decrypt_data_with_encryption_key(cipher, &self.foreign_id, &self.message)?; Ok(MlsPendingMessage { foreign_id: self.foreign_id, message, diff --git a/keystore/src/entities/platform/wasm/mls/psk_bundle.rs b/keystore/src/entities/platform/wasm/mls/psk_bundle.rs index b2bb96bfa9..593b114063 100644 --- a/keystore/src/entities/platform/wasm/mls/psk_bundle.rs +++ b/keystore/src/entities/platform/wasm/mls/psk_bundle.rs @@ -2,12 +2,12 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use crate::{ - CryptoKeystoreResult, MissingKeyErrorKind, Sha256Hash, + CryptoKeystoreResult, MissingKeyErrorKind, connection::{DatabaseConnection, KeystoreDatabaseConnection, TransactionWrapper}, entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, StoredPskBundle, StringEntityId}, traits::{ - DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, Entity as NewEntity, - EntityBase as NewEntityBase, EntityDatabaseMutation, KeyType as _, + BorrowPrimaryKey, DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, Entity as NewEntity, + EntityBase as NewEntityBase, EntityDatabaseMutation, EntityDeleteBorrowed, EntityGetBorrowed, PrimaryKey, }, }; @@ -22,7 +22,7 @@ impl EntityBase for StoredPskBundle { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::PskBundle(self) + crate::transaction::dynamic_dispatch::Entity::PskBundle(self.into()) } } @@ -75,20 +75,30 @@ impl NewEntityBase for StoredPskBundle { const COLLECTION_NAME: &'static str = "mls_psk_bundles"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::PskBundle(self) + crate::transaction::dynamic_dispatch::Entity::PskBundle(self.into()) } } -#[async_trait(?Send)] -impl NewEntity for StoredPskBundle { - type PrimaryKey = Sha256Hash; +impl PrimaryKey for StoredPskBundle { + type PrimaryKey = Vec; - fn primary_key(&self) -> Self::PrimaryKey { - Sha256Hash::hash_from(&self.psk_id) + fn primary_key(&self) -> Vec { + self.psk_id.clone() } +} - async fn get(conn: &mut Self::ConnectionType, key: &Self::PrimaryKey) -> CryptoKeystoreResult> { - conn.storage().new_get(key.bytes().as_ref()).await +impl BorrowPrimaryKey for StoredPskBundle { + type BorrowedPrimaryKey = [u8]; + + fn borrow_primary_key(&self) -> &[u8] { + &self.psk_id + } +} + +#[async_trait(?Send)] +impl NewEntity for StoredPskBundle { + async fn get(conn: &mut Self::ConnectionType, key: &Vec) -> CryptoKeystoreResult> { + Self::get_borrowed(conn, key.as_slice()).await } async fn count(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult { @@ -100,6 +110,16 @@ impl NewEntity for StoredPskBundle { } } +#[async_trait(?Send)] +impl EntityGetBorrowed for StoredPskBundle { + async fn get_borrowed( + conn: &mut ::ConnectionType, + key: &[u8], + ) -> CryptoKeystoreResult> { + conn.storage().new_get(key).await + } +} + #[async_trait(?Send)] impl<'a> EntityDatabaseMutation<'a> for StoredPskBundle { type Transaction = TransactionWrapper<'a>; @@ -112,8 +132,18 @@ impl<'a> EntityDatabaseMutation<'a> for StoredPskBundle { tx.new_count::().await } - async fn delete(tx: &Self::Transaction, id: &Self::PrimaryKey) -> CryptoKeystoreResult { - tx.new_delete::(id.bytes().as_ref()).await + async fn delete(tx: &Self::Transaction, id: &Vec) -> CryptoKeystoreResult { + Self::delete_borrowed(tx, id.as_slice()).await + } +} + +#[async_trait(?Send)] +impl<'a> EntityDeleteBorrowed<'a> for StoredPskBundle { + async fn delete_borrowed( + tx: &>::Transaction, + id: &[u8], + ) -> CryptoKeystoreResult { + tx.new_delete::(id).await } } @@ -145,8 +175,7 @@ impl Decrypting<'static> for StoredPskBundleDecrypt { type DecryptedForm = StoredPskBundle; fn decrypt(self, cipher: &aes_gcm::Aes256Gcm) -> CryptoKeystoreResult { - let primary_key = Sha256Hash::hash_from(&self.psk_id); - let psk = ::decrypt_data(cipher, &primary_key, &self.psk)?; + let psk = ::decrypt_data(cipher, &self.psk_id, &self.psk)?; Ok(StoredPskBundle { psk_id: self.psk_id, psk, diff --git a/keystore/src/entities/platform/wasm/mls/refresh_token.rs b/keystore/src/entities/platform/wasm/mls/refresh_token.rs index 5f7ac5a4ea..f674bac0de 100644 --- a/keystore/src/entities/platform/wasm/mls/refresh_token.rs +++ b/keystore/src/entities/platform/wasm/mls/refresh_token.rs @@ -1,8 +1,13 @@ +use serde::{Deserialize, Serialize}; + use crate::{ - MissingKeyErrorKind, + CryptoKeystoreResult, MissingKeyErrorKind, connection::KeystoreDatabaseConnection, entities::{E2eiRefreshToken, EntityBase, UniqueEntity}, - traits::{EntityBase as NewEntityBase, UniqueEntityImplementationHelper}, + traits::{ + DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, EntityBase as NewEntityBase, UniqueEntity as _, + UniqueEntityImplementationHelper, + }, }; #[async_trait::async_trait(?Send)] @@ -16,7 +21,7 @@ impl EntityBase for E2eiRefreshToken { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::E2eiRefreshToken(self) + crate::transaction::dynamic_dispatch::Entity::E2eiRefreshToken(self.into()) } } @@ -37,7 +42,7 @@ impl NewEntityBase for E2eiRefreshToken { const COLLECTION_NAME: &'static str = "e2ei_refresh_token"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::E2eiRefreshToken(self) + crate::transaction::dynamic_dispatch::Entity::E2eiRefreshToken(self.into()) } } @@ -50,3 +55,30 @@ impl UniqueEntityImplementationHelper for E2eiRefreshToken { &self.content } } + +#[derive(Serialize, Deserialize)] +pub struct E2eiRefreshTokenEncrypted { + content: Vec, +} + +impl<'a> Encrypting<'a> for E2eiRefreshToken { + type EncryptedForm = E2eiRefreshTokenEncrypted; + + fn encrypt(&'a self, cipher: &aes_gcm::Aes256Gcm) -> CryptoKeystoreResult { + let content = ::encrypt_data(self, cipher, &self.content)?; + Ok(E2eiRefreshTokenEncrypted { content }) + } +} + +impl Decrypting<'static> for E2eiRefreshTokenEncrypted { + type DecryptedForm = E2eiRefreshToken; + + fn decrypt(self, cipher: &aes_gcm::Aes256Gcm) -> CryptoKeystoreResult { + let content = ::decrypt_data(cipher, &E2eiRefreshToken::KEY, &self.content)?; + Ok(E2eiRefreshToken { content }) + } +} + +impl Decryptable<'static> for E2eiRefreshToken { + type DecryptableFrom = E2eiRefreshTokenEncrypted; +} diff --git a/keystore/src/entities/platform/wasm/proteus/identity.rs b/keystore/src/entities/platform/wasm/proteus/identity.rs index cb0e627840..730b14fdb7 100644 --- a/keystore/src/entities/platform/wasm/proteus/identity.rs +++ b/keystore/src/entities/platform/wasm/proteus/identity.rs @@ -7,7 +7,7 @@ use crate::{ entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, ProteusIdentity, StringEntityId}, traits::{ DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, Entity as NewEntity, - EntityBase as NewEntityBase, EntityDatabaseMutation, KeyType as _, UniqueEntity, + EntityBase as NewEntityBase, EntityDatabaseMutation, KeyType as _, PrimaryKey, UniqueEntity, }, }; @@ -22,7 +22,7 @@ impl EntityBase for ProteusIdentity { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::ProteusIdentity(self) + crate::transaction::dynamic_dispatch::Entity::ProteusIdentity(self.into()) } } @@ -81,7 +81,7 @@ impl NewEntityBase for ProteusIdentity { const COLLECTION_NAME: &'static str = "proteus_identities"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::ProteusIdentity(self) + crate::transaction::dynamic_dispatch::Entity::ProteusIdentity(self.into()) } } @@ -89,13 +89,15 @@ impl UniqueEntity for ProteusIdentity { const KEY: [u8; 1] = [1]; } -#[async_trait(?Send)] -impl NewEntity for ProteusIdentity { +impl PrimaryKey for ProteusIdentity { type PrimaryKey = [u8; 1]; - fn primary_key(&self) -> [u8; 1] { + fn primary_key(&self) -> Self::PrimaryKey { Self::KEY } +} +#[async_trait(?Send)] +impl NewEntity for ProteusIdentity { async fn get(conn: &mut Self::ConnectionType, _key: &Self::PrimaryKey) -> CryptoKeystoreResult> { let identity = Self::load_all(conn).await?.pop(); Ok(identity) diff --git a/keystore/src/entities/platform/wasm/proteus/prekey.rs b/keystore/src/entities/platform/wasm/proteus/prekey.rs index 6f4df017eb..f74591f213 100644 --- a/keystore/src/entities/platform/wasm/proteus/prekey.rs +++ b/keystore/src/entities/platform/wasm/proteus/prekey.rs @@ -7,7 +7,7 @@ use crate::{ entities::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, ProteusPrekey, StringEntityId}, traits::{ DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, Entity as NewEntity, - EntityBase as NewEntityBase, EntityDatabaseMutation, KeyType, + EntityBase as NewEntityBase, EntityDatabaseMutation, KeyType, PrimaryKey, }, }; @@ -22,7 +22,7 @@ impl EntityBase for ProteusPrekey { } fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::ProteusPrekey(self) + crate::transaction::dynamic_dispatch::Entity::ProteusPrekey(self.into()) } } @@ -75,18 +75,19 @@ impl NewEntityBase for ProteusPrekey { const COLLECTION_NAME: &'static str = "proteus_prekeys"; fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity { - crate::transaction::dynamic_dispatch::Entity::ProteusPrekey(self) + crate::transaction::dynamic_dispatch::Entity::ProteusPrekey(self.into()) } } -#[async_trait(?Send)] -impl NewEntity for ProteusPrekey { +impl PrimaryKey for ProteusPrekey { type PrimaryKey = u16; - - fn primary_key(&self) -> u16 { + fn primary_key(&self) -> Self::PrimaryKey { self.id } +} +#[async_trait(?Send)] +impl NewEntity for ProteusPrekey { async fn get(conn: &mut Self::ConnectionType, key: &u16) -> CryptoKeystoreResult> { conn.storage().new_get(key.bytes().as_ref()).await } diff --git a/keystore/src/entities/proteus.rs b/keystore/src/entities/proteus.rs index 7b54278c3a..d8496b5c5a 100644 --- a/keystore/src/entities/proteus.rs +++ b/keystore/src/entities/proteus.rs @@ -1,6 +1,6 @@ use zeroize::Zeroize; -use crate::connection::FetchFromDatabase; +use crate::traits::FetchFromDatabase as _; #[derive(core_crypto_macros::Debug, Clone, Zeroize, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[zeroize(drop)] @@ -75,7 +75,7 @@ impl ProteusPrekey { if id == limit { return Err(crate::CryptoKeystoreError::NoFreePrekeyId); } - if conn.find::(&id.to_le_bytes()).await?.is_none() { + if conn.get::(&id).await?.is_none() { break; } id += 1; diff --git a/keystore/src/error.rs b/keystore/src/error.rs index cc3c55be2e..9ebad10122 100644 --- a/keystore/src/error.rs +++ b/keystore/src/error.rs @@ -152,6 +152,10 @@ pub enum CryptoKeystoreError { MigrationNotSupported(u32), #[error("The migration failed: {0}")] MigrationFailed(String), + #[error("the provided bytes could not be interpreted as the primary key of {0}")] + InvalidPrimaryKeyBytes(&'static str), + #[error("the entity {0} had an unknown collection name and could not be found")] + UnknownEntity(&'static str), } #[cfg(target_family = "wasm")] diff --git a/keystore/src/hash.rs b/keystore/src/hash.rs index 33d8f3be47..9e7e8f1d37 100644 --- a/keystore/src/hash.rs +++ b/keystore/src/hash.rs @@ -2,7 +2,10 @@ use std::fmt; use sha2::{Digest, Sha256}; -use crate::{CryptoKeystoreResult, traits::KeyType}; +use crate::{ + CryptoKeystoreResult, + traits::{KeyType, OwnedKeyType}, +}; /// Used to calculate ID hashes for some MlsEntities' SQLite tables (not used on wasm). /// We only use sha256 on platforms where we use SQLite. @@ -69,6 +72,12 @@ impl KeyType for Sha256Hash { } } +impl OwnedKeyType for Sha256Hash { + fn from_bytes(bytes: &[u8]) -> Option { + bytes.try_into().ok().map(Self) + } +} + #[cfg(not(target_family = "wasm"))] impl rusqlite::ToSql for Sha256Hash { fn to_sql(&self) -> rusqlite::Result> { diff --git a/keystore/src/mls.rs b/keystore/src/mls.rs index 3b56ffb8ce..c2dd12f25f 100644 --- a/keystore/src/mls.rs +++ b/keystore/src/mls.rs @@ -3,12 +3,12 @@ use openmls_basic_credential::SignatureKeyPair; use openmls_traits::key_store::{MlsEntity, MlsEntityId}; use crate::{ - CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind, - connection::FetchFromDatabase, + CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind, Sha256Hash, entities::{ - EntityFindParams, PersistedMlsGroup, PersistedMlsPendingGroup, StoredCredential, StoredE2eiEnrollment, - StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle, + PersistedMlsGroup, PersistedMlsPendingGroup, StoredCredential, StoredE2eiEnrollment, StoredEncryptionKeyPair, + StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle, }, + traits::FetchFromDatabase, }; /// An interface for the specialized queries in the KeyStore @@ -124,23 +124,19 @@ pub trait CryptoKeystoreMls: Sized { #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] impl CryptoKeystoreMls for crate::Database { async fn mls_fetch_keypackages(&self, count: u32) -> CryptoKeystoreResult> { - let reverse = !cfg!(target_family = "wasm"); - let keypackages = self - .find_all::(EntityFindParams { - limit: Some(count), - offset: None, - reverse, - }) - .await?; - + let keypackages = self.load_all::().await?; Ok(keypackages .into_iter() .filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok()) + .take(count as _) .collect()) } async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool { - matches!(self.find::(group_id).await, Ok(Some(_))) + matches!( + self.get_borrowed::(group_id.as_ref()).await, + Ok(Some(_)) + ) } async fn mls_group_persist( @@ -162,16 +158,20 @@ impl CryptoKeystoreMls for crate::Database { async fn mls_groups_restore( &self, ) -> CryptoKeystoreResult, (Option>, Vec)>> { - let groups = self.find_all::(EntityFindParams::default()).await?; + let groups = self.load_all::().await?; Ok(groups .into_iter() - .map(|group: PersistedMlsGroup| (group.id.clone(), (group.parent_id.clone(), group.state.clone()))) + .map(|mut group: PersistedMlsGroup| { + let id = std::mem::take(&mut group.id); + let parent_id = std::mem::take(&mut group.parent_id); + let state = std::mem::take(&mut group.state); + (id, (parent_id, state)) + }) .collect()) } async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> { - self.remove::(group_id).await?; - + self.remove_borrowed::(group_id.as_ref()).await?; Ok(()) } @@ -196,7 +196,7 @@ impl CryptoKeystoreMls for crate::Database { &self, group_id: impl AsRef<[u8]> + Send, ) -> CryptoKeystoreResult<(Vec, Vec)> { - self.find(group_id) + self.get_borrowed(group_id.as_ref()) .await? .map(|r: PersistedMlsPendingGroup| (r.state.clone(), r.custom_configuration.clone())) .ok_or(CryptoKeystoreError::MissingKeyInStore( @@ -205,7 +205,8 @@ impl CryptoKeystoreMls for crate::Database { } async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> { - self.remove::(group_id).await + self.remove_borrowed::(group_id.as_ref()) + .await } async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> { @@ -219,13 +220,13 @@ impl CryptoKeystoreMls for crate::Database { async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult> { // someone who has time could try to optimize this but honestly it's really on the cold path - let enrollment = self - .find::(id) - .await? - .ok_or(CryptoKeystoreError::MissingKeyInStore( - MissingKeyErrorKind::StoredE2eiEnrollment, - ))?; - self.remove::(id).await?; + let enrollment = + self.get_borrowed::(id) + .await? + .ok_or(CryptoKeystoreError::MissingKeyInStore( + MissingKeyErrorKind::StoredE2eiEnrollment, + ))?; + self.remove_borrowed::(id).await?; Ok(enrollment.content.clone()) } } @@ -313,11 +314,12 @@ impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database match V::ID { MlsEntityId::GroupState => { - let group: PersistedMlsGroup = self.find(k).await.ok().flatten()?; + let group: PersistedMlsGroup = self.get_borrowed(k).await.ok().flatten()?; deser(&group.state).ok() } MlsEntityId::SignatureKeyPair => { - let stored_credential = self.find::(k).await.ok().flatten()?; + let hash = Sha256Hash::from_existing_hash(k).ok()?; + let stored_credential = self.get::(&hash).await.ok().flatten()?; let ciphersuite = Ciphersuite::try_from(stored_credential.ciphersuite).ok()?; let signature_scheme = ciphersuite.signature_algorithm(); @@ -333,23 +335,23 @@ impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database deser(&mls_keypair_serialized).ok() } MlsEntityId::KeyPackage => { - let kp: StoredKeypackage = self.find(k).await.ok().flatten()?; + let kp: StoredKeypackage = self.get_borrowed(k).await.ok().flatten()?; deser(&kp.keypackage).ok() } MlsEntityId::HpkePrivateKey => { - let hpke_pk: StoredHpkePrivateKey = self.find(k).await.ok().flatten()?; + let hpke_pk: StoredHpkePrivateKey = self.get_borrowed(k).await.ok().flatten()?; deser(&hpke_pk.sk).ok() } MlsEntityId::PskBundle => { - let psk_bundle: StoredPskBundle = self.find(k).await.ok().flatten()?; + let psk_bundle: StoredPskBundle = self.get_borrowed(k).await.ok().flatten()?; deser(&psk_bundle.psk).ok() } MlsEntityId::EncryptionKeyPair => { - let kp: StoredEncryptionKeyPair = self.find(k).await.ok().flatten()?; + let kp: StoredEncryptionKeyPair = self.get_borrowed(k).await.ok().flatten()?; deser(&kp.sk).ok() } MlsEntityId::EpochEncryptionKeyPair => { - let kp: StoredEpochEncryptionKeypair = self.find(k).await.ok().flatten()?; + let kp: StoredEpochEncryptionKeypair = self.get_borrowed(k).await.ok().flatten()?; deser(&kp.keypairs).ok() } } @@ -357,16 +359,16 @@ impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database async fn delete(&self, k: &[u8]) -> Result<(), Self::Error> { match V::ID { - MlsEntityId::GroupState => self.remove::(k).await?, + MlsEntityId::GroupState => self.remove_borrowed::(k).await?, MlsEntityId::SignatureKeyPair => unimplemented!( "Deleting a signature key pair should not be done through this API, any keypair should be deleted via deleting a credential." ), - MlsEntityId::HpkePrivateKey => self.remove::(k).await?, - MlsEntityId::KeyPackage => self.remove::(k).await?, - MlsEntityId::PskBundle => self.remove::(k).await?, - MlsEntityId::EncryptionKeyPair => self.remove::(k).await?, - MlsEntityId::EpochEncryptionKeyPair => self.remove::(k).await?, + MlsEntityId::HpkePrivateKey => self.remove_borrowed::(k).await?, + MlsEntityId::KeyPackage => self.remove_borrowed::(k).await?, + MlsEntityId::PskBundle => self.remove_borrowed::(k).await?, + MlsEntityId::EncryptionKeyPair => self.remove_borrowed::(k).await?, + MlsEntityId::EpochEncryptionKeyPair => self.remove_borrowed::(k).await?, } Ok(()) diff --git a/keystore/src/proteus.rs b/keystore/src/proteus.rs index 24b7bcaa90..70de143612 100644 --- a/keystore/src/proteus.rs +++ b/keystore/src/proteus.rs @@ -1,7 +1,6 @@ use crate::{ - CryptoKeystoreError, CryptoKeystoreResult, - connection::{Database, FetchFromDatabase}, - entities::ProteusPrekey, + CryptoKeystoreError, CryptoKeystoreResult, connection::Database, entities::ProteusPrekey, + traits::FetchFromDatabase as _, }; #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] @@ -29,15 +28,12 @@ impl proteus_traits::PreKeyStore for Database { &mut self, id: proteus_traits::RawPreKeyId, ) -> Result, Self::Error> { - Ok(self - .find::(&id.to_le_bytes()) - .await? - .map(|db_prekey| db_prekey.prekey.clone())) + self.get::(&id) + .await + .map(|db_prekey| db_prekey.map(|mut db_prekey| std::mem::take(&mut db_prekey.prekey))) } async fn remove(&mut self, id: proteus_traits::RawPreKeyId) -> Result<(), Self::Error> { - Database::remove::(self, id.to_le_bytes()).await?; - - Ok(()) + Database::remove::(self, &id).await } } diff --git a/keystore/src/traits/entity.rs b/keystore/src/traits/entity.rs index c1a30a58b7..acaa782f12 100644 --- a/keystore/src/traits/entity.rs +++ b/keystore/src/traits/entity.rs @@ -4,7 +4,10 @@ use async_trait::async_trait; use crate::{ CryptoKeystoreResult, - traits::{EntityBase, KeyType}, + traits::{ + EntityBase, KeyType, OwnedKeyType, + primary_key::{BorrowPrimaryKey, PrimaryKey}, + }, }; /// Something which can be stored in our database. @@ -12,19 +15,7 @@ use crate::{ /// It has a primary key, which uniquely identifies it. #[cfg_attr(target_family = "wasm", async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), async_trait)] -pub trait Entity: EntityBase { - /// Each distinct `PrimaryKey` uniquely identifies either 0 or 1 instance. - /// - /// This constraint should be enforced at the DB level. - type PrimaryKey: KeyType; - - /// Get this entity's primary key. - /// - /// This must return an owned type, because there are some entities for which only owned primary keys are possible. - /// However, entities which have primary keys owned within the entity itself should consider also implementing - /// [`BorrowPrimaryKey`] for greater efficiency. - fn primary_key(&self) -> Self::PrimaryKey; - +pub trait Entity: EntityBase + PrimaryKey { /// Get an entity by its primary key. /// /// For entites whose primary key has a distinct borrowed type, it is best to implement this as a direct @@ -44,25 +35,17 @@ pub trait Entity: EntityBase { async fn load_all(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult>; } -/// An extension trait which should be implemented for all entities whose primary key has a distinct borrowed form. -/// -/// i.e. `String`, `Vec`, etc. #[cfg_attr(target_family = "wasm", async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), async_trait)] -pub trait BorrowPrimaryKey: Entity { - type BorrowedPrimaryKey: ?Sized + ToOwned; - - /// Borrow this entity's primary key without copying any data. - /// - /// This borrowed key has a lifetime tied to that of this entity. - fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey; - +pub trait EntityGetBorrowed: Entity + BorrowPrimaryKey { /// Get an entity by a borrowed form of its primary key. /// /// The type signature here is somewhat complicated, but it breaks down simply: if our primary key is something /// like `Vec`, we want to be able to use this method even if what we have on hand is `&[u8]`. - async fn get_borrowed(conn: &mut Self::ConnectionType, key: &Q) -> CryptoKeystoreResult> + async fn get_borrowed( + conn: &mut Self::ConnectionType, + key: &Self::BorrowedPrimaryKey, + ) -> CryptoKeystoreResult> where - Self::PrimaryKey: Borrow, - Q: KeyType; + for<'pk> &'pk Self::BorrowedPrimaryKey: KeyType; } diff --git a/keystore/src/traits/entity_base.rs b/keystore/src/traits/entity_base.rs index c11cb3bf75..64413b5bff 100644 --- a/keystore/src/traits/entity_base.rs +++ b/keystore/src/traits/entity_base.rs @@ -1,3 +1,5 @@ +use std::{any::Any, sync::Arc}; + use crate::connection::DatabaseConnection; /// A supertrait that all entities must implement. This handles multiplexing over the two different database backends. @@ -5,6 +7,12 @@ use crate::connection::DatabaseConnection; /// This trait should be removed once the persistence layers are unified. See WPB-16241. pub trait EntityBase: 'static + Sized { type ConnectionType: for<'a> DatabaseConnection<'a>; + + /// Entities which implement `EntityDatabaseMutation` have a `pre_save` method which might generate or + /// update some fields of the item. The canonical example is an `updated_at` field. + /// + /// This type must contain a copy of each modification to the item, so that the caller of a `.save(entity)` + /// function can know what has changed and what the new values are. type AutoGeneratedFields: Default; /// Beware: if you change the value of this constant on any WASM entity, you'll need to do a data migration @@ -16,5 +24,14 @@ pub trait EntityBase: 'static + Sized { as_dyn_any.downcast_ref() } + fn downcast_arc(self: Arc) -> Option> + where + Self: Send + Sync, + T: EntityBase + Send + Sync, + { + let as_dyn_any = self as Arc; + as_dyn_any.downcast().ok() + } + fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity; } diff --git a/keystore/src/traits/entity_database_mutation.rs b/keystore/src/traits/entity_database_mutation.rs index 22f3820b3e..be31407504 100644 --- a/keystore/src/traits/entity_database_mutation.rs +++ b/keystore/src/traits/entity_database_mutation.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use crate::{ CryptoKeystoreResult, - traits::{Entity, KeyType}, + traits::{BorrowPrimaryKey, Entity, KeyType, PrimaryKey}, }; /// Extend an [`Entity`] with db-mutating operations which can be performed when provided with a transaction. @@ -17,6 +17,9 @@ pub trait EntityDatabaseMutation<'a>: Entity CryptoKeystoreResult { Ok(Default::default()) } @@ -39,22 +42,18 @@ pub trait EntityDatabaseMutation<'a>: Entity>::delete_borrowed(tx, id).await /// } /// ``` - async fn delete(tx: &Self::Transaction, id: &::PrimaryKey) -> CryptoKeystoreResult; + async fn delete(tx: &Self::Transaction, id: &Self::PrimaryKey) -> CryptoKeystoreResult; } /// Extend an [`Entity`] with db-mutating operations which can be performed when provided with a transaction. #[cfg_attr(target_family = "wasm", async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), async_trait)] -pub trait EntityDeleteBorrowed<'a>: EntityDatabaseMutation<'a> { +pub trait EntityDeleteBorrowed<'a>: EntityDatabaseMutation<'a> + BorrowPrimaryKey { /// Delete an entity by a borrowed form of its primary key. /// /// The type signature here is somewhat complicated, but it breaks down simply: if our primary key is something /// like `Vec`, we want to be able to use this method even if what we have on hand is `&[u8]`. - async fn delete_borrowed( - tx: &>::Transaction, - id: &Q, - ) -> CryptoKeystoreResult + async fn delete_borrowed(tx: &Self::Transaction, id: &Self::BorrowedPrimaryKey) -> CryptoKeystoreResult where - Self::PrimaryKey: Borrow, - Q: KeyType; + for<'pk> &'pk Self::BorrowedPrimaryKey: KeyType; } diff --git a/keystore/src/traits/fetch_from_database.rs b/keystore/src/traits/fetch_from_database.rs index 360070550c..7bd472b920 100644 --- a/keystore/src/traits/fetch_from_database.rs +++ b/keystore/src/traits/fetch_from_database.rs @@ -1,9 +1,13 @@ +use std::borrow::Borrow; + use async_trait::async_trait; use crate::{ CryptoKeystoreResult, connection::KeystoreDatabaseConnection, - traits::{Entity, UniqueEntity}, + traits::{ + BorrowPrimaryKey, Entity, EntityBase, EntityGetBorrowed, KeyType, PrimaryKey, UniqueEntity, UniqueEntityExt, + }, }; /// Interface to fetch from the database either from the connection directly or through a @@ -15,22 +19,44 @@ use crate::{ #[cfg_attr(not(target_family = "wasm"), async_trait)] pub trait FetchFromDatabase: Send + Sync { /// Get an instance of `E` from the database by its primary key. - async fn get(&self, id: &::PrimaryKey) -> CryptoKeystoreResult> + async fn get(&self, id: &E::PrimaryKey) -> CryptoKeystoreResult> where - E: Entity; + E: Entity + Clone + Send + Sync; /// Count the number of `E`s in the database. async fn count(&self) -> CryptoKeystoreResult where - E: Entity; + E: Entity + Clone + Send + Sync; /// Load all `E`s from the database. async fn load_all(&self) -> CryptoKeystoreResult> where - E: Entity; + E: Entity + Clone + Send + Sync; + + /// Get an instance of `E` from the database by the borrowed form of its primary key. + async fn get_borrowed( + &self, + id: &::BorrowedPrimaryKey, + ) -> CryptoKeystoreResult> + where + E: EntityGetBorrowed + Clone + Send + Sync, + E::PrimaryKey: Borrow, + for<'a> &'a E::BorrowedPrimaryKey: KeyType; /// Get the requested unique entity from the database. - async fn get_unique(&self) -> CryptoKeystoreResult> + async fn get_unique<'a, U>(&self) -> CryptoKeystoreResult> + where + U: UniqueEntityExt<'a> + Entity + Clone + Send + Sync, + { + self.get::(&U::KEY).await + } + + /// Determine whether a unique entity is present in the database. + async fn exists<'a, U>(&self) -> CryptoKeystoreResult where - U: UniqueEntity; + U: UniqueEntityExt<'a> + Entity + Clone + Send + Sync, + { + let count = self.count::().await?; + Ok(count > 0) + } } diff --git a/keystore/src/traits/item_encryption/aad.rs b/keystore/src/traits/item_encryption/aad.rs index 9537bcc2de..25d0ce22d8 100644 --- a/keystore/src/traits/item_encryption/aad.rs +++ b/keystore/src/traits/item_encryption/aad.rs @@ -32,4 +32,11 @@ impl Aad { let id = primary_key.bytes().into_owned(); Self { type_name, id } } + + /// Don't use this unless you really have to! Prefer [`Self::from_primary_key`]. + pub(super) fn from_encryption_key_bytes(key_bytes: impl AsRef<[u8]>) -> Self { + let type_name = E::COLLECTION_NAME.as_bytes().to_vec(); + let id = key_bytes.as_ref().to_owned(); + Self { type_name, id } + } } diff --git a/keystore/src/traits/item_encryption/decrypt_data.rs b/keystore/src/traits/item_encryption/decrypt_data.rs index cdd01d9ad6..be33d539cb 100644 --- a/keystore/src/traits/item_encryption/decrypt_data.rs +++ b/keystore/src/traits/item_encryption/decrypt_data.rs @@ -1,7 +1,7 @@ use super::aad::{AES_GCM_256_NONCE_SIZE, Aad}; use crate::{ CryptoKeystoreError, CryptoKeystoreResult, - traits::{Entity, EntityDeleteBorrowed, KeyType as _}, + traits::{EncryptionKey, Entity, EntityDeleteBorrowed, KeyType as _}, }; fn decrypt_with_nonce_and_aad( @@ -29,7 +29,7 @@ pub trait DecryptData: Entity { /// Decrypt some data, symmetrically to the process [`encrypt_data`][super::EncryptData::encrypt_data] uses. fn decrypt_data( cipher: &aes_gcm::Aes256Gcm, - primary_key: &::PrimaryKey, + primary_key: &Self::PrimaryKey, data: &[u8], ) -> CryptoKeystoreResult>; } @@ -47,3 +47,32 @@ impl DecryptData for E { decrypt_with_nonce_and_aad(cipher, msg, nonce, &aad) } } + +/// This trait uses an explicitly-set decryption key to decrypt some data. +/// +/// This should rarely be used. +pub trait DecryptWithExplicitEncryptionKey { + /// Decrypt some data with an encryption key (see [`EncryptionKey`]) instead of the instance's primary key. + fn decrypt_data_with_encryption_key( + cipher: &aes_gcm::Aes256Gcm, + encryption_key: &[u8], + data: &[u8], + ) -> CryptoKeystoreResult>; +} + +impl DecryptWithExplicitEncryptionKey for E +where + E: Entity + EncryptionKey, +{ + fn decrypt_data_with_encryption_key( + cipher: &aes_gcm::Aes256Gcm, + encryption_key: &[u8], + data: &[u8], + ) -> CryptoKeystoreResult> { + let aad = Aad::from_encryption_key_bytes::(encryption_key).serialize()?; + let (nonce, msg) = data + .split_at_checked(AES_GCM_256_NONCE_SIZE) + .ok_or(CryptoKeystoreError::AesGcmError)?; + decrypt_with_nonce_and_aad(cipher, msg, nonce, &aad) + } +} diff --git a/keystore/src/traits/item_encryption/encrypt_data.rs b/keystore/src/traits/item_encryption/encrypt_data.rs index 55b214066a..0f558d091d 100644 --- a/keystore/src/traits/item_encryption/encrypt_data.rs +++ b/keystore/src/traits/item_encryption/encrypt_data.rs @@ -46,3 +46,41 @@ impl EncryptData for E { encrypt_with_nonce_and_aad(cipher, data, &nonce_bytes, &aad) } } + +/// This trait is an hack enabling us to encrypt types for which we don't use the primary key in the AAD. +/// +/// The only reason we'd ever want this is if the primary key is not what we actually use, and the only +/// reason that would be the case is if we're faking a primary key where no such key really exists. +/// +/// In other words, MLS pending messages. +pub trait EncryptionKey { + /// Get the key bytes which are to be used as the encryption key for this data. + fn encryption_key(&self) -> &[u8]; +} + +/// This trait uses the explicitly-set encryption key to encrypt some data. +/// +/// This should rarely be used. +pub trait EncryptWithExplicitEncryptionKey { + /// Encrypt some data with an encryption key (see [`EncryptionKey`]) instead of the instance's primary key. + fn encrypt_data_with_encryption_key( + &self, + cipher: &aes_gcm::Aes256Gcm, + data: &[u8], + ) -> CryptoKeystoreResult>; +} + +impl EncryptWithExplicitEncryptionKey for E +where + E: Entity + EncryptionKey, +{ + fn encrypt_data_with_encryption_key( + &self, + cipher: &aes_gcm::Aes256Gcm, + data: &[u8], + ) -> CryptoKeystoreResult> { + let aad = Aad::from_encryption_key_bytes::(self.encryption_key()).serialize()?; + let nonce_bytes: [u8; AES_GCM_256_NONCE_SIZE] = rand::random(); + encrypt_with_nonce_and_aad(cipher, data, &nonce_bytes, &aad) + } +} diff --git a/keystore/src/traits/item_encryption/mod.rs b/keystore/src/traits/item_encryption/mod.rs index 258a23bd66..b7d8085a14 100644 --- a/keystore/src/traits/item_encryption/mod.rs +++ b/keystore/src/traits/item_encryption/mod.rs @@ -39,7 +39,7 @@ mod decrypting; mod encrypt_data; mod encrypting; -pub use decrypt_data::DecryptData; +pub use decrypt_data::{DecryptData, DecryptWithExplicitEncryptionKey}; pub use decrypting::{Decryptable, Decrypting}; -pub use encrypt_data::EncryptData; +pub use encrypt_data::{EncryptData, EncryptWithExplicitEncryptionKey, EncryptionKey}; pub use encrypting::Encrypting; diff --git a/keystore/src/traits/key_type.rs b/keystore/src/traits/key_type.rs index 53b7f5cb5f..d00d875669 100644 --- a/keystore/src/traits/key_type.rs +++ b/keystore/src/traits/key_type.rs @@ -4,7 +4,7 @@ use std::borrow::Cow; /// /// This might be a primary key, in which case the key uniquely identifies either 0 or 1 entries in the database. /// Or it might be a search key, in which case the key could match any number of entries. -pub trait KeyType: Send + Sync { +pub trait KeyType: Send + Sync + Sized { /// Get a unique binary representation of this key. /// /// For simple keys it can just be the borrowed form of the key itself, @@ -12,35 +12,52 @@ pub trait KeyType: Send + Sync { fn bytes(&self) -> Cow<'_, [u8]>; } -// useful for unique entities; non-allocating -impl KeyType for () { - fn bytes(&self) -> Cow<'_, [u8]> { - Vec::new().into() - } +/// An owned key type can be converted to from arbitrary bytes. +pub trait OwnedKeyType: 'static + KeyType { + /// Parse some bytes into an instance of this type. + /// + /// We're just going with `Option` instead of `CryptoKeystoreResult` for now because + /// the hopeful assumption is that this is going to be a rare occurrence that doesn't + /// need much explanation. + fn from_bytes(bytes: &[u8]) -> Option; } macro_rules! impl_keytype { - ($t:ty, |$self:ident| $impl:expr) => { + ($t:ty, |$self:ident| $bytes:expr) => { impl KeyType for $t { fn bytes(&$self) -> Cow<'_, [u8]> { - $impl.into() + $bytes.into() } } }; + ($t:ty, |$self:ident| $bytes:expr, |$bytes_id:ident| $from_bytes:expr) => { + impl_keytype!($t, |$self| $bytes); + + impl OwnedKeyType for $t { + fn from_bytes($bytes_id: &[u8]) -> Option { + $from_bytes + } + } + }; + + } +// useful for unique entities; non-allocating +impl_keytype!((), |self| Vec::new(), |bytes| bytes.is_empty().then_some(())); impl_keytype!(&[u8], |self| *self); -impl_keytype!(Vec, |self| self.as_slice()); +impl_keytype!(Vec, |self| self.as_slice(), |bytes| Some(bytes.into())); impl_keytype!(&str, |self| self.as_bytes()); -impl_keytype!(String, |self| self.as_bytes()); +impl_keytype!(String, |self| self.as_bytes(), |bytes| str::from_utf8(bytes) + .ok() + .map(ToOwned::to_owned)); macro_rules! impl_keytype_for_integer { ($t:ty) => { - impl KeyType for $t { - fn bytes(&self) -> Cow<'_, [u8]> { - Vec::from(self.to_le_bytes()).into() - } - } + impl_keytype!($t, |self| Vec::from(self.to_le_bytes()), |bytes| { + let array = bytes.try_into().ok()?; + Some(<$t>::from_le_bytes(array)) + }); }; } @@ -56,8 +73,4 @@ impl_keytype_for_integer!(i64); impl_keytype_for_integer!(i128); /// Some unique entities use a single byte as a key type -impl KeyType for [u8; 1] { - fn bytes(&self) -> Cow<'_, [u8]> { - self.into() - } -} +impl_keytype!([u8; 1], |self| self, |bytes| bytes.try_into().ok()); diff --git a/keystore/src/traits/mod.rs b/keystore/src/traits/mod.rs index 5068dba59e..e3aff4b4b5 100644 --- a/keystore/src/traits/mod.rs +++ b/keystore/src/traits/mod.rs @@ -9,12 +9,17 @@ mod entity_database_mutation; mod fetch_from_database; mod item_encryption; mod key_type; +mod primary_key; mod unique_entity; -pub use entity::{BorrowPrimaryKey, Entity}; +pub use entity::{Entity, EntityGetBorrowed}; pub use entity_base::EntityBase; pub use entity_database_mutation::{EntityDatabaseMutation, EntityDeleteBorrowed}; pub use fetch_from_database::FetchFromDatabase; -pub use item_encryption::{DecryptData, Decryptable, Decrypting, EncryptData, Encrypting}; -pub use key_type::KeyType; +pub use item_encryption::{ + DecryptData, DecryptWithExplicitEncryptionKey, Decryptable, Decrypting, EncryptData, + EncryptWithExplicitEncryptionKey, Encrypting, EncryptionKey, +}; +pub use key_type::{KeyType, OwnedKeyType}; +pub use primary_key::{BorrowPrimaryKey, PrimaryKey}; pub use unique_entity::{UniqueEntity, UniqueEntityExt, UniqueEntityImplementationHelper}; diff --git a/keystore/src/traits/primary_key.rs b/keystore/src/traits/primary_key.rs new file mode 100644 index 0000000000..8aa218812b --- /dev/null +++ b/keystore/src/traits/primary_key.rs @@ -0,0 +1,30 @@ +use async_trait::async_trait; + +use crate::traits::OwnedKeyType; + +/// Something which has a distinct primary key which can uniquely identify it. +pub trait PrimaryKey { + /// Each distinct `PrimaryKey` uniquely identifies either 0 or 1 instance. + /// + /// This constraint should be enforced at the DB level. + type PrimaryKey: OwnedKeyType; + + /// Get this entity's primary key. + /// + /// This must return an owned type, because there are some entities for which only owned primary keys are possible. + /// However, entities which have primary keys owned within the entity itself should consider also implementing + /// [`BorrowPrimaryKey`] for greater efficiency. + fn primary_key(&self) -> Self::PrimaryKey; +} + +/// Something whose primary key can be borrowed as a distinct type. +/// +/// i.e. `String`, `Vec`, etc. +pub trait BorrowPrimaryKey: PrimaryKey { + type BorrowedPrimaryKey: ?Sized + ToOwned; + + /// Borrow this entity's primary key without copying any data. + /// + /// This borrowed key has a lifetime tied to that of this entity. + fn borrow_primary_key(&self) -> &Self::BorrowedPrimaryKey; +} diff --git a/keystore/src/traits/unique_entity.rs b/keystore/src/traits/unique_entity.rs index 3560081713..a0dc231f00 100644 --- a/keystore/src/traits/unique_entity.rs +++ b/keystore/src/traits/unique_entity.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; #[cfg(not(target_family = "wasm"))] -use rusqlite::{OptionalExtension as _, params}; +use rusqlite::{OptionalExtension as _, ToSql, params}; #[cfg(target_family = "wasm")] use serde::de::DeserializeOwned; @@ -10,14 +10,16 @@ use crate::entities::{count_helper, count_helper_tx, delete_helper, load_all_hel use crate::traits::{Decryptable, Decrypting, Encrypting, KeyType as _}; use crate::{ CryptoKeystoreResult, - connection::TransactionWrapper, - traits::{Entity, EntityBase, entity_database_mutation::EntityDatabaseMutation}, + connection::{KeystoreDatabaseConnection, TransactionWrapper}, + traits::{Entity, EntityBase, PrimaryKey, entity_database_mutation::EntityDatabaseMutation}, }; /// A unique entity can appear either 0 or 1 times in the database. -pub trait UniqueEntity: EntityBase + Entity { +pub trait UniqueEntity: + EntityBase + PrimaryKey +{ /// The id used as they key when storing this entity in a KV store. - const KEY: ::PrimaryKey; + const KEY: Self::PrimaryKey; } /// Unique entities get some convenience methods implemented automatically. @@ -42,8 +44,49 @@ pub trait UniqueEntityExt<'a>: UniqueEntity + EntityDatabaseMutation<'a> { async fn exists(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult; } -#[cfg_attr(target_family = "wasm", async_trait(?Send))] -#[cfg_attr(not(target_family = "wasm"), async_trait)] +// unfortunately we have to implement this trait twice, with nearly-identical but distinct bounds + +#[cfg(target_family = "wasm")] +#[async_trait(?Send)] +impl<'a, E> UniqueEntityExt<'a> for E +where + E: UniqueEntity + EntityDatabaseMutation<'a> + Sync, +{ + /// Get this unique entity from the database. + async fn get_unique(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult> { + Self::get(conn, &Self::KEY).await + } + + /// Set this unique entity into the database, replacing it if it already exists. + /// + /// Returns `true` if the entity previously existed and was replaced, or + /// `false` if it was not removed and this was a pure insertion. + async fn set_and_replace(&'a self, tx: &Self::Transaction) -> CryptoKeystoreResult { + let deleted = Self::delete(tx, &Self::KEY).await?; + self.save(tx).await?; + Ok(deleted) + } + + /// Set this unique entity into the database if it does not already exist. + /// + /// Returns `true` if the entity was saved, or `false` if it aborted due to an already-existing entity. + async fn set_if_absent(&'a self, tx: &Self::Transaction) -> CryptoKeystoreResult { + let count = ::count(tx).await?; + if count > 0 { + return Ok(false); + } + self.save(tx).await?; + Ok(true) + } + + /// Returns whether or not the database contains an instance of this unique entity. + async fn exists(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult { + ::count(conn).await.map(|count| count > 0) + } +} + +#[cfg(not(target_family = "wasm"))] +#[async_trait] impl<'a, E> UniqueEntityExt<'a> for E where E: UniqueEntity + EntityDatabaseMutation<'a> + Sync, @@ -96,6 +139,7 @@ where /// /// If you implement this trait, you get the following traits auto-implemented: /// +/// - `PrimaryKey` /// - `UniqueEntity` /// - `Entity` /// - `EntityDatabaseMutation` @@ -110,13 +154,29 @@ pub trait UniqueEntityImplementationHelper { fn content(&self) -> &[u8]; } +impl PrimaryKey for T +where + T: EntityBase + UniqueEntityImplementationHelper, +{ + // The old keystore trait used usize as the primary key type, but that would vary + // in width across various implementations and so is intentionally not a `KeyType`. + // So we distinguish betwen `u32` and `u64` according to whether or not we're on wasm. + #[cfg(target_family = "wasm")] + type PrimaryKey = u32; + #[cfg(not(target_family = "wasm"))] + type PrimaryKey = u64; + + fn primary_key(&self) -> Self::PrimaryKey { + Self::KEY + } +} + #[cfg(target_family = "wasm")] -#[async_trait(?Send)] impl UniqueEntity for T where T: EntityBase + UniqueEntityImplementationHelper - + Entity, + + PrimaryKey, { const KEY: u32 = 0; } @@ -131,15 +191,6 @@ where + Decryptable<'static>, >::DecryptableFrom: DeserializeOwned, { - // The old trait used usize as the primary key type, but that would vary - // in width across various implementations and so is intentionally not a `KeyType`. - // Instead we use `u32` which should be the same width on wasm. - type PrimaryKey = u32; - - fn primary_key(&self) -> Self::PrimaryKey { - Self::KEY - } - async fn get(conn: &mut Self::ConnectionType, key: &Self::PrimaryKey) -> CryptoKeystoreResult> { conn.storage().new_get(key.bytes().as_ref()).await } @@ -175,16 +226,17 @@ where tx.new_count::().await } - async fn delete(tx: &Self::Transaction, id: &::PrimaryKey) -> CryptoKeystoreResult { + async fn delete(tx: &Self::Transaction, id: &Self::PrimaryKey) -> CryptoKeystoreResult { tx.new_delete::(id.bytes().as_ref()).await } } #[cfg(not(target_family = "wasm"))] -#[async_trait] impl UniqueEntity for T where - T: EntityBase + UniqueEntityImplementationHelper, + T: EntityBase + + UniqueEntityImplementationHelper + + PrimaryKey, { const KEY: u64 = 0; } @@ -193,26 +245,19 @@ where #[async_trait] impl Entity for T where - T: EntityBase + UniqueEntityImplementationHelper, + T: EntityBase + + PrimaryKey + + UniqueEntityImplementationHelper, + ::PrimaryKey: ToSql, { - // The old trait used usize as the primary key type, not u64, but that would vary - // in width across various implementations and so is intentionally not a `KeyType`. - // Instead we use `u64` which should be the same width on the expected runtimes - // for non-wasm. - type PrimaryKey = u64; - - fn primary_key(&self) -> Self::PrimaryKey { - Self::KEY - } - async fn get(conn: &mut Self::ConnectionType, key: &Self::PrimaryKey) -> CryptoKeystoreResult> { let conn = conn.conn().await; let mut statement = conn.prepare_cached(&format!( - "SELECT * FROM {collection_name} WHERE id = ?", + "SELECT content FROM {collection_name} WHERE id = ?", collection_name = Self::COLLECTION_NAME ))?; statement - .query_row([key], |row| Ok(Self::new(row.get(0)?))) + .query_row([key], |row| Ok(Self::new(row.get("content")?))) .optional() .map_err(Into::into) } @@ -224,7 +269,7 @@ where /// Retrieve all entities of this type from the database. async fn load_all(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult> { - load_all_helper::(conn, |row| Ok(Self::new(row.get(0)?))).await + load_all_helper::(conn, |row| Ok(Self::new(row.get("content")?))).await } } @@ -233,8 +278,10 @@ where impl<'a, T> EntityDatabaseMutation<'a> for T where T: EntityBase + + UniqueEntity + UniqueEntityImplementationHelper + Sync, + ::PrimaryKey: ToSql, { type Transaction = TransactionWrapper<'a>; @@ -251,7 +298,7 @@ where count_helper_tx::(tx).await } - async fn delete(tx: &Self::Transaction, id: &::PrimaryKey) -> CryptoKeystoreResult { + async fn delete(tx: &Self::Transaction, id: &Self::PrimaryKey) -> CryptoKeystoreResult { delete_helper::(tx, "id", id).await } } diff --git a/keystore/src/transaction/dynamic_dispatch.rs b/keystore/src/transaction/dynamic_dispatch.rs deleted file mode 100644 index df993237ef..0000000000 --- a/keystore/src/transaction/dynamic_dispatch.rs +++ /dev/null @@ -1,211 +0,0 @@ -//! This module exists merely because the `Entity` trait is not object safe. -//! See . - -#[cfg(target_family = "wasm")] -use crate::entities::E2eiRefreshToken; -#[cfg(feature = "proteus-keystore")] -use crate::entities::{ProteusIdentity, ProteusPrekey, ProteusSession}; -use crate::{ - CryptoKeystoreError, CryptoKeystoreResult, - connection::TransactionWrapper, - entities::{ - ConsumerData, E2eiAcmeCA, E2eiCrl, E2eiIntermediateCert, EntityBase, EntityTransactionExt, MlsPendingMessage, - PersistedMlsGroup, PersistedMlsPendingGroup, StoredBufferedCommit, StoredCredential, StoredE2eiEnrollment, - StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle, - StringEntityId, UniqueEntity, - }, -}; - -#[derive(Debug)] -pub enum Entity { - ConsumerData(ConsumerData), - HpkePrivateKey(StoredHpkePrivateKey), - StoredKeypackage(StoredKeypackage), - PskBundle(StoredPskBundle), - EncryptionKeyPair(StoredEncryptionKeyPair), - StoredEpochEncryptionKeypair(StoredEpochEncryptionKeypair), - StoredCredential(StoredCredential), - StoredBufferedCommit(StoredBufferedCommit), - PersistedMlsGroup(PersistedMlsGroup), - PersistedMlsPendingGroup(PersistedMlsPendingGroup), - MlsPendingMessage(MlsPendingMessage), - StoredE2eiEnrollment(StoredE2eiEnrollment), - #[cfg(target_family = "wasm")] - E2eiRefreshToken(E2eiRefreshToken), - E2eiAcmeCA(E2eiAcmeCA), - E2eiIntermediateCert(E2eiIntermediateCert), - E2eiCrl(E2eiCrl), - #[cfg(feature = "proteus-keystore")] - ProteusIdentity(ProteusIdentity), - #[cfg(feature = "proteus-keystore")] - ProteusPrekey(ProteusPrekey), - #[cfg(feature = "proteus-keystore")] - ProteusSession(ProteusSession), -} - -#[derive(Debug, Clone, PartialEq)] -pub enum EntityId { - HpkePrivateKey(Vec), - KeyPackage(Vec), - PskBundle(Vec), - EncryptionKeyPair(Vec), - EpochEncryptionKeyPair(Vec), - StoredCredential(Vec), - StoredBufferedCommit(Vec), - PersistedMlsGroup(Vec), - PersistedMlsPendingGroup(Vec), - MlsPendingMessage(Vec), - StoredE2eiEnrollment(Vec), - #[cfg(target_family = "wasm")] - E2eiRefreshToken(Vec), - E2eiAcmeCA(Vec), - E2eiIntermediateCert(Vec), - E2eiCrl(Vec), - #[cfg(feature = "proteus-keystore")] - ProteusIdentity(Vec), - #[cfg(feature = "proteus-keystore")] - ProteusPrekey(Vec), - #[cfg(feature = "proteus-keystore")] - ProteusSession(Vec), -} - -impl EntityId { - fn as_id(&self) -> StringEntityId<'_> { - match self { - EntityId::HpkePrivateKey(vec) => vec.as_slice().into(), - EntityId::KeyPackage(vec) => vec.as_slice().into(), - EntityId::PskBundle(vec) => vec.as_slice().into(), - EntityId::EncryptionKeyPair(vec) => vec.as_slice().into(), - EntityId::EpochEncryptionKeyPair(vec) => vec.as_slice().into(), - EntityId::StoredCredential(vec) => vec.as_slice().into(), - EntityId::StoredBufferedCommit(vec) => vec.as_slice().into(), - EntityId::PersistedMlsGroup(vec) => vec.as_slice().into(), - EntityId::PersistedMlsPendingGroup(vec) => vec.as_slice().into(), - EntityId::MlsPendingMessage(vec) => vec.as_slice().into(), - EntityId::StoredE2eiEnrollment(vec) => vec.as_slice().into(), - #[cfg(target_family = "wasm")] - EntityId::E2eiRefreshToken(vec) => vec.as_slice().into(), - EntityId::E2eiAcmeCA(vec) => vec.as_slice().into(), - EntityId::E2eiIntermediateCert(vec) => vec.as_slice().into(), - EntityId::E2eiCrl(vec) => vec.as_slice().into(), - #[cfg(feature = "proteus-keystore")] - EntityId::ProteusIdentity(vec) => vec.as_slice().into(), - #[cfg(feature = "proteus-keystore")] - EntityId::ProteusSession(id) => id.as_slice().into(), - #[cfg(feature = "proteus-keystore")] - EntityId::ProteusPrekey(vec) => vec.as_slice().into(), - } - } - - pub(crate) fn from_collection_name(entity_id: &'static str, id: &[u8]) -> CryptoKeystoreResult { - match entity_id { - StoredHpkePrivateKey::COLLECTION_NAME => Ok(Self::HpkePrivateKey(id.into())), - StoredKeypackage::COLLECTION_NAME => Ok(Self::KeyPackage(id.into())), - StoredPskBundle::COLLECTION_NAME => Ok(Self::PskBundle(id.into())), - StoredEncryptionKeyPair::COLLECTION_NAME => Ok(Self::EncryptionKeyPair(id.into())), - StoredEpochEncryptionKeypair::COLLECTION_NAME => Ok(Self::EpochEncryptionKeyPair(id.into())), - StoredBufferedCommit::COLLECTION_NAME => Ok(Self::StoredBufferedCommit(id.into())), - PersistedMlsGroup::COLLECTION_NAME => Ok(Self::PersistedMlsGroup(id.into())), - PersistedMlsPendingGroup::COLLECTION_NAME => Ok(Self::PersistedMlsPendingGroup(id.into())), - StoredCredential::COLLECTION_NAME => Ok(Self::StoredCredential(id.into())), - MlsPendingMessage::COLLECTION_NAME => Ok(Self::MlsPendingMessage(id.into())), - StoredE2eiEnrollment::COLLECTION_NAME => Ok(Self::StoredE2eiEnrollment(id.into())), - E2eiCrl::COLLECTION_NAME => Ok(Self::E2eiCrl(id.into())), - E2eiAcmeCA::COLLECTION_NAME => Ok(Self::E2eiAcmeCA(id.into())), - #[cfg(target_family = "wasm")] - E2eiRefreshToken::COLLECTION_NAME => Ok(Self::E2eiRefreshToken(id.into())), - E2eiIntermediateCert::COLLECTION_NAME => Ok(Self::E2eiIntermediateCert(id.into())), - #[cfg(feature = "proteus-keystore")] - ProteusIdentity::COLLECTION_NAME => Ok(Self::ProteusIdentity(id.into())), - #[cfg(feature = "proteus-keystore")] - ProteusPrekey::COLLECTION_NAME => Ok(Self::ProteusPrekey(id.into())), - #[cfg(feature = "proteus-keystore")] - ProteusSession::COLLECTION_NAME => Ok(Self::ProteusSession(id.into())), - _ => Err(CryptoKeystoreError::NotImplemented), - } - } - - pub(crate) fn collection_name(&self) -> &'static str { - match self { - EntityId::KeyPackage(_) => StoredKeypackage::COLLECTION_NAME, - EntityId::PskBundle(_) => StoredPskBundle::COLLECTION_NAME, - EntityId::EncryptionKeyPair(_) => StoredEncryptionKeyPair::COLLECTION_NAME, - EntityId::EpochEncryptionKeyPair(_) => StoredEpochEncryptionKeypair::COLLECTION_NAME, - EntityId::StoredCredential(_) => StoredCredential::COLLECTION_NAME, - EntityId::StoredBufferedCommit(_) => StoredBufferedCommit::COLLECTION_NAME, - EntityId::PersistedMlsGroup(_) => PersistedMlsGroup::COLLECTION_NAME, - EntityId::PersistedMlsPendingGroup(_) => PersistedMlsPendingGroup::COLLECTION_NAME, - EntityId::MlsPendingMessage(_) => MlsPendingMessage::COLLECTION_NAME, - EntityId::StoredE2eiEnrollment(_) => StoredE2eiEnrollment::COLLECTION_NAME, - #[cfg(target_family = "wasm")] - EntityId::E2eiRefreshToken(_) => E2eiRefreshToken::COLLECTION_NAME, - EntityId::E2eiAcmeCA(_) => E2eiAcmeCA::COLLECTION_NAME, - EntityId::E2eiIntermediateCert(_) => E2eiIntermediateCert::COLLECTION_NAME, - EntityId::E2eiCrl(_) => E2eiCrl::COLLECTION_NAME, - #[cfg(feature = "proteus-keystore")] - EntityId::ProteusIdentity(_) => ProteusIdentity::COLLECTION_NAME, - #[cfg(feature = "proteus-keystore")] - EntityId::ProteusPrekey(_) => ProteusPrekey::COLLECTION_NAME, - #[cfg(feature = "proteus-keystore")] - EntityId::ProteusSession(_) => ProteusSession::COLLECTION_NAME, - EntityId::HpkePrivateKey(_) => StoredHpkePrivateKey::COLLECTION_NAME, - } - } -} - -pub async fn execute_save(tx: &TransactionWrapper<'_>, entity: &Entity) -> CryptoKeystoreResult<()> { - match entity { - Entity::ConsumerData(consumer_data) => consumer_data.replace(tx).await, - Entity::HpkePrivateKey(mls_hpke_private_key) => mls_hpke_private_key.save(tx).await, - Entity::StoredKeypackage(mls_key_package) => mls_key_package.save(tx).await, - Entity::PskBundle(mls_psk_bundle) => mls_psk_bundle.save(tx).await, - Entity::EncryptionKeyPair(mls_encryption_key_pair) => mls_encryption_key_pair.save(tx).await, - Entity::StoredEpochEncryptionKeypair(mls_epoch_encryption_key_pair) => { - mls_epoch_encryption_key_pair.save(tx).await - } - Entity::StoredCredential(mls_credential) => mls_credential.save(tx).await, - Entity::StoredBufferedCommit(mls_pending_commit) => mls_pending_commit.save(tx).await, - Entity::PersistedMlsGroup(persisted_mls_group) => persisted_mls_group.save(tx).await, - Entity::PersistedMlsPendingGroup(persisted_mls_pending_group) => persisted_mls_pending_group.save(tx).await, - Entity::MlsPendingMessage(mls_pending_message) => mls_pending_message.save(tx).await, - Entity::StoredE2eiEnrollment(e2ei_enrollment) => e2ei_enrollment.save(tx).await, - #[cfg(target_family = "wasm")] - Entity::E2eiRefreshToken(e2ei_refresh_token) => e2ei_refresh_token.replace(tx).await, - Entity::E2eiAcmeCA(e2ei_acme_ca) => e2ei_acme_ca.replace(tx).await, - Entity::E2eiIntermediateCert(e2ei_intermediate_cert) => e2ei_intermediate_cert.save(tx).await, - Entity::E2eiCrl(e2ei_crl) => e2ei_crl.save(tx).await, - #[cfg(feature = "proteus-keystore")] - Entity::ProteusSession(record) => record.save(tx).await, - #[cfg(feature = "proteus-keystore")] - Entity::ProteusIdentity(record) => record.save(tx).await, - #[cfg(feature = "proteus-keystore")] - Entity::ProteusPrekey(record) => record.save(tx).await, - } -} - -pub async fn execute_delete(tx: &TransactionWrapper<'_>, entity_id: &EntityId) -> CryptoKeystoreResult<()> { - match entity_id { - id @ EntityId::HpkePrivateKey(_) => StoredHpkePrivateKey::delete(tx, id.as_id()).await, - id @ EntityId::KeyPackage(_) => StoredKeypackage::delete(tx, id.as_id()).await, - id @ EntityId::PskBundle(_) => StoredPskBundle::delete(tx, id.as_id()).await, - id @ EntityId::EncryptionKeyPair(_) => StoredEncryptionKeyPair::delete(tx, id.as_id()).await, - id @ EntityId::EpochEncryptionKeyPair(_) => StoredEpochEncryptionKeypair::delete(tx, id.as_id()).await, - id @ EntityId::StoredCredential(_) => StoredCredential::delete(tx, id.as_id()).await, - id @ EntityId::StoredBufferedCommit(_) => StoredBufferedCommit::delete(tx, id.as_id()).await, - id @ EntityId::PersistedMlsGroup(_) => PersistedMlsGroup::delete(tx, id.as_id()).await, - id @ EntityId::PersistedMlsPendingGroup(_) => PersistedMlsPendingGroup::delete(tx, id.as_id()).await, - id @ EntityId::MlsPendingMessage(_) => MlsPendingMessage::delete(tx, id.as_id()).await, - id @ EntityId::StoredE2eiEnrollment(_) => StoredE2eiEnrollment::delete(tx, id.as_id()).await, - #[cfg(target_family = "wasm")] - id @ EntityId::E2eiRefreshToken(_) => E2eiRefreshToken::delete(tx, id.as_id()).await, - id @ EntityId::E2eiAcmeCA(_) => E2eiAcmeCA::delete(tx, id.as_id()).await, - id @ EntityId::E2eiIntermediateCert(_) => E2eiIntermediateCert::delete(tx, id.as_id()).await, - id @ EntityId::E2eiCrl(_) => E2eiCrl::delete(tx, id.as_id()).await, - #[cfg(feature = "proteus-keystore")] - id @ EntityId::ProteusSession(_) => ProteusSession::delete(tx, id.as_id()).await, - #[cfg(feature = "proteus-keystore")] - id @ EntityId::ProteusIdentity(_) => ProteusIdentity::delete(tx, id.as_id()).await, - #[cfg(feature = "proteus-keystore")] - id @ EntityId::ProteusPrekey(_) => ProteusPrekey::delete(tx, id.as_id()).await, - } -} diff --git a/keystore/src/transaction/dynamic_dispatch/entity.rs b/keystore/src/transaction/dynamic_dispatch/entity.rs new file mode 100644 index 0000000000..2009484930 --- /dev/null +++ b/keystore/src/transaction/dynamic_dispatch/entity.rs @@ -0,0 +1,113 @@ +use std::sync::Arc; + +#[cfg(target_family = "wasm")] +use crate::entities::E2eiRefreshToken; +#[cfg(feature = "proteus-keystore")] +use crate::entities::{ProteusIdentity, ProteusPrekey, ProteusSession}; +use crate::{ + CryptoKeystoreResult, + connection::TransactionWrapper, + entities::{ + ConsumerData, E2eiAcmeCA, E2eiCrl, E2eiIntermediateCert, MlsPendingMessage, PersistedMlsGroup, + PersistedMlsPendingGroup, StoredBufferedCommit, StoredCredential, StoredE2eiEnrollment, + StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle, + }, + traits::{EntityBase, EntityDatabaseMutation as _, UniqueEntityExt as _}, +}; + +#[derive(Debug)] +pub enum Entity { + ConsumerData(Arc), + HpkePrivateKey(Arc), + StoredKeypackage(Arc), + PskBundle(Arc), + EncryptionKeyPair(Arc), + StoredEpochEncryptionKeypair(Arc), + StoredCredential(Arc), + StoredBufferedCommit(Arc), + PersistedMlsGroup(Arc), + PersistedMlsPendingGroup(Arc), + MlsPendingMessage(Arc), + StoredE2eiEnrollment(Arc), + #[cfg(target_family = "wasm")] + E2eiRefreshToken(Arc), + E2eiAcmeCA(Arc), + E2eiIntermediateCert(Arc), + E2eiCrl(Arc), + #[cfg(feature = "proteus-keystore")] + ProteusIdentity(Arc), + #[cfg(feature = "proteus-keystore")] + ProteusPrekey(Arc), + #[cfg(feature = "proteus-keystore")] + ProteusSession(Arc), +} + +impl Entity { + /// Downcast this entity to an instance of the requested type. + /// + /// This increments the smart pointer counter instead of cloning the potentially large item instance. + pub(crate) fn downcast(&self) -> Option> + where + E: EntityBase + Send + Sync, + { + match self { + Entity::ConsumerData(consumer_data) => consumer_data.clone().downcast_arc(), + Entity::HpkePrivateKey(stored_hpke_private_key) => stored_hpke_private_key.clone().downcast_arc(), + Entity::StoredKeypackage(stored_keypackage) => stored_keypackage.clone().downcast_arc(), + Entity::PskBundle(stored_psk_bundle) => stored_psk_bundle.clone().downcast_arc(), + Entity::EncryptionKeyPair(stored_encryption_key_pair) => stored_encryption_key_pair.clone().downcast_arc(), + Entity::StoredEpochEncryptionKeypair(stored_epoch_encryption_keypair) => { + stored_epoch_encryption_keypair.clone().downcast_arc() + } + Entity::StoredCredential(stored_credential) => stored_credential.clone().downcast_arc(), + Entity::StoredBufferedCommit(stored_buffered_commit) => stored_buffered_commit.clone().downcast_arc(), + Entity::PersistedMlsGroup(persisted_mls_group) => persisted_mls_group.clone().downcast_arc(), + Entity::PersistedMlsPendingGroup(persisted_mls_pending_group) => { + persisted_mls_pending_group.clone().downcast_arc() + } + Entity::MlsPendingMessage(mls_pending_message) => mls_pending_message.clone().downcast_arc(), + Entity::StoredE2eiEnrollment(stored_e2ei_enrollment) => stored_e2ei_enrollment.clone().downcast_arc(), + Entity::E2eiAcmeCA(e2ei_acme_ca) => e2ei_acme_ca.clone().downcast_arc(), + Entity::E2eiIntermediateCert(e2ei_intermediate_cert) => e2ei_intermediate_cert.clone().downcast_arc(), + Entity::E2eiCrl(e2ei_crl) => e2ei_crl.clone().downcast_arc(), + #[cfg(target_family = "wasm")] + Entity::E2eiRefreshToken(e2ei_refresh_token) => e2ei_refresh_token.clone().downcast_arc(), + #[cfg(feature = "proteus-keystore")] + Entity::ProteusIdentity(proteus_identity) => proteus_identity.clone().downcast_arc(), + #[cfg(feature = "proteus-keystore")] + Entity::ProteusPrekey(proteus_prekey) => proteus_prekey.clone().downcast_arc(), + #[cfg(feature = "proteus-keystore")] + Entity::ProteusSession(proteus_session) => proteus_session.clone().downcast_arc(), + } + } + + pub(crate) async fn execute_save(&self, tx: &TransactionWrapper<'_>) -> CryptoKeystoreResult<()> { + match self { + Entity::ConsumerData(consumer_data) => consumer_data.set_and_replace(tx).await.map(|_| ()), + Entity::HpkePrivateKey(mls_hpke_private_key) => mls_hpke_private_key.save(tx).await, + Entity::StoredKeypackage(mls_key_package) => mls_key_package.save(tx).await, + Entity::PskBundle(mls_psk_bundle) => mls_psk_bundle.save(tx).await, + Entity::EncryptionKeyPair(mls_encryption_key_pair) => mls_encryption_key_pair.save(tx).await, + Entity::StoredEpochEncryptionKeypair(mls_epoch_encryption_key_pair) => { + mls_epoch_encryption_key_pair.save(tx).await + } + Entity::StoredCredential(mls_credential) => mls_credential.save(tx).await, + Entity::StoredBufferedCommit(mls_pending_commit) => mls_pending_commit.save(tx).await, + Entity::PersistedMlsGroup(persisted_mls_group) => persisted_mls_group.save(tx).await, + Entity::PersistedMlsPendingGroup(persisted_mls_pending_group) => persisted_mls_pending_group.save(tx).await, + Entity::MlsPendingMessage(mls_pending_message) => mls_pending_message.save(tx).await, + Entity::StoredE2eiEnrollment(e2ei_enrollment) => e2ei_enrollment.save(tx).await, + Entity::E2eiAcmeCA(e2ei_acme_ca) => e2ei_acme_ca.set_and_replace(tx).await.map(|_| ()), + Entity::E2eiIntermediateCert(e2ei_intermediate_cert) => e2ei_intermediate_cert.save(tx).await, + #[cfg(target_family = "wasm")] + Entity::E2eiRefreshToken(e2ei_refresh_token) => e2ei_refresh_token.set_and_replace(tx).await.map(|_| ()), + Entity::E2eiCrl(e2ei_crl) => e2ei_crl.save(tx).await, + #[cfg(feature = "proteus-keystore")] + Entity::ProteusSession(record) => record.save(tx).await, + #[cfg(feature = "proteus-keystore")] + Entity::ProteusIdentity(record) => record.save(tx).await, + #[cfg(feature = "proteus-keystore")] + Entity::ProteusPrekey(record) => record.save(tx).await, + } + } +} diff --git a/keystore/src/transaction/dynamic_dispatch/entity_id.rs b/keystore/src/transaction/dynamic_dispatch/entity_id.rs new file mode 100644 index 0000000000..b8525b37d0 --- /dev/null +++ b/keystore/src/transaction/dynamic_dispatch/entity_id.rs @@ -0,0 +1,126 @@ +use core::fmt; +use std::borrow::Cow; + +#[cfg(target_family = "wasm")] +use crate::entities::E2eiRefreshToken; +#[cfg(feature = "proteus-keystore")] +use crate::entities::{ProteusIdentity, ProteusPrekey, ProteusSession}; +use crate::{ + CryptoKeystoreError, CryptoKeystoreResult, + connection::TransactionWrapper, + entities::{ + ConsumerData, E2eiCrl, E2eiIntermediateCert, MlsPendingMessage, PersistedMlsGroup, PersistedMlsPendingGroup, + StoredBufferedCommit, StoredCredential, StoredE2eiEnrollment, StoredEncryptionKeyPair, + StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle, + }, + traits::{BorrowPrimaryKey, Entity, EntityDatabaseMutation, KeyType, OwnedKeyType as _}, + transaction::dynamic_dispatch::EntityType, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct EntityId { + typ: EntityType, + id: Vec, +} + +impl fmt::Display for EntityId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Self { typ, id } = self; + write!(f, "{typ:?}: {}", hex::encode(id)) + } +} + +impl EntityId { + fn primary_key(&self) -> CryptoKeystoreResult + where + E: Entity, + { + E::PrimaryKey::from_bytes(&self.id) + .ok_or(CryptoKeystoreError::InvalidPrimaryKeyBytes(self.typ.collection_name())) + } + + fn from_key(primary_key: Cow<'_, [u8]>) -> Option + where + E: Entity, + { + let typ = EntityType::from_collection_name(E::COLLECTION_NAME)?; + let id = primary_key.into_owned(); + Some(Self { typ, id }) + } + + pub(crate) fn from_entity(entity: &E) -> Option + where + E: Entity, + { + Self::from_key::(entity.primary_key().bytes()) + } + + pub(crate) fn from_primary_key(primary_key: &E::PrimaryKey) -> Option + where + E: Entity, + { + Self::from_key::(primary_key.bytes()) + } + + pub(crate) fn from_borrowed_primary_key(primary_key: &E::BorrowedPrimaryKey) -> Option + where + E: Entity + BorrowPrimaryKey, + { + Self::from_key::(primary_key.to_owned().bytes()) + } + + pub(crate) fn collection_name(&self) -> &'static str { + self.typ.collection_name() + } + + pub(crate) async fn execute_delete(&self, tx: &TransactionWrapper<'_>) -> CryptoKeystoreResult { + match self.typ { + EntityType::HpkePrivateKey => { + StoredHpkePrivateKey::delete(tx, &self.primary_key::()?).await + } + EntityType::KeyPackage => StoredKeypackage::delete(tx, &self.primary_key::()?).await, + EntityType::PskBundle => StoredPskBundle::delete(tx, &self.primary_key::()?).await, + EntityType::EncryptionKeyPair => { + StoredEncryptionKeyPair::delete(tx, &self.primary_key::()?).await + } + EntityType::EpochEncryptionKeyPair => { + StoredEpochEncryptionKeypair::delete(tx, &self.primary_key::()?).await + } + EntityType::StoredCredential => { + StoredCredential::delete(tx, &self.primary_key::()?).await + } + EntityType::StoredBufferedCommit => { + StoredBufferedCommit::delete(tx, &self.primary_key::()?).await + } + EntityType::PersistedMlsGroup => { + PersistedMlsGroup::delete(tx, &self.primary_key::()?).await + } + EntityType::PersistedMlsPendingGroup => { + PersistedMlsPendingGroup::delete(tx, &self.primary_key::()?).await + } + EntityType::MlsPendingMessage => { + let primary_key = self.primary_key::()?; + MlsPendingMessage::delete_by_conversation_id(tx, &primary_key.foreign_id).await + } + EntityType::StoredE2eiEnrollment => { + StoredE2eiEnrollment::delete(tx, &self.primary_key::()?).await + } + #[cfg(target_family = "wasm")] + EntityType::E2eiRefreshToken => { + E2eiRefreshToken::delete(tx, &self.primary_key::()?).await + } + EntityType::E2eiAcmeCA => Err(CryptoKeystoreError::NotImplemented), + EntityType::E2eiIntermediateCert => { + E2eiIntermediateCert::delete(tx, &self.primary_key::()?).await + } + EntityType::E2eiCrl => E2eiCrl::delete(tx, &self.primary_key::()?).await, + #[cfg(feature = "proteus-keystore")] + EntityType::ProteusSession => ProteusSession::delete(tx, &self.primary_key::()?).await, + #[cfg(feature = "proteus-keystore")] + EntityType::ProteusIdentity => ProteusIdentity::delete(tx, &self.primary_key::()?).await, + #[cfg(feature = "proteus-keystore")] + EntityType::ProteusPrekey => ProteusPrekey::delete(tx, &self.primary_key::()?).await, + EntityType::ConsumerData => ConsumerData::delete(tx, &self.primary_key::()?).await, + } + } +} diff --git a/keystore/src/transaction/dynamic_dispatch/entity_type.rs b/keystore/src/transaction/dynamic_dispatch/entity_type.rs new file mode 100644 index 0000000000..8531f5535e --- /dev/null +++ b/keystore/src/transaction/dynamic_dispatch/entity_type.rs @@ -0,0 +1,98 @@ +#[cfg(target_family = "wasm")] +use crate::entities::E2eiRefreshToken; +#[cfg(feature = "proteus-keystore")] +use crate::entities::{ProteusIdentity, ProteusPrekey, ProteusSession}; +use crate::{ + entities::{ + ConsumerData, E2eiAcmeCA, E2eiCrl, E2eiIntermediateCert, MlsPendingMessage, PersistedMlsGroup, + PersistedMlsPendingGroup, StoredBufferedCommit, StoredCredential, StoredE2eiEnrollment, + StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle, + }, + traits::EntityBase as _, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) enum EntityType { + HpkePrivateKey, + KeyPackage, + PskBundle, + EncryptionKeyPair, + EpochEncryptionKeyPair, + StoredCredential, + StoredBufferedCommit, + PersistedMlsGroup, + PersistedMlsPendingGroup, + MlsPendingMessage, + StoredE2eiEnrollment, + #[cfg(target_family = "wasm")] + E2eiRefreshToken, + E2eiAcmeCA, + E2eiIntermediateCert, + E2eiCrl, + #[cfg(feature = "proteus-keystore")] + ProteusIdentity, + #[cfg(feature = "proteus-keystore")] + ProteusPrekey, + #[cfg(feature = "proteus-keystore")] + ProteusSession, + ConsumerData, +} + +impl EntityType { + pub(crate) fn from_collection_name(collection_name: &'static str) -> Option { + match collection_name { + StoredHpkePrivateKey::COLLECTION_NAME => Some(Self::HpkePrivateKey), + StoredKeypackage::COLLECTION_NAME => Some(Self::KeyPackage), + StoredPskBundle::COLLECTION_NAME => Some(Self::PskBundle), + StoredEncryptionKeyPair::COLLECTION_NAME => Some(Self::EncryptionKeyPair), + StoredEpochEncryptionKeypair::COLLECTION_NAME => Some(Self::EpochEncryptionKeyPair), + StoredBufferedCommit::COLLECTION_NAME => Some(Self::StoredBufferedCommit), + PersistedMlsGroup::COLLECTION_NAME => Some(Self::PersistedMlsGroup), + PersistedMlsPendingGroup::COLLECTION_NAME => Some(Self::PersistedMlsPendingGroup), + StoredCredential::COLLECTION_NAME => Some(Self::StoredCredential), + MlsPendingMessage::COLLECTION_NAME => Some(Self::MlsPendingMessage), + StoredE2eiEnrollment::COLLECTION_NAME => Some(Self::StoredE2eiEnrollment), + E2eiCrl::COLLECTION_NAME => Some(Self::E2eiCrl), + E2eiAcmeCA::COLLECTION_NAME => Some(Self::E2eiAcmeCA), + #[cfg(target_family = "wasm")] + E2eiRefreshToken::COLLECTION_NAME => Some(Self::E2eiRefreshToken), + E2eiIntermediateCert::COLLECTION_NAME => Some(Self::E2eiIntermediateCert), + #[cfg(feature = "proteus-keystore")] + ProteusIdentity::COLLECTION_NAME => Some(Self::ProteusIdentity), + #[cfg(feature = "proteus-keystore")] + ProteusPrekey::COLLECTION_NAME => Some(Self::ProteusPrekey), + #[cfg(feature = "proteus-keystore")] + ProteusSession::COLLECTION_NAME => Some(Self::ProteusSession), + ConsumerData::COLLECTION_NAME => Some(Self::ConsumerData), + _ => None, + } + } + + pub(crate) fn collection_name(&self) -> &'static str { + match self { + Self::KeyPackage => StoredKeypackage::COLLECTION_NAME, + Self::PskBundle => StoredPskBundle::COLLECTION_NAME, + Self::EncryptionKeyPair => StoredEncryptionKeyPair::COLLECTION_NAME, + Self::EpochEncryptionKeyPair => StoredEpochEncryptionKeypair::COLLECTION_NAME, + Self::StoredCredential => StoredCredential::COLLECTION_NAME, + Self::StoredBufferedCommit => StoredBufferedCommit::COLLECTION_NAME, + Self::PersistedMlsGroup => PersistedMlsGroup::COLLECTION_NAME, + Self::PersistedMlsPendingGroup => PersistedMlsPendingGroup::COLLECTION_NAME, + Self::MlsPendingMessage => MlsPendingMessage::COLLECTION_NAME, + Self::StoredE2eiEnrollment => StoredE2eiEnrollment::COLLECTION_NAME, + #[cfg(target_family = "wasm")] + Self::E2eiRefreshToken => E2eiRefreshToken::COLLECTION_NAME, + Self::E2eiAcmeCA => E2eiAcmeCA::COLLECTION_NAME, + Self::E2eiIntermediateCert => E2eiIntermediateCert::COLLECTION_NAME, + Self::E2eiCrl => E2eiCrl::COLLECTION_NAME, + #[cfg(feature = "proteus-keystore")] + Self::ProteusIdentity => ProteusIdentity::COLLECTION_NAME, + #[cfg(feature = "proteus-keystore")] + Self::ProteusPrekey => ProteusPrekey::COLLECTION_NAME, + #[cfg(feature = "proteus-keystore")] + Self::ProteusSession => ProteusSession::COLLECTION_NAME, + Self::HpkePrivateKey => StoredHpkePrivateKey::COLLECTION_NAME, + Self::ConsumerData => ConsumerData::COLLECTION_NAME, + } + } +} diff --git a/keystore/src/transaction/dynamic_dispatch/mod.rs b/keystore/src/transaction/dynamic_dispatch/mod.rs new file mode 100644 index 0000000000..514a90c358 --- /dev/null +++ b/keystore/src/transaction/dynamic_dispatch/mod.rs @@ -0,0 +1,7 @@ +mod entity; +mod entity_id; +mod entity_type; + +pub(crate) use entity::Entity; +pub(crate) use entity_id::EntityId; +pub(crate) use entity_type::EntityType; diff --git a/keystore/src/transaction/mod.rs b/keystore/src/transaction/mod.rs index 0f955d9998..479c64b998 100644 --- a/keystore/src/transaction/mod.rs +++ b/keystore/src/transaction/mod.rs @@ -1,38 +1,40 @@ use std::{ - collections::{HashMap, hash_map::Entry}, + borrow::Cow, + collections::{HashMap, HashSet, hash_map::Entry}, sync::Arc, }; use async_lock::{RwLock, SemaphoreGuardArc}; use itertools::Itertools; -use zeroize::Zeroizing; -#[cfg(feature = "proteus-keystore")] -use crate::entities::proteus::*; use crate::{ CryptoKeystoreError, CryptoKeystoreResult, connection::{Database, KeystoreDatabaseConnection}, - entities::{ConsumerData, EntityBase, EntityFindParams, EntityTransactionExt, UniqueEntity, mls::*}, + entities::{MlsPendingMessage, MlsPendingMessagePrimaryKey, PersistedMlsGroupExt}, + traits::{BorrowPrimaryKey, Entity, EntityBase as _, EntityDatabaseMutation, EntityDeleteBorrowed, KeyType}, transaction::dynamic_dispatch::EntityId, }; -pub mod dynamic_dispatch; +pub(crate) mod dynamic_dispatch; -#[derive(Debug, Default, derive_more::Deref, derive_more::DerefMut)] -struct InMemoryTable(HashMap, Zeroizing>>); - -type InMemoryCache = Arc>>; +/// table: primary key -> entity reference +type InMemoryTable = HashMap; +/// collection: collection name -> table +type InMemoryCollection = Arc>>; /// This represents a transaction, where all operations will be done in memory and committed at the /// end #[derive(Debug, Clone)] pub(crate) struct KeystoreTransaction { - cache: InMemoryCache, - deleted: Arc>>, + cache: InMemoryCollection, + deleted: Arc>>, _semaphore_guard: Arc, } impl KeystoreTransaction { + /// Instantiate a new transaction. + /// + /// Requires a semaphore guard to ensure that only one exists at a time. pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult { Ok(Self { cache: Default::default(), @@ -41,357 +43,309 @@ impl KeystoreTransaction { }) } - pub(crate) async fn save_mut< - E: crate::entities::Entity + EntityTransactionExt + Sync, - >( - &self, - mut entity: E, - ) -> CryptoKeystoreResult { - entity.pre_save().await?; - let mut cache_guard = self.cache.write().await; - let table = cache_guard.entry(E::COLLECTION_NAME.to_string()).or_default(); - let serialized = postcard::to_stdvec(&entity)?; - // Use merge_key() because `id_raw()` is not always unique for records. - // For `MlsPendingMessage` it's the id of the group it belongs to. - table.insert(entity.merge_key(), Zeroizing::new(serialized)); - Ok(entity) + /// Save an entity into this transaction. + /// + /// This is a multi-step process: + /// + /// - Adjust the entity by calling its [`pre_save()`][Entity::pre_save] method. + /// - Store the entity in an internal map. + /// - Remove the entity from the set of deleted entities, if it was there. + /// - On [`Self::commit`], actually persist the entity into the supplied database. + pub(crate) async fn save<'a, E>(&self, mut entity: E) -> CryptoKeystoreResult + where + E: Entity + EntityDatabaseMutation<'a> + Send + Sync, + { + let auto_generated_fields = entity.pre_save().await?; + + let entity_id = EntityId::from_entity(&entity).ok_or(CryptoKeystoreError::UnknownEntity(E::COLLECTION_NAME))?; + { + // start by adding the entity + let mut cache_guard = self.cache.write().await; + let table = cache_guard.entry(E::COLLECTION_NAME).or_default(); + table.insert(entity_id.clone(), entity.to_transaction_entity()); + } + { + // at this point remove the entity from the set of deleted entities to ensure that + // this new data gets propagated + let mut cache_guard = self.deleted.write().await; + cache_guard.remove(&entity_id); + } + + Ok(auto_generated_fields) } - pub(crate) async fn remove< - E: crate::entities::Entity + EntityTransactionExt, - S: AsRef<[u8]>, - >( - &self, - id: S, - ) -> CryptoKeystoreResult<()> { + async fn remove_by_entity_id<'a, E>(&self, entity_id: EntityId) -> CryptoKeystoreResult<()> + where + E: Entity + EntityDatabaseMutation<'a>, + { + // rm this entity from the set of added/modified items + // it might never touch the real db at all let mut cache_guard = self.cache.write().await; - if let Entry::Occupied(mut table) = cache_guard.entry(E::COLLECTION_NAME.to_string()) - && let Entry::Occupied(cached_record) = table.get_mut().entry(id.as_ref().to_vec()) + if let Entry::Occupied(mut table) = cache_guard.entry(E::COLLECTION_NAME) + && let Entry::Occupied(cached_record) = table.get_mut().entry(entity_id.clone()) { cached_record.remove_entry(); }; - let mut deleted_list = self.deleted.write().await; - deleted_list.push(EntityId::from_collection_name(E::COLLECTION_NAME, id.as_ref())?); + // add this entity to the set of items which should be deleted from the persisted db + let mut deleted_set = self.deleted.write().await; + deleted_set.insert(entity_id); Ok(()) } - pub(crate) async fn child_groups(&self, entity: E, persisted_records: Vec) -> CryptoKeystoreResult> + /// Remove an entity by its primary key. + /// + /// Where the primary key has a distinct borrowed form, consider [`Self::remove_borrowed`]. + /// + /// Note that this doesn't return whether or not anything was actually removed because + /// that won't happen until the transaction is committed. + pub(crate) async fn remove<'a, E>(&self, id: &E::PrimaryKey) -> CryptoKeystoreResult<()> + where + E: Entity + EntityDatabaseMutation<'a>, + { + let entity_id = + EntityId::from_primary_key::(id).ok_or(CryptoKeystoreError::UnknownEntity(E::COLLECTION_NAME))?; + self.remove_by_entity_id::(entity_id).await + } + + /// Remove an entity by the borrowed form of its primary key. + /// + /// Note that this doesn't return whether or not anything was actually removed because + /// that won't happen until the transaction is committed. + pub(crate) async fn remove_borrowed<'a, E>(&self, id: &E::BorrowedPrimaryKey) -> CryptoKeystoreResult<()> where - E: crate::entities::Entity + PersistedMlsGroupExt + Sync, + E: EntityDeleteBorrowed<'a> + BorrowPrimaryKey, { - // First get all raw groups from the cache, then deserialize them to enable filtering by there parent id - // matching `entity.id_raw()`. - let cached_records = self - .find_all_in_cache() - .await? - .into_iter() - .filter(|maybe_child: &E| { + let entity_id = EntityId::from_borrowed_primary_key::(id) + .ok_or(CryptoKeystoreError::UnknownEntity(E::COLLECTION_NAME))?; + self.remove_by_entity_id::(entity_id).await + } + + pub(crate) async fn child_groups( + &self, + entity: E, + persisted_records: impl IntoIterator, + ) -> CryptoKeystoreResult> + where + E: Clone + Entity + BorrowPrimaryKey + PersistedMlsGroupExt + Send + Sync, + for<'pk> &'pk ::BorrowedPrimaryKey: KeyType, + { + // First get all raw groups from the cache, then filter by their parent id + let cached_records = self.find_all_in_cache::().await; + let cached_records = cached_records + .iter() + .filter(|maybe_child| { maybe_child .parent_id() - .map(|parent_id| parent_id == entity.id_raw()) + .map(|parent_id| parent_id == entity.borrow_primary_key().bytes().as_ref()) .unwrap_or_default() }) - .collect(); + .map(Arc::as_ref) + .map(Cow::Borrowed); + + let persisted_records = persisted_records.into_iter().map(Cow::Owned); - Ok(self - .merge_records(cached_records, persisted_records, EntityFindParams::default()) - .await) + Ok(self.merge_records(cached_records, persisted_records).await) } - pub(crate) async fn remove_pending_messages_by_conversation_id( - &self, - conversation_id: impl AsRef<[u8]> + Send, - ) -> CryptoKeystoreResult<()> { - // We cannot return an error from `retain()`, so we've got to do this dance with a mutable result. - let mut result = Ok(()); + pub(crate) async fn remove_pending_messages_by_conversation_id(&self, conversation_id: impl AsRef<[u8]> + Send) { + let conversation_id = conversation_id.as_ref(); let mut cache_guard = self.cache.write().await; - if let Entry::Occupied(mut table) = cache_guard.entry(MlsPendingMessage::COLLECTION_NAME.to_string()) { - table.get_mut().retain(|_key, record_bytes| { - postcard::from_bytes::(record_bytes) - .map(|pending_message| pending_message.foreign_id != conversation_id.as_ref()) - .inspect_err(|err| result = Err(err.clone())) - .unwrap_or(false) + if let Entry::Occupied(mut table) = cache_guard.entry(MlsPendingMessage::COLLECTION_NAME) { + table.get_mut().retain(|_key, entity| { + let pending_message = entity + .downcast::() + .expect("table for MlsPendingMessage contains only that type"); + pending_message.foreign_id != conversation_id }); } - - let mut deleted_list = self.deleted.write().await; - deleted_list.push(EntityId::from_collection_name( - MlsPendingMessage::COLLECTION_NAME, - conversation_id.as_ref(), - )?); - result.map_err(Into::into) + drop(cache_guard); + + let mut deleted_set = self.deleted.write().await; + deleted_set.insert( + EntityId::from_primary_key::(&MlsPendingMessagePrimaryKey::from_conversation_id( + conversation_id, + )) + .expect("mls pending messages are proper entities which can be parsed"), + ); } pub(crate) async fn find_pending_messages_by_conversation_id( &self, conversation_id: &[u8], - persisted_records: Vec, + persisted_records: impl IntoIterator, ) -> CryptoKeystoreResult> { - let cached_records = self - .find_all_in_cache::() - .await? - .into_iter() + let persisted_records = persisted_records.into_iter().map(Cow::Owned); + + let cached_records = self.find_all_in_cache::().await; + let cached_records = cached_records + .iter() .filter(|pending_message| pending_message.foreign_id == conversation_id) - .collect(); - let merged_records = self - .merge_records(cached_records, persisted_records, Default::default()) - .await; + .map(Arc::as_ref) + .map(Cow::Borrowed); + + let merged_records = self.merge_records(cached_records, persisted_records).await; Ok(merged_records) } - async fn find_in_cache(&self, id: &[u8]) -> CryptoKeystoreResult> + async fn find_in_cache(&self, entity_id: &EntityId) -> Option> where - E: crate::entities::Entity, + E: Entity + Send + Sync, { let cache_guard = self.cache.read().await; cache_guard .get(E::COLLECTION_NAME) - .and_then(|table| { - table - .get(id) - .map(|record| -> CryptoKeystoreResult<_> { postcard::from_bytes::(record).map_err(Into::into) }) - }) - .transpose() + .and_then(|table| table.get(entity_id).and_then(|entity| entity.downcast())) } /// The result of this function will have different contents for different scenarios: /// * `Some(Some(E))` - the transaction cache contains the record /// * `Some(None)` - the deletion of the record has been cached /// * `None` - there is no information about the record in the cache - pub(crate) async fn find(&self, id: &[u8]) -> CryptoKeystoreResult>> + async fn get_by_entity_id(&self, entity_id: &EntityId) -> Option>> where - E: crate::entities::Entity, + E: Entity + Send + Sync, { - let maybe_cached_record = self.find_in_cache(id).await?; - if let Some(cached_record) = maybe_cached_record { - return Ok(Some(Some(cached_record))); - } - + // when applying our transaction to the real database, we delete after inserting, + // so here we have to check for deletion before we check for existing values let deleted_list = self.deleted.read().await; - if deleted_list.contains(&EntityId::from_collection_name(E::COLLECTION_NAME, id)?) { - return Ok(Some(None)); + if deleted_list.contains(entity_id) { + return Some(None); } - Ok(None) + self.find_in_cache::(entity_id).await.map(Some) } - pub(crate) async fn find_unique>( - &self, - ) -> CryptoKeystoreResult> { - #[cfg(target_family = "wasm")] - let id = &U::ID; - #[cfg(not(target_family = "wasm"))] - let id = &[U::ID as u8]; - let maybe_cached_record = self.find_in_cache::(id).await?; - match maybe_cached_record { - Some(cached_record) => Ok(Some(cached_record)), - _ => { - // The deleted list doesn't have to be checked because unique entities don't implement - // deletion, just replace. So we can directly return None. - Ok(None) - } - } + /// The result of this function will have different contents for different scenarios: + /// * `Some(Some(E))` - the transaction cache contains the record + /// * `Some(None)` - the deletion of the record has been cached + /// * `None` - there is no information about the record in the cache + pub(crate) async fn get(&self, id: &E::PrimaryKey) -> Option>> + where + E: Entity + Send + Sync, + { + let entity_id = EntityId::from_primary_key::(id)?; + self.get_by_entity_id(&entity_id).await } - async fn find_all_in_cache>( - &self, - ) -> CryptoKeystoreResult> { + /// The result of this function will have different contents for different scenarios: + /// * `Some(Some(E))` - the transaction cache contains the record + /// * `Some(None)` - the deletion of the record has been cached + /// * `None` - there is no information about the record in the cache + pub(crate) async fn get_borrowed(&self, id: &E::BorrowedPrimaryKey) -> Option>> + where + E: Entity + BorrowPrimaryKey + Send + Sync, + { + let entity_id = EntityId::from_borrowed_primary_key::(id)?; + self.get_by_entity_id(&entity_id).await + } + + async fn find_all_in_cache(&self) -> Vec> + where + E: Entity + Send + Sync, + { let cache_guard = self.cache.read().await; - let cached_records = cache_guard + cache_guard .get(E::COLLECTION_NAME) .map(|table| { table .values() - .map(|record| postcard::from_bytes::(record).map_err(Into::into)) - .collect::>>() + .map(|record: &dynamic_dispatch::Entity| { + record + .downcast::() + .expect("all entries in this table are of this type") + .clone() + }) + .collect::>() }) - .transpose()? - .unwrap_or_default(); - Ok(cached_records) + .unwrap_or_default() } - pub(crate) async fn find_all>( - &self, - persisted_records: Vec, - params: EntityFindParams, - ) -> CryptoKeystoreResult> { - let cached_records = self.find_all_in_cache().await?; - let merged_records = self.merge_records(cached_records, persisted_records, params).await; + pub(crate) async fn find_all(&self, persisted_records: Vec) -> CryptoKeystoreResult> + where + E: Clone + Entity + Send + Sync, + { + let cached_records = self.find_all_in_cache().await; + let merged_records = self + .merge_records( + cached_records.iter().map(Arc::as_ref).map(Cow::Borrowed), + persisted_records.into_iter().map(Cow::Owned), + ) + .await; Ok(merged_records) } - pub(crate) async fn find_many>( - &self, - persisted_records: Vec, - ids: &[Vec], - ) -> CryptoKeystoreResult> { - let records = self - .find_all(persisted_records, EntityFindParams::default()) - .await? - .into_iter() - .filter(|record| ids.contains(&record.id_raw().to_vec())) - .collect(); - Ok(records) - } - /// Build a single list of unique records from two potentially overlapping lists. /// In case of overlap, records in `records_a` are prioritized. /// Identity from the perspective of this function is determined by the output of - /// [crate::entities::Entity::merge_key]. + /// [Entity::merge_key]. /// /// Further, the output list of records is built with respect to the provided [EntityFindParams] /// and the deleted records cached in this [Self] instance. - async fn merge_records>( + async fn merge_records<'a, E>( &self, - records_a: Vec, - records_b: Vec, - params: EntityFindParams, - ) -> Vec { - let mut merged = records_a.into_iter().chain(records_b).unique_by(|e| e.merge_key()); - + records_a: impl IntoIterator>, + records_b: impl IntoIterator>, + ) -> Vec + where + E: Clone + Entity, + { let deleted_records = self.deleted.read().await; - let merged: &mut dyn Iterator = if params.reverse { &mut merged.rev() } else { &mut merged }; - - merged - .filter(|record| !Self::record_is_in_deleted_list(record, &deleted_records)) - .skip(params.offset.unwrap_or(0) as usize) - .take(params.limit.unwrap_or(u32::MAX) as usize) + records_a + .into_iter() + .chain(records_b) + .unique_by(|e| e.primary_key().bytes().into_owned()) + .filter_map(|record| { + let id = EntityId::from_entity(record.as_ref())?; + (!deleted_records.contains(&id)).then_some(record.into_owned()) + }) .collect() } - fn record_is_in_deleted_list>( - record: &E, - deleted_records: &[EntityId], - ) -> bool { - let id = EntityId::from_collection_name(E::COLLECTION_NAME, record.id_raw()); - let Ok(id) = id else { return false }; - deleted_records.contains(&id) - } -} + /// Persists all the operations in the database. It will effectively open a transaction + /// internally, perform all the buffered operations and commit. + pub(crate) async fn commit(&self, db: &Database) -> Result<(), CryptoKeystoreError> { + let conn = db.conn().await?; + let mut conn = conn.conn().await; + + let cache = self.cache.read().await; + let deleted_ids = self.deleted.read().await; + + let table_names_with_deletion = deleted_ids.iter().map(|entity_id| entity_id.collection_name()); + let table_names_with_save = cache + .values() + .flat_map(|table| table.keys()) + .map(|entity_id| entity_id.collection_name()); + let mut tables = table_names_with_deletion + .chain(table_names_with_save) + .collect::>(); + + if tables.is_empty() { + log::debug!("Empty transaction was committed."); + return Ok(()); + } -/// Persist all records cached in `$keystore_transaction` (first argument), -/// using a transaction on `$db` (second argument). -/// Use the provided types to read from the cache and write to the `$db`. -/// -/// # Examples -/// ```rust,ignore -/// let transaction = KeystoreTransaction::new(); -/// let db = Connection::new(); -/// -/// // Commit records of all provided types -/// commit_transaction!( -/// transaction, db, -/// [ -/// (identifier_01, StoredCredential), -/// (identifier_02, StoredSignatureKeypair), -/// ], -/// ); -/// -/// // Commit records of provided types in the first list. Commit records of types in the second -/// // list only if the "proteus-keystore" cargo feature is enabled. -/// commit_transaction!( -/// transaction, db, -/// [ -/// (identifier_01, StoredCredential), -/// (identifier_02, StoredSignatureKeypair), -/// ], -/// proteus_types: [ -/// (identifier_03, ProteusPrekey), -/// (identifier_04, ProteusIdentity), -/// (identifier_05, ProteusSession) -/// ] -/// ); -/// ``` -macro_rules! commit_transaction { - ($keystore_transaction:expr_2021, $db:expr_2021, [ $( ($records:ident, $entity:ty) ),*], proteus_types: [ $( ($conditional_records:ident, $conditional_entity:ty) ),*]) => { - #[cfg(feature = "proteus-keystore")] - commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*], [ $( ($conditional_records, $conditional_entity) ),*]); - - #[cfg(not(feature = "proteus-keystore"))] - commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*]); - }; - ($keystore_transaction:expr_2021, $db:expr_2021, $([ $( ($records:ident, $entity:ty) ),*]),*) => { - let cached_collections = ( $( $( - $keystore_transaction.find_all_in_cache::<$entity>().await?, - )* )* ); - - let ( $( $( $records, )* )* ) = cached_collections; - - let conn = $db.conn().await?; - let mut conn = conn.conn().await; - let deleted_ids = $keystore_transaction.deleted.read().await; - - let mut tables = Vec::new(); - $( $( - if !$records.is_empty() { - tables.push(<$entity>::COLLECTION_NAME); - } - )* )* - - for deleted_id in deleted_ids.iter() { - tables.push(deleted_id.collection_name()); - } - - if tables.is_empty() { - log::debug!("Empty transaction was committed."); - return Ok(()); - } - - #[cfg(target_family = "wasm")] - let tx = conn.new_transaction(&tables).await?; - #[cfg(not(target_family = "wasm"))] - let tx = conn.transaction()?.into(); - - $( $( - if !$records.is_empty() { - for record in $records { - dynamic_dispatch::execute_save(&tx, &record.to_transaction_entity()).await?; - } - } - )* )* + tables.sort_unstable(); + tables.dedup(); + // open a database transaction + #[cfg(target_family = "wasm")] + let tx = conn.new_transaction(&tables).await?; + #[cfg(not(target_family = "wasm"))] + let tx = conn.transaction()?.into(); - for deleted_id in deleted_ids.iter() { - dynamic_dispatch::execute_delete(&tx, deleted_id).await? + for entity in cache.values().flat_map(|table| table.values()) { + entity.execute_save(&tx).await?; } - tx.commit_tx().await?; - }; -} + for deleted_id in deleted_ids.iter() { + deleted_id.execute_delete(&tx).await?; + } -impl KeystoreTransaction { - /// Persists all the operations in the database. It will effectively open a transaction - /// internally, perform all the buffered operations and commit. - pub(crate) async fn commit(&self, db: &Database) -> Result<(), CryptoKeystoreError> { - commit_transaction!( - self, db, - [ - (identifier_01, StoredCredential), - // (identifier_02, StoredSignatureKeypair), - (identifier_03, StoredHpkePrivateKey), - (identifier_04, StoredEncryptionKeyPair), - (identifier_05, StoredEpochEncryptionKeypair), - (identifier_06, StoredPskBundle), - (identifier_07, StoredKeypackage), - (identifier_08, PersistedMlsGroup), - (identifier_09, PersistedMlsPendingGroup), - (identifier_10, MlsPendingMessage), - (identifier_11, StoredE2eiEnrollment), - // (identifier_12, E2eiRefreshToken), - (identifier_13, E2eiAcmeCA), - (identifier_14, E2eiIntermediateCert), - (identifier_15, E2eiCrl), - (identifier_16, ConsumerData) - ], - proteus_types: [ - (identifier_17, ProteusPrekey), - (identifier_18, ProteusIdentity), - (identifier_19, ProteusSession) - ] - ); + // and commit everything + tx.commit_tx().await?; Ok(()) } diff --git a/keystore/tests/common.rs b/keystore/tests/common.rs index c6a980e6c4..3b26a6e1ce 100644 --- a/keystore/tests/common.rs +++ b/keystore/tests/common.rs @@ -61,15 +61,18 @@ impl KeystoreTestContext { impl Drop for KeystoreTestContext { fn drop(&mut self) { if let Some(store) = self.store.take() { - let commit_and_wipe = async move { - store.commit_transaction().await.expect("Could not commit transaction"); + let rollback_and_wipe = async move { + store + .rollback_transaction() + .await + .expect("could not rollback transaction"); store.wipe().await.expect("Could not wipe store"); }; #[cfg(not(target_family = "wasm"))] - futures_lite::future::block_on(commit_and_wipe); + futures_lite::future::block_on(rollback_and_wipe); #[cfg(target_family = "wasm")] - wasm_bindgen_futures::spawn_local(commit_and_wipe); + wasm_bindgen_futures::spawn_local(rollback_and_wipe); } } } diff --git a/keystore/tests/mls.rs b/keystore/tests/mls.rs index f8bc5d92aa..82f2107933 100644 --- a/keystore/tests/mls.rs +++ b/keystore/tests/mls.rs @@ -59,7 +59,7 @@ mod tests { #[apply(all_storage_types)] pub async fn can_add_read_delete_credential_openmls_traits(context: KeystoreTestContext) { - use core_crypto_keystore::connection::FetchFromDatabase; + use core_crypto_keystore::{Sha256Hash, traits::FetchFromDatabase}; use itertools::Itertools as _; use openmls_basic_credential::SignatureKeyPair; @@ -92,7 +92,7 @@ mod tests { let (credential_from_store,) = backend .key_store() - .find_all::(Default::default()) + .load_all::() .await .unwrap() .into_iter() @@ -102,7 +102,7 @@ mod tests { backend .key_store() - .remove::(credential_from_store.public_key.clone()) + .remove::(&Sha256Hash::hash_from(&credential_from_store.public_key)) .await .unwrap(); } diff --git a/keystore/tests/z_entities.rs b/keystore/tests/z_entities.rs index 75559223f4..9e17aaaa76 100644 --- a/keystore/tests/z_entities.rs +++ b/keystore/tests/z_entities.rs @@ -35,8 +35,7 @@ macro_rules! test_for_entity { crate::tests_impl::can_remove_entity::<$entity>(&store, entity).await; let ignore_count = pat_to_bool!($($ignore_entity_count)?); - let ignore_find_many = pat_to_bool!($($ignore_find_many)?); - crate::tests_impl::can_list_entities_with_find_many::<$entity>(&store, ignore_count, ignore_find_many).await; + crate::tests_impl::insert_count_entities::<$entity>(&store).await; crate::tests_impl::can_list_entities_with_find_all::<$entity>(&store, ignore_count).await; } }; @@ -45,29 +44,39 @@ macro_rules! test_for_entity { #[cfg(test)] mod tests_impl { use core_crypto_keystore::{ - connection::{FetchFromDatabase, KeystoreDatabaseConnection}, - entities::{Entity, EntityFindParams, EntityTransactionExt, MlsPendingMessage, StoredCredential}, + connection::KeystoreDatabaseConnection, + entities::{MlsPendingMessage, StoredCredential}, + traits::{Entity, EntityDatabaseMutation, FetchFromDatabase as _, PrimaryKey as _}, }; use super::common::*; use crate::{ENTITY_COUNT, utils::EntityRandomUpdateExt}; - pub(crate) async fn can_save_entity< - R: EntityRandomUpdateExt + Entity + EntityTransactionExt + Sync, - >( - store: &CryptoKeystore, - ) -> R { + pub(crate) async fn can_save_entity<'a, R>(store: &CryptoKeystore) -> R + where + R: Clone + + EntityRandomUpdateExt + + Entity + + EntityDatabaseMutation<'a> + + Send + + Sync, + { let entity = R::random(); store.save(entity.clone()).await.unwrap(); entity } - pub(crate) async fn can_find_entity< - R: EntityRandomUpdateExt + Entity + 'static + Sync, - >( - store: &CryptoKeystore, - entity: &R, - ) { + pub(crate) async fn can_find_entity<'a, R>(store: &CryptoKeystore, entity: &R) + where + R: Clone + + std::fmt::Debug + + Eq + + EntityRandomUpdateExt + + Entity + + EntityDatabaseMutation<'a> + + Send + + Sync, + { if let Some(pending_message) = entity.downcast::() { let pending_message_from_store = store .find_pending_messages_by_conversation_id(&pending_message.foreign_id) @@ -78,71 +87,76 @@ mod tests_impl { assert_eq!(*pending_message, pending_message_from_store); } else if let Some(credential) = entity.downcast::() { let mut credential_from_store = store - .find::(&entity.merge_key()) + .get::(&credential.primary_key()) .await .unwrap() .unwrap(); credential_from_store.equalize(); assert_eq!(*credential, credential_from_store); } else { - let mut entity_from_store = store.find::(entity.id_raw()).await.unwrap().unwrap(); + let primary_key = entity.primary_key(); + let mut entity_from_store = store.get::(&primary_key).await.unwrap().unwrap(); entity_from_store.equalize(); assert_eq!(*entity, entity_from_store); }; } - pub(crate) async fn can_update_entity< - R: EntityRandomUpdateExt + Entity + EntityTransactionExt + Sync, - >( - store: &CryptoKeystore, - entity: &mut R, - ) { + pub(crate) async fn can_update_entity<'a, R>(store: &CryptoKeystore, entity: &mut R) + where + R: Clone + + std::fmt::Debug + + Eq + + EntityRandomUpdateExt + + Entity + + EntityDatabaseMutation<'a> + + Send + + Sync, + { entity.random_update(); store.save(entity.clone()).await.unwrap(); - let entity2: R = store.find(entity.id_raw()).await.unwrap().unwrap(); + let entity2: R = store.get(&entity.primary_key()).await.unwrap().unwrap(); assert_eq!(*entity, entity2); } - pub(crate) async fn can_remove_entity< - R: EntityRandomUpdateExt + Entity + EntityTransactionExt + Sync, - >( - store: &CryptoKeystore, - entity: R, - ) { - store.remove::(entity.id_raw()).await.unwrap(); - let entity2: Option = store.find(entity.id_raw()).await.unwrap(); + pub(crate) async fn can_remove_entity<'a, R>(store: &CryptoKeystore, entity: R) + where + R: Clone + + EntityRandomUpdateExt + + Entity + + EntityDatabaseMutation<'a> + + Send + + Sync, + { + store.remove::(&entity.primary_key()).await.unwrap(); + let entity2: Option = store.get(&entity.primary_key()).await.unwrap(); assert!(entity2.is_none()); } - pub(crate) async fn can_list_entities_with_find_many< - R: EntityRandomUpdateExt + Entity + EntityTransactionExt + Sync, - >( - store: &CryptoKeystore, - ignore_entity_count: bool, - ignore_find_many: bool, - ) { - let mut ids: Vec> = vec![]; + pub(super) async fn insert_count_entities<'a, R>(store: &CryptoKeystore) + where + R: Clone + + EntityRandomUpdateExt + + Entity + + EntityDatabaseMutation<'a> + + Send + + Sync, + { for _ in 0..ENTITY_COUNT { let entity = R::random(); - ids.push(entity.id_raw().to_vec()); store.save(entity).await.unwrap(); } - - if !ignore_find_many { - let entities = store.find_many::(&ids).await.unwrap(); - if !ignore_entity_count { - assert_eq!(entities.len(), ENTITY_COUNT); - } - } } - pub(crate) async fn can_list_entities_with_find_all< - R: EntityRandomUpdateExt + Entity + Sync, - >( - store: &CryptoKeystore, - ignore_entity_count: bool, - ) { - let entities = store.find_all::(EntityFindParams::default()).await.unwrap(); + pub(crate) async fn can_list_entities_with_find_all<'a, R>(store: &CryptoKeystore, ignore_entity_count: bool) + where + R: Clone + + EntityRandomUpdateExt + + Entity + + EntityDatabaseMutation<'a> + + Send + + Sync, + { + let entities = store.load_all::().await.unwrap(); if !ignore_entity_count { assert_eq!(entities.len(), ENTITY_COUNT); } @@ -175,6 +189,7 @@ mod tests { test_for_entity!(test_e2ei_intermediate_cert, E2eiIntermediateCert); test_for_entity!(test_e2ei_crl, E2eiCrl); test_for_entity!(test_e2ei_enrollment, StoredE2eiEnrollment ignore_update:true); + test_for_entity!(test_e2ei_acme_ca, E2eiAcmeCA ignore_entity_count:true ignore_find_many:true); cfg_if::cfg_if! { if #[cfg(feature = "proteus-keystore")] { @@ -212,12 +227,51 @@ mod tests { store.rollback_transaction().await.unwrap(); store.new_transaction().await.unwrap(); } + + #[apply(all_storage_types)] + async fn can_save_and_load_consumer_data(context: KeystoreTestContext) { + use core_crypto_keystore::traits::FetchFromDatabase as _; + + eprintln!("creating store"); + let store = context.store(); + + eprintln!("checking consumer data before it exists"); + assert!(!store.exists::().await.unwrap()); + let consumer_data = store.get_unique::().await.unwrap(); + assert!(consumer_data.is_none()); + + eprintln!("saving some consumer data"); + const DATA: &[u8] = b"here is some arbitrary data"; + store + .save(ConsumerData { + content: DATA.to_owned(), + }) + .await + .unwrap(); + + // from transaction + eprintln!("checking retrieving consumer data from active transaction"); + assert!(store.exists::().await.unwrap()); + let consumer_data = store.get_unique::().await.unwrap().unwrap(); + assert_eq!(consumer_data.content, DATA); + + eprintln!("committing transaction"); + store.commit_transaction().await.unwrap(); + // don't forget to open a new (blank) transaction + store.new_transaction().await.unwrap(); + + // from storage (fallthrough) + eprintln!("checking retrieving consumer data from storage"); + assert!(store.exists::().await.unwrap()); + let consumer_data = store.get_unique::().await.unwrap().unwrap(); + assert_eq!(consumer_data.content, DATA); + } } #[cfg(test)] pub mod utils { use core_crypto_keystore::entities::{ - MlsPendingMessage, PersistedMlsGroup, PersistedMlsPendingGroup, ProteusSession, StoredCredential, + E2eiAcmeCA, MlsPendingMessage, PersistedMlsGroup, PersistedMlsPendingGroup, ProteusSession, StoredCredential, StoredE2eiEnrollment, StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle, }; @@ -337,6 +391,7 @@ pub mod utils { impl_entity_random_update_ext!(MlsPendingMessage, id_field = foreign_id, blob_fields = [message,]); impl_entity_random_update_ext!(StoredE2eiEnrollment, id_field = id, blob_fields = [content,]); impl_entity_random_update_ext!(StoredEpochEncryptionKeypair, id_field = id, blob_fields = [keypairs,]); + impl_entity_random_update_ext!(E2eiAcmeCA, blob_fields = [content,]); impl EntityRandomExt for core_crypto_keystore::entities::E2eiIntermediateCert { fn random() -> Self {