diff --git a/benchmarks/quantized_matmul_kernel/autotune.py b/benchmarks/quantized_matmul_kernel/autotune.py new file mode 100644 index 000000000000..f2d1858df64c --- /dev/null +++ b/benchmarks/quantized_matmul_kernel/autotune.py @@ -0,0 +1,132 @@ +# To run, do python myscripts/autotune_quantized_matmul_pallas_kernel2.py 2>&1 | tee out.txt +# Then in the out.txt, extract lines with "Add to table:" and replace the string with 4 spaces, then copy to the block table in the pallas kernel. +import time +from typing import List + +import jax +import jax.numpy as jnp +from jax import lax +from jax import random +import numpy as np +import functools +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from torch_xla.experimental.pallas_kernels.quantized_matmul_kernel import ( + quantized_matmul_int8, + get_tuned_block_sizes, + TUNED_BLOCK_SIZES, +) + +def _quantize_tensor(x, n_bits: int = 8, dim: int = -1): + max_val = jnp.amax(jnp.abs(x), axis=dim, keepdims=True) + int_min = -2**(n_bits - 1) + int_max = 2**(n_bits - 1) - 1 + scale = max_val / int_max + x_int = jnp.clip(jnp.rint(x / scale), int_min, int_max).astype(jnp.int8) + return x_int, scale.astype(x.dtype) + +def find_factors_multiple_of_128(n: int) -> List[int]: + """ + Finds all factors of an integer n that are also multiples of 128. + + Args: + n: The integer for which to find factors. + + Returns: + A list of integers that are factors of n and are multiples of 128. + Returns an empty list if n is 0 or if no such factors exist. + Handles negative input by taking the absolute value. + """ + # Handle edge case for 0 + if n == 0: + return [] + + # Work with the absolute value for factor finding + n_abs = abs(n) + + # We are looking for factors f such that n_abs % f == 0 AND f % 128 == 0. + # This means f must be a multiple of 128. + # So, we only need to check multiples of 128 up to n_abs. + + factors = [] + multiplier = 1 + potential_factor = 128 * multiplier + + # Iterate through multiples of 128 (128, 256, 384, ...) + # as long as they are less than or equal to n_abs + while potential_factor <= n_abs: + # Check if this multiple of 128 is a factor of n_abs + if n_abs % potential_factor == 0: + factors.append(potential_factor) + + # Move to the next multiple of 128 + multiplier += 1 + potential_factor = 128 * multiplier # Calculate the next potential factor + + return factors + +# Benchmarking script starts. + +# one off +# batch_sizes = [16, 32] +# out_in_features = [(128, 256), (256, 128)] +# batch_block_sizes = [128, 256] + +# for real +batch_sizes = [16, 32, 64, 128, 256, 512, 1024, 2048] +out_in_features = [(6144, 4096), (4096, 4096), (28672, 4096), (4096, 14336), (1280, 8192), (8192, 1024), (7168, 8192), (8192, 3584)] +batch_block_sizes = [128, 256, 512, 1024, 2048] + +for bs in batch_sizes: + for n_output_features, n_input_features in out_in_features: + out_block_sizes = find_factors_multiple_of_128(n_output_features) + in_block_sizes = find_factors_multiple_of_128(n_input_features) + print(f'Benchmarking w8a8 matmul with bs={bs}, n_output_features={n_output_features}, n_input_features={n_input_features}, for block sizes: {out_block_sizes=}, {in_block_sizes=}') + + dtype = jnp.bfloat16 + prng_key = jax.random.key(1234) + k0, k1 = jax.random.split(prng_key, 2) + x = jax.random.normal(k0, (bs, n_input_features), dtype=dtype) + w = jax.random.normal(k1, (n_output_features, n_input_features), dtype=dtype) + w_w8a8_jax, scalar_jax = _quantize_tensor(w, n_bits=8, dim=-1) + scalar_jax = scalar_jax.squeeze() + assert scalar_jax.shape == (n_output_features,) + + vmem_limit_80_mb = 80 * 1024 * 1024 + best_time = None + best_batch_block_size = None + best_out_block_size = None + best_in_block_size = None + skip_trial = False + for batch_block_size in batch_block_sizes: + if bs < batch_block_size and best_time is not None: + continue + for out_block_size in out_block_sizes: + for in_block_size in in_block_sizes: + skip_trial = False + print(f'Benchmarking w8a8 matmul bs={bs}, n_output_features={n_output_features}, n_input_features={n_input_features} with batch_block_size={batch_block_size}, out_block_size={out_block_size}, in_block_size={in_block_size}', flush=True) + for _ in range(10): # warming up + try: + quantized_matmul_int8(x, w_w8a8_jax, scalar_jax, quantize_activation=True, batch_block_size=batch_block_size, out_block_size=out_block_size, in_block_size=in_block_size, vmem_limit_bytes=vmem_limit_80_mb).block_until_ready() + except Exception as e: + print(f'Failed to run quantized_matmul with batch_block_size={batch_block_size}, out_block_size={out_block_size}, in_block_size={in_block_size} due to {e}', flush=True) + skip_trial = True + break + if skip_trial: + continue + num_iterations = 30 + start_time = time.perf_counter_ns() + for _ in range(num_iterations): + quantized_matmul_int8(x, w_w8a8_jax, scalar_jax, quantize_activation=True, batch_block_size=batch_block_size, out_block_size=out_block_size, in_block_size=in_block_size, vmem_limit_bytes=vmem_limit_80_mb).block_until_ready() + end_time = time.perf_counter_ns() + elapsed_time = (end_time - start_time) / num_iterations + print(f'Benchmarked w8a8 matmul with batch_block_size={batch_block_size}, out_block_size={out_block_size}, in_block_size={in_block_size}, time={elapsed_time}') + if best_time is None or elapsed_time < best_time: + best_time = elapsed_time + best_batch_block_size = batch_block_size + best_out_block_size = out_block_size + best_in_block_size = in_block_size + print(f'Best batch_block_size={best_batch_block_size}, out_block_size={best_out_block_size}, in_block_size={best_in_block_size}, time={best_time}') + print(f'Add to table: (6, {bs}, {n_output_features}, {n_input_features}, \'{jnp.dtype(dtype).name}\', {True}): ({best_batch_block_size}, {best_out_block_size}, {best_in_block_size}),') + +# key should be: bs, n_output_features, n_input_features, dtype, quantize_activation diff --git a/benchmarks/quantized_matmul_kernel/benchmark.py b/benchmarks/quantized_matmul_kernel/benchmark.py new file mode 100644 index 000000000000..e83ea146f0fe --- /dev/null +++ b/benchmarks/quantized_matmul_kernel/benchmark.py @@ -0,0 +1,74 @@ +import time +from typing import List + +import jax +import jax.numpy as jnp +from jax import lax +from jax import random +import numpy as np +import functools +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from torch_xla.experimental.pallas_kernels.quantized_matmul_kernel import ( + quantized_matmul_int8, +) + +cases = { + "TP=1":{ + "bs_nums": {16: 127, 128: 901, 256: 1, 512: 512, 1024: 5, 2048: 846}, + "out_in_features": [(6144, 4096), (4096, 4096), (28672, 4096), (4096, 14336)], + }, + "TP=8": { + "bs_nums": {16: 1016, 64: 64, 128: 1440, 512: 8, 1024: 32, 2048: 7032}, + "out_in_features": [(1280, 8192), (8192, 1024), (7168, 8192), (8192, 3584)], + }, +} + +# one off, for testing +# cases = { +# "TP=1":{ +# "bs_nums": {16: 1, 128: 2}, +# "out_in_features": [(128, 128)], +# }, +# "TP=8": { +# "bs_nums": {16: 1, 128: 2}, +# "out_in_features": [(256, 256)], +# }, +# } + +def _quantize_tensor(x, n_bits: int = 8, dim: int = -1): + max_val = jnp.amax(jnp.abs(x), axis=dim, keepdims=True) + int_min = -2**(n_bits - 1) + int_max = 2**(n_bits - 1) - 1 + scale = max_val / int_max + x_int = jnp.clip(jnp.rint(x / scale), int_min, int_max).astype(jnp.int8) + return x_int, scale.astype(x.dtype) + +def run_benchmark(bs_nums, out_in_features: List[tuple]): + elapsed_time_ms = 0 + print(f"Running benchmark with bs_nums: {bs_nums} and out_in_features: {out_in_features}") + for bs, num_occur in bs_nums.items(): + for n_output_features, n_input_features in out_in_features: + dtype = jnp.bfloat16 + prng_key = random.key(1234) + k0, k1 = jax.random.split(prng_key, 2) + x = jax.random.normal(k0, (bs, n_input_features), dtype=dtype) + w = jax.random.normal(k1, (n_output_features, n_input_features), dtype=dtype) + w_w8a8_jax, scalar_jax = _quantize_tensor(w, n_bits=8, dim=-1) + scalar_jax = scalar_jax.squeeze() + assert scalar_jax.shape == (n_output_features,) + num_warmup = 5 + for _ in range(num_warmup): + quantized_matmul_int8(x, w_w8a8_jax, scalar_jax, quantize_activation=True).block_until_ready() + start_time = time.perf_counter_ns() + for _ in range(num_occur): + quantized_matmul_int8(x, w_w8a8_jax, scalar_jax, quantize_activation=True).block_until_ready() + end_time = time.perf_counter_ns() + elapsed_time_ms += (end_time - start_time)/(1e6) + return elapsed_time_ms + +for case, value in cases.items(): + bs_nums = value["bs_nums"] + out_in_features = value["out_in_features"] + elapsed_time_ms = run_benchmark(bs_nums, out_in_features) + print(f"Benchmarking {case} took {elapsed_time_ms:.2f} ms") diff --git a/benchmarks/quantized_matmul_kernel/microbenchmark.py b/benchmarks/quantized_matmul_kernel/microbenchmark.py new file mode 100644 index 000000000000..21efa645f5d8 --- /dev/null +++ b/benchmarks/quantized_matmul_kernel/microbenchmark.py @@ -0,0 +1,2 @@ +# I plan to create one but I think the kernel should reside in g3 and so the microbenchmark should be in g3 using xprof. +# So I'll not create a microbenchmark for now. diff --git a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py index e2d4903ae260..6693b8ddb6b3 100644 --- a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py +++ b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py @@ -212,70 +212,70 @@ def quantized_matmul_int8( # - out_block_size # - in_block_size TUNED_BLOCK_SIZES = { - (6, 16, 6144, 4096, 'bfloat16', True): (128, 768, 4096), - (6, 16, 4096, 4096, 'bfloat16', True): (128, 2048, 2048), - (6, 16, 28672, 4096, 'bfloat16', True): (128, 1792, 2048), - (6, 16, 4096, 14336, 'bfloat16', True): (128, 512, 7168), - (6, 16, 1280, 8192, 'bfloat16', True): (128, 256, 8192), - (6, 16, 8192, 1024, 'bfloat16', True): (128, 8192, 256), - (6, 16, 7168, 8192, 'bfloat16', True): (128, 512, 8192), + (6, 16, 6144, 4096, 'bfloat16', True): (128, 6144, 1024), + (6, 16, 4096, 4096, 'bfloat16', True): (128, 4096, 1024), + (6, 16, 28672, 4096, 'bfloat16', True): (128, 3584, 2048), + (6, 16, 4096, 14336, 'bfloat16', True): (128, 4096, 1792), + (6, 16, 1280, 8192, 'bfloat16', True): (128, 1280, 2048), + (6, 16, 8192, 1024, 'bfloat16', True): (128, 1024, 1024), + (6, 16, 7168, 8192, 'bfloat16', True): (128, 7168, 1024), (6, 16, 8192, 3584, 'bfloat16', True): (128, 2048, 3584), (6, 32, 6144, 4096, 'bfloat16', True): (128, 1536, 4096), - (6, 32, 4096, 4096, 'bfloat16', True): (128, 4096, 2048), - (6, 32, 28672, 4096, 'bfloat16', True): (128, 2048, 2048), - (6, 32, 4096, 14336, 'bfloat16', True): (128, 256, 14336), - (6, 32, 1280, 8192, 'bfloat16', True): (128, 256, 8192), - (6, 32, 8192, 1024, 'bfloat16', True): (128, 1024, 1024), - (6, 32, 7168, 8192, 'bfloat16', True): (128, 3584, 2048), - (6, 32, 8192, 3584, 'bfloat16', True): (128, 1024, 3584), - (6, 64, 6144, 4096, 'bfloat16', True): (128, 512, 4096), - (6, 64, 4096, 4096, 'bfloat16', True): (128, 2048, 2048), - (6, 64, 28672, 4096, 'bfloat16', True): (128, 1792, 2048), + (6, 32, 4096, 4096, 'bfloat16', True): (128, 4096, 1024), + (6, 32, 28672, 4096, 'bfloat16', True): (128, 4096, 2048), + (6, 32, 4096, 14336, 'bfloat16', True): (128, 4096, 1792), + (6, 32, 1280, 8192, 'bfloat16', True): (128, 1280, 2048), + (6, 32, 8192, 1024, 'bfloat16', True): (128, 8192, 256), + (6, 32, 7168, 8192, 'bfloat16', True): (128, 1792, 4096), + (6, 32, 8192, 3584, 'bfloat16', True): (128, 8192, 896), + (6, 64, 6144, 4096, 'bfloat16', True): (128, 1536, 2048), + (6, 64, 4096, 4096, 'bfloat16', True): (128, 2048, 1024), + (6, 64, 28672, 4096, 'bfloat16', True): (128, 3584, 2048), (6, 64, 4096, 14336, 'bfloat16', True): (128, 4096, 1024), - (6, 64, 1280, 8192, 'bfloat16', True): (128, 1280, 4096), - (6, 64, 8192, 1024, 'bfloat16', True): (128, 4096, 1024), - (6, 64, 7168, 8192, 'bfloat16', True): (128, 512, 8192), + (6, 64, 1280, 8192, 'bfloat16', True): (128, 1280, 2048), + (6, 64, 8192, 1024, 'bfloat16', True): (128, 8192, 1024), + (6, 64, 7168, 8192, 'bfloat16', True): (128, 3584, 2048), (6, 64, 8192, 3584, 'bfloat16', True): (128, 2048, 1792), - (6, 128, 6144, 4096, 'bfloat16', True): (128, 1536, 4096), - (6, 128, 4096, 4096, 'bfloat16', True): (128, 1024, 4096), - (6, 128, 28672, 4096, 'bfloat16', True): (128, 1024, 4096), - (6, 128, 4096, 14336, 'bfloat16', True): (128, 2048, 2048), - (6, 128, 1280, 8192, 'bfloat16', True): (128, 640, 4096), + (6, 128, 6144, 4096, 'bfloat16', True): (128, 6144, 1024), + (6, 128, 4096, 4096, 'bfloat16', True): (128, 4096, 2048), + (6, 128, 28672, 4096, 'bfloat16', True): (128, 28672, 512), + (6, 128, 4096, 14336, 'bfloat16', True): (128, 2048, 3584), + (6, 128, 1280, 8192, 'bfloat16', True): (128, 1280, 1024), (6, 128, 8192, 1024, 'bfloat16', True): (128, 2048, 1024), - (6, 128, 7168, 8192, 'bfloat16', True): (128, 896, 8192), - (6, 128, 8192, 3584, 'bfloat16', True): (128, 1024, 3584), - (6, 256, 6144, 4096, 'bfloat16', True): (256, 1536, 4096), - (6, 256, 4096, 4096, 'bfloat16', True): (256, 512, 4096), - (6, 256, 28672, 4096, 'bfloat16', True): (256, 896, 4096), - (6, 256, 4096, 14336, 'bfloat16', True): (256, 4096, 2048), - (6, 256, 1280, 8192, 'bfloat16', True): (256, 1280, 4096), - (6, 256, 8192, 1024, 'bfloat16', True): (256, 2048, 1024), - (6, 256, 7168, 8192, 'bfloat16', True): (256, 512, 8192), - (6, 256, 8192, 3584, 'bfloat16', True): (256, 1024, 3584), + (6, 128, 7168, 8192, 'bfloat16', True): (128, 1792, 4096), + (6, 128, 8192, 3584, 'bfloat16', True): (128, 8192, 896), + (6, 256, 6144, 4096, 'bfloat16', True): (256, 3072, 4096), + (6, 256, 4096, 4096, 'bfloat16', True): (256, 2048, 4096), + (6, 256, 28672, 4096, 'bfloat16', True): (256, 3584, 4096), + (6, 256, 4096, 14336, 'bfloat16', True): (256, 4096, 1792), + (6, 256, 1280, 8192, 'bfloat16', True): (256, 1280, 2048), + (6, 256, 8192, 1024, 'bfloat16', True): (256, 4096, 1024), + (6, 256, 7168, 8192, 'bfloat16', True): (256, 1792, 4096), + (6, 256, 8192, 3584, 'bfloat16', True): (256, 8192, 512), (6, 512, 6144, 4096, 'bfloat16', True): (512, 2048, 4096), - (6, 512, 4096, 4096, 'bfloat16', True): (512, 1024, 4096), - (6, 512, 28672, 4096, 'bfloat16', True): (512, 1792, 4096), - (6, 512, 4096, 14336, 'bfloat16', True): (512, 4096, 1792), + (6, 512, 4096, 4096, 'bfloat16', True): (512, 4096, 512), + (6, 512, 28672, 4096, 'bfloat16', True): (512, 4096, 4096), + (6, 512, 4096, 14336, 'bfloat16', True): (512, 4096, 2048), (6, 512, 1280, 8192, 'bfloat16', True): (512, 1280, 2048), (6, 512, 8192, 1024, 'bfloat16', True): (512, 4096, 1024), - (6, 512, 7168, 8192, 'bfloat16', True): (512, 1792, 4096), - (6, 512, 8192, 3584, 'bfloat16', True): (512, 2048, 3584), - (6, 1024, 6144, 4096, 'bfloat16', True): (1024, 1024, 4096), - (6, 1024, 4096, 4096, 'bfloat16', True): (1024, 2048, 4096), - (6, 1024, 28672, 4096, 'bfloat16', True): (1024, 3584, 4096), - (6, 1024, 4096, 14336, 'bfloat16', True): (1024, 2048, 2048), - (6, 1024, 1280, 8192, 'bfloat16', True): (1024, 1280, 2048), - (6, 1024, 8192, 1024, 'bfloat16', True): (256, 8192, 1024), - (6, 1024, 7168, 8192, 'bfloat16', True): (1024, 1792, 8192), - (6, 1024, 8192, 3584, 'bfloat16', True): (1024, 2048, 3584), + (6, 512, 7168, 8192, 'bfloat16', True): (512, 7168, 512), + (6, 512, 8192, 3584, 'bfloat16', True): (512, 8192, 512), + (6, 1024, 6144, 4096, 'bfloat16', True): (512, 6144, 4096), + (6, 1024, 4096, 4096, 'bfloat16', True): (256, 4096, 4096), + (6, 1024, 28672, 4096, 'bfloat16', True): (1024, 4096, 4096), + (6, 1024, 4096, 14336, 'bfloat16', True): (1024, 4096, 1792), + (6, 1024, 1280, 8192, 'bfloat16', True): (512, 1280, 4096), + (6, 1024, 8192, 1024, 'bfloat16', True): (512, 8192, 1024), + (6, 1024, 7168, 8192, 'bfloat16', True): (512, 7168, 1024), + (6, 1024, 8192, 3584, 'bfloat16', True): (256, 8192, 3584), (6, 2048, 6144, 4096, 'bfloat16', True): (256, 6144, 4096), - (6, 2048, 4096, 4096, 'bfloat16', True): (1024, 2048, 4096), + (6, 2048, 4096, 4096, 'bfloat16', True): (512, 4096, 4096), (6, 2048, 28672, 4096, 'bfloat16', True): (1024, 4096, 4096), - (6, 2048, 4096, 14336, 'bfloat16', True): (1024, 4096, 1024), - (6, 2048, 1280, 8192, 'bfloat16', True): (512, 1280, 8192), + (6, 2048, 4096, 14336, 'bfloat16', True): (1024, 4096, 2048), + (6, 2048, 1280, 8192, 'bfloat16', True): (2048, 1280, 1024), (6, 2048, 8192, 1024, 'bfloat16', True): (256, 8192, 1024), - (6, 2048, 7168, 8192, 'bfloat16', True): (1024, 1792, 8192), - (6, 2048, 8192, 3584, 'bfloat16', True): (2048, 2048, 3584), + (6, 2048, 7168, 8192, 'bfloat16', True): (256, 7168, 8192), + (6, 2048, 8192, 3584, 'bfloat16', True): (512, 8192, 3584), }