Files
pytorch/test/higher_order_ops/test_invoke_quant.py
2025-07-17 12:08:33 +00:00

239 lines
6.9 KiB
Python

# Owner(s): ["module: higher order operators"]
# flake8: noqa: B950
import contextlib
import logging
import unittest
import torch
import torch._dynamo
import torch._functorch
import torch._inductor
import torch._inductor.decomposition
from torch._higher_order_ops import InvokeQuant
from torch._inductor import config
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
Ignored,
Match,
PatternMatcherPass,
register_graph_pattern,
)
from torch._inductor.utils import is_big_gpu, run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
run_tests,
skipIfTorchDynamo,
skipIfXpu,
TestCase,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
invoke_quant_tracer = InvokeQuant()
@skipIfTorchDynamo("Not a torch._dynamo test")
class TestInvokeQuant(TestCase):
backend = ""
def test_simple(self):
def gn(x, y):
return (torch.mul(x, y) + y,)
def fn(x, y):
return invoke_quant_tracer(
gn, x, y, scheme="nf4", quant_options=invoke_quant_tracer
)[0]
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
ref = gn(x, y)[0]
x_clone = x.clone().detach().requires_grad_(False)
y_clone = y.clone().detach().requires_grad_(False)
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
self.assertEqual(ref, res)
def test_construct_inline(self):
def gn(x, y):
return (torch.mul(x, y) + y,)
def fn(x, y):
return InvokeQuant(codegen_low_precision=False)(gn, x, y, scheme="nf4")[0]
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
ref = gn(x, y)[0]
x_clone = x.clone().detach().requires_grad_(False)
y_clone = y.clone().detach().requires_grad_(False)
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
self.assertEqual(ref, res)
def test_inline(self):
def gn(x, y):
return (torch.mul(x, y) + y,)
def fn(x, y):
return InvokeQuant()(gn, x, y, scheme="nf4")[0]
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
ref = gn(x, y)[0]
x_clone = x.clone().detach().requires_grad_(False)
y_clone = y.clone().detach().requires_grad_(False)
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
self.assertEqual(ref, res)
def test_multiple(self):
torch._logging.set_logs(post_grad_graphs=True)
def gn(x, y):
return torch.mul(x, y) + y
def fn(x, y, z):
o1 = invoke_quant_tracer(gn, x, y, scheme="nf4")
o2 = invoke_quant_tracer(gn, y, z, scheme="nf4")
return o1 + o2
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
z = torch.randn(8, requires_grad=False)
ref = fn(x, y, z)
log_context = (
contextlib.nullcontext()
if self.backend != "inductor"
else self.assertLogs(logger="torch._inductor", level=logging.DEBUG)
)
with log_context as log:
res = torch.compile(fn, backend=self.backend)(x, y, z)
self.assertEqual(ref, res)
if self.backend == "inductor":
logs = "\n".join(r.getMessage() for r in log.records)
f = FileCheck()
f.check("AFTER POST GRAD")
f.check("subgraph0").check("subgraph1")
for _ in range(2):
f.check("torch.ops.higher_order.invoke_quant(").check_same("nf4")
f.run(logs)
class TestInvokeQuantEager(TestInvokeQuant):
backend = "eager"
class TestInvokeQuantAotEager(TestInvokeQuant):
backend = "aot_eager"
class TestInvokeQuantInductor(TestInvokeQuant):
backend = "inductor"
def test_pattern_matching(self):
counter = 0
test_pass = PatternMatcherPass()
def my_pass(g):
return test_pass.apply(g)
def gn(x, y):
return torch.mul(x, y) + y
def fn(x, y, z):
return invoke_quant_tracer(gn, x, y, scheme="nf4") @ z
def fn_no_match(x, y, z):
return invoke_quant_tracer(gn, x, y) @ z
x = torch.randn(64, 64, requires_grad=False)
y = torch.randn(64, 64, requires_grad=False)
z = torch.randn(64, 64, requires_grad=False)
@register_graph_pattern(
CallFunction(
torch.ops.aten.mm,
CallFunction(
torch.ops.higher_order.invoke_quant,
Ignored(),
Ignored(),
Ignored(),
scheme="nf4",
),
Arg(),
),
pass_dict=test_pass,
)
def quant_matching(match: Match, *args, **kwargs):
nonlocal counter
counter += 1
with unittest.mock.patch(
"torch._inductor.config.post_grad_custom_pre_pass", my_pass
):
torch.compile(fn)(x, y, z)
self.assertTrue(counter == 1)
torch.compile(fn_no_match)(x, y, z)
self.assertTrue(counter == 1)
@skipIfXpu(
msg="MM Triton template fusion for XPU not work because the fusion"
" can not speedup, unskip until #146568 fixed."
)
@requires_gpu()
@config.patch(prologue_fusion=True)
def test_prologue(self):
if not is_big_gpu():
raise unittest.SkipTest("requires large gpu to max-autotune")
def gn(x, y):
return torch.mul(x, y) + (y - 1)
def fn(x, y, z):
return (
invoke_quant_tracer(
gn, x, y, scheme="nf4", quant_options=invoke_quant_tracer
)
@ z
)
x = torch.randn(
64, 64, requires_grad=False, device=GPU_TYPE, dtype=torch.float16
)
# make this a no-op to ensure equivalent numerics
y = torch.randn(
64, 64, requires_grad=False, device=GPU_TYPE, dtype=torch.float16
).fill_(1.0)
z = torch.randn(
64, 64, requires_grad=False, device=GPU_TYPE, dtype=torch.float16
)
ref = gn(x, y) @ z
x_clone = x.clone().detach().requires_grad_(False)
y_clone = y.clone().detach().requires_grad_(False)
z_clone = z.clone().detach().requires_grad_(False)
torch._dynamo.reset()
with torch.no_grad(), config.patch(max_autotune_gemm_backends="TRITON"):
fn_c = torch.compile(fn, mode="max-autotune-no-cudagraphs")
res, code = run_and_get_code(fn_c, x_clone, y_clone, z_clone)
FileCheck().check("k_idx in range").check_not("tl.float32").check(
"tl.dot"
).run(code[0])
self.assertEqual(ref, res)
del TestInvokeQuant
if __name__ == "__main__":
run_tests()