[3/n][Optimus][Auto-AC][reland] Support any fp8 quantization type and set scaling as the default" (#154057)

Summary:
This is a reland of D74910193.
We change the dtype to torch.float8_e5m2 in unit test since it is not supported.

Test Plan:
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:quantization
```

Differential Revision: D75169792

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154057
Approved by: https://github.com/Mingming-Ding
This commit is contained in:
Menglu Yu
2025-05-22 18:26:31 +00:00
committed by PyTorch MergeBot
parent c2660d29a5
commit 788d9cb2d7
2 changed files with 97 additions and 31 deletions

View File

@ -47,6 +47,21 @@ class FeedforwardNN(torch.nn.Module):
return x
class LayernormNN(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, normalized_shape, weight, bias):
x = torch.nn.functional.layer_norm(
input=input,
normalized_shape=normalized_shape,
weight=weight,
bias=bias,
eps=1e-5,
)
return x
class TestQuantization(TestCase):
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
@ -106,10 +121,11 @@ class TestQuantization(TestCase):
"use_scaling": True,
"size_in_mb": 0.0,
"exclude_primals": True,
"allowed_dtypes": "torch.bfloat16;torch.float32",
},
},
)
def test_activation_quantization_aten(self):
def test_activation_quantization_aten_with_scaling(self):
counters.clear()
module = TargetCPModule().to(GPU_TYPE)
input = [
@ -159,6 +175,59 @@ class TestQuantization(TestCase):
self.assertTrue(torch.allclose(ref, res))
counters.clear()
@requires_gpu()
@torch._inductor.config.patch(
pre_grad_fusion_options={},
post_grad_fusion_options={
"activation_quantization_aten_pass": {
"quant_type": "torch.float8_e5m2",
"use_scaling": False,
"size_in_mb": 0.0,
"exclude_primals": True,
"allowed_dtypes": "torch.bfloat16;torch.float32",
},
},
)
def test_activation_quantization_aten_without_scaling(self):
counters.clear()
module = LayernormNN().to(GPU_TYPE)
normalized_shape = [256]
input = [
torch.randn(
(1, 3, 256), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
),
normalized_shape,
torch.randn(
*normalized_shape,
requires_grad=True,
device=GPU_TYPE,
dtype=torch.bfloat16,
),
torch.randn(
*normalized_shape,
requires_grad=True,
device=GPU_TYPE,
dtype=torch.bfloat16,
),
]
traced = torch.compile(module)
ref = module(*input)
res = traced(*input)
self.compare_pred(module, traced, input)
ref.sum().backward()
res.sum().backward()
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
self.assertEqual(
counters["inductor"]["activation_quantization_fwd_aten_pass"], 1
)
self.assertEqual(
counters["inductor"]["activation_quantization_bwd_aten_pass"], 1
)
self.assertTrue(torch.allclose(ref, res))
counters.clear()
if __name__ == "__main__":
if IS_LINUX and HAS_GPU:

View File

@ -379,7 +379,7 @@ def calculate_quantization_scaling(
scale_node = graph.call_function(
torch.ops.prims.convert_element_type.default,
args=(mul_node, torch.float32),
name="scale_" + str(node.name),
name="fp8_scale_" + str(node.name),
)
scale_node.meta["val"] = torch.ops.prims.convert_element_type.default(
mul_node.meta["val"], torch.float32
@ -444,7 +444,7 @@ def perform_quantization(
quant_activation_node = graph.call_function(
torch.ops.prims.convert_element_type.default,
args=(clamp_max_scaled_node, quant_type),
name="quant_" + str(node.name),
name="fp8_quant_" + str(node.name),
)
quant_activation_node.meta[
"val"
@ -513,13 +513,8 @@ def calculate_range(dtype: torch.dtype) -> tuple:
Returns:
tuple: A tuple containing the minimum and maximum values.
"""
if dtype == torch.float8_e5m2:
# 8-bit floating-point format with e5m2 layout
min_val = -57344.0
max_val = 57344.0
else:
raise ValueError(f"Unsupported dtype: {dtype}")
return min_val, max_val
info = torch.finfo(dtype)
return info.min, info.max
def quantize_activation_fw(graph: torch.fx.Graph) -> None:
@ -535,7 +530,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
# case: use scaling
if torch._inductor.config.post_grad_fusion_options[
"activation_quantization_aten_pass"
].get("use_scaling", False):
].get("use_scaling", True):
# calculating the scale
scale_node = calculate_quantization_scaling(
graph, node, clamp_max, 1e-12
@ -554,7 +549,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
quant_node = graph.call_function(
torch.ops.prims.convert_element_type.default,
args=(node, quant_type),
name="quant_" + str(node.name),
name="fp8_quant_" + str(node.name),
)
quant_node.meta[
"val"
@ -595,12 +590,13 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
# case: use scaling
with graph.inserting_after(node):
# find corresponding scale node
scale_name = "scale_" + node.name.replace("quant_", "")
scale_name = "fp8_scale_" + node.name.replace("fp8_quant_", "")
scale_node = next(
bwd_input
for bwd_input in bw_inputs
if bwd_input.name == scale_name
)
with graph.inserting_after(scale_node):
activation_node = graph.call_function(
torch.ops.prims.convert_element_type.default,
args=(node, dequant_type),
@ -613,7 +609,7 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
activation_node.meta["tensor_meta"] = extract_tensor_metadata(
activation_node.meta["val"]
)
with graph.inserting_after(scale_node):
with graph.inserting_after(activation_node):
divided_target_node_32 = graph.call_function(
torch.ops.aten.div.Tensor,
args=(activation_node, scale_node),
@ -725,11 +721,22 @@ def enable_activation_quantization(
),
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "before_activation_quantization_bwd_aten_pass",
"encoding": "string",
},
payload_fn=lambda: bwd_module.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
quant_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
# update the corresponding bwd_inputs due to the fwd_outputs quantization
for fwd_node in quant_fwd_module_outputs:
if "quant_" in fwd_node.name:
bwd_input = bwd_module_inputs[fwd_node.name.replace("quant_", "")]
if "fp8_quant_" in fwd_node.name:
bwd_input = bwd_module_inputs[fwd_node.name.replace("fp8_quant_", "")]
with bwd_module.graph.inserting_after(bwd_input):
quant_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
dequant_type = bwd_input.meta["dequant_type"]
@ -741,34 +748,24 @@ def enable_activation_quantization(
# update the bwd_inputs if quantization with scaling is used
if torch._inductor.config.post_grad_fusion_options[
"activation_quantization_aten_pass"
].get("use_scaling", False):
].get("use_scaling", True):
quant_bwd_module_inputs = list(bwd_module.graph.find_nodes(op="placeholder"))
# update the corresponding bwd input nodes find the last non-tangent node
bwd_input_loc = list(bwd_module_inputs.values())[-1]
for bw_input in reversed(bwd_module_inputs.values()):
bwd_input_loc = quant_bwd_module_inputs[-1]
for bw_input in reversed(quant_bwd_module_inputs):
if not _is_tangent(bw_input):
bwd_input_loc = bw_input
break
scaled_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
for fwd_node in scaled_fwd_module_outputs:
if "scale_" in fwd_node.name:
if "fp8_scale_" in fwd_node.name:
# fwd node is a scale node
with bwd_module.graph.inserting_after(bwd_input_loc):
scale_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
scale_bwd_input.meta.update(fwd_node.meta)
bwd_input_loc = scale_bwd_input
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "before_activation_quantization_bwd_aten_pass",
"encoding": "string",
},
payload_fn=lambda: bwd_module.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
quantize_activation_bw(bwd_module.graph)
trace_structured(