mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[training ir migration] Fix ReorderConvertTest (#134010)
Summary: Change ReorderConvertTest to work with the new `capture_pre_autograd_graph` implementation using D61175223. Note that now `ReorderConvertTest` doesn't work with the old `capture_pre_autograd_graph` anymore. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/passes/tests:optimize_test -- -r ReorderConvertTest ``` Differential Revision: D61507772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134010 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
e8fc1e0118
commit
8337b4d96e
@ -107,7 +107,6 @@ def capture_pre_autograd_graph(
|
||||
|
||||
"""
|
||||
from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
|
||||
from torch._utils_internal import export_api_rollout_check
|
||||
from torch._export.non_strict_utils import make_constraints
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
from torch.export._unlift import _create_stateful_graph_module
|
||||
@ -123,79 +122,72 @@ def capture_pre_autograd_graph(
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if export_api_rollout_check():
|
||||
@lru_cache
|
||||
def print_export_warning():
|
||||
log.warning("Using torch.export._trace._export")
|
||||
print_export_warning()
|
||||
module = torch.export._trace._export(f, args, kwargs, dynamic_shapes=dynamic_shapes, pre_dispatch=True).module()
|
||||
else:
|
||||
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
|
||||
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
|
||||
|
||||
# Do not decompose dropout for exported models, because in eval mode the dropout
|
||||
# op disappears from the graph, which makes it difficult to switch to train mode.
|
||||
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
|
||||
decomp_table = {
|
||||
op: op.decompose
|
||||
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
|
||||
if op != torch.ops.aten.dropout.default
|
||||
# Do not decompose dropout for exported models, because in eval mode the dropout
|
||||
# op disappears from the graph, which makes it difficult to switch to train mode.
|
||||
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
|
||||
decomp_table = {
|
||||
op: op.decompose
|
||||
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
|
||||
if op != torch.ops.aten.dropout.default
|
||||
}
|
||||
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
|
||||
m = torch._dynamo.export(
|
||||
f,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
assume_static_by_default=True,
|
||||
tracing_mode="symbolic",
|
||||
decomposition_table=decomp_table,
|
||||
pre_dispatch=True,
|
||||
aten_graph=True,
|
||||
_log_export_usage=False,
|
||||
)(
|
||||
*args,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
_, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
|
||||
|
||||
m.meta["inline_constraints"] = {
|
||||
k: v
|
||||
for k, v in fake_mode.shape_env.var_to_range.items()
|
||||
if re.match(r"^[if]\d+$", str(k))
|
||||
}
|
||||
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
|
||||
m = torch._dynamo.export(
|
||||
f,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
assume_static_by_default=True,
|
||||
tracing_mode="symbolic",
|
||||
decomposition_table=decomp_table,
|
||||
pre_dispatch=True,
|
||||
aten_graph=True,
|
||||
_log_export_usage=False,
|
||||
)(
|
||||
*args,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
_, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
|
||||
if isinstance(f, torch.nn.Module):
|
||||
from torch.export._trace import _restore_state_dict
|
||||
_restore_state_dict(f, m)
|
||||
|
||||
m.meta["inline_constraints"] = {
|
||||
k: v
|
||||
for k, v in fake_mode.shape_env.var_to_range.items()
|
||||
if re.match(r"^[if]\d+$", str(k))
|
||||
}
|
||||
flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
|
||||
combined_args = _combine_args(f, args, kwargs)
|
||||
range_constraints = make_constraints(
|
||||
fake_mode,
|
||||
m,
|
||||
combined_args,
|
||||
dynamic_shapes,
|
||||
0,
|
||||
)
|
||||
|
||||
if isinstance(f, torch.nn.Module):
|
||||
from torch.export._trace import _restore_state_dict
|
||||
_restore_state_dict(f, m)
|
||||
module = _create_stateful_graph_module(
|
||||
m,
|
||||
range_constraints=range_constraints,
|
||||
)
|
||||
|
||||
flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
|
||||
combined_args = _combine_args(f, args, kwargs)
|
||||
range_constraints = make_constraints(
|
||||
fake_mode,
|
||||
m,
|
||||
combined_args,
|
||||
dynamic_shapes,
|
||||
0,
|
||||
)
|
||||
error_message = \
|
||||
"""
|
||||
Calling train() or eval() is not supported for exported models.
|
||||
Alternatively, you may override these methods to do custom user behavior as follows:
|
||||
|
||||
module = _create_stateful_graph_module(
|
||||
m,
|
||||
range_constraints=range_constraints,
|
||||
)
|
||||
def _my_train(self, mode: bool = True):
|
||||
...
|
||||
|
||||
error_message = \
|
||||
"""
|
||||
Calling train() or eval() is not supported for exported models.
|
||||
Alternatively, you may override these methods to do custom user behavior as follows:
|
||||
def _my_eval(self):
|
||||
...
|
||||
|
||||
def _my_train(self, mode: bool = True):
|
||||
...
|
||||
|
||||
def _my_eval(self):
|
||||
...
|
||||
|
||||
model.train = types.MethodType(_my_train, model)
|
||||
model.eval = types.MethodType(_my_eval, model)
|
||||
"""
|
||||
model.train = types.MethodType(_my_train, model)
|
||||
model.eval = types.MethodType(_my_eval, model)
|
||||
"""
|
||||
|
||||
def _train(self, mode: bool = True):
|
||||
raise NotImplementedError(error_message)
|
||||
|
Reference in New Issue
Block a user