mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Fix autocast context manager when there is exception (#159565)
Summary: When exception occurs inside context manager, we need to either return False OR properly propagage exceptions via __exit__(exc_type, exc_val). But previously while tracing, we don't actually run the exit node so we end up swallowing the exception in a very weird way as outlined in https://github.com/pytorch/pytorch/issues/153202. This PR fixes it Test Plan: new test case Rollback Plan: Differential Revision: D79348382 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159565 Approved by: https://github.com/zou3519, https://github.com/yushangdi
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							83e2ea8135
						
					
				
				
					commit
					2ac45c2752
				
			| @ -14946,6 +14946,51 @@ class GraphModule(torch.nn.Module): | ||||
|         self.assertEqual(x.sin(), ep.module()(x)) | ||||
|         pytree._deregister_pytree_node(torch.FunctionSchema) | ||||
|  | ||||
|     @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") | ||||
|     def test_exception(self): | ||||
|         class Model(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=8) | ||||
|                 self.register_buffer("buffer", torch.ones(4, 4)) | ||||
|                 self.register_buffer("param", torch.ones(4, 4)) | ||||
|  | ||||
|             def forward(self, x): | ||||
|                 token_ids = torch.randint(0, 10, (4,), device=x.device) | ||||
|                 embedded = self.embedding(token_ids).sum() | ||||
|                 return self.buffer.sum() + self.param.sum() + x.sum() + embedded | ||||
|  | ||||
|         class BarModel(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.mod = Model() | ||||
|  | ||||
|             def forward(self, x): | ||||
|                 if "cuda" in str(x.device): | ||||
|                     mod = self.mod.to(x.device) | ||||
|                     return mod(x) | ||||
|                 else: | ||||
|                     return x.sum() | ||||
|  | ||||
|         class BarBar(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.mod = BarModel() | ||||
|  | ||||
|             def forward(self, x): | ||||
|                 with torch.amp.autocast(device_type="cuda"): | ||||
|                     y = self.mod(x) | ||||
|                 return y | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             with self.assertRaisesRegex(RuntimeError, "Couldn't swap Embedding.weight"): | ||||
|                 _ = torch.export.export( | ||||
|                     BarBar(), | ||||
|                     (), | ||||
|                     {"x": torch.randn(4, 4, 4, device="cuda")}, | ||||
|                     strict=False, | ||||
|                 ).module() | ||||
|  | ||||
|     def test_export_for_training_with_state_dict_hooks(self): | ||||
|         def _state_dict_pre_hook(mod, prefix, keep_vars): | ||||
|             mod._buffers["test"] = torch.Tensor([1]) | ||||
|  | ||||
| @ -397,7 +397,10 @@ class autocast: | ||||
|                         self._enabled, | ||||
|                         self._cache_enabled, | ||||
|                     ) | ||||
|                     return mode.__torch_function__(torch.amp._enter_autocast, (), args) | ||||
|                     mode.__torch_function__(torch.amp._enter_autocast, (), args) | ||||
|                     return self | ||||
|  | ||||
|         return self | ||||
|  | ||||
|     def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):  # type: ignore[override] | ||||
|         if torch._jit_internal.is_scripting(): | ||||
| @ -420,7 +423,10 @@ class autocast: | ||||
|                     mode, | ||||
|                     torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode, | ||||
|                 ): | ||||
|                     return mode.__torch_function__(torch.amp._exit_autocast, (), ()) | ||||
|                     mode.__torch_function__(torch.amp._exit_autocast, (), ()) | ||||
|                     # This is very important because the above line actually doesn't | ||||
|                     # run exit code so it end up swallowing exceptions. | ||||
|                     return False | ||||
|         return False | ||||
|  | ||||
|     def __call__(self, func): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user