Files
pytorch/torch/_dynamo/debug_utils.py
Edward Z. Yang a867e6db85 Add newline before minified repro path (#103083)
Minor QOL change.  This log message is pushed into my history by the
backtrace, which is a pain because if I tab up in tmux I can no longer
paste it without line breaks.  This makes it more convenient to use tmux
copy mode to get only the file (as I get the entire line this way.)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103083
Approved by: https://github.com/albanD
2023-06-06 17:09:44 +00:00

690 lines
23 KiB
Python

import copy
import functools
import getpass
import itertools
import logging
import os
import subprocess
import tempfile
import textwrap
from collections import Counter
from importlib import import_module
from typing import Callable, Optional, Sequence, TypeVar
import torch
import torch._prims_common as utils
import torch._subclasses.meta_utils
from torch._dynamo.testing import rand_strided
from torch._prims_common import is_float_dtype
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
from . import config
from .utils import clone_inputs, get_debug_dir
log = logging.getLogger(__name__)
T = TypeVar("T")
inductor_config = import_module("torch._inductor.config")
use_buck = inductor_config.is_fbcode()
if use_buck:
import libfb.py.build_info # type: ignore[import]
extra_deps = []
extra_imports = ""
if use_buck:
extra_deps = [
"//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu",
"//caffe2/torch/fb/sparsenn:sparsenn_operators",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
]
cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//")
extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])
BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"]
class BuckTargetWriter:
def __init__(self, filename):
self.subdir, self.py_file = os.path.split(os.path.abspath(filename))
self.target = self.py_file.replace(".py", "")
# Get main_module path from fbcode
self.path = f'{self.subdir.replace("/", ".")}.{self.target}'
self.path = self.path[self.path.find("fbcode.") :]
self.path = self.path[7:]
# Get cmd line path
tmp = self.subdir
tmp = tmp[tmp.find("fbcode/") :][7:]
self.cmd_line_path = f"//{tmp}:{self.target}"
def build(self):
extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps])
return textwrap.dedent(
f"""
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
python_binary(
name="{self.target}",
srcs = ["{self.py_file}"],
compile = False,
deps = [
"//caffe2:torch",
"//caffe2/functorch:functorch",
"//triton:triton",
"{cur_target}",
],
cpp_deps = [
{extra_cpp_deps}
],
main_module = "{self.path}",
)
"""
)
def write(self, print_msg=True):
target_file = os.path.join(self.subdir, "TARGETS")
with open(target_file, "w") as fd:
fd.write(self.build())
# log.warning("Wrote isolation TARGETS file at %s", target_file)
cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path]
if print_msg:
log.warning(
"Found an example that reproduces the error. Run this cmd to repro - %s",
" ".join(cmd_split),
)
return cmd_split
def minifier_dir():
path = os.path.join(get_debug_dir(), "minifier")
if path is None:
path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}"
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
return path
MAX_CONSTANT_NUMEL_INLINE = 4
class NNModuleToString:
safe_reprs = [
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.LayerNorm,
torch.nn.Dropout,
torch.nn.Softmax,
torch.nn.ReLU,
torch.nn.GELU,
torch.nn.Identity,
torch.nn.MaxPool2d,
torch.nn.Embedding,
torch.nn.Tanh,
torch.nn.ConvTranspose1d,
torch.nn.GLU,
torch.nn.LSTM,
torch.nn.Flatten,
torch.nn.AdaptiveAvgPool2d,
]
@staticmethod
def can_convert_to_string(gm):
cant_convert = set()
for _, module in gm.named_children():
if type(module) not in NNModuleToString.safe_reprs:
cant_convert.add(module)
if len(cant_convert) > 0:
log.warning("We have not tested reprs of some modules - %s", cant_convert)
# TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct.
return True
@staticmethod
def convert(gm):
from torch.nn.modules.module import _addindent
tab = " " * 4
model_str = textwrap.dedent(
"""
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
"""
)
for module_name, module in gm.named_children():
module_str = f"{module.__repr__()}"
# module should be a core torch.nn.Module, so all parameters
# should be on the same device.
example_param = next(module.parameters(), None)
if example_param is not None and example_param.is_cuda:
module_str = f"{module_str}.cuda()"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
for buffer_name, buffer in gm._buffers.items():
if buffer is None:
continue
# Serialize full data for small buffers
if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE:
from torch._tensor_str import PRINT_OPTS
assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE
tensor_str = repr(buffer)
elif torch.is_floating_point(buffer):
tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
else:
tensor_str = (
f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})"
)
if buffer.is_cuda:
tensor_str = f"{tensor_str}.cuda()"
model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n"
for param_name, param in gm._parameters.items():
if param is None:
continue
maybe_device = ""
if param.is_cuda:
maybe_device = ', device="cuda"'
tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))"
model_str += f"{tab*2}self.{param_name} = {tensor_str}\n"
# TODO - Keep this code for now. But, I don't think we will need this.
# attrs = dir(gm)
# for attr in attrs:
# if "_tensor_constant" in attr:
# val = getattr(gm, attr)
# model_str += f" {attr} = {val!r}\n"
model_str += f"{_addindent(gm.code, 4)}\n"
return model_str
@functools.lru_cache(None) # subprocess is expensive
def _cuda_system_info_comment():
if not torch.cuda.is_available():
return "# torch.cuda.is_available()==False, no GPU info collected\n"
model_str = "# CUDA Info: \n"
try:
cuda_version_out = subprocess.run(["nvcc", "--version"], stdout=subprocess.PIPE)
cuda_version_lines = cuda_version_out.stdout.decode().split("\n")
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]])
model_str += f"{comment}\n"
except FileNotFoundError:
model_str += "# nvcc not found\n"
gpu_names = Counter(
torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())
)
model_str += "# GPU Hardware Info: \n"
for name, count in gpu_names.items():
model_str += f"# {name} : {count} \n"
model_str += "\n"
return model_str
def generate_config_string(*, stable_output=False):
import torch._functorch.config
import torch._inductor.config
if stable_output:
return "# config omitted due to stable_output=True"
return f"""\
import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
{torch._dynamo.config.codegen_config()}
{torch._inductor.config.codegen_config()}
{torch._functorch.config.codegen_config()}
"""
def get_minifier_repro_path():
return os.path.join(minifier_dir(), "minifier_launcher.py")
def helper_for_dump_minify(contents):
minified_repro_path = get_minifier_repro_path()
log.warning("Writing minified repro to:\n%s", minified_repro_path)
if use_buck:
BuckTargetWriter(minified_repro_path).write()
try:
with open(minified_repro_path, "w") as fd:
fd.write(contents)
except OSError as e:
log.exception(e)
raise NotImplementedError("Could not write to {minified_repro_path}") from e
class AccuracyError(Exception):
pass
def clone_inputs_retaining_gradness(example_inputs):
"""
This clone inputs is different from utils clone_input. In case of minifier,
all the tensors are leaf tensors while creating a new graph. So, we set the
requires_grad field w/o checking the leafness of the tensor.
"""
cloned_inputs = clone_inputs(example_inputs)
for idx in range(len(example_inputs)):
if isinstance(cloned_inputs[idx], torch.Tensor):
cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad)
return cloned_inputs
def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False):
"""
Runs a forward and possibly backward iteration for a given mod and args.
When disable_clone is True, we will use args as-is without cloning.
This is higher fidelity but we may destroy the args in the process.
"""
from torch._functorch.aot_autograd import make_boxed_func
from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
gm = copy.deepcopy(gm)
if not disable_clone:
args = clone_inputs_retaining_gradness(args)
if hasattr(gm, "zero_grad"):
gm.zero_grad(True)
# TorchInductor returned callable expects lists. So, boxing the call.
orig_named_parameters = getattr(gm, "named_parameters", None)
orig_named_buffers = getattr(gm, "named_buffers", None)
if not hasattr(gm, "_boxed_call") and (
orig_named_parameters is not None or orig_named_buffers is not None
):
gm = make_boxed_func(gm)
if orig_named_parameters is not None:
gm.named_parameters = orig_named_parameters
if orig_named_buffers is not None:
gm.named_buffers = orig_named_buffers
out = gm(args)
if only_fwd:
return out
if requires_bwd_pass(out):
loss = reduce_to_scalar_loss(out)
loss.backward()
return collect_results(gm, out, None, args)
def same_two_models(
gm,
opt_gm,
example_inputs,
only_fwd=False,
*,
require_fp64=False,
ignore_non_fp=False,
):
"""
Check two models have same accuracy.
require_fp64: if True, raise an error if we unable to calculate the fp64 reference
ignore_non_fp: if True, do not compare outputs which are not floating point. This
is mostly useful for the minifier (which wants to avoid quantizing floating point
error into integer/boolean error)
"""
from .eval_frame import OptimizedModule
from .testing import (
named_buffers_for_optimized_module,
named_parameters_for_optimized_module,
)
from .utils import same
if isinstance(gm, OptimizedModule):
gm.named_parameters = named_parameters_for_optimized_module(gm)
gm.named_buffers = named_buffers_for_optimized_module(gm)
if isinstance(opt_gm, OptimizedModule):
opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm)
opt_gm.named_buffers = named_buffers_for_optimized_module(opt_gm)
ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
fp64_ref = None
if config.same_two_models_use_fp64:
try:
fp64_model, fp64_examples = cast_to_fp64(
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
)
fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd)
except Exception:
if require_fp64:
raise RuntimeError("Could not generate fp64 outputs")
log.warning("Could not generate fp64 outputs")
try:
res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd)
except Exception as e:
# This means that the minified graph is bad/exposes a different problem.
# As we are checking accuracy here, lets log the exception and return True.
log.exception(
(
"While minifying the program in accuracy minification mode, "
"ran into a runtime exception which is likely an unrelated issue."
" Skipping this graph."
)
)
return True
passing = same(
ref,
res,
fp64_ref,
tol=config.repro_tolerance,
equal_nan=True,
ignore_non_fp=ignore_non_fp,
)
return passing
def cast_convert_element_type_to_fp64(model):
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.prims.convert_element_type.default
):
assert len(node.args) == 2
if is_float_dtype(node.args[1]) and node.args[1] != torch.float64:
node.args = (node.args[0], torch.float64)
model.graph.lint()
model.recompile()
return model
def cast_to(dtype, model, inputs):
from torch.utils._pytree import tree_map
model = model.to(dtype)
if dtype == torch.float64:
# If casting to fp64 for accuracy comparison, we need to
# take care of convert_element_type explicitly
model = cast_convert_element_type_to_fp64(model)
inputs = tree_map(
lambda x: x.to(dtype)
if isinstance(x, torch.Tensor) and x.is_floating_point()
else x,
inputs,
)
return model, inputs
def cast_to_fp64(model, inputs):
return cast_to(torch.float64, model, inputs)
def backend_accuracy_fails(
gm,
example_inputs,
compiler_fn,
only_fwd=False,
*,
require_fp64=False,
ignore_non_fp=False,
):
try:
compiled_gm = compiler_fn(
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
)
return not same_two_models(
gm,
compiled_gm,
example_inputs,
only_fwd,
require_fp64=require_fp64,
ignore_non_fp=ignore_non_fp,
)
except Exception as e:
# This means that the the minified graph is bad/exposes a different problem.
# As we are checking accuracy here, lets log the exception and return False.
log.exception(
(
"While minifying the program in accuracy minification mode, "
"ran into a runtime exception which is likely an unrelated issue."
" Skipping this graph"
)
)
return False
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# REPRO SUPPORT CODE
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# Helper functions for computing what the default values of tensor
# values should be. These all coincide with factory functions, e.g., torch.empty
def _stride_or_default(
stride: Optional[Sequence[int]], *, shape: Sequence[int]
) -> Sequence[int]:
return stride if stride is not None else utils.make_contiguous_strides_for(shape)
def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]:
return lambda x: x if x is not None else d
_dtype_or_default = _mk_defaulter(torch.float32)
_device_or_default = _mk_defaulter(torch.device("cpu"))
_storage_offset_or_default = _mk_defaulter(0)
_requires_grad_or_default = _mk_defaulter(False)
_is_leaf_or_default = _mk_defaulter(False)
class NopInputReader:
def __init__(self):
self.total = 0
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
self.total += 1
def tensor(self, *args, **kwargs):
pass
def symint(self, *args, **kwargs):
pass
# TODO: Support bundling the entire repro into a zip file for ease of
# transferring around
class InputReader:
def __init__(self, save_dir=None, *, pbar=None):
# If None, we will generate random data instead. It's important
# to natively support this use case as it will allow people to
# share repros without including the real data, if the problem
# reproduces even on random data.
if save_dir is None:
log.warning("no save_dir specified, will generate random data")
self.store = ContentStoreReader(save_dir) if save_dir is not None else None
self.args = []
self.pbar = pbar
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
if self.pbar is not None:
self.pbar.update(1)
device = _device_or_default(device)
dtype_hint = _dtype_or_default(dtype_hint)
if self.store is not None and storage_hash is not None:
try:
storage = self.store.read_storage(storage_hash)
except FileNotFoundError:
pass
else:
if device != storage.device:
log.warning("device mismatch: %s != %s", device, storage.device)
# TODO: transfer it to the right device? But failing this
# way would be very mysterious! Would have been better
# not to store device in the serialized format...
return storage
log.warning("could not load %s, generating random data instead", storage_hash)
shape = (nbytes // dtype_hint.itemsize,)
stride = _stride_or_default(None, shape=shape)
return rand_strided(shape, stride, dtype_hint, device).untyped_storage()
def tensor(
self,
storage,
shape,
stride=None,
*,
storage_offset=None,
dtype=None,
requires_grad=None,
is_leaf=None,
**metadata,
):
stride = _stride_or_default(stride, shape=shape)
storage_offset = _storage_offset_or_default(storage_offset)
dtype = _dtype_or_default(dtype)
is_leaf = _is_leaf_or_default(is_leaf)
requires_grad = _requires_grad_or_default(requires_grad)
t = torch.tensor(
[], dtype=dtype, device=storage.device, requires_grad=requires_grad
)
with torch.no_grad():
t.set_(storage, storage_offset, shape, stride)
if not is_leaf:
# Fake up some autograd history in a very naughty way
with torch.enable_grad():
t = t.clone(memory_format=torch.preserve_format)
with torch.no_grad():
t.set_(storage, storage_offset, shape, stride)
assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf
torch._utils.set_tensor_metadata(t, metadata)
self.args.append(t)
return t # for BC
def symint(self, val):
self.args.append(val)
return val # for BC
# Here is our writer strategy:
# 1. We will stream all of the inputs to disk
# 2. You can now deterministically randomize the inputs, or reload
# the inputs from disk
# 3. You can YOLO run the script without the inputs, in which case
# we'll fill the inputs with random data and pray. This is the
# legacy behavior, but it's also useful if you want to find out
# if we're so broken even random inputs trigger it
# 4. We could offer an in process "check if the randomized thing
# works too" but this is delicate so we don't do it
class InputWriter:
def __init__(self, save_dir, *, stable_hash=False):
self._lines = []
# TODO: consider ensuring tensor and storage counters line up?
self.storage_counter = itertools.count()
self.save_dir = save_dir
self.store = (
ContentStoreWriter(save_dir, stable_hash=stable_hash)
if save_dir is not None
else None
)
self.seen_storages = {}
def lines(self):
r = [
"def load_args(reader):",
]
r.extend(f" {l}" for l in self._lines)
# In case we need to change the internal format of load_args
# in an FC-breaking way
r.append("load_args._version = 0")
return r
# Storages are untyped, but we need to initialize them with data if
# we don't have the real data, so we give a hint saying what kind
# of initialization may be appropriate
#
# If we had a FakeTensor, device_hint tells us what device should be
def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str:
ws = StorageWeakRef(untyped_storage)
v = self.seen_storages.get(ws)
if v is not None:
return v
v = f"buf{next(self.storage_counter)}"
maybe_dtype_hint = ""
if _dtype_or_default(None) != _dtype_or_default(dtype_hint):
maybe_dtype_hint = f", dtype_hint={dtype_hint!r}"
# TODO: being optional on device is kind of pointless as the default
# is CPU but most repros we care about are CUDA
maybe_device = ""
device = untyped_storage.device
if device.type == "meta":
assert device_hint is not None
device = device_hint
if _device_or_default(None) != device:
maybe_device = f", device={device!r}"
nbytes = untyped_storage.nbytes()
storage_hash = None
if self.store is not None and untyped_storage.device.type != "meta":
storage_hash = self.store.write_storage(untyped_storage)
self._lines.append(
f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})"
)
self.seen_storages[ws] = v
return v
def tensor(self, name, t) -> None:
storage = self.storage(
t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
)
args = []
# NB: this is positional, must come first
if _stride_or_default(None, shape=t.shape) != t.stride():
args.append(str(tuple(t.stride())))
if _dtype_or_default(None) != t.dtype:
args.append(f"dtype={t.dtype!r}")
if _storage_offset_or_default(None) != t.storage_offset():
args.append(f"storage_offset={t.storage_offset()!r}")
tensor_metadata = torch._utils.get_tensor_metadata(t)
if tensor_metadata:
args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items())
if _requires_grad_or_default(None) != t.requires_grad:
args.append(f"requires_grad={t.requires_grad!r}")
is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t)
if _is_leaf_or_default(None) != is_leaf:
args.append(f"is_leaf={is_leaf!r}")
self._lines.append(
"reader.tensor("
+ ", ".join([storage, str(tuple(t.shape)), *args])
+ f") # {name}"
)
# TODO: this doesn't actually symint atm
def symint(self, name, val) -> None:
if isinstance(val, torch.SymInt):
val = val.node.hint
self._lines.append(f"reader.symint({val!r}) # {name}")