mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[3/n][Optimus][Auto-AC][reland] Support any fp8 quantization type and set scaling as the default" (#154057)
Summary: This is a reland of D74910193. We change the dtype to torch.float8_e5m2 in unit test since it is not supported. Test Plan: ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:quantization ``` Differential Revision: D75169792 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154057 Approved by: https://github.com/Mingming-Ding
This commit is contained in:
committed by
PyTorch MergeBot
parent
c2660d29a5
commit
788d9cb2d7
@ -47,6 +47,21 @@ class FeedforwardNN(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LayernormNN(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, input, normalized_shape, weight, bias):
|
||||||
|
x = torch.nn.functional.layer_norm(
|
||||||
|
input=input,
|
||||||
|
normalized_shape=normalized_shape,
|
||||||
|
weight=weight,
|
||||||
|
bias=bias,
|
||||||
|
eps=1e-5,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class TestQuantization(TestCase):
|
class TestQuantization(TestCase):
|
||||||
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
|
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
|
||||||
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
|
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
|
||||||
@ -106,10 +121,11 @@ class TestQuantization(TestCase):
|
|||||||
"use_scaling": True,
|
"use_scaling": True,
|
||||||
"size_in_mb": 0.0,
|
"size_in_mb": 0.0,
|
||||||
"exclude_primals": True,
|
"exclude_primals": True,
|
||||||
|
"allowed_dtypes": "torch.bfloat16;torch.float32",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def test_activation_quantization_aten(self):
|
def test_activation_quantization_aten_with_scaling(self):
|
||||||
counters.clear()
|
counters.clear()
|
||||||
module = TargetCPModule().to(GPU_TYPE)
|
module = TargetCPModule().to(GPU_TYPE)
|
||||||
input = [
|
input = [
|
||||||
@ -159,6 +175,59 @@ class TestQuantization(TestCase):
|
|||||||
self.assertTrue(torch.allclose(ref, res))
|
self.assertTrue(torch.allclose(ref, res))
|
||||||
counters.clear()
|
counters.clear()
|
||||||
|
|
||||||
|
@requires_gpu()
|
||||||
|
@torch._inductor.config.patch(
|
||||||
|
pre_grad_fusion_options={},
|
||||||
|
post_grad_fusion_options={
|
||||||
|
"activation_quantization_aten_pass": {
|
||||||
|
"quant_type": "torch.float8_e5m2",
|
||||||
|
"use_scaling": False,
|
||||||
|
"size_in_mb": 0.0,
|
||||||
|
"exclude_primals": True,
|
||||||
|
"allowed_dtypes": "torch.bfloat16;torch.float32",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def test_activation_quantization_aten_without_scaling(self):
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
module = LayernormNN().to(GPU_TYPE)
|
||||||
|
normalized_shape = [256]
|
||||||
|
input = [
|
||||||
|
torch.randn(
|
||||||
|
(1, 3, 256), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
|
||||||
|
),
|
||||||
|
normalized_shape,
|
||||||
|
torch.randn(
|
||||||
|
*normalized_shape,
|
||||||
|
requires_grad=True,
|
||||||
|
device=GPU_TYPE,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
|
torch.randn(
|
||||||
|
*normalized_shape,
|
||||||
|
requires_grad=True,
|
||||||
|
device=GPU_TYPE,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
traced = torch.compile(module)
|
||||||
|
ref = module(*input)
|
||||||
|
res = traced(*input)
|
||||||
|
self.compare_pred(module, traced, input)
|
||||||
|
ref.sum().backward()
|
||||||
|
res.sum().backward()
|
||||||
|
self.compare_parameters(module, traced)
|
||||||
|
self.compare_gradients(module, traced)
|
||||||
|
self.assertEqual(
|
||||||
|
counters["inductor"]["activation_quantization_fwd_aten_pass"], 1
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
counters["inductor"]["activation_quantization_bwd_aten_pass"], 1
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.allclose(ref, res))
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if IS_LINUX and HAS_GPU:
|
if IS_LINUX and HAS_GPU:
|
||||||
|
|||||||
@ -379,7 +379,7 @@ def calculate_quantization_scaling(
|
|||||||
scale_node = graph.call_function(
|
scale_node = graph.call_function(
|
||||||
torch.ops.prims.convert_element_type.default,
|
torch.ops.prims.convert_element_type.default,
|
||||||
args=(mul_node, torch.float32),
|
args=(mul_node, torch.float32),
|
||||||
name="scale_" + str(node.name),
|
name="fp8_scale_" + str(node.name),
|
||||||
)
|
)
|
||||||
scale_node.meta["val"] = torch.ops.prims.convert_element_type.default(
|
scale_node.meta["val"] = torch.ops.prims.convert_element_type.default(
|
||||||
mul_node.meta["val"], torch.float32
|
mul_node.meta["val"], torch.float32
|
||||||
@ -444,7 +444,7 @@ def perform_quantization(
|
|||||||
quant_activation_node = graph.call_function(
|
quant_activation_node = graph.call_function(
|
||||||
torch.ops.prims.convert_element_type.default,
|
torch.ops.prims.convert_element_type.default,
|
||||||
args=(clamp_max_scaled_node, quant_type),
|
args=(clamp_max_scaled_node, quant_type),
|
||||||
name="quant_" + str(node.name),
|
name="fp8_quant_" + str(node.name),
|
||||||
)
|
)
|
||||||
quant_activation_node.meta[
|
quant_activation_node.meta[
|
||||||
"val"
|
"val"
|
||||||
@ -513,13 +513,8 @@ def calculate_range(dtype: torch.dtype) -> tuple:
|
|||||||
Returns:
|
Returns:
|
||||||
tuple: A tuple containing the minimum and maximum values.
|
tuple: A tuple containing the minimum and maximum values.
|
||||||
"""
|
"""
|
||||||
if dtype == torch.float8_e5m2:
|
info = torch.finfo(dtype)
|
||||||
# 8-bit floating-point format with e5m2 layout
|
return info.min, info.max
|
||||||
min_val = -57344.0
|
|
||||||
max_val = 57344.0
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
|
||||||
return min_val, max_val
|
|
||||||
|
|
||||||
|
|
||||||
def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
||||||
@ -535,7 +530,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
|||||||
# case: use scaling
|
# case: use scaling
|
||||||
if torch._inductor.config.post_grad_fusion_options[
|
if torch._inductor.config.post_grad_fusion_options[
|
||||||
"activation_quantization_aten_pass"
|
"activation_quantization_aten_pass"
|
||||||
].get("use_scaling", False):
|
].get("use_scaling", True):
|
||||||
# calculating the scale
|
# calculating the scale
|
||||||
scale_node = calculate_quantization_scaling(
|
scale_node = calculate_quantization_scaling(
|
||||||
graph, node, clamp_max, 1e-12
|
graph, node, clamp_max, 1e-12
|
||||||
@ -554,7 +549,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
|||||||
quant_node = graph.call_function(
|
quant_node = graph.call_function(
|
||||||
torch.ops.prims.convert_element_type.default,
|
torch.ops.prims.convert_element_type.default,
|
||||||
args=(node, quant_type),
|
args=(node, quant_type),
|
||||||
name="quant_" + str(node.name),
|
name="fp8_quant_" + str(node.name),
|
||||||
)
|
)
|
||||||
quant_node.meta[
|
quant_node.meta[
|
||||||
"val"
|
"val"
|
||||||
@ -595,12 +590,13 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
|
|||||||
# case: use scaling
|
# case: use scaling
|
||||||
with graph.inserting_after(node):
|
with graph.inserting_after(node):
|
||||||
# find corresponding scale node
|
# find corresponding scale node
|
||||||
scale_name = "scale_" + node.name.replace("quant_", "")
|
scale_name = "fp8_scale_" + node.name.replace("fp8_quant_", "")
|
||||||
scale_node = next(
|
scale_node = next(
|
||||||
bwd_input
|
bwd_input
|
||||||
for bwd_input in bw_inputs
|
for bwd_input in bw_inputs
|
||||||
if bwd_input.name == scale_name
|
if bwd_input.name == scale_name
|
||||||
)
|
)
|
||||||
|
with graph.inserting_after(scale_node):
|
||||||
activation_node = graph.call_function(
|
activation_node = graph.call_function(
|
||||||
torch.ops.prims.convert_element_type.default,
|
torch.ops.prims.convert_element_type.default,
|
||||||
args=(node, dequant_type),
|
args=(node, dequant_type),
|
||||||
@ -613,7 +609,7 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
|
|||||||
activation_node.meta["tensor_meta"] = extract_tensor_metadata(
|
activation_node.meta["tensor_meta"] = extract_tensor_metadata(
|
||||||
activation_node.meta["val"]
|
activation_node.meta["val"]
|
||||||
)
|
)
|
||||||
with graph.inserting_after(scale_node):
|
with graph.inserting_after(activation_node):
|
||||||
divided_target_node_32 = graph.call_function(
|
divided_target_node_32 = graph.call_function(
|
||||||
torch.ops.aten.div.Tensor,
|
torch.ops.aten.div.Tensor,
|
||||||
args=(activation_node, scale_node),
|
args=(activation_node, scale_node),
|
||||||
@ -725,11 +721,22 @@ def enable_activation_quantization(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
trace_structured(
|
||||||
|
"artifact",
|
||||||
|
metadata_fn=lambda: {
|
||||||
|
"name": "before_activation_quantization_bwd_aten_pass",
|
||||||
|
"encoding": "string",
|
||||||
|
},
|
||||||
|
payload_fn=lambda: bwd_module.print_readable(
|
||||||
|
print_output=False, include_stride=True, include_device=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
quant_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
|
quant_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
|
||||||
# update the corresponding bwd_inputs due to the fwd_outputs quantization
|
# update the corresponding bwd_inputs due to the fwd_outputs quantization
|
||||||
for fwd_node in quant_fwd_module_outputs:
|
for fwd_node in quant_fwd_module_outputs:
|
||||||
if "quant_" in fwd_node.name:
|
if "fp8_quant_" in fwd_node.name:
|
||||||
bwd_input = bwd_module_inputs[fwd_node.name.replace("quant_", "")]
|
bwd_input = bwd_module_inputs[fwd_node.name.replace("fp8_quant_", "")]
|
||||||
with bwd_module.graph.inserting_after(bwd_input):
|
with bwd_module.graph.inserting_after(bwd_input):
|
||||||
quant_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
|
quant_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
|
||||||
dequant_type = bwd_input.meta["dequant_type"]
|
dequant_type = bwd_input.meta["dequant_type"]
|
||||||
@ -741,34 +748,24 @@ def enable_activation_quantization(
|
|||||||
# update the bwd_inputs if quantization with scaling is used
|
# update the bwd_inputs if quantization with scaling is used
|
||||||
if torch._inductor.config.post_grad_fusion_options[
|
if torch._inductor.config.post_grad_fusion_options[
|
||||||
"activation_quantization_aten_pass"
|
"activation_quantization_aten_pass"
|
||||||
].get("use_scaling", False):
|
].get("use_scaling", True):
|
||||||
|
quant_bwd_module_inputs = list(bwd_module.graph.find_nodes(op="placeholder"))
|
||||||
# update the corresponding bwd input nodes find the last non-tangent node
|
# update the corresponding bwd input nodes find the last non-tangent node
|
||||||
bwd_input_loc = list(bwd_module_inputs.values())[-1]
|
bwd_input_loc = quant_bwd_module_inputs[-1]
|
||||||
for bw_input in reversed(bwd_module_inputs.values()):
|
for bw_input in reversed(quant_bwd_module_inputs):
|
||||||
if not _is_tangent(bw_input):
|
if not _is_tangent(bw_input):
|
||||||
bwd_input_loc = bw_input
|
bwd_input_loc = bw_input
|
||||||
break
|
break
|
||||||
|
|
||||||
scaled_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
|
scaled_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
|
||||||
for fwd_node in scaled_fwd_module_outputs:
|
for fwd_node in scaled_fwd_module_outputs:
|
||||||
if "scale_" in fwd_node.name:
|
if "fp8_scale_" in fwd_node.name:
|
||||||
# fwd node is a scale node
|
# fwd node is a scale node
|
||||||
with bwd_module.graph.inserting_after(bwd_input_loc):
|
with bwd_module.graph.inserting_after(bwd_input_loc):
|
||||||
scale_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
|
scale_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
|
||||||
scale_bwd_input.meta.update(fwd_node.meta)
|
scale_bwd_input.meta.update(fwd_node.meta)
|
||||||
bwd_input_loc = scale_bwd_input
|
bwd_input_loc = scale_bwd_input
|
||||||
|
|
||||||
trace_structured(
|
|
||||||
"artifact",
|
|
||||||
metadata_fn=lambda: {
|
|
||||||
"name": "before_activation_quantization_bwd_aten_pass",
|
|
||||||
"encoding": "string",
|
|
||||||
},
|
|
||||||
payload_fn=lambda: bwd_module.print_readable(
|
|
||||||
print_output=False, include_stride=True, include_device=True
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
quantize_activation_bw(bwd_module.graph)
|
quantize_activation_bw(bwd_module.graph)
|
||||||
|
|
||||||
trace_structured(
|
trace_structured(
|
||||||
|
|||||||
Reference in New Issue
Block a user