Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions benchmarks/quantized_matmul_kernel/autotune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# To run, do python myscripts/autotune_quantized_matmul_pallas_kernel2.py 2>&1 | tee out.txt
# Then in the out.txt, extract lines with "Add to table:" and replace the string with 4 spaces, then copy to the block table in the pallas kernel.
import time
from typing import List

import jax
import jax.numpy as jnp
from jax import lax
from jax import random
import numpy as np
import functools
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from torch_xla.experimental.pallas_kernels.quantized_matmul_kernel import (
quantized_matmul_int8,
get_tuned_block_sizes,
TUNED_BLOCK_SIZES,
)

def _quantize_tensor(x, n_bits: int = 8, dim: int = -1):
max_val = jnp.amax(jnp.abs(x), axis=dim, keepdims=True)
int_min = -2**(n_bits - 1)
int_max = 2**(n_bits - 1) - 1
scale = max_val / int_max
x_int = jnp.clip(jnp.rint(x / scale), int_min, int_max).astype(jnp.int8)
return x_int, scale.astype(x.dtype)

def find_factors_multiple_of_128(n: int) -> List[int]:
"""
Finds all factors of an integer n that are also multiples of 128.

Args:
n: The integer for which to find factors.

Returns:
A list of integers that are factors of n and are multiples of 128.
Returns an empty list if n is 0 or if no such factors exist.
Handles negative input by taking the absolute value.
"""
# Handle edge case for 0
if n == 0:
return []

# Work with the absolute value for factor finding
n_abs = abs(n)

# We are looking for factors f such that n_abs % f == 0 AND f % 128 == 0.
# This means f must be a multiple of 128.
# So, we only need to check multiples of 128 up to n_abs.

factors = []
multiplier = 1
potential_factor = 128 * multiplier

# Iterate through multiples of 128 (128, 256, 384, ...)
# as long as they are less than or equal to n_abs
while potential_factor <= n_abs:
# Check if this multiple of 128 is a factor of n_abs
if n_abs % potential_factor == 0:
factors.append(potential_factor)

# Move to the next multiple of 128
multiplier += 1
potential_factor = 128 * multiplier # Calculate the next potential factor

return factors

# Benchmarking script starts.

# one off
# batch_sizes = [16, 32]
# out_in_features = [(128, 256), (256, 128)]
# batch_block_sizes = [128, 256]

# for real
batch_sizes = [16, 32, 64, 128, 256, 512, 1024, 2048]
out_in_features = [(6144, 4096), (4096, 4096), (28672, 4096), (4096, 14336), (1280, 8192), (8192, 1024), (7168, 8192), (8192, 3584)]
batch_block_sizes = [128, 256, 512, 1024, 2048]

for bs in batch_sizes:
for n_output_features, n_input_features in out_in_features:
out_block_sizes = find_factors_multiple_of_128(n_output_features)
in_block_sizes = find_factors_multiple_of_128(n_input_features)
print(f'Benchmarking w8a8 matmul with bs={bs}, n_output_features={n_output_features}, n_input_features={n_input_features}, for block sizes: {out_block_sizes=}, {in_block_sizes=}')

dtype = jnp.bfloat16
prng_key = jax.random.key(1234)
k0, k1 = jax.random.split(prng_key, 2)
x = jax.random.normal(k0, (bs, n_input_features), dtype=dtype)
w = jax.random.normal(k1, (n_output_features, n_input_features), dtype=dtype)
w_w8a8_jax, scalar_jax = _quantize_tensor(w, n_bits=8, dim=-1)
scalar_jax = scalar_jax.squeeze()
assert scalar_jax.shape == (n_output_features,)

vmem_limit_80_mb = 80 * 1024 * 1024
best_time = None
best_batch_block_size = None
best_out_block_size = None
best_in_block_size = None
skip_trial = False
for batch_block_size in batch_block_sizes:
if bs < batch_block_size and best_time is not None:
continue
for out_block_size in out_block_sizes:
for in_block_size in in_block_sizes:
skip_trial = False
print(f'Benchmarking w8a8 matmul bs={bs}, n_output_features={n_output_features}, n_input_features={n_input_features} with batch_block_size={batch_block_size}, out_block_size={out_block_size}, in_block_size={in_block_size}', flush=True)
for _ in range(10): # warming up
try:
quantized_matmul_int8(x, w_w8a8_jax, scalar_jax, quantize_activation=True, batch_block_size=batch_block_size, out_block_size=out_block_size, in_block_size=in_block_size, vmem_limit_bytes=vmem_limit_80_mb).block_until_ready()
except Exception as e:
print(f'Failed to run quantized_matmul with batch_block_size={batch_block_size}, out_block_size={out_block_size}, in_block_size={in_block_size} due to {e}', flush=True)
skip_trial = True
break
if skip_trial:
continue
num_iterations = 30
start_time = time.perf_counter_ns()
for _ in range(num_iterations):
quantized_matmul_int8(x, w_w8a8_jax, scalar_jax, quantize_activation=True, batch_block_size=batch_block_size, out_block_size=out_block_size, in_block_size=in_block_size, vmem_limit_bytes=vmem_limit_80_mb).block_until_ready()
end_time = time.perf_counter_ns()
elapsed_time = (end_time - start_time) / num_iterations
print(f'Benchmarked w8a8 matmul with batch_block_size={batch_block_size}, out_block_size={out_block_size}, in_block_size={in_block_size}, time={elapsed_time}')
if best_time is None or elapsed_time < best_time:
best_time = elapsed_time
best_batch_block_size = batch_block_size
best_out_block_size = out_block_size
best_in_block_size = in_block_size
print(f'Best batch_block_size={best_batch_block_size}, out_block_size={best_out_block_size}, in_block_size={best_in_block_size}, time={best_time}')
print(f'Add to table: (6, {bs}, {n_output_features}, {n_input_features}, \'{jnp.dtype(dtype).name}\', {True}): ({best_batch_block_size}, {best_out_block_size}, {best_in_block_size}),')

# key should be: bs, n_output_features, n_input_features, dtype, quantize_activation
74 changes: 74 additions & 0 deletions benchmarks/quantized_matmul_kernel/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import time
from typing import List

import jax
import jax.numpy as jnp
from jax import lax
from jax import random
import numpy as np
import functools
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from torch_xla.experimental.pallas_kernels.quantized_matmul_kernel import (
quantized_matmul_int8,
)

cases = {
"TP=1":{
"bs_nums": {16: 127, 128: 901, 256: 1, 512: 512, 1024: 5, 2048: 846},
"out_in_features": [(6144, 4096), (4096, 4096), (28672, 4096), (4096, 14336)],
},
"TP=8": {
"bs_nums": {16: 1016, 64: 64, 128: 1440, 512: 8, 1024: 32, 2048: 7032},
"out_in_features": [(1280, 8192), (8192, 1024), (7168, 8192), (8192, 3584)],
},
}

# one off, for testing
# cases = {
# "TP=1":{
# "bs_nums": {16: 1, 128: 2},
# "out_in_features": [(128, 128)],
# },
# "TP=8": {
# "bs_nums": {16: 1, 128: 2},
# "out_in_features": [(256, 256)],
# },
# }

def _quantize_tensor(x, n_bits: int = 8, dim: int = -1):
max_val = jnp.amax(jnp.abs(x), axis=dim, keepdims=True)
int_min = -2**(n_bits - 1)
int_max = 2**(n_bits - 1) - 1
scale = max_val / int_max
x_int = jnp.clip(jnp.rint(x / scale), int_min, int_max).astype(jnp.int8)
return x_int, scale.astype(x.dtype)

def run_benchmark(bs_nums, out_in_features: List[tuple]):
elapsed_time_ms = 0
print(f"Running benchmark with bs_nums: {bs_nums} and out_in_features: {out_in_features}")
for bs, num_occur in bs_nums.items():
for n_output_features, n_input_features in out_in_features:
dtype = jnp.bfloat16
prng_key = random.key(1234)
k0, k1 = jax.random.split(prng_key, 2)
x = jax.random.normal(k0, (bs, n_input_features), dtype=dtype)
w = jax.random.normal(k1, (n_output_features, n_input_features), dtype=dtype)
w_w8a8_jax, scalar_jax = _quantize_tensor(w, n_bits=8, dim=-1)
scalar_jax = scalar_jax.squeeze()
assert scalar_jax.shape == (n_output_features,)
num_warmup = 5
for _ in range(num_warmup):
quantized_matmul_int8(x, w_w8a8_jax, scalar_jax, quantize_activation=True).block_until_ready()
start_time = time.perf_counter_ns()
for _ in range(num_occur):
quantized_matmul_int8(x, w_w8a8_jax, scalar_jax, quantize_activation=True).block_until_ready()
end_time = time.perf_counter_ns()
elapsed_time_ms += (end_time - start_time)/(1e6)
return elapsed_time_ms

for case, value in cases.items():
bs_nums = value["bs_nums"]
out_in_features = value["out_in_features"]
elapsed_time_ms = run_benchmark(bs_nums, out_in_features)
print(f"Benchmarking {case} took {elapsed_time_ms:.2f} ms")
2 changes: 2 additions & 0 deletions benchmarks/quantized_matmul_kernel/microbenchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# I plan to create one but I think the kernel should reside in g3 and so the microbenchmark should be in g3 using xprof.
# So I'll not create a microbenchmark for now.
106 changes: 53 additions & 53 deletions torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,70 +212,70 @@ def quantized_matmul_int8(
# - out_block_size
# - in_block_size
TUNED_BLOCK_SIZES = {
(6, 16, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
(6, 16, 4096, 4096, 'bfloat16', True): (128, 2048, 2048),
(6, 16, 28672, 4096, 'bfloat16', True): (128, 1792, 2048),
(6, 16, 4096, 14336, 'bfloat16', True): (128, 512, 7168),
(6, 16, 1280, 8192, 'bfloat16', True): (128, 256, 8192),
(6, 16, 8192, 1024, 'bfloat16', True): (128, 8192, 256),
(6, 16, 7168, 8192, 'bfloat16', True): (128, 512, 8192),
(6, 16, 6144, 4096, 'bfloat16', True): (128, 6144, 1024),
(6, 16, 4096, 4096, 'bfloat16', True): (128, 4096, 1024),
(6, 16, 28672, 4096, 'bfloat16', True): (128, 3584, 2048),
(6, 16, 4096, 14336, 'bfloat16', True): (128, 4096, 1792),
(6, 16, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
(6, 16, 8192, 1024, 'bfloat16', True): (128, 1024, 1024),
(6, 16, 7168, 8192, 'bfloat16', True): (128, 7168, 1024),
(6, 16, 8192, 3584, 'bfloat16', True): (128, 2048, 3584),
(6, 32, 6144, 4096, 'bfloat16', True): (128, 1536, 4096),
(6, 32, 4096, 4096, 'bfloat16', True): (128, 4096, 2048),
(6, 32, 28672, 4096, 'bfloat16', True): (128, 2048, 2048),
(6, 32, 4096, 14336, 'bfloat16', True): (128, 256, 14336),
(6, 32, 1280, 8192, 'bfloat16', True): (128, 256, 8192),
(6, 32, 8192, 1024, 'bfloat16', True): (128, 1024, 1024),
(6, 32, 7168, 8192, 'bfloat16', True): (128, 3584, 2048),
(6, 32, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
(6, 64, 6144, 4096, 'bfloat16', True): (128, 512, 4096),
(6, 64, 4096, 4096, 'bfloat16', True): (128, 2048, 2048),
(6, 64, 28672, 4096, 'bfloat16', True): (128, 1792, 2048),
(6, 32, 4096, 4096, 'bfloat16', True): (128, 4096, 1024),
(6, 32, 28672, 4096, 'bfloat16', True): (128, 4096, 2048),
(6, 32, 4096, 14336, 'bfloat16', True): (128, 4096, 1792),
(6, 32, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
(6, 32, 8192, 1024, 'bfloat16', True): (128, 8192, 256),
(6, 32, 7168, 8192, 'bfloat16', True): (128, 1792, 4096),
(6, 32, 8192, 3584, 'bfloat16', True): (128, 8192, 896),
(6, 64, 6144, 4096, 'bfloat16', True): (128, 1536, 2048),
(6, 64, 4096, 4096, 'bfloat16', True): (128, 2048, 1024),
(6, 64, 28672, 4096, 'bfloat16', True): (128, 3584, 2048),
(6, 64, 4096, 14336, 'bfloat16', True): (128, 4096, 1024),
(6, 64, 1280, 8192, 'bfloat16', True): (128, 1280, 4096),
(6, 64, 8192, 1024, 'bfloat16', True): (128, 4096, 1024),
(6, 64, 7168, 8192, 'bfloat16', True): (128, 512, 8192),
(6, 64, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
(6, 64, 8192, 1024, 'bfloat16', True): (128, 8192, 1024),
(6, 64, 7168, 8192, 'bfloat16', True): (128, 3584, 2048),
(6, 64, 8192, 3584, 'bfloat16', True): (128, 2048, 1792),
(6, 128, 6144, 4096, 'bfloat16', True): (128, 1536, 4096),
(6, 128, 4096, 4096, 'bfloat16', True): (128, 1024, 4096),
(6, 128, 28672, 4096, 'bfloat16', True): (128, 1024, 4096),
(6, 128, 4096, 14336, 'bfloat16', True): (128, 2048, 2048),
(6, 128, 1280, 8192, 'bfloat16', True): (128, 640, 4096),
(6, 128, 6144, 4096, 'bfloat16', True): (128, 6144, 1024),
(6, 128, 4096, 4096, 'bfloat16', True): (128, 4096, 2048),
(6, 128, 28672, 4096, 'bfloat16', True): (128, 28672, 512),
(6, 128, 4096, 14336, 'bfloat16', True): (128, 2048, 3584),
(6, 128, 1280, 8192, 'bfloat16', True): (128, 1280, 1024),
(6, 128, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
(6, 128, 7168, 8192, 'bfloat16', True): (128, 896, 8192),
(6, 128, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
(6, 256, 6144, 4096, 'bfloat16', True): (256, 1536, 4096),
(6, 256, 4096, 4096, 'bfloat16', True): (256, 512, 4096),
(6, 256, 28672, 4096, 'bfloat16', True): (256, 896, 4096),
(6, 256, 4096, 14336, 'bfloat16', True): (256, 4096, 2048),
(6, 256, 1280, 8192, 'bfloat16', True): (256, 1280, 4096),
(6, 256, 8192, 1024, 'bfloat16', True): (256, 2048, 1024),
(6, 256, 7168, 8192, 'bfloat16', True): (256, 512, 8192),
(6, 256, 8192, 3584, 'bfloat16', True): (256, 1024, 3584),
(6, 128, 7168, 8192, 'bfloat16', True): (128, 1792, 4096),
(6, 128, 8192, 3584, 'bfloat16', True): (128, 8192, 896),
(6, 256, 6144, 4096, 'bfloat16', True): (256, 3072, 4096),
(6, 256, 4096, 4096, 'bfloat16', True): (256, 2048, 4096),
(6, 256, 28672, 4096, 'bfloat16', True): (256, 3584, 4096),
(6, 256, 4096, 14336, 'bfloat16', True): (256, 4096, 1792),
(6, 256, 1280, 8192, 'bfloat16', True): (256, 1280, 2048),
(6, 256, 8192, 1024, 'bfloat16', True): (256, 4096, 1024),
(6, 256, 7168, 8192, 'bfloat16', True): (256, 1792, 4096),
(6, 256, 8192, 3584, 'bfloat16', True): (256, 8192, 512),
(6, 512, 6144, 4096, 'bfloat16', True): (512, 2048, 4096),
(6, 512, 4096, 4096, 'bfloat16', True): (512, 1024, 4096),
(6, 512, 28672, 4096, 'bfloat16', True): (512, 1792, 4096),
(6, 512, 4096, 14336, 'bfloat16', True): (512, 4096, 1792),
(6, 512, 4096, 4096, 'bfloat16', True): (512, 4096, 512),
(6, 512, 28672, 4096, 'bfloat16', True): (512, 4096, 4096),
(6, 512, 4096, 14336, 'bfloat16', True): (512, 4096, 2048),
(6, 512, 1280, 8192, 'bfloat16', True): (512, 1280, 2048),
(6, 512, 8192, 1024, 'bfloat16', True): (512, 4096, 1024),
(6, 512, 7168, 8192, 'bfloat16', True): (512, 1792, 4096),
(6, 512, 8192, 3584, 'bfloat16', True): (512, 2048, 3584),
(6, 1024, 6144, 4096, 'bfloat16', True): (1024, 1024, 4096),
(6, 1024, 4096, 4096, 'bfloat16', True): (1024, 2048, 4096),
(6, 1024, 28672, 4096, 'bfloat16', True): (1024, 3584, 4096),
(6, 1024, 4096, 14336, 'bfloat16', True): (1024, 2048, 2048),
(6, 1024, 1280, 8192, 'bfloat16', True): (1024, 1280, 2048),
(6, 1024, 8192, 1024, 'bfloat16', True): (256, 8192, 1024),
(6, 1024, 7168, 8192, 'bfloat16', True): (1024, 1792, 8192),
(6, 1024, 8192, 3584, 'bfloat16', True): (1024, 2048, 3584),
(6, 512, 7168, 8192, 'bfloat16', True): (512, 7168, 512),
(6, 512, 8192, 3584, 'bfloat16', True): (512, 8192, 512),
(6, 1024, 6144, 4096, 'bfloat16', True): (512, 6144, 4096),
(6, 1024, 4096, 4096, 'bfloat16', True): (256, 4096, 4096),
(6, 1024, 28672, 4096, 'bfloat16', True): (1024, 4096, 4096),
(6, 1024, 4096, 14336, 'bfloat16', True): (1024, 4096, 1792),
(6, 1024, 1280, 8192, 'bfloat16', True): (512, 1280, 4096),
(6, 1024, 8192, 1024, 'bfloat16', True): (512, 8192, 1024),
(6, 1024, 7168, 8192, 'bfloat16', True): (512, 7168, 1024),
(6, 1024, 8192, 3584, 'bfloat16', True): (256, 8192, 3584),
(6, 2048, 6144, 4096, 'bfloat16', True): (256, 6144, 4096),
(6, 2048, 4096, 4096, 'bfloat16', True): (1024, 2048, 4096),
(6, 2048, 4096, 4096, 'bfloat16', True): (512, 4096, 4096),
(6, 2048, 28672, 4096, 'bfloat16', True): (1024, 4096, 4096),
(6, 2048, 4096, 14336, 'bfloat16', True): (1024, 4096, 1024),
(6, 2048, 1280, 8192, 'bfloat16', True): (512, 1280, 8192),
(6, 2048, 4096, 14336, 'bfloat16', True): (1024, 4096, 2048),
(6, 2048, 1280, 8192, 'bfloat16', True): (2048, 1280, 1024),
(6, 2048, 8192, 1024, 'bfloat16', True): (256, 8192, 1024),
(6, 2048, 7168, 8192, 'bfloat16', True): (1024, 1792, 8192),
(6, 2048, 8192, 3584, 'bfloat16', True): (2048, 2048, 3584),
(6, 2048, 7168, 8192, 'bfloat16', True): (256, 7168, 8192),
(6, 2048, 8192, 3584, 'bfloat16', True): (512, 8192, 3584),
}


Expand Down
Loading