[training ir migration] Fix ReorderConvertTest (#134010)

Summary:
Change ReorderConvertTest to work with the new `capture_pre_autograd_graph` implementation using D61175223.

Note that now `ReorderConvertTest` doesn't work with the old `capture_pre_autograd_graph` anymore.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/passes/tests:optimize_test -- -r ReorderConvertTest
```

Differential Revision: D61507772

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134010
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Shangdi Yu
2024-08-21 04:48:43 +00:00
committed by PyTorch MergeBot
parent e8fc1e0118
commit 8337b4d96e

View File

@ -107,7 +107,6 @@ def capture_pre_autograd_graph(
"""
from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
from torch._utils_internal import export_api_rollout_check
from torch._export.non_strict_utils import make_constraints
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.export._unlift import _create_stateful_graph_module
@ -123,79 +122,72 @@ def capture_pre_autograd_graph(
if kwargs is None:
kwargs = {}
if export_api_rollout_check():
@lru_cache
def print_export_warning():
log.warning("Using torch.export._trace._export")
print_export_warning()
module = torch.export._trace._export(f, args, kwargs, dynamic_shapes=dynamic_shapes, pre_dispatch=True).module()
else:
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
# Do not decompose dropout for exported models, because in eval mode the dropout
# op disappears from the graph, which makes it difficult to switch to train mode.
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
decomp_table = {
op: op.decompose
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
if op != torch.ops.aten.dropout.default
# Do not decompose dropout for exported models, because in eval mode the dropout
# op disappears from the graph, which makes it difficult to switch to train mode.
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
decomp_table = {
op: op.decompose
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
if op != torch.ops.aten.dropout.default
}
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
m = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes,
assume_static_by_default=True,
tracing_mode="symbolic",
decomposition_table=decomp_table,
pre_dispatch=True,
aten_graph=True,
_log_export_usage=False,
)(
*args,
**kwargs,
)[0]
_, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
m.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
m = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes,
assume_static_by_default=True,
tracing_mode="symbolic",
decomposition_table=decomp_table,
pre_dispatch=True,
aten_graph=True,
_log_export_usage=False,
)(
*args,
**kwargs,
)[0]
_, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
if isinstance(f, torch.nn.Module):
from torch.export._trace import _restore_state_dict
_restore_state_dict(f, m)
m.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}
flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
combined_args = _combine_args(f, args, kwargs)
range_constraints = make_constraints(
fake_mode,
m,
combined_args,
dynamic_shapes,
0,
)
if isinstance(f, torch.nn.Module):
from torch.export._trace import _restore_state_dict
_restore_state_dict(f, m)
module = _create_stateful_graph_module(
m,
range_constraints=range_constraints,
)
flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
combined_args = _combine_args(f, args, kwargs)
range_constraints = make_constraints(
fake_mode,
m,
combined_args,
dynamic_shapes,
0,
)
error_message = \
"""
Calling train() or eval() is not supported for exported models.
Alternatively, you may override these methods to do custom user behavior as follows:
module = _create_stateful_graph_module(
m,
range_constraints=range_constraints,
)
def _my_train(self, mode: bool = True):
...
error_message = \
"""
Calling train() or eval() is not supported for exported models.
Alternatively, you may override these methods to do custom user behavior as follows:
def _my_eval(self):
...
def _my_train(self, mode: bool = True):
...
def _my_eval(self):
...
model.train = types.MethodType(_my_train, model)
model.eval = types.MethodType(_my_eval, model)
"""
model.train = types.MethodType(_my_train, model)
model.eval = types.MethodType(_my_eval, model)
"""
def _train(self, mode: bool = True):
raise NotImplementedError(error_message)