mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] use torch.amp.autocast instead of torch.cuda.amp.autocast (#134291)
torch.cuda.amp.autocast / torch.cpu.amp.autocast are deprecated and spew a ton of warnings when these tests run. This PR: Update to just use torch.amp.autocast(device). Note: this uncovers a bug in the test: when `device` is CUDA, it actually shows up as "cuda:0" - so previously, this test was _always_ using `torch.cpu.amp.autocast` even for `cuda` device. This PR fixes this, and uncovers additional bugs in `pinverse` and `linalg.pinv`; `linalg.pinv` was already failing before on CPU, but now the test also catches failures on CUDA, (and this PR adds to the skipped-test list). Pull Request resolved: https://github.com/pytorch/pytorch/pull/134291 Approved by: https://github.com/YuqingJ
This commit is contained in:
committed by
PyTorch MergeBot
parent
a1061009c9
commit
d433a603af
@ -2338,6 +2338,7 @@ fake_autocast_device_skips = defaultdict(dict)
|
||||
|
||||
# TODO: investigate/fix
|
||||
fake_autocast_device_skips["cpu"] = {"linalg.pinv"}
|
||||
fake_autocast_device_skips["cuda"] = {"linalg.pinv", "pinverse"}
|
||||
|
||||
|
||||
dynamic_output_op_tests = (
|
||||
@ -2575,12 +2576,14 @@ class TestFakeTensor(TestCase):
|
||||
|
||||
@ops(op_db, dtypes=OpDTypes.any_one)
|
||||
def test_fake_autocast(self, device, dtype, op):
|
||||
if op.name in fake_autocast_device_skips[device]:
|
||||
device_type = torch.device(device).type
|
||||
if op.name in fake_autocast_device_skips[device_type]:
|
||||
self.skipTest("Skip failing test")
|
||||
context = (
|
||||
torch.cuda.amp.autocast if device == "cuda" else torch.cpu.amp.autocast
|
||||
)
|
||||
self._test_fake_helper(device, dtype, op, context)
|
||||
|
||||
def context_fn():
|
||||
return torch.amp.autocast(device_type)
|
||||
|
||||
self._test_fake_helper(device, dtype, op, context_fn)
|
||||
|
||||
def _test_fake_crossref_helper(self, device, dtype, op, context):
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
|
||||
Reference in New Issue
Block a user