Fix autocast context manager when there is exception (#159565)

Summary: When exception occurs inside context manager, we need to either return False OR properly propagage exceptions via __exit__(exc_type, exc_val). But previously while tracing, we don't actually run the exit node so we end up swallowing the exception in a very weird way as outlined in https://github.com/pytorch/pytorch/issues/153202. This PR fixes it

Test Plan:
new test case

Rollback Plan:

Differential Revision: D79348382

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159565
Approved by: https://github.com/zou3519, https://github.com/yushangdi
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar
2025-08-01 02:12:24 +00:00
committed by PyTorch MergeBot
parent 83e2ea8135
commit 2ac45c2752
2 changed files with 53 additions and 2 deletions

View File

@ -14946,6 +14946,51 @@ class GraphModule(torch.nn.Module):
self.assertEqual(x.sin(), ep.module()(x))
pytree._deregister_pytree_node(torch.FunctionSchema)
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
def test_exception(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)
self.register_buffer("buffer", torch.ones(4, 4))
self.register_buffer("param", torch.ones(4, 4))
def forward(self, x):
token_ids = torch.randint(0, 10, (4,), device=x.device)
embedded = self.embedding(token_ids).sum()
return self.buffer.sum() + self.param.sum() + x.sum() + embedded
class BarModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = Model()
def forward(self, x):
if "cuda" in str(x.device):
mod = self.mod.to(x.device)
return mod(x)
else:
return x.sum()
class BarBar(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = BarModel()
def forward(self, x):
with torch.amp.autocast(device_type="cuda"):
y = self.mod(x)
return y
with torch.no_grad():
with self.assertRaisesRegex(RuntimeError, "Couldn't swap Embedding.weight"):
_ = torch.export.export(
BarBar(),
(),
{"x": torch.randn(4, 4, 4, device="cuda")},
strict=False,
).module()
def test_export_for_training_with_state_dict_hooks(self):
def _state_dict_pre_hook(mod, prefix, keep_vars):
mod._buffers["test"] = torch.Tensor([1])

View File

@ -397,7 +397,10 @@ class autocast:
self._enabled,
self._cache_enabled,
)
return mode.__torch_function__(torch.amp._enter_autocast, (), args)
mode.__torch_function__(torch.amp._enter_autocast, (), args)
return self
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
if torch._jit_internal.is_scripting():
@ -420,7 +423,10 @@ class autocast:
mode,
torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode,
):
return mode.__torch_function__(torch.amp._exit_autocast, (), ())
mode.__torch_function__(torch.amp._exit_autocast, (), ())
# This is very important because the above line actually doesn't
# run exit code so it end up swallowing exceptions.
return False
return False
def __call__(self, func):