diff --git a/Cargo.lock b/Cargo.lock index 82f750f..5effb29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,40 @@ dependencies = [ "memchr", ] +[[package]] +name = "anyhow" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" + +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + [[package]] name = "backtrace" version = "0.3.74" @@ -120,6 +154,95 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "gimli" version = "0.31.0" @@ -231,6 +354,12 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "prettyplease" version = "0.2.25" @@ -353,6 +482,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + [[package]] name = "sqlparser" version = "0.52.0" @@ -378,6 +516,8 @@ dependencies = [ name = "static_sqlite" version = "0.1.0" dependencies = [ + "anyhow", + "futures", "static_sqlite_async", "static_sqlite_core", "static_sqlite_macros", @@ -389,7 +529,9 @@ dependencies = [ name = "static_sqlite_async" version = "0.1.0" dependencies = [ + "async-stream", "crossbeam-channel", + "futures", "static_sqlite_core", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 1e72bd0..ee11665 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,8 @@ resolver = "2" static_sqlite_macros = { path = "static_sqlite_macros", version = "0.1.0" } static_sqlite_core = { path = "static_sqlite_core", version = "0.1.0" } static_sqlite_async = { path = "static_sqlite_async", version = "0.1.0" } - +futures = { version = "0.3" } +anyhow = "1.0.97" [dev-dependencies] tokio = { version = "1", features = ["rt", "sync", "macros"] } trybuild = "1.0" diff --git a/README.md b/README.md index 51369dc..021247f 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,182 @@ async fn main() -> Result<()> { cargo add --git https://github.com/swlkr/static_sqlite ``` + +# Example for Transactions + +Use the methods begin_transaction, commit_transaction and rollback_transaction to manage Sqlite transactions. + + +```rust + + // migration and sql-fn definition goes here + + let db = static_sqlite::open(":memory:").await?; + + migrate(&db).await?; + + db.begin_transaction()?; + insert_row(&db, "test1").await?.first_row()?; + insert_row(&db, "test2").await?.first_row()?; + db.commit_transaction()?; +``` + +# Example for First + +If the name of your statement ends with "_first", the created fn return an Option with the first value instead of a Vec. + +I the query returns more than one rows, it throws an error. + +```rust + sql! { + let migrate = r#" + create table Row ( + id integer primary key autoincrement, + txt text NOT NULL + ) + "#; + + let insert_row = r#" + insert into Row (txt) values (:txt) returning * + "#; + + let select_row = r#" + select * from Row where id = :id + "#; + } + + let db = static_sqlite::open(":memory:").await?; + migrate(&db).await?; + + insert_row(&db, "test1").await?.first_row()?; + insert_row(&db, "test2").await?.first_row()?; + + match select_row_first(&db, 1).await? { + Some(row) => assert_eq!(row.txt, "test1"), + None => panic!("Row 1 not found"), + } +``` + +# Example for Streams + +If the name of your statement ends with "_stream", the created fn return an async Stream instead of a Vec. + +This way you can iterate over large result sets. + +```rust +sql! { + let migrate = r#" + create table Row ( + txt text + ) + "#; + + let insert_row = r#" + insert into Row (txt) values (:txt) returning * + "#; + + let select_rows_stream = r#" + select * from Row + "#; + } + + let db = static_sqlite::open(":memory:").await?; + migrate(&db).await?; + + insert_row(&db, Some("test1")).await?.first_row()?; + insert_row(&db, Some("test2")).await?.first_row()?; + insert_row(&db, Some("test3")).await?.first_row()?; + insert_row(&db, Some("test4")).await?.first_row()?; + + let f = select_rows_stream(&db).await?; + + pin_mut!(f); + + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test1".into())); + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test2".into())); + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test3".into())); + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test4".into())); +} + +``` + +# Example with aliased columns and type-hints + +Sometimes the type of either a bound parameter or a returned column can not be inferred by +sqlite / static_sqlite (see [sqlite3 docs](https://www.sqlite.org/c3ref/column_decltype.html)) + +In this case you can use type-hints to help the static_sqlite to use the correct type. + +To use type-hints your parameter or column name needs to follow the following format: + +``` +__ +``` + +or + +``` +____ +``` + +If not explicitly specified, the parameter or column is assumed to be NOT NULL. + +```rust +sql! { + let migrate = r#" + create table User ( + id integer primary key, + name text unique not null + ); + create table Friendship ( + id integer primary key, + user_id integer not null references User(id), + friend_id integer not null references User(id) + ); + "#; + + let insert_user = r#" + insert into User (name) + values (:name) + returning * + "#; + let create_friendship = r#" + insert into Friendship (user_id, friend_id) + values (:user_id, :friend_id) + returning * + "#; + let get_friendship = r#" + select + u1.name as friend1_name__TEXT, + u2.name as friend2_name__TEXT + from Friendship, User as u1, User as u2 + where Friendship.user_id = u1.id + and Friendship.friend_id = u2.id + and Friendship.id = :friendship_id__INTEGER + "#; +} + + +#[tokio::main] +async fn main() -> Result<()> { + let db = static_sqlite::open(":memory:").await?; + let _ = migrate(&db).await?; + insert_user(&db, "swlkr").await?; + insert_user(&db, "toolbar23").await?; + create_friendship(&db, 1, 2).await?; + + let friends = get_friendship(&db, 1).await?; + + assert_eq!(friends.len(), 1); + assert_eq!(friends.first().unwrap().friend1_name, "swlkr"); + assert_eq!(friends.first().unwrap().friend2_name, "toolbar23"); + + Ok(()) +} +``` + + + # Treesitter ``` diff --git a/src/lib.rs b/src/lib.rs index 1b2404a..76ddb83 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ extern crate self as static_sqlite; pub use static_sqlite_async::{ - execute, execute_all, open, query, rows, Error, FromRow, Result, Savepoint, Sqlite, Value, + execute, execute_all, open, query, query_first, rows, stream, Error, FromRow, Result, + Savepoint, Sqlite, Value, }; pub use static_sqlite_core::FirstRow; pub use static_sqlite_macros::sql; diff --git a/static_sqlite_async/Cargo.toml b/static_sqlite_async/Cargo.toml index 8b22739..04d61bb 100644 --- a/static_sqlite_async/Cargo.toml +++ b/static_sqlite_async/Cargo.toml @@ -7,3 +7,5 @@ edition = "2021" static_sqlite_core = { path = "../static_sqlite_core", version = "0.1.0" } tokio = { version = "1", features = ["sync"] } crossbeam-channel = { version = "0.5" } +futures = { version = "0.3" } +async-stream = "0.3" diff --git a/static_sqlite_async/src/lib.rs b/static_sqlite_async/src/lib.rs index f9597b0..2fba3fa 100644 --- a/static_sqlite_async/src/lib.rs +++ b/static_sqlite_async/src/lib.rs @@ -1,8 +1,9 @@ // Inspired by the incredible tokio-rusqlite crate // https://github.com/programatik29/tokio-rusqlite/blob/master/src/lib.rs -use static_sqlite_core as core; use crossbeam_channel::Sender; +pub use futures::Stream; +use static_sqlite_core as core; use tokio::sync::oneshot; pub use static_sqlite_core::*; @@ -33,9 +34,7 @@ impl Sqlite { return Ok(()); } - result - .unwrap() - .map_err(|e| Error::Sqlite(e.to_string())) + result.unwrap().map_err(|e| Error::Sqlite(e.to_string())) } pub async fn call(&self, function: F) -> Result @@ -54,6 +53,18 @@ impl Sqlite { receiver.await.map_err(|_| Error::ConnectionClosed)? } + + pub async fn begin_transaction(&self) -> Result<()> { + self.call(move |conn| conn.begin_transaction()).await + } + + pub async fn commit_transaction(&self) -> Result<()> { + self.call(move |conn| conn.commit_transaction()).await + } + + pub async fn rollback_transaction(&self) -> Result<()> { + self.call(move |conn| conn.rollback_transaction()).await + } } pub async fn open(path: impl ToString) -> Result { @@ -126,6 +137,41 @@ pub async fn query( conn.call(move |conn| conn.query(sql, ¶ms)).await } +pub async fn query_first( + conn: &Sqlite, + sql: &'static str, + params: Vec, +) -> Result> { + conn.call(move |conn| conn.query_first(sql, ¶ms)).await +} + +pub async fn stream( + conn: &Sqlite, + sql: &'static str, + params: Vec, +) -> Result>> { + let (sender, receiver) = std::sync::mpsc::channel(); + + conn.sender + .send(Message::Execute(Box::new(move |conn| { + let value = conn.iter(sql, ¶ms).unwrap(); + + for item in value { + let res = sender.send(item); + if res.is_err() { + break; + } + } + }))) + .map_err(|_| Error::ConnectionClosed)?; + + Ok(async_stream::stream! { + for item in receiver { + yield item; + } + }) +} + pub async fn rows( conn: Sqlite, sql: &'static str, diff --git a/static_sqlite_core/Cargo.toml b/static_sqlite_core/Cargo.toml index c8eff2c..be2367c 100644 --- a/static_sqlite_core/Cargo.toml +++ b/static_sqlite_core/Cargo.toml @@ -6,4 +6,3 @@ edition = "2021" [dependencies] thiserror = "1" static_sqlite_ffi = { path = "../static_sqlite_ffi" } - diff --git a/static_sqlite_core/src/ffi.rs b/static_sqlite_core/src/ffi.rs index e17cff3..143e494 100644 --- a/static_sqlite_core/src/ffi.rs +++ b/static_sqlite_core/src/ffi.rs @@ -9,6 +9,7 @@ use static_sqlite_ffi::{ use std::{ ffi::{c_char, c_int, CStr, CString, NulError}, + marker::PhantomData, num::TryFromIntError, ops::Deref, str::Utf8Error, @@ -33,6 +34,8 @@ pub enum Error { ConnectionClosed, #[error("sqlite row not found")] RowNotFound, + #[error("sqlite returned too many rows in result")] + TooManyRowsInResult, #[error(transparent)] Utf8Error(#[from] Utf8Error), } @@ -147,6 +150,27 @@ impl Sqlite { self.execute(sql, vec![]) } + pub fn begin_transaction(&self) -> Result<()> { + match self.execute("BEGIN TRANSACTION", vec![]) { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } + + pub fn commit_transaction(&self) -> Result<()> { + match self.execute("COMMIT TRANSACTION", vec![]) { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } + + pub fn rollback_transaction(&self) -> Result<()> { + match self.execute("ROLLBACK TRANSACTION", vec![]) { + Ok(_) => Ok(()), + Err(e) => Err(e), + } + } + pub fn query(&self, sql: &'static str, params: &[Value]) -> Result> { unsafe { let stmt = self.prepare(sql, params)?; @@ -160,25 +184,7 @@ impl Sqlite { .to_string_lossy() .into_owned(); - let value = match sqlite3_column_type(stmt, i) { - 1 => Value::Integer(sqlite3_column_int64(stmt, i)), - 2 => Value::Real(sqlite3_column_double(stmt, i)), - 3 => { - let text = - CStr::from_ptr(sqlite3_column_text(stmt, i) as *const c_char) - .to_string_lossy() - .into_owned(); - Value::Text(text) - } - 4 => { - let len = sqlite3_column_bytes(stmt, i) as usize; - let ptr = sqlite3_column_text(stmt, i); - let slice = std::slice::from_raw_parts(ptr, len); - Value::Blob(slice.to_vec()) - } - _ => Value::Null, - }; - + let value = Self::get_column_value(stmt, i)?; values.push((name, value)); } @@ -186,23 +192,99 @@ impl Sqlite { rows.push(row); } - if sqlite3_finalize(stmt) != 0 { - let error = CStr::from_ptr(sqlite3_errmsg(self.db)) - .to_string_lossy() - .into_owned(); - if error.starts_with("UNIQUE constraint failed: ") { - return Err(Error::UniqueConstraint( - error.replace("UNIQUE constraint failed: ", ""), + Self::finalize_statement(self.db, stmt)?; + + Ok(rows) + } + } + + pub fn query_first( + &self, + sql: &'static str, + params: &[Value], + ) -> Result> { + match self.query(sql, params) { + Ok(rows) => Ok(match rows.len() { + 0 => None, + 1 => Some(rows.into_iter().nth(0).unwrap()), + _ => return Err(Error::TooManyRowsInResult), + }), + Err(e) => Err(e), + } + } + + unsafe fn get_column_value(stmt: *mut sqlite3_stmt, i: c_int) -> Result { + match sqlite3_column_type(stmt, i) { + x if x == static_sqlite_ffi::SQLITE_INTEGER as i32 => { + Ok(Value::Integer(sqlite3_column_int64(stmt, i))) + } + x if x == static_sqlite_ffi::SQLITE_FLOAT as i32 => { + Ok(Value::Real(sqlite3_column_double(stmt, i))) + } + x if x == static_sqlite_ffi::SQLITE_TEXT as i32 => { + let text_ptr = sqlite3_column_text(stmt, i) as *const c_char; + if text_ptr.is_null() { + Ok(Value::Text(String::new())) + } else { + let text = CStr::from_ptr(text_ptr).to_str()?.to_owned(); + Ok(Value::Text(text)) + } + } + x if x == static_sqlite_ffi::SQLITE_BLOB as i32 => { + let len = sqlite3_column_bytes(stmt, i); + if len < 0 { + return Err(Error::Sqlite( + "SQLite returned negative length for BLOB column".into(), )); + } + let len = len as usize; + let ptr = static_sqlite_ffi::sqlite3_column_blob(stmt, i); + if ptr.is_null() { + if len == 0 { + Ok(Value::Blob(vec![])) + } else { + Err(Error::Sqlite("SQLite returned null pointer for non-empty BLOB column (likely out of memory)".into())) + } } else { - return Err(Error::Sqlite(error)); + let slice = std::slice::from_raw_parts(ptr as *const u8, len); + Ok(Value::Blob(slice.to_vec())) } } + x if x == static_sqlite_ffi::SQLITE_NULL as i32 => Ok(Value::Null), + _ => Err(Error::Sqlite(format!( + "Unexpected column type {}", + sqlite3_column_type(stmt, i) + ))), + } + } - Ok(rows) + unsafe fn finalize_statement(db: *mut sqlite3, stmt: *mut sqlite3_stmt) -> Result<()> { + let rc = sqlite3_finalize(stmt); + if rc != static_sqlite_ffi::SQLITE_OK as i32 { + let error = CStr::from_ptr(sqlite3_errmsg(db)) + .to_string_lossy() + .into_owned(); + if error.starts_with("UNIQUE constraint failed: ") { + Err(Error::UniqueConstraint( + error.replace("UNIQUE constraint failed: ", ""), + )) + } else { + Err(Error::Sqlite(error)) + } + } else { + Ok(()) } } + pub fn iter<'a, T: FromRow + 'a>( + &'a self, + sql: &str, + params: &[Value], + ) -> Result> + 'a> { + let stmt = self.prepare(sql, params)?; + Ok(SqliteIterator::new(self, stmt)) + } + pub fn rows(&self, sql: &str, params: &[Value]) -> Result>> { unsafe { let stmt = self.prepare(sql, params)?; @@ -216,43 +298,14 @@ impl Sqlite { .to_string_lossy() .into_owned(); - let value = match sqlite3_column_type(stmt, i) { - 1 => Value::Integer(sqlite3_column_int64(stmt, i)), - 2 => Value::Real(sqlite3_column_double(stmt, i)), - 3 => { - let text = - CStr::from_ptr(sqlite3_column_text(stmt, i) as *const c_char) - .to_string_lossy() - .into_owned(); - Value::Text(text) - } - 4 => { - let len = sqlite3_column_bytes(stmt, i) as usize; - let ptr = sqlite3_column_text(stmt, i); - let slice = std::slice::from_raw_parts(ptr, len); - Value::Blob(slice.to_vec()) - } - _ => Value::Null, - }; - + let value = Self::get_column_value(stmt, i)?; values.push((name, value)); } rows.push(values); } - if sqlite3_finalize(stmt) != 0 { - let error = CStr::from_ptr(sqlite3_errmsg(self.db)) - .to_string_lossy() - .into_owned(); - if error.starts_with("UNIQUE constraint failed: ") { - return Err(Error::UniqueConstraint( - error.replace("UNIQUE constraint failed: ", ""), - )); - } else { - return Err(Error::Sqlite(error)); - } - } + Self::finalize_statement(self.db, stmt)?; Ok(rows) } @@ -581,3 +634,88 @@ impl From<()> for Value { Value::Null } } + +#[derive(Debug)] +pub struct SqliteIterator<'a, T: FromRow> { + db: &'a Sqlite, + stmt: *mut sqlite3_stmt, + finished: bool, + _marker: PhantomData, +} + +impl<'a, T: FromRow> SqliteIterator<'a, T> { + fn new(db: &'a Sqlite, stmt: *mut sqlite3_stmt) -> Self { + SqliteIterator { + db, + stmt, + finished: false, + _marker: PhantomData, + } + } +} + +impl<'a, T: FromRow> Iterator for SqliteIterator<'a, T> { + type Item = Result; + + fn next(&mut self) -> Option { + if self.finished { + return None; + } + + unsafe { + match sqlite3_step(self.stmt) { + SQLITE_ROW => { + let column_count = sqlite3_column_count(self.stmt); + let mut values: Vec<(String, Value)> = vec![]; + + for i in 0..column_count { + let name_ptr = sqlite3_column_name(self.stmt, i); + let name = if name_ptr.is_null() { + format!("column_{}", i) + } else { + match CStr::from_ptr(name_ptr).to_str() { + Ok(s) => s.to_owned(), + Err(e) => return Some(Err(e.into())), + } + }; + + match Sqlite::get_column_value(self.stmt, i) { + Ok(value) => values.push((name, value)), + Err(e) => { + self.finished = true; + return Some(Err(e)); + } + } + } + + match T::from_row(values) { + Ok(row) => Some(Ok(row)), + Err(e) => { + self.finished = true; + Some(Err(e)) + } + } + } + SQLITE_DONE => { + self.finished = true; + None + } + _ => { + self.finished = true; + let error = CStr::from_ptr(sqlite3_errmsg(self.db.db)) + .to_string_lossy() + .into_owned(); + Some(Err(Error::Sqlite(error))) + } + } + } + } +} + +impl<'a, T: FromRow> Drop for SqliteIterator<'a, T> { + fn drop(&mut self) { + unsafe { + let _ = Sqlite::finalize_statement(self.db.db, self.stmt); + } + } +} diff --git a/static_sqlite_core/src/lib.rs b/static_sqlite_core/src/lib.rs index b5ce9e7..850008d 100644 --- a/static_sqlite_core/src/lib.rs +++ b/static_sqlite_core/src/lib.rs @@ -21,6 +21,14 @@ pub fn query( conn.query(sql, params) } +pub fn query_first( + conn: &Sqlite, + sql: &'static str, + params: &[Value], +) -> Result> { + conn.query_first(sql, params) +} + pub fn rows(conn: &Sqlite, sql: &str, params: &[Value]) -> Result>> { conn.rows(sql, params) } diff --git a/static_sqlite_macros/src/lib.rs b/static_sqlite_macros/src/lib.rs index 350bcd4..2c1599d 100644 --- a/static_sqlite_macros/src/lib.rs +++ b/static_sqlite_macros/src/lib.rs @@ -380,7 +380,7 @@ fn migrate_fn(expr: &SqlExpr) -> TokenStream { let SqlExpr { ident, sql, .. } = expr; quote! { - pub async fn #ident(sqlite: &static_sqlite::Sqlite) -> Result<()> { + pub async fn #ident(sqlite: &static_sqlite::Sqlite) -> static_sqlite::Result<()> { let sql = #sql.to_string(); let _ = static_sqlite::execute_all(&sqlite, "create table if not exists __migrations__ (sql text primary key not null);".into()).await?; for stmt in sql.split(";").filter(|s| !s.trim().is_empty()) { @@ -395,6 +395,12 @@ fn migrate_fn(expr: &SqlExpr) -> TokenStream { } } +enum FunctionType { + QueryVec, + QueryOption, + Stream, +} + fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result> { let mut output = vec![]; for expr in exprs { @@ -420,84 +426,195 @@ fn fn_tokens(db: &Sqlite, schema: &Schema, exprs: &[&SqlExpr]) -> Result {} }; } - let input_schema_rows: Vec<&&SchemaRow> = inputs - .iter() - .filter_map(|col_name| schema_rows.iter().find(|row| &row.column_name == col_name)) - .collect(); - let fn_args = input_schema_rows + + let fn_args = inputs .iter() - .map(|field| { - let field_type = match field.column_type.as_str() { - "BLOB" => quote! { Vec }, - "INTEGER" => quote! { i64 }, - "REAL" | "DOUBLE" => quote! { f64 }, - "TEXT" => quote! { impl ToString }, - _ => unimplemented!("Sqlite fn arg not supported"), - }; - let field_name = Ident::new(&field.column_name, expr.ident.span()); - let not_null = field.not_null; - let pk = field.pk; - match (pk, not_null) { - (0, 0) => quote! { #field_name: Option<#field_type> }, - _ => quote! { #field_name: #field_type }, + .map(|aliases_column_name| { + match parse_type_hinted_column_name(aliases_column_name, &schema_rows) { + TypedToken::FromTypeHint(type_hint) => { + let field_name = Ident::new(&type_hint.alias, expr.ident.span()); + let field_type = + create_fn_argument_type(&type_hint.alias, &type_hint.column_type); + match type_hint.not_null { + 0 => quote! { #field_name: Option<#field_type> }, + _ => quote! { #field_name: #field_type }, + } + } + TypedToken::FromSchemaRow(schema_row) => { + let field_name = Ident::new(&schema_row.column_name, expr.ident.span()); + let field_type = + create_fn_argument_type(aliases_column_name, &schema_row.column_type); + match (schema_row.pk, schema_row.not_null) { + (0, 0) => quote! { #field_name: Option<#field_type> }, + _ => quote! { #field_name: #field_type }, + } + } } }) .collect::>(); - let params = input_schema_rows + + let params = inputs .iter() - .map(|field| { - let not_null = field.not_null; - let name = Ident::new(&field.column_name, expr.ident.span()); - match field.column_type.as_str() { - "BLOB" => { - quote! { #name.into() } + .map(|aliases_column_name| { + match parse_type_hinted_column_name(aliases_column_name, &schema_rows) { + TypedToken::FromTypeHint(type_hint) => { + let field_name = Ident::new(&type_hint.alias, expr.ident.span()); + create_binding_value(&type_hint.column_type, type_hint.not_null, field_name) + } + TypedToken::FromSchemaRow(schema_row) => { + let field_name = Ident::new(&schema_row.column_name, expr.ident.span()); + create_binding_value( + &schema_row.column_type, + schema_row.not_null, + field_name, + ) } - "INTEGER" => quote! { #name.into() }, - "REAL" | "DOUBLE" => quote! { #name.into() }, - "TEXT" => match not_null { - 1 => quote! { - #name.to_string().into() - }, - 0 => quote! { - match #name { - Some(val) => val.to_string().into(), - None => static_sqlite::Value::Null - } - }, - _ => unreachable!(), - }, - _ => unimplemented!("Sqlite param not supported"), } }) .collect::>(); + + let fn_type = if expr.ident.to_string().ends_with("_stream") { + FunctionType::Stream + } else if expr.ident.to_string().ends_with("_first") { + FunctionType::QueryOption + } else { + FunctionType::QueryVec + }; + let ident = &expr.ident; let outputs = output_column_names(db, expr)?; let pascal_case = snake_to_pascal_case(&ident); - let cols: Vec = outputs + + let output_typed = outputs .iter() - .filter_map(|col_name| { - schema_rows - .iter() - .find(|row| &row.column_name == col_name) - .cloned() - .cloned() - }) - .collect(); - let struct_tokens = struct_tokens(expr.ident.span(), &pascal_case, &cols); + .map(|output| parse_type_hinted_column_name(output, &schema_rows)) + .collect::>(); + + let struct_tokens = struct_tokens(expr.ident.span(), &pascal_case, &output_typed); + let sql = &expr.sql; + + let fn_tokens = match fn_type { + FunctionType::QueryVec => quote! { + #[allow(non_snake_case)] + pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result> { + let rows: Vec<#pascal_case> = static_sqlite::query(db, #sql, vec![#(#params,)*]).await?; + Ok(rows) + } + }, + FunctionType::QueryOption => quote! { + #[allow(non_snake_case)] + pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result> { + static_sqlite::query_first(db, #sql, vec![#(#params,)*]).await + } + }, + FunctionType::Stream => quote! { + #[allow(non_snake_case)] + pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> static_sqlite::Result>> { + static_sqlite::stream(db, #sql, vec![#(#params,)*]).await + } + }, + }; + output.push(quote! { #struct_tokens - #[doc = #sql] - pub async fn #ident(db: &static_sqlite::Sqlite, #(#fn_args),*) -> Result> { - let rows: Vec<#pascal_case> = static_sqlite::query(db, #sql, vec![#(#params,)*]).await?; - Ok(rows) - } + #fn_tokens + }) } Ok(output) } +fn create_fn_argument_type(fieldname: &String, column_type: &str) -> TokenStream { + match column_type { + "BLOB" => quote! { Vec }, + "INTEGER" => quote! { i64 }, + "REAL" | "DOUBLE" => quote! { f64 }, + "TEXT" => quote! { impl ToString }, + _ => unimplemented!( + "type {:?} not supported for fn arg {:?}", + column_type, + fieldname + ), + } +} + +fn create_binding_value(field_type: &str, not_null: i64, name: Ident) -> TokenStream { + match field_type { + "BLOB" => { + quote! { #name.into() } + } + "INTEGER" => quote! { #name.into() }, + "REAL" | "DOUBLE" => quote! { #name.into() }, + "TEXT" => match not_null { + 1 => quote! { + + #name.to_string().into() + }, + 0 => quote! { + match #name { + Some(val) => val.to_string().into(), + None => static_sqlite::Value::Null + } + }, + _ => unreachable!(), + }, + _ => unimplemented!("Sqlite param not supported"), + } +} + +#[derive(Debug, Clone)] +struct TypeHintedToken { + name: String, + alias: String, + column_type: String, + not_null: i64, +} + +#[derive(Debug, Clone)] +enum TypedToken { + FromTypeHint(TypeHintedToken), + FromSchemaRow(SchemaRow), +} + +/* + * Parses a type hint and returns a TypedColumnOrParameter + * + * If the alias is in the form of __ or ____ then it is a type hint + * Otherwise it is a column name + * + */ +fn parse_type_hinted_column_name(alias: &str, schema_rows: &Vec<&SchemaRow>) -> TypedToken { + let parts = alias.split("__").collect::>(); + let result = match parts.len() { + 1 => TypedToken::FromSchemaRow( + match schema_rows.iter().find(|row| &row.column_name == alias) { + Some(row) => (**row).clone(), + None => panic!("Column {:?} referenced in binding or column not found in schema, maybe you forgot to add the type hint?", alias), + } + ), + 2 => TypedToken::FromTypeHint(TypeHintedToken { + alias: alias.to_string(), + name: parts[0].to_string(), + column_type: parts[1].to_string(), + not_null: 1, + }), + 3 => TypedToken::FromTypeHint(TypeHintedToken { + alias: alias.to_string(), + name: parts[0].to_string(), + column_type: parts[1].to_string(), + not_null: match parts[2].to_lowercase().as_str() { + "nullable" => 0, + "not_null" => 1, + _ => panic!("Invalid type hint: {:?}, last part must be nullable or not_null", alias), + }, + }), + _ => panic!("Invalid type hint: {:?}", alias), + }; + result +} + fn join_table_names(expr: &&SqlExpr) -> Vec { let mut output = vec![]; visit_relations(&expr.statements, |rel| { @@ -533,16 +650,37 @@ fn structs_tokens(span: Span, schema: &Schema) -> Vec { .iter() .map(|(table, cols)| { let ident = proc_macro2::Ident::new(&table, span); - struct_tokens(span, &ident, cols) + let typed_tokens: Vec = cols + .iter() + .map(|col| TypedToken::FromSchemaRow(col.clone())) + .collect(); + struct_tokens(span, &ident, &typed_tokens) }) .collect() } -fn struct_tokens(span: Span, ident: &Ident, cols: &Vec) -> TokenStream { - let struct_fields = cols.iter().map(|row| { - let field_type = field_type(row); - let name = Ident::new(&row.column_name, span); - let optional = match (row.not_null, row.pk) { +fn struct_tokens(span: Span, ident: &Ident, output_typed: &[TypedToken]) -> TokenStream { + let struct_fields = output_typed.iter().map(|row| { + let field_type = match row { + TypedToken::FromTypeHint(type_hint) => { + field_type_from_datatype_name(&type_hint.column_type) + } + TypedToken::FromSchemaRow(schema_row) => field_type(schema_row), + }; + let name = match row { + TypedToken::FromTypeHint(type_hint) => Ident::new(&type_hint.name, span), + TypedToken::FromSchemaRow(schema_row) => Ident::new(&schema_row.column_name, span), + }; + let optional = match ( + match row { + TypedToken::FromTypeHint(type_hint) => type_hint.not_null, + TypedToken::FromSchemaRow(schema_row) => schema_row.not_null, + }, + match row { + TypedToken::FromTypeHint(_) => 0, + TypedToken::FromSchemaRow(schema_row) => schema_row.pk, + }, + ) { (0, 0) => true, (0, 1) | (1, 0) | (1, 1) => false, _ => unreachable!(), @@ -553,9 +691,21 @@ fn struct_tokens(span: Span, ident: &Ident, cols: &Vec) -> TokenStrea false => quote! { pub #name: #field_type }, } }); - let match_stmt = cols.iter().map(|field| { - let name = Ident::new(&field.column_name, span); - let lit_str = LitStr::new(&field.column_name, span); + let match_stmt = output_typed.iter().map(|row| { + let name = Ident::new( + match row { + TypedToken::FromTypeHint(type_hint) => &type_hint.name, + TypedToken::FromSchemaRow(schema_row) => &schema_row.column_name, + }, + span, + ); + let lit_str = LitStr::new( + match row { + TypedToken::FromTypeHint(type_hint) => &type_hint.alias, + TypedToken::FromSchemaRow(schema_row) => &schema_row.column_name, + }, + span, + ); quote! { #lit_str => row.#name = value.try_into()? @@ -584,7 +734,11 @@ fn struct_tokens(span: Span, ident: &Ident, cols: &Vec) -> TokenStrea } fn field_type(row: &SchemaRow) -> TokenStream { - match row.column_type.as_str() { + field_type_from_datatype_name(&row.column_type) +} + +fn field_type_from_datatype_name(datatype_name: &str) -> TokenStream { + match datatype_name { "BLOB" => quote! { Vec }, "INTEGER" => quote! { i64 }, "REAL" | "DOUBLE" => quote! { f64 }, diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 791e729..d9fd026 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -1,5 +1,6 @@ +use futures::pin_mut; +use futures::StreamExt; use static_sqlite::{sql, FirstRow, Result, Sqlite}; - #[tokio::test] async fn option_type_works() -> Result<()> { sql! { @@ -12,7 +13,7 @@ async fn option_type_works() -> Result<()> { let insert_row = r#" insert into Row (txt) values (:txt) returning * "#; - } + }; let db = static_sqlite::open(":memory:").await?; let _k = migrate(&db).await?; @@ -24,6 +25,84 @@ async fn option_type_works() -> Result<()> { Ok(()) } +#[tokio::test] +async fn stream_works() -> Result<()> { + sql! { + let migrate = r#" + create table Row ( + txt text + ) + "#; + + let insert_row = r#" + insert into Row (txt) values (:txt) returning * + "#; + + let select_rows_stream = r#" + select * from Row + "#; + } + + let db = static_sqlite::open(":memory:").await?; + migrate(&db).await?; + + insert_row(&db, Some("test1")).await?.first_row()?; + insert_row(&db, Some("test2")).await?.first_row()?; + insert_row(&db, Some("test3")).await?.first_row()?; + insert_row(&db, Some("test4")).await?.first_row()?; + + let f = select_rows_stream(&db).await?; + + pin_mut!(f); + + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test1".into())); + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test2".into())); + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test3".into())); + assert_eq!(f.next().await.unwrap().unwrap().txt, Some("test4".into())); + + Ok(()) +} + +#[tokio::test] +async fn query_first_works() -> Result<()> { + sql! { + let migrate = r#" + create table Row ( + id integer primary key autoincrement, + txt text NOT NULL + ) + "#; + + let insert_row_first = r#" + insert into Row (txt) values (:txt) returning * + "#; + + let select_row_first = r#" + select * from Row where id = :id + "#; + } + + let db = static_sqlite::open(":memory:").await?; + migrate(&db).await?; + + assert_eq!(insert_row_first(&db, "test1").await?.unwrap().txt, "test1"); + assert_eq!(insert_row_first(&db, "test2").await?.unwrap().txt, "test2"); + assert_eq!(insert_row_first(&db, "test3").await?.unwrap().txt, "test3"); + assert_eq!(insert_row_first(&db, "test4").await?.unwrap().txt, "test4"); + + match select_row_first(&db, 1).await? { + Some(row) => assert_eq!(row.txt, "test1"), + None => panic!("Row 1 not found"), + } + + match select_row_first(&db, 2).await? { + Some(row) => assert_eq!(row.txt, "test2"), + None => panic!("Row 2 not found"), + } + + Ok(()) +} + #[tokio::test] async fn it_works() -> Result<()> { sql! { @@ -118,7 +197,8 @@ async fn it_works() -> Result<()> { Some(2.0), Some(vec![0xFE, 0xED]), ) - .await?.first_row()?; + .await? + .first_row()?; assert_eq!( row, @@ -162,6 +242,7 @@ async fn readme_works() -> Result<()> { values (:name) returning * "#; + } let db = static_sqlite::open(":memory:").await?; @@ -174,6 +255,57 @@ async fn readme_works() -> Result<()> { Ok(()) } +#[tokio::test] +async fn transaction_works() -> Result<()> { + sql! { + let migrate = r#" + create table Item ( + id integer primary key + ); + "#; + + let insert_item = r#" + insert into Item (id) + values (:id) + returning * + "#; + + let get_item_first = r#" + select id from Item where id = :id + "#; + + } + let db = static_sqlite::open(":memory:").await?; + let _ = migrate(&db).await?; + + // being; insert; commmit + db.begin_transaction().await?; + insert_item(&db, 1).await?; + let item1 = get_item_first(&db, 1).await?; + assert_eq!(item1.is_some(), true); + db.commit_transaction().await?; + + // begin; insert; rollback + db.begin_transaction().await?; + + insert_item(&db, 2).await?; + let item2_in_transaction = get_item_first(&db, 2).await?; + assert_eq!(item2_in_transaction.is_some(), true); + + db.rollback_transaction().await?; + + let item2_after_rollback = get_item_first(&db, 2).await?; + assert_eq!(item2_after_rollback.is_some(), false); + + // rollback without begin + match db.rollback_transaction().await { + Ok(_) => panic!("should fail because no transaction is in progress"), + Err(_) => (), + } + + Ok(()) +} + #[tokio::test] async fn crud_works() -> Result<()> { sql! { @@ -201,8 +333,8 @@ async fn crud_works() -> Result<()> { let all_users = r#" select id, name from User "#; - } + } let db = static_sqlite::open(":memory:").await?; let _ = migrate(&db).await?; let user = insert_user(&db, "swlkr").await?.first_row()?; @@ -226,6 +358,162 @@ async fn crud_works() -> Result<()> { Ok(()) } +#[tokio::test] +async fn parameters_that_are_not_in_the_schema_work() -> Result<()> { + sql! { + let migrate = r#" + create table User ( + id integer primary key, + name text unique not null + ); + + create table Post ( + id integer primary key, + user_id integer not null references User(id), + name text unique not null + ); + "#; + + let insert_user = r#" + insert into User (name) values (:name) returning * + "#; + + let insert_post = r#" + insert into Post (user_id, name) values (:user_id, :name) returning * + "#; + let select_posts = r#" + select * from Post where id = :id AND id = :id__INTEGER AND name = :id__INTEGER AND name = :name AND :ff__TEXT="sdd" + "#; + } + + let db = static_sqlite::open(":memory:").await?; + let _ = migrate(&db).await?; + let user1 = insert_user(&db, "user1").await?.first_row()?; + insert_post(&db, user1.id, "user 1 - post1") + .await? + .first_row()?; + insert_post(&db, user1.id, "user 1 - post2") + .await? + .first_row()?; + let user2 = insert_user(&db, "user2").await?.first_row()?; + insert_post(&db, user2.id, "user 2 - post1") + .await? + .first_row()?; + insert_post(&db, user2.id, "user 2 - post2") + .await? + .first_row()?; + + let posts = select_posts(&db, 1, 2, "Hello", "sdd").await?; + println!("{:?}", posts); + + Ok(()) +} + +#[tokio::test] +async fn example_friendshipworks() -> Result<()> { + use static_sqlite::{self, sql}; + + sql! { + let migrate = r#" + create table User ( + id integer primary key, + name text unique not null + ); + + create table Friendship ( + id integer primary key, + user_id integer not null references User(id), + friend_id integer not null references User(id) + ); + "#; + + let insert_user = r#" + insert into User (name) + values (:name) + returning * + "#; + let create_friendship = r#" + insert into Friendship (user_id, friend_id) + values (:user_id, :friend_id) + returning * + "#; + let get_friendship = r#" + SELECT + u1.name as friend1_name__TEXT, + u2.name as friend2_name__TEXT + FROM Friendship, User as u1, User as u2 + WHERE Friendship.user_id = u1.id + AND Friendship.friend_id = u2.id + AND Friendship.id = :friendship_id__INTEGER + "#; + + + } + + let db = static_sqlite::open(":memory:").await?; + let _ = migrate(&db).await?; + insert_user(&db, "swlkr").await?; + insert_user(&db, "toolbar23").await?; + create_friendship(&db, 1, 2).await?; + + let friends = get_friendship(&db, 1).await?; + + assert_eq!(friends.len(), 1); + assert_eq!(friends.first().unwrap().friend1_name, "swlkr"); + assert_eq!(friends.first().unwrap().friend2_name, "toolbar23"); + + Ok(()) +} + +#[tokio::test] +async fn duplicate_column_names_in_one_query_work() -> Result<()> { + sql! { + let migrate = r#" + CREATE TABLE IF NOT EXISTS Identifiers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + entity_type INTEGER NOT NULL, + identifier_type TEXT NOT NULL, + identifier_value TEXT NOT NULL, + UNIQUE(entity_type, identifier_type, identifier_value) + ); + + CREATE TABLE IF NOT EXISTS MappingChanges ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + from_identifier INTEGER NOT NULL REFERENCES Identifiers(id), + to_identifier_previous INTEGER NOT NULL REFERENCES Identifiers(id), + to_identifier_new INTEGER NOT NULL REFERENCES Identifiers(id), + timestamp INTEGER NOT NULL + ); + "#; + + let get_changes = r#" + SELECT + mc.id, + mc.timestamp, + f.entity_type, + f.identifier_type, + f.identifier_value, + op.identifier_type as old_type__TEXT__NULLABLE, + op.identifier_value as old_value__TEXT__NULLABLE, + n.identifier_type as new_type__TEXT__NULLABLE, + n.identifier_value as new_value__TEXT__NULLABLE + FROM MappingChanges mc, Identifiers f, Identifiers op, Identifiers n + WHERE mc.from_identifier = f.id + AND mc.to_identifier_previous = op.id + AND mc.to_identifier_new = n.id + AND mc.timestamp > :since__INTEGER + ORDER BY mc.timestamp ASC + "#; + } + + let db = static_sqlite::open(":memory:").await?; + let _ = migrate(&db).await?; + let changes = get_changes(&db, 1).await?; + println!("{:?}", changes); + + Ok(()) +} + #[test] fn ui() { let t = trybuild::TestCases::new();