diff --git a/build_tools/jax.py b/build_tools/jax.py index df78bf3e2f3..ec0b4aaef45 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -19,8 +19,33 @@ def install_requirements() -> List[str]: def test_requirements() -> List[str]: - """Test dependencies for TE/JAX extensions.""" - return ["numpy", "triton"] + """Test dependencies for TE/JAX extensions. + + Triton Package Selection: + The triton package is selected based on NVTE_USE_PYTORCH_TRITON environment variable: + + Default (NVTE_USE_PYTORCH_TRITON unset or "0"): + Returns 'triton' - OpenAI's standard package from PyPI. + Install with: pip install triton + + NVTE_USE_PYTORCH_TRITON=1: + Returns 'pytorch-triton' - for mixed JAX+PyTorch environments. + Install with: pip install pytorch-triton --index-url https://download.pytorch.org/whl/cu121 + + Note: Do NOT install pytorch-triton from PyPI directly - that's a placeholder. + """ + use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ( + "1", + "true", + "yes", + ) + + triton_package = "pytorch-triton" if use_pytorch_triton else "triton" + + return [ + "numpy", + triton_package, + ] def xla_path() -> str: diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index b03ef04fa42..19abd7d8293 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -13,7 +13,17 @@ def install_requirements() -> List[str]: - """Install dependencies for TE/PyTorch extensions.""" + """Install dependencies for TE/PyTorch extensions. + + IMPORTANT - PyTorch Index Required for pytorch-triton: + These dependencies MUST be installed using PyTorch's package index: + + pip install pytorch-triton --index-url https://download.pytorch.org/whl/ + + - pytorch-triton is only available from PyTorch's index (not PyPI) + - The 'pytorch-triton' package on PyPI is a placeholder that will fail + - torch.compile() requires pytorch-triton, not OpenAI's 'triton' package + """ return [ "torch>=2.1", "einops", @@ -22,7 +32,7 @@ def install_requirements() -> List[str]: "packaging", "pydantic", "nvdlfw-inspect", - "triton", + "pytorch-triton", ] diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index 13a36421bf1..6e635d3a2a2 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -9,7 +9,33 @@ IMPORTANT: This module requires Triton to be installed. If you don't have Triton, use transformer_engine.jax.cpp_extensions instead (CUDA/FFI based primitives). -Install Triton: pip install triton + +Triton Package Options: +----------------------- +There are two compatible Triton packages: + +1. Standard 'triton' from OpenAI (recommended for JAX-only environments): + pip install triton + +2. 'pytorch-triton' from PyTorch's index (for mixed JAX+PyTorch environments): + pip install torch --index-url https://download.pytorch.org/whl/cu121 + # pytorch-triton is automatically installed as a dependency + + Both packages work with JAX Triton kernels. The pytorch-triton package + has version format "X.Y.Z+" (e.g., "3.0.0+45fff310c8"). + +WARNING: Do NOT run 'pip install pytorch-triton' directly! The package on PyPI +is a placeholder that will fail with "RuntimeError: Should never be installed". +The real pytorch-triton only comes bundled with PyTorch from PyTorch's index. + + +Environment Variables: + NVTE_USE_PYTORCH_TRITON: If set to "1", acknowledge using pytorch-triton + for JAX Triton kernels (suppresses compatibility warnings). Set this + when both JAX and PyTorch are installed in the same environment. + + Example: + export NVTE_USE_PYTORCH_TRITON=1 Usage: @@ -23,6 +49,11 @@ def lowering(ctx, x, **kwargs): # Use permutation functions from transformer_engine.jax.triton_extensions import make_row_id_map, permute_with_mask_map + + # Check Triton package info + from transformer_engine.jax.triton_extensions import get_triton_info + info = get_triton_info() + print(f"Using Triton {info['version']} from {info['source']}") """ from .utils import * diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 12d6a9e3de4..361e17ef3ee 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -6,9 +6,13 @@ This module provides utility functions for integrating Triton kernels into JAX primitives. Triton is only imported when this module is used. + +Triton Package Compatibility --> see __init__.py """ import hashlib +import os +import warnings from typing import Any, Callable, Mapping import zlib @@ -17,6 +21,102 @@ import jax.numpy as jnp +# Placeholder package version on PyPI that should never be used +_PYTORCH_TRITON_PLACEHOLDER_VERSION = "0.0.1" + + +def _detect_triton_package(): + """Detect which Triton package is installed and validate compatibility. + + Returns: + tuple: (triton_version: str or None, is_pytorch_triton: bool, is_placeholder: bool) + + The function detects: + - None: Triton not installed + - Standard triton from OpenAI (versions like "3.1.0") + - Real pytorch-triton from PyTorch's index (versions like "3.0.0+45fff310c8") + - Placeholder pytorch-triton from PyPI (version "0.0.1" - broken, raises RuntimeError) + """ + try: + import triton + + triton_version = getattr(triton, "__version__", "unknown") + except ImportError: + return None, False, False + except RuntimeError as e: + # The placeholder pytorch-triton package from PyPI raises: + # RuntimeError: "Should never be installed" + if "Should never be installed" in str(e): + return _PYTORCH_TRITON_PLACEHOLDER_VERSION, False, True + raise + + # Check for placeholder package (version 0.0.1 from PyPI) + is_placeholder = triton_version == _PYTORCH_TRITON_PLACEHOLDER_VERSION + + # Real pytorch-triton versions have a commit SHA suffix like "3.0.0+45fff310c8" + is_pytorch_triton = "+" in triton_version and len(triton_version.split("+")[-1]) >= 8 + + return triton_version, is_pytorch_triton, is_placeholder + + +def _check_triton_compatibility(): + """Check Triton package compatibility and emit warnings if necessary.""" + triton_version, is_pytorch_triton, is_placeholder = _detect_triton_package() + + # Handle placeholder package from PyPI + if is_placeholder: + raise ImportError( + "Detected the placeholder 'pytorch-triton' package (version 0.0.1) from PyPI.\n" + "This is NOT a functional Triton installation.\n\n" + "The placeholder package exists to prevent namespace conflicts. To fix this:\n\n" + "Option 1 - Use standard Triton (recommended for JAX-only environments):\n" + " pip uninstall pytorch-triton triton\n" + " pip install triton\n\n" + "Option 2 - Use real pytorch-triton (for mixed JAX+PyTorch environments):\n" + " pip uninstall pytorch-triton triton\n" + " pip install torch --index-url https://download.pytorch.org/whl/cu121\n" + " # pytorch-triton is automatically installed as a torch dependency\n\n" + "Note: Do NOT run 'pip install pytorch-triton' directly - this installs\n" + "the broken placeholder. The real pytorch-triton only comes from PyTorch's index." + ) + + if triton_version is None: + raise ImportError( + "Triton is required for transformer_engine.jax.triton_extensions.\n\n" + "Option 1 - Install standard Triton (recommended for JAX-only):\n" + " pip install triton\n\n" + "Option 2 - Install PyTorch with pytorch-triton (for mixed environments):\n" + " pip install torch --index-url https://download.pytorch.org/whl/cu121\n\n" + "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." + ) + + use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() + use_pytorch_triton_explicit = use_pytorch_triton_env in ("1", "true", "yes") + + if is_pytorch_triton: + if use_pytorch_triton_explicit: + # User explicitly opted in - just log info (no warning) + pass # Silent acknowledgment, no warning needed + else: + # pytorch-triton detected but user didn't explicitly opt in + warnings.warn( + f"Detected pytorch-triton package (version {triton_version}) instead of the" + " standard 'triton' package from OpenAI. This typically happens when PyTorch is" + " installed alongside JAX.\n\npytorch-triton is compatible with JAX Triton" + " kernels. To suppress this warning, set:\n export" + " NVTE_USE_PYTORCH_TRITON=1\n\nAlternatively, for a JAX-only environment:\n - Use" + " separate virtual environments for JAX and PyTorch, or\n - Use" + " transformer_engine.jax.cpp_extensions instead (CUDA-based, no Triton needed)", + category=UserWarning, + stacklevel=3, + ) + + return triton_version, is_pytorch_triton + + +# Perform compatibility check and get triton info +_TRITON_VERSION, _IS_PYTORCH_TRITON = _check_triton_compatibility() + try: from jax._src.lib import gpu_triton from triton.compiler import compiler as tc @@ -30,12 +130,35 @@ ) from e -__all__ = ["triton_call_lowering"] +__all__ = ["triton_call_lowering", "get_triton_info"] # Triton kernel cache (module-level, shared across all kernels) _TRITON_KERNEL_CACHE = {} +def get_triton_info(): + """Get information about the installed Triton package. + + Returns: + dict: Dictionary containing: + - version (str): Triton version string (e.g., "3.1.0" or "3.0.0+45fff310c8") + - is_pytorch_triton (bool): True if using real pytorch-triton from PyTorch's index + - is_openai_triton (bool): True if using standard triton from OpenAI/PyPI + - env_acknowledged (bool): True if NVTE_USE_PYTORCH_TRITON=1 is set + - source (str): "pytorch" or "openai" indicating the package source + """ + use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() + env_acknowledged = use_pytorch_triton_env in ("1", "true", "yes") + + return { + "version": _TRITON_VERSION, + "is_pytorch_triton": _IS_PYTORCH_TRITON, + "is_openai_triton": not _IS_PYTORCH_TRITON, + "env_acknowledged": env_acknowledged and _IS_PYTORCH_TRITON, + "source": "pytorch" if _IS_PYTORCH_TRITON else "openai", + } + + def get_triton_dtype(aval): """Convert JAX dtype to Triton type string. @@ -142,17 +265,31 @@ def compile_triton( ) # Create kernel object for JAX - kernel = gpu_triton.TritonKernel( - compiled.name, - num_warps, - compiled.metadata.shared, - compiled.asm["ptx"], - "", # ttir - compute_capability, - 1, - 1, - 1, # cluster_dims - ) + # From jax/jaxlib/gpu/triton_kernels.cc: + from packaging import version + + if version.parse(jax.__version__) >= version.parse("0.8.2"): + kernel = gpu_triton.TritonKernel( + compiled.name, # arg0: kernel_name (str) + num_warps, # arg1: num_warps (int) + num_ctas, # arg2: num_ctas (int) + compiled.metadata.shared, # arg3: shared_mem_bytes (int) + compiled.asm["ptx"], # arg4: ptx (str) + "", # arg5: ttir (str) - empty + compute_capability, # arg6: compute_capability (int) + ) + else: + kernel = gpu_triton.TritonKernel( + compiled.name, + num_warps, + compiled.metadata.shared, + compiled.asm["ptx"], + "", # ttir + compute_capability, + 1, + 1, + 1, + ) _TRITON_KERNEL_CACHE[cache_key] = kernel return kernel