mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Currently, the device-bias linter only targets functions decorated with @requires_gpu. This PR adds support for two new detection scenarios: 1. Detect device-bias code in functions decorated with @requires_triton. 2. Detect device-bias code for entire test suites that are defined as shared across GPUs. For example: ``` if __name__ == "__main__": if HAS_GPU: run_tests() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159949 Approved by: https://github.com/EikanWang, https://github.com/jansel
455 lines
14 KiB
Python
455 lines
14 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import contextlib
|
|
import dis
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
from torch.testing._internal.common_utils import IS_FBCODE
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton
|
|
from torch.utils._triton import (
|
|
has_triton_experimental_host_tma,
|
|
has_triton_tensor_descriptor_host_tma,
|
|
)
|
|
|
|
|
|
def _filter_instructions(instructions, opname):
|
|
return list(filter(lambda x: x.opname == opname, instructions))
|
|
|
|
|
|
class ReconstructTest(torch._dynamo.test_case.TestCase):
|
|
@contextlib.contextmanager
|
|
def register_bytecode_hook(self, fn):
|
|
def hook(code, out_code):
|
|
fn(list(dis.get_instructions(out_code)))
|
|
return None
|
|
|
|
torch._dynamo.reset()
|
|
handle = torch._dynamo.convert_frame.register_bytecode_hook(hook)
|
|
try:
|
|
yield
|
|
finally:
|
|
handle.remove()
|
|
|
|
def test_ConstDict_optimize_reconstruct(self):
|
|
"""
|
|
Emit code to reconstruct only the key that changed
|
|
"""
|
|
|
|
def hook(instructions: list[dis.Instruction]):
|
|
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
|
self.assertEqual(len(build_map), 1)
|
|
# reconstruct only d[40]
|
|
self.assertEqual(build_map[0].argval, 1)
|
|
|
|
def f(d, t):
|
|
d[40] = t + 1
|
|
|
|
t = torch.randn(3, 4)
|
|
d = {1: t}
|
|
d_opt = d.copy()
|
|
f(d, t)
|
|
|
|
with self.register_bytecode_hook(hook):
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
opt_f(d_opt, t)
|
|
self.assertEqual(d, d_opt)
|
|
|
|
def test_ConstDict_pop_reconstruct(self):
|
|
"""
|
|
If something is pop'ed from the dict, we reconstruct everything
|
|
"""
|
|
|
|
def hook(instructions: list[dis.Instruction]):
|
|
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
|
self.assertEqual(len(build_map), 1)
|
|
# reconstruct everything
|
|
self.assertEqual(build_map[0].argval, 2)
|
|
|
|
def f(d, t):
|
|
d.pop(2)
|
|
d[40] = t + 1
|
|
|
|
t = torch.randn(3, 4)
|
|
d = {1: t, 2: t + 1}
|
|
d_opt = d.copy()
|
|
|
|
f(d, t)
|
|
|
|
with self.register_bytecode_hook(hook):
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
opt_f(d_opt, t)
|
|
self.assertEqual(d, d_opt)
|
|
|
|
def test_ConstDict_popitem_reconstruct(self):
|
|
"""
|
|
If something is pop'ed from the dict, we reconstruct everything
|
|
"""
|
|
|
|
def hook(instructions: list[dis.Instruction]):
|
|
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
|
self.assertEqual(len(build_map), 1)
|
|
# reconstruct everything
|
|
self.assertEqual(build_map[0].argval, 1)
|
|
|
|
def f(d, t):
|
|
d.popitem()
|
|
|
|
t = torch.randn(3, 4)
|
|
d = {1: t, 2: t + 1}
|
|
d_opt = d.copy()
|
|
|
|
f(d, t)
|
|
|
|
with self.register_bytecode_hook(hook):
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
opt_f(d_opt, t)
|
|
self.assertEqual(d, d_opt)
|
|
|
|
def test_ConstDict_popitem_reconstruct_graph_break(self):
|
|
"""
|
|
If something is pop'ed from the dict, we reconstruct everything.
|
|
Calling dict.popitem will graph break.
|
|
"""
|
|
|
|
def f(d, t):
|
|
d.popitem()
|
|
|
|
t = torch.randn(3, 4)
|
|
d = {1: t, 2: t + 1}
|
|
d_opt = d.copy()
|
|
|
|
f(d, t)
|
|
|
|
opt_f = torch.compile(backend="eager")(f)
|
|
opt_f(d_opt, t)
|
|
self.assertEqual(d, d_opt)
|
|
|
|
def test_ConstDict_del_reconstruct(self):
|
|
"""
|
|
If something is deleted from the dict, we reconstruct everything
|
|
"""
|
|
|
|
def hook(instructions: list[dis.Instruction]):
|
|
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
|
self.assertEqual(len(build_map), 1)
|
|
# reconstruct everything
|
|
self.assertEqual(build_map[0].argval, 2)
|
|
|
|
def f(d, t):
|
|
del d[2]
|
|
d[40] = t + 1
|
|
|
|
t = torch.randn(3, 4)
|
|
d = {1: t, 2: t + 1}
|
|
d_opt = d.copy()
|
|
|
|
f(d, t)
|
|
|
|
with self.register_bytecode_hook(hook):
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
opt_f(d_opt, t)
|
|
self.assertEqual(d, d_opt)
|
|
|
|
def test_ConstDict_get_reconstruct(self):
|
|
"""
|
|
dict.get shouldn't affect anything
|
|
"""
|
|
|
|
def hook(instructions: list[dis.Instruction]):
|
|
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
|
self.assertEqual(len(build_map), 1)
|
|
self.assertEqual(build_map[0].argval, 1)
|
|
load_const = _filter_instructions(instructions, "LOAD_CONST")
|
|
self.assertNotIn(123, load_const)
|
|
|
|
def f(d, t):
|
|
d[456] = d.get(456) + t
|
|
|
|
t = torch.randn(3, 4)
|
|
d = {123: t, 456: t + 1}
|
|
d_opt = d.copy()
|
|
|
|
f(d, t)
|
|
|
|
with self.register_bytecode_hook(hook):
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
opt_f(d_opt, t)
|
|
self.assertEqual(d, d_opt)
|
|
|
|
def test_ConstDict_clear_reconstruct(self):
|
|
"""
|
|
If dict.clear() is used, we reconstruct everything
|
|
"""
|
|
|
|
def hook(instructions: list[dis.Instruction]):
|
|
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
|
self.assertEqual(len(build_map), 1)
|
|
# reconstruct everything
|
|
self.assertEqual(build_map[0].argval, 1)
|
|
|
|
def f(d, t):
|
|
d.clear()
|
|
d[3] = t + 3
|
|
|
|
t = torch.randn(3, 4)
|
|
d = {1: t, 2: t + 1}
|
|
d_opt = d.copy()
|
|
|
|
f(d, t)
|
|
|
|
with self.register_bytecode_hook(hook):
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
opt_f(d_opt, t)
|
|
self.assertEqual(d, d_opt)
|
|
|
|
def test_create_dict_reconstruct(self):
|
|
"""
|
|
If dict is created inside a function, everything needs to be reconstructed
|
|
"""
|
|
|
|
def hook(instructions: list[dis.Instruction]):
|
|
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
|
self.assertEqual(len(build_map), 1)
|
|
# reconstruct everything
|
|
self.assertEqual(build_map[0].argval, 2)
|
|
|
|
def f(t):
|
|
return {1: t, 2: t + 1}
|
|
|
|
t = torch.randn(3, 4)
|
|
d = f(t)
|
|
|
|
with self.register_bytecode_hook(hook):
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
d_opt = opt_f(t)
|
|
self.assertEqual(d, d_opt)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE, "capturing functional_call is not enabled by default in FB_CODE"
|
|
)
|
|
def test_functional_call_reconstruct(self):
|
|
"""
|
|
PyTorch shouldn't codegen any key/value when functional_call is used
|
|
"""
|
|
|
|
def hook(instructions: list[dis.Instruction]):
|
|
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
|
# don't reconstruct anything
|
|
self.assertEqual(len(build_map), 0)
|
|
|
|
m = torch.nn.Linear(3, 3)
|
|
new_bias = torch.randn(3)
|
|
new_weight = torch.randn(3, 3)
|
|
|
|
def fn(new_weight, new_bias, x):
|
|
return torch.func.functional_call(
|
|
m, {"weight": new_weight, "bias": new_bias}, x
|
|
)
|
|
|
|
x = torch.randn(2, 3)
|
|
expected = torch.nn.functional.linear(x, new_weight, new_bias)
|
|
with self.register_bytecode_hook(hook):
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
got = opt_fn(new_weight, new_bias, x)
|
|
self.assertEqual(expected, got)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE, "capturing functional_call is not enabled by default in FB_CODE"
|
|
)
|
|
def test_functional_call_reconstruct_2(self):
|
|
"""
|
|
PyTorch shouldn't codegen any key/value when functional_call is used
|
|
"""
|
|
|
|
def hook(instructions: list[dis.Instruction]):
|
|
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
|
# don't reconstruct anything
|
|
self.assertEqual(len(build_map), 0)
|
|
|
|
class DummyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.a = torch.nn.ModuleDict(
|
|
{
|
|
"b": torch.nn.ModuleDict(
|
|
{
|
|
"c": torch.nn.ModuleDict(
|
|
{
|
|
"d": torch.nn.ModuleDict(
|
|
{"e": torch.nn.Linear(10, 10, bias=False)}
|
|
)
|
|
}
|
|
)
|
|
}
|
|
)
|
|
}
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.a.b.c.d.e(x)
|
|
|
|
model = DummyModule()
|
|
|
|
def fn(model, states, x):
|
|
return torch.func.functional_call(model, states, x)
|
|
|
|
x = torch.randn(2, 3)
|
|
states = model.state_dict()
|
|
x = torch.randn(10, 10)
|
|
expected = fn(model, states, x)
|
|
with self.register_bytecode_hook(hook):
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
got = opt_fn(model, states, x)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_graph_break_in_wrapped_user_function(self):
|
|
def fn(x):
|
|
x = x + 1
|
|
torch._dynamo.graph_break()
|
|
assert torch.compiler.is_compiling()
|
|
assert not torch.is_grad_enabled()
|
|
return x + 2
|
|
|
|
@torch.compile(backend="eager")
|
|
def gn(x):
|
|
x = torch.no_grad()(fn)(x)
|
|
# reconstruction failure would cause a skipped frame
|
|
assert torch.compiler.is_compiling()
|
|
assert torch.is_grad_enabled()
|
|
return x
|
|
|
|
inp = torch.randn(3)
|
|
self.assertEqual(gn(inp), inp + 3)
|
|
|
|
def test_graph_break_in_wrapped_user_method(self):
|
|
class Foo:
|
|
def __init__(self):
|
|
self.a = 1
|
|
self.b = 2
|
|
|
|
def fn(self, x):
|
|
x = x + self.a
|
|
torch._dynamo.graph_break()
|
|
assert torch.compiler.is_compiling()
|
|
assert not torch.is_grad_enabled()
|
|
return x + self.b
|
|
|
|
obj = Foo()
|
|
|
|
@torch.compile(backend="eager")
|
|
def gn(x):
|
|
obj.fn = torch.no_grad()(obj.fn)
|
|
x = obj.fn(x)
|
|
# reconstruction failure would cause a skipped frame
|
|
assert torch.compiler.is_compiling()
|
|
assert torch.is_grad_enabled()
|
|
return x
|
|
|
|
inp = torch.randn(3)
|
|
self.assertEqual(gn(inp), inp + 3)
|
|
|
|
def test_graph_break_in_wrapped_nested_function(self):
|
|
@torch.compile(backend="eager")
|
|
def gn(x):
|
|
a = 1
|
|
b = 2
|
|
|
|
@torch.no_grad()
|
|
def fn(x):
|
|
x = x + a
|
|
torch._dynamo.graph_break()
|
|
assert torch.compiler.is_compiling()
|
|
assert not torch.is_grad_enabled()
|
|
return x + b
|
|
|
|
x = fn(x)
|
|
# reconstruction failure would cause a skipped frame
|
|
assert torch.compiler.is_compiling()
|
|
assert torch.is_grad_enabled()
|
|
return x
|
|
|
|
inp = torch.randn(3)
|
|
self.assertEqual(gn(inp), inp + 3)
|
|
|
|
def test_graph_break_in_wrapped_skipped_function(self):
|
|
from torch._dynamo import trace_rules
|
|
from torch._dynamo.testing import _skipped_function_for_test_reconstruct
|
|
from torch._dynamo.variables import SkipFunctionVariable
|
|
|
|
self.assertIs(
|
|
trace_rules.lookup(_skipped_function_for_test_reconstruct),
|
|
SkipFunctionVariable,
|
|
)
|
|
|
|
def fn(x):
|
|
x = x + 1
|
|
torch._dynamo.graph_break()
|
|
assert torch.compiler.is_compiling()
|
|
assert not torch.is_grad_enabled()
|
|
return x + 2
|
|
|
|
@torch.compile(backend="eager")
|
|
def gn(x):
|
|
x = torch.no_grad()(_skipped_function_for_test_reconstruct)(fn, x)
|
|
# reconstruction failure would cause a skipped frame
|
|
assert torch.compiler.is_compiling()
|
|
assert torch.is_grad_enabled()
|
|
return x
|
|
|
|
inp = torch.randn(3)
|
|
self.assertEqual(gn(inp), inp + 3)
|
|
|
|
@requires_triton()
|
|
@unittest.skipIf(
|
|
not has_triton_experimental_host_tma(),
|
|
"Test requires triton.tools.experimental_descriptor API",
|
|
)
|
|
def test_tma_experimental_reconstruct(self):
|
|
import triton
|
|
|
|
def create_tma(tensor):
|
|
tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor(
|
|
tensor.data_ptr(),
|
|
tensor.size(0),
|
|
tensor.size(1),
|
|
32,
|
|
32,
|
|
tensor.element_size(),
|
|
)
|
|
return tensor + 1, tma
|
|
|
|
x = torch.randn(128, 128, device=GPU_TYPE)
|
|
|
|
ref = create_tma(x)
|
|
res = torch.compile(create_tma, backend="eager")(x)
|
|
self.assertEqual(ref[1].desc, res[1].desc)
|
|
|
|
@requires_triton()
|
|
@unittest.skipIf(
|
|
not has_triton_tensor_descriptor_host_tma(),
|
|
"Test requires triton.tools.tensor_descriptor API",
|
|
)
|
|
def test_tma_stable_reconstruct(self):
|
|
import triton
|
|
|
|
def create_tma(tensor):
|
|
tma = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
|
|
tensor,
|
|
[32, 32],
|
|
)
|
|
return tensor + 1, tma
|
|
|
|
x = torch.randn(128, 128, device=GPU_TYPE)
|
|
|
|
ref = create_tma(x)
|
|
res = torch.compile(create_tma, backend="eager")(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|