[FP8] Extend per-token-group quantization support to QuantFP8 (#24342)

Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
Tahsin Tunan
2025-09-17 07:31:06 +06:00
committed by GitHub
parent 493b10f8bf
commit cef32104b4
5 changed files with 444 additions and 61 deletions

View File

@ -2,14 +2,25 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools import itertools
from typing import Callable from typing import Callable
from unittest.mock import patch
import pandas as pd
import torch import torch
from vllm import _custom_ops as ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
def with_triton_mode(fn):
"""Temporarily force the Triton fallback path"""
def wrapped(*args, **kwargs):
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
return fn(*args, **kwargs)
return wrapped
# TODO(luka): use standalone_compile utility # TODO(luka): use standalone_compile utility
@ -21,78 +32,236 @@ def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
return inner return inner
torch._dynamo.config.recompile_limit = 8888 def bench_compile(fn: Callable):
compilation_config = CompilationConfig(custom_ops=["none"]) # recompile for different shapes
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)): fwd = torch.compile(fn, fullgraph=True, dynamic=False)
torch_per_token_quant_fp8 = torch.compile(
QuantFP8(False, GroupShape.PER_TOKEN),
fullgraph=True,
dynamic=False, # recompile for different shapes
)
# First dim is explicitly dynamic to simulate vLLM usage # First dim is explicitly dynamic to simulate vLLM usage
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0) return with_dyn_arg(fwd, 0, 0)
def cuda_per_token_quant_fp8( torch._dynamo.config.recompile_limit = 8888
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input)
def calculate_diff(batch_size: int, seq_len: int): def calculate_diff(
"""Calculate difference between Triton and CUDA implementations.""" batch_size: int,
hidden_size: int,
group_shape: GroupShape,
dtype: torch.dtype,
):
"""Calculate the difference between Inductor and CUDA implementations."""
device = torch.device("cuda") device = torch.device("cuda")
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device) x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device)
torch_out, torch_scale = torch_per_token_quant_fp8(x) quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False)
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
if torch.allclose( torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x)
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5 torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x)
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5): cuda_out, cuda_scale = quant_fp8.forward_cuda(x)
out_allclose = lambda o1, o2: torch.allclose(
o1.to(torch.float32),
o2.to(torch.float32),
rtol=1e-3,
atol=1e-5,
)
scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5)
if (
out_allclose(cuda_out, torch_out)
and scale_allclose(cuda_scale, torch_scale)
and out_allclose(cuda_out, torch_eager_out)
and scale_allclose(cuda_scale, torch_eager_scale)
):
print("✅ All implementations match") print("✅ All implementations match")
else: else:
print("❌ Implementations differ") print("❌ Implementations differ")
batch_size_range = [1, 16, 32, 64, 128] configs = []
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
configs = list(itertools.product(batch_size_range, seq_len_range))
@triton.testing.perf_report( def benchmark_quantization(
triton.testing.Benchmark( batch_size,
x_names=["batch_size", "seq_len"], hidden_size,
x_vals=configs, provider,
line_arg="provider", group_shape: GroupShape,
line_vals=["torch", "cuda"], col_major: bool,
line_names=["Torch", "CUDA"], dtype: torch.dtype,
styles=[("blue", "-"), ("green", "-")], ):
ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance",
args={},
)
)
def benchmark_quantization(batch_size, seq_len, provider):
dtype = torch.float16
device = torch.device("cuda") device = torch.device("cuda")
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major)
if provider == "torch": if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone()) fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone())
elif provider == "cuda": elif provider == "cuda":
fn = lambda: cuda_per_token_quant_fp8(x.clone()) fn = lambda: quant_fp8.forward_cuda(x.clone())
elif provider == "triton":
if not group_shape.is_per_group():
# Triton only supported for per-group
return 0, 0, 0
fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
# TODO(luka) extract to utils
def compute_geomean_speedups(
df: pd.DataFrame,
baseline_col: str,
speedup_cols: list[str],
groupby_cols: list[str] | None = None,
) -> pd.DataFrame:
"""
Compute geometric mean speedups over a baseline column.
Args:
df: Input dataframe
baseline_col: Column to use as baseline
speedup_cols: Columns to compute speedups for
groupby_cols: Columns to group by. If None, compute over entire df.
Returns:
pd.DataFrame with geometric mean speedups
"""
from scipy.stats import gmean
def geo_speedup(group: pd.DataFrame) -> pd.Series:
ratios = {
col: (group[baseline_col] / group[col]).values for col in speedup_cols
}
return pd.Series({col: gmean(vals) for col, vals in ratios.items()})
if groupby_cols is None:
result = geo_speedup(df).to_frame().T
else:
result = (
df.groupby(groupby_cols)
.apply(geo_speedup, include_groups=False)
.reset_index()
)
return result
if __name__ == "__main__": if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=4096) parser = FlexibleArgumentParser(
benchmark_quantization.run(print_data=True) description="Benchmark the various implementations of QuantFP8 (dynamic-only)"
)
parser.add_argument("-c", "--check", action="store_true")
parser.add_argument(
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
)
parser.add_argument(
"--hidden-sizes",
type=int,
nargs="+",
default=None,
help="Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)",
)
parser.add_argument(
"--batch-sizes",
type=int,
nargs="+",
default=None,
help="Batch sizes to benchmark (default: 1,16,32,64,128)",
)
parser.add_argument(
"--group-sizes",
type=int,
nargs="+",
default=None,
help="Group sizes for GroupShape(1,N) to benchmark. "
"Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)",
)
parser.add_argument(
"--no-column-major",
action="store_true",
help="Disable column-major scales testing",
)
args = parser.parse_args()
assert args
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
hidden_sizes = args.hidden_sizes or [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
batch_sizes = args.batch_sizes or [1, 16, 32, 64, 128]
if args.group_sizes is not None:
group_shapes = []
for size in args.group_sizes:
if size == 0:
group_shapes.append(GroupShape.PER_TENSOR)
elif size == -1:
group_shapes.append(GroupShape.PER_TOKEN)
else:
group_shapes.append(GroupShape(1, size))
else:
group_shapes = [
GroupShape.PER_TENSOR,
GroupShape.PER_TOKEN,
GroupShape(1, 64),
GroupShape(1, 128),
]
column_major_scales = [False] if args.no_column_major else [True, False]
config_gen = itertools.product(
group_shapes,
column_major_scales,
batch_sizes,
hidden_sizes,
)
# filter out column-major scales for non-group, reverse order
configs.extend(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1]))
print(f"Running {len(configs)} configurations:")
print(f" Hidden sizes: {hidden_sizes}")
print(f" Batch sizes: {batch_sizes}")
print(f" Group shapes: {[str(g) for g in group_shapes]}")
print(f" Column major scales: {column_major_scales}")
print()
if args.check:
for group_shape in group_shapes:
group_size = group_shape[1]
print(f"{group_size=}")
calculate_diff(
batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype
)
benchmark = triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["hidden_size", "batch_size", "col_major", "group_shape"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "cuda", "triton"],
line_names=["Torch (Compiled)", "CUDA", "Triton"],
styles=[("blue", "-"), ("green", "-"), ("black", "-")],
ylabel="us",
plot_name="QuantFP8 performance",
args={},
)
)(benchmark_quantization)
df = benchmark.run(print_data=True, dtype=dtype, return_df=True)
# Print geomean speedups
geo_table_grouped = compute_geomean_speedups(
df,
baseline_col="Torch (Compiled)",
speedup_cols=["CUDA", "Triton"],
groupby_cols=["col_major", "group_shape"],
)
print("Speedup over Torch (Compiled)")
print(geo_table_grouped.to_string(index=False))

View File

@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for QuantFP8 Group Quantization implementation."""
import pytest
import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
@pytest.mark.parametrize(
"batch_size,hidden_dim,group_size",
[
(16, 256, 32), # Small
(64, 1024, 64), # Medium
(128, 2048, 128), # Large
(8, 513, 64), # Non-divisible (native only)
])
@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_size: int, seed: int) -> None:
"""Test QuantFP8 group quantization with various configurations.
Tests both CUDA and native implementations, column-major scales,
and verifies consistency between implementations.
"""
current_platform.seed_everything(seed)
x = torch.randn(
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
expected_num_groups = (hidden_dim + group_size - 1) // group_size
is_divisible = hidden_dim % group_size == 0
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)
# 1. Test native implementation (always available)
x_quant_native, scales_native = quant_op.forward_native(x.clone())
assert x_quant_native.shape == x.shape
assert scales_native.shape == (batch_size, expected_num_groups)
# 2. Test column-major scales configuration
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True)
_, scales_col = quant_op_col.forward_native(x.clone())
assert scales_col.shape == (expected_num_groups, batch_size)
# 3. Test CUDA implementation (only for divisible dimensions)
if is_divisible:
x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
assert x_quant_cuda.shape == x.shape
assert scales_cuda.shape == (batch_size, expected_num_groups)
# Verify CUDA/native consistency
assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8)
# Quantized values should mostly match
diff_count = (x_quant_cuda != x_quant_native).sum().item()
diff_ratio = diff_count / x_quant_cuda.numel()
assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}"
@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_multidimensional(seed: int) -> None:
current_platform.seed_everything(seed)
group_size = 64
# Test with 3D input
batch1, batch2, hidden_dim = 4, 8, 512
x_3d = torch.randn(
(batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)
x_quant, scales = quant_op.forward_native(x_3d.clone())
assert x_quant.shape == x_3d.shape
assert scales.shape == (batch1, batch2, hidden_dim // group_size)
# Test column_major_scales with multi-dim
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True)
_, scales_col = quant_op_col.forward_native(x_3d.clone())
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)
# Test with 4D input
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
x_4d = torch.randn((batch1, batch2, batch3, hidden_dim),
dtype=torch.bfloat16,
device="cuda") * 8
x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
assert x_quant_4d.shape == x_4d.shape
assert scales_4d.shape == (batch1, batch2, batch3,
hidden_dim // group_size)
_, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size,
batch3)
@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_edge_cases(seed: int) -> None:
current_platform.seed_everything(seed)
batch_size = 16
group_size = 64
# Test with single group (group_size >= hidden_dim)
x_small = torch.randn(
(batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)
x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
assert x_quant_small.shape == x_small.shape
assert scales_small.shape == (batch_size, 1)
# Test with zero inputs
x_zero = torch.zeros((batch_size, 256),
dtype=torch.bfloat16,
device="cuda")
x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
assert x_quant_zero.shape == x_zero.shape
assert (scales_zero > 0).all(), "Scales should be clamped to minimum"
# Test very large values
x_large = torch.full((batch_size, 256),
1000.0,
dtype=torch.bfloat16,
device="cuda")
x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
assert x_quant_large.shape == x_large.shape
# FP8 max is typically 448 or 224, so scales should be > 1
assert (scales_large > 1.0).all(), "Large values should have scales > 1"

View File

@ -32,9 +32,11 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP) TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8) _resize_cache, moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim) calculate_tile_tokens_dim)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4) dequant_mxfp4)
from vllm.platforms import current_platform from vllm.platforms import current_platform

View File

@ -23,28 +23,39 @@ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
@CustomOp.register("quant_fp8") @CustomOp.register("quant_fp8")
class QuantFP8(CustomOp): class QuantFP8(CustomOp):
""" """
Quantize input tensor to per-tensor or per-token FP8. Quantize input tensor to FP8 (per-tensor, per-token, or per-group).
This CustomOp supports both static and dynamic quantization. This CustomOp supports both static and dynamic quantization.
""" """
def __init__(self, def __init__(self,
static: bool, static: bool,
group_shape: GroupShape, group_shape: GroupShape,
num_token_padding: Optional[int] = None): num_token_padding: Optional[int] = None,
column_major_scales: bool = False):
""" """
:param static: static or dynamic quantization :param static: static or dynamic quantization
:param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR) :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
:param num_token_padding: Pad the token dimension of output to this size or arbitrary block size)
:param num_token_padding: Pad the token dimension of output to this
size
:param column_major_scales: For group quantization, output scales in
column major format
""" """
super().__init__() super().__init__()
self.num_token_padding = num_token_padding
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
assert not static or group_shape == GroupShape.PER_TENSOR, \
"Only per-tensor scales supported for static quantization."
self.static = static self.static = static
self.group_shape = group_shape self.group_shape = group_shape
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales
self.is_group_quant = group_shape.is_per_group()
if self.is_group_quant:
assert not static, "Group quantization only supports dynamic mode"
self.group_size = group_shape.col
else:
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
assert not static or group_shape == GroupShape.PER_TENSOR, \
"Only per-tensor scales supported for static quantization."
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
def forward_cuda( def forward_cuda(
self, self,
@ -52,11 +63,19 @@ class QuantFP8(CustomOp):
scale: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"
from vllm.model_executor.layers.quantization.utils import fp8_utils
return fp8_utils.per_token_group_quant_fp8(
x,
group_size=self.group_size,
column_major_scales=self.column_major_scales,
dtype=_FP8_DTYPE)
assert (scale is not None) == self.static assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape assert scale_ub is None or (not self.static and self.group_shape
== GroupShape.PER_TOKEN == GroupShape.PER_TOKEN
and scale_ub.numel() == 1) and scale_ub.numel() == 1)
return ops.scaled_fp8_quant( return ops.scaled_fp8_quant(
x, x,
scale, scale,
@ -70,6 +89,10 @@ class QuantFP8(CustomOp):
scale: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None,
): ):
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"
return self._quantize_group_native(x)
assert (scale is not None) == self.static assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape assert scale_ub is None or (not self.static and self.group_shape
== GroupShape.PER_TOKEN == GroupShape.PER_TOKEN
@ -84,8 +107,7 @@ class QuantFP8(CustomOp):
else: else:
x_max = x.abs().max().unsqueeze(-1).to(torch.float32) x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
scale = x_max / _FP8_MAX scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR)
# Even for dynamic per-token scales, # Even for dynamic per-token scales,
# reciprocal performs slightly better than division # reciprocal performs slightly better than division
@ -101,3 +123,34 @@ class QuantFP8(CustomOp):
out = F.pad(out, (0, 0, 0, padding), "constant", 0.0) out = F.pad(out, (0, 0, 0, padding), "constant", 0.0)
return out, scale return out, scale
def _quantize_group_native(
self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
orig_shape = x.shape
hidden_dim = x.shape[-1]
num_groups = (hidden_dim + self.group_size - 1) // self.group_size
padded_dim = num_groups * self.group_size
if padded_dim != hidden_dim:
padding = padded_dim - hidden_dim
x = F.pad(x, (0, padding), mode='constant', value=0.0)
x_grouped = x.view(-1, num_groups, self.group_size)
absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float()
scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
x_scaled = x_grouped / scales
x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
x_quant = x_quant.view(-1, padded_dim)
if padded_dim != hidden_dim:
x_quant = x_quant[..., :hidden_dim]
x_quant = x_quant.view(orig_shape)
scales = scales.squeeze(-1)
scales = scales.reshape(orig_shape[:-1] + (num_groups, ))
if self.column_major_scales:
scales = scales.transpose(-2, -1).contiguous()
return x_quant, scales

View File

@ -34,6 +34,15 @@ class GroupShape(_GroupShape):
PER_TENSOR: ClassVar['GroupShape'] PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape'] PER_TOKEN: ClassVar['GroupShape']
def is_per_tensor(self) -> bool:
return self.row == -1 and self.col == -1
def is_per_token(self) -> bool:
return self.row == 1 and self.col == -1
def is_per_group(self) -> bool:
return self.row == 1 and self.col >= 1
GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1)