diff --git a/Cargo.lock b/Cargo.lock index 8de1cf8b..fbabb147 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3252,6 +3252,12 @@ version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "downcast-rs" version = "2.0.2" @@ -3577,6 +3583,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619" + [[package]] name = "fs-err" version = "3.2.0" @@ -5251,6 +5263,32 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "mockall" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "multer" version = "3.1.0" @@ -5990,6 +6028,32 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "predicates" +version = "3.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" + +[[package]] +name = "predicates-tree" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -7251,6 +7315,7 @@ dependencies = [ "chrono", "error-stack", "error-stack-trace", + "mockall", "serde", "serde_json", "snafu", @@ -7414,6 +7479,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "termtree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index 5ad311ae..73e4ac71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,7 +92,7 @@ url = "2.5" utoipa = { version = "5.3.1", features = ["uuid", "chrono"] } utoipa-axum = { version = "0.2.0" } utoipa-swagger-ui = { version = "9", features = ["axum"] } -uuid = { version = "1.10.0", features = ["v4", "serde"] } +uuid = { version = "1.10.0", features = ["v4", "v7", "serde"] } validator = { version = "0.20.0", features = ["derive"] } mockall = "0.13.1" insta = { version = "1.44.1", features = ["json", "filters", "redactions"] } diff --git a/crates/api-snowflake-rest/Cargo.toml b/crates/api-snowflake-rest/Cargo.toml index 1f397f4e..daa22a55 100644 --- a/crates/api-snowflake-rest/Cargo.toml +++ b/crates/api-snowflake-rest/Cargo.toml @@ -7,6 +7,7 @@ license-file.workspace = true [features] default = [] retry-disable = [] +traces-test-log = [] [dependencies] api-snowflake-rest-sessions = { path = "../api-snowflake-rest-sessions" } diff --git a/crates/api-snowflake-rest/src/tests/create_test_server.rs b/crates/api-snowflake-rest/src/tests/create_test_server.rs index 96b3a1be..6ce3703f 100644 --- a/crates/api-snowflake-rest/src/tests/create_test_server.rs +++ b/crates/api-snowflake-rest/src/tests/create_test_server.rs @@ -12,7 +12,8 @@ use std::sync::{Arc, Condvar, Mutex}; use std::thread; use std::time::Duration; use tokio::runtime::Builder; -use tracing_subscriber::fmt::format::FmtSpan; +#[cfg(feature = "traces-test-log")] +use tracing_subscriber::{fmt, fmt::format::FmtSpan}; static INIT: std::sync::Once = std::sync::Once::new(); @@ -140,10 +141,13 @@ fn setup_tracing() { .with_targets(targets_with_level(&DISABLED_TARGETS, LevelFilter::OFF)) .with_default(LevelFilter::TRACE), ), - ) + ); + + #[cfg(feature = "traces-test-log")] + let registry = registry // Logs filtering .with( - tracing_subscriber::fmt::layer() + fmt::layer() .with_writer( std::fs::OpenOptions::new() .create(true) diff --git a/crates/build-info/build.rs b/crates/build-info/build.rs index 130736bd..49b11a6d 100644 --- a/crates/build-info/build.rs +++ b/crates/build-info/build.rs @@ -45,10 +45,11 @@ fn main() { println!("cargo:rustc-env=BUILD_TIMESTAMP={build_timestamp}"); // Rerun build script if git HEAD changes - println!("cargo:rerun-if-changed=.git/HEAD"); + // Should point to the root of the repository + println!("cargo:rerun-if-changed=../../.git/HEAD"); // Also rerun if the current branch ref changes if let Some(branch_ref) = run_git_command(&["symbolic-ref", "HEAD"]) { - let ref_path = format!(".git/{branch_ref}"); + let ref_path = format!("../../.git/{branch_ref}"); println!("cargo:rerun-if-changed={ref_path}"); } } diff --git a/crates/embucket-lambda/Cargo.toml b/crates/embucket-lambda/Cargo.toml index 516023d7..72f1d4dc 100644 --- a/crates/embucket-lambda/Cargo.toml +++ b/crates/embucket-lambda/Cargo.toml @@ -45,6 +45,7 @@ retry-disable = ["api-snowflake-rest/retry-disable"] streaming = [] rest-catalog = ["executor/rest-catalog"] dedicated-executor = ["executor/dedicated-executor"] +state-store-query = ["executor/state-store-query"] [package.metadata.lambda] # Default binary to deploy diff --git a/crates/embucketd/Cargo.toml b/crates/embucketd/Cargo.toml index a4510782..e7e73965 100644 --- a/crates/embucketd/Cargo.toml +++ b/crates/embucketd/Cargo.toml @@ -52,3 +52,4 @@ retry-disable = ["api-snowflake-rest/retry-disable"] rest-catalog = ["executor/rest-catalog"] dedicated-executor = ["executor/dedicated-executor"] state-store = ["executor/state-store"] +state-store-query = ["executor/state-store-query"] diff --git a/crates/executor/Cargo.toml b/crates/executor/Cargo.toml index 494d6bd5..7602bb22 100644 --- a/crates/executor/Cargo.toml +++ b/crates/executor/Cargo.toml @@ -5,9 +5,11 @@ edition = "2024" license-file.workspace = true [features] +default = [] rest-catalog = ["catalog/rest-catalog"] dedicated-executor = [] state-store = [] +state-store-query = ["state-store"] [dependencies] catalog-metastore = { path = "../catalog-metastore" } diff --git a/crates/executor/src/error.rs b/crates/executor/src/error.rs index 47a56c67..edbab26c 100644 --- a/crates/executor/src/error.rs +++ b/crates/executor/src/error.rs @@ -1,5 +1,5 @@ use super::snowflake_error::SnowflakeError; -use crate::query_types::{QueryId, QueryStatus}; +use crate::query_types::{ExecutionStatus, QueryId}; use catalog::error::Error as CatalogError; use datafusion_common::DataFusionError; use error_stack_trace; @@ -592,7 +592,7 @@ pub enum Error { }, #[snafu(display("Query {query_id} result notify error: {error}"))] - QueryStatusRecv { + ExecutionStatusRecv { query_id: QueryId, #[snafu(source)] error: tokio::sync::watch::error::RecvError, @@ -601,10 +601,10 @@ pub enum Error { }, #[snafu(display("Query {query_id} status notify error: {error}"))] - NotifyQueryStatus { + NotifyExecutionStatus { query_id: QueryId, #[snafu(source)] - error: tokio::sync::watch::error::SendError, + error: tokio::sync::watch::error::SendError, #[snafu(implicit)] location: Location, }, diff --git a/crates/executor/src/error_code.rs b/crates/executor/src/error_code.rs index 6146815d..f182da74 100644 --- a/crates/executor/src/error_code.rs +++ b/crates/executor/src/error_code.rs @@ -6,8 +6,10 @@ use std::fmt::Display; // For reference: https://github.com/snowflakedb/snowflake-cli/blob/main/src/snowflake/cli/api/errno.py // Some of our error codes may be mapped to Snowflake error codes +// Do not set values for error codes, they are assigned in Display trait #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub enum ErrorCode { + None, Db, Metastore, #[cfg(feature = "state-store")] @@ -35,6 +37,10 @@ pub enum ErrorCode { EntityNotFound(Entity, OperationOn), Other, UnsupportedFeature, + Timeout, + Cancelled, + LimitExceeded, + QueryTask, } impl Display for ErrorCode { @@ -42,6 +48,8 @@ impl Display for ErrorCode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let code = match self { Self::UnsupportedFeature => 2, + Self::Timeout => 630, + Self::Cancelled => 684, Self::HistoricalQueryError => 1001, Self::DataFusionSqlParse => 1003, Self::DataFusionSql => 2003, diff --git a/crates/executor/src/lib.rs b/crates/executor/src/lib.rs index 7541893f..cb62f87d 100644 --- a/crates/executor/src/lib.rs +++ b/crates/executor/src/lib.rs @@ -6,6 +6,7 @@ pub mod error; pub mod error_code; pub mod models; pub mod query; +pub mod query_task_result; pub mod query_types; pub mod running_queries; pub mod service; @@ -18,7 +19,7 @@ pub mod utils; pub mod tests; pub use error::{Error, Result}; -pub use query_types::{QueryId, QueryStatus}; +pub use query_types::{ExecutionStatus, QueryId}; pub use running_queries::RunningQueryId; pub use snowflake_error::SnowflakeError; diff --git a/crates/executor/src/models.rs b/crates/executor/src/models.rs index 53c5d5bd..2faba8e8 100644 --- a/crates/executor/src/models.rs +++ b/crates/executor/src/models.rs @@ -8,7 +8,7 @@ use std::collections::HashMap; use std::sync::Arc; use uuid::Uuid; -#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct QueryContext { pub database: Option, pub schema: Option, @@ -18,6 +18,21 @@ pub struct QueryContext { pub ip_address: Option, } +// Add own Default implementation to avoid getting default (zeroed) Uuid. +// This compromise is against rules, since this default is not deterministic. +impl Default for QueryContext { + fn default() -> Self { + Self { + database: None, + schema: None, + worksheet_id: None, + query_id: Uuid::now_v7(), + request_id: None, + ip_address: None, + } + } +} + impl QueryContext { #[must_use] pub fn new( @@ -29,9 +44,7 @@ impl QueryContext { database, schema, worksheet_id, - query_id: QueryId::default(), - request_id: None, - ip_address: None, + ..Default::default() } } diff --git a/crates/executor/src/query_task_result.rs b/crates/executor/src/query_task_result.rs new file mode 100644 index 00000000..d5c6ed2c --- /dev/null +++ b/crates/executor/src/query_task_result.rs @@ -0,0 +1,87 @@ +use super::error as ex_error; +use super::error::Result; +use super::error_code::ErrorCode; +use super::models::QueryResult; +use super::query_types::ExecutionStatus; +use super::snowflake_error::SnowflakeError; +use snafu::ResultExt; +use tokio::task::JoinError; +use uuid::Uuid; + +// pub type TaskFuture = tokio::task::JoinHandle>; + +pub struct ExecutionTaskResult { + pub result: Result, + pub execution_status: ExecutionStatus, + pub error_code: Option, +} + +impl ExecutionTaskResult { + pub fn from_query_result(query_id: Uuid, result: Result) -> Self { + let execution_status = result + .as_ref() + .map_or_else(|_| ExecutionStatus::Fail, |_| ExecutionStatus::Success); + let error_code = match result.as_ref() { + Ok(_) => None, + Err(err) => Some(SnowflakeError::from_executor_error(err).error_code()), + }; + // set query execution status to successful or failed + Self { + result: result.context(ex_error::QueryExecutionSnafu { query_id }), + execution_status, + error_code, + } + } + + #[must_use] + pub fn from_query_limit_exceeded(query_id: Uuid) -> Self { + Self { + result: ex_error::ConcurrencyLimitSnafu + .fail() + .context(ex_error::QueryExecutionSnafu { query_id }), + execution_status: ExecutionStatus::Incident, + error_code: Some(ErrorCode::LimitExceeded), + } + } + + #[must_use] + pub fn from_failed_query_task(query_id: Uuid, task_error: JoinError) -> Self { + Self { + result: Err(task_error) + .context(ex_error::QuerySubtaskJoinSnafu) + .context(ex_error::QueryExecutionSnafu { query_id }), + execution_status: ExecutionStatus::Incident, + error_code: Some(ErrorCode::QueryTask), + } + } + + #[must_use] + pub fn from_cancelled_query_task(query_id: Uuid) -> Self { + Self { + result: ex_error::QueryCancelledSnafu { query_id } + .fail() + .context(ex_error::QueryExecutionSnafu { query_id }), + execution_status: ExecutionStatus::Fail, + error_code: Some(ErrorCode::Cancelled), + } + } + + #[must_use] + pub fn from_timeout_query_task(query_id: Uuid) -> Self { + Self { + result: ex_error::QueryTimeoutSnafu + .fail() + .context(ex_error::QueryExecutionSnafu { query_id }), + execution_status: ExecutionStatus::Fail, + error_code: Some(ErrorCode::Timeout), + } + } + + #[cfg(feature = "state-store-query")] + pub fn assign_query_attributes(&self, query: &mut state_store::Query) { + query.set_execution_status(self.execution_status); + if let Some(error_code) = self.error_code { + query.set_error_code(error_code.to_string()); + } + } +} diff --git a/crates/executor/src/query_types.rs b/crates/executor/src/query_types.rs index 61b63a52..2c6d85d8 100644 --- a/crates/executor/src/query_types.rs +++ b/crates/executor/src/query_types.rs @@ -1,14 +1,16 @@ -use serde::{Deserialize, Serialize}; -use std::fmt::Debug; -use uuid::Uuid; +pub type QueryId = uuid::Uuid; -pub type QueryId = Uuid; - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub enum QueryStatus { - Running, - Successful, - Failed, - Cancelled, - TimedOut, +cfg_if::cfg_if! { + if #[cfg(feature = "state-store-query")] { + pub use state_store::ExecutionStatus; + } else { + use serde::{Deserialize, Serialize}; + use std::fmt::Debug; + #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] + pub enum ExecutionStatus { + Success, + Fail, + Incident, + } + } } diff --git a/crates/executor/src/running_queries.rs b/crates/executor/src/running_queries.rs index 30844a19..cf2ff674 100644 --- a/crates/executor/src/running_queries.rs +++ b/crates/executor/src/running_queries.rs @@ -1,6 +1,6 @@ use super::error::{self as ex_error, Result}; use super::models::QueryResult; -use crate::query_types::{QueryId, QueryStatus}; +use crate::query_types::{ExecutionStatus, QueryId}; use dashmap::DashMap; use snafu::{OptionExt, ResultExt}; use std::sync::Arc; @@ -20,8 +20,8 @@ pub struct RunningQuery { pub result_handle: Option>>, pub cancellation_token: CancellationToken, // user can be notified when query is finished - tx: watch::Sender, - rx: watch::Receiver, + tx: watch::Sender>, + rx: watch::Receiver>, } #[derive(Debug, Clone)] @@ -33,7 +33,7 @@ pub enum RunningQueryId { impl RunningQuery { #[must_use] pub fn new(query_id: QueryId) -> Self { - let (tx, rx) = watch::channel(QueryStatus::Running); + let (tx, rx) = watch::channel(None); Self { query_id, request_id: None, @@ -76,9 +76,9 @@ impl RunningQuery { )] pub fn notify_query_finished( &self, - status: QueryStatus, - ) -> std::result::Result<(), watch::error::SendError> { - self.tx.send(status) + status: ExecutionStatus, + ) -> std::result::Result<(), watch::error::SendError>> { + self.tx.send(Some(status)) } #[tracing::instrument( @@ -89,14 +89,14 @@ impl RunningQuery { )] pub async fn wait_query_finished( &self, - ) -> std::result::Result { + ) -> std::result::Result { // use loop here to bypass default query status we posted at init // it should not go to the actual loop and should resolve as soon as results are ready let mut rx = self.rx.clone(); loop { rx.changed().await?; let status = *rx.borrow(); - if status != QueryStatus::Running { + if let Some(status) = status { break Ok(status); } } @@ -131,7 +131,7 @@ impl RunningQueriesRegistry { skip(self), err )] - pub async fn wait_query_finished(&self, query_id: QueryId) -> Result { + pub async fn wait_query_finished(&self, query_id: QueryId) -> Result { let running_query = self .queries .get(&query_id) @@ -139,7 +139,7 @@ impl RunningQueriesRegistry { running_query .wait_query_finished() .await - .context(ex_error::QueryStatusRecvSnafu { query_id }) + .context(ex_error::ExecutionStatusRecvSnafu { query_id }) } } @@ -149,7 +149,7 @@ pub trait RunningQueries: Send + Sync { fn add(&self, running_query: RunningQuery); fn remove(&self, query_id: QueryId) -> Result; fn abort(&self, query_id: QueryId) -> Result<()>; - fn notify_query_finished(&self, query_id: QueryId, status: QueryStatus) -> Result<()>; + fn notify_query_finished(&self, query_id: QueryId, status: ExecutionStatus) -> Result<()>; fn locate_query_id(&self, running_query_id: RunningQueryId) -> Result; fn count(&self) -> usize; } @@ -197,7 +197,7 @@ impl RunningQueries for RunningQueriesRegistry { skip(self), err )] - fn notify_query_finished(&self, query_id: QueryId, status: QueryStatus) -> Result<()> { + fn notify_query_finished(&self, query_id: QueryId, status: ExecutionStatus) -> Result<()> { let running_query = self .queries .get(&query_id) diff --git a/crates/executor/src/service.rs b/crates/executor/src/service.rs index 820bbf3c..703476cd 100644 --- a/crates/executor/src/service.rs +++ b/crates/executor/src/service.rs @@ -26,7 +26,8 @@ use super::error::{self as ex_error, Result}; use super::models::{QueryContext, QueryResult}; use super::running_queries::{RunningQueries, RunningQueriesRegistry, RunningQuery}; use super::session::UserSession; -use crate::query_types::{QueryId, QueryStatus}; +use crate::query_task_result::ExecutionTaskResult; +use crate::query_types::QueryId; use crate::running_queries::RunningQueryId; use crate::session::{SESSION_INACTIVITY_EXPIRATION_SECONDS, to_unix}; use crate::tracing::SpanTracer; @@ -153,6 +154,27 @@ pub struct CoreExecutionService { } impl CoreExecutionService { + #[cfg(feature = "state-store")] + pub async fn new_test_executor( + metastore: Arc, + state_store: Arc, + config: Arc, + ) -> Result { + Self::initialize_datafusion_tracer(); + + let catalog_list = Self::catalog_list(metastore.clone(), &config).await?; + let runtime_env = Self::runtime_env(&config, catalog_list.clone())?; + Ok(Self { + metastore, + df_sessions: Arc::new(RwLock::new(HashMap::new())), + config, + catalog_list, + runtime_env, + queries: Arc::new(RunningQueriesRegistry::new()), + state_store, + }) + } + #[tracing::instrument( name = "CoreExecutionService::new", level = "debug", @@ -485,35 +507,45 @@ impl ExecutionService for CoreExecutionService { async fn submit( &self, session_id: &str, - query: &str, + query_text: &str, query_context: QueryContext, ) -> Result { let user_session = self.get_session(session_id).await?; - #[cfg(feature = "state-store")] - { - let query_record = Query::new( - query, - query_context.query_id, - session_id, - query_context.request_id, - ); - self.state_store - .put_query(query_record) - .await - .context(ex_error::StateStoreSnafu)?; - let _ = self - .state_store - .get_query(&query_context.query_id.to_string()) - .await - .context(ex_error::StateStoreSnafu)?; + cfg_if::cfg_if! { + if #[cfg(feature = "state-store-query")] { + let mut query = Query::new( + query_text, + query_context.query_id, + session_id, + query_context.request_id, + ); + let query_id = query_context.query_id; + } else { + let query_id = Uuid::now_v7(); + } } if self.queries.count() >= self.config.max_concurrency_level { - return ex_error::ConcurrencyLimitSnafu.fail(); + let limit_exceeded = ExecutionTaskResult::from_query_limit_exceeded(query_id); + #[cfg(feature = "state-store-query")] + { + // query created with failed status already + limit_exceeded.assign_query_attributes(&mut query); + self.state_store + .put_query(&query) + .await + .context(ex_error::StateStoreSnafu)?; + } + // here we always return error, but Ok should fit Result type too + return limit_exceeded.result.map(|_| query_id); } - let query_id = Uuid::new_v4(); + #[cfg(feature = "state-store-query")] + self.state_store + .put_query(&query) + .await + .context(ex_error::StateStoreSnafu)?; // Record the result as part of the current span. tracing::Span::current() @@ -521,11 +553,7 @@ impl ExecutionService for CoreExecutionService { .record("with_timeout_secs", self.config.query_timeout_secs); let request_id = query_context.request_id; - let query = query.to_string(); - let query_timeout_secs = self.config.query_timeout_secs; - let queries_clone = self.queries.clone(); let query_token = CancellationToken::new(); - let query_token_clone = query_token.clone(); let task_span = tracing::info_span!("spawn_query_task"); @@ -535,87 +563,103 @@ impl ExecutionService for CoreExecutionService { query_id = %query_id, session_id = %session_id ); - let handle = tokio::spawn(async move { - let sub_task_span = tracing::info_span!("spawn_query_sub_task"); - let mut query_obj = user_session.query(query, query_context.with_query_id(query_id)); - - // Create nested task so in case of abort/timeout it can be aborted - // and result is handled properly (status / query result saved) - let subtask_fut = task::spawn(async move { - query_obj.execute().instrument(sub_task_span).await - }); - let subtask_abort_handle = subtask_fut.abort_handle(); - - // wait for any future to be resolved - let (query_result, query_status) = tokio::select! { - finished = subtask_fut => { - match finished { - Ok(inner_result) => { - // set query execution status to successful or failed - let status = inner_result.as_ref().map_or_else(|_| QueryStatus::Failed, |_| QueryStatus::Successful); - (inner_result.context(ex_error::QueryExecutionSnafu { - query_id, - }), status) - }, - Err(error) => { - tracing::error!("Query {query_id} sub task join error: {error:?}"); - (Err(ex_error::Error::QuerySubtaskJoin { error, location: snafu::location!() }).context(ex_error::QueryExecutionSnafu { - query_id, - }), QueryStatus::Failed) - }, + let handle = tokio::spawn({ + #[cfg(feature = "state-store-query")] + let state_store = self.state_store.clone(); + let query_text = query_text.to_string(); + let query_timeout = Duration::from_secs(self.config.query_timeout_secs); + let queries_registry = self.queries.clone(); + let query_token = query_token.clone(); + async move { + let sub_task_span = tracing::info_span!("spawn_query_sub_task"); + let mut query_obj = user_session.query(query_text, query_context); + + // Create nested task so in case of abort/timeout it can be aborted + // and result is handled properly (status / query result saved) + let task_future = + task::spawn(async move { query_obj.execute().instrument(sub_task_span).await }); + + let subtask_abort_handle = task_future.abort_handle(); + // wait for any future to be resolved + let execution_result = tokio::select! { + finished = task_future => { + match finished { + Ok(inner_result) => ExecutionTaskResult::from_query_result(query_id, inner_result), + Err(task_error) => { + tracing::error!("Query {query_id} sub task join error: {task_error:?}"); + ExecutionTaskResult::from_failed_query_task(query_id, task_error) + }, + } + }, + () = query_token.cancelled() => { + tracing::info_span!("abort_cancelled_query"); + subtask_abort_handle.abort(); + ExecutionTaskResult::from_cancelled_query_task(query_id) + }, + // Execute the query with a timeout to prevent long-running or stuck queries + // from blocking system resources indefinitely. If the timeout is exceeded, + // convert the timeout into a standard QueryTimeout error so it can be handled + // and recorded like any other execution failure + () = tokio::time::sleep(query_timeout) => { + tracing::info_span!("query_timeout_received_do_abort"); + subtask_abort_handle.abort(); + ExecutionTaskResult::from_timeout_query_task(query_id) + } + }; + + let _ = tracing::info_span!( + "finished_query_status", + query_id = query_id.to_string(), + query_status = format!("{:?}", execution_result.execution_status), + error_code = format!("{:?}", execution_result.error_code), + ) + .entered(); + + cfg_if::cfg_if! { + if #[cfg(feature = "state-store-query")] { + execution_result.assign_query_attributes(&mut query); + // just log error and do not raise it from task + if let Err(err) = state_store.update_query(&query).await { + tracing::error!("Failed to update query {query_id}: {err:?}"); + } + } else { + user_session.record_query_id(query_id); } - }, - () = query_token.cancelled() => { - tracing::info_span!("query_cancelled_do_abort"); - subtask_abort_handle.abort(); - (ex_error::QueryCancelledSnafu { query_id }.fail().context(ex_error::QueryExecutionSnafu { - query_id, - }), QueryStatus::Cancelled) - }, - // Execute the query with a timeout to prevent long-running or stuck queries - // from blocking system resources indefinitely. If the timeout is exceeded, - // convert the timeout into a standard QueryTimeout error so it can be handled - // and recorded like any other execution failure - () = tokio::time::sleep(Duration::from_secs(query_timeout_secs)) => { - tracing::info_span!("query_timeout_received_do_abort"); - subtask_abort_handle.abort(); - (ex_error::QueryTimeoutSnafu.fail().context(ex_error::QueryExecutionSnafu { - query_id, - }), QueryStatus::TimedOut) } - }; - let _ = tracing::info_span!("finished_query_status", - query_id = query_id.to_string(), - query_status = format!("{query_status:?}"), - ) - .entered(); - - user_session.record_query_id(query_id); - - // Notify subscribers query finishes and result is ready. - // Do not immediately remove query from running queries registry - // as RunningQuery contains result handle that caller should consume. - queries_clone.notify_query_finished(query_id, query_status)?; - - // Discard results after short timeout, to prevent memory leaks - tokio::spawn(async move { - tokio::time::sleep(Duration::from_secs(TIMEOUT_DISCARD_INTERVAL_SECONDS)).await; - let running_query = queries_clone.remove(query_id); - if let Ok(RunningQuery {result_handle: Some(result_handle), ..}) = running_query { - tracing::debug!("Discarding '{query_status:?}' result for query {query_id}"); - let _ = result_handle.await; - } - }); + // Notify subscribers query finishes and result is ready. + // Do not immediately remove query from running queries registry + // as RunningQuery contains result handle that caller should consume. + queries_registry.notify_query_finished(query_id, execution_result.execution_status)?; + + // Discard results after short timeout, to prevent memory leaks + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(TIMEOUT_DISCARD_INTERVAL_SECONDS)).await; + let running_query = queries_registry.remove(query_id); + if let Ok(RunningQuery { + result_handle: Some(result_handle), + .. + }) = running_query + { + tracing::debug!( + "Discard execution result '{:?}' for query {query_id}", + execution_result.execution_status + ); + let _ = result_handle.await; + } + }); - query_result - }.instrument(alloc_span).instrument(task_span)); + execution_result.result + } + .instrument(alloc_span) + .instrument(task_span) + }); self.queries.add( RunningQuery::new(query_id) .with_request_id(request_id) .with_result_handle(handle) - .with_cancellation_token(query_token_clone), + .with_cancellation_token(query_token), ); Ok(query_id) diff --git a/crates/executor/src/tests/mod.rs b/crates/executor/src/tests/mod.rs index db704fe7..8ed1fd62 100644 --- a/crates/executor/src/tests/mod.rs +++ b/crates/executor/src/tests/mod.rs @@ -3,3 +3,5 @@ pub mod s3_tables; pub mod service; pub mod snowflake_errors; pub mod sql; +#[cfg(feature = "state-store-query")] +pub mod statestore_queries_unittest; diff --git a/crates/executor/src/tests/statestore_queries_unittest.rs b/crates/executor/src/tests/statestore_queries_unittest.rs new file mode 100644 index 00000000..671259bd --- /dev/null +++ b/crates/executor/src/tests/statestore_queries_unittest.rs @@ -0,0 +1,48 @@ +use crate::models::QueryContext; +use crate::service::{CoreExecutionService, ExecutionService}; +use crate::utils::Config; +use catalog_metastore::InMemoryMetastore; +use state_store::{MockStateStore, SessionRecord, StateStore}; +use std::sync::Arc; + +const TEST_SESSION_ID: &str = "test_session_id"; + +// it stucks without multithread +#[allow(clippy::expect_used)] +#[tokio::test] +async fn test_query_lifecycle() { + let mut state_store_mock = MockStateStore::new(); + state_store_mock + .expect_put_new_session() + .returning(|_| Ok(())); + state_store_mock + .expect_get_session() + .returning(|_| Ok(SessionRecord::new(TEST_SESSION_ID))); + state_store_mock.expect_put_query().returning(|_| Ok(())); + state_store_mock.expect_update_query().returning(|_| Ok(())); + + let state_store: Arc = Arc::new(state_store_mock); + + let metastore = Arc::new(InMemoryMetastore::new()); + let execution_svc = CoreExecutionService::new_test_executor( + metastore, + state_store, + Arc::new(Config::default()), + ) + .await + .expect("Failed to create execution service"); + + execution_svc + .create_session(TEST_SESSION_ID) + .await + .expect("Failed to create session"); + + let _ = execution_svc + .query( + TEST_SESSION_ID, + "SELECT 1 AS a, 2.0 AS b, '3' AS c WHERE False", + QueryContext::default(), + ) + .await + .expect("Failed to execute query"); +} diff --git a/crates/queries/Cargo.toml b/crates/queries/Cargo.toml index 896ea10d..6f4bd785 100644 --- a/crates/queries/Cargo.toml +++ b/crates/queries/Cargo.toml @@ -4,6 +4,10 @@ version = "0.1.0" edition = "2024" license-file = { workspace = true } +[features] +default = [] +tests = [] # tests not included by default + [dependencies] error-stack-trace = { path = "../error-stack-trace" } error-stack = { path = "../error-stack" } diff --git a/crates/queries/src/lib.rs b/crates/queries/src/lib.rs index 23daaa16..7284ab5f 100644 --- a/crates/queries/src/lib.rs +++ b/crates/queries/src/lib.rs @@ -2,7 +2,7 @@ pub mod error; pub mod models; pub mod operations; -#[cfg(test)] +#[cfg(all(feature = "tests", test))] pub mod tests; pub use models::QuerySource; diff --git a/crates/state-store/Cargo.toml b/crates/state-store/Cargo.toml index 83c6e6b6..a5d30a88 100644 --- a/crates/state-store/Cargo.toml +++ b/crates/state-store/Cargo.toml @@ -16,6 +16,7 @@ serde = { workspace = true } serde_json = {workspace = true} snafu = { workspace = true } chrono = { workspace = true } +mockall = { workspace = true } uuid = { workspace = true } [lints] diff --git a/crates/state-store/src/lib.rs b/crates/state-store/src/lib.rs index a4afb098..ee82f9d1 100644 --- a/crates/state-store/src/lib.rs +++ b/crates/state-store/src/lib.rs @@ -5,6 +5,6 @@ pub mod state_store; pub use config::DynamoDbConfig; pub use error::{Error, Result}; -pub use models::{SessionRecord, Variable, ViewRecord}; +pub use models::{ExecutionStatus, Query, SessionRecord, Variable, ViewRecord}; pub use state_store::DynamoDbStateStore; -pub use state_store::StateStore; +pub use state_store::{MockStateStore, StateStore}; diff --git a/crates/state-store/src/models.rs b/crates/state-store/src/models.rs index 168eea87..bb517b33 100644 --- a/crates/state-store/src/models.rs +++ b/crates/state-store/src/models.rs @@ -23,30 +23,19 @@ impl Display for Entities { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] -pub enum QueryStatus { - #[default] - Created, - Queued, - Running, - LimitExceeded, - Successful, - Failed, - Cancelled, - TimedOut, +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExecutionStatus { + Success, + Fail, + Incident, } -impl Display for QueryStatus { +impl Display for ExecutionStatus { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let value = match self { - Self::Created => "created", - Self::Queued => "queued", - Self::Running => "running", - Self::LimitExceeded => "limit_exceeded", - Self::Successful => "successful", - Self::Failed => "failed", - Self::Cancelled => "cancelled", - Self::TimedOut => "timed_out", + Self::Success => "success", + Self::Fail => "fail", + Self::Incident => "incident", }; write!(f, "{value}") } @@ -127,7 +116,6 @@ pub struct Variable { pub struct Query { pub query_id: Uuid, pub request_id: Option, - pub query_status: QueryStatus, pub query_text: String, pub session_id: String, #[serde(default, skip_serializing_if = "Option::is_none")] @@ -159,7 +147,7 @@ pub struct Query { #[serde(default, skip_serializing_if = "Option::is_none")] pub query_tag: Option, #[serde(default, skip_serializing_if = "Option::is_none")] - pub execution_status: Option, + pub execution_status: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub error_code: Option, #[serde(default, skip_serializing_if = "Option::is_none")] @@ -315,4 +303,14 @@ impl Query { pub fn entity(&self) -> String { Entities::Query.to_string() } + + // Why? warning: this could be a `const fn` + #[allow(clippy::missing_const_for_fn)] + pub fn set_execution_status(&mut self, status: ExecutionStatus) { + self.execution_status = Some(status); + } + + pub fn set_error_code(&mut self, error_code: String) { + self.error_code = Some(error_code); + } } diff --git a/crates/state-store/src/state_store.rs b/crates/state-store/src/state_store.rs index acd19daf..f6f69546 100644 --- a/crates/state-store/src/state_store.rs +++ b/crates/state-store/src/state_store.rs @@ -23,6 +23,7 @@ const QUERY_ID_INDEX: &str = "GSI_QUERY_ID_INDEX"; const REQUEST_ID_INDEX: &str = "GSI_REQUEST_ID_INDEX"; const SESSION_ID_INDEX: &str = "GSI_SESSION_ID_INDEX"; +#[mockall::automock] #[async_trait::async_trait] pub trait StateStore: Send + Sync { async fn put_new_session(&self, session_id: &str) -> Result<()>; @@ -30,12 +31,12 @@ pub trait StateStore: Send + Sync { async fn get_session(&self, session_id: &str) -> Result; async fn delete_session(&self, session_id: &str) -> Result<()>; async fn update_session(&self, session: SessionRecord) -> Result<()>; - async fn put_query(&self, query: Query) -> Result<()>; + async fn put_query(&self, query: &Query) -> Result<()>; async fn get_query(&self, query_id: &str) -> Result; async fn get_query_by_request_id(&self, request_id: &str) -> Result; async fn get_queries_by_session_id(&self, session_id: &str) -> Result>; async fn delete_query(&self, query_id: &str) -> Result<()>; - async fn update_query(&self, query: Query) -> Result<()>; + async fn update_query(&self, query: &Query) -> Result<()>; } /// `DynamoDB` single-table client. @@ -217,7 +218,7 @@ impl StateStore for DynamoDbStateStore { self.put_session(session).await } - async fn put_query(&self, query: Query) -> Result<()> { + async fn put_query(&self, query: &Query) -> Result<()> { let mut item = HashMap::new(); let pk = Self::query_pk(&query.start_time); let sk = Self::query_sk(&query.start_time); @@ -284,7 +285,7 @@ impl StateStore for DynamoDbStateStore { Ok(()) } - async fn update_query(&self, query: Query) -> Result<()> { + async fn update_query(&self, query: &Query) -> Result<()> { self.put_query(query).await } }