mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157637 Approved by: https://github.com/yewentao256, https://github.com/albanD ghstack dependencies: #156605
239 lines
6.9 KiB
Python
239 lines
6.9 KiB
Python
# Owner(s): ["module: higher order operators"]
|
|
# flake8: noqa: B950
|
|
|
|
import contextlib
|
|
import logging
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._functorch
|
|
import torch._inductor
|
|
import torch._inductor.decomposition
|
|
from torch._higher_order_ops import InvokeQuant
|
|
from torch._inductor import config
|
|
from torch._inductor.pattern_matcher import (
|
|
Arg,
|
|
CallFunction,
|
|
Ignored,
|
|
Match,
|
|
PatternMatcherPass,
|
|
register_graph_pattern,
|
|
)
|
|
from torch._inductor.utils import is_big_gpu, run_and_get_code
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_utils import (
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
skipIfXpu,
|
|
TestCase,
|
|
)
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
|
|
|
|
|
|
invoke_quant_tracer = InvokeQuant()
|
|
|
|
|
|
@skipIfTorchDynamo("Not a torch._dynamo test")
|
|
class TestInvokeQuant(TestCase):
|
|
backend = ""
|
|
|
|
def test_simple(self):
|
|
def gn(x, y):
|
|
return (torch.mul(x, y) + y,)
|
|
|
|
def fn(x, y):
|
|
return invoke_quant_tracer(
|
|
gn, x, y, scheme="nf4", quant_options=invoke_quant_tracer
|
|
)[0]
|
|
|
|
x = torch.randn(8, requires_grad=False)
|
|
y = torch.randn(8, requires_grad=False)
|
|
ref = gn(x, y)[0]
|
|
|
|
x_clone = x.clone().detach().requires_grad_(False)
|
|
y_clone = y.clone().detach().requires_grad_(False)
|
|
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_construct_inline(self):
|
|
def gn(x, y):
|
|
return (torch.mul(x, y) + y,)
|
|
|
|
def fn(x, y):
|
|
return InvokeQuant(codegen_low_precision=False)(gn, x, y, scheme="nf4")[0]
|
|
|
|
x = torch.randn(8, requires_grad=False)
|
|
y = torch.randn(8, requires_grad=False)
|
|
ref = gn(x, y)[0]
|
|
|
|
x_clone = x.clone().detach().requires_grad_(False)
|
|
y_clone = y.clone().detach().requires_grad_(False)
|
|
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_inline(self):
|
|
def gn(x, y):
|
|
return (torch.mul(x, y) + y,)
|
|
|
|
def fn(x, y):
|
|
return InvokeQuant()(gn, x, y, scheme="nf4")[0]
|
|
|
|
x = torch.randn(8, requires_grad=False)
|
|
y = torch.randn(8, requires_grad=False)
|
|
ref = gn(x, y)[0]
|
|
|
|
x_clone = x.clone().detach().requires_grad_(False)
|
|
y_clone = y.clone().detach().requires_grad_(False)
|
|
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_multiple(self):
|
|
torch._logging.set_logs(post_grad_graphs=True)
|
|
|
|
def gn(x, y):
|
|
return torch.mul(x, y) + y
|
|
|
|
def fn(x, y, z):
|
|
o1 = invoke_quant_tracer(gn, x, y, scheme="nf4")
|
|
o2 = invoke_quant_tracer(gn, y, z, scheme="nf4")
|
|
return o1 + o2
|
|
|
|
x = torch.randn(8, requires_grad=False)
|
|
y = torch.randn(8, requires_grad=False)
|
|
z = torch.randn(8, requires_grad=False)
|
|
ref = fn(x, y, z)
|
|
|
|
log_context = (
|
|
contextlib.nullcontext()
|
|
if self.backend != "inductor"
|
|
else self.assertLogs(logger="torch._inductor", level=logging.DEBUG)
|
|
)
|
|
|
|
with log_context as log:
|
|
res = torch.compile(fn, backend=self.backend)(x, y, z)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
if self.backend == "inductor":
|
|
logs = "\n".join(r.getMessage() for r in log.records)
|
|
f = FileCheck()
|
|
f.check("AFTER POST GRAD")
|
|
f.check("subgraph0").check("subgraph1")
|
|
for _ in range(2):
|
|
f.check("torch.ops.higher_order.invoke_quant(").check_same("nf4")
|
|
f.run(logs)
|
|
|
|
|
|
class TestInvokeQuantEager(TestInvokeQuant):
|
|
backend = "eager"
|
|
|
|
|
|
class TestInvokeQuantAotEager(TestInvokeQuant):
|
|
backend = "aot_eager"
|
|
|
|
|
|
class TestInvokeQuantInductor(TestInvokeQuant):
|
|
backend = "inductor"
|
|
|
|
def test_pattern_matching(self):
|
|
counter = 0
|
|
|
|
test_pass = PatternMatcherPass()
|
|
|
|
def my_pass(g):
|
|
return test_pass.apply(g)
|
|
|
|
def gn(x, y):
|
|
return torch.mul(x, y) + y
|
|
|
|
def fn(x, y, z):
|
|
return invoke_quant_tracer(gn, x, y, scheme="nf4") @ z
|
|
|
|
def fn_no_match(x, y, z):
|
|
return invoke_quant_tracer(gn, x, y) @ z
|
|
|
|
x = torch.randn(64, 64, requires_grad=False)
|
|
y = torch.randn(64, 64, requires_grad=False)
|
|
z = torch.randn(64, 64, requires_grad=False)
|
|
|
|
@register_graph_pattern(
|
|
CallFunction(
|
|
torch.ops.aten.mm,
|
|
CallFunction(
|
|
torch.ops.higher_order.invoke_quant,
|
|
Ignored(),
|
|
Ignored(),
|
|
Ignored(),
|
|
scheme="nf4",
|
|
),
|
|
Arg(),
|
|
),
|
|
pass_dict=test_pass,
|
|
)
|
|
def quant_matching(match: Match, *args, **kwargs):
|
|
nonlocal counter
|
|
counter += 1
|
|
|
|
with unittest.mock.patch(
|
|
"torch._inductor.config.post_grad_custom_pre_pass", my_pass
|
|
):
|
|
torch.compile(fn)(x, y, z)
|
|
self.assertTrue(counter == 1)
|
|
|
|
torch.compile(fn_no_match)(x, y, z)
|
|
self.assertTrue(counter == 1)
|
|
|
|
@skipIfXpu(
|
|
msg="MM Triton template fusion for XPU not work because the fusion"
|
|
" can not speedup, unskip until #146568 fixed."
|
|
)
|
|
@requires_gpu()
|
|
@config.patch(prologue_fusion=True)
|
|
def test_prologue(self):
|
|
if not is_big_gpu():
|
|
raise unittest.SkipTest("requires large gpu to max-autotune")
|
|
|
|
def gn(x, y):
|
|
return torch.mul(x, y) + (y - 1)
|
|
|
|
def fn(x, y, z):
|
|
return (
|
|
invoke_quant_tracer(
|
|
gn, x, y, scheme="nf4", quant_options=invoke_quant_tracer
|
|
)
|
|
@ z
|
|
)
|
|
|
|
x = torch.randn(
|
|
64, 64, requires_grad=False, device=GPU_TYPE, dtype=torch.float16
|
|
)
|
|
# make this a no-op to ensure equivalent numerics
|
|
y = torch.randn(
|
|
64, 64, requires_grad=False, device=GPU_TYPE, dtype=torch.float16
|
|
).fill_(1.0)
|
|
z = torch.randn(
|
|
64, 64, requires_grad=False, device=GPU_TYPE, dtype=torch.float16
|
|
)
|
|
ref = gn(x, y) @ z
|
|
|
|
x_clone = x.clone().detach().requires_grad_(False)
|
|
y_clone = y.clone().detach().requires_grad_(False)
|
|
z_clone = z.clone().detach().requires_grad_(False)
|
|
torch._dynamo.reset()
|
|
with torch.no_grad(), config.patch(max_autotune_gemm_backends="TRITON"):
|
|
fn_c = torch.compile(fn, mode="max-autotune-no-cudagraphs")
|
|
res, code = run_and_get_code(fn_c, x_clone, y_clone, z_clone)
|
|
|
|
FileCheck().check("k_idx in range").check_not("tl.float32").check(
|
|
"tl.dot"
|
|
).run(code[0])
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
|
|
del TestInvokeQuant
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|