Fix autocast for non-strict export (#137495)

Summary:

add testing for autocast and set_grad nodes for export_for_training. In export_for_training, we do not wrap the autocast and set_grad node in to HOP, but we should still have the set_grad_enabled/autocast nodes.

add support for autocast in non-strict export. Previously, `_enter_autocast` and `_exit_autocast` nodes don't show up in the export graph when we use `strict=False`.

- In autocast's enter and exit function, we dispatch to `PreDispatchTorchFunctionMode.__torch_function__`.
 if we have PreDispatchTorchFunctionMode in our function_mode_stack, the call stack looks like below. This is mostly the same call stack as strict mode, except strict mode enters [here](https://www.internalfb.com/code/fbsource/[0d4f1135cacdb26c6e01d5dce1ce52a15d61ee48]/xplat/caffe2/torch/_dynamo/variables/ctx_manager.py?lines=806).
```
- torch.amp.autocast.__enter__()'s torch.overrides.handle_torch_function
- torch.fx.experimental.proxy_tensor.TorchFunctionMetadataMode.__torch_function__
- torch.amp._enter_autocast()'s torch.overrides.handle_torch_function
- PreDispatchTorchFunctionMode.__torch_function__
```
- in `PreDispatchTorchFunctionMode.__torch_function__`, we create the autocast nodes.
- to match the strict mode behavior, we let the input node to the `_exist_autocast` node be the corresponding `_enter_autocast` node. This requires us to maintain a stack in `PreDispatchTorchFunctionMode`.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_export_with_autocast
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_export_with_set_grad
```

Differential Revision: D64016023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137495
Approved by: https://github.com/bdhirsh
This commit is contained in:
Shangdi Yu
2024-10-16 17:38:57 +00:00
committed by PyTorch MergeBot
parent 7ba706c74e
commit a47bb4a393
3 changed files with 55 additions and 8 deletions

View File

@ -6788,6 +6788,11 @@ def forward(self, x, b_t, y):
"torch.ops.higher_order.wrap_with_set_grad_enabled",
ep.graph_module.code,
)
gm = torch.export.export_for_training(model, (torch.randn(4, 4),)).module()
self.assertIn(
"set_grad_enabled",
gm.code,
)
# T203671967
@testing.expectedFailureRetraceability # autocast nodes not created after re-tracing
@ -6799,23 +6804,26 @@ def forward(self, x, b_t, y):
):
y = x.sin().sum()
with torch.autocast(
device_type="cpu", dtype=torch.float64, enabled=True
device_type="cpu", dtype=torch.float16, enabled=True
):
z = y.sin().sum()
return z
model = Model()
ep = export(model, (torch.randn(4, 4),), {})
# _export_for_traininig is using pre_dispatch=False
# Therefore the autocast calls are not replaced with a hop.
# non_strict doesn't have autocast nodes
if not is_non_strict_test(self._testMethodName) and not is_training_ir_test(
self._testMethodName
):
# autocast nodes do not exist after run_decomposition()
if not is_training_ir_test(self._testMethodName):
self.assertIn(
"torch.ops.higher_order.wrap_with_autocast",
ep.graph_module.code,
)
# _export_for_traininig is using pre_dispatch=False
# Therefore the autocast calls are not replaced with a hop.
gm = torch.export.export_for_training(model, (torch.randn(4, 4),)).module()
self.assertIn(
"autocast",
gm.code,
)
def test_export_as_backend(self):
def f(x, y):

View File

@ -355,6 +355,24 @@ class autocast:
torch.autocast_increment_nesting()
torch.set_autocast_cache_enabled(self._cache_enabled)
# only dispatch to PreDispatchTorchFunctionMode to avoid exposing this
# API to other functional modes. We only expose to PreDispatchTorchFunctionMode
# for preserving autocast in torch.export.export.
if torch._C._is_torch_function_mode_enabled():
stacks = torch.overrides._get_current_function_mode_stack()
for mode in stacks:
if isinstance(
mode,
torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode,
):
args = (
self.device,
self.fast_dtype,
self._enabled,
self._cache_enabled,
)
return mode.__torch_function__(torch.amp._enter_autocast, (), args)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
if torch._jit_internal.is_scripting():
return
@ -365,6 +383,18 @@ class autocast:
torch.set_autocast_enabled(self.device, self.prev)
torch.set_autocast_dtype(self.device, self.prev_fastdtype)
torch.set_autocast_cache_enabled(self.prev_cache_enabled)
# only dispatch to PreDispatchTorchFunctionMode to avoid exposing this
# API to other functional modes. We only expose to PreDispatchTorchFunctionMode
# for preserving autocast in torch.export.export.
if torch._C._is_torch_function_mode_enabled():
stacks = torch.overrides._get_current_function_mode_stack()
for mode in stacks:
if isinstance(
mode,
torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode,
):
return mode.__torch_function__(torch.amp._exit_autocast, (), ())
return False
def __call__(self, func):

View File

@ -1251,10 +1251,14 @@ _temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manag
class PreDispatchTorchFunctionMode(TorchFunctionMode):
def __init__(self, tracer: _ProxyTracer) -> None:
self.tracer = tracer
# The input to torch.amp.autocast_mode._exit_autocast graph node should be the
# enter_autocast node. So we have to save the enter autocast node here, and assign it
# to the exit_autocast call_function node.
self.enter_autocast_nodes: List[torch.fx.Node] = []
def __torch_function__(
self,
func: OpOverload,
func: Union[OpOverload, Callable],
types: Tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None,
@ -1265,7 +1269,12 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
# TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
# instead of hardcoding it here.
# T203648563
if func == torch.amp.autocast_mode._exit_autocast:
enter_node = self.enter_autocast_nodes.pop()
args = (enter_node,)
node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type]
if func == torch.amp.autocast_mode._enter_autocast:
self.enter_autocast_nodes.append(node)
if func in [
torch._C._set_grad_enabled,
torch.amp.autocast_mode._enter_autocast,