mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI Minifier] Save EP instead of graphs (#141159)
Summary: `repro.py` can have nested graph modules, e.g. ``` class Repro(torch.nn.Module): def __init__(self) -> None: super().__init__() self.true_graph_0 = GraphModule() def forward(self): true_graph_0 = self.true_graph_0 return (true_graph_0,) ``` So dumping the string doesn’t always work. So, 1) we use exported program in repro.py instead 2) we still dump the graph module string, but only put it in comments We also added two flags to `minifier_launcher.py` - `minifier-export-mode`: whether strict or non-strict export is used in the minifier - `skip-export-error`: intermediate graphs that cannot be exported will be skipped. Test Plan: ``` buck2 run fbcode//caffe2/test/inductor:minifier_utils_cpu -- -r string python test/inductor/test_minifier.py ``` Differential Revision: D66175257 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141159 Approved by: https://github.com/henrylhtsang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ca9813ea14
commit
f28bac76f5
@ -6,6 +6,7 @@ import torch._dynamo.config as dynamo_config
|
||||
import torch._inductor.config as inductor_config
|
||||
from torch._dynamo.test_minifier_common import MinifierTestBase
|
||||
from torch._inductor import config
|
||||
from torch.export import load as export_load
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_JETSON,
|
||||
IS_MACOS,
|
||||
@ -214,16 +215,15 @@ with torch.no_grad():
|
||||
)
|
||||
def test_aoti_cpu_compile_error(self):
|
||||
res = self._test_aoti("cpu", "CppCompileError")
|
||||
ep_file_path = res.get_exported_program_path()
|
||||
gm = export_load(ep_file_path).module()
|
||||
self.assertExpectedInline(
|
||||
res.repro_module(),
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
class Repro(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, linear):
|
||||
relu = torch.ops.aten.relu.default(linear); linear = None
|
||||
return (relu,)""",
|
||||
def forward(self, linear):
|
||||
linear, = fx_pytree.tree_flatten_spec(([linear], {}), self._in_spec)
|
||||
relu = torch.ops.aten.relu.default(linear); linear = None
|
||||
return pytree.tree_unflatten((relu,), self._out_spec)""",
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@ -236,16 +236,15 @@ class Repro(torch.nn.Module):
|
||||
)
|
||||
def test_aoti_gpu_compile_error(self):
|
||||
res = self._test_aoti(GPU_TYPE, "SyntaxError")
|
||||
ep_file_path = res.get_exported_program_path()
|
||||
gm = export_load(ep_file_path).module()
|
||||
self.assertExpectedInline(
|
||||
res.repro_module(),
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
class Repro(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, linear):
|
||||
relu = torch.ops.aten.relu.default(linear); linear = None
|
||||
return (relu,)""",
|
||||
def forward(self, linear):
|
||||
linear, = fx_pytree.tree_flatten_spec(([linear], {}), self._in_spec)
|
||||
relu = torch.ops.aten.relu.default(linear); linear = None
|
||||
return pytree.tree_unflatten((relu,), self._out_spec)""",
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,6 +1,10 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import torch
|
||||
from torch._dynamo.repro.aoti import export_for_aoti_minifier
|
||||
from torch._dynamo.repro.aoti import (
|
||||
AOTIMinifierError,
|
||||
export_for_aoti_minifier,
|
||||
get_module_string,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
@ -19,8 +23,72 @@ class MinifierUtilsTests(TestCase):
|
||||
# Here we obtained a graph with invalid output by symbolic_trace for simplicity,
|
||||
# it can also obtained from running functorch.compile.minifier on an exported graph.
|
||||
traced = torch.fx.symbolic_trace(model)
|
||||
gm = export_for_aoti_minifier(traced, (torch.randn(2, 2),))
|
||||
self.assertTrue(gm is None)
|
||||
for strict in [True, False]:
|
||||
gm = export_for_aoti_minifier(traced, (torch.randn(2, 2),), strict=strict)
|
||||
self.assertTrue(gm is None)
|
||||
|
||||
def test_non_exportable(self):
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.sum()
|
||||
|
||||
model = SimpleModel()
|
||||
# Force export failure by providing an input with in-compatible shapes
|
||||
inputs = (torch.randn(2), torch.randn(2))
|
||||
for strict in [True, False]:
|
||||
gm = export_for_aoti_minifier(
|
||||
model, inputs, strict=strict, skip_export_error=True
|
||||
)
|
||||
print(gm)
|
||||
self.assertTrue(gm is None)
|
||||
|
||||
with self.assertRaises(AOTIMinifierError):
|
||||
export_for_aoti_minifier(
|
||||
model, inputs, strict=strict, skip_export_error=False
|
||||
)
|
||||
|
||||
def test_convert_module_to_string(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, flag):
|
||||
flag = flag.item()
|
||||
|
||||
def true_fn(x):
|
||||
return x.clone()
|
||||
|
||||
return torch.cond(flag > 0, true_fn, true_fn, [x])
|
||||
|
||||
inputs = (
|
||||
torch.rand(28, 28),
|
||||
torch.tensor(1),
|
||||
)
|
||||
|
||||
model = M()
|
||||
gm = torch.export.export(model, inputs, strict=False).module()
|
||||
|
||||
# TODO: make NNModuleToString.convert() generate string for nested submodules.
|
||||
model_string = get_module_string(gm)
|
||||
self.assertExpectedInline(
|
||||
model_string.strip(),
|
||||
"""\
|
||||
# from torch.nn import *
|
||||
# class Repro(torch.nn.Module):
|
||||
# def __init__(self) -> None:
|
||||
# super().__init__()
|
||||
# self.true_graph_0 = <lambda>()
|
||||
# self.false_graph_0 = <lambda>()
|
||||
|
||||
|
||||
|
||||
# def forward(self, x, flag):
|
||||
# x, flag, = fx_pytree.tree_flatten_spec(([x, flag], {}), self._in_spec)
|
||||
# item = torch.ops.aten.item.default(flag); flag = None
|
||||
# gt = item > 0; item = None
|
||||
# true_graph_0 = self.true_graph_0
|
||||
# false_graph_0 = self.false_graph_0
|
||||
# cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x]); gt = true_graph_0 = false_graph_0 = x = None
|
||||
# getitem = cond[0]; cond = None
|
||||
# return pytree.tree_unflatten((getitem,), self._out_spec)""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -4,11 +4,12 @@ import functools
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import textwrap
|
||||
from importlib import import_module
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.debug_utils import (
|
||||
@ -19,13 +20,12 @@ from torch._dynamo.debug_utils import (
|
||||
helper_for_dump_minify,
|
||||
InputReader,
|
||||
minifier_dir,
|
||||
NNModuleToString,
|
||||
NopInputReader,
|
||||
)
|
||||
from torch.export import ExportedProgram
|
||||
from torch.hub import tqdm
|
||||
|
||||
from .after_aot import generate_compiler_repro_string
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -53,32 +53,78 @@ def dump_to_minify(
|
||||
os.makedirs(subdir, exist_ok=True)
|
||||
save_graph_repro_ep(
|
||||
out,
|
||||
exported_program,
|
||||
compiler_name,
|
||||
exported_program=exported_program,
|
||||
save_dir=subdir,
|
||||
command="minify",
|
||||
options=options,
|
||||
config_patches=options,
|
||||
)
|
||||
return helper_for_dump_minify(out.getvalue())
|
||||
|
||||
|
||||
def get_module_string(gm):
|
||||
def _convert_to_comment(s_):
|
||||
s = s_.split("\n")
|
||||
if len(s) == 1:
|
||||
return "# " + s_
|
||||
first = s.pop(0)
|
||||
for i in range(len(s)):
|
||||
line = s[i]
|
||||
if line.strip() != "":
|
||||
s[i] = "# " + line
|
||||
else:
|
||||
s[i] = ""
|
||||
s = "\n".join(s)
|
||||
s = first + "\n" + s
|
||||
return s
|
||||
|
||||
module_string = NNModuleToString.convert(gm)
|
||||
return _convert_to_comment(module_string)
|
||||
|
||||
|
||||
def save_graph_repro_ep(
|
||||
fd,
|
||||
exported_program: ExportedProgram,
|
||||
compiler_name,
|
||||
*,
|
||||
options: Optional[Dict[str, str]] = None,
|
||||
exported_program: Optional[ExportedProgram] = None,
|
||||
gm: Optional[torch.nn.Module] = None,
|
||||
args: Optional[Tuple[Any]] = None,
|
||||
config_patches: Optional[Dict[str, str]] = None,
|
||||
stable_output=False,
|
||||
save_dir=None,
|
||||
command="run",
|
||||
accuracy=None,
|
||||
check_str=None,
|
||||
module_in_comment=False,
|
||||
strict=False,
|
||||
):
|
||||
# Save graph for reproducing the error.
|
||||
# Either exported_program or gm will be saved, depending on which one is defined.
|
||||
# Only one of exported_program and gm should be defined.
|
||||
|
||||
if exported_program is None and gm is None:
|
||||
raise AOTIMinifierError("One of exported_program and gm must be defined")
|
||||
if exported_program is not None and gm is not None:
|
||||
raise AOTIMinifierError("Only one of exported_program and gm can be defined")
|
||||
if gm is not None and args is None:
|
||||
raise AOTIMinifierError("If gm is defined, args should also be defined")
|
||||
|
||||
if exported_program is None:
|
||||
assert gm is not None
|
||||
assert args is not None
|
||||
exported_program = torch.export.export(gm, args, strict=strict)
|
||||
elif gm is None:
|
||||
gm = exported_program.module()
|
||||
|
||||
# save a graph preview using gm
|
||||
module_string = get_module_string(gm)
|
||||
fd.write(module_string)
|
||||
|
||||
# save a graph repro using exported_program
|
||||
fd.write(
|
||||
generate_compiler_repro_exported_program(
|
||||
exported_program,
|
||||
options=options,
|
||||
options=config_patches,
|
||||
stable_output=stable_output,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
@ -94,53 +140,14 @@ def save_graph_repro_ep(
|
||||
)
|
||||
|
||||
|
||||
def save_graph_repro_string(
|
||||
fd,
|
||||
def dump_compiler_graph_state(
|
||||
gm,
|
||||
args,
|
||||
compiler_name,
|
||||
*,
|
||||
config_patches=None,
|
||||
stable_output=False,
|
||||
save_dir=None,
|
||||
command="run",
|
||||
accuracy=None,
|
||||
tracing_mode=None,
|
||||
check_str=None,
|
||||
):
|
||||
# save a graph repro by dumping the `gm` as a string
|
||||
if any(
|
||||
isinstance(arg, torch.fx.experimental._backward_state.BackwardState)
|
||||
for arg in args
|
||||
):
|
||||
fd.write(
|
||||
"Repro is not generated due to existence of BackwardState in graph input"
|
||||
)
|
||||
return
|
||||
fd.write(
|
||||
generate_compiler_repro_string(
|
||||
gm,
|
||||
args,
|
||||
stable_output=stable_output,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
)
|
||||
if accuracy is None:
|
||||
accuracy = "_accuracy" in compiler_name
|
||||
fd.write("if __name__ == '__main__':\n")
|
||||
fd.write(" from torch._dynamo.repro.aoti import run_repro, repro_load_args\n")
|
||||
fd.write(
|
||||
f" config_patches={config_patches}\n"
|
||||
f" with torch.no_grad():\n"
|
||||
f" args = repro_load_args(load_args, save_dir={save_dir!r})\n"
|
||||
f" exported_program = torch.export.export(mod, args)\n"
|
||||
f" run_repro(exported_program, config_patches=config_patches, accuracy={accuracy!r}, command={command!r}, "
|
||||
f"save_dir={save_dir!r}, check_str={check_str!r})\n"
|
||||
)
|
||||
|
||||
|
||||
def dump_compiler_graph_state(
|
||||
gm, args, compiler_name, *, config_patches=None, accuracy=None
|
||||
strict=False,
|
||||
):
|
||||
subdir = os.path.join(minifier_dir(), "checkpoints")
|
||||
if not os.path.exists(subdir):
|
||||
@ -149,16 +156,17 @@ def dump_compiler_graph_state(
|
||||
log.warning(
|
||||
"Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name
|
||||
)
|
||||
# exported_program = torch.export.export(gm, tuple(args))
|
||||
with open(file_name, "w") as fd:
|
||||
save_graph_repro_string(
|
||||
save_graph_repro_ep(
|
||||
fd,
|
||||
gm,
|
||||
args,
|
||||
compiler_name,
|
||||
gm=gm,
|
||||
args=tuple(args),
|
||||
config_patches=config_patches,
|
||||
save_dir=subdir,
|
||||
accuracy=accuracy,
|
||||
module_in_comment=True,
|
||||
strict=strict,
|
||||
)
|
||||
curdir = os.getcwd()
|
||||
repro_path = os.path.join(curdir, "repro.py")
|
||||
@ -281,7 +289,9 @@ def repro_run(options, exported_program, config_patches):
|
||||
synchronize() # ensure segfaults are surfaced
|
||||
|
||||
|
||||
def export_for_aoti_minifier(gm, tuple_inputs) -> Optional[torch.nn.Module]:
|
||||
def export_for_aoti_minifier(
|
||||
gm, tuple_inputs, strict=False, skip_export_error=True
|
||||
) -> Optional[torch.nn.Module]:
|
||||
# Some graphs cannot be used for AOTI/export (illegal graphs), these should be
|
||||
# considered as graphs that don't fail in the minifier, so the minifier keeps searching.
|
||||
# In these case, we return None. Otherwise, we return the exported graph module.
|
||||
@ -290,20 +300,29 @@ def export_for_aoti_minifier(gm, tuple_inputs) -> Optional[torch.nn.Module]:
|
||||
#
|
||||
# Please add to this list of illegal graphs if you change the implementation here.
|
||||
# - graph output is not allowed by export
|
||||
#
|
||||
# If skip_export_error=True, then the errors in export will not be raised, and the minifier
|
||||
# will keep exploring and ignore this graph.
|
||||
from torch._dynamo.exc import UserError, UserErrorType
|
||||
|
||||
try:
|
||||
ep = torch.export.export(gm, tuple_inputs)
|
||||
ep = torch.export.export(gm, tuple_inputs, strict=strict)
|
||||
gm = ep.module()
|
||||
return gm
|
||||
except UserError as e:
|
||||
# graph output is not allowed by export
|
||||
if e.error_type == UserErrorType.INVALID_OUTPUT:
|
||||
return None
|
||||
else:
|
||||
raise AOTIMinifierError(e) from e
|
||||
except Exception as e:
|
||||
if skip_export_error:
|
||||
return None
|
||||
if isinstance(e, UserError) and e.error_type == UserErrorType.INVALID_OUTPUT:
|
||||
# graph output is not allowed by export when strict=True
|
||||
return None
|
||||
if isinstance(e, RuntimeError):
|
||||
# graph output is not allowed by export when strict=False
|
||||
pattern = r"Found .* in output, which is not a known type\."
|
||||
if re.search(pattern, str(e)) is not None:
|
||||
return None
|
||||
raise AOTIMinifierError(e) from e
|
||||
# we should never reach here
|
||||
return None
|
||||
|
||||
|
||||
def repro_minify(options, exported_program, config_patches):
|
||||
@ -312,6 +331,9 @@ def repro_minify(options, exported_program, config_patches):
|
||||
|
||||
mod, args, kwargs = repro_common(options, exported_program)
|
||||
compiler_name = "aot_inductor"
|
||||
assert options.minifier_export_mode in ["dynamo", "python"]
|
||||
strict = options.minifier_export_mode == "dynamo"
|
||||
skip_export_error = options.skip_export_error
|
||||
|
||||
from torch.cuda import synchronize
|
||||
|
||||
@ -325,7 +347,9 @@ def repro_minify(options, exported_program, config_patches):
|
||||
def module_fails(gm, flat_example_inputs, check_str=None):
|
||||
# we have to export first so the in_spec and out_spec are populated
|
||||
tuple_inputs = tuple(flat_example_inputs)
|
||||
gm = export_for_aoti_minifier(gm, tuple_inputs)
|
||||
gm = export_for_aoti_minifier(
|
||||
gm, tuple_inputs, strict=strict, skip_export_error=skip_export_error
|
||||
)
|
||||
|
||||
# Some graphs cannot be used for AOTI/export (illegal graphs), these should be
|
||||
# considered as graphs that don't fail in the minifier, so the minifier keeps searching.
|
||||
@ -356,6 +380,7 @@ def repro_minify(options, exported_program, config_patches):
|
||||
dump_compiler_graph_state,
|
||||
compiler_name=compiler_name,
|
||||
config_patches=config_patches,
|
||||
strict=strict,
|
||||
),
|
||||
save_dir=options.save_dir,
|
||||
offload_to_disk=options.offload_to_disk,
|
||||
@ -367,8 +392,6 @@ def repro_minify(options, exported_program, config_patches):
|
||||
|
||||
def run_repro(
|
||||
exported_program,
|
||||
# load_args,
|
||||
# kwargs: Dict[str, Any],
|
||||
*,
|
||||
config_patches: Optional[Dict[str, str]] = None,
|
||||
command="run",
|
||||
@ -376,6 +399,8 @@ def run_repro(
|
||||
save_dir=None,
|
||||
tracing_mode=None,
|
||||
check_str=None,
|
||||
minifier_export_mode="python",
|
||||
skip_export_error=True,
|
||||
**more_kwargs,
|
||||
):
|
||||
for k in more_kwargs:
|
||||
@ -465,6 +490,21 @@ default settings on this script:
|
||||
default=check_str,
|
||||
help="require minified program to fail with error containing this string",
|
||||
)
|
||||
parser_minify.add_argument(
|
||||
"--minifier-export-mode",
|
||||
type=str,
|
||||
default=minifier_export_mode,
|
||||
help=(
|
||||
"The export mode used in minifier, either dynamo or python."
|
||||
"`dynamo` corresponds to strict=True, and `python` corresponds to strict=False."
|
||||
),
|
||||
)
|
||||
parser_minify.add_argument(
|
||||
"--skip-export-error",
|
||||
type=bool,
|
||||
default=skip_export_error,
|
||||
help="Skip intermediate graphs that cannot be exported.",
|
||||
)
|
||||
|
||||
# Run the repro in the context of minification, inverting exit code meaning
|
||||
parser_minifier_query = subparsers.add_parser(
|
||||
|
Reference in New Issue
Block a user