mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Since XPU does not gate community pull requests, we’ve observed that contributors often hardcode "cuda" in functions decorated with @requires_gpu() when adding new test cases. This causes the tests to fail on XPU and breaks XPU CI.
This PR adds a linter to detect such issues automatically. An example is shown below.
```
Error (TEST_DEVICE_BIAS) [device-bias]
`@requires_gpu` function should not hardcode device='cuda'
11670 | .contiguous()
11671 | )
11672 |
>>> 11673 | inp = torch.rand((64, 64), device="cuda") * 2 - 1
11674 | boundaries = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9])
11675 |
11676 | self.common(fn, (inp, boundaries), check_lowp=False)
Error (TEST_DEVICE_BIAS) [device-bias]
`@requires_gpu` function should not hardcode .cuda() call
11700 | self.assertEqual(ref, res)
11701 |
11702 | for offset2 in (0, 1, 2, 3, 4):
>>> 11703 | base2 = torch.randn(64 * 64 + 64, dtype=torch.float32).cuda()
11704 | inp2 = torch.as_strided(base2, (64, 64), (64, 1), offset2)
11705 | ref2 = fn(inp2)
11706 | res2 = fn_c(inp2)
Error (TEST_DEVICE_BIAS) [device-bias]
`@requires_gpu` function should not hardcode torch.device('cuda:0')
11723 | return x.sin() + x.cos()
11724 |
11725 | base = torch.randn(
>>> 11726 | 64 * 64 + 64, dtype=torch.float32, device=torch.device("cuda:0")
11727 | )
11728 |
11729 | inp1 = torch.as_strided(base, (32, 32), (32, 1), 4)
Error (TEST_DEVICE_BIAS) [device-bias]
`@requires_gpu` function should not hardcode .to('cuda') call
11771 | torch.manual_seed(42)
11772 | base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=self.device)
11773 | torch.manual_seed(42)
>>> 11774 | base_ref = torch.randn(64 * 64 + 64, dtype=torch.float32).to("cuda")
11775 |
11776 | inp = torch.as_strided(base, size, stride, offset)
11777 | inp_ref = torch.as_strided(base_ref, size, stride, offset)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152948
Approved by: https://github.com/EikanWang, https://github.com/cyyever, https://github.com/malfet, https://github.com/jansel