From bd56c983d6fe8ff93bddd5faaf8d96e01c90fd83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 28 Feb 2025 18:20:11 -0500 Subject: [PATCH] [torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902) Signed-off-by: luka --- tests/compile/backend.py | 13 +- tests/compile/test_functionalization.py | 8 +- tests/compile/test_fusion.py | 107 +++++++------- vllm/compilation/noop_elimination.py | 135 ++++++++++++++++++ vllm/compilation/pass_manager.py | 8 +- vllm/compilation/reshapes.py | 90 ------------ vllm/compilation/vllm_inductor_pass.py | 18 ++- vllm/config.py | 13 +- .../layers/quantization/utils/w8a8_utils.py | 7 +- 9 files changed, 239 insertions(+), 160 deletions(-) create mode 100644 vllm/compilation/noop_elimination.py delete mode 100644 vllm/compilation/reshapes.py diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 74bc58a2dd..64416eb136 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -13,21 +13,26 @@ class TestBackend: This class provides a simple Inductor backend that can be used for testing. It takes a list of custom passes and runs them after Inductor's passes. It also saves the graph before and after the custom passes for inspection. + + Inductor config can be modified directly by editing the inductor_config + property. This can be helpful for adding passes like the + 'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'. """ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.custom_passes = list(passes) from torch._inductor import config - self.current_config = config.shallow_copy_dict() - self.current_config['force_disable_caches'] = True - self.current_config['post_grad_custom_post_pass'] = self.post_pass + self.inductor_config = config.shallow_copy_dict() + self.inductor_config['force_disable_caches'] = True + self.inductor_config['post_grad_custom_post_pass'] = self.post_pass def __call__(self, graph: fx.GraphModule, example_inputs): + self.graph_pre_compile = deepcopy(graph) from torch._inductor.compile_fx import compile_fx return compile_fx(graph, example_inputs, - config_patches=self.current_config) + config_patches=self.inductor_config) def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 8f50405226..9f9b2d06b2 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -9,7 +9,7 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func -from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig from .backend import TestBackend @@ -50,11 +50,11 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, torch.set_default_device("cuda") config = CompilationConfig.PassConfig(enable_fusion=do_fusion, - enable_reshape=True) - reshape_pass = RedundantReshapesPass(config) + enable_noop=True) + noop_pass = NoOpEliminationPass(config) fusion_pass = FusionPass.instance(config) - passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass] + passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass] func_pass = FixFunctionalizationPass(config) backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index c14f0caab5..89abc00176 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,23 +5,25 @@ import torch from compressed_tensors.quantization import FP8_DTYPE import vllm.envs as envs +import vllm.plugins from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, FusionPass, QuantKey) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe -from vllm.compilation.reshapes import RedundantReshapesPass -from vllm.config import CompilationConfig +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear) + CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity) from .backend import TestBackend class TestModel(torch.nn.Module): - def __init__(self, hidden_size: int, eps: float, static: bool, *args, - **kwargs): + def __init__(self, hidden_size: int, eps: float, static: bool, + cutlass_fp8_enabled: bool, *args, **kwargs): super().__init__(*args, **kwargs) + self.cutlass_fp8_enabled = cutlass_fp8_enabled self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] if static: @@ -41,7 +43,8 @@ class TestModel(torch.nn.Module): self.w[0], self.wscale[0], self.scale[0], - use_per_token_if_dynamic=True) + use_per_token_if_dynamic=True, + cutlass_fp8_supported=self.cutlass_fp8_enabled) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) @@ -49,7 +52,8 @@ class TestModel(torch.nn.Module): self.w[1], self.wscale[1], self.scale[1], - use_per_token_if_dynamic=True) + use_per_token_if_dynamic=True, + cutlass_fp8_supported=self.cutlass_fp8_enabled) y3, resid = self.norm[2](x3, resid) # use resid here return y3 @@ -59,60 +63,67 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) +@pytest.mark.parametrize("cutlass_fp8_enabled", + [True, False] if CUTLASS_FP8_SUPPORTED else [False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): +def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, + cutlass_fp8_enabled): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) + maybe_create_device_identity() # needed for certain non-cutlass fp8 paths - # Reshape pass is needed for the fusion pass to work - config = CompilationConfig.PassConfig(enable_fusion=True, - enable_reshape=True) - reshape_pass = RedundantReshapesPass(config) - fusion_pass = FusionPass.instance(config) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) + with vllm.config.set_current_vllm_config(vllm_config): + # Reshape pass is needed for the fusion pass to work + config = CompilationConfig.PassConfig(enable_fusion=True, + enable_noop=True) + noop_pass = NoOpEliminationPass(config) + fusion_pass = FusionPass.instance(config) - backend = TestBackend(reshape_pass, fusion_pass) - model = TestModel(hidden_size, eps, static) + backend = TestBackend(noop_pass, fusion_pass) + model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled) - # First dimension dynamic - x = torch.rand(num_tokens, hidden_size) - torch._dynamo.mark_dynamic(x, 0) + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) - result = model(x) + result = model(x) - model2 = torch.compile(model, backend=backend) - result2 = model2(x) + model2 = torch.compile(model, backend=backend) + result2 = model2(x) - # Higher tol for dynamic, even higher for bfloat16 - if static: - ATOL, RTOL = (1e-3, 1e-3) - elif dtype == torch.float16: - ATOL, RTOL = (2e-3, 2e-3) - else: - ATOL, RTOL = (1e-2, 1e-2) + # Higher tol for dynamic, even higher for bfloat16 + if static: + ATOL, RTOL = (1e-3, 1e-3) + elif dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) - torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) - # Check substitution worked - pre_nodes = backend.graph_pre_pass.nodes - post_nodes = backend.graph_post_pass.nodes + # Check substitution worked + pre_nodes = backend.graph_pre_pass.nodes + post_nodes = backend.graph_post_pass.nodes - # static is per-tensor, dynamic is per-token - key = QuantKey(dtype=FP8_DTYPE, - static=static, - per_tensor=static, - symmetric=True) - rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)] - add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)] - fp8_quant = QUANT_OPS[key] + # static is per-tensor, dynamic is per-token + key = QuantKey(dtype=FP8_DTYPE, + static=static, + per_tensor=static, + symmetric=True) + rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)] + add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)] + fp8_quant = QUANT_OPS[key] - # In pre-nodes, fp8 quant should be present and fused kernels should not - assert find_auto_fn_maybe(pre_nodes, rms_quant) is None - assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None - find_auto_fn(pre_nodes, fp8_quant) + # In pre-nodes, fp8 quant should be there and fused kernels should not + assert find_auto_fn_maybe(pre_nodes, rms_quant) is None + assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None + find_auto_fn(pre_nodes, fp8_quant) - # In post-nodes, fused kernels should be present and fp8 quant should not - find_auto_fn(post_nodes, rms_quant) - find_auto_fn(post_nodes, add_rms_quant) - assert find_auto_fn_maybe(post_nodes, fp8_quant) is None + # In post-nodes, fused kernels should be there and fp8 quant should not + find_auto_fn(post_nodes, rms_quant) + find_auto_fn(post_nodes, add_rms_quant) + assert find_auto_fn_maybe(post_nodes, fp8_quant) is None diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py new file mode 100644 index 0000000000..19127e933e --- /dev/null +++ b/vllm/compilation/noop_elimination.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, Union + +import torch.fx +from torch import SymInt + +from vllm.logger import init_logger + +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class NoOpEliminationPass(VllmInductorPass): + """ + This is an inductor pass that removes redundant reshape/slice operations. + It is required for RMSNorm-quant fusion to work properly. + That's because apply_fp8_linear adds a reshape, which is redundant + in the 2D-case. Additionally, torch internal no-op elimination pass does + not handle certain slice variants. + + Example graph 1: + getitem_1: "f16[s0, 4096]" = ... + view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) + at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Can be replaced with: + getitem_1: "f16[s0, 4096]" = ... + at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Example graph 2: + arg0: "s0" = SymInt(s0) + scaled_mm: "f16[s0, 4096]" = ... + slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0) + at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...) + out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0) + + Can be replaced with: + arg0: "s0" = SymInt(s0) + scaled_mm: "f16[s0, 4096]" = ... + at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...) + out: "f16[s0, 4096]" = at[1] + + TODO(luka): This is currently tested in test_fusion, + but separate tests could be good. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_noop_elimination") + count = 0 + # Remove no-op reshapes/views: + for node in graph.nodes: + if is_func(node, torch.ops.aten.reshape.default): + input, shape = node.args[:2] + input_shape = input.meta["val"].shape + if len(shape) != len(input_shape): + # Reshape changing rank, skip + continue + + if shape.count(-1) > 1: + # Invalid reshape args, skip + continue + + if self.all_dims_equivalent(shape, input_shape): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + elif is_func(node, torch.ops.aten.slice.Tensor): + input, dim_index, start, end = node.args[:4] + input_shape = input.meta["val"].shape + i_dim = input_shape[dim_index] + + if start == 0 and self.dims_equivalent(end, i_dim): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + elif is_func(node, torch.ops.aten.slice_scatter.default): + base, view, dim_index, start, end = node.args[:5] + base_shape = base.meta["val"].shape + view_shape = view.meta["val"].shape + + view_dim = view_shape[dim_index] + + # Check that view fully covers base and the full view is used + # (if the view fully covered the base after slicing but was not + # fully used, we could replace slice_scatter with a simple slice + # but that's a niche case). + if (base_shape == view_shape and start == 0 + and self.dims_equivalent(end, view_dim)): + node.replace_all_uses_with(view) + graph.erase_node(node) + count += 1 + + logger.debug("Removed %s no-op reshapes and slices", count) + self.dump_graph(graph, "after_noop_elimination") + self.end_and_log() + + def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]], + i_dims: Iterable[Union[int, SymInt]]): + return all( + self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) + + def dims_equivalent(self, dim: Union[int, torch.fx.Node], + i_dim: Union[int, SymInt]) -> bool: + """ + This function checks if two dimensions are equivalent. + :param dim: The dimension arg to reshape/slice + :param i_dim: The corresponding dimension in the input tensor + :return: Are the dimensions equivalent? + + There are three cases in which the dimensions are equivalent: + 1. The dimensions are equal (both integers) + 2. The reshape dimension is -1 (i.e. inferred) + 3. The dimensions both correspond to the same SymInt + + While case 2 does not guarantee the dimensions are equal, + they are equal if all other dimensions are equal. + + In case 3, the reshape dimension is a torch.fx.Node, + and its value is a SymInt. That value is equal to the + input dimension. + + """ + # Case 1 and 2 + if dim == i_dim or dim == -1: + return True + # Case 3 + return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 52f8c3b1ec..b012346c35 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -11,7 +11,7 @@ from vllm.logger import init_logger from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass from .inductor_pass import InductorPass -from .reshapes import RedundantReshapesPass +from .noop_elimination import NoOpEliminationPass logger = init_logger(__name__) @@ -36,7 +36,7 @@ class PostGradPassManager(Parent): The order of the post-grad post-passes is: 1. passes (constructor parameter) - 2. default passes (RedundantReshapesPass, FusionPass) + 2. default passes (NoopEliminationPass, FusionPass) 3. config["post_grad_custom_post_pass"] (if it exists) 4. fix_functionalization This way, all passes operate on a functionalized graph. @@ -54,8 +54,8 @@ class PostGradPassManager(Parent): def configure(self, pass_config: CompilationConfig.PassConfig): self.pass_config = pass_config - if pass_config.enable_reshape: - self.passes += [RedundantReshapesPass(pass_config)] + if pass_config.enable_noop: + self.passes += [NoOpEliminationPass(pass_config)] if pass_config.enable_fusion: self.passes += [FusionPass.instance(pass_config)] diff --git a/vllm/compilation/reshapes.py b/vllm/compilation/reshapes.py deleted file mode 100644 index 292baae852..0000000000 --- a/vllm/compilation/reshapes.py +++ /dev/null @@ -1,90 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -import torch.fx -from torch import SymInt - -from vllm.logger import init_logger - -from .fx_utils import is_func -from .vllm_inductor_pass import VllmInductorPass - -logger = init_logger(__name__) - - -class RedundantReshapesPass(VllmInductorPass): - """ - This is an inductor pass that removes redundant reshape operations. - It is required for RMSNorm-quant fusion to work properly. - That's because apply_fp8_linear adds a reshape, which is redundant - in the 2D-case. - - Example graph: - - getitem_1: "f16[s0, 4096]" = ... - view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) - at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) - out: "f8e4m3fn[s0, 4096]" = at[1] - - Can be replaced with: - getitem_1: "f16[s0, 4096]" = ... - at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) - out: "f8e4m3fn[s0, 4096]" = at[1] - """ - - def __call__(self, graph: torch.fx.Graph): - self.begin() - self.dump_graph(graph, "before_reshapes") - count = 0 - # Remove no-op reshapes/views: - for node in graph.nodes: - if is_func(node, torch.ops.aten.reshape.default): - input, shape = node.args[:2] - input_shape = input.meta["val"].shape - if len(shape) != len(input_shape): - # Reshape changing rank, skip - continue - - if shape.count(-1) > 1: - # Invalid reshape args, skip - continue - - if all( - self.dims_equivalent(s, i_s) - for s, i_s in zip(shape, input_shape)): - node.replace_all_uses_with(input) - graph.erase_node(node) - count += 1 - - logger.debug("Removed %s no-op reshapes", count) - - self.dump_graph(graph, "after_reshapes") - self.end_and_log() - - def dims_equivalent(self, dim: Union[int, torch.fx.Node], - i_dim: Union[int, SymInt]) -> bool: - """ - This function checks if two dimensions are equivalent. - :param dim: The dimension arg to reshape - :param i_dim: The corresponding dimension in the input tensor - :return: Are the dimensions equivalent? - - There are three cases in which the dimensions are equivalent: - 1. The dimensions are equal (both integers) - 2. The reshape dimension is -1 (i.e. inferred) - 3. The dimensions both correspond to the same SymInt - - While case 2 does not guarantee the dimensions are equal, - they are equal if all other dimensions are equal. - - In case 3, the reshape dimension is a torch.fx.Node, - and its value is a SymInt. That value is equal to the - input dimension. - - """ - # Case 1 and 2 - if dim == i_dim or dim == -1: - return True - # Case 3 - return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 1d2597e427..98ed6f1472 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -28,8 +28,8 @@ class VllmInductorPass(InductorPass): self.config = config self.pass_name = self.__class__.__name__ - def dump_graph(self, graph: torch.fx.Graph, stage: str): - if stage in self.config.dump_graph_stages: + def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): + if stage in self.config.dump_graph_stages or always: # Make sure filename includes rank in the distributed setting parallel = p_is_init() and get_tp_world_size() > 1 rank = f"-{get_tp_rank()}" if parallel else "" @@ -49,3 +49,17 @@ class VllmInductorPass(InductorPass): self._end_time = time.perf_counter_ns() duration_ms = float(self._end_time - self._start_time) / 1.0e6 logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) + + +class PrinterInductorPass(VllmInductorPass): + + def __init__(self, + name: str, + config: CompilationConfig.PassConfig, + always=False): + super().__init__(config) + self.name = name + self.always = always + + def __call__(self, graph: torch.fx.Graph): + self.dump_graph(graph, self.name, always=self.always) diff --git a/vllm/config.py b/vllm/config.py index 78d02b0173..c710847344 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2993,13 +2993,13 @@ class CompilationConfig(BaseModel): Each pass defines its own stages (before, after, maybe in-between). - dump_graph_dir: directory to dump the graphs. Default is . - enable_fusion: whether to enable the custom fusion pass. - - enable_reshape: whether to enable the custom reshape elimination pass. - TODO better pass enabling system. + - enable_noop: whether to enable the custom no-op elimination pass. + TODO(luka) better pass enabling system. """ dump_graph_stages: List[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path(".")) enable_fusion: bool = True - enable_reshape: bool = True + enable_noop: bool = True def uuid(self): """ @@ -3008,13 +3008,12 @@ class CompilationConfig(BaseModel): Do not include dump_graph_* in the hash - they don't affect compilation. """ - dict_ = self.model_dump( - include={"enable_fusion", "enable_reshape"}) + dict_ = self.model_dump(include={"enable_fusion", "enable_noop"}) encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).digest() def model_post_init(self, __context: Any) -> None: - if not self.enable_reshape and self.enable_fusion: + if not self.enable_noop and self.enable_fusion: logger.warning_once( "Fusion enabled but reshape elimination disabled. " "RMSNorm + quant (fp8) fusion might not work") @@ -3411,7 +3410,7 @@ class VllmConfig: self.compilation_config.use_inductor = True self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.pass_config.enable_fusion = False - self.compilation_config.pass_config.enable_reshape = False + self.compilation_config.pass_config.enable_noop = False self.compilation_config.level = CompilationLevel.PIECEWISE self._set_cudagraph_sizes() diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 0f93b7f6c4..8072f30776 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union import torch from vllm import _custom_ops as ops +from vllm.config import CompilationLevel, get_current_vllm_config from vllm.platforms import current_platform # Input scaling factors are no longer optional in _scaled_mm starting @@ -161,10 +162,14 @@ def apply_fp8_linear( # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. + # We also don't pad when using torch.compile, + # as it breaks with dynamic shapes. + config = get_current_vllm_config().compilation_config + do_pad = config.level < CompilationLevel.PIECEWISE qinput, x_scale = ops.scaled_fp8_quant( input_2d, input_scale, - num_token_padding=17, + num_token_padding=17 if do_pad else None, use_per_token_if_dynamic=use_per_token_if_dynamic) per_tensor_weights = (weight_scale.numel() == 1)