diff --git a/Makefile b/Makefile index 155b7ea1..477f0ffe 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,7 @@ INTEG_API_INVOKE := RestApiUrl HttpApiUrl INTEG_EXTENSIONS := extension-fn extension-trait logs-trait # Using musl to run extensions on both AL1 and AL2 INTEG_ARCH := x86_64-unknown-linux-musl +RIE_MAX_CONCURRENCY ?= 4 define uppercase $(shell sed -r 's/(^|-)(\w)/\U\2/g' <<< $(1)) @@ -111,4 +112,8 @@ fmt: cargo +nightly fmt --all test-rie: - ./scripts/test-rie.sh $(EXAMPLE) \ No newline at end of file + ./scripts/test-rie.sh $(EXAMPLE) + +# Run RIE in Lambda Managed Instance (LMI) mode with concurrent polling. +test-rie-lmi: + RIE_MAX_CONCURRENCY=$(RIE_MAX_CONCURRENCY) ./scripts/test-rie.sh $(EXAMPLE) diff --git a/examples/basic-lambda-concurrent/Cargo.toml b/examples/basic-lambda-concurrent/Cargo.toml new file mode 100644 index 00000000..180b4830 --- /dev/null +++ b/examples/basic-lambda-concurrent/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "basic-lambda-concurrent" +version = "0.1.0" +edition = "2021" + +[dependencies] +lambda_runtime = { path = "../../lambda-runtime" } +serde = "1.0.219" +tokio = { version = "1", features = ["macros"] } diff --git a/examples/basic-lambda-concurrent/src/main.rs b/examples/basic-lambda-concurrent/src/main.rs new file mode 100644 index 00000000..018a2dba --- /dev/null +++ b/examples/basic-lambda-concurrent/src/main.rs @@ -0,0 +1,74 @@ +// This example requires the following input to succeed: +// { "command": "do something" } + +use lambda_runtime::{service_fn, tracing, Error, LambdaEvent}; +use serde::{Deserialize, Serialize}; + +/// This is also a made-up example. Requests come into the runtime as unicode +/// strings in json format, which can map to any structure that implements `serde::Deserialize` +/// The runtime pays no attention to the contents of the request payload. +#[derive(Deserialize)] +struct Request { + command: String, +} + +/// This is a made-up example of what a response structure may look like. +/// There is no restriction on what it can be. The runtime requires responses +/// to be serialized into json. The runtime pays no attention +/// to the contents of the response payload. +#[derive(Serialize)] +struct Response { + req_id: String, + msg: String, +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + // required to enable CloudWatch error logging by the runtime + tracing::init_default_subscriber(); + + let func = service_fn(my_handler); + if let Err(err) = lambda_runtime::run_concurrent(func).await { + eprintln!("run error: {:?}", err); + return Err(err); + } + Ok(()) +} + +pub(crate) async fn my_handler(event: LambdaEvent) -> Result { + // extract some useful info from the request + let command = event.payload.command; + + // prepare the response + let resp = Response { + req_id: event.context.request_id, + msg: format!("Command {command} executed."), + }; + + // return `Response` (it will be serialized to JSON automatically by the runtime) + Ok(resp) +} + +#[cfg(test)] +mod tests { + use crate::{my_handler, Request}; + use lambda_runtime::{Context, LambdaEvent}; + + #[tokio::test] + async fn response_is_good_for_simple_input() { + let id = "ID"; + + let mut context = Context::default(); + context.request_id = id.to_string(); + + let payload = Request { + command: "X".to_string(), + }; + let event = LambdaEvent { payload, context }; + + let result = my_handler(event).await.unwrap(); + + assert_eq!(result.msg, "Command X executed."); + assert_eq!(result.req_id, id.to_string()); + } +} diff --git a/examples/basic-lambda/src/main.rs b/examples/basic-lambda/src/main.rs index d3f2a3cd..396c3afd 100644 --- a/examples/basic-lambda/src/main.rs +++ b/examples/basic-lambda/src/main.rs @@ -28,7 +28,10 @@ async fn main() -> Result<(), Error> { tracing::init_default_subscriber(); let func = service_fn(my_handler); - lambda_runtime::run(func).await?; + if let Err(err) = lambda_runtime::run(func).await { + eprintln!("run error: {:?}", err); + return Err(err); + } Ok(()) } diff --git a/lambda-http/src/lib.rs b/lambda-http/src/lib.rs index 60e279c7..1d44d6b0 100644 --- a/lambda-http/src/lib.rs +++ b/lambda-http/src/lib.rs @@ -102,7 +102,7 @@ use std::{ }; mod streaming; -pub use streaming::{run_with_streaming_response, StreamAdapter}; +pub use streaming::{run_with_streaming_response, run_with_streaming_response_concurrent, StreamAdapter}; /// Type alias for `http::Request`s with a fixed [`Body`](enum.Body.html) type pub type Request = http::Request; @@ -151,6 +151,18 @@ pub struct Adapter<'a, R, S> { _phantom_data: PhantomData<&'a R>, } +impl<'a, R, S> Clone for Adapter<'a, R, S> +where + S: Clone, +{ + fn clone(&self) -> Self { + Self { + service: self.service.clone(), + _phantom_data: PhantomData, + } + } +} + impl<'a, R, S, E> From for Adapter<'a, R, S> where S: Service, @@ -203,6 +215,24 @@ where lambda_runtime::run(Adapter::from(handler)).await } +/// Starts the Lambda Rust runtime in a mode that is compatible with +/// Lambda Managed Instances (concurrent invocations). +/// +/// When `AWS_LAMBDA_MAX_CONCURRENCY` is set to a value greater than 1, this +/// will spawn `AWS_LAMBDA_MAX_CONCURRENCY` worker tasks, each running its own +/// `/next` polling loop. When the environment variable is unset or `<= 1`, +/// it falls back to the same sequential behavior as [`run`], so the same +/// handler can run on both classic Lambda and Lambda Managed Instances. +pub async fn run_concurrent(handler: S) -> Result<(), Error> +where + S: Service + Clone + Send + 'static, + S::Future: Send + 'static, + R: IntoResponse + Send + Sync + 'static, + E: std::fmt::Debug + Into + Send + 'static, +{ + lambda_runtime::run_concurrent(Adapter::from(handler)).await +} + #[cfg(test)] mod test_adapter { use std::task::{Context, Poll}; diff --git a/lambda-http/src/streaming.rs b/lambda-http/src/streaming.rs index ed61c773..a729206c 100644 --- a/lambda-http/src/streaming.rs +++ b/lambda-http/src/streaming.rs @@ -10,7 +10,7 @@ pub use http::{self, Response}; use http_body::Body; use lambda_runtime::{ tower::{ - util::{MapRequest, MapResponse}, + util::{BoxCloneService, MapRequest, MapResponse}, ServiceBuilder, ServiceExt, }, Diagnostic, @@ -93,14 +93,33 @@ where B::Error: Into + Send + Debug, { ServiceBuilder::new() - .map_request(|req: LambdaEvent| { - let event: Request = req.payload.into(); - event.with_lambda_context(req.context) - }) + .map_request(event_to_request as fn(LambdaEvent) -> Request) .service(handler) .map_response(into_stream_response) } +/// Builds a streaming-aware Tower service from a `Service` that can be +/// cloned and sent across tasks. This is used by the concurrent HTTP entrypoint. +#[allow(clippy::type_complexity)] +fn into_stream_service_boxed( + handler: S, +) -> BoxCloneService, StreamResponse>, E> +where + S: Service, Error = E> + Clone + Send + 'static, + S::Future: Send + 'static, + E: Debug + Into + Send + 'static, + B: Body + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, +{ + let svc = ServiceBuilder::new() + .map_request(event_to_request as fn(LambdaEvent) -> Request) + .service(handler) + .map_response(into_stream_response); + + BoxCloneService::new(svc) +} + /// Converts an `http::Response` into a streaming Lambda response. fn into_stream_response(res: Response) -> StreamResponse> where @@ -128,6 +147,11 @@ where } } +fn event_to_request(req: LambdaEvent) -> Request { + let event: Request = req.payload.into(); + event.with_lambda_context(req.context) +} + /// Runs the Lambda runtime with a handler that returns **streaming** HTTP /// responses. /// @@ -147,6 +171,24 @@ where lambda_runtime::run(into_stream_service(handler)).await } +/// Runs the Lambda runtime with a handler that returns **streaming** HTTP +/// responses, in a mode that is compatible with Lambda Managed Instances. +/// +/// This uses a cloneable, boxed service internally so it can be driven by the +/// concurrent runtime. When `AWS_LAMBDA_MAX_CONCURRENCY` is not set or `<= 1`, +/// it falls back to the same sequential behavior as [`run_with_streaming_response`]. +pub async fn run_with_streaming_response_concurrent(handler: S) -> Result<(), Error> +where + S: Service, Error = E> + Clone + Send + 'static, + S::Future: Send + 'static, + E: Debug + Into + Send + 'static, + B: Body + Unpin + Send + 'static, + B::Data: Into + Send, + B::Error: Into + Send + Debug, +{ + lambda_runtime::run_concurrent(into_stream_service_boxed(handler)).await +} + pin_project_lite::pin_project! { #[non_exhaustive] pub struct BodyStream { diff --git a/lambda-runtime-api-client/src/lib.rs b/lambda-runtime-api-client/src/lib.rs index 3df616ab..86cc715f 100644 --- a/lambda-runtime-api-client/src/lib.rs +++ b/lambda-runtime-api-client/src/lib.rs @@ -41,6 +41,7 @@ impl Client { ClientBuilder { connector: HttpConnector::new(), uri: None, + pool_size: None, } } } @@ -59,11 +60,16 @@ impl Client { self.client.request(req).map_err(Into::into).boxed() } - /// Create a new client with a given base URI and HTTP connector. - fn with(base: Uri, connector: HttpConnector) -> Self { - let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) - .http1_max_buf_size(1024 * 1024) - .build(connector); + /// Create a new client with a given base URI, HTTP connector, and optional pool size hint. + fn with(base: Uri, connector: HttpConnector, pool_size: Option) -> Self { + let mut builder = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()); + builder.http1_max_buf_size(1024 * 1024); + + if let Some(size) = pool_size { + builder.pool_max_idle_per_host(size); + } + + let client = builder.build(connector); Self { base, client } } @@ -94,6 +100,7 @@ impl Client { pub struct ClientBuilder { connector: HttpConnector, uri: Option, + pool_size: Option, } impl ClientBuilder { @@ -102,6 +109,7 @@ impl ClientBuilder { ClientBuilder { connector, uri: self.uri, + pool_size: self.pool_size, } } @@ -111,6 +119,14 @@ impl ClientBuilder { Self { uri: Some(uri), ..self } } + /// Provide a pool size hint for the underlying Hyper client. + pub fn with_pool_size(self, pool_size: usize) -> Self { + Self { + pool_size: Some(pool_size), + ..self + } + } + /// Create the new client to interact with the Runtime API. pub fn build(self) -> Result { let uri = match self.uri { @@ -120,7 +136,7 @@ impl ClientBuilder { uri.try_into().expect("Unable to convert to URL") } }; - Ok(Client::with(uri, self.connector)) + Ok(Client::with(uri, self.connector, self.pool_size)) } } @@ -182,4 +198,17 @@ mod tests { &req.uri().to_string() ); } + + #[test] + fn builder_accepts_pool_size() { + let base = "http://localhost:9001"; + let expected: Uri = base.parse().unwrap(); + let client = Client::builder() + .with_pool_size(4) + .with_endpoint(base.parse().unwrap()) + .build() + .unwrap(); + + assert_eq!(client.base, expected); + } } diff --git a/lambda-runtime/src/layers/api_client.rs b/lambda-runtime/src/layers/api_client.rs index d44a84f2..7113ee0a 100644 --- a/lambda-runtime/src/layers/api_client.rs +++ b/lambda-runtime/src/layers/api_client.rs @@ -44,6 +44,18 @@ where } } +impl Clone for RuntimeApiClientService +where + S: Clone, +{ + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + client: self.client.clone(), + } + } +} + #[pin_project(project = RuntimeApiClientFutureProj)] pub enum RuntimeApiClientFuture { First(#[pin] F, Arc), diff --git a/lambda-runtime/src/layers/api_response.rs b/lambda-runtime/src/layers/api_response.rs index 453f8b4c..5bb3c96f 100644 --- a/lambda-runtime/src/layers/api_response.rs +++ b/lambda-runtime/src/layers/api_response.rs @@ -51,6 +51,27 @@ impl Clone + for RuntimeApiResponseService< + S, + EventPayload, + Response, + BufferedResponse, + StreamingResponse, + StreamItem, + StreamError, + > +where + S: Clone, +{ + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + _phantom: PhantomData, + } + } +} + impl Service for RuntimeApiResponseService< S, diff --git a/lambda-runtime/src/layers/trace.rs b/lambda-runtime/src/layers/trace.rs index e93927b1..4a3ad3d9 100644 --- a/lambda-runtime/src/layers/trace.rs +++ b/lambda-runtime/src/layers/trace.rs @@ -25,6 +25,7 @@ impl Layer for TracingLayer { } /// Tower service returned by [TracingLayer]. +#[derive(Clone)] pub struct TracingService { inner: S, } diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index cbcd0a9e..c1bf2c70 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -59,6 +59,9 @@ pub struct Config { pub log_stream: String, /// The name of the Amazon CloudWatch Logs group for the function. pub log_group: String, + /// Maximum concurrent invocations for Lambda managed-concurrency environments. + /// Populated from `AWS_LAMBDA_MAX_CONCURRENCY` when present. + pub max_concurrency: Option, } type RefConfig = Arc; @@ -75,8 +78,17 @@ impl Config { version: env::var("AWS_LAMBDA_FUNCTION_VERSION").expect("Missing AWS_LAMBDA_FUNCTION_VERSION env var"), log_stream: env::var("AWS_LAMBDA_LOG_STREAM_NAME").unwrap_or_default(), log_group: env::var("AWS_LAMBDA_LOG_GROUP_NAME").unwrap_or_default(), + max_concurrency: env::var("AWS_LAMBDA_MAX_CONCURRENCY") + .ok() + .and_then(|v| v.parse::().ok()) + .filter(|&c| c > 0), } } + + /// Returns true if concurrent runtime mode should be enabled. + pub fn is_concurrent(&self) -> bool { + self.max_concurrency.map(|c| c > 1).unwrap_or(false) + } } /// Return a new [`ServiceFn`] with a closure that takes an event and context as separate arguments. @@ -126,6 +138,50 @@ where runtime.run().await } +/// Starts the Lambda Rust runtime in a mode that is compatible with +/// Lambda Managed Instances (concurrent invocations). +/// +/// When `AWS_LAMBDA_MAX_CONCURRENCY` is set to a value greater than 1, this +/// will spawn `AWS_LAMBDA_MAX_CONCURRENCY` worker tasks, each running its own +/// `/next` polling loop. When the environment variable is unset or `<= 1`, it +/// falls back to the same sequential behavior as [`run`], so the same handler +/// can run on both classic Lambda and Lambda Managed Instances. +/// +/// If you need more control over the runtime and add custom middleware, use the +/// [Runtime] type directly. +/// +/// # Example +/// ```no_run +/// use lambda_runtime::{Error, service_fn, LambdaEvent}; +/// use serde_json::Value; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Error> { +/// let func = service_fn(func); +/// lambda_runtime::run_concurrent(func).await?; +/// Ok(()) +/// } +/// +/// async fn func(event: LambdaEvent) -> Result { +/// Ok(event.payload) +/// } +/// ``` +pub async fn run_concurrent(handler: F) -> Result<(), Error> +where + F: Service, Response = R> + Clone + Send + 'static, + F::Future: Future> + Send + 'static, + F::Error: Into + fmt::Debug, + A: for<'de> Deserialize<'de> + Send + 'static, + R: IntoFunctionResponse + Send + 'static, + B: Serialize + Send + 'static, + S: Stream> + Unpin + Send + 'static, + D: Into + Send + 'static, + E: Into + Send + Debug + 'static, +{ + let runtime = Runtime::new(handler).layer(layers::TracingLayer::new()); + runtime.run_concurrent().await +} + /// Spawns a task that will be execute a provided async closure when the process /// receives unix graceful shutdown signals. If the closure takes longer than 500ms /// to execute, an unhandled `SIGKILL` signal might be received. diff --git a/lambda-runtime/src/runtime.rs b/lambda-runtime/src/runtime.rs index 517ee64f..ab10fc46 100644 --- a/lambda-runtime/src/runtime.rs +++ b/lambda-runtime/src/runtime.rs @@ -4,13 +4,20 @@ use crate::{ types::{invoke_request_id, IntoFunctionResponse, LambdaEvent}, Config, Context, Diagnostic, }; +use futures::stream::FuturesUnordered; use http_body_util::BodyExt; use lambda_runtime_api_client::{BoxError, Client as ApiClient}; use serde::{Deserialize, Serialize}; -use std::{env, fmt::Debug, future::Future, sync::Arc}; +use std::{ + env, + fmt::Debug, + future::Future, + io, + sync::{Arc, OnceLock}, +}; use tokio_stream::{Stream, StreamExt}; use tower::{Layer, Service, ServiceExt}; -use tracing::trace; +use tracing::{error, trace, warn}; /* ----------------------------------------- INVOCATION ---------------------------------------- */ @@ -55,6 +62,9 @@ pub struct Runtime { client: Arc, } +/// One-time marker to log X-Ray behavior in concurrent mode. +static XRAY_LOGGED: OnceLock<()> = OnceLock::new(); + impl Runtime< RuntimeApiClientService< @@ -92,7 +102,13 @@ where pub fn new(handler: F) -> Self { trace!("Loading config from env"); let config = Arc::new(Config::from_env()); - let client = Arc::new(ApiClient::builder().build().expect("Unable to create a runtime client")); + let pool_size = config.max_concurrency.unwrap_or(1).max(1) as usize; + let client = Arc::new( + ApiClient::builder() + .with_pool_size(pool_size) + .build() + .expect("Unable to create a runtime client"), + ); Self { service: wrap_handler(handler, client.clone()), config, @@ -137,6 +153,92 @@ impl Runtime { } } +impl Runtime +where + S: Service + Clone + Send + 'static, + S::Future: Send, +{ + /// Start the runtime in concurrent mode when configured for Lambda managed-concurrency. + /// + /// If `AWS_LAMBDA_MAX_CONCURRENCY` is not set or is `<= 1`, this falls back to the + /// sequential `run_with_incoming` loop so that the same handler can run on both + /// classic Lambda and Lambda Managed Instances. + pub async fn run_concurrent(self) -> Result<(), BoxError> { + if self.config.is_concurrent() { + let max_concurrency = self.config.max_concurrency.unwrap_or(1); + Self::run_concurrent_inner(self.service, self.config, self.client, max_concurrency).await + } else { + let incoming = incoming(&self.client); + Self::run_with_incoming(self.service, self.config, incoming).await + } + } + + /// Concurrent processing using N independent long-poll loops (for Lambda managed-concurrency). + async fn run_concurrent_inner( + service: S, + config: Arc, + client: Arc, + max_concurrency: u32, + ) -> Result<(), BoxError> { + let limit = max_concurrency as usize; + + let mut workers = FuturesUnordered::new(); + for _ in 1..limit { + workers.push(tokio::spawn(concurrent_worker_loop( + service.clone(), + config.clone(), + client.clone(), + ))); + } + workers.push(tokio::spawn(concurrent_worker_loop(service, config, client))); + + // Track the first infrastructure error. A single worker failing should + // not terminate the whole runtime (LMI keeps running with the remaining + // healthy workers). We only return an error once there are no workers + // left (i.e., we cannot keep at least 1 worker alive). + // + // Note: Handler errors (Err returned from user code) do NOT trigger this. + // They are reported to Lambda via /invocation/{id}/error and the worker + // continues. This only captures unrecoverable runtime failures like + // API client failures, runtime panics, etc. + let mut first_error: Option = None; + let mut remaining_workers = limit; + while let Some(result) = futures::StreamExt::next(&mut workers).await { + remaining_workers = remaining_workers.saturating_sub(1); + match result { + Ok(Ok(())) => { + // `concurrent_worker_loop` runs indefinitely, so an Ok return indicates + // an unexpected worker exit; we still decrement because the task is gone. + warn!(remaining_workers, "Concurrent worker exited unexpectedly without error"); + if first_error.is_none() { + first_error = Some(Box::new(io::Error::other( + "all concurrent workers exited unexpectedly without error", + ))); + } + } + Ok(Err(err)) => { + error!(error = %err, remaining_workers, "Concurrent worker exited with error"); + if first_error.is_none() { + first_error = Some(err); + } + } + Err(join_err) => { + let err: BoxError = Box::new(join_err); + error!(error = %err, remaining_workers, "Concurrent worker panicked"); + if first_error.is_none() { + first_error = Some(err); + } + } + } + } + + match first_error { + Some(err) => Err(err), + None => Ok(()), + } + } +} + impl Runtime where S: Service, @@ -158,30 +260,7 @@ where while let Some(next_event_response) = incoming.next().await { trace!("New event arrived (run loop)"); let event = next_event_response?; - let (parts, incoming) = event.into_parts(); - - #[cfg(debug_assertions)] - if parts.status == http::StatusCode::NO_CONTENT { - // Ignore the event if the status code is 204. - // This is a way to keep the runtime alive when - // there are no events pending to be processed. - continue; - } - - // Build the invocation such that it can be sent to the service right away - // when it is ready - let body = incoming.collect().await?.to_bytes(); - let context = Context::new(invoke_request_id(&parts.headers)?, config.clone(), &parts.headers)?; - let invocation = LambdaInvocation { parts, body, context }; - - // Setup Amazon's default tracing data - amzn_trace_env(&invocation.context); - - // Wait for service to be ready - let ready = service.ready().await?; - - // Once ready, call the service which will respond to the Lambda runtime API - ready.call(invocation).await?; + process_invocation(&mut service, &config, event, true).await?; } Ok(()) } @@ -233,6 +312,73 @@ fn incoming( } } +/// Creates a future that polls the `/next` endpoint. +async fn next_event_future(client: Arc) -> Result, BoxError> { + let req = NextEventRequest.into_req()?; + client.call(req).await +} + +async fn concurrent_worker_loop(mut service: S, config: Arc, client: Arc) -> Result<(), BoxError> +where + S: Service, + S::Future: Send, +{ + loop { + let event = match next_event_future(client.clone()).await { + Ok(event) => event, + Err(e) => { + warn!(error = %e, "Error polling /next, retrying"); + continue; + } + }; + + process_invocation(&mut service, &config, event, false).await?; + } +} + +async fn process_invocation( + service: &mut S, + config: &Arc, + event: http::Response, + set_amzn_trace_env: bool, +) -> Result<(), BoxError> +where + S: Service, +{ + let (parts, incoming) = event.into_parts(); + + #[cfg(debug_assertions)] + if parts.status == http::StatusCode::NO_CONTENT { + // Ignore the event if the status code is 204. + // This is a way to keep the runtime alive when + // there are no events pending to be processed. + return Ok(()); + } + + // Build the invocation such that it can be sent to the service right away + // when it is ready + let body = incoming.collect().await?.to_bytes(); + let context = Context::new(invoke_request_id(&parts.headers)?, config.clone(), &parts.headers)?; + let invocation = LambdaInvocation { parts, body, context }; + + if set_amzn_trace_env { + // Setup Amazon's default tracing data + amzn_trace_env(&invocation.context); + } else { + // Inform users that X-Ray is available via context, not env var, in concurrent mode. + XRAY_LOGGED.get_or_init(|| { + trace!("Concurrent mode: _X_AMZN_TRACE_ID is not set; use context.xray_trace_id"); + }); + } + + // Wait for service to be ready + let ready = service.ready().await?; + + // Once ready, call the service which will respond to the Lambda runtime API + ready.call(invocation).await?; + Ok(()) +} + fn amzn_trace_env(ctx: &Context) { match &ctx.xray_trace_id { Some(trace_id) => env::set_var("_X_AMZN_TRACE_ID", trace_id), @@ -251,13 +397,28 @@ mod endpoint_tests { requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest}, Config, Diagnostic, Error, Runtime, }; + use bytes::Bytes; use futures::future::BoxFuture; - use http::{HeaderValue, StatusCode}; - use http_body_util::BodyExt; + use http::{HeaderValue, Method, Request, Response, StatusCode}; + use http_body_util::{BodyExt, Full}; use httpmock::prelude::*; + use hyper::{body::Incoming, service::service_fn}; + use hyper_util::{ + rt::{tokio::TokioIo, TokioExecutor}, + server::conn::auto::Builder as ServerBuilder, + }; use lambda_runtime_api_client::Client; - use std::{env, sync::Arc}; + use std::{ + convert::Infallible, + env, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, + }; + use tokio::{net::TcpListener, sync::Notify}; use tokio_stream::StreamExt; #[tokio::test] @@ -456,6 +617,7 @@ mod endpoint_tests { version: "1".to_string(), log_stream: "test_stream".to_string(), log_group: "test_log".to_string(), + max_concurrency: None, }); let client = Arc::new(client); @@ -485,4 +647,206 @@ mod endpoint_tests { }) .await } + + #[test] + fn config_parses_max_concurrency() { + // Preserve existing env values + let prev_fn = env::var("AWS_LAMBDA_FUNCTION_NAME").ok(); + let prev_mem = env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE").ok(); + let prev_ver = env::var("AWS_LAMBDA_FUNCTION_VERSION").ok(); + let prev_log_stream = env::var("AWS_LAMBDA_LOG_STREAM_NAME").ok(); + let prev_log_group = env::var("AWS_LAMBDA_LOG_GROUP_NAME").ok(); + let prev_max = env::var("AWS_LAMBDA_MAX_CONCURRENCY").ok(); + + env::set_var("AWS_LAMBDA_FUNCTION_NAME", "test_fn"); + env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "128"); + env::set_var("AWS_LAMBDA_FUNCTION_VERSION", "1"); + env::set_var("AWS_LAMBDA_LOG_STREAM_NAME", "test_stream"); + env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log"); + env::set_var("AWS_LAMBDA_MAX_CONCURRENCY", "4"); + + let cfg = Config::from_env(); + assert_eq!(cfg.max_concurrency, Some(4)); + assert!(cfg.is_concurrent()); + + // Restore env + if let Some(v) = prev_fn { + env::set_var("AWS_LAMBDA_FUNCTION_NAME", v); + } else { + env::remove_var("AWS_LAMBDA_FUNCTION_NAME"); + } + if let Some(v) = prev_mem { + env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", v); + } else { + env::remove_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE"); + } + if let Some(v) = prev_ver { + env::set_var("AWS_LAMBDA_FUNCTION_VERSION", v); + } else { + env::remove_var("AWS_LAMBDA_FUNCTION_VERSION"); + } + if let Some(v) = prev_log_stream { + env::set_var("AWS_LAMBDA_LOG_STREAM_NAME", v); + } else { + env::remove_var("AWS_LAMBDA_LOG_STREAM_NAME"); + } + if let Some(v) = prev_log_group { + env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", v); + } else { + env::remove_var("AWS_LAMBDA_LOG_GROUP_NAME"); + } + if let Some(v) = prev_max { + env::set_var("AWS_LAMBDA_MAX_CONCURRENCY", v); + } else { + env::remove_var("AWS_LAMBDA_MAX_CONCURRENCY"); + } + } + + #[tokio::test] + async fn concurrent_worker_crash_does_not_stop_other_workers() -> Result<(), Error> { + let next_calls = Arc::new(AtomicUsize::new(0)); + let response_calls = Arc::new(AtomicUsize::new(0)); + let first_error_served = Arc::new(Notify::new()); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let base: http::Uri = format!("http://{addr}").parse().unwrap(); + + let server_handle = { + let next_calls = next_calls.clone(); + let response_calls = response_calls.clone(); + let first_error_served = first_error_served.clone(); + tokio::spawn(async move { + loop { + let (tcp, _) = match listener.accept().await { + Ok(v) => v, + Err(_) => return, + }; + + let next_calls = next_calls.clone(); + let response_calls = response_calls.clone(); + let first_error_served = first_error_served.clone(); + let service = service_fn(move |req: Request| { + let next_calls = next_calls.clone(); + let response_calls = response_calls.clone(); + let first_error_served = first_error_served.clone(); + async move { + let (parts, body) = req.into_parts(); + let method = parts.method; + let path = parts.uri.path().to_string(); + + if method == Method::POST { + // Drain request body to support keep-alive. + let _ = body.collect().await; + } + + if method == Method::GET && path == "/2018-06-01/runtime/invocation/next" { + let call_index = next_calls.fetch_add(1, Ordering::SeqCst); + match call_index { + // First worker errors (missing request id header). + 0 => { + first_error_served.notify_one(); + let res = Response::builder() + .status(StatusCode::OK) + .header("lambda-runtime-deadline-ms", "1542409706888") + .body(Full::new(Bytes::from_static(b"{}"))) + .unwrap(); + return Ok::<_, Infallible>(res); + } + // Second worker should keep running and process an invocation, even if another worker errors. + 1 => { + first_error_served.notified().await; + let res = Response::builder() + .status(StatusCode::OK) + .header("content-type", "application/json") + .header("lambda-runtime-aws-request-id", "good-request") + .header("lambda-runtime-deadline-ms", "1542409706888") + .body(Full::new(Bytes::from_static(b"{}"))) + .unwrap(); + return Ok::<_, Infallible>(res); + } + // Finally, error the remaining worker so the runtime can terminate and the test can assert behavior. + 2 => { + let res = Response::builder() + .status(StatusCode::OK) + .header("lambda-runtime-deadline-ms", "1542409706888") + .body(Full::new(Bytes::from_static(b"{}"))) + .unwrap(); + return Ok::<_, Infallible>(res); + } + _ => { + let res = Response::builder() + .status(StatusCode::NO_CONTENT) + .body(Full::new(Bytes::new())) + .unwrap(); + return Ok::<_, Infallible>(res); + } + } + } + + if method == Method::POST && path.ends_with("/response") { + response_calls.fetch_add(1, Ordering::SeqCst); + let res = Response::builder() + .status(StatusCode::OK) + .body(Full::new(Bytes::new())) + .unwrap(); + return Ok::<_, Infallible>(res); + } + + let res = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Full::new(Bytes::new())) + .unwrap(); + Ok::<_, Infallible>(res) + } + }); + + let io = TokioIo::new(tcp); + tokio::spawn(async move { + if let Err(err) = ServerBuilder::new(TokioExecutor::new()) + .serve_connection(io, service) + .await + { + eprintln!("Error serving connection: {err:?}"); + } + }); + } + }) + }; + + async fn func(event: crate::LambdaEvent) -> Result { + Ok(event.payload) + } + + let handler = crate::service_fn(func); + let client = Arc::new(Client::builder().with_endpoint(base).build()?); + let runtime = Runtime { + client: client.clone(), + config: Arc::new(Config { + function_name: "test_fn".to_string(), + memory: 128, + version: "1".to_string(), + log_stream: "test_stream".to_string(), + log_group: "test_log".to_string(), + max_concurrency: Some(2), + }), + service: wrap_handler(handler, client), + }; + + let res = tokio::time::timeout(Duration::from_secs(2), runtime.run_concurrent()).await; + assert!(res.is_ok(), "run_concurrent timed out"); + assert!( + res.unwrap().is_err(), + "expected runtime to terminate once all workers crashed" + ); + + assert_eq!( + response_calls.load(Ordering::SeqCst), + 1, + "expected remaining worker to keep running after a worker crash" + ); + + server_handle.abort(); + Ok(()) + } } diff --git a/lambda-runtime/src/types.rs b/lambda-runtime/src/types.rs index 5e5f487a..03cbfad0 100644 --- a/lambda-runtime/src/types.rs +++ b/lambda-runtime/src/types.rs @@ -104,13 +104,23 @@ impl Context { /// and the incoming request data. pub fn new(request_id: &str, env_config: RefConfig, headers: &HeaderMap) -> Result { let client_context: Option = if let Some(value) = headers.get("lambda-runtime-client-context") { - serde_json::from_str(value.to_str()?)? + let raw = value.to_str()?; + if raw.is_empty() { + None + } else { + Some(serde_json::from_str(raw)?) + } } else { None }; let identity: Option = if let Some(value) = headers.get("lambda-runtime-cognito-identity") { - serde_json::from_str(value.to_str()?)? + let raw = value.to_str()?; + if raw.is_empty() { + None + } else { + Some(serde_json::from_str(raw)?) + } } else { None }; diff --git a/scripts/test-rie.sh b/scripts/test-rie.sh index 911cb390..3561a8ee 100755 --- a/scripts/test-rie.sh +++ b/scripts/test-rie.sh @@ -2,16 +2,23 @@ set -euo pipefail EXAMPLE=${1:-basic-lambda} +# Optional: set RIE_MAX_CONCURRENCY to enable LMI mode (emulates AWS_LAMBDA_MAX_CONCURRENCY) +RIE_MAX_CONCURRENCY=${RIE_MAX_CONCURRENCY:-} echo "Building Docker image with RIE for example: $EXAMPLE..." docker build -f Dockerfile.rie --build-arg EXAMPLE=$EXAMPLE -t rust-lambda-rie-test . echo "Starting RIE container on port 9000..." -docker run -p 9000:8080 rust-lambda-rie-test & +if [ -n "$RIE_MAX_CONCURRENCY" ]; then + echo "Enabling LMI mode with AWS_LAMBDA_MAX_CONCURRENCY=$RIE_MAX_CONCURRENCY" + docker run -p 9000:8080 -e AWS_LAMBDA_MAX_CONCURRENCY="$RIE_MAX_CONCURRENCY" rust-lambda-rie-test & +else + docker run -p 9000:8080 rust-lambda-rie-test & +fi CONTAINER_PID=$! echo "Container started. Test with:" -if [ "$EXAMPLE" = "basic-lambda" ]; then +if [ "$EXAMPLE" = "basic-lambda" ] || [ "$EXAMPLE" = "basic-lambda-concurrent" ]; then echo "curl -XPOST 'http://localhost:9000/2015-03-31/functions/function/invocations' -d '{\"command\": \"test from RIE\"}' -H 'Content-Type: application/json'" else echo "For example '$EXAMPLE', check examples/$EXAMPLE/src/main.rs for the expected payload format." @@ -19,4 +26,4 @@ fi echo "" echo "Press Ctrl+C to stop the container." -wait $CONTAINER_PID \ No newline at end of file +wait $CONTAINER_PID