mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix autocast context manager when there is exception (#159565)
Summary: When exception occurs inside context manager, we need to either return False OR properly propagage exceptions via __exit__(exc_type, exc_val). But previously while tracing, we don't actually run the exit node so we end up swallowing the exception in a very weird way as outlined in https://github.com/pytorch/pytorch/issues/153202. This PR fixes it Test Plan: new test case Rollback Plan: Differential Revision: D79348382 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159565 Approved by: https://github.com/zou3519, https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
83e2ea8135
commit
2ac45c2752
@ -14946,6 +14946,51 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(x.sin(), ep.module()(x))
|
||||
pytree._deregister_pytree_node(torch.FunctionSchema)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
||||
def test_exception(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)
|
||||
self.register_buffer("buffer", torch.ones(4, 4))
|
||||
self.register_buffer("param", torch.ones(4, 4))
|
||||
|
||||
def forward(self, x):
|
||||
token_ids = torch.randint(0, 10, (4,), device=x.device)
|
||||
embedded = self.embedding(token_ids).sum()
|
||||
return self.buffer.sum() + self.param.sum() + x.sum() + embedded
|
||||
|
||||
class BarModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mod = Model()
|
||||
|
||||
def forward(self, x):
|
||||
if "cuda" in str(x.device):
|
||||
mod = self.mod.to(x.device)
|
||||
return mod(x)
|
||||
else:
|
||||
return x.sum()
|
||||
|
||||
class BarBar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mod = BarModel()
|
||||
|
||||
def forward(self, x):
|
||||
with torch.amp.autocast(device_type="cuda"):
|
||||
y = self.mod(x)
|
||||
return y
|
||||
|
||||
with torch.no_grad():
|
||||
with self.assertRaisesRegex(RuntimeError, "Couldn't swap Embedding.weight"):
|
||||
_ = torch.export.export(
|
||||
BarBar(),
|
||||
(),
|
||||
{"x": torch.randn(4, 4, 4, device="cuda")},
|
||||
strict=False,
|
||||
).module()
|
||||
|
||||
def test_export_for_training_with_state_dict_hooks(self):
|
||||
def _state_dict_pre_hook(mod, prefix, keep_vars):
|
||||
mod._buffers["test"] = torch.Tensor([1])
|
||||
|
@ -397,7 +397,10 @@ class autocast:
|
||||
self._enabled,
|
||||
self._cache_enabled,
|
||||
)
|
||||
return mode.__torch_function__(torch.amp._enter_autocast, (), args)
|
||||
mode.__torch_function__(torch.amp._enter_autocast, (), args)
|
||||
return self
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
|
||||
if torch._jit_internal.is_scripting():
|
||||
@ -420,7 +423,10 @@ class autocast:
|
||||
mode,
|
||||
torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode,
|
||||
):
|
||||
return mode.__torch_function__(torch.amp._exit_autocast, (), ())
|
||||
mode.__torch_function__(torch.amp._exit_autocast, (), ())
|
||||
# This is very important because the above line actually doesn't
|
||||
# run exit code so it end up swallowing exceptions.
|
||||
return False
|
||||
return False
|
||||
|
||||
def __call__(self, func):
|
||||
|
Reference in New Issue
Block a user