Files
pytorch/test/inductor/test_quantization.py
Menglu Yu 788d9cb2d7 [3/n][Optimus][Auto-AC][reland] Support any fp8 quantization type and set scaling as the default" (#154057)
Summary:
This is a reland of D74910193.
We change the dtype to torch.float8_e5m2 in unit test since it is not supported.

Test Plan:
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:quantization
```

Differential Revision: D75169792

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154057
Approved by: https://github.com/Mingming-Ding
2025-05-22 18:26:34 +00:00

235 lines
7.6 KiB
Python

# Owner(s): ["module: inductor"]
import logging
import numpy as np
import torch
import torch._inductor
import torch._inductor.fx_passes.group_batch_fusion
from torch._dynamo.utils import counters
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
log = logging.getLogger(__name__)
class TargetCPModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
relued = torch.relu(x1)
tanhed = torch.tanh(relued)
tensor = torch.matmul(
tanhed,
x2,
)
return tensor
class FeedforwardNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(1, 64)
self.fc2 = torch.nn.Linear(64, 64)
self.fc3 = torch.nn.Linear(64, 64)
self.fc4 = torch.nn.Linear(64, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
tanh_x = torch.tanh(x)
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(tanh_x))
x = self.fc4(x)
return x
class LayernormNN(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, normalized_shape, weight, bias):
x = torch.nn.functional.layer_norm(
input=input,
normalized_shape=normalized_shape,
weight=weight,
bias=bias,
eps=1e-5,
)
return x
class TestQuantization(TestCase):
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
return False
for key1 in ref_dict.keys():
key2 = "_orig_mod." + key1
assert key2 in res_dict, f"{key1} does not exist in traced module"
# if both of them are None, continue
if (
not isinstance(ref_dict[key1], torch.Tensor)
and not isinstance(res_dict[key2], torch.Tensor)
and ref_dict[key1] is None
and res_dict[key2] is None
):
log.info(
"None found with key1 and value 1: %s, %s, key2 and value2 %s, %s",
key1,
ref_dict[key1],
key2,
res_dict[key2],
)
continue
elif not torch.allclose(
ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol, equal_nan=True
):
log.info(
"gradient mismatch for eager and compiled modules, with eager: %s and compiled: %s",
ref_dict[key1],
res_dict[key2],
)
return False
return True
def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
ref = module(*input)
res = traced(*input)
self.assertEqual(ref, res, rtol=rtol, atol=atol)
def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
ref_params = dict(module.named_parameters())
res_params = dict(traced.named_parameters())
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))
def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
ref_grad = {key: param.grad for key, param in module.named_parameters()}
res_grad = {key: param.grad for key, param in traced.named_parameters()}
self.assertTrue(
self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
)
@requires_gpu()
@torch._inductor.config.patch(
pre_grad_fusion_options={},
post_grad_fusion_options={
"activation_quantization_aten_pass": {
"quant_type": "torch.float8_e5m2",
"use_scaling": True,
"size_in_mb": 0.0,
"exclude_primals": True,
"allowed_dtypes": "torch.bfloat16;torch.float32",
},
},
)
def test_activation_quantization_aten_with_scaling(self):
counters.clear()
module = TargetCPModule().to(GPU_TYPE)
input = [
torch.rand(
(16, 10), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
),
torch.rand(
(10, 16), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
),
]
traced = torch.compile(module)
ref = module(*input)
res = traced(*input)
self.compare_pred(module, traced, input)
ref.sum().backward()
res.sum().backward()
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
self.assertEqual(
counters["inductor"]["activation_quantization_fwd_aten_pass"], 1
)
self.assertEqual(
counters["inductor"]["activation_quantization_bwd_aten_pass"], 1
)
self.assertTrue(torch.allclose(ref, res))
counters.clear()
module = FeedforwardNN().to(GPU_TYPE)
X = np.linspace(-10, 10, 100).reshape(-1, 1).astype(np.float32)
input = [
torch.from_numpy(X).to(GPU_TYPE),
]
traced = torch.compile(module)
ref = module(*input)
res = traced(*input)
self.compare_pred(module, traced, input)
ref.sum().backward()
res.sum().backward()
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
self.assertEqual(
counters["inductor"]["activation_quantization_fwd_aten_pass"], 1
)
self.assertEqual(
counters["inductor"]["activation_quantization_bwd_aten_pass"], 1
)
self.assertTrue(torch.allclose(ref, res))
counters.clear()
@requires_gpu()
@torch._inductor.config.patch(
pre_grad_fusion_options={},
post_grad_fusion_options={
"activation_quantization_aten_pass": {
"quant_type": "torch.float8_e5m2",
"use_scaling": False,
"size_in_mb": 0.0,
"exclude_primals": True,
"allowed_dtypes": "torch.bfloat16;torch.float32",
},
},
)
def test_activation_quantization_aten_without_scaling(self):
counters.clear()
module = LayernormNN().to(GPU_TYPE)
normalized_shape = [256]
input = [
torch.randn(
(1, 3, 256), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
),
normalized_shape,
torch.randn(
*normalized_shape,
requires_grad=True,
device=GPU_TYPE,
dtype=torch.bfloat16,
),
torch.randn(
*normalized_shape,
requires_grad=True,
device=GPU_TYPE,
dtype=torch.bfloat16,
),
]
traced = torch.compile(module)
ref = module(*input)
res = traced(*input)
self.compare_pred(module, traced, input)
ref.sum().backward()
res.sum().backward()
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
self.assertEqual(
counters["inductor"]["activation_quantization_fwd_aten_pass"], 1
)
self.assertEqual(
counters["inductor"]["activation_quantization_bwd_aten_pass"], 1
)
self.assertTrue(torch.allclose(ref, res))
counters.clear()
if __name__ == "__main__":
if IS_LINUX and HAS_GPU:
run_tests()