Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions crypto-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ extern crate proc_macro;

use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
// use quote::quote;
use syn::{
Attribute, Block, FnArg, ItemFn, ReturnType, Visibility, parse_macro_input, punctuated::Punctuated, token::Comma,
};
Expand All @@ -18,8 +18,11 @@ mod idempotent;
/// To be used internally inside the `core-crypto-keystore` crate only.
#[proc_macro_derive(Entity, attributes(entity, id))]
pub fn derive_entity(input: TokenStream) -> TokenStream {
let parsed = parse_macro_input!(input as KeyStoreEntity).flatten();
TokenStream::from(quote! { #parsed })
// TODO, DO NOT MERGE THIS PR UNTIL ATTENDED TO
// temporarily disable deriving entity
let _parsed = parse_macro_input!(input as KeyStoreEntity).flatten();
TokenStream::new()
// TokenStream::from(quote! { #parsed })
}

/// Will drop current MLS group in memory and replace it with the one in the keystore.
Expand Down
1 change: 0 additions & 1 deletion keystore/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ dummy-entity = []

[dependencies]
thiserror.workspace = true
cfg-if.workspace = true
derive_more.workspace = true
hex.workspace = true
zeroize = { workspace = true, features = ["zeroize_derive"] }
Expand Down
275 changes: 145 additions & 130 deletions keystore/src/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
use std::{fmt, ops::Deref};

use sha2::{Digest as _, Sha256};
use zeroize::{Zeroize, ZeroizeOnDrop};

pub mod platform {
cfg_if::cfg_if! {
if #[cfg(target_family = "wasm")] {
mod wasm;
pub use self::wasm::WasmConnection as KeystoreDatabaseConnection;
pub use wasm::storage;
pub use self::wasm::storage::WasmStorageTransaction as TransactionWrapper;
} else {
mod generic;
pub use self::generic::SqlCipherConnection as KeystoreDatabaseConnection;
pub use self::generic::TransactionWrapper;
}
}
#[cfg(not(target_family = "wasm"))]
mod generic;
#[cfg(target_family = "wasm")]
mod wasm;

#[cfg(not(target_family = "wasm"))]
pub use self::generic::{SqlCipherConnection as KeystoreDatabaseConnection, TransactionWrapper};
#[cfg(target_family = "wasm")]
pub use self::wasm::{
WasmConnection as KeystoreDatabaseConnection, storage, storage::WasmStorageTransaction as TransactionWrapper,
};
}

use std::{ops::DerefMut, sync::Arc};
use std::{
fmt,
ops::{Deref, DerefMut},
sync::Arc,
};

use async_lock::{Mutex, MutexGuard, Semaphore};
use sha2::{Digest as _, Sha256};
use zeroize::{Zeroize, ZeroizeOnDrop};

pub use self::platform::*;
use crate::{
CryptoKeystoreError, CryptoKeystoreResult,
entities::{Entity, EntityFindParams, EntityTransactionExt, MlsPendingMessage, StringEntityId, UniqueEntity},
CryptoKeystoreError, CryptoKeystoreResult, Entity, EntityTransactionExt, FetchFromDatabase, UniqueEntity,
transaction::KeystoreTransaction,
};

Expand Down Expand Up @@ -135,32 +133,6 @@ pub struct Database {

const ALLOWED_CONCURRENT_TRANSACTIONS_COUNT: usize = 1;

/// Interface to fetch from the database either from the connection directly or through a
/// transaaction
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
pub trait FetchFromDatabase: Send + Sync {
async fn find<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
&self,
id: impl AsRef<[u8]> + Send,
) -> CryptoKeystoreResult<Option<E>>;

async fn find_unique<U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
&self,
) -> CryptoKeystoreResult<U>;

async fn find_all<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
&self,
params: EntityFindParams,
) -> CryptoKeystoreResult<Vec<E>>;

async fn find_many<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
&self,
ids: &[Vec<u8>],
) -> CryptoKeystoreResult<Vec<E>>;
async fn count<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(&self) -> CryptoKeystoreResult<usize>;
}

// SAFETY: this has mutexes and atomics protecting underlying data so this is safe to share between threads
unsafe impl Send for Database {}
// SAFETY: this has mutexes and atomics protecting underlying data so this is safe to share between threads
Expand Down Expand Up @@ -302,12 +274,10 @@ impl Database {
Ok(())
}

pub async fn child_groups<
pub async fn child_groups<E>(&self, entity: E) -> CryptoKeystoreResult<Vec<E>>
where
E: Entity<ConnectionType = KeystoreDatabaseConnection> + crate::entities::PersistedMlsGroupExt + Sync,
>(
&self,
entity: E,
) -> CryptoKeystoreResult<Vec<E>> {
{
let mut conn = self.conn().await?;
let persisted_records = entity.child_groups(conn.deref_mut()).await?;

Expand All @@ -318,47 +288,45 @@ impl Database {
transaction.child_groups(entity, persisted_records).await
}

pub async fn save<E: Entity<ConnectionType = KeystoreDatabaseConnection> + Sync + EntityTransactionExt>(
&self,
entity: E,
) -> CryptoKeystoreResult<E> {
pub async fn save<'a, E>(&self, entity: E) -> CryptoKeystoreResult<E>
where
E: Entity<ConnectionType = KeystoreDatabaseConnection> + Sync + EntityTransactionExt<'a>,
{
let transaction_guard = self.transaction.lock().await;
let Some(transaction) = transaction_guard.as_ref() else {
return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
};
transaction.save_mut(entity).await
}

pub async fn remove<
E: Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt,
pub async fn remove<'a, E, S>(&self, id: S) -> CryptoKeystoreResult<()>
where
E: Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt<'a>,
S: AsRef<[u8]>,
>(
&self,
id: S,
) -> CryptoKeystoreResult<()> {
{
let transaction_guard = self.transaction.lock().await;
let Some(transaction) = transaction_guard.as_ref() else {
return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
};
transaction.remove::<E, S>(id).await
}

pub async fn find_pending_messages_by_conversation_id(
&self,
conversation_id: &[u8],
) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
let mut conn = self.conn().await?;
let persisted_records =
MlsPendingMessage::find_all_by_conversation_id(&mut conn, conversation_id, Default::default()).await?;

let transaction_guard = self.transaction.lock().await;
let Some(transaction) = transaction_guard.as_ref() else {
return Ok(persisted_records);
};
transaction
.find_pending_messages_by_conversation_id(conversation_id, persisted_records)
.await
}
// pub async fn find_pending_messages_by_conversation_id(
// &self,
// conversation_id: &[u8],
// ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
// let mut conn = self.conn().await?;
// let persisted_records =
// MlsPendingMessage::find_all_by_conversation_id(&mut conn, conversation_id, Default::default()).await?;

// let transaction_guard = self.transaction.lock().await;
// let Some(transaction) = transaction_guard.as_ref() else {
// return Ok(persisted_records);
// };
// transaction
// .find_pending_messages_by_conversation_id(conversation_id, persisted_records)
// .await
// }

pub async fn remove_pending_messages_by_conversation_id(
&self,
Expand All @@ -382,13 +350,89 @@ impl Database {
}
}

// #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
// #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
// impl FetchFromDatabase for Database {
// async fn find<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
// &self,
// id: impl AsRef<[u8]> + Send,
// ) -> CryptoKeystoreResult<Option<E>> {
// // If a transaction is in progress...
// if let Some(transaction) = self.transaction.lock().await.as_ref()
// //... and it has information about this entity, ...
// && let Some(cached_record) = transaction.find::<E>(id.as_ref()).await?
// {
// // ... return that result
// return Ok(cached_record);
// }

// // Otherwise get it from the database
// let mut conn = self.conn().await?;
// E::find_one(&mut conn, &id.as_ref().into()).await
// }

// async fn find_unique<U: UniqueEntity>(&self) -> CryptoKeystoreResult<U> {
// // If a transaction is in progress...
// if let Some(transaction) = self.transaction.lock().await.as_ref()
// //... and it has information about this entity, ...
// && let Some(cached_record) = transaction.find_unique::<U>().await?
// {
// // ... return that result
// return Ok(cached_record);
// }
// // Otherwise get it from the database
// let mut conn = self.conn().await?;
// U::find_unique(&mut conn).await
// }

// async fn find_all<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
// &self,
// params: EntityFindParams,
// ) -> CryptoKeystoreResult<Vec<E>> {
// let mut conn = self.conn().await?;
// let persisted_records = E::find_all(&mut conn, params.clone()).await?;

// let transaction_guard = self.transaction.lock().await;
// let Some(transaction) = transaction_guard.as_ref() else {
// return Ok(persisted_records);
// };
// transaction.find_all(persisted_records, params).await
// }

// async fn find_many<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
// &self,
// ids: &[Vec<u8>],
// ) -> CryptoKeystoreResult<Vec<E>> {
// let entity_ids: Vec<StringEntityId> = ids.iter().map(|id| id.as_slice().into()).collect();
// let mut conn = self.conn().await?;
// let persisted_records = E::find_many(&mut conn, &entity_ids).await?;

// let transaction_guard = self.transaction.lock().await;
// let Some(transaction) = transaction_guard.as_ref() else {
// return Ok(persisted_records);
// };
// transaction.find_many(persisted_records, ids).await
// }

// async fn count<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(&self) -> CryptoKeystoreResult<usize> {
// if self.transaction.lock().await.is_some() {
// // Unfortunately, we have to do this because of possible record id overlap
// // between cache and db.
// return Ok(self.find_all::<E>(Default::default()).await?.len());
// };
// let mut conn = self.conn().await?;
// E::count(&mut conn).await
// }
// }

#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
impl FetchFromDatabase for Database {
async fn find<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
&self,
id: impl AsRef<[u8]> + Send,
) -> CryptoKeystoreResult<Option<E>> {
/// Get an instance of `E` from the database by its primary key.
async fn get<E>(&self, id: &<E as Entity>::PrimaryKey) -> CryptoKeystoreResult<Option<E>>
where
E: Entity<ConnectionType = KeystoreDatabaseConnection>,
{
// If a transaction is in progress...
if let Some(transaction) = self.transaction.lock().await.as_ref()
//... and it has information about this entity, ...
Expand All @@ -400,59 +444,30 @@ impl FetchFromDatabase for Database {

// Otherwise get it from the database
let mut conn = self.conn().await?;
E::find_one(&mut conn, &id.as_ref().into()).await
}

async fn find_unique<U: UniqueEntity>(&self) -> CryptoKeystoreResult<U> {
// If a transaction is in progress...
if let Some(transaction) = self.transaction.lock().await.as_ref()
//... and it has information about this entity, ...
&& let Some(cached_record) = transaction.find_unique::<U>().await?
{
// ... return that result
return Ok(cached_record);
}
// Otherwise get it from the database
let mut conn = self.conn().await?;
U::find_unique(&mut conn).await
E::get(&mut conn, &id.as_ref().into()).await
}

async fn find_all<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
&self,
params: EntityFindParams,
) -> CryptoKeystoreResult<Vec<E>> {
let mut conn = self.conn().await?;
let persisted_records = E::find_all(&mut conn, params.clone()).await?;

let transaction_guard = self.transaction.lock().await;
let Some(transaction) = transaction_guard.as_ref() else {
return Ok(persisted_records);
};
transaction.find_all(persisted_records, params).await
/// Count the number of `E`s in the database.
async fn count<E>(&self) -> CryptoKeystoreResult<u32>
where
E: Entity<ConnectionType = KeystoreDatabaseConnection>,
{
todo!()
}

async fn find_many<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
&self,
ids: &[Vec<u8>],
) -> CryptoKeystoreResult<Vec<E>> {
let entity_ids: Vec<StringEntityId> = ids.iter().map(|id| id.as_slice().into()).collect();
let mut conn = self.conn().await?;
let persisted_records = E::find_many(&mut conn, &entity_ids).await?;

let transaction_guard = self.transaction.lock().await;
let Some(transaction) = transaction_guard.as_ref() else {
return Ok(persisted_records);
};
transaction.find_many(persisted_records, ids).await
/// Load all `E`s from the database.
async fn load_all<E>(&self) -> CryptoKeystoreResult<Vec<E>>
where
E: Entity<ConnectionType = KeystoreDatabaseConnection>,
{
todo!()
}

async fn count<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(&self) -> CryptoKeystoreResult<usize> {
if self.transaction.lock().await.is_some() {
// Unfortunately, we have to do this because of possible record id overlap
// between cache and db.
return Ok(self.find_all::<E>(Default::default()).await?.len());
};
let mut conn = self.conn().await?;
E::count(&mut conn).await
/// Get the requested unique entity from the database.
async fn get_unique<U>(&self) -> CryptoKeystoreResult<Option<U>>
where
U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>,
{
todo!()
}
}
Loading