mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[3/n][Optimus][Auto-AC][reland] Support any fp8 quantization type and set scaling as the default" (#154057)
Summary: This is a reland of D74910193. We change the dtype to torch.float8_e5m2 in unit test since it is not supported. Test Plan: ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:quantization ``` Differential Revision: D75169792 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154057 Approved by: https://github.com/Mingming-Ding
This commit is contained in:
committed by
PyTorch MergeBot
parent
c2660d29a5
commit
788d9cb2d7
@ -47,6 +47,21 @@ class FeedforwardNN(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class LayernormNN(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input, normalized_shape, weight, bias):
|
||||
x = torch.nn.functional.layer_norm(
|
||||
input=input,
|
||||
normalized_shape=normalized_shape,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
eps=1e-5,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class TestQuantization(TestCase):
|
||||
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
|
||||
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
|
||||
@ -106,10 +121,11 @@ class TestQuantization(TestCase):
|
||||
"use_scaling": True,
|
||||
"size_in_mb": 0.0,
|
||||
"exclude_primals": True,
|
||||
"allowed_dtypes": "torch.bfloat16;torch.float32",
|
||||
},
|
||||
},
|
||||
)
|
||||
def test_activation_quantization_aten(self):
|
||||
def test_activation_quantization_aten_with_scaling(self):
|
||||
counters.clear()
|
||||
module = TargetCPModule().to(GPU_TYPE)
|
||||
input = [
|
||||
@ -159,6 +175,59 @@ class TestQuantization(TestCase):
|
||||
self.assertTrue(torch.allclose(ref, res))
|
||||
counters.clear()
|
||||
|
||||
@requires_gpu()
|
||||
@torch._inductor.config.patch(
|
||||
pre_grad_fusion_options={},
|
||||
post_grad_fusion_options={
|
||||
"activation_quantization_aten_pass": {
|
||||
"quant_type": "torch.float8_e5m2",
|
||||
"use_scaling": False,
|
||||
"size_in_mb": 0.0,
|
||||
"exclude_primals": True,
|
||||
"allowed_dtypes": "torch.bfloat16;torch.float32",
|
||||
},
|
||||
},
|
||||
)
|
||||
def test_activation_quantization_aten_without_scaling(self):
|
||||
counters.clear()
|
||||
|
||||
module = LayernormNN().to(GPU_TYPE)
|
||||
normalized_shape = [256]
|
||||
input = [
|
||||
torch.randn(
|
||||
(1, 3, 256), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
|
||||
),
|
||||
normalized_shape,
|
||||
torch.randn(
|
||||
*normalized_shape,
|
||||
requires_grad=True,
|
||||
device=GPU_TYPE,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
torch.randn(
|
||||
*normalized_shape,
|
||||
requires_grad=True,
|
||||
device=GPU_TYPE,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
]
|
||||
traced = torch.compile(module)
|
||||
ref = module(*input)
|
||||
res = traced(*input)
|
||||
self.compare_pred(module, traced, input)
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
self.compare_parameters(module, traced)
|
||||
self.compare_gradients(module, traced)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["activation_quantization_fwd_aten_pass"], 1
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["activation_quantization_bwd_aten_pass"], 1
|
||||
)
|
||||
self.assertTrue(torch.allclose(ref, res))
|
||||
counters.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_GPU:
|
||||
|
@ -379,7 +379,7 @@ def calculate_quantization_scaling(
|
||||
scale_node = graph.call_function(
|
||||
torch.ops.prims.convert_element_type.default,
|
||||
args=(mul_node, torch.float32),
|
||||
name="scale_" + str(node.name),
|
||||
name="fp8_scale_" + str(node.name),
|
||||
)
|
||||
scale_node.meta["val"] = torch.ops.prims.convert_element_type.default(
|
||||
mul_node.meta["val"], torch.float32
|
||||
@ -444,7 +444,7 @@ def perform_quantization(
|
||||
quant_activation_node = graph.call_function(
|
||||
torch.ops.prims.convert_element_type.default,
|
||||
args=(clamp_max_scaled_node, quant_type),
|
||||
name="quant_" + str(node.name),
|
||||
name="fp8_quant_" + str(node.name),
|
||||
)
|
||||
quant_activation_node.meta[
|
||||
"val"
|
||||
@ -513,13 +513,8 @@ def calculate_range(dtype: torch.dtype) -> tuple:
|
||||
Returns:
|
||||
tuple: A tuple containing the minimum and maximum values.
|
||||
"""
|
||||
if dtype == torch.float8_e5m2:
|
||||
# 8-bit floating-point format with e5m2 layout
|
||||
min_val = -57344.0
|
||||
max_val = 57344.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
return min_val, max_val
|
||||
info = torch.finfo(dtype)
|
||||
return info.min, info.max
|
||||
|
||||
|
||||
def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
||||
@ -535,7 +530,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
||||
# case: use scaling
|
||||
if torch._inductor.config.post_grad_fusion_options[
|
||||
"activation_quantization_aten_pass"
|
||||
].get("use_scaling", False):
|
||||
].get("use_scaling", True):
|
||||
# calculating the scale
|
||||
scale_node = calculate_quantization_scaling(
|
||||
graph, node, clamp_max, 1e-12
|
||||
@ -554,7 +549,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
||||
quant_node = graph.call_function(
|
||||
torch.ops.prims.convert_element_type.default,
|
||||
args=(node, quant_type),
|
||||
name="quant_" + str(node.name),
|
||||
name="fp8_quant_" + str(node.name),
|
||||
)
|
||||
quant_node.meta[
|
||||
"val"
|
||||
@ -595,12 +590,13 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
|
||||
# case: use scaling
|
||||
with graph.inserting_after(node):
|
||||
# find corresponding scale node
|
||||
scale_name = "scale_" + node.name.replace("quant_", "")
|
||||
scale_name = "fp8_scale_" + node.name.replace("fp8_quant_", "")
|
||||
scale_node = next(
|
||||
bwd_input
|
||||
for bwd_input in bw_inputs
|
||||
if bwd_input.name == scale_name
|
||||
)
|
||||
with graph.inserting_after(scale_node):
|
||||
activation_node = graph.call_function(
|
||||
torch.ops.prims.convert_element_type.default,
|
||||
args=(node, dequant_type),
|
||||
@ -613,7 +609,7 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
|
||||
activation_node.meta["tensor_meta"] = extract_tensor_metadata(
|
||||
activation_node.meta["val"]
|
||||
)
|
||||
with graph.inserting_after(scale_node):
|
||||
with graph.inserting_after(activation_node):
|
||||
divided_target_node_32 = graph.call_function(
|
||||
torch.ops.aten.div.Tensor,
|
||||
args=(activation_node, scale_node),
|
||||
@ -725,11 +721,22 @@ def enable_activation_quantization(
|
||||
),
|
||||
)
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "before_activation_quantization_bwd_aten_pass",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: bwd_module.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
),
|
||||
)
|
||||
|
||||
quant_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
|
||||
# update the corresponding bwd_inputs due to the fwd_outputs quantization
|
||||
for fwd_node in quant_fwd_module_outputs:
|
||||
if "quant_" in fwd_node.name:
|
||||
bwd_input = bwd_module_inputs[fwd_node.name.replace("quant_", "")]
|
||||
if "fp8_quant_" in fwd_node.name:
|
||||
bwd_input = bwd_module_inputs[fwd_node.name.replace("fp8_quant_", "")]
|
||||
with bwd_module.graph.inserting_after(bwd_input):
|
||||
quant_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
|
||||
dequant_type = bwd_input.meta["dequant_type"]
|
||||
@ -741,34 +748,24 @@ def enable_activation_quantization(
|
||||
# update the bwd_inputs if quantization with scaling is used
|
||||
if torch._inductor.config.post_grad_fusion_options[
|
||||
"activation_quantization_aten_pass"
|
||||
].get("use_scaling", False):
|
||||
].get("use_scaling", True):
|
||||
quant_bwd_module_inputs = list(bwd_module.graph.find_nodes(op="placeholder"))
|
||||
# update the corresponding bwd input nodes find the last non-tangent node
|
||||
bwd_input_loc = list(bwd_module_inputs.values())[-1]
|
||||
for bw_input in reversed(bwd_module_inputs.values()):
|
||||
bwd_input_loc = quant_bwd_module_inputs[-1]
|
||||
for bw_input in reversed(quant_bwd_module_inputs):
|
||||
if not _is_tangent(bw_input):
|
||||
bwd_input_loc = bw_input
|
||||
break
|
||||
|
||||
scaled_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
|
||||
for fwd_node in scaled_fwd_module_outputs:
|
||||
if "scale_" in fwd_node.name:
|
||||
if "fp8_scale_" in fwd_node.name:
|
||||
# fwd node is a scale node
|
||||
with bwd_module.graph.inserting_after(bwd_input_loc):
|
||||
scale_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
|
||||
scale_bwd_input.meta.update(fwd_node.meta)
|
||||
bwd_input_loc = scale_bwd_input
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "before_activation_quantization_bwd_aten_pass",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: bwd_module.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
),
|
||||
)
|
||||
|
||||
quantize_activation_bw(bwd_module.graph)
|
||||
|
||||
trace_structured(
|
||||
|
Reference in New Issue
Block a user