From 8047421fbb607d70ede13b9cd5a60b7b8bdfe348 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Thu, 7 Aug 2025 22:19:11 -0700 Subject: [PATCH] [Linter] Expanding the scope of detecting device-bias code. (#159949) Currently, the device-bias linter only targets functions decorated with @requires_gpu. This PR adds support for two new detection scenarios: 1. Detect device-bias code in functions decorated with @requires_triton. 2. Detect device-bias code for entire test suites that are defined as shared across GPUs. For example: ``` if __name__ == "__main__": if HAS_GPU: run_tests() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159949 Approved by: https://github.com/EikanWang, https://github.com/jansel --- test/dynamo/test_aot_autograd_cache.py | 6 +- test/dynamo/test_reconstruct.py | 6 +- test/inductor/test_aot_inductor.py | 8 +- test/inductor/test_codecache.py | 4 +- test/inductor/test_inplace_padding.py | 4 +- test/inductor/test_max_autotune.py | 84 +++++++++++-------- test/inductor/test_memory.py | 4 +- test/inductor/test_op_dtype_prop.py | 8 +- test/inductor/test_triton_heuristics.py | 2 +- .../adapters/test_device_bias_linter.py | 81 +++++++++++++----- 10 files changed, 132 insertions(+), 75 deletions(-) diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index 0d4a1f01f9a3..d26e4b31917e 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -447,8 +447,8 @@ class AOTAutogradCacheTests(InductorTestCase): def fn(x, y): return (x * 2, y @ y) - a = torch.rand(25, device="cuda") - b = torch.rand(5, 5, device="cuda") + a = torch.rand(25, device=GPU_TYPE) + b = torch.rand(5, 5, device=GPU_TYPE) compiled_fn = torch.compile(fn, backend="inductor") self.assertEqual(fn(a, b), compiled_fn(a, b)) @@ -822,7 +822,7 @@ class AOTAutogradCacheTests(InductorTestCase): def fn(a): return MyAutogradFunction.apply(a) - a = torch.randn(5, device="cuda", requires_grad=True) + a = torch.randn(5, device=GPU_TYPE, requires_grad=True) a2 = a.clone().detach_().requires_grad_(True) compiled_fn = torch.compile(fn, backend="inductor") result = compiled_fn(a) diff --git a/test/dynamo/test_reconstruct.py b/test/dynamo/test_reconstruct.py index 0cafaf9878e6..9f3d41964195 100644 --- a/test/dynamo/test_reconstruct.py +++ b/test/dynamo/test_reconstruct.py @@ -7,7 +7,7 @@ import unittest import torch import torch._dynamo.test_case from torch.testing._internal.common_utils import IS_FBCODE -from torch.testing._internal.inductor_utils import requires_triton +from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton from torch.utils._triton import ( has_triton_experimental_host_tma, has_triton_tensor_descriptor_host_tma, @@ -420,7 +420,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase): ) return tensor + 1, tma - x = torch.randn(128, 128, device="cuda") + x = torch.randn(128, 128, device=GPU_TYPE) ref = create_tma(x) res = torch.compile(create_tma, backend="eager")(x) @@ -441,7 +441,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase): ) return tensor + 1, tma - x = torch.randn(128, 128, device="cuda") + x = torch.randn(128, 128, device=GPU_TYPE) ref = create_tma(x) res = torch.compile(create_tma, backend="eager")(x) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index e0218cd9d8be..9fa13dc180f9 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -552,7 +552,7 @@ class AOTInductorTestsTemplate: triton.set_allocator( lambda size, align, stream: torch.empty( - size, dtype=torch.int8, device="cuda" + size, dtype=torch.int8, device=GPU_TYPE ) ) @@ -5235,9 +5235,9 @@ class AOTInductorTestsTemplate: return z example_inputs = ( - torch.randn(10, 20, device="cuda"), - torch.randn(20, 30, device="cuda"), - torch.randn(10, 30, device="cuda"), + torch.randn(10, 20, device=GPU_TYPE), + torch.randn(20, 30, device=GPU_TYPE), + torch.randn(10, 30, device=GPU_TYPE), ) model = Model() kernel_calls = [ diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 8e53725dd159..3597663431fd 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -2801,8 +2801,8 @@ class TestAutotuneCache(TestCase): def fn(x, y): return (x + y).relu() - x = torch.randn(100, 100).cuda() - y = torch.randn(100, 100).cuda() + x = torch.randn(100, 100).to(GPU_TYPE) + y = torch.randn(100, 100).to(GPU_TYPE) with config.patch( { diff --git a/test/inductor/test_inplace_padding.py b/test/inductor/test_inplace_padding.py index 46d5cf61121e..7ddd0dd4441b 100644 --- a/test/inductor/test_inplace_padding.py +++ b/test/inductor/test_inplace_padding.py @@ -233,9 +233,9 @@ class InplacePaddingTest(TestCase): loss.backward() return loss - x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16() + x = torch.randn(B * T, C, requires_grad=True).to(GPU_TYPE).bfloat16() x.retain_grad() - y = torch.randint(0, V, (B * T,)).cuda() + y = torch.randint(0, V, (B * T,)).to(GPU_TYPE) opt_f = torch.compile(f) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 93165fa2dcec..ff1d8c3fb875 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -142,8 +142,16 @@ class TestMaxAutotune(TestCase): return torch.mm(a, b) M, N, K = 21, 31, 11 - a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda() - b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda() + a = ( + torch.randn(*((K, M) if a_transposed else (M, K))) + .to(torch.float16) + .to(GPU_TYPE) + ) + b = ( + torch.randn(*((N, K) if b_transposed else (K, N))) + .to(torch.float16) + .to(GPU_TYPE) + ) with config.patch( { @@ -166,8 +174,8 @@ class TestMaxAutotune(TestCase): return torch.mm(a, b) M, N, K = 21, 31, 11 - a = torch.randn(M, K).to(torch.float16).cuda() - b = torch.randn(K, N).to(torch.float16).cuda() + a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) + b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) with ( self.assertRaises(BackendCompilerFailed) as context, @@ -194,8 +202,8 @@ class TestMaxAutotune(TestCase): return torch.mm(a, b) M, N, K = 21, 31, 11 - a = torch.randn(M, K).to(torch.float16).cuda() - b = torch.randn(K, N).to(torch.float16).cuda() + a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) + b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) # TMA requires 16-byte alignment: here we repeat the dims # by the factor of 8, as float16 is 2-byte. All dims are @@ -261,9 +269,17 @@ class TestMaxAutotune(TestCase): return torch.addmm(x, a, b) M, N, K = 21, 31, 11 - a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda() - b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda() - x = torch.randn(N).to(torch.float16).cuda() + a = ( + torch.randn(*((K, M) if a_transposed else (M, K))) + .to(torch.float16) + .to(GPU_TYPE) + ) + b = ( + torch.randn(*((N, K) if b_transposed else (K, N))) + .to(torch.float16) + .to(GPU_TYPE) + ) + x = torch.randn(N).to(torch.float16).to(GPU_TYPE) with config.patch( { @@ -286,9 +302,9 @@ class TestMaxAutotune(TestCase): return torch.addmm(x, a, b) M, N, K = 21, 31, 11 - a = torch.randn(M, K).to(torch.float16).cuda() - b = torch.randn(K, N).to(torch.float16).cuda() - x = torch.randn(N).to(torch.float16).cuda() + a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) + b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) + x = torch.randn(N).to(torch.float16).to(GPU_TYPE) with ( self.assertRaises(BackendCompilerFailed) as context, @@ -315,9 +331,9 @@ class TestMaxAutotune(TestCase): return torch.addmm(x, a, b) M, N, K = 21, 31, 11 - a = torch.randn(M, K).to(torch.float16).cuda() - b = torch.randn(K, N).to(torch.float16).cuda() - x = torch.randn(N).to(torch.float16).cuda() + a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) + b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) + x = torch.randn(N).to(torch.float16).to(GPU_TYPE) # TMA requires 16-byte alignment: here we repeat the dims # by the factor of 8, as float16 is 2-byte. All dims are @@ -362,15 +378,15 @@ class TestMaxAutotune(TestCase): # Create large matrices to ensure we use all possible sms size = 2560 - a = torch.randn(size, size, device="cuda", dtype=torch.bfloat16) + a = torch.randn(size, size, device=GPU_TYPE, dtype=torch.bfloat16) b = ( - torch.randn(size, size, device="cuda", dtype=torch.bfloat16) + torch.randn(size, size, device=GPU_TYPE, dtype=torch.bfloat16) .transpose(0, 1) .contiguous() .transpose(0, 1) ) - scale_a = torch.tensor(1, dtype=torch.float32, device="cuda") - scale_b = torch.tensor(1, dtype=torch.float32, device="cuda") + scale_a = torch.tensor(1, dtype=torch.float32, device=GPU_TYPE) + scale_b = torch.tensor(1, dtype=torch.float32, device=GPU_TYPE) args = ( (a.to(torch.float8_e4m3fn), b.to(torch.float8_e4m3fn), scale_a, scale_b) @@ -949,9 +965,9 @@ class TestMaxAutotune(TestCase): loss.backward() return loss - x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16() + x = torch.randn(B * T, C, requires_grad=True).to(GPU_TYPE).bfloat16() x.retain_grad() - y = torch.randint(0, V, (B * T,)).cuda() + y = torch.randint(0, V, (B * T,)).to(GPU_TYPE) import torch._inductor.utils as inductor_utils @@ -985,8 +1001,8 @@ class TestMaxAutotune(TestCase): M, N, K = sizes - a = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) - b = torch.randn(K, N, dtype=dtype, device="cuda", requires_grad=True) + a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True) + b = torch.randn(K, N, dtype=dtype, device=GPU_TYPE, requires_grad=True) possible_splits = range(2, min(K // M, K // N) + 1) @@ -1083,10 +1099,10 @@ class TestMaxAutotune(TestCase): return (a_in @ b).relu() a = torch.randn( - 32, 32768, dtype=torch.bfloat16, device="cuda", requires_grad=True + 32, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True ) b = torch.randn( - 32768, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True + 32768, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True ) torch._dynamo.reset() @@ -1126,9 +1142,11 @@ class TestMaxAutotune(TestCase): a_in = torch.cat([a for _ in range(256)], dim=0) return (a_in @ b).relu().sum() - a = torch.randn(8, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True) + a = torch.randn( + 8, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True + ) b = torch.randn( - 64, 32768, dtype=torch.bfloat16, device="cuda", requires_grad=True + 64, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True ) torch._dynamo.reset() @@ -1168,8 +1186,8 @@ class TestMaxAutotune(TestCase): a = a.transpose(0, 1) return a @ b - a = torch.randn((32768, 256), device="cuda", dtype=torch.bfloat16) - b = torch.randn((32768, 1152), device="cuda", dtype=torch.bfloat16) + a = torch.randn((32768, 256), device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn((32768, 1152), device=GPU_TYPE, dtype=torch.bfloat16) b = b[:, :1096] @@ -1522,8 +1540,8 @@ class TestMaxAutotune(TestCase): for M, N, K in shapes: get_k_splits.cache_clear() use_decompose_k_choice.cache_clear() - a = torch.randn(M, K, dtype=torch.float16, device="cuda") - b = torch.randn(K, N, dtype=torch.float16, device="cuda") + a = torch.randn(M, K, dtype=torch.float16, device=GPU_TYPE) + b = torch.randn(K, N, dtype=torch.float16, device=GPU_TYPE) with config.patch( { @@ -1560,8 +1578,8 @@ class TestMaxAutotune(TestCase): M, N, K = (1024, 1024, 1024) - a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True) - b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True) + a = torch.randn(M, K, dtype=torch.float16, device=GPU_TYPE, requires_grad=True) + b = torch.randn(K, N, dtype=torch.float16, device=GPU_TYPE, requires_grad=True) with mock.patch( "torch._inductor.template_registry.get_template_heuristic" diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index 2231b94316b3..81f7ea03d3bb 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -379,8 +379,8 @@ class TestOperatorReorderForPeakMemory(TestCase): return out, out2, inp2 @ inp2 - inp = torch.rand([256, 256], device="cuda") - inp2 = torch.rand([256, 256], device="cuda") + inp = torch.rand([256, 256], device=GPU_TYPE) + inp2 = torch.rand([256, 256], device=GPU_TYPE) def replace_foreach(gm): nodes = gm.find_nodes( diff --git a/test/inductor/test_op_dtype_prop.py b/test/inductor/test_op_dtype_prop.py index 458d64aa41d5..6f7eec601666 100644 --- a/test/inductor/test_op_dtype_prop.py +++ b/test/inductor/test_op_dtype_prop.py @@ -260,7 +260,7 @@ class TestCase(InductorTestCase): def fn(x, y): return x % y, x / y - x, y = (torch.rand([8], dtype=torch.float16, device="cuda") for _ in range(2)) + x, y = (torch.rand([8], dtype=torch.float16, device=GPU_TYPE) for _ in range(2)) out, code = run_and_get_code(torch.compile(fn), x, y) @@ -271,7 +271,7 @@ class TestCase(InductorTestCase): @config.patch("test_configs.runtime_triton_dtype_assert", True) def test_constant(self): def fn(): - return (torch.full((2, 3), 3.1416, device="cuda", dtype=torch.float16),) + return (torch.full((2, 3), 3.1416, device=GPU_TYPE, dtype=torch.float16),) out, code = run_and_get_code(torch.compile(fn)) FileCheck().check("static_assert").check_same(".dtype").run(code[0]) @@ -284,7 +284,7 @@ class TestCase(InductorTestCase): def fn(x): return torch.any(x) - x = torch.rand([40], device="cuda").to(torch.bool) + x = torch.rand([40], device=GPU_TYPE).to(torch.bool) out, code = run_and_get_code(torch.compile(fn), x) self.assertEqual(fn(x), out) @@ -293,7 +293,7 @@ class TestCase(InductorTestCase): def test_assoc_scan(self): from torch._higher_order_ops.associative_scan import associative_scan - x = torch.randn(10, device="cuda") + x = torch.randn(10, device=GPU_TYPE) # dtype check correctly associative_scan( lambda acc, curr: acc + torch.abs(curr), x, dim=-1, combine_mode="pointwise" diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index a9f898a36af5..4c2a04678b88 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -257,7 +257,7 @@ class TestTritonHeuristics(TestCase): def fn(x): return triton_sqr(x) - x = torch.randn(32, device="cuda") + x = torch.randn(32, device=GPU_TYPE) ref = fn(x) res = torch.compile(fn)(x) self.assertEqual(ref, res) diff --git a/tools/linter/adapters/test_device_bias_linter.py b/tools/linter/adapters/test_device_bias_linter.py index 00786ef3df86..a2079e4fe810 100644 --- a/tools/linter/adapters/test_device_bias_linter.py +++ b/tools/linter/adapters/test_device_bias_linter.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 """ This lint verifies that every Python test file (file that matches test_*.py or -*_test.py in the test folder) has a cuda hard code in `requires_gpu()` -decorated function to ensure that the test not fail on other GPU. - +*_test.py in the test folder) has a cuda hard code in `requires_gpu()` or +`requires_triton()` decorated function or `if HAS_GPU:` guarded main section, +to ensure that the test not fail on other GPU devices. """ from __future__ import annotations @@ -39,21 +39,59 @@ class LintMessage(NamedTuple): DEVICE_BIAS = ["cuda", "xpu", "mps"] +GPU_RELATED_DECORATORS = {"requires_gpu", "requires_triton"} + + +def is_main_has_gpu(tree: ast.AST) -> bool: + def _contains_has_gpu(node: ast.AST) -> bool: + if isinstance(node, ast.Name) and node.id in ["HAS_GPU", "RUN_GPU"]: + return True + elif isinstance(node, ast.BoolOp): + return any(_contains_has_gpu(value) for value in node.values) + elif isinstance(node, ast.UnaryOp): + return _contains_has_gpu(node.operand) + elif isinstance(node, ast.Compare): + return _contains_has_gpu(node.left) or any( + _contains_has_gpu(comp) for comp in node.comparators + ) + elif isinstance(node, (ast.IfExp, ast.Call)): + return False + return False + + for node in ast.walk(tree): + # Detect if __name__ == "__main__": + if isinstance(node, ast.If): + if ( + isinstance(node.test, ast.Compare) + and isinstance(node.test.left, ast.Name) + and node.test.left.id == "__name__" + ): + if any( + isinstance(comp, ast.Constant) and comp.value == "__main__" + for comp in node.test.comparators + ): + for inner_node in node.body: + if isinstance(inner_node, ast.If) and _contains_has_gpu( + inner_node.test + ): + return True + return False class DeviceBiasVisitor(ast.NodeVisitor): - def __init__(self, filename: str): + def __init__(self, filename: str, is_gpu_test_suite: bool) -> None: self.filename = filename self.lint_messages: list[LintMessage] = [] + self.is_gpu_test_suite = is_gpu_test_suite - def _has_requires_gpu_decorator(self, node: ast.FunctionDef) -> bool: + def _has_proper_decorator(self, node: ast.FunctionDef) -> bool: for d in node.decorator_list: - if isinstance(d, ast.Name) and d.id == "requires_gpu": + if isinstance(d, ast.Name) and d.id in GPU_RELATED_DECORATORS: return True if ( isinstance(d, ast.Call) and isinstance(d.func, ast.Name) - and d.func.id == "requires_gpu" + and d.func.id in GPU_RELATED_DECORATORS ): return True return False @@ -62,7 +100,6 @@ class DeviceBiasVisitor(ast.NodeVisitor): def _check_keyword_device(self, subnode: ast.keyword, msg_prefix: str) -> None: if subnode.arg != "device": return - val = subnode.value if isinstance(val, ast.Constant) and any( bias in val.value for bias in DEVICE_BIAS @@ -124,15 +161,7 @@ class DeviceBiasVisitor(ast.NodeVisitor): f"{msg_prefix} `with torch.device('{ctx_expr.args[0].value}')`, suggest to use torch.device(GPU_TYPE)", ) - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - # Check if the function is decorated with @requires_gpu, which indicates - # that the function is intended to run on GPU devices (e.g., CUDA or XPU), - # but ensure it does not hardcode the device to CUDA. - if not self._has_requires_gpu_decorator(node): - self.generic_visit(node) - return - - msg_prefix = "`@requires_gpu` function should not hardcode" + def _check_node(self, node: ast.AST, msg_prefix: str) -> None: for subnode in ast.walk(node): if isinstance(subnode, ast.keyword): self._check_keyword_device(subnode, msg_prefix) @@ -143,6 +172,16 @@ class DeviceBiasVisitor(ast.NodeVisitor): elif isinstance(subnode, ast.With): self._check_with_statement(subnode, msg_prefix) + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + if self._has_proper_decorator(node): + msg_prefix = ( + "`@requires_gpu` or `@requires_triton` function should not hardcode" + ) + self._check_node(node, msg_prefix) + elif self.is_gpu_test_suite: + # If the function is guarded by HAS_GPU in main(), we still need to check for device bias + msg_prefix = "The test suites is shared amount GPUS, should not hardcode" + self._check_node(node, msg_prefix) self.generic_visit(node) def record(self, node: ast.AST, message: str) -> None: @@ -165,16 +204,16 @@ def check_file(filename: str) -> list[LintMessage]: with open(filename) as f: source = f.read() tree = ast.parse(source, filename=filename) - checker = DeviceBiasVisitor(filename) + is_gpu_test_suite = is_main_has_gpu(tree) + checker = DeviceBiasVisitor(filename, is_gpu_test_suite) checker.visit(tree) - return checker.lint_messages def main() -> None: parser = argparse.ArgumentParser( - description="Detect Device bias in python functions decorated with [require_gpu]" - " that may potentially break support for other GPU devices.", + description="Detect Device bias in functions decorated with requires_gpu/requires_triton" + " or guarded by HAS_GPU block in main() that may break other GPU devices.", fromfile_prefix_chars="@", ) parser.add_argument(