mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[export] build the infra to rollout predispatch export. (#122326)
Test Plan: fbcode:caffe2/test/quantization:test_quantization fbcode:bolt/nn/executorch/backends/tests:qnn_test fbcode:on_device_ai/helios/compiler_tests/... fbcode:pyspeech/tests:pyspeech_utils_test_oss fbcode:caffe2/test:quantization_pt2e_qat fbcode:on_device_ai/Assistant/Jarvis/tests:test_custom_ops fbcode:modai/test:test_modai fbcode:executorch/exir/backend/test:test_partitioner Differential Revision: D55133846 Pull Request resolved: https://github.com/pytorch/pytorch/pull/122326 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
4b535906aa
commit
b1fa0ce4aa
@ -127,71 +127,75 @@ def capture_pre_autograd_graph(
|
||||
|
||||
"""
|
||||
from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG
|
||||
|
||||
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
|
||||
from torch._utils_internal import export_api_rollout_check
|
||||
|
||||
assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# Do not decompose dropout for exported models, because in eval mode the dropout
|
||||
# op disappears from the graph, which makes it difficult to switch to train mode.
|
||||
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
|
||||
decomp_table = {
|
||||
op: op.decompose
|
||||
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
|
||||
if op != torch.ops.aten.dropout.default
|
||||
}
|
||||
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
|
||||
m = torch._dynamo.export(
|
||||
f,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
assume_static_by_default=True,
|
||||
tracing_mode="symbolic",
|
||||
decomposition_table=decomp_table,
|
||||
pre_dispatch=True,
|
||||
aten_graph=True,
|
||||
_log_export_usage=False,
|
||||
)(
|
||||
*args,
|
||||
**kwargs,
|
||||
)[0]
|
||||
if export_api_rollout_check():
|
||||
module = torch.export._trace._export(f, args, kwargs, dynamic_shapes=dynamic_shapes, pre_dispatch=True).module()
|
||||
else:
|
||||
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
|
||||
|
||||
_, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs)
|
||||
|
||||
m.meta["inline_constraints"] = {
|
||||
k: v
|
||||
for k, v in fake_mode.shape_env.var_to_range.items()
|
||||
if re.match(r"^[if]\d+$", str(k))
|
||||
# Do not decompose dropout for exported models, because in eval mode the dropout
|
||||
# op disappears from the graph, which makes it difficult to switch to train mode.
|
||||
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
|
||||
decomp_table = {
|
||||
op: op.decompose
|
||||
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
|
||||
if op != torch.ops.aten.dropout.default
|
||||
}
|
||||
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
|
||||
m = torch._dynamo.export(
|
||||
f,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
assume_static_by_default=True,
|
||||
tracing_mode="symbolic",
|
||||
decomposition_table=decomp_table,
|
||||
pre_dispatch=True,
|
||||
aten_graph=True,
|
||||
_log_export_usage=False,
|
||||
)(
|
||||
*args,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
if isinstance(f, torch.nn.Module):
|
||||
from torch.export._trace import _restore_state_dict
|
||||
_restore_state_dict(f, m)
|
||||
_, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs)
|
||||
|
||||
flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
|
||||
range_constraints = _process_constraints(fake_mode, m, 0, flat_args)
|
||||
m.meta["inline_constraints"] = {
|
||||
k: v
|
||||
for k, v in fake_mode.shape_env.var_to_range.items()
|
||||
if re.match(r"^[if]\d+$", str(k))
|
||||
}
|
||||
|
||||
module = _create_stateful_graph_module(
|
||||
m,
|
||||
range_constraints=range_constraints,
|
||||
)
|
||||
if isinstance(f, torch.nn.Module):
|
||||
from torch.export._trace import _restore_state_dict
|
||||
_restore_state_dict(f, m)
|
||||
|
||||
error_message = \
|
||||
"""
|
||||
Calling train() or eval() is not supported for exported models.
|
||||
Alternatively, you may override these methods to do custom user behavior as follows:
|
||||
flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
|
||||
range_constraints = _process_constraints(fake_mode, m, 0, flat_args)
|
||||
|
||||
def _my_train(self, mode: bool = True):
|
||||
...
|
||||
module = _create_stateful_graph_module(
|
||||
m,
|
||||
range_constraints=range_constraints,
|
||||
)
|
||||
|
||||
def _my_eval(self):
|
||||
...
|
||||
error_message = \
|
||||
"""
|
||||
Calling train() or eval() is not supported for exported models.
|
||||
Alternatively, you may override these methods to do custom user behavior as follows:
|
||||
|
||||
model.train = types.MethodType(_my_train, model)
|
||||
model.eval = types.MethodType(_my_eval, model)
|
||||
"""
|
||||
def _my_train(self, mode: bool = True):
|
||||
...
|
||||
|
||||
def _my_eval(self):
|
||||
...
|
||||
|
||||
model.train = types.MethodType(_my_train, model)
|
||||
model.eval = types.MethodType(_my_eval, model)
|
||||
"""
|
||||
|
||||
def _train(self, mode: bool = True):
|
||||
raise NotImplementedError(error_message)
|
||||
|
@ -100,6 +100,10 @@ def log_torchscript_usage(api: str):
|
||||
return
|
||||
|
||||
|
||||
def export_api_rollout_check() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def justknobs_check(name: str) -> bool:
|
||||
"""
|
||||
This function can be used to killswitch functionality in FB prod,
|
||||
|
Reference in New Issue
Block a user