[export] build the infra to rollout predispatch export. (#122326)

Test Plan:
fbcode:caffe2/test/quantization:test_quantization
fbcode:bolt/nn/executorch/backends/tests:qnn_test
fbcode:on_device_ai/helios/compiler_tests/...
fbcode:pyspeech/tests:pyspeech_utils_test_oss
fbcode:caffe2/test:quantization_pt2e_qat
fbcode:on_device_ai/Assistant/Jarvis/tests:test_custom_ops
fbcode:modai/test:test_modai
fbcode:executorch/exir/backend/test:test_partitioner

Differential Revision: D55133846

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122326
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Zhengxu Chen
2024-03-22 00:55:10 +00:00
committed by PyTorch MergeBot
parent 4b535906aa
commit b1fa0ce4aa
2 changed files with 58 additions and 50 deletions

View File

@ -127,71 +127,75 @@ def capture_pre_autograd_graph(
"""
from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
from torch._utils_internal import export_api_rollout_check
assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
if kwargs is None:
kwargs = {}
# Do not decompose dropout for exported models, because in eval mode the dropout
# op disappears from the graph, which makes it difficult to switch to train mode.
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
decomp_table = {
op: op.decompose
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
if op != torch.ops.aten.dropout.default
}
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
m = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes,
assume_static_by_default=True,
tracing_mode="symbolic",
decomposition_table=decomp_table,
pre_dispatch=True,
aten_graph=True,
_log_export_usage=False,
)(
*args,
**kwargs,
)[0]
if export_api_rollout_check():
module = torch.export._trace._export(f, args, kwargs, dynamic_shapes=dynamic_shapes, pre_dispatch=True).module()
else:
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
_, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs)
m.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
# Do not decompose dropout for exported models, because in eval mode the dropout
# op disappears from the graph, which makes it difficult to switch to train mode.
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
decomp_table = {
op: op.decompose
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
if op != torch.ops.aten.dropout.default
}
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
m = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes,
assume_static_by_default=True,
tracing_mode="symbolic",
decomposition_table=decomp_table,
pre_dispatch=True,
aten_graph=True,
_log_export_usage=False,
)(
*args,
**kwargs,
)[0]
if isinstance(f, torch.nn.Module):
from torch.export._trace import _restore_state_dict
_restore_state_dict(f, m)
_, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs)
flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
range_constraints = _process_constraints(fake_mode, m, 0, flat_args)
m.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}
module = _create_stateful_graph_module(
m,
range_constraints=range_constraints,
)
if isinstance(f, torch.nn.Module):
from torch.export._trace import _restore_state_dict
_restore_state_dict(f, m)
error_message = \
"""
Calling train() or eval() is not supported for exported models.
Alternatively, you may override these methods to do custom user behavior as follows:
flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
range_constraints = _process_constraints(fake_mode, m, 0, flat_args)
def _my_train(self, mode: bool = True):
...
module = _create_stateful_graph_module(
m,
range_constraints=range_constraints,
)
def _my_eval(self):
...
error_message = \
"""
Calling train() or eval() is not supported for exported models.
Alternatively, you may override these methods to do custom user behavior as follows:
model.train = types.MethodType(_my_train, model)
model.eval = types.MethodType(_my_eval, model)
"""
def _my_train(self, mode: bool = True):
...
def _my_eval(self):
...
model.train = types.MethodType(_my_train, model)
model.eval = types.MethodType(_my_eval, model)
"""
def _train(self, mode: bool = True):
raise NotImplementedError(error_message)

View File

@ -100,6 +100,10 @@ def log_torchscript_usage(api: str):
return
def export_api_rollout_check() -> bool:
return False
def justknobs_check(name: str) -> bool:
"""
This function can be used to killswitch functionality in FB prod,