mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix autocast for non-strict export (#137495)
Summary: add testing for autocast and set_grad nodes for export_for_training. In export_for_training, we do not wrap the autocast and set_grad node in to HOP, but we should still have the set_grad_enabled/autocast nodes. add support for autocast in non-strict export. Previously, `_enter_autocast` and `_exit_autocast` nodes don't show up in the export graph when we use `strict=False`. - In autocast's enter and exit function, we dispatch to `PreDispatchTorchFunctionMode.__torch_function__`. if we have PreDispatchTorchFunctionMode in our function_mode_stack, the call stack looks like below. This is mostly the same call stack as strict mode, except strict mode enters [here](https://www.internalfb.com/code/fbsource/[0d4f1135cacdb26c6e01d5dce1ce52a15d61ee48]/xplat/caffe2/torch/_dynamo/variables/ctx_manager.py?lines=806). ``` - torch.amp.autocast.__enter__()'s torch.overrides.handle_torch_function - torch.fx.experimental.proxy_tensor.TorchFunctionMetadataMode.__torch_function__ - torch.amp._enter_autocast()'s torch.overrides.handle_torch_function - PreDispatchTorchFunctionMode.__torch_function__ ``` - in `PreDispatchTorchFunctionMode.__torch_function__`, we create the autocast nodes. - to match the strict mode behavior, we let the input node to the `_exist_autocast` node be the corresponding `_enter_autocast` node. This requires us to maintain a stack in `PreDispatchTorchFunctionMode`. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_export_with_autocast buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_export_with_set_grad ``` Differential Revision: D64016023 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137495 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
7ba706c74e
commit
a47bb4a393
@ -6788,6 +6788,11 @@ def forward(self, x, b_t, y):
|
||||
"torch.ops.higher_order.wrap_with_set_grad_enabled",
|
||||
ep.graph_module.code,
|
||||
)
|
||||
gm = torch.export.export_for_training(model, (torch.randn(4, 4),)).module()
|
||||
self.assertIn(
|
||||
"set_grad_enabled",
|
||||
gm.code,
|
||||
)
|
||||
|
||||
# T203671967
|
||||
@testing.expectedFailureRetraceability # autocast nodes not created after re-tracing
|
||||
@ -6799,23 +6804,26 @@ def forward(self, x, b_t, y):
|
||||
):
|
||||
y = x.sin().sum()
|
||||
with torch.autocast(
|
||||
device_type="cpu", dtype=torch.float64, enabled=True
|
||||
device_type="cpu", dtype=torch.float16, enabled=True
|
||||
):
|
||||
z = y.sin().sum()
|
||||
return z
|
||||
|
||||
model = Model()
|
||||
ep = export(model, (torch.randn(4, 4),), {})
|
||||
# _export_for_traininig is using pre_dispatch=False
|
||||
# Therefore the autocast calls are not replaced with a hop.
|
||||
# non_strict doesn't have autocast nodes
|
||||
if not is_non_strict_test(self._testMethodName) and not is_training_ir_test(
|
||||
self._testMethodName
|
||||
):
|
||||
# autocast nodes do not exist after run_decomposition()
|
||||
if not is_training_ir_test(self._testMethodName):
|
||||
self.assertIn(
|
||||
"torch.ops.higher_order.wrap_with_autocast",
|
||||
ep.graph_module.code,
|
||||
)
|
||||
# _export_for_traininig is using pre_dispatch=False
|
||||
# Therefore the autocast calls are not replaced with a hop.
|
||||
gm = torch.export.export_for_training(model, (torch.randn(4, 4),)).module()
|
||||
self.assertIn(
|
||||
"autocast",
|
||||
gm.code,
|
||||
)
|
||||
|
||||
def test_export_as_backend(self):
|
||||
def f(x, y):
|
||||
|
@ -355,6 +355,24 @@ class autocast:
|
||||
torch.autocast_increment_nesting()
|
||||
torch.set_autocast_cache_enabled(self._cache_enabled)
|
||||
|
||||
# only dispatch to PreDispatchTorchFunctionMode to avoid exposing this
|
||||
# API to other functional modes. We only expose to PreDispatchTorchFunctionMode
|
||||
# for preserving autocast in torch.export.export.
|
||||
if torch._C._is_torch_function_mode_enabled():
|
||||
stacks = torch.overrides._get_current_function_mode_stack()
|
||||
for mode in stacks:
|
||||
if isinstance(
|
||||
mode,
|
||||
torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode,
|
||||
):
|
||||
args = (
|
||||
self.device,
|
||||
self.fast_dtype,
|
||||
self._enabled,
|
||||
self._cache_enabled,
|
||||
)
|
||||
return mode.__torch_function__(torch.amp._enter_autocast, (), args)
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
|
||||
if torch._jit_internal.is_scripting():
|
||||
return
|
||||
@ -365,6 +383,18 @@ class autocast:
|
||||
torch.set_autocast_enabled(self.device, self.prev)
|
||||
torch.set_autocast_dtype(self.device, self.prev_fastdtype)
|
||||
torch.set_autocast_cache_enabled(self.prev_cache_enabled)
|
||||
|
||||
# only dispatch to PreDispatchTorchFunctionMode to avoid exposing this
|
||||
# API to other functional modes. We only expose to PreDispatchTorchFunctionMode
|
||||
# for preserving autocast in torch.export.export.
|
||||
if torch._C._is_torch_function_mode_enabled():
|
||||
stacks = torch.overrides._get_current_function_mode_stack()
|
||||
for mode in stacks:
|
||||
if isinstance(
|
||||
mode,
|
||||
torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode,
|
||||
):
|
||||
return mode.__torch_function__(torch.amp._exit_autocast, (), ())
|
||||
return False
|
||||
|
||||
def __call__(self, func):
|
||||
|
@ -1251,10 +1251,14 @@ _temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manag
|
||||
class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
||||
def __init__(self, tracer: _ProxyTracer) -> None:
|
||||
self.tracer = tracer
|
||||
# The input to torch.amp.autocast_mode._exit_autocast graph node should be the
|
||||
# enter_autocast node. So we have to save the enter autocast node here, and assign it
|
||||
# to the exit_autocast call_function node.
|
||||
self.enter_autocast_nodes: List[torch.fx.Node] = []
|
||||
|
||||
def __torch_function__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
func: Union[OpOverload, Callable],
|
||||
types: Tuple[torch._C._TensorMeta, ...],
|
||||
args: Tuple[object, ...] = (),
|
||||
kwargs: Optional[Dict[str, object]] = None,
|
||||
@ -1265,7 +1269,12 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
||||
# TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
|
||||
# instead of hardcoding it here.
|
||||
# T203648563
|
||||
if func == torch.amp.autocast_mode._exit_autocast:
|
||||
enter_node = self.enter_autocast_nodes.pop()
|
||||
args = (enter_node,)
|
||||
node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type]
|
||||
if func == torch.amp.autocast_mode._enter_autocast:
|
||||
self.enter_autocast_nodes.append(node)
|
||||
if func in [
|
||||
torch._C._set_grad_enabled,
|
||||
torch.amp.autocast_mode._enter_autocast,
|
||||
|
Reference in New Issue
Block a user