mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Linter] Expanding the scope of detecting device-bias code. (#159949)
Currently, the device-bias linter only targets functions decorated with @requires_gpu. This PR adds support for two new detection scenarios: 1. Detect device-bias code in functions decorated with @requires_triton. 2. Detect device-bias code for entire test suites that are defined as shared across GPUs. For example: ``` if __name__ == "__main__": if HAS_GPU: run_tests() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159949 Approved by: https://github.com/EikanWang, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
4183d4ff3d
commit
8047421fbb
@ -447,8 +447,8 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||
def fn(x, y):
|
||||
return (x * 2, y @ y)
|
||||
|
||||
a = torch.rand(25, device="cuda")
|
||||
b = torch.rand(5, 5, device="cuda")
|
||||
a = torch.rand(25, device=GPU_TYPE)
|
||||
b = torch.rand(5, 5, device=GPU_TYPE)
|
||||
|
||||
compiled_fn = torch.compile(fn, backend="inductor")
|
||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||
@ -822,7 +822,7 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||
def fn(a):
|
||||
return MyAutogradFunction.apply(a)
|
||||
|
||||
a = torch.randn(5, device="cuda", requires_grad=True)
|
||||
a = torch.randn(5, device=GPU_TYPE, requires_grad=True)
|
||||
a2 = a.clone().detach_().requires_grad_(True)
|
||||
compiled_fn = torch.compile(fn, backend="inductor")
|
||||
result = compiled_fn(a)
|
||||
|
@ -7,7 +7,7 @@ import unittest
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch.testing._internal.common_utils import IS_FBCODE
|
||||
from torch.testing._internal.inductor_utils import requires_triton
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton
|
||||
from torch.utils._triton import (
|
||||
has_triton_experimental_host_tma,
|
||||
has_triton_tensor_descriptor_host_tma,
|
||||
@ -420,7 +420,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
|
||||
)
|
||||
return tensor + 1, tma
|
||||
|
||||
x = torch.randn(128, 128, device="cuda")
|
||||
x = torch.randn(128, 128, device=GPU_TYPE)
|
||||
|
||||
ref = create_tma(x)
|
||||
res = torch.compile(create_tma, backend="eager")(x)
|
||||
@ -441,7 +441,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
|
||||
)
|
||||
return tensor + 1, tma
|
||||
|
||||
x = torch.randn(128, 128, device="cuda")
|
||||
x = torch.randn(128, 128, device=GPU_TYPE)
|
||||
|
||||
ref = create_tma(x)
|
||||
res = torch.compile(create_tma, backend="eager")(x)
|
||||
|
@ -552,7 +552,7 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
triton.set_allocator(
|
||||
lambda size, align, stream: torch.empty(
|
||||
size, dtype=torch.int8, device="cuda"
|
||||
size, dtype=torch.int8, device=GPU_TYPE
|
||||
)
|
||||
)
|
||||
|
||||
@ -5235,9 +5235,9 @@ class AOTInductorTestsTemplate:
|
||||
return z
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(10, 20, device="cuda"),
|
||||
torch.randn(20, 30, device="cuda"),
|
||||
torch.randn(10, 30, device="cuda"),
|
||||
torch.randn(10, 20, device=GPU_TYPE),
|
||||
torch.randn(20, 30, device=GPU_TYPE),
|
||||
torch.randn(10, 30, device=GPU_TYPE),
|
||||
)
|
||||
model = Model()
|
||||
kernel_calls = [
|
||||
|
@ -2801,8 +2801,8 @@ class TestAutotuneCache(TestCase):
|
||||
def fn(x, y):
|
||||
return (x + y).relu()
|
||||
|
||||
x = torch.randn(100, 100).cuda()
|
||||
y = torch.randn(100, 100).cuda()
|
||||
x = torch.randn(100, 100).to(GPU_TYPE)
|
||||
y = torch.randn(100, 100).to(GPU_TYPE)
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
|
@ -233,9 +233,9 @@ class InplacePaddingTest(TestCase):
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16()
|
||||
x = torch.randn(B * T, C, requires_grad=True).to(GPU_TYPE).bfloat16()
|
||||
x.retain_grad()
|
||||
y = torch.randint(0, V, (B * T,)).cuda()
|
||||
y = torch.randint(0, V, (B * T,)).to(GPU_TYPE)
|
||||
|
||||
opt_f = torch.compile(f)
|
||||
|
||||
|
@ -142,8 +142,16 @@ class TestMaxAutotune(TestCase):
|
||||
return torch.mm(a, b)
|
||||
|
||||
M, N, K = 21, 31, 11
|
||||
a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda()
|
||||
b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda()
|
||||
a = (
|
||||
torch.randn(*((K, M) if a_transposed else (M, K)))
|
||||
.to(torch.float16)
|
||||
.to(GPU_TYPE)
|
||||
)
|
||||
b = (
|
||||
torch.randn(*((N, K) if b_transposed else (K, N)))
|
||||
.to(torch.float16)
|
||||
.to(GPU_TYPE)
|
||||
)
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
@ -166,8 +174,8 @@ class TestMaxAutotune(TestCase):
|
||||
return torch.mm(a, b)
|
||||
|
||||
M, N, K = 21, 31, 11
|
||||
a = torch.randn(M, K).to(torch.float16).cuda()
|
||||
b = torch.randn(K, N).to(torch.float16).cuda()
|
||||
a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE)
|
||||
b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE)
|
||||
|
||||
with (
|
||||
self.assertRaises(BackendCompilerFailed) as context,
|
||||
@ -194,8 +202,8 @@ class TestMaxAutotune(TestCase):
|
||||
return torch.mm(a, b)
|
||||
|
||||
M, N, K = 21, 31, 11
|
||||
a = torch.randn(M, K).to(torch.float16).cuda()
|
||||
b = torch.randn(K, N).to(torch.float16).cuda()
|
||||
a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE)
|
||||
b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE)
|
||||
|
||||
# TMA requires 16-byte alignment: here we repeat the dims
|
||||
# by the factor of 8, as float16 is 2-byte. All dims are
|
||||
@ -261,9 +269,17 @@ class TestMaxAutotune(TestCase):
|
||||
return torch.addmm(x, a, b)
|
||||
|
||||
M, N, K = 21, 31, 11
|
||||
a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda()
|
||||
b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda()
|
||||
x = torch.randn(N).to(torch.float16).cuda()
|
||||
a = (
|
||||
torch.randn(*((K, M) if a_transposed else (M, K)))
|
||||
.to(torch.float16)
|
||||
.to(GPU_TYPE)
|
||||
)
|
||||
b = (
|
||||
torch.randn(*((N, K) if b_transposed else (K, N)))
|
||||
.to(torch.float16)
|
||||
.to(GPU_TYPE)
|
||||
)
|
||||
x = torch.randn(N).to(torch.float16).to(GPU_TYPE)
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
@ -286,9 +302,9 @@ class TestMaxAutotune(TestCase):
|
||||
return torch.addmm(x, a, b)
|
||||
|
||||
M, N, K = 21, 31, 11
|
||||
a = torch.randn(M, K).to(torch.float16).cuda()
|
||||
b = torch.randn(K, N).to(torch.float16).cuda()
|
||||
x = torch.randn(N).to(torch.float16).cuda()
|
||||
a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE)
|
||||
b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE)
|
||||
x = torch.randn(N).to(torch.float16).to(GPU_TYPE)
|
||||
|
||||
with (
|
||||
self.assertRaises(BackendCompilerFailed) as context,
|
||||
@ -315,9 +331,9 @@ class TestMaxAutotune(TestCase):
|
||||
return torch.addmm(x, a, b)
|
||||
|
||||
M, N, K = 21, 31, 11
|
||||
a = torch.randn(M, K).to(torch.float16).cuda()
|
||||
b = torch.randn(K, N).to(torch.float16).cuda()
|
||||
x = torch.randn(N).to(torch.float16).cuda()
|
||||
a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE)
|
||||
b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE)
|
||||
x = torch.randn(N).to(torch.float16).to(GPU_TYPE)
|
||||
|
||||
# TMA requires 16-byte alignment: here we repeat the dims
|
||||
# by the factor of 8, as float16 is 2-byte. All dims are
|
||||
@ -362,15 +378,15 @@ class TestMaxAutotune(TestCase):
|
||||
|
||||
# Create large matrices to ensure we use all possible sms
|
||||
size = 2560
|
||||
a = torch.randn(size, size, device="cuda", dtype=torch.bfloat16)
|
||||
a = torch.randn(size, size, device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
b = (
|
||||
torch.randn(size, size, device="cuda", dtype=torch.bfloat16)
|
||||
torch.randn(size, size, device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.tensor(1, dtype=torch.float32, device="cuda")
|
||||
scale_b = torch.tensor(1, dtype=torch.float32, device="cuda")
|
||||
scale_a = torch.tensor(1, dtype=torch.float32, device=GPU_TYPE)
|
||||
scale_b = torch.tensor(1, dtype=torch.float32, device=GPU_TYPE)
|
||||
|
||||
args = (
|
||||
(a.to(torch.float8_e4m3fn), b.to(torch.float8_e4m3fn), scale_a, scale_b)
|
||||
@ -949,9 +965,9 @@ class TestMaxAutotune(TestCase):
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16()
|
||||
x = torch.randn(B * T, C, requires_grad=True).to(GPU_TYPE).bfloat16()
|
||||
x.retain_grad()
|
||||
y = torch.randint(0, V, (B * T,)).cuda()
|
||||
y = torch.randint(0, V, (B * T,)).to(GPU_TYPE)
|
||||
|
||||
import torch._inductor.utils as inductor_utils
|
||||
|
||||
@ -985,8 +1001,8 @@ class TestMaxAutotune(TestCase):
|
||||
|
||||
M, N, K = sizes
|
||||
|
||||
a = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True)
|
||||
b = torch.randn(K, N, dtype=dtype, device="cuda", requires_grad=True)
|
||||
a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
|
||||
b = torch.randn(K, N, dtype=dtype, device=GPU_TYPE, requires_grad=True)
|
||||
|
||||
possible_splits = range(2, min(K // M, K // N) + 1)
|
||||
|
||||
@ -1083,10 +1099,10 @@ class TestMaxAutotune(TestCase):
|
||||
return (a_in @ b).relu()
|
||||
|
||||
a = torch.randn(
|
||||
32, 32768, dtype=torch.bfloat16, device="cuda", requires_grad=True
|
||||
32, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
b = torch.randn(
|
||||
32768, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True
|
||||
32768, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
@ -1126,9 +1142,11 @@ class TestMaxAutotune(TestCase):
|
||||
a_in = torch.cat([a for _ in range(256)], dim=0)
|
||||
return (a_in @ b).relu().sum()
|
||||
|
||||
a = torch.randn(8, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||
a = torch.randn(
|
||||
8, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
b = torch.randn(
|
||||
64, 32768, dtype=torch.bfloat16, device="cuda", requires_grad=True
|
||||
64, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
@ -1168,8 +1186,8 @@ class TestMaxAutotune(TestCase):
|
||||
a = a.transpose(0, 1)
|
||||
return a @ b
|
||||
|
||||
a = torch.randn((32768, 256), device="cuda", dtype=torch.bfloat16)
|
||||
b = torch.randn((32768, 1152), device="cuda", dtype=torch.bfloat16)
|
||||
a = torch.randn((32768, 256), device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
b = torch.randn((32768, 1152), device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
|
||||
b = b[:, :1096]
|
||||
|
||||
@ -1522,8 +1540,8 @@ class TestMaxAutotune(TestCase):
|
||||
for M, N, K in shapes:
|
||||
get_k_splits.cache_clear()
|
||||
use_decompose_k_choice.cache_clear()
|
||||
a = torch.randn(M, K, dtype=torch.float16, device="cuda")
|
||||
b = torch.randn(K, N, dtype=torch.float16, device="cuda")
|
||||
a = torch.randn(M, K, dtype=torch.float16, device=GPU_TYPE)
|
||||
b = torch.randn(K, N, dtype=torch.float16, device=GPU_TYPE)
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
@ -1560,8 +1578,8 @@ class TestMaxAutotune(TestCase):
|
||||
|
||||
M, N, K = (1024, 1024, 1024)
|
||||
|
||||
a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True)
|
||||
b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True)
|
||||
a = torch.randn(M, K, dtype=torch.float16, device=GPU_TYPE, requires_grad=True)
|
||||
b = torch.randn(K, N, dtype=torch.float16, device=GPU_TYPE, requires_grad=True)
|
||||
|
||||
with mock.patch(
|
||||
"torch._inductor.template_registry.get_template_heuristic"
|
||||
|
@ -379,8 +379,8 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
||||
|
||||
return out, out2, inp2 @ inp2
|
||||
|
||||
inp = torch.rand([256, 256], device="cuda")
|
||||
inp2 = torch.rand([256, 256], device="cuda")
|
||||
inp = torch.rand([256, 256], device=GPU_TYPE)
|
||||
inp2 = torch.rand([256, 256], device=GPU_TYPE)
|
||||
|
||||
def replace_foreach(gm):
|
||||
nodes = gm.find_nodes(
|
||||
|
@ -260,7 +260,7 @@ class TestCase(InductorTestCase):
|
||||
def fn(x, y):
|
||||
return x % y, x / y
|
||||
|
||||
x, y = (torch.rand([8], dtype=torch.float16, device="cuda") for _ in range(2))
|
||||
x, y = (torch.rand([8], dtype=torch.float16, device=GPU_TYPE) for _ in range(2))
|
||||
|
||||
out, code = run_and_get_code(torch.compile(fn), x, y)
|
||||
|
||||
@ -271,7 +271,7 @@ class TestCase(InductorTestCase):
|
||||
@config.patch("test_configs.runtime_triton_dtype_assert", True)
|
||||
def test_constant(self):
|
||||
def fn():
|
||||
return (torch.full((2, 3), 3.1416, device="cuda", dtype=torch.float16),)
|
||||
return (torch.full((2, 3), 3.1416, device=GPU_TYPE, dtype=torch.float16),)
|
||||
|
||||
out, code = run_and_get_code(torch.compile(fn))
|
||||
FileCheck().check("static_assert").check_same(".dtype").run(code[0])
|
||||
@ -284,7 +284,7 @@ class TestCase(InductorTestCase):
|
||||
def fn(x):
|
||||
return torch.any(x)
|
||||
|
||||
x = torch.rand([40], device="cuda").to(torch.bool)
|
||||
x = torch.rand([40], device=GPU_TYPE).to(torch.bool)
|
||||
out, code = run_and_get_code(torch.compile(fn), x)
|
||||
self.assertEqual(fn(x), out)
|
||||
|
||||
@ -293,7 +293,7 @@ class TestCase(InductorTestCase):
|
||||
def test_assoc_scan(self):
|
||||
from torch._higher_order_ops.associative_scan import associative_scan
|
||||
|
||||
x = torch.randn(10, device="cuda")
|
||||
x = torch.randn(10, device=GPU_TYPE)
|
||||
# dtype check correctly
|
||||
associative_scan(
|
||||
lambda acc, curr: acc + torch.abs(curr), x, dim=-1, combine_mode="pointwise"
|
||||
|
@ -257,7 +257,7 @@ class TestTritonHeuristics(TestCase):
|
||||
def fn(x):
|
||||
return triton_sqr(x)
|
||||
|
||||
x = torch.randn(32, device="cuda")
|
||||
x = torch.randn(32, device=GPU_TYPE)
|
||||
ref = fn(x)
|
||||
res = torch.compile(fn)(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
@ -1,9 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This lint verifies that every Python test file (file that matches test_*.py or
|
||||
*_test.py in the test folder) has a cuda hard code in `requires_gpu()`
|
||||
decorated function to ensure that the test not fail on other GPU.
|
||||
|
||||
*_test.py in the test folder) has a cuda hard code in `requires_gpu()` or
|
||||
`requires_triton()` decorated function or `if HAS_GPU:` guarded main section,
|
||||
to ensure that the test not fail on other GPU devices.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -39,21 +39,59 @@ class LintMessage(NamedTuple):
|
||||
|
||||
|
||||
DEVICE_BIAS = ["cuda", "xpu", "mps"]
|
||||
GPU_RELATED_DECORATORS = {"requires_gpu", "requires_triton"}
|
||||
|
||||
|
||||
def is_main_has_gpu(tree: ast.AST) -> bool:
|
||||
def _contains_has_gpu(node: ast.AST) -> bool:
|
||||
if isinstance(node, ast.Name) and node.id in ["HAS_GPU", "RUN_GPU"]:
|
||||
return True
|
||||
elif isinstance(node, ast.BoolOp):
|
||||
return any(_contains_has_gpu(value) for value in node.values)
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
return _contains_has_gpu(node.operand)
|
||||
elif isinstance(node, ast.Compare):
|
||||
return _contains_has_gpu(node.left) or any(
|
||||
_contains_has_gpu(comp) for comp in node.comparators
|
||||
)
|
||||
elif isinstance(node, (ast.IfExp, ast.Call)):
|
||||
return False
|
||||
return False
|
||||
|
||||
for node in ast.walk(tree):
|
||||
# Detect if __name__ == "__main__":
|
||||
if isinstance(node, ast.If):
|
||||
if (
|
||||
isinstance(node.test, ast.Compare)
|
||||
and isinstance(node.test.left, ast.Name)
|
||||
and node.test.left.id == "__name__"
|
||||
):
|
||||
if any(
|
||||
isinstance(comp, ast.Constant) and comp.value == "__main__"
|
||||
for comp in node.test.comparators
|
||||
):
|
||||
for inner_node in node.body:
|
||||
if isinstance(inner_node, ast.If) and _contains_has_gpu(
|
||||
inner_node.test
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class DeviceBiasVisitor(ast.NodeVisitor):
|
||||
def __init__(self, filename: str):
|
||||
def __init__(self, filename: str, is_gpu_test_suite: bool) -> None:
|
||||
self.filename = filename
|
||||
self.lint_messages: list[LintMessage] = []
|
||||
self.is_gpu_test_suite = is_gpu_test_suite
|
||||
|
||||
def _has_requires_gpu_decorator(self, node: ast.FunctionDef) -> bool:
|
||||
def _has_proper_decorator(self, node: ast.FunctionDef) -> bool:
|
||||
for d in node.decorator_list:
|
||||
if isinstance(d, ast.Name) and d.id == "requires_gpu":
|
||||
if isinstance(d, ast.Name) and d.id in GPU_RELATED_DECORATORS:
|
||||
return True
|
||||
if (
|
||||
isinstance(d, ast.Call)
|
||||
and isinstance(d.func, ast.Name)
|
||||
and d.func.id == "requires_gpu"
|
||||
and d.func.id in GPU_RELATED_DECORATORS
|
||||
):
|
||||
return True
|
||||
return False
|
||||
@ -62,7 +100,6 @@ class DeviceBiasVisitor(ast.NodeVisitor):
|
||||
def _check_keyword_device(self, subnode: ast.keyword, msg_prefix: str) -> None:
|
||||
if subnode.arg != "device":
|
||||
return
|
||||
|
||||
val = subnode.value
|
||||
if isinstance(val, ast.Constant) and any(
|
||||
bias in val.value for bias in DEVICE_BIAS
|
||||
@ -124,15 +161,7 @@ class DeviceBiasVisitor(ast.NodeVisitor):
|
||||
f"{msg_prefix} `with torch.device('{ctx_expr.args[0].value}')`, suggest to use torch.device(GPU_TYPE)",
|
||||
)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
# Check if the function is decorated with @requires_gpu, which indicates
|
||||
# that the function is intended to run on GPU devices (e.g., CUDA or XPU),
|
||||
# but ensure it does not hardcode the device to CUDA.
|
||||
if not self._has_requires_gpu_decorator(node):
|
||||
self.generic_visit(node)
|
||||
return
|
||||
|
||||
msg_prefix = "`@requires_gpu` function should not hardcode"
|
||||
def _check_node(self, node: ast.AST, msg_prefix: str) -> None:
|
||||
for subnode in ast.walk(node):
|
||||
if isinstance(subnode, ast.keyword):
|
||||
self._check_keyword_device(subnode, msg_prefix)
|
||||
@ -143,6 +172,16 @@ class DeviceBiasVisitor(ast.NodeVisitor):
|
||||
elif isinstance(subnode, ast.With):
|
||||
self._check_with_statement(subnode, msg_prefix)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
if self._has_proper_decorator(node):
|
||||
msg_prefix = (
|
||||
"`@requires_gpu` or `@requires_triton` function should not hardcode"
|
||||
)
|
||||
self._check_node(node, msg_prefix)
|
||||
elif self.is_gpu_test_suite:
|
||||
# If the function is guarded by HAS_GPU in main(), we still need to check for device bias
|
||||
msg_prefix = "The test suites is shared amount GPUS, should not hardcode"
|
||||
self._check_node(node, msg_prefix)
|
||||
self.generic_visit(node)
|
||||
|
||||
def record(self, node: ast.AST, message: str) -> None:
|
||||
@ -165,16 +204,16 @@ def check_file(filename: str) -> list[LintMessage]:
|
||||
with open(filename) as f:
|
||||
source = f.read()
|
||||
tree = ast.parse(source, filename=filename)
|
||||
checker = DeviceBiasVisitor(filename)
|
||||
is_gpu_test_suite = is_main_has_gpu(tree)
|
||||
checker = DeviceBiasVisitor(filename, is_gpu_test_suite)
|
||||
checker.visit(tree)
|
||||
|
||||
return checker.lint_messages
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Detect Device bias in python functions decorated with [require_gpu]"
|
||||
" that may potentially break support for other GPU devices.",
|
||||
description="Detect Device bias in functions decorated with requires_gpu/requires_triton"
|
||||
" or guarded by HAS_GPU block in main() that may break other GPU devices.",
|
||||
fromfile_prefix_chars="@",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
Reference in New Issue
Block a user