+
+ .. literalinclude:: jax_current_scaling_example.py
+ :language: python
+ :start-after: # START_CURRENT_SCALING_EXAMPLE
+ :end-before: # END_CURRENT_SCALING_EXAMPLE
\ No newline at end of file
diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_cast_process.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_cast_process.svg
new file mode 100644
index 00000000000..294fca318bc
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_cast_process.svg
@@ -0,0 +1,55 @@
+
+
diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_current_scaling_all_gather.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_current_scaling_all_gather.svg
new file mode 100644
index 00000000000..f984e1dd310
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_current_scaling_all_gather.svg
@@ -0,0 +1,78 @@
+
+
+
diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_formats.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_formats.svg
new file mode 100644
index 00000000000..bf86a29a6c9
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_formats.svg
@@ -0,0 +1,164 @@
+
+
diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg
new file mode 100644
index 00000000000..a07f596a000
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg
@@ -0,0 +1,110 @@
+
diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_tensor_core.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_tensor_core.svg
new file mode 100644
index 00000000000..5416b5f4c36
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_tensor_core.svg
@@ -0,0 +1,75 @@
+
+
+
diff --git a/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py b/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py
new file mode 100644
index 00000000000..2d7c7ed9c41
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py
@@ -0,0 +1,33 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+# START_CURRENT_SCALING_EXAMPLE
+
+import jax
+import jax.numpy as jnp
+import transformer_engine.jax as te
+from transformer_engine.jax.flax import DenseGeneral
+from transformer_engine.common.recipe import Float8CurrentScaling, Format
+
+# Create FP8 Current Scaling recipe
+# Available formats:
+# - Format.HYBRID (default) -- E4M3 for forward pass, E5M2 for backward pass
+# - Format.E4M3 -- E4M3 for both forward and backward pass
+recipe = Float8CurrentScaling(fp8_format=Format.HYBRID)
+
+with te.autocast(enabled=True, recipe=recipe):
+ # Create and initialize layer
+ layer = DenseGeneral(features=1024)
+ key = jax.random.PRNGKey(0)
+ x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
+ params = layer.init(key, x)
+
+ # Forward and backward pass
+ def loss_fn(params):
+ output = layer.apply(params, x)
+ return output.sum()
+
+ loss, grads = jax.value_and_grad(loss_fn)(params)
+
+# END_CURRENT_SCALING_EXAMPLE
diff --git a/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py b/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py
new file mode 100644
index 00000000000..1eef7cf9a99
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py
@@ -0,0 +1,29 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+# START_CURRENT_SCALING_EXAMPLE
+
+import torch
+import transformer_engine.pytorch as te
+from transformer_engine.common.recipe import Float8CurrentScaling, Format
+
+# Create FP8 Current Scaling recipe
+# Available formats:
+# - Format.HYBRID (default) -- E4M3 for forward pass, E5M2 for backward pass
+# - Format.E4M3 -- E4M3 for both forward and backward pass
+recipe = Float8CurrentScaling(fp8_format=Format.HYBRID)
+
+# Create a simple linear layer with bfloat16 parameters
+layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
+
+# Forward and backward pass
+inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
+
+with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
+ output = layer(inp)
+ loss = output.sum()
+
+loss.backward()
+
+# END_CURRENT_SCALING_EXAMPLE
diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst
new file mode 100644
index 00000000000..772ed73fab7
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst
@@ -0,0 +1,172 @@
+..
+ Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+FP8 Delayed Scaling
+===================================
+
+FP8 Delayed Scaling estimates scaling factors from historical amax values rather than computing them
+for each tensor. This reduces tensor reads per quantization from two to one, improving memory efficiency.
+
+Both this recipe and :doc:`FP8 Current Scaling <../fp8_current_scaling/fp8_current_scaling>` use
+the same FP8 formats (E4M3/E5M2) with one float32 scaling factor per tensor.
+Reading the FP8 Current Scaling documentation first is recommended.
+
+Quantization with delayed scaling factors
+-----------------------------------------
+
+FP8 Current Scaling requires two tensor reads per quantization: one to compute amax,
+one to cast. FP8 Delayed Scaling eliminates the first read by predicting the scaling factor
+from historical amax values - hence *delayed* (using past values) versus *current* (using present values).
+
+The quantization process works as follows:
+
+1. **Compute scaling factor from history** (no tensor read needed):
+ The scaling factor is derived from stored ``amax_history`` using the formula:
+
+ ``scaling_factor = FP8_MAX / amax``
+
+ where ``amax`` is computed from history using either ``max`` (default) or ``most_recent`` algorithm.
+
+2. **Quantize the tensor** (one tensor read):
+ Apply the scaling factor and cast to FP8. Values exceeding FP8 range are clipped.
+
+3. **Update history**:
+ Record the actual amax from this quantization for future iterations.
+
+Each module maintains an ``amax_history`` tensor of configurable length (``amax_history_len``)
+for each quantized tensor.
+
+.. raw:: html
+ :file: img/scaling_comparison.svg
+
+*Figure 1. Comparison of FP8 Current Scaling and FP8 Delayed Scaling quantization processes.*
+
+Amax History Management
+-----------------------
+
+The ``amax_history`` buffer acts as a sliding window of recent amax values.
+Position 0 serves as a staging area for the current amax, while positions 1 to N-1
+store the history from oldest to newest. Each quantization writes the observed amax
+to position 0, and after the pass completes, the history is rotated:
+
+.. code-block:: text
+
+ Before rotation: [amax_N, amax_1, amax_2, ..., amax_N-1] (amax_N = current, amax_1 = oldest)
+ After rotation: [0, amax_2, ..., amax_N-1, amax_N] (amax_1 dropped, amax_N appended)
+
+The effective history length is ``amax_history_len - 1`` since position 0 is reserved
+for the staging area.
+
+The implementation differs between PyTorch and JAX:
+
+.. tabs::
+
+ .. tab:: PyTorch
+
+ Each module creates two ``amax_history`` tensors, initialized to zero:
+
+ - Forward: shape ``(amax_history_len, num_gemms * 3)`` — three FP8 tensors per GEMM (input, weight, output)
+ - Backward: shape ``(amax_history_len, num_gemms * 2)`` — two FP8 tensors per GEMM (grad_output, grad_input)
+
+ During the first forward pass, modules register their ``amax_history`` tensors
+ to a **global buffer** associated with the autocast context. When the context exits,
+ a single CUDA kernel processes all registered tensors at once - performing both
+ amax reduction across GPUs and history rotation.
+
+ This batched approach (one kernel for all tensors instead of one kernel per tensor)
+ minimizes kernel launch overhead.
+
+ .. tab:: JAX
+
+ Each quantizer maintains its own ``amax_history`` as a Flax variable with shape ``(amax_history_len,)``.
+ There is no global buffer - each quantizer updates independently.
+
+ The rotation is performed per-quantizer using ``jnp.roll``:
+
+ .. code-block:: python
+
+ updated_amax_history = jnp.roll(amax_history, -1, -1)
+ amax_history = updated_amax_history.at[0].set(0.0)
+
+Here's how to use FP8 Delayed Scaling in PyTorch and JAX:
+
+.. tabs::
+
+ .. tab:: PyTorch
+
+ .. raw:: html
+
+
+
+ .. literalinclude:: jax_delayed_scaling_example.py
+ :language: python
+ :start-after: # START_DELAYED_SCALING_EXAMPLE
+ :end-before: # END_DELAYED_SCALING_EXAMPLE
+
+
+Distributed Training
+--------------------
+
+Since FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling,
+transpose gather is not supported. However, amax reduction works slightly differently in different frameworks.
+
+.. tabs::
+
+ .. tab:: PyTorch
+
+ Amax reduction is controlled by two parameters:
+
+ - ``reduce_amax`` in recipe: enables/disables reduction (required for SP and CP)
+ - ``amax_reduction_group`` in ``autocast``: specifies the process group for reduction
+
+ We recommend reducing amax across all GPUs where the tensor is sharded,
+ including data parallel ranks.
+
+ .. literalinclude:: pytorch_delayed_scaling_distributed_example.py
+ :language: python
+ :start-after: # START_AMAX_REDUCTION_EXAMPLE
+ :end-before: # END_AMAX_REDUCTION_EXAMPLE
+
+ In data parallel training, some modules may not execute on certain ranks
+ (e.g., MoE experts that receive no tokens). This is handled as follows:
+
+ - **First iteration**: All modules must execute on all ranks to register
+ their ``amax_history`` tensors in the global buffer. Mismatched registration
+ causes the ``all_reduce`` to hang due to different tensor sizes across ranks.
+ - **Subsequent iterations**: The ``autocast`` context must be entered and exited
+ on all ranks (this triggers the collective reduction). Individual modules can be
+ skipped - if no rank executes a module, its history is not rotated and scale
+ remains unchanged.
+
+
+ .. tab:: JAX
+
+ Amax reduction is always enabled and managed automatically.
+ Reduction scope: all parallelism axes except pipeline parallelism (TP, SP, DP/FSDP).
+
+ .. literalinclude:: jax_delayed_scaling_distributed_example.py
+ :language: python
+ :start-after: # START_AMAX_REDUCTION_EXAMPLE
+ :end-before: # END_AMAX_REDUCTION_EXAMPLE
+
+Supported devices
+-----------------
+
+Ada and later (SM 8.9+)
\ No newline at end of file
diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/img/scaling_comparison.svg b/docs/features/low_precision_training/fp8_delayed_scaling/img/scaling_comparison.svg
new file mode 100644
index 00000000000..aff4ba0da38
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_delayed_scaling/img/scaling_comparison.svg
@@ -0,0 +1,82 @@
+
+
+
diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py
new file mode 100644
index 00000000000..48f6944ac1f
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+# START_AMAX_REDUCTION_EXAMPLE
+import transformer_engine.jax as te
+from transformer_engine.common.recipe import DelayedScaling
+
+# Amax reduction scope is managed internally
+recipe = DelayedScaling(reduce_amax=True) # Must be True in JAX
+
+with te.autocast(enabled=True, recipe=recipe):
+ output = layer.apply(params, inp)
+
+# END_AMAX_REDUCTION_EXAMPLE
diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py
new file mode 100644
index 00000000000..0500e2d40d1
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py
@@ -0,0 +1,51 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import jax
+
+# Requires Ada (SM89) or newer for FP8 support
+cc = jax.devices()[0].device_kind
+assert (
+ "RTX 40" in cc
+ or "RTX 5" in cc
+ or "Ada" in cc
+ or "L40" in cc
+ or "H100" in cc
+ or "H200" in cc
+ or "GH" in cc
+ or "B100" in cc
+ or "B200" in cc
+ or "GB" in cc
+), "This example requires SM89 (Ada) or newer"
+
+# START_DELAYED_SCALING_EXAMPLE
+
+import jax
+import jax.numpy as jnp
+import transformer_engine.jax as te
+from transformer_engine.jax.flax import DenseGeneral
+from transformer_engine.common.recipe import DelayedScaling
+
+# Create FP8 Delayed Scaling recipe
+recipe = DelayedScaling(
+ margin=0, # Margin for scaling factor computation (default: 0)
+ amax_history_len=1024, # Length of amax history window (default: 1024)
+ amax_compute_algo="max", # How to compute amax from history (default: "max")
+)
+
+with te.autocast(enabled=True, recipe=recipe):
+ # Initialize layer and data
+ layer = DenseGeneral(features=1024)
+ key = jax.random.PRNGKey(0)
+ x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
+ params = layer.init(key, x)
+
+ # Forward and backward pass
+ def loss_fn(params):
+ output = layer.apply(params, x)
+ return output.sum()
+
+ loss, grads = jax.value_and_grad(loss_fn)(params)
+
+# END_DELAYED_SCALING_EXAMPLE
diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py
new file mode 100644
index 00000000000..2c99fe1a2cf
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+# START_AMAX_REDUCTION_EXAMPLE
+import torch.distributed as dist
+import transformer_engine.pytorch as te
+from transformer_engine.common.recipe import DelayedScaling
+
+# Create process group for amax reduction (e.g., all 8 GPUs)
+amax_reduction_group = dist.new_group(ranks=[0, 1, 2, 3, 4, 5, 6, 7])
+
+recipe = DelayedScaling(reduce_amax=True)
+
+with te.autocast(recipe=recipe, amax_reduction_group=amax_reduction_group):
+ output = model(inp)
+
+# END_AMAX_REDUCTION_EXAMPLE
diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py
new file mode 100644
index 00000000000..628f368641f
--- /dev/null
+++ b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import torch
+
+# Requires Ada (SM89) or newer for FP8 support
+assert torch.cuda.get_device_capability()[0] >= 9 or (
+ torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9
+), "This example requires SM89 (Ada) or newer"
+
+# START_DELAYED_SCALING_EXAMPLE
+
+import torch
+import transformer_engine.pytorch as te
+from transformer_engine.common.recipe import DelayedScaling
+
+# Create FP8 Delayed Scaling recipe
+recipe = DelayedScaling(
+ margin=0, # Margin for scaling factor computation (default: 0)
+ amax_history_len=1024, # Length of amax history window (default: 1024)
+ amax_compute_algo="max", # How to compute amax from history (default: "max")
+)
+
+# Create a linear layer with bfloat16 parameters
+layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
+
+# Forward and backward pass
+inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
+
+with te.autocast(enabled=True, recipe=recipe):
+ output = layer(inp)
+ loss = output.sum()
+
+loss.backward()
+
+# END_DELAYED_SCALING_EXAMPLE
diff --git a/docs/features/low_precision_training/index.rst b/docs/features/low_precision_training/index.rst
new file mode 100644
index 00000000000..39fba078811
--- /dev/null
+++ b/docs/features/low_precision_training/index.rst
@@ -0,0 +1,17 @@
+..
+ Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+Low precision training
+===================================
+
+.. toctree::
+
+ introduction/introduction.rst
+ performance_considerations/performance_considerations.rst
+ fp8_current_scaling/fp8_current_scaling.rst
+ fp8_delayed_scaling/fp8_delayed_scaling.rst
+ fp8_blockwise_scaling/fp8_blockwise_scaling.rst
+ mxfp8/mxfp8.rst
+ nvfp4/nvfp4.rst
\ No newline at end of file
diff --git a/docs/features/low_precision_training/introduction/autocast_jax.py b/docs/features/low_precision_training/introduction/autocast_jax.py
new file mode 100644
index 00000000000..1c0e91a338f
--- /dev/null
+++ b/docs/features/low_precision_training/introduction/autocast_jax.py
@@ -0,0 +1,101 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import jax
+
+# Requires Ada (SM89) or newer for FP8 support
+cc = jax.devices()[0].device_kind
+assert (
+ "RTX 40" in cc
+ or "RTX 5" in cc
+ or "Ada" in cc
+ or "L40" in cc
+ or "H100" in cc
+ or "H200" in cc
+ or "GH" in cc
+ or "B100" in cc
+ or "B200" in cc
+ or "GB" in cc
+), "This example requires SM89 (Ada) or newer"
+
+# START_AUTOCAST_BASIC
+
+import jax
+import jax.numpy as jnp
+import transformer_engine.jax as te
+from transformer_engine.jax.flax import TransformerLayer
+from transformer_engine.jax.sharding import MeshResource, global_shard_guard
+from transformer_engine.common.recipe import DelayedScaling, Format
+
+# Set up recipe
+recipe = DelayedScaling()
+
+# Model initialization must happen inside autocast
+with global_shard_guard(MeshResource()):
+ with te.autocast(enabled=True, recipe=recipe):
+ layer = TransformerLayer(
+ hidden_size=1024,
+ mlp_hidden_size=4096,
+ num_attention_heads=16,
+ )
+
+ init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0))
+ x = jax.random.normal(init_key, (32, 128, 1024), dtype=jnp.bfloat16)
+ params = layer.init({"params": init_key, "dropout": dropout_key}, x)
+
+ # Forward and backward pass (both inside autocast for JAX)
+ def loss_fn(params):
+ output = layer.apply(params, x, rngs={"dropout": dropout_key})
+ return output.sum()
+
+ loss, grads = jax.value_and_grad(loss_fn)(params)
+
+# END_AUTOCAST_BASIC
+
+
+# START_AUTOCAST_SEQUENTIAL
+
+encoder_recipe = DelayedScaling(fp8_format=Format.E4M3)
+decoder_recipe = DelayedScaling(fp8_format=Format.HYBRID)
+
+with global_shard_guard(MeshResource()):
+ with te.autocast(enabled=True, recipe=encoder_recipe):
+ encoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
+ encoder_params = encoder.init({"params": init_key, "dropout": dropout_key}, x)
+ hidden = encoder.apply(encoder_params, x, rngs={"dropout": dropout_key})
+
+ with te.autocast(enabled=True, recipe=decoder_recipe):
+ decoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
+ decoder_params = decoder.init({"params": init_key, "dropout": dropout_key}, hidden)
+ output = decoder.apply(decoder_params, hidden, rngs={"dropout": dropout_key})
+
+# END_AUTOCAST_SEQUENTIAL
+
+
+# START_AUTOCAST_NESTED
+
+outer_recipe = DelayedScaling(fp8_format=Format.E4M3)
+inner_recipe = DelayedScaling(fp8_format=Format.HYBRID)
+
+with global_shard_guard(MeshResource()):
+ with te.autocast(enabled=True, recipe=outer_recipe):
+ # layer1 uses outer_recipe
+ layer1 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
+ params1 = layer1.init({"params": init_key, "dropout": dropout_key}, x)
+ hidden = layer1.apply(params1, x, rngs={"dropout": dropout_key})
+
+ with te.autocast(enabled=True, recipe=inner_recipe):
+ # layer2 uses inner_recipe (overrides outer)
+ layer2 = TransformerLayer(
+ hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16
+ )
+ params2 = layer2.init({"params": init_key, "dropout": dropout_key}, hidden)
+ hidden = layer2.apply(params2, hidden, rngs={"dropout": dropout_key})
+
+ # layer3 uses outer_recipe again
+ layer3 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
+ params3 = layer3.init({"params": init_key, "dropout": dropout_key}, hidden)
+ output = layer3.apply(params3, hidden, rngs={"dropout": dropout_key})
+
+# END_AUTOCAST_NESTED
diff --git a/docs/features/low_precision_training/introduction/autocast_pytorch.py b/docs/features/low_precision_training/introduction/autocast_pytorch.py
new file mode 100644
index 00000000000..17d813b3fa9
--- /dev/null
+++ b/docs/features/low_precision_training/introduction/autocast_pytorch.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import torch
+
+# Requires Ada (SM89) or newer for FP8 support
+assert torch.cuda.get_device_capability()[0] >= 9 or (
+ torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9
+), "This example requires SM89 (Ada) or newer"
+
+# START_AUTOCAST_BASIC
+
+import torch
+import transformer_engine.pytorch as te
+from transformer_engine.common.recipe import DelayedScaling, Format
+
+recipe = DelayedScaling()
+layer = te.Linear(1024, 1024)
+inp = torch.randn(32, 1024, dtype=torch.float32, device="cuda")
+
+with te.autocast(enabled=True, recipe=recipe):
+ output = layer(inp)
+
+# .backward() is called outside of autocast
+loss = output.sum()
+loss.backward()
+
+# END_AUTOCAST_BASIC
+
+
+# START_AUTOCAST_SEQUENTIAL
+
+encoder_recipe = DelayedScaling(fp8_format=Format.E4M3)
+decoder_recipe = DelayedScaling(fp8_format=Format.HYBRID)
+
+encoder = te.Linear(1024, 1024)
+decoder = te.Linear(1024, 1024)
+
+with te.autocast(enabled=True, recipe=encoder_recipe):
+ hidden = encoder(inp)
+
+with te.autocast(enabled=True, recipe=decoder_recipe):
+ output = decoder(hidden)
+
+# END_AUTOCAST_SEQUENTIAL
+
+
+# START_AUTOCAST_NESTED
+
+outer_recipe = DelayedScaling(fp8_format=Format.E4M3)
+inner_recipe = DelayedScaling(fp8_format=Format.HYBRID)
+
+layer1 = te.Linear(1024, 1024)
+layer2 = te.Linear(1024, 1024)
+layer3 = te.Linear(1024, 1024)
+
+with te.autocast(enabled=True, recipe=outer_recipe):
+ # layer1 uses outer_recipe
+ x = layer1(inp)
+
+ with te.autocast(enabled=True, recipe=inner_recipe):
+ # layer2 uses inner_recipe (overrides outer)
+ x = layer2(x)
+
+ # layer3 uses outer_recipe again
+ output = layer3(x)
+
+# END_AUTOCAST_NESTED
diff --git a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py
new file mode 100644
index 00000000000..14647daa1b1
--- /dev/null
+++ b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py
@@ -0,0 +1,43 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+# START_BF16_FP16_TRAINING
+
+import jax
+import jax.numpy as jnp
+from transformer_engine.jax.flax import TransformerLayer
+from transformer_engine.jax.sharding import MeshResource, global_shard_guard
+
+
+def run_forward_backward(params_dtype, compute_dtype):
+ # Create TransformerLayer
+ layer = TransformerLayer(
+ hidden_size=1024,
+ mlp_hidden_size=4096,
+ num_attention_heads=16,
+ dtype=params_dtype,
+ )
+
+ # Initialize parameters
+ init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0))
+ x = jax.random.normal(init_key, (32, 128, 1024), dtype=compute_dtype)
+
+ # TransformerLayer requires mesh resource context
+ with global_shard_guard(MeshResource()):
+ params = layer.init({"params": init_key, "dropout": dropout_key}, x)
+
+ # Forward and backward pass
+ def loss_fn(params):
+ output = layer.apply(params, x, rngs={"dropout": dropout_key})
+ assert output.dtype == compute_dtype
+ return output.sum()
+
+ loss, grads = jax.value_and_grad(loss_fn)(params)
+
+
+run_forward_backward(jnp.float32, jnp.float32) # high precision training
+run_forward_backward(jnp.float32, jnp.bfloat16) # bfloat16 training with master weights in FP32
+run_forward_backward(jnp.bfloat16, jnp.bfloat16) # bfloat16 training with weights in BF16
+
+# END_BF16_FP16_TRAINING
diff --git a/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py b/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py
new file mode 100644
index 00000000000..8779f0bff03
--- /dev/null
+++ b/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+# START_BF16_FP16_TRAINING
+
+import torch
+import transformer_engine.pytorch as te
+from contextlib import nullcontext
+
+
+def run_forward_backward(params_dtype, autocast_precision, grad_scaler_enabled):
+ if grad_scaler_enabled:
+ grad_scaler = torch.amp.GradScaler("cuda")
+
+ layer = te.TransformerLayer(
+ hidden_size=1024,
+ ffn_hidden_size=4096,
+ num_attention_heads=16,
+ params_dtype=params_dtype,
+ )
+ x = torch.randn(32, 128, 1024, dtype=params_dtype, device="cuda")
+
+ autocast_ctx = (
+ torch.autocast(device_type="cuda", dtype=autocast_precision)
+ if autocast_precision is not None
+ else nullcontext()
+ )
+ with autocast_ctx:
+ output = layer(x)
+ assert (
+ output.dtype == autocast_precision if autocast_precision is not None else params_dtype
+ )
+ loss = output.sum()
+ if grad_scaler_enabled:
+ grad_scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+
+run_forward_backward(torch.float32, torch.float32, False) # high precision training
+run_forward_backward(
+ torch.float32, torch.bfloat16, False
+) # bfloat16 training with master weights in FP32
+run_forward_backward(
+ torch.float32, torch.float16, True
+) # fp16 training with master weights in FP32, needs loss scaling
+run_forward_backward(
+ torch.bfloat16, torch.bfloat16, False
+) # bfloat16 training with weights in BF16
+
+# END_BF16_FP16_TRAINING
diff --git a/docs/features/low_precision_training/introduction/img/fp8_linear_flow.svg b/docs/features/low_precision_training/introduction/img/fp8_linear_flow.svg
new file mode 100644
index 00000000000..e1861ebc1cb
--- /dev/null
+++ b/docs/features/low_precision_training/introduction/img/fp8_linear_flow.svg
@@ -0,0 +1,172 @@
+
+
diff --git a/docs/features/low_precision_training/introduction/img/fp_formats_comparison.svg b/docs/features/low_precision_training/introduction/img/fp_formats_comparison.svg
new file mode 100644
index 00000000000..a6c46b364d1
--- /dev/null
+++ b/docs/features/low_precision_training/introduction/img/fp_formats_comparison.svg
@@ -0,0 +1,183 @@
+
diff --git a/docs/features/low_precision_training/introduction/img/master_weights_approaches.svg b/docs/features/low_precision_training/introduction/img/master_weights_approaches.svg
new file mode 100644
index 00000000000..b231fefd903
--- /dev/null
+++ b/docs/features/low_precision_training/introduction/img/master_weights_approaches.svg
@@ -0,0 +1,112 @@
+
+
diff --git a/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg b/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg
new file mode 100644
index 00000000000..708e6ea50f7
--- /dev/null
+++ b/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg
@@ -0,0 +1,105 @@
+
+
diff --git a/docs/features/low_precision_training/introduction/introduction.rst b/docs/features/low_precision_training/introduction/introduction.rst
new file mode 100644
index 00000000000..8a5d6c7acae
--- /dev/null
+++ b/docs/features/low_precision_training/introduction/introduction.rst
@@ -0,0 +1,277 @@
+..
+ Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+Introduction
+===================================
+
+Transformer Engine accelerates deep learning by leveraging low precision formats on NVIDIA GPUs.
+This chapter introduces mixed precision training and FP8 support.
+
+
+Training in BF16/FP16
+---------------------
+
+Deep learning traditionally uses 32-bit floating-point (FP32) numbers.
+NVIDIA GPUs support lower precision formats—FP16 since Pascal, BF16 since Ampere—which offer higher throughput and lower memory usage.
+Let's compare these formats.
+
+.. raw:: html
+ :file: img/fp_formats_comparison.svg
+
+*Figure 1: Comparison of FP32, BF16, and FP16 floating-point formats showing bit allocation for sign, exponent, and mantissa.*
+
+The key differences between these formats are:
+
+* **FP32** (32 bits total): 1 sign bit + 8 exponent bits + 23 mantissa bits – standard single-precision format
+* **BF16** (16 bits total): 1 sign bit + 8 exponent bits + 7 mantissa bits – maintains FP32's exponent range but reduced precision
+* **FP16** (16 bits total): 1 sign bit + 5 exponent bits + 10 mantissa bits – reduced range but higher precision than BF16
+
+BF16's advantage is that it shares the same exponent range as FP32,
+making it easier to convert between the two formats without overflow/underflow issues.
+FP16 offers better precision for smaller values but has a more limited dynamic range,
+which results in the need to perform loss scaling to avoid overflow/underflow—see `this paper on loss scaling `__ for more details.
+
+**Mixed precision**
+
+Not all operations can run in reduced precision.
+Modern deep learning frameworks use *mixed precision training*, where:
+
+* *Low precision* is used for matrix multiplications and other compute-heavy operations, which remain numerically stable at lower precision,
+* *High precision (FP32)* must be used for numerically sensitive operations to maintain training stability. These include layer normalization, softmax, and loss computations—operations that involve division or exponentiation, where small rounding errors can amplify and propagate through the network, leading to gradient instability or degraded convergence.
+
+**Master weights**
+
+Mixed precision training also raises the question of how to store model weights.
+Lower precision formats like FP16 and BF16 have limited representational granularity,
+which becomes problematic during gradient updates.
+When a small gradient is added to a not so small weight stored in low precision,
+the result may round back to the original value if the update falls below the format's precision threshold.
+Moreover, some elements of the gradient itself can be too small to be represented in low precision.
+
+The solution is to maintain *master weights* in FP32.
+During training, weights are cast to lower precision for forward and backward passes,
+but the gradient updates are applied to the full-precision master copy.
+This ensures that even small gradients accumulate correctly over time.
+
+There are two common software approaches to storing master weights:
+
+* *In the optimizer*:
+ The model holds low-precision weights,
+ while the optimizer maintains FP32 copies alongside momentum and other state.
+ During each step,
+ the optimizer updates its FP32 copy and casts the result back to the model's low-precision weights.
+ This makes it easier to shard master weights together with other optimizer state, for example in ZeRO optimizer.
+
+* *In the model*:
+ The model stores weights directly in FP32,
+ and they are cast to lower precision on-the-fly during forward and backward passes.
+ This approach works seamlessly with any standard optimizer, requiring no special support.
+
+.. raw:: html
+ :file: img/master_weights_approaches.svg
+
+*Figure 2: Three approaches to weight storage—low precision only (no master weights), master weights stored in the model, and master weights stored in the optimizer.*
+
+.. tabs::
+
+ .. tab:: PyTorch
+
+ The PyTorch API of Transformer Engine provides two mechanisms to control precision:
+
+ * **Weight precision**: Use the ``params_dtype`` argument in any TE layer constructor.
+ * **Computation precision**: Use the ``torch.autocast`` context manager.
+
+ If parameters are set to be in lower precision and no autocast is used, then lower precision is used for computation.
+ Input is cast to lower precision before the computation inside the layer.
+ Output precision is the same as autocast precision.
+
+ .. literalinclude:: bf16_fp16_training_pytorch.py
+ :language: python
+ :start-after: # START_BF16_FP16_TRAINING
+ :end-before: # END_BF16_FP16_TRAINING
+
+
+ .. tab:: JAX
+
+ The JAX API of Transformer Engine provides two mechanisms to control precision:
+
+ * **Weight precision**: Use the ``dtype`` argument in any TE layer constructor.
+ * **Computation precision**: Determined by the dtype of the input tensor.
+
+ For training with master weights in FP32 and computation in BF16,
+ cast the input tensor to BF16 before passing it to the layer.
+
+ .. literalinclude:: bf16_fp16_training_jax.py
+ :language: python
+ :start-after: # START_BF16_FP16_TRAINING
+ :end-before: # END_BF16_FP16_TRAINING
+
+
+
+Lower precisions
+----------------
+
+Transformer Engine's primary feature is supporting even lower precision than BF16/FP16, such as FP8, MXFP8, NVFP4, etc.
+The logic of these precisions is more complicated than the logic of BF16/FP16 – they require scaling factors to
+properly represent the full range of values in the tensor. Sometimes it is one scaling factor per tensor,
+sometimes it is one scaling factor per block of values. A precision format combined with the logic for training
+is called **a recipe**.
+
+In this section we present common logic for all the recipes. Each one of them is described in more detail in a separate section later.
+Let's now see how we can train in lower precisions in supported frameworks.
+
+.. tabs::
+
+ .. tab:: PyTorch
+
+ The PyTorch API of Transformer Engine provides an ``autocast`` context manager to control precision.
+ It's similar to the ``torch.autocast`` context manager, but tailored for low precision training.
+ The most important argument is the ``recipe`` argument, which accepts objects inheriting from
+ :class:`~transformer_engine.common.recipe.Recipe`.
+
+ Forward computations need to be performed inside the ``autocast`` context manager,
+ while the ``.backward()`` call should be outside of it.
+
+ Here is a basic example:
+
+ .. raw:: html
+
+
+ Needs to be run on SM89+ (Ada or newer)
+
+
+ .. literalinclude:: autocast_pytorch.py
+ :language: python
+ :start-after: # START_AUTOCAST_BASIC
+ :end-before: # END_AUTOCAST_BASIC
+
+ You can use multiple recipes in the same model in the following ways:
+
+ **Sequential contexts** – apply different recipes to different parts of your model:
+
+ .. raw:: html
+
+
+ Needs to be run on SM89+ (Ada or newer)
+
+
+ .. literalinclude:: autocast_pytorch.py
+ :language: python
+ :start-after: # START_AUTOCAST_SEQUENTIAL
+ :end-before: # END_AUTOCAST_SEQUENTIAL
+
+ **Nested contexts** – the inner context overrides the outer one for its scope:
+
+ .. raw:: html
+
+
+ Needs to be run on SM89+ (Ada or newer)
+
+
+ .. literalinclude:: autocast_pytorch.py
+ :language: python
+ :start-after: # START_AUTOCAST_NESTED
+ :end-before: # END_AUTOCAST_NESTED
+
+
+ .. tab:: JAX
+
+ The JAX API of Transformer Engine provides an ``autocast`` context manager similar to PyTorch.
+ The key difference is that in JAX, model initialization must happen inside the ``autocast`` context
+ to properly capture quantization metadata in the parameter tree.
+
+ Additionally, JAX requires a ``global_shard_guard(MeshResource())`` context (even for single GPU)
+ and the ``mesh_resource`` argument in the ``autocast`` call.
+
+ Here is a basic example:
+
+ .. raw:: html
+
+
+ Needs to be run on SM89+ (Ada or newer)
+
+
+ .. literalinclude:: autocast_jax.py
+ :language: python
+ :start-after: # START_AUTOCAST_BASIC
+ :end-before: # END_AUTOCAST_BASIC
+
+ You can use multiple recipes in the same model in the following ways:
+
+ **Sequential contexts** – apply different recipes to different parts of your model:
+
+ .. raw:: html
+
+
+ Needs to be run on SM89+ (Ada or newer)
+
+
+ .. literalinclude:: autocast_jax.py
+ :language: python
+ :start-after: # START_AUTOCAST_SEQUENTIAL
+ :end-before: # END_AUTOCAST_SEQUENTIAL
+
+ **Nested contexts** – the inner context overrides the outer one for its scope:
+
+ .. raw:: html
+
+
+ Needs to be run on SM89+ (Ada or newer)
+
+
+ .. literalinclude:: autocast_jax.py
+ :language: python
+ :start-after: # START_AUTOCAST_NESTED
+ :end-before: # END_AUTOCAST_NESTED
+
+**Mixed precision with 8- or 4-bit precisions**
+
+From now on, we will refer to FP8/MXFP8/NVFP4 etc. as *low precision*
+and to FP32/BF16/FP16 as *high precision*. This terminology will be
+used throughout the rest of the documentation.
+
+Not all operations run in low precision:
+
+- **Non-attention linear operations**: run in low precision.
+- **Attention computations**: run in high precision by default (some recipes allow low precision as an option).
+- **Other operations** (layer normalization, softmax, etc.): run in high precision.
+
+Within high-precision operations, there are two categories:
+
+- **Configurable precision**: most operations run in parameter precision (FP32/BF16/FP16) or the precision specified by ``torch.autocast``.
+- **Fixed FP32 precision**: some operations, or parts of operations—such as the division in layernorm—always run in FP32, regardless of other settings.
+
+.. raw:: html
+ :file: img/mixed_precision_operations.svg
+
+*Figure 3: Default single-device forward pass of TransformerLayer operations precision – only linear operations (outside of dot product attention) are in lower precision.*
+
+**Linear layer data flow**
+
+Let's see how data flow of a linear layer works by default on a single H100 GPU with FP8 precision:
+
+H100 (Hopper) architecture natively supports FP8 Matrix Multiplication only in **TN** layout (Transpose-NoTranspose),
+so GEMM with tensors ``A`` and ``B`` returns ``B * A^T``.
+
+*Forward pass*
+
+* Input is quantized to FP8 – both ``input`` and ``input^T`` quantized versions are created.
+* Weights are stored in high precision and quantized to low precision before the GEMM – both ``weight`` and ``weight^T`` quantized versions are created.
+* FP8 GEMM with layout **TN** is run with ``weight`` and ``input`` tensors,
+* Outputs – ``input * weight^T`` tensor – are returned in high precision.
+
+*Backward pass*
+
+* Output gradients are quantized to FP8 – both ``output_grad`` and ``output_grad^T`` quantized versions are created.
+* FP8 GEMM with layout **TN** is performed with ``weight^T`` and ``output_grad`` tensors to compute input gradients.
+* FP8 GEMM with layout **TN** is performed with ``input^T`` and ``output_grad^T`` tensors to compute weight gradients.
+* Input gradients – ``output_grad * weight`` tensor – are returned in high precision.
+* Weight gradients – ``output_grad^T * input`` tensor – are returned in high precision.
+
+
+.. raw:: html
+ :file: img/fp8_linear_flow.svg
+
+*Figure 4: Forward pass of a Linear layer with low precision data flow.*
diff --git a/docs/features/low_precision_training/mxfp8/img/fp8_1d_scaling.svg b/docs/features/low_precision_training/mxfp8/img/fp8_1d_scaling.svg
new file mode 100644
index 00000000000..30f16d9a718
--- /dev/null
+++ b/docs/features/low_precision_training/mxfp8/img/fp8_1d_scaling.svg
@@ -0,0 +1,177 @@
+
+
diff --git a/docs/features/low_precision_training/mxfp8/img/mxfp8_row_col.svg b/docs/features/low_precision_training/mxfp8/img/mxfp8_row_col.svg
new file mode 100644
index 00000000000..42ea0308bb8
--- /dev/null
+++ b/docs/features/low_precision_training/mxfp8/img/mxfp8_row_col.svg
@@ -0,0 +1,266 @@
+
diff --git a/docs/features/low_precision_training/mxfp8/img/mxfp8_scale_linearize_and_swizzle.svg b/docs/features/low_precision_training/mxfp8/img/mxfp8_scale_linearize_and_swizzle.svg
new file mode 100644
index 00000000000..6e4ed44d56e
--- /dev/null
+++ b/docs/features/low_precision_training/mxfp8/img/mxfp8_scale_linearize_and_swizzle.svg
@@ -0,0 +1,190 @@
+
+
diff --git a/docs/features/low_precision_training/mxfp8/img/mxfp8_swizzle_both_tensors.svg b/docs/features/low_precision_training/mxfp8/img/mxfp8_swizzle_both_tensors.svg
new file mode 100644
index 00000000000..d8489ecc4f6
--- /dev/null
+++ b/docs/features/low_precision_training/mxfp8/img/mxfp8_swizzle_both_tensors.svg
@@ -0,0 +1,101 @@
+
+
diff --git a/docs/features/low_precision_training/mxfp8/img/mxfp8_tensor_scaling_layout.svg b/docs/features/low_precision_training/mxfp8/img/mxfp8_tensor_scaling_layout.svg
new file mode 100644
index 00000000000..3b81ff0a36d
--- /dev/null
+++ b/docs/features/low_precision_training/mxfp8/img/mxfp8_tensor_scaling_layout.svg
@@ -0,0 +1,63 @@
+
diff --git a/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py
new file mode 100644
index 00000000000..d41b1ecfe47
--- /dev/null
+++ b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import jax
+
+# Check for Blackwell or newer GPU
+gpu = jax.devices("gpu")[0]
+major, minor = gpu.compute_capability.split(".")
+assert int(major) >= 10, f"MXFP8 requires SM100 (Blackwell) or later, got SM{major}{minor}"
+
+# START_MXFP8_EXAMPLE
+
+import jax
+import jax.numpy as jnp
+import transformer_engine.jax as te
+from transformer_engine.jax.flax import DenseGeneral
+from transformer_engine.common.recipe import MXFP8BlockScaling, Format
+
+# Create MXFP8 recipe
+recipe = MXFP8BlockScaling(
+ fp8_format=Format.E4M3, # FP8 format (default: E4M3, E5M2 not supported)
+)
+
+with te.autocast(enabled=True, recipe=recipe):
+ # Initialize layer and data
+ layer = DenseGeneral(features=1024)
+ key = jax.random.PRNGKey(0)
+ x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
+ params = layer.init(key, x)
+
+ # Forward and backward pass
+ def loss_fn(params):
+ output = layer.apply(params, x)
+ return output.sum()
+
+ loss, grads = jax.value_and_grad(loss_fn)(params)
+
+# END_MXFP8_EXAMPLE
diff --git a/docs/features/low_precision_training/mxfp8/mxfp8.rst b/docs/features/low_precision_training/mxfp8/mxfp8.rst
new file mode 100644
index 00000000000..b0c80e837cc
--- /dev/null
+++ b/docs/features/low_precision_training/mxfp8/mxfp8.rst
@@ -0,0 +1,199 @@
+..
+ Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+MXFP8
+=====
+
+
+MXFP8 (Microscaling FP8) is an enhanced FP8 blockwise scaling recipe that leverages native hardware
+acceleration on Blackwell GPUs (SM 10.0+). By using one scaling factor per 32 consecutive values
+(rather than 128), MXFP8 delivers finer-grained quantization with improved numerical precision.
+
+
+
+Data Format
+-----------
+
+The representation of an FP8 tensor element ``x`` in MXFP8 precision is given by:
+
+.. code-block:: python
+
+ x = x_fp8 * s_block
+
+where
+
+* ``x_fp8`` is the FP8 value in E4M3 format,
+* ``s_block`` is a local **E8M0** scaling factor shared by a block of 32 elements.
+
+
+**FP8 format**
+
+Like FP8 Blockwise Scaling, E4M3 is used by default for both forward and backward passes.
+The finer-grained scaling provides sufficient dynamic range without requiring the E5M2 format.
+The ``fp8_format`` parameter also supports ``HYBRID`` mode (E4M3 for forward, E5M2 for backward).
+Pure E5M2 training is not supported.
+
+
+**Block size**
+
+Block size is 32.
+Blocks are one-dimensional, containing 32 consecutive values. No 2D scaling is performed.
+
+There are some assumptions on the dimensions of the tensor:
+
+* the tensor must have at least 2 dimensions,
+* the last dimension must be divisible by 32,
+* the product of all dimensions except the last must be divisible by 32.
+
+
+**Scaling factors**
+
+Scaling factors are stored as E8M0 (8 exponent bits, 0 mantissa bits), which inherently represents
+powers of 2. This differs from FP8 Blockwise Scaling, which uses 32-bit floating point numbers
+optionally constrained to powers of 2. Note that FP32 also has 8 exponent bits, so the representable
+ranges are similar when the power-of-2 constraint is enabled.
+
+Each block's scaling factor is computed through the following steps:
+
+1. Find the maximum absolute value (``amax_block``) across all 32 elements in the block.
+2. Compute the E8M0 biased exponent: ``e = float_to_e8m0(amax_block / max_fp8)``, where ``max_fp8 = 448``
+ (the maximum representable value in E4M3 format).
+
+ Since E8M0 and FP32 share the same exponent bias (127), ``float_to_e8m0`` simply extracts
+ the 8-bit exponent from the FP32 representation, rounding up if the mantissa is non-zero.
+
+3. The scaling factor is ``s_block = 2^(e - 127)``.
+
+This ensures that the largest value in each block fits within the FP8 representable range without overflow.
+
+
+.. raw:: html
+ :file: img/fp8_1d_scaling.svg
+
+*Figure 1. MXFP8 uses one E8M0 scaling factor per 32 consecutive elements, providing fine-grained
+quantization and compact scaling factor representation.*
+
+
+Handling transposes
+-------------------
+
+Blackwell architecture supports multiple FP8 GEMM layouts (TN, NT, NN), so columnwise usage
+does not require explicit transposition. However, rowwise and columnwise quantizations are different:
+
+- *Rowwise* - 1 scaling factor per 32 consecutive elements along a row (1×32 blocks).
+- *Columnwise* - 1 scaling factor per 32 consecutive elements along a column (32×1 blocks).
+
+Because the scaling factor blocks have different orientations, rowwise and columnwise MXFP8 tensors
+are numerically different — one cannot derive one from the other. Both must be quantized
+independently from full-precision data.
+
+.. raw:: html
+ :file: img/mxfp8_row_col.svg
+
+*Figure 2. MXFP8 rowwise vs columnwise quantization layout.*
+
+
+Swizzling scaling factors
+-------------------------
+
+Like :doc:`FP8 Blockwise Scaling <../fp8_blockwise_scaling/fp8_blockwise_scaling>`, MXFP8 uses different data layouts for communication and computation.
+MXFP8 GEMMs require scaling factors in a specific hardware layout
+(see `cuBLAS documentation `__).
+The conversion to this GEMM-ready layout is called *swizzling*. Because swizzled scaling factors
+cannot be communicated across devices, Transformer Engine performs swizzling after any required
+communication, just before each GEMM operation.
+
+.. raw:: html
+ :file: img/mxfp8_swizzle_both_tensors.svg
+
+*Figure 3. MXFP8 swizzling process: standard scaling factors are rearranged into the hardware-required layout.*
+
+
+Blackwell Tensor Cores compute matrix multiplications using ``128x128`` tiles.
+Scaling factors are stored in row-major order, but to process a tile, we need a ``128x4`` vertical
+slice of scaling factors. In row-major storage, these vertical slices are scattered in memory
+with gaps between each row. The hardware requires them to be stored contiguously.
+
+.. raw:: html
+ :file: img/mxfp8_tensor_scaling_layout.svg
+
+*Figure 4. FP8 tensor (left) is divided into 128x128 tiles. Each tile requires a 128x4 block of scaling factors (right). These vertical blocks are not contiguous in memory.*
+
+Swizzling transforms the layout to meet hardware requirements by:
+
+1. **Linearizing** the ``128x4`` blocks so they are stored contiguously one after another.
+2. **Permuting** the 4-byte elements within each block.
+
+Specifically, if we index the 128 4-byte elements in a scaling factor block as :math:`0, 1, \dots, 127`, the hardware expects them in the following interleaved order:
+
+.. code-block:: text
+
+ 0, 32, 64, 96, 1, 33, 65, 97, ..., k, 32 + k, 64 + k, 96 + k, ..., 31, 63, 95, 127
+
+
+.. raw:: html
+ :file: img/mxfp8_scale_linearize_and_swizzle.svg
+
+*Figure 5. Linearization and swizzling of scaling factors. The 2D grid of scaling factors is first flattened into a contiguous sequence of blocks (top), then the rows within each block are interleaved to match the hardware access pattern (bottom).*
+
+For columnwise scaling factors, the process is analogous but with ``4x128`` horizontal blocks instead of ``128x4`` vertical blocks.
+
+
+
+Distributed training
+--------------------
+
+**Scale synchronization**
+
+The blockwise scaled tensor does not need any scale synchronization among the nodes.
+This is because each scaling factor is local to its 32-element block,
+unlike :doc:`FP8 Current <../fp8_current_scaling/fp8_current_scaling>`/:doc:`Delayed Scaling <../fp8_delayed_scaling/fp8_delayed_scaling>` where a single global scale applies to the entire tensor, even when sharded.
+
+**Quantized all-gather**
+
+All-gather of columnwise tensors is supported and necessary because:
+
+- columnwise quantized tensors cannot be computed from rowwise quantized ones (as mentioned earlier),
+- gathering high-precision tensors is avoided in most cases for performance reasons.
+
+
+Examples
+--------
+
+Here's how to use MXFP8 recipe in PyTorch and JAX:
+
+.. tabs::
+
+ .. tab:: PyTorch
+
+ .. raw:: html
+
+
+
+ .. literalinclude:: jax_mxfp8_example.py
+ :language: python
+ :start-after: # START_MXFP8_EXAMPLE
+ :end-before: # END_MXFP8_EXAMPLE
+
+
+Supported devices
+-----------------
+
+Blackwell and later (SM 10.0+)
\ No newline at end of file
diff --git a/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py b/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py
new file mode 100644
index 00000000000..19891083b45
--- /dev/null
+++ b/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py
@@ -0,0 +1,34 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import torch
+
+# Check for Blackwell or newer GPU
+major, minor = torch.cuda.get_device_capability()
+assert major >= 10, f"MXFP8 requires SM100 (Blackwell) or later, got SM{major}{minor}"
+
+# START_MXFP8_EXAMPLE
+
+import torch
+import transformer_engine.pytorch as te
+from transformer_engine.common.recipe import MXFP8BlockScaling, Format
+
+# Create MXFP8 recipe
+recipe = MXFP8BlockScaling(
+ fp8_format=Format.E4M3, # E4M3 (default) or HYBRID; pure E5M2 not supported
+)
+
+# Create a linear layer with bfloat16 parameters
+layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
+
+# Forward and backward pass
+inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
+
+with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
+ output = layer(inp)
+ loss = output.sum()
+
+loss.backward()
+
+# END_MXFP8_EXAMPLE
diff --git a/docs/features/low_precision_training/nvfp4/img/nvfp4_all_gather.svg b/docs/features/low_precision_training/nvfp4/img/nvfp4_all_gather.svg
new file mode 100644
index 00000000000..3e215551a78
--- /dev/null
+++ b/docs/features/low_precision_training/nvfp4/img/nvfp4_all_gather.svg
@@ -0,0 +1,118 @@
+
+
+
diff --git a/docs/features/low_precision_training/nvfp4/img/nvfp4_hierarchical_scaling.svg b/docs/features/low_precision_training/nvfp4/img/nvfp4_hierarchical_scaling.svg
new file mode 100644
index 00000000000..05e67b78896
--- /dev/null
+++ b/docs/features/low_precision_training/nvfp4/img/nvfp4_hierarchical_scaling.svg
@@ -0,0 +1,186 @@
+
\ No newline at end of file
diff --git a/docs/features/low_precision_training/nvfp4/img/nvfp4_row_col.svg b/docs/features/low_precision_training/nvfp4/img/nvfp4_row_col.svg
new file mode 100644
index 00000000000..30363d6ce22
--- /dev/null
+++ b/docs/features/low_precision_training/nvfp4/img/nvfp4_row_col.svg
@@ -0,0 +1,208 @@
+
diff --git a/docs/features/low_precision_training/nvfp4/img/nvfp4_vs_fp8.svg b/docs/features/low_precision_training/nvfp4/img/nvfp4_vs_fp8.svg
new file mode 100644
index 00000000000..68f6bf90390
--- /dev/null
+++ b/docs/features/low_precision_training/nvfp4/img/nvfp4_vs_fp8.svg
@@ -0,0 +1,91 @@
+
+
+
diff --git a/docs/features/low_precision_training/nvfp4/img/rht.svg b/docs/features/low_precision_training/nvfp4/img/rht.svg
new file mode 100644
index 00000000000..0250c27ae54
--- /dev/null
+++ b/docs/features/low_precision_training/nvfp4/img/rht.svg
@@ -0,0 +1,138 @@
+
diff --git a/docs/features/low_precision_training/nvfp4/img/stochastic_rounding.svg b/docs/features/low_precision_training/nvfp4/img/stochastic_rounding.svg
new file mode 100644
index 00000000000..eb745f6e84f
--- /dev/null
+++ b/docs/features/low_precision_training/nvfp4/img/stochastic_rounding.svg
@@ -0,0 +1,95 @@
+
diff --git a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py
new file mode 100644
index 00000000000..99a16f21a79
--- /dev/null
+++ b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import jax
+
+# Check for Blackwell or newer GPU
+gpu = jax.devices("gpu")[0]
+major, minor = gpu.compute_capability.split(".")
+assert int(major) >= 10, f"NVFP4 requires SM100 (Blackwell) or later, got SM{major}{minor}"
+
+# START_NVFP4_EXAMPLE
+
+import jax
+import jax.numpy as jnp
+import transformer_engine.jax as te
+from transformer_engine.jax.flax import DenseGeneral
+from transformer_engine.common.recipe import NVFP4BlockScaling, Format
+
+# Define NVFP4 recipe
+recipe = NVFP4BlockScaling(
+ fp8_format=Format.E4M3,
+ use_2d_weight_quantization=True,
+ use_rht=True,
+)
+
+with te.autocast(enabled=True, recipe=recipe):
+ # Initialize layer and data
+ layer = DenseGeneral(features=1024)
+ key = jax.random.PRNGKey(0)
+ x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
+ params = layer.init(key, x)
+
+ # Forward and backward pass
+ def loss_fn(params):
+ output = layer.apply(params, x)
+ return output.sum()
+
+ loss, grads = jax.value_and_grad(loss_fn)(params)
+
+# END_NVFP4_EXAMPLE
diff --git a/docs/features/low_precision_training/nvfp4/nvfp4.rst b/docs/features/low_precision_training/nvfp4/nvfp4.rst
new file mode 100644
index 00000000000..cc8ad6e7470
--- /dev/null
+++ b/docs/features/low_precision_training/nvfp4/nvfp4.rst
@@ -0,0 +1,261 @@
+..
+ Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+NVFP4
+===================================
+
+NVFP4 is the first 4-bit recipe introduced in Transformer Engine –
+please refer to the `NVFP4 paper `__ for more details.
+It is a more complex recipe than the previous ones – apart from the new data format,
+it introduces multiple features which help training stability.
+
+Data Format
+----------------------
+
+The NVFP4 datatype consists of 1 sign bit, 2 exponent bits, and 1 mantissa bit (E2M1).
+It can represent values of magnitude up to +/- 6.
+NVFP4 uses a hierarchical block scaling approach where multiple scaling factors are combined to recover the high precision value.
+
+.. raw:: html
+ :file: img/nvfp4_vs_fp8.svg
+
+*Figure 1. Bit layout comparison between standard FP8 formats (E4M3 and E5M2) and NVFP4 (E2M1).*
+
+
+The representation of an NVFP4 tensor element ``x`` is given by:
+
+.. code-block:: python
+
+ x = x_e2m1 * s_block * s_global
+
+where
+
+* ``x_e2m1`` is the 4-bit value,
+* ``s_block`` is a local **FP8 E4M3** scaling factor shared by a block of 16 consecutive elements,
+* ``s_global`` is a global **FP32** scaling factor applied to the entire tensor.
+
+**Scaling Factor Computation**
+
+The scaling factors are computed as follows:
+
+1. Global scaling factor (``s_global``):
+
+.. code-block:: python
+
+ s_global = global_amax / (fp8_max * fp4_max)
+ # where:
+ # - global_amax: maximum absolute value across the entire tensor
+ # - fp8_max: maximum representable value in FP8 E4M3 (448.0)
+ # - fp4_max: maximum representable value in NVFP4 E2M1 (6.0)
+
+2. Block scaling factor (``s_block``):
+
+.. code-block:: python
+
+ s_block = (block_amax / fp4_max) / s_global
+ # where:
+ # - block_amax: maximum absolute value within the block
+ # - fp4_max: maximum representable value in NVFP4 E2M1 (6.0)
+ # - s_block is stored in FP8 E4M3 format
+
+
+.. raw:: html
+ :file: img/nvfp4_hierarchical_scaling.svg
+
+*Figure 2. NVFP4 hierarchical scaling structure showing the combination of block-level and global scaling factors.*
+
+This hierarchical structure uses fine-grained block scaling
+to adapt to local magnitude variations and global scaling
+to handle the overall dynamic range.
+
+**2D weight scaling**
+
+NVFP4 can be:
+
+* 1 dimensional - each block of 16 consecutive elements shares a scaling factor,
+* 2 dimensional - each block of 16x16 elements shares a scaling factor.
+
+By default, NVFP4 uses 2D scaling for weights and 1D scaling for activations and gradients.
+Set ``disable_2d_quantization=True`` in the recipe configuration to force 1D scaling for weights as well (activations and gradients always use 1D).
+The motivation for using 2D scaling for weights is to ensure that rowwise and columnwise
+quantized tensors are numerically equivalent.
+Please refer to the `NVFP4 paper `__ for more details.
+
+
+Stochastic Rounding
+-------------------
+
+Stochastic rounding is applied when casting scaled values to NVFP4 format. Instead of deterministic rounding
+(always rounding to nearest even value), each scaled value is probabilistically rounded to one of the two
+nearest representable NVFP4 values. The probability of rounding to a given value is inversely proportional to
+the distance to that value, which ensures that the expected value of the quantized
+tensor equals the original value, eliminating systematic quantization bias during training.
+Stochastic rounding is hardware-accelerated using native GPU instructions introduced with the
+Blackwell architecture.
+
+.. raw:: html
+ :file: img/stochastic_rounding.svg
+
+*Figure 3. Stochastic rounding illustration. Given a value* ``x`` *to be quantized, and the two nearest
+representable NVFP4 values* ``v1`` *(lower) and* ``v2`` *(higher), deterministic rounding always
+rounds to the nearest value, while stochastic rounding probabilistically rounds to either value.
+If* ``x`` *is 40% of the way from* ``v1`` *to* ``v2``, *there is a 60% chance of rounding to* ``v1``
+*and a 40% chance of rounding to* ``v2``.
+
+Stochastic rounding is enabled only for gradients. It can be disabled by setting
+``disable_stochastic_rounding=True`` in the recipe configuration.
+
+
+Random Hadamard Transform
+--------------------------
+
+Random Hadamard Transform (RHT) applies an orthogonal rotation to the tensor **before quantization**,
+smoothing outliers in the tensor distributions and making them easier to represent accurately in NVFP4.
+RHT is applied to columnwise quantization of inputs and gradients, which are operands
+for the **wgrad GEMM**. This GEMM – according to the paper – is particularly sensitive
+to quantization errors, hence the additional outlier smoothing.
+RHT is supported only for BF16 inputs/gradients; other dtypes will raise an error.
+
+The transform is defined as:
+
+.. math::
+
+ x' = x H
+
+where :math:`H` is the RHT matrix defined below. The quantization scale factor is computed
+from the rotated tensor :math:`x'`.
+
+**Hadamard matrix**
+
+The :math:`d \times d` Hadamard matrix has elements :math:`\pm 1` and satisfies :math:`H_d H_d^T = d I`.
+When normalized by :math:`1/\sqrt{d}`, the matrix becomes orthogonal and can be applied
+to both operands of a matrix multiplication:
+
+.. math::
+
+ C = (AH)(H^T B) = AB
+
+where the transforms cancel within the dot-product since :math:`H H^T = I`.
+
+**Sign matrix**
+
+In the RHT implementation, a :math:`d`-dimensional diagonal sign matrix :math:`S_d` is applied
+together with the Hadamard matrix:
+
+.. math::
+
+ H = \frac{1}{\sqrt{d}} S_d H_d
+
+where diagonal entries of :math:`S_d` are :math:`\{-1, 1\}` and flip the signs of different rows of :math:`H_d`.
+As described in the paper, a single random sign vector is shared across all linear layers throughout training.
+In the implementation, this vector is fixed and the RHT matrix is computed once at initialization and cached.
+
+**Tiled implementation**
+
+The Hadamard transform is performed in a tiled approach along the last dimension of the tensor.
+For an :math:`m \times k` tensor, the data is reshaped to :math:`(mk/d) \times d`
+and multiplied by the :math:`d \times d` matrix :math:`H`. In this implementation, :math:`d = 16`.
+
+
+.. raw:: html
+ :file: img/rht.svg
+
+*Figure 4. WGRAD GEMM pipeline comparison: without RHT (left) and with RHT applied (right).*
+
+Handling transposes
+-------------------
+
+Like :doc:`MXFP8 <../mxfp8/mxfp8>`, NVFP4 requires both rowwise and columnwise quantized tensors
+for different GEMM operands. Unlike MXFP8 which supports multiple layouts (TN, NT, NN),
+**NVFP4 GEMM only supports the TN layout**.
+
+NVFP4 stores columnwise data and scaling factors in a **transposed layout**:
+
+- **Rowwise**: data ``[A, B]`` with 1×16 horizontal blocks, ``scales`` shape ``[A, B/16]``
+- **Columnwise**: data ``[B, A]`` (transposed) with 1×16 horizontal blocks, ``scales`` shape ``[B, A/16]``
+
+Scale tensors are padded for hardware alignment: first dimension to a multiple of 128,
+second dimension to a multiple of 4 (e.g. rowwise: ``[roundup(A, 128), roundup(B/16, 4)]``).
+
+.. raw:: html
+ :file: img/nvfp4_row_col.svg
+
+*Figure 5. NVFP4 rowwise vs columnwise quantization layout. Unlike MXFP8, columnwise scales are stored transposed.*
+
+
+Swizzling scaling factors
+-------------------------
+
+NVFP4 requires swizzling of block scaling factors (``s_block``) before GEMM operations,
+similar to :doc:`MXFP8 <../mxfp8/mxfp8>`. Key differences:
+
+- Block size is 16 (vs 32 for MXFP8)
+- Both rowwise and columnwise scaling factors are swizzled, but thanks to the transposed
+ columnwise layout, a single rowwise swizzle kernel handles both cases.
+- Scaling factors are stored as FP8 E4M3 (vs E8M0 for MXFP8)
+
+
+Distributed training
+--------------------
+
+**Amax reduction**
+
+Block scaling factors (``s_block``) do not require synchronization between nodes,
+as each scaling factor is local to its block of 16 elements.
+However, the global scaling factor (``s_global``) requires amax synchronization for gathered tensors.
+For tensors that are gathered (e.g., input and gradient in sequence parallelism),
+amax reduction is performed before quantization.
+If before synchronization there was ``amax_1`` on node 1,
+``amax_2`` on node 2, etc., after synchronization there will be ``max(amax_1, amax_2, ...)`` on all nodes.
+
+**Quantized all-gather**
+
+All-gather of columnwise tensors is supported. To enable quantized all-gather,
+all nodes must use the same ``s_global``, which is computed from the synchronized global amax.
+This is automatically enabled for column-parallel and row-parallel linear layers.
+
+.. raw:: html
+ :file: img/nvfp4_all_gather.svg
+
+*Figure 6. Quantization and all-gather flow for NVFP4 showing amax synchronization and hierarchical scaling.*
+
+Examples
+--------
+
+Here's how to use NVFP4 recipe in PyTorch and JAX. The examples show how to configure features like 2D weight quantization and Random Hadamard Transform (RHT):
+
+.. tabs::
+
+ .. tab:: PyTorch
+
+ .. raw:: html
+
+
+
+ .. literalinclude:: jax_nvfp4_example.py
+ :language: python
+ :start-after: # START_NVFP4_EXAMPLE
+ :end-before: # END_NVFP4_EXAMPLE
+
+
+Supported devices
+-----------------
+
+Blackwell and later (SM 10.0+)
diff --git a/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py
new file mode 100644
index 00000000000..c34845ae2aa
--- /dev/null
+++ b/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import torch
+
+# Check for Blackwell or newer GPU
+major, minor = torch.cuda.get_device_capability()
+assert major >= 10, f"NVFP4 requires SM100 (Blackwell) or later, got SM{major}{minor}"
+
+# START_NVFP4_EXAMPLE
+
+import torch
+import transformer_engine.pytorch as te
+from transformer_engine.common.recipe import NVFP4BlockScaling, Format
+
+# Define NVFP4 recipe
+# Key features like 2D weight quantization and RHT can be enabled here
+recipe = NVFP4BlockScaling(
+ fp8_format=Format.E4M3,
+ use_2d_weight_quantization=True,
+ use_rht=True,
+)
+
+# Create a linear layer with bfloat16 parameters
+layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
+
+# Forward and backward pass
+inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
+
+with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
+ output = layer(inp)
+ loss = output.sum()
+
+loss.backward()
+
+# END_NVFP4_EXAMPLE
diff --git a/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py b/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py
new file mode 100644
index 00000000000..4bcee273972
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
+
+# START_FUSED_LAYERS
+
+import jax
+import jax.numpy as jnp
+import transformer_engine.jax as te
+from transformer_engine.jax.flax import LayerNorm, DenseGeneral, LayerNormDenseGeneral
+from transformer_engine.common.recipe import DelayedScaling
+
+key = jax.random.PRNGKey(0)
+x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
+
+# Example 1: Separate LayerNorm and DenseGeneral layers
+layer_norm = LayerNorm()
+dense = DenseGeneral(features=1024)
+
+# Initialize parameters
+ln_params = layer_norm.init(key, x)
+dense_params = dense.init(key, x)
+
+# Two separate operations
+normalized = layer_norm.apply(ln_params, x)
+output_separate = dense.apply(dense_params, normalized)
+
+# Example 2: Fused LayerNormDenseGeneral layer
+fused_layer = LayerNormDenseGeneral(features=1024)
+
+# Initialize and apply with FP8 autocast
+recipe = DelayedScaling()
+with te.autocast(enabled=True, recipe=recipe):
+ fused_params = fused_layer.init(key, x)
+ output_fused, _ = fused_layer.apply(fused_params, x) # Returns (output, ln_output)
+
+# The fused layer is more efficient as it combines LayerNorm and quantization
+
+# END_FUSED_LAYERS
diff --git a/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py b/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py
new file mode 100644
index 00000000000..1a9a1baf2ed
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import torch
+
+# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
+cc = torch.cuda.get_device_capability()
+assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)"
+
+# START_FUSED_LAYERS
+
+import torch
+import transformer_engine.pytorch as te
+from transformer_engine.common.recipe import DelayedScaling
+
+# Example 1: Separate LayerNorm and Linear layers
+layer_norm = te.LayerNorm(1024)
+linear = te.Linear(1024, 1024)
+
+inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
+
+# Two separate operations: LayerNorm produces FP32, then Linear quantizes it
+normalized = layer_norm(inp)
+output_separate = linear(normalized)
+
+# Example 2: Fused LayerNormLinear layer
+fused_layer = te.LayerNormLinear(1024, 1024, params_dtype=torch.bfloat16)
+
+# Single operation: LayerNorm output is directly quantized
+recipe = DelayedScaling()
+with te.autocast(enabled=True, recipe=recipe):
+ output_fused = fused_layer(inp)
+
+# The fused layer is more efficient as it avoids redundant quantization
+
+# END_FUSED_LAYERS
diff --git a/docs/features/low_precision_training/performance_considerations/img/fused_layers.svg b/docs/features/low_precision_training/performance_considerations/img/fused_layers.svg
new file mode 100644
index 00000000000..8b7ffb5b50b
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/img/fused_layers.svg
@@ -0,0 +1,120 @@
+
+
diff --git a/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg b/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg
new file mode 100644
index 00000000000..df5102090e9
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg
@@ -0,0 +1,218 @@
+
+
diff --git a/docs/features/low_precision_training/performance_considerations/img/hopper_vs_blackwell_layout.svg b/docs/features/low_precision_training/performance_considerations/img/hopper_vs_blackwell_layout.svg
new file mode 100644
index 00000000000..6f9bc4d5a10
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/img/hopper_vs_blackwell_layout.svg
@@ -0,0 +1,122 @@
+
+
diff --git a/docs/features/low_precision_training/performance_considerations/img/sequence_parallel_quantization.svg b/docs/features/low_precision_training/performance_considerations/img/sequence_parallel_quantization.svg
new file mode 100644
index 00000000000..5b61ac24788
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/img/sequence_parallel_quantization.svg
@@ -0,0 +1,159 @@
+
+
+
diff --git a/docs/features/low_precision_training/performance_considerations/img/transpose_fusion.svg b/docs/features/low_precision_training/performance_considerations/img/transpose_fusion.svg
new file mode 100644
index 00000000000..194b1237e14
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/img/transpose_fusion.svg
@@ -0,0 +1,181 @@
+
+
+
diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out
new file mode 100644
index 00000000000..a57b4931b46
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out
@@ -0,0 +1,3 @@
+# START_MEMORY_USAGE_1
+Memory usage after forward pass: 6.00 MB
+# END_MEMORY_USAGE_1
diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py
new file mode 100644
index 00000000000..d5d1aabb7ed
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
+
+print("# START_MEMORY_USAGE_1")
+
+import jax
+import jax.numpy as jnp
+from transformer_engine.jax.flax import DenseGeneral
+
+
+def get_gpu_memory_mb():
+ """Get current GPU memory usage in MB."""
+ jax.effects_barrier()
+ stats = jax.local_devices()[0].memory_stats()
+ return stats["bytes_in_use"] / (1024**2) if stats else 0.0
+
+
+def measure_memory():
+ key = jax.random.PRNGKey(0)
+ jax.clear_caches()
+
+ init_memory = get_gpu_memory_mb()
+
+ # Initialize layer with BF16 parameters
+ layer = DenseGeneral(features=1024, dtype=jnp.bfloat16)
+ x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16)
+ params = layer.init(key, x)
+
+ # Forward pass in high precision
+ output = layer.apply(params, x)
+
+ mem_after_forward = get_gpu_memory_mb() - init_memory
+ return mem_after_forward
+
+
+# Warmup run
+measure_memory()
+
+# Actual measurement
+mem_after_forward = measure_memory()
+print(f"Memory usage after forward pass: {mem_after_forward:.2f} MB")
+
+print("# END_MEMORY_USAGE_1")
diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out
new file mode 100644
index 00000000000..f977460e84c
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out
@@ -0,0 +1,11 @@
+/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden.
+ Overriding a previously registered kernel for the same operator and the same dispatch key
+ operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
+ registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922
+ dispatch key: ADInplaceOrView
+ previous kernel: no debug info
+ new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
+ self.m.impl(
+# START_MEMORY_USAGE_1
+Memory usage after forward pass: 6.00 MB
+# END_MEMORY_USAGE_1
diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py
new file mode 100644
index 00000000000..5e7f2ae1770
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import torch
+
+# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
+cc = torch.cuda.get_device_capability()
+assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)"
+
+print("# START_MEMORY_USAGE_1")
+# START_MEMORY_USAGE_1
+import torch
+import transformer_engine.pytorch as te
+
+
+def measure_memory():
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ init_memory = torch.cuda.memory_allocated()
+ layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
+ memory = torch.cuda.memory_allocated() - init_memory
+
+ inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda")
+ out = layer(inp)
+ mem_after_forward = torch.cuda.memory_allocated() - init_memory
+
+ return memory, mem_after_forward
+
+
+# Warmup run
+measure_memory()
+
+# Actual measurement
+memory, mem_after_forward = measure_memory()
+print(f"Memory usage after forward pass: {mem_after_forward/1024**2:.2f} MB")
+# END_MEMORY_USAGE_1
+print("# END_MEMORY_USAGE_1")
diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out
new file mode 100644
index 00000000000..85ee423022e
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out
@@ -0,0 +1,3 @@
+# START_MEMORY_USAGE_2
+Memory usage after forward pass: 6.01 MB
+# END_MEMORY_USAGE_2
diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py
new file mode 100644
index 00000000000..378a7c1e06c
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py
@@ -0,0 +1,50 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
+
+print("# START_MEMORY_USAGE_2")
+
+import jax
+import jax.numpy as jnp
+import transformer_engine.jax as te
+from transformer_engine.jax.flax import DenseGeneral
+from transformer_engine.common.recipe import DelayedScaling
+
+
+def get_gpu_memory_mb():
+ """Get current GPU memory usage in MB."""
+ jax.effects_barrier()
+ stats = jax.local_devices()[0].memory_stats()
+ return stats["bytes_in_use"] / (1024**2) if stats else 0.0
+
+
+def measure_memory():
+ key = jax.random.PRNGKey(0)
+ recipe = DelayedScaling()
+ jax.clear_caches()
+
+ init_memory = get_gpu_memory_mb()
+
+ # Initialize layer with BF16 parameters (outside autocast)
+ layer = DenseGeneral(features=1024, dtype=jnp.bfloat16)
+ x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16)
+
+ # Forward pass with FP8 compute
+ with te.autocast(enabled=True, recipe=recipe):
+ params = layer.init(key, x)
+ output = layer.apply(params, x)
+
+ mem_after_forward = get_gpu_memory_mb() - init_memory
+ return mem_after_forward
+
+
+# Warmup run
+measure_memory()
+
+# Actual measurement
+mem_after_forward = measure_memory()
+print(f"Memory usage after forward pass: {mem_after_forward:.2f} MB")
+
+print("# END_MEMORY_USAGE_2")
diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out
new file mode 100644
index 00000000000..9f7fa90ca16
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out
@@ -0,0 +1,11 @@
+/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden.
+ Overriding a previously registered kernel for the same operator and the same dispatch key
+ operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
+ registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922
+ dispatch key: ADInplaceOrView
+ previous kernel: no debug info
+ new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
+ self.m.impl(
+# START_MEMORY_USAGE_2
+Memory after forward pass: 8.02 MB
+# END_MEMORY_USAGE_2
diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py
new file mode 100644
index 00000000000..276bde42022
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import torch
+
+# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
+cc = torch.cuda.get_device_capability()
+assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)"
+
+print("# START_MEMORY_USAGE_2")
+# START_MEMORY_USAGE_2
+import torch
+import transformer_engine.pytorch as te
+
+
+def measure_memory():
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ init_memory = torch.cuda.memory_allocated()
+ layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
+ inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda")
+
+ with te.autocast(enabled=True):
+ out = layer(inp)
+ mem_after_forward = torch.cuda.memory_allocated() - init_memory
+
+ return mem_after_forward
+
+
+# Warmup run
+measure_memory()
+
+# Actual measurement
+mem_after_forward = measure_memory()
+print(f"Memory after forward pass: {mem_after_forward/1024**2:.2f} MB")
+# END_MEMORY_USAGE_2
+print("# END_MEMORY_USAGE_2")
diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out
new file mode 100644
index 00000000000..9ccba3d3e60
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out
@@ -0,0 +1,11 @@
+/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden.
+ Overriding a previously registered kernel for the same operator and the same dispatch key
+ operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
+ registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922
+ dispatch key: ADInplaceOrView
+ previous kernel: no debug info
+ new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
+ self.m.impl(
+# START_MEMORY_USAGE_3
+Memory after forward pass: 6.02 MB
+# END_MEMORY_USAGE_3
diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py
new file mode 100644
index 00000000000..d603da2809d
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py
@@ -0,0 +1,44 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import torch
+
+# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
+cc = torch.cuda.get_device_capability()
+assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)"
+
+print("# START_MEMORY_USAGE_3")
+# START_MEMORY_USAGE_3
+import torch
+import transformer_engine.pytorch as te
+
+
+def measure_memory():
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ init_memory = torch.cuda.memory_allocated()
+
+ # FP8 forward and backward with FP8 weights
+ with te.quantized_model_init(enabled=True), torch.no_grad():
+ layer_fp8 = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
+ memory = torch.cuda.memory_allocated() - init_memory
+
+ inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda")
+ with te.autocast(enabled=True):
+ out = layer_fp8(inp)
+
+ mem_after_forward = torch.cuda.memory_allocated() - init_memory
+
+ return memory, mem_after_forward
+
+
+# Warmup run
+measure_memory()
+
+# Actual measurement
+memory, mem_after_forward = measure_memory()
+print(f"Memory after forward pass: {mem_after_forward/1024**2:.2f} MB")
+# END_MEMORY_USAGE_3
+print("# END_MEMORY_USAGE_3")
diff --git a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst
new file mode 100644
index 00000000000..0409bc336b9
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst
@@ -0,0 +1,480 @@
+..
+ Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+Performance Considerations
+===================================
+
+.. _handling_transposes:
+
+Handling transposes
+-------------------
+
+In the last chapter we demonstrated that for FP8 on Hopper architecture,
+some tensors need to be physically transposed in memory to perform needed GEMMs.
+Dealing with transposes in Transformer low precision training is a bit tricky.
+Let's start by introducing the concept of *tensor usages*.
+
+**Tensor usages**
+
+Each quantized tensor may have two usages:
+
+- *rowwise usage* -- which is used for matrix multiplication, when the consecutive elements in row are accessed,
+- *columnwise usage* -- which is used for matrix multiplication, when the consecutive elements in column are accessed,
+
+To understand what access of consecutive elements means, let's consider two matrices ``A`` and ``B``
+and analyze how their elements are accessed during multiplication.
+
+For NN (non-transposed, non-transposed) multiplication ``C = A * B``, the formula is ``C_ij = sum_k(A_ik * B_kj)``.
+To compute element ``C_ij``, we iterate over the i-th row of ``A`` (elements ``A_i0, A_i1, ...``)
+and the j-th column of ``B`` (elements ``B_0j, B_1j, ...``). Thus, ``A`` is accessed rowwise
+and ``B`` is accessed columnwise.
+
+For NT (non-transposed, transposed) multiplication ``C = A * B^T``, the formula changes to ``C_ij = sum_k(A_ik * B_jk)``.
+Now we iterate over the i-th row of ``A`` and the j-th row of ``B`` (elements ``B_j0, B_j1, ...``).
+Both tensors are accessed rowwise.
+
+The figure below illustrates these access patterns:
+
+.. figure:: img/gemm_access_pattern.svg
+ :align: center
+ :alt: Matrix multiplication access pattern showing rowwise access for first tensor and columnwise access for second tensor
+
+ Figure 1: Access patterns in matrix multiplication for matrices in ``A * B`` and ``A * B^T`` operations.
+
+Based on the visualization above, we can derive general rules for when each matrix
+is accessed in rowwise or columnwise fashion. The key insight is that:
+
+- The **first tensor** in a matrix multiplication is accessed along its rows (rowwise) when non-transposed,
+ or along its columns (columnwise) when transposed.
+- The **second tensor** follows the opposite pattern: columnwise when non-transposed, rowwise when transposed.
+
+.. table:: Table 1: Summary of tensor access patterns based on transpose state.
+ :align: center
+
+ +------------------+--------------+---------------+
+ | | First tensor | Second tensor |
+ +------------------+--------------+---------------+
+ | Non-transposed | rowwise | columnwise |
+ +------------------+--------------+---------------+
+ | Transposed | columnwise | rowwise |
+ +------------------+--------------+---------------+
+
+**Input, weight and output gradient usages**
+
+Now let's apply these rules to a Linear layer. During training, a Linear layer performs
+three GEMM operations: one in the forward pass and two in the backward pass.
+
+
+.. table:: Table 2: Tensor access patterns for GEMM operations in a Linear layer during training.
+ :align: center
+
+ +-------------------+-------------------------------------+---------------------------+---------------------------+
+ | GEMM | Formula | First tensor usage | Second tensor usage |
+ +===================+=====================================+===========================+===========================+
+ | Forward | ``output = input * weight^T`` | input: rowwise | weight: rowwise |
+ +-------------------+-------------------------------------+---------------------------+---------------------------+
+ | Weight gradient | ``wgrad = output_grad^T * input`` | output_grad: columnwise | input: columnwise |
+ +-------------------+-------------------------------------+---------------------------+---------------------------+
+ | Input gradient | ``dgrad = output_grad * weight`` | output_grad: rowwise | weight: columnwise |
+ +-------------------+-------------------------------------+---------------------------+---------------------------+
+
+An important observation is that the **forward pass uses only rowwise tensors** - both input
+and weight are accessed rowwise.
+
+The backward pass introduces columnwise access. For weight gradient, both output gradient and input
+are accessed columnwise. For input gradient, output gradient is rowwise while weight is columnwise.
+
+As a result, each tensor (input, weight, output gradient) needs both rowwise and columnwise
+usages during training. This has implications for memory layout and transpose operations.
+
+
+**Architecture differences**
+
+The physical memory layout requirements for rowwise and columnwise usages differ between architectures
+and recipes. For FP8 tensors:
+
+- *Hopper*: cannot efficiently access elements in columnwise fashion, so columnwise tensors need to be physically transposed in memory.
+- *Blackwell*: supports columnwise access natively, so no transpose is needed.
+
+We will see that for most of the recipes and devices, rowwise usage and columnwise usage need different tensors.
+Thus by *rowwise tensor* and *columnwise tensor* we mean tensors that are used in rowwise and columnwise usages respectively.
+
+.. figure:: img/hopper_vs_blackwell_layout.svg
+ :align: center
+ :alt: Comparison of rowwise and columnwise tensor layouts on Blackwell vs Hopper
+
+ Figure 2: On Blackwell, rowwise and columnwise usages share the same memory layout.
+ On Hopper, columnwise usage requires a physical transpose.
+
+**Quantization fusions**
+
+This section is relevant only for recipes for which columnwise tensors
+are different from rowwise tensors.
+
+Note that performing rowwise and columnwise quantization at the same time
+enables some fusions, which usually lead to better performance.
+We showcase 3 example scenarios of producing quantized tensors in rowwise and columnwise usages,
+TE will use best possible fusion for given recipe and TE module configuration:
+
+1. *Computation of quantized tensor in both rowwise and columnwise usages in a single kernel in forward pass*.
+
+ This is the fastest one,
+ but since the columnwise usage is saved for backward pass, it may lead to increased memory usage,
+ if the high precision tensor also needs to be saved for backward - for example if it is the attention output which is saved anyway.
+
+2. *Computation of quantized tensor in rowwise usage in forward pass and fused quantization to produce columnwise usage in backward pass*.
+
+ This is usually slower than the previous one, since high precision tensor needs to be read twice.
+ It is used for example when high precision tensor is gathered both in forward and in backward
+ and quantized tensor gather is not implemented for such recipe.
+
+3. *Computation of quantized tensor in rowwise usage in forward pass and transpose to columnwise usage in backward pass*.
+
+ This is not possible for all recipes, but if it is possible it is more memory efficient than Option 1.
+
+Transformer Engine uses the best possible fusion internally, so users do not need to worry about the details.
+We showcase this issue in the documentation to understand memory consequences of different fusion scenarios.
+
+.. raw:: html
+ :file: img/transpose_fusion.svg
+
+*Figure 3: Three scenarios of producing quantized tensors in rowwise and columnwise usages.*
+
+
+
+Memory usage
+------------
+
+This section discusses memory usage in low precision training.
+Contrary to intuition, FP8 training does not always reduce memory compared to BF16/FP16.
+
+*Master weights*
+
+Transformer Engine stores weights in high precision and quantizes them to low precision before each GEMM.
+Moreover, one can specify the precision of the weights stored in the model - if this can be FP32 or
+BF16/FP16 -- or do not store high precision weights in the model at all. There are multiple scenarios to consider,
+three of them are listed below:
+
+1. model weights are in FP32, quantized to low precision before each GEMM,
+2. model weights are in BF16/FP16, quantized to low precision before each GEMM, master weights in optimizer are in FP32.
+3. model weight are stored directly in low precision, and master weights in optimizer are in FP32.
+
+Note that all these scenarios may have different memory footprints.
+
+*Activations saved for backward*
+
+Unlike weights, activations do not require a high precision copy for optimizer updates.
+As shown in Table 2, the input needs rowwise usage in forward and columnwise usage
+for weight gradient computation in backward — so it must be saved between passes.
+
+The memory impact depends on which scenario from Figure 3.
+Additionally, on architectures where rowwise and columnwise share the same memory layout
+(e.g., FP8 on Blackwell, as shown in Figure 2), a single quantized tensor serves both usages,
+reducing memory overhead compared to architectures requiring separate tensors.
+
+Output gradients, on the other hand, are computed during backward and do not need to be saved —
+both rowwise and columnwise usages are produced on the fly as needed.
+
+The FP8 examples below are analyzed on Hopper (SM90) or Ada (SM89) architecture, where rowwise
+and columnwise tensors require separate memory layouts.
+
+.. tabs::
+
+ .. tab:: PyTorch
+
+ **1. Baseline: high precision forward pass**
+
+ Let's start with a forward pass in higher precision to establish a baseline.
+
+ .. raw:: html
+
+
+ Needs to be run on SM89 (Ada) or SM90 (Hopper)
+
+
+ .. container:: program-output
+
+ .. literalinclude:: memory_usage_3_pytorch.out
+ :language: text
+ :start-after: # START_MEMORY_USAGE_3
+ :end-before: # END_MEMORY_USAGE_3
+
+ Total memory usage is ``1 MB (weight in FP8) + 2 MB (input) + 1 MB (input in FP8) + 2 MB (output) = 6 MB``.
+ Note that columnwise FP8 weight is not computed during initialization with ``torch.no_grad()``.
+ It will be computed on the first backward pass from the rowwise FP8 weight.
+
+ **4. Saving original input instead of quantized**
+
+ By default, TE saves the columnwise quantized input for the backward pass (needed for weight gradient).
+ However, when the high precision input is already being saved (e.g., for a residual connection),
+ keeping an additional quantized copy wastes memory.
+
+ The ``save_original_input=True`` option tells the layer to reference the original high precision input
+ instead of caching a separate quantized copy. The input is re-quantized during backward when needed.
+ Below is an example with a residual block where input is kept for the addition:
+
+ .. raw:: html
+
+
+ Needs to be run on SM89 (Ada) or SM90 (Hopper)
+
+
+ .. container:: program-output
+
+ .. literalinclude:: memory_usage_2_jax.out
+ :language: text
+ :start-after: # START_MEMORY_USAGE_2
+ :end-before: # END_MEMORY_USAGE_2
+
+ In JAX, unlike PyTorch, FP8 weights are not cached between forward passes.
+ Weights are stored in BF16 and quantized to FP8 on-the-fly during each forward pass.
+ This means the memory usage is similar to the baseline.
+
+ .. note::
+
+ JAX does not currently support storing model weights directly in FP8 format
+ like PyTorch's ``quantized_model_init``. Weights are always stored in high precision
+ (BF16/FP32) and quantized to FP8 during computation.
+
+Fused layers
+------------
+
+
+Transformer Engine provides fused layers such as ``LayerNormLinear`` and ``LayerNormMLP``
+that enable kernel fusion optimizations. One key optimization is fusing layer normalization
+with quantization.
+
+In a typical Transformer architecture, LayerNorm precedes a Linear layer. Without fusion,
+the LayerNorm outputs in FP32, and the Linear layer must then quantize this input before
+performing the GEMM — adding overhead. With ``LayerNormLinear``, these operations are fused
+into a single kernel: the LayerNorm output is quantized directly, eliminating the separate
+quantization step and reducing memory bandwidth.
+
+
+.. raw:: html
+ :file: img/fused_layers.svg
+
+*Figure 4: Comparison of separate LayerNorm and Linear layers versus fused LayerNormLinear layer, showing reduced quantization overhead.*
+
+
+Let's see how we can use fused layers in different frameworks.
+
+.. tabs::
+
+ .. tab:: PyTorch
+
+ In PyTorch, Transformer Engine provides fused layers like ``LayerNormLinear`` and ``LayerNormMLP``.
+ These layers combine normalization and linear operations with optimized quantization.
+
+ .. raw:: html
+
+
+ Needs to be run on SM89+ (Ada, Hopper, Blackwell, or newer)
+
+
+ .. literalinclude:: fused_layers_pytorch.py
+ :language: python
+ :start-after: # START_FUSED_LAYERS
+ :end-before: # END_FUSED_LAYERS
+
+ The fused ``LayerNormLinear`` layer is particularly efficient in FP8 training because
+ it avoids an intermediate quantization step. The LayerNorm output is directly quantized
+ for the GEMM operation, reducing memory bandwidth and improving performance.
+
+ .. tab:: JAX
+
+ In JAX, Transformer Engine provides fused layers like ``LayerNormDenseGeneral`` and ``LayerNormMLP``.
+ These layers combine normalization and dense operations with optimized quantization.
+
+ .. raw:: html
+
+
+ Needs to be run on SM89+ (Ada, Hopper, Blackwell, or newer)
+
+
+ .. literalinclude:: fused_layers_jax.py
+ :language: python
+ :start-after: # START_FUSED_LAYERS
+ :end-before: # END_FUSED_LAYERS
+
+ The fused ``LayerNormDenseGeneral`` layer is particularly efficient in FP8 training because
+ it avoids an intermediate quantization step. The LayerNorm output is directly quantized
+ for the GEMM operation, reducing memory bandwidth and improving performance.
+
+
+Distributed training
+--------------------
+
+Transformer Engine handles collective operations internally, so users typically don't need to manage
+the interaction between communication and low precision computation.
+
+Recall that each Linear layer involves six tensors: weight, input, output, and their gradients.
+Of these, output and gradients are returned in high precision, and weights are generally not
+communicated (except in FSDP, which is outside the scope of this section). This leaves two
+tensors where low precision communication matters: **input** and **output gradient**.
+
+For sequence parallelism, TE supports all-gather of quantized tensors. This provides several benefits:
+
+1. *Reduced memory* — no need to store high precision tensors for backward pass.
+2. *Reduced communication* — smaller tensors mean less data to transfer.
+3. *Parallelized quantization* — quantization work is distributed across GPUs.
+
+Support varies by recipe — for example, columnwise quantized all-gather is not available
+for all configurations.
+
+The figure below illustrates one possible all-gather scenario for input and output gradient tensors.
+Actual behavior depends on the recipe and module configuration.
+
+.. raw:: html
+ :file: img/sequence_parallel_quantization.svg
+
+*Figure 5: All-gather of quantized tensors for input and gradient tensors.
+This is one possible scenario — actual behavior varies depending on the recipe and module configuration.*
+
+
diff --git a/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out
new file mode 100644
index 00000000000..c7545c4ee7e
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out
@@ -0,0 +1,12 @@
+/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden.
+ Overriding a previously registered kernel for the same operator and the same dispatch key
+ operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
+ registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922
+ dispatch key: ADInplaceOrView
+ previous kernel: no debug info
+ new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
+ self.m.impl(
+# START_SAVE_ORIGINAL_INPUT
+save_original_input=False: 25.0 MB
+save_original_input=True: 24.0 MB
+# END_SAVE_ORIGINAL_INPUT
diff --git a/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py
new file mode 100644
index 00000000000..869be8e763c
--- /dev/null
+++ b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py
@@ -0,0 +1,51 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import torch
+
+# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
+cc = torch.cuda.get_device_capability()
+assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)"
+
+print("# START_SAVE_ORIGINAL_INPUT")
+# START_SAVE_ORIGINAL_INPUT
+import torch
+import transformer_engine.pytorch as te
+from transformer_engine.common.recipe import Float8CurrentScaling
+
+recipe = Float8CurrentScaling()
+
+
+def residual_block(layer, inp):
+ """Residual connection: input is saved for addition after linear."""
+ out = layer(inp)
+ return out + inp # inp must be kept for this addition
+
+
+def measure_memory(use_save_original):
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ layer = te.Linear(
+ 1024, 1024, params_dtype=torch.bfloat16, save_original_input=use_save_original
+ )
+ inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda", requires_grad=True)
+
+ with te.autocast(enabled=True, recipe=recipe):
+ out = residual_block(layer, inp)
+ out.sum().backward()
+
+ return torch.cuda.max_memory_allocated() / 1024**2
+
+
+# Warmup runs
+measure_memory(False)
+measure_memory(True)
+
+# Actual measurements
+for use_save_original in [False, True]:
+ peak = measure_memory(use_save_original)
+ print(f"save_original_input={use_save_original}: {peak:.1f} MB")
+# END_SAVE_ORIGINAL_INPUT
+print("# END_SAVE_ORIGINAL_INPUT")
diff --git a/docs/features/other_optimizations/cpu_offloading/cpu_offloading.rst b/docs/features/other_optimizations/cpu_offloading/cpu_offloading.rst
new file mode 100644
index 00000000000..29968e661c7
--- /dev/null
+++ b/docs/features/other_optimizations/cpu_offloading/cpu_offloading.rst
@@ -0,0 +1,230 @@
+..
+ Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+CPU Offloading
+===================================
+
+.. note::
+
+ CPU Offloading in Transformer Engine is currently available only for **PyTorch**.
+ It supports all PyTorch modules, not just TE layers.
+
+CPU offloading moves activation tensors from GPU to CPU memory during the
+forward pass and reloads them during backward. Transfers are **asynchronous**,
+enabling significant GPU memory savings with minimal overhead.
+
+Unlike activation checkpointing, offloading avoids recomputation — activations
+are stored on CPU instead of being recalculated, making it faster when
+CPU-GPU bandwidth is sufficient.
+
+
+Hardware Support
+----------------
+
+CPU offloading benefits greatly from fast CPU-GPU interconnects.
+The faster the link, the more effectively transfer time can be hidden
+behind computation.
+
+.. raw:: html
+ :file: img/pcie_vs_nvlink.svg
+
+*Figure 1. Traditional PCIe system vs GH200 Superchip with NVLink-C2C.*
+
+Traditional **PCIe Gen5 x16** systems offer **128 GB/s** bidirectional bandwidth
+between CPU and GPU, which limits offloading benefits.
+
+With **NVLink-C2C** (GH200), bandwidth jumps to **900 GB/s** bidirectional,
+making offloading increasingly attractive on modern NVIDIA superchips.
+The GH200 pairs a Grace CPU with 512 GB LPDDR5X memory and a Hopper GPU
+with 96 GB or 141 GB HBM3e, providing ample CPU memory for offloading activations.
+
+Offloading/reloading consumes HBM bandwidth, which may compete with
+computation — even when transfers are asynchronous. At full speed, this takes
+up to **900 GB/s** of HBM bandwidth. However, GH200's HBM3e provides **~4.9 TB/s**,
+so offloading/reloading uses less than 20%, making the impact on compute minimal.
+
+CPU Offloading in Transformer Engine
+------------------------------------
+
+Transformer Engine supports CPU offloading of activations for sequences of layers, where each layer
+consumes the output of the previous one — which is the case for most LLM architectures.
+These layers do not need to be TE layers —
+they can be arbitrary PyTorch modules. The API is as follows:
+
+.. code-block:: python
+
+ def get_cpu_offload_context(
+ enabled: bool = False,
+ num_layers: Optional[int] = 1,
+ model_layers: int = 1,
+ manual_synchronization: bool = False,
+ retain_pinned_cpu_buffers: bool = False,
+ offload_stream: Optional[torch.cuda.Stream] = None,
+ ) -> Union[Tuple[ContextManager, Callable], Tuple[ContextManager, Callable, ManualOffloadSynchronizer]]:
+ ...
+
+The ``model_layers`` parameter must always be set to the total number of layers in the model.
+There are two modes of operation:
+
+1. **Default scheduling** — set ``num_layers`` to the number of layers to offload.
+ The algorithm automatically schedules offload/reload operations to overlap with computation.
+
+2. **Manual synchronization** — set ``manual_synchronization=True`` (do not set ``num_layers``).
+ This mode provides explicit control over when to start offload/reload using the returned ``ManualOffloadSynchronizer``.
+
+The :func:`transformer_engine.pytorch.get_cpu_offload_context` function returns:
+
+- **context manager** — wrap each layer's forward pass with it to enable activation capture.
+- **sync function** — call on the output tensor after each layer, as shown in the example below.
+
+The example below shows how to offload activations for a sequence of ``torch.nn.Linear`` layers using the default scheduling algorithm:
+
+.. tabs::
+
+ .. tab:: PyTorch
+
+ .. literalinclude:: pytorch_basic_offload_example.py
+ :language: python
+ :start-after: # START_BASIC_EXAMPLE
+ :end-before: # END_BASIC_EXAMPLE
+
+
+Default Offloading Scheduling
+-----------------------------
+
+Default scheduling is enabled when ``manual_synchronization=False`` (the default).
+The ``num_layers`` parameter must be specified to set the number of layers to offload.
+The algorithm then automatically determines when to offload and reload activations
+to maximize overlap with computation.
+
+For ``num_layers`` layers offloaded of ``model_layers`` layers:
+
+- First ``num_layers`` layers are offloaded to CPU.
+- Offloading starts as soon as tensors are saved for backward — it does not wait
+ for the layer's forward pass to complete.
+- At most ``(model_layers - num_layers)`` sets of activations are on GPU at any time;
+ both compute and reload may be stalled to enforce this limit.
+- Reloading must complete by the time the tensor is needed for the layer's backward pass.
+
+Specifying a low enough ``num_layers`` enables full overlap of computation
+and offload/reload. The following two scenarios illustrate this — one with full overlap, and one with stalls.
+
+.. raw:: html
+ :file: img/scheduling.svg
+
+*Figure 2. With* ``num_layers=2`` *and* ``model_layers=5`` *, at most 3 sets of activations are on GPU. Layer 1 offloading starts during its forward pass (when the first tensor is saved for backward). Offloading fully overlaps with forward, reloading fully overlaps with backward.*
+
+When ``num_layers`` is too high, the GPU memory limit forces stalls:
+
+.. raw:: html
+ :file: img/scheduling_stall.svg
+
+*Figure 3. With* ``num_layers=3`` *and* ``model_layers=5`` *, at most 2 sets of activations can be on GPU (5-3=2), which causes stalls. In forward, Layer 4 cannot start until Layer 2 is offloaded, otherwise there would be 3 sets of activations on GPU (Layers 2, 3, 4). In backward, Layer 3 cannot start immediately — its activations are still on CPU and must be reloaded first. Some tensors may finish reloading earlier, allowing parts of the layer (e.g., a sublayer) to run while the rest waits. The same applies to Layers 2 and 1.*
+
+
+Manual Synchronization
+----------------------
+
+For custom scheduling, set ``manual_synchronization=True``
+and pass a custom ``offload_stream``. This returns a ``ManualOffloadSynchronizer``
+with explicit control over transfers and allows synchronization via stream operations.
+
+This mode is useful when training does not follow the standard "all forwards then all backwards"
+pattern — for example, in pipeline parallelism. Having access to the ``offload_stream`` enables
+custom synchronization logic (e.g., waiting, recording events) tailored to the specific workload.
+
+The ``ManualOffloadSynchronizer`` object provides the following methods:
+
+- ``start_offload_layer(layer_id)`` — queue async GPU→CPU copies on the offload stream.
+ Before each copy, the offload stream waits for an event recorded when that tensor
+ was saved for backward.
+- ``release_activation_forward_gpu_memory(layer_id)`` — wait for the offload to complete
+ and release GPU memory.
+- ``start_reload_layer(layer_id)`` — queue async CPU→GPU copies on the offload stream.
+ When tensors are accessed in backward, compute stream waits for each tensor's reload
+ to complete.
+
+To skip offloading for a specific layer, simply do not call any of these methods for that layer.
+
+.. tabs::
+
+ .. tab:: PyTorch
+
+ The example demonstrates:
+
+ 1. **Forward pass**: After each layer, call ``start_offload_layer(i)`` to begin
+ async copy of layer ``i``'s activations to CPU.
+ 2. **Release GPU memory**: Call ``offload_stream.synchronize()`` to wait for all
+ offloads to finish, then ``release_activation_forward_gpu_memory(i)`` to free
+ the GPU tensors.
+ 3. **Before backward**: Call ``start_reload_layer(i)`` to begin async reload.
+ The compute stream will automatically wait for each tensor to be reloaded
+ before it's accessed in backward.
+
+ .. literalinclude:: pytorch_manual_offload_example.py
+ :language: python
+ :start-after: # START_MANUAL_EXAMPLE
+ :end-before: # END_MANUAL_EXAMPLE
+
+
+CPU Offloading and CUDA Graphs
+------------------------------
+
+CPU offloading works with CUDA graphs — async copies and stream synchronization
+are GPU operations that can be captured and replayed, even when accessing
+pinned CPU memory (via PCIe DMA, without CPU involvement).
+
+.. note::
+
+ The entire forward and backward pass must be captured in a single graph.
+ Per-layer graph capture is not supported due to cross-layer synchronization.
+
+.. note::
+
+ Allocating pinned CPU memory is currently not graphable. Use
+ ``retain_pinned_cpu_buffers=True`` and run a warm-up iteration before
+ capture to pre-allocate buffers that are reused during replay.
+
+.. tabs::
+
+ .. tab:: PyTorch
+
+ .. literalinclude:: pytorch_cuda_graphs_example.py
+ :language: python
+ :start-after: # START_CUDA_GRAPHS_EXAMPLE
+ :end-before: # END_CUDA_GRAPHS_EXAMPLE
+
+Caveats
+-------
+
+.. warning::
+
+ **Heuristic activation detection**:
+
+ CPU Offloading is implemented using
+ `PyTorch saved tensors hooks `_.
+ PyTorch saves various tensors for backward — not just activations, but also weights and other data.
+
+ Activation detection is heuristic: all CUDA tensors that are not ``torch.nn.Parameter`` are offloaded.
+ For TE layers, tensors that should not be offloaded are manually excluded.
+ For non-TE layers, no such exclusions exist, so some tensors may remain pinned in GPU memory
+ even after being copied to CPU (e.g., if the layer stores references in ``ctx``),
+ resulting in wasted bandwidth with no memory savings.
+
+.. warning::
+
+ **Memory layout changes**:
+
+ Offloading/reloading can change tensor memory layout and relations:
+
+ 1. Views of the same storage may be restored as separate allocations.
+ 2. Adjacent tensors may not be adjacent after reload.
+
+ CUDA kernels that rely on specific memory layout may produce unexpected results.
+ To mitigate (1), non-trivial views are excluded from offloading by default.
+ TE attention kernels are an exception — they use internal handling that is tested and supported.
+ Issue (2) is not mitigated — custom kernels that assume adjacent tensors share
+ contiguous memory may still fail.
+
diff --git a/docs/features/other_optimizations/cpu_offloading/img/pcie_vs_nvlink.svg b/docs/features/other_optimizations/cpu_offloading/img/pcie_vs_nvlink.svg
new file mode 100644
index 00000000000..bcca76917a9
--- /dev/null
+++ b/docs/features/other_optimizations/cpu_offloading/img/pcie_vs_nvlink.svg
@@ -0,0 +1,111 @@
+
\ No newline at end of file
diff --git a/docs/features/other_optimizations/cpu_offloading/img/scheduling.svg b/docs/features/other_optimizations/cpu_offloading/img/scheduling.svg
new file mode 100644
index 00000000000..19255c3474d
--- /dev/null
+++ b/docs/features/other_optimizations/cpu_offloading/img/scheduling.svg
@@ -0,0 +1,110 @@
+
diff --git a/docs/features/other_optimizations/cpu_offloading/img/scheduling_stall.svg b/docs/features/other_optimizations/cpu_offloading/img/scheduling_stall.svg
new file mode 100644
index 00000000000..cd2d1a660c6
--- /dev/null
+++ b/docs/features/other_optimizations/cpu_offloading/img/scheduling_stall.svg
@@ -0,0 +1,143 @@
+
diff --git a/docs/features/other_optimizations/cpu_offloading/pytorch_basic_offload_example.py b/docs/features/other_optimizations/cpu_offloading/pytorch_basic_offload_example.py
new file mode 100644
index 00000000000..a9e0f6278b3
--- /dev/null
+++ b/docs/features/other_optimizations/cpu_offloading/pytorch_basic_offload_example.py
@@ -0,0 +1,32 @@
+# START_BASIC_EXAMPLE
+import torch
+from transformer_engine.pytorch import get_cpu_offload_context
+
+# Setup
+num_layers = 12
+offloaded_layers = 3
+layers = [torch.nn.Linear(1024, 1024).cuda() for _ in range(num_layers)]
+x = torch.randn(16, 1024, 1024, device="cuda")
+
+# Get offloading context and sync function
+cpu_offload_context, sync_function = get_cpu_offload_context(
+ enabled=True,
+ model_layers=num_layers,
+ num_layers=offloaded_layers,
+)
+
+# Forward pass
+for i in range(num_layers):
+ # Context manager captures tensors saved for backward.
+ # These tensors will be offloaded to CPU asynchronously.
+ with cpu_offload_context:
+ x = layers[i](x)
+
+ # sync_function must be called after each layer's forward pass.
+ # This cannot be done inside the context manager because
+ # it needs the output tensor after the layer has finished.
+ x = sync_function(x)
+
+loss = x.sum()
+loss.backward()
+# END_BASIC_EXAMPLE
diff --git a/docs/features/other_optimizations/cpu_offloading/pytorch_cuda_graphs_example.py b/docs/features/other_optimizations/cpu_offloading/pytorch_cuda_graphs_example.py
new file mode 100644
index 00000000000..506d916842c
--- /dev/null
+++ b/docs/features/other_optimizations/cpu_offloading/pytorch_cuda_graphs_example.py
@@ -0,0 +1,43 @@
+# START_CUDA_GRAPHS_EXAMPLE
+import torch
+from transformer_engine.pytorch import get_cpu_offload_context, make_graphed_callables
+
+# Setup
+num_layers = 12
+offloaded_layers = 3
+layers = [torch.nn.Linear(1024, 1024).cuda() for _ in range(num_layers)]
+
+# Enable offloading with retained buffers for CUDA graphs
+cpu_offload_context, sync_function = get_cpu_offload_context(
+ enabled=True,
+ model_layers=num_layers,
+ num_layers=offloaded_layers,
+ retain_pinned_cpu_buffers=True,
+)
+
+
+# Wrap layers in a module that uses offloading
+class OffloadedModel(torch.nn.Module):
+ def __init__(self, layers):
+ super().__init__()
+ self.layers = torch.nn.ModuleList(layers)
+
+ def forward(self, x):
+ for layer in self.layers:
+ with cpu_offload_context:
+ x = layer(x)
+ x = sync_function(x)
+ return x
+
+
+model = OffloadedModel(layers)
+sample_input = (torch.randn(16, 1024, 1024, device="cuda"),)
+
+# Create graphed callable (warmup is handled internally)
+graphed_model = make_graphed_callables(model, sample_input)
+
+# Use the graphed model
+x = torch.randn(16, 1024, 1024, device="cuda")
+out = graphed_model(x)
+out.sum().backward()
+# END_CUDA_GRAPHS_EXAMPLE
diff --git a/docs/features/other_optimizations/cpu_offloading/pytorch_manual_offload_example.py b/docs/features/other_optimizations/cpu_offloading/pytorch_manual_offload_example.py
new file mode 100644
index 00000000000..af1c76604c6
--- /dev/null
+++ b/docs/features/other_optimizations/cpu_offloading/pytorch_manual_offload_example.py
@@ -0,0 +1,37 @@
+# START_MANUAL_EXAMPLE
+import torch
+from transformer_engine.pytorch import get_cpu_offload_context
+
+# Setup
+num_layers = 12
+layers = [torch.nn.Linear(1024, 1024).cuda() for _ in range(num_layers)]
+x = torch.randn(16, 1024, 1024, device="cuda")
+
+offload_stream = torch.cuda.Stream()
+cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context(
+ enabled=True,
+ model_layers=num_layers,
+ manual_synchronization=True,
+ offload_stream=offload_stream,
+)
+
+# Forward pass - manually trigger offload after each layer
+for i in range(num_layers):
+ with cpu_offload_context:
+ x = layers[i](x)
+ x = sync_function(x)
+ manual_controller.start_offload_layer(i)
+
+# Wait for offloads, then release GPU memory
+offload_stream.synchronize()
+for i in range(num_layers):
+ manual_controller.release_activation_forward_gpu_memory(i)
+
+# Start reloading before backward
+for i in range(num_layers - 1, -1, -1):
+ manual_controller.start_reload_layer(i)
+
+# Backward pass
+loss = x.sum()
+loss.backward()
+# END_MANUAL_EXAMPLE
diff --git a/docs/features/other_optimizations/index.rst b/docs/features/other_optimizations/index.rst
new file mode 100644
index 00000000000..9338bc52ced
--- /dev/null
+++ b/docs/features/other_optimizations/index.rst
@@ -0,0 +1,11 @@
+..
+ Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+Other optimizations
+===================================
+
+.. toctree::
+
+ cpu_offloading/cpu_offloading.rst
\ No newline at end of file
diff --git a/docs/index.rst b/docs/index.rst
index 37d21c2a5dd..73fc2554ac1 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -39,6 +39,15 @@ Transformer Engine documentation
api/common
api/framework
+
+.. toctree::
+ :hidden:
+ :caption: Features
+
+ features/low_precision_training/index.rst
+ features/other_optimizations/index.rst
+
+
.. toctree::
:hidden:
:caption: Examples and Tutorials