mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/102979 Approved by: https://github.com/wconstab
1706 lines
51 KiB
Python
1706 lines
51 KiB
Python
import collections
|
|
import contextlib
|
|
import copy
|
|
import cProfile
|
|
import dataclasses
|
|
import datetime
|
|
import dis
|
|
import enum
|
|
import functools
|
|
import gc
|
|
import inspect
|
|
import itertools
|
|
import logging
|
|
import math
|
|
import operator
|
|
import os
|
|
import pstats
|
|
import sys
|
|
import textwrap
|
|
import time
|
|
import types
|
|
import typing
|
|
import weakref
|
|
from contextlib import contextmanager
|
|
from functools import lru_cache, wraps
|
|
from typing import Any, Dict, Tuple, Union
|
|
|
|
import torch._logging
|
|
from torch._guards import detect_fake_mode # noqa: F401
|
|
from . import config
|
|
|
|
try:
|
|
import numpy as np
|
|
|
|
HAS_NUMPY = True
|
|
except ModuleNotFoundError:
|
|
np = None # type: ignore[assignment]
|
|
HAS_NUMPY = False
|
|
|
|
try:
|
|
import torch_np
|
|
|
|
HAS_NUMPY_TORCH_INTEROP = True
|
|
except ModuleNotFoundError:
|
|
torch_np = None
|
|
HAS_NUMPY_TORCH_INTEROP = False
|
|
|
|
import importlib
|
|
|
|
import torch
|
|
import torch._functorch.config
|
|
import torch.fx.experimental.symbolic_shapes
|
|
import torch.utils.checkpoint
|
|
from torch import fx
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
from torch._subclasses.fake_tensor import FakeTensor
|
|
from torch.nn.modules.lazy import LazyModuleMixin
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
counters = collections.defaultdict(collections.Counter)
|
|
troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html"
|
|
nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html"
|
|
nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# profiling compilation time
|
|
compilation_metrics = collections.OrderedDict()
|
|
|
|
timer_counter = itertools.count()
|
|
|
|
|
|
def tabulate(rows, headers):
|
|
try:
|
|
import tabulate
|
|
|
|
return tabulate.tabulate(rows, headers=headers)
|
|
except ImportError:
|
|
return "\n".join(
|
|
", ".join(map(str, row)) for row in itertools.chain([headers], rows)
|
|
)
|
|
|
|
|
|
def dynamo_profiled(func):
|
|
@wraps(func)
|
|
def profile_wrapper(*args, **kwargs):
|
|
global timer_counter
|
|
datafn = (
|
|
func.__name__ + f"{next(timer_counter)}.profile"
|
|
) # Name the data file sensibly
|
|
prof = cProfile.Profile()
|
|
prof.enable()
|
|
retval = prof.runcall(func, *args, **kwargs)
|
|
prof.disable()
|
|
print(f"### Cprofile for {func.__name__} iter {next(timer_counter)} ###")
|
|
ps = pstats.Stats(prof)
|
|
ps.sort_stats(pstats.SortKey.TIME).print_stats(20)
|
|
ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20)
|
|
prof.dump_stats(datafn)
|
|
return retval
|
|
|
|
return profile_wrapper
|
|
|
|
|
|
frame_phase_timing = collections.OrderedDict()
|
|
|
|
curr_frame = 0
|
|
|
|
|
|
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
|
|
def increment_frame():
|
|
global curr_frame
|
|
curr_frame = curr_frame + 1
|
|
|
|
|
|
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
|
|
def reset_frame_count():
|
|
global curr_frame
|
|
frame_phase_timing.clear()
|
|
curr_frame = 0
|
|
|
|
|
|
op_count = 0
|
|
|
|
|
|
def increment_op_count(cnt):
|
|
global op_count
|
|
op_count += cnt
|
|
|
|
|
|
# Print a report of time spent so far
|
|
# Ex:
|
|
# TIMING:
|
|
# entire_frame_compile:8.574629999999999
|
|
# backend_compile:5.26806
|
|
def print_time_report():
|
|
total = 0
|
|
total_by_key = {}
|
|
for frame, timings in frame_phase_timing.items():
|
|
for key, timing in timings.items():
|
|
total += timing
|
|
if key not in total_by_key:
|
|
total_by_key[key] = timing
|
|
else:
|
|
total_by_key[key] += timing
|
|
|
|
out = "TIMING:"
|
|
for key, value in total_by_key.items():
|
|
out = f"{out} {key}:{round(value, 5)}"
|
|
|
|
print(out)
|
|
|
|
|
|
# dynamo_timed API works as a function decorator
|
|
# By wrapping a function in dynamo_timed, we can store a record in compilation_metrics
|
|
# where the key is the functions name.
|
|
# For example:
|
|
#
|
|
# @dynamo_timed
|
|
# def _foo(...):
|
|
#
|
|
# Would show up as an entry in our timing dict:
|
|
# OrderedDict([('bar.<locals>._foo', [0.083690, 0.23949, 3.1425e-05])])
|
|
# This is extremely useful for granular debugging.
|
|
#
|
|
# For a higher-level mode, pass a phase_name into dynamo_timed
|
|
# phase_names record an extra record into a separate compilation timing structure,
|
|
# one keyed on frame+name rather than function.
|
|
# The frame is incremented outside of this function, in def increment_frame() above.
|
|
def dynamo_timed(original_function=None, phase_name=None):
|
|
def dynamo_timed_inner(func):
|
|
@wraps(func)
|
|
def time_wrapper(*args, **kwargs):
|
|
key = func.__qualname__
|
|
if key not in compilation_metrics:
|
|
compilation_metrics[key] = []
|
|
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
|
|
t0 = time.time()
|
|
r = func(*args, **kwargs)
|
|
time_spent = time.time() - t0
|
|
# print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
|
|
compilation_metrics[key].append(time_spent)
|
|
if phase_name:
|
|
frame_key = str(curr_frame)
|
|
if frame_key not in frame_phase_timing:
|
|
frame_phase_timing[frame_key] = {}
|
|
assert (
|
|
phase_name not in frame_phase_timing[frame_key]
|
|
), f"Duplicate phase name {phase_name} for frame {frame_key}"
|
|
frame_phase_timing[frame_key][phase_name] = time_spent
|
|
return r
|
|
|
|
return time_wrapper
|
|
|
|
if original_function:
|
|
return dynamo_timed_inner(original_function)
|
|
return dynamo_timed_inner
|
|
|
|
|
|
def compile_times(repr="str", aggregate=False):
|
|
"""
|
|
Get metrics about torchdynamo frontend/backend compilation times.
|
|
|
|
Accumulates information from functions tagged with `@dynamo_timed`.
|
|
|
|
repr='str' returns a printable string for user interaction, and 'csv'
|
|
returns headers, rows which can be logged for output
|
|
|
|
aggregate causes values from multiple compilations (e.g. split graphs)
|
|
to be accumulated into one value. If false, expect more than one value
|
|
per metric.
|
|
"""
|
|
|
|
def fmt_fn(values, item_fn=lambda x: x):
|
|
if aggregate:
|
|
return item_fn(sum(values))
|
|
return ", ".join(map(item_fn, values))
|
|
|
|
if repr == "str":
|
|
rows = [
|
|
(k, fmt_fn(compilation_metrics[k], item_fn=lambda x: f"{x:.4f}"))
|
|
for k in compilation_metrics
|
|
]
|
|
out = "TorchDynamo compilation metrics:\n"
|
|
out += tabulate(rows, headers=("Function", "Runtimes (s)"))
|
|
return out
|
|
elif repr == "csv":
|
|
values = [
|
|
fmt_fn(v, item_fn=lambda x: f"{x:.6f}")
|
|
for v in compilation_metrics.values()
|
|
]
|
|
headers = list(compilation_metrics.keys())
|
|
return headers, values
|
|
|
|
|
|
tensortype_to_dtype = {
|
|
torch.FloatTensor: (torch.float32, torch.float),
|
|
torch.DoubleTensor: (torch.float64, torch.double),
|
|
torch.HalfTensor: (torch.float16, torch.half),
|
|
torch.BFloat16Tensor: (torch.bfloat16,),
|
|
torch.ByteTensor: (torch.uint8,),
|
|
torch.CharTensor: (torch.int8,),
|
|
torch.LongTensor: (torch.int64, torch.long),
|
|
torch.IntTensor: (torch.int32, torch.int),
|
|
torch.ShortTensor: (torch.int16, torch.short),
|
|
torch.BoolTensor: (torch.bool,),
|
|
}
|
|
|
|
|
|
class DuplicateWarningChecker:
|
|
def __init__(self, maxsize=4096):
|
|
self.maxsize = maxsize
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.set = collections.OrderedDict()
|
|
|
|
def add(self, key):
|
|
if key in self.set:
|
|
self.set.move_to_end(key, last=True)
|
|
if not config.verbose:
|
|
return False
|
|
else:
|
|
self.set[key] = None
|
|
while len(self.set) > self.maxsize:
|
|
self.set.popitem(last=False)
|
|
return True
|
|
|
|
|
|
graph_break_dup_warning_checker = DuplicateWarningChecker()
|
|
|
|
|
|
def setup_compile_debug():
|
|
compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
|
|
|
|
if compile_debug:
|
|
torch._logging.set_logs(
|
|
dynamo=logging.DEBUG,
|
|
aot=logging.DEBUG,
|
|
inductor=logging.DEBUG,
|
|
output_code=True, # this is off by default
|
|
)
|
|
return add_file_handler()
|
|
|
|
return contextlib.ExitStack()
|
|
|
|
|
|
def reset_graph_break_dup_checker():
|
|
graph_break_dup_warning_checker.reset()
|
|
|
|
|
|
def add_file_handler():
|
|
log_path = os.path.join(get_debug_dir(), "torchdynamo")
|
|
if not os.path.exists(log_path):
|
|
os.makedirs(log_path)
|
|
|
|
log_file_handler = logging.FileHandler(os.path.join(log_path, "debug.log"))
|
|
logger = logging.getLogger("torch._dynamo")
|
|
logger.addHandler(log_file_handler)
|
|
|
|
exitstack = contextlib.ExitStack()
|
|
exitstack.callback(lambda: logger.removeHandler(log_file_handler))
|
|
return exitstack
|
|
|
|
|
|
def setup_log_file():
|
|
exitstack = contextlib.ExitStack()
|
|
if config.log_file_name is not None:
|
|
log_file_handler = logging.FileHandler(config.log_file_name)
|
|
for logger in logging.get_loggers():
|
|
logger.addHandler(log_file_handler)
|
|
exitstack.callback(lambda: logger.removeHandler(log_file_handler))
|
|
return exitstack
|
|
|
|
return exitstack
|
|
|
|
|
|
def gen_record_file_name(exc, code):
|
|
return f"{get_debug_dir()}/error_recordings/\
|
|
{code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec"
|
|
|
|
|
|
def write_record_to_file(filename, exec_record):
|
|
try:
|
|
if os.path.exists(filename):
|
|
log.warning(
|
|
"Unable to write execution record %s; file already exists.", filename
|
|
)
|
|
else:
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
with open(filename, "wb") as f:
|
|
exec_record.dump(f)
|
|
except Exception:
|
|
log.error("Unable to write execution record %s", filename, exc_info=1)
|
|
|
|
|
|
def count_calls(g: fx.Graph):
|
|
c = 0
|
|
for n in g.nodes:
|
|
if "call" in n.op:
|
|
c += 1
|
|
return c
|
|
|
|
|
|
def identity(x):
|
|
return x
|
|
|
|
|
|
def nothing(*args, **kwargs):
|
|
pass
|
|
|
|
|
|
class ExactWeakKeyDictionary:
|
|
"""Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality"""
|
|
|
|
def __init__(self):
|
|
self.values = dict()
|
|
self.refs = dict()
|
|
|
|
def __getitem__(self, key):
|
|
return self.values[id(key)]
|
|
|
|
def get(self, key, default=None):
|
|
return self.values.get(id(key), default)
|
|
|
|
def __contains__(self, key):
|
|
return id(key) in self.values
|
|
|
|
def __setitem__(self, key, value):
|
|
idx = id(key)
|
|
if idx not in self.refs:
|
|
self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx))
|
|
self.values[idx] = value
|
|
|
|
def _remove_id(self, idx):
|
|
if idx in self.values:
|
|
del self.values[idx]
|
|
if idx in self.refs:
|
|
del self.refs[idx]
|
|
|
|
def clear(self):
|
|
self.refs.clear()
|
|
self.values.clear()
|
|
|
|
|
|
def istype(obj, allowed_types):
|
|
"""isinstance() without subclasses"""
|
|
if isinstance(allowed_types, (tuple, list, set)):
|
|
return type(obj) in allowed_types
|
|
return type(obj) is allowed_types
|
|
|
|
|
|
def is_typing(value):
|
|
if sys.version_info < (3, 9):
|
|
return isinstance(value, typing._GenericAlias)
|
|
else:
|
|
return isinstance(
|
|
value, (typing._SpecialGenericAlias, typing._UnionGenericAlias)
|
|
)
|
|
|
|
|
|
def is_numpy_int_type(value):
|
|
if HAS_NUMPY:
|
|
return istype(
|
|
value,
|
|
(
|
|
np.int8,
|
|
np.int16,
|
|
np.int32,
|
|
np.int64,
|
|
np.uint8,
|
|
np.uint16,
|
|
np.uint32,
|
|
np.uint64,
|
|
),
|
|
)
|
|
else:
|
|
return False
|
|
|
|
|
|
def is_numpy_float_type(value):
|
|
if HAS_NUMPY:
|
|
return istype(
|
|
value,
|
|
(
|
|
np.float16,
|
|
np.float32,
|
|
np.float64,
|
|
),
|
|
)
|
|
else:
|
|
return False
|
|
|
|
|
|
def is_numpy_ndarray(value):
|
|
if HAS_NUMPY:
|
|
return istype(value, np.ndarray)
|
|
else:
|
|
return False
|
|
|
|
|
|
def istensor(obj):
|
|
"""Check of obj is a tensor"""
|
|
tensor_list = (
|
|
torch.Tensor,
|
|
torch.nn.Parameter,
|
|
*config.traceable_tensor_subclasses,
|
|
)
|
|
tensor_list = tensor_list + (torch._subclasses.FakeTensor,)
|
|
return istype(obj, tensor_list)
|
|
|
|
|
|
def is_lazy_module(mod):
|
|
return isinstance(mod, LazyModuleMixin)
|
|
|
|
|
|
@functools.lru_cache(4096)
|
|
def print_once(*args):
|
|
print(*args)
|
|
|
|
|
|
def make_cell(val=None):
|
|
"""Some black magic to create a cell object that usually only exists in a closure"""
|
|
x = val
|
|
|
|
def f():
|
|
return x
|
|
|
|
assert len(f.__closure__) == 1
|
|
return f.__closure__[0]
|
|
|
|
|
|
def proxy_args_kwargs(args, kwargs):
|
|
try:
|
|
proxy_args = tuple(arg.as_proxy() for arg in args)
|
|
proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
|
|
return proxy_args, proxy_kwargs
|
|
except NotImplementedError as e:
|
|
from .exc import unimplemented
|
|
from .variables.base import typestr
|
|
|
|
raise unimplemented(
|
|
f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}"
|
|
) from e
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CleanupHook:
|
|
"""Remove a global variable when hook is called"""
|
|
|
|
scope: Dict[str, Any]
|
|
name: str
|
|
|
|
def __call__(self, *args):
|
|
CleanupManager.count -= 1
|
|
del self.scope[self.name]
|
|
|
|
@staticmethod
|
|
def create(scope, name, val):
|
|
assert name not in scope
|
|
CleanupManager.count += 1
|
|
scope[name] = val
|
|
return CleanupHook(scope, name)
|
|
|
|
|
|
class CleanupManager(ExactWeakKeyDictionary):
|
|
count = 0
|
|
|
|
def _remove_id(self, idx):
|
|
for hook in self.values[idx]:
|
|
hook()
|
|
super()._remove_id(idx)
|
|
|
|
|
|
CleanupManager.instance = CleanupManager()
|
|
|
|
|
|
def clone_tensor(x):
|
|
"""Clone the tensor and its gradient"""
|
|
y = x.clone().requires_grad_(x.requires_grad)
|
|
if x.is_leaf and x.grad is not None:
|
|
y.grad = x.grad.clone()
|
|
return y
|
|
|
|
|
|
def clone_input(x, *, dtype=None):
|
|
"""copy while preserving strides"""
|
|
# TODO: this is questionable
|
|
if isinstance(x, torch._subclasses.FakeTensor):
|
|
# this func fails on fake tensors in __torch_dispatch__
|
|
return x
|
|
|
|
def torch_clone(x):
|
|
y = torch.clone(x)
|
|
if x.is_leaf:
|
|
y.requires_grad_(x.requires_grad)
|
|
if x.is_leaf and x.grad is not None:
|
|
y.grad = clone_input(x.grad, dtype=dtype)
|
|
if hasattr(x, "_dynamo_dynamic_indices"):
|
|
y._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy()
|
|
return y
|
|
|
|
with torch.no_grad():
|
|
if x.device.type == "xla":
|
|
# Access data_ptr() for a xla tensor will cause crash
|
|
return torch_clone(x)
|
|
|
|
needed_size = sum(
|
|
(shape - 1) * stride for shape, stride in zip(x.size(), x.stride())
|
|
)
|
|
if x.is_quantized:
|
|
result = torch.empty_quantized((needed_size + 32,), x)
|
|
else:
|
|
result = torch.empty(
|
|
needed_size + 32, dtype=dtype or x.dtype, device=x.device
|
|
)
|
|
cache_line_offset = (
|
|
(x.data_ptr() - result.data_ptr()) % 32
|
|
) // x.element_size()
|
|
result.as_strided_(x.size(), x.stride(), cache_line_offset)
|
|
try:
|
|
result.copy_(x.clone())
|
|
if x.is_leaf:
|
|
result.requires_grad_(x.requires_grad)
|
|
if x.is_leaf and x.grad is not None:
|
|
result.grad = clone_input(x.grad, dtype=dtype)
|
|
except RuntimeError:
|
|
# RuntimeError: unsupported operation: more than one element of the written-to
|
|
# tensor refers to a single memory location. Please clone() the tensor before
|
|
# performing the operation.
|
|
return torch_clone(x)
|
|
if hasattr(x, "_dynamo_dynamic_indices"):
|
|
result._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy()
|
|
return result
|
|
|
|
|
|
def clone_inputs(example_inputs):
|
|
if type(example_inputs) is dict:
|
|
res = dict(example_inputs)
|
|
for key, value in res.items():
|
|
if isinstance(value, tuple):
|
|
res[key] = clone_inputs(value)
|
|
else:
|
|
assert isinstance(value, torch.Tensor), type(value)
|
|
res[key] = clone_input(value)
|
|
return res
|
|
|
|
res = list(example_inputs)
|
|
for i in range(len(res)):
|
|
if isinstance(res[i], torch.Tensor):
|
|
res[i] = clone_input(res[i])
|
|
return res
|
|
|
|
|
|
@contextmanager
|
|
def preserve_rng_state():
|
|
rng = torch.clone(torch.random.get_rng_state())
|
|
if torch.cuda.is_available():
|
|
cuda_rng = torch.clone(torch.cuda.get_rng_state())
|
|
try:
|
|
yield
|
|
finally:
|
|
torch.random.set_rng_state(rng)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.set_rng_state(cuda_rng)
|
|
|
|
|
|
def is_jit_model(model0):
|
|
return isinstance(
|
|
model0,
|
|
(
|
|
torch.jit._trace.TopLevelTracedModule,
|
|
torch.jit._script.RecursiveScriptModule,
|
|
torch.jit.ScriptFunction,
|
|
torch.jit.ScriptModule,
|
|
),
|
|
)
|
|
|
|
|
|
def torchscript(model, example_inputs, verbose=False):
|
|
if is_jit_model(model):
|
|
# already done?
|
|
return model
|
|
|
|
try:
|
|
return torch.jit.trace(model, example_inputs)
|
|
except Exception:
|
|
try:
|
|
return torch.jit.script(model)
|
|
except Exception:
|
|
if verbose:
|
|
log.exception("jit error")
|
|
else:
|
|
log.error("Both torch.jit.trace and torch.jit.script failed")
|
|
return None
|
|
|
|
|
|
def getfile(obj):
|
|
try:
|
|
return inspect.getfile(obj)
|
|
except TypeError:
|
|
return None
|
|
|
|
|
|
def is_namedtuple(obj):
|
|
"""Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple"""
|
|
return is_namedtuple_cls(type(obj))
|
|
|
|
|
|
def is_namedtuple_cls(cls):
|
|
"""Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple"""
|
|
try:
|
|
if issubclass(cls, tuple):
|
|
bases = getattr(cls, "__bases__", []) or [None]
|
|
module = getattr(cls, "__module__", None)
|
|
return module == "torch.return_types" or (
|
|
bases[0] is tuple and hasattr(cls, "_make") and hasattr(cls, "_fields")
|
|
)
|
|
except TypeError:
|
|
pass
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(1)
|
|
def namedtuple_fields(cls):
|
|
"""Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple"""
|
|
if cls is slice:
|
|
return ["start", "stop", "step"]
|
|
|
|
assert issubclass(cls, tuple)
|
|
if hasattr(cls, "_fields"):
|
|
# normal namedtuples
|
|
return cls._fields
|
|
|
|
@dataclasses.dataclass
|
|
class Marker:
|
|
index: int
|
|
|
|
# frustrating ones e.g. torch.return_types.max
|
|
assert cls.__module__ == "torch.return_types"
|
|
obj = cls(map(Marker, range(cls.n_fields)))
|
|
fields = [None] * cls.n_fields
|
|
for name in dir(obj):
|
|
if name[0] != "_" and isinstance(getattr(obj, name), Marker):
|
|
fields[getattr(obj, name).index] = name
|
|
return fields
|
|
|
|
|
|
def checkpoint_params(gm):
|
|
with torch.no_grad():
|
|
rng_state = torch.clone(torch.random.get_rng_state())
|
|
if torch.cuda.is_available():
|
|
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
|
|
saved_state = []
|
|
for param in itertools.chain(gm.parameters(), gm.buffers()):
|
|
saved_state.append((param, param._version, torch.clone(param)))
|
|
|
|
def restore():
|
|
with torch.no_grad():
|
|
torch.random.set_rng_state(rng_state)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.set_rng_state(cuda_rng_state)
|
|
for param, version, original_value in saved_state:
|
|
if param._version != version:
|
|
param.copy_(original_value)
|
|
|
|
return restore
|
|
|
|
|
|
def timed(model, example_inputs, times=1):
|
|
if torch.cuda.is_available():
|
|
synchronize = torch.cuda.synchronize
|
|
else:
|
|
synchronize = nothing
|
|
|
|
synchronize()
|
|
gc.collect()
|
|
torch.manual_seed(1337)
|
|
t0 = time.perf_counter()
|
|
for _ in range(times):
|
|
result = model(*example_inputs)
|
|
synchronize()
|
|
t1 = time.perf_counter()
|
|
return result, t1 - t0
|
|
|
|
|
|
def check_is_cuda(gm, example_inputs):
|
|
return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True)))
|
|
|
|
|
|
@lru_cache(32)
|
|
def rot_n_helper(n):
|
|
assert n > 1
|
|
vars = [f"v{i}" for i in range(n)]
|
|
rotated = reversed(vars[-1:] + vars[:-1])
|
|
fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})")
|
|
fn.__name__ = f"rot_{n}_helper"
|
|
return fn
|
|
|
|
|
|
def is_safe_constant(v):
|
|
if istype(v, (tuple, frozenset)):
|
|
return all(map(is_safe_constant, v))
|
|
return isinstance(v, (enum.Enum, type)) or istype(
|
|
v,
|
|
(
|
|
types.CodeType,
|
|
int,
|
|
float,
|
|
bool,
|
|
str,
|
|
bytes,
|
|
type(None),
|
|
slice,
|
|
type(type),
|
|
torch.device,
|
|
torch.dtype,
|
|
),
|
|
)
|
|
|
|
|
|
def check_constant_args(args, kwargs):
|
|
return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values()))
|
|
|
|
|
|
def check_unspec_python_args(args, kwargs):
|
|
from .variables.constant import ConstantVariable
|
|
from .variables.tensor import UnspecializedPythonVariable
|
|
|
|
unspec_count = 0
|
|
for x in itertools.chain(args, kwargs.values()):
|
|
if isinstance(x, UnspecializedPythonVariable):
|
|
unspec_count += 1
|
|
elif not isinstance(x, (UnspecializedPythonVariable, ConstantVariable)):
|
|
return False
|
|
else:
|
|
pass
|
|
|
|
return unspec_count > 0
|
|
|
|
|
|
def specialize_args_kwargs(tx, args, kwargs):
|
|
specialized_args = []
|
|
specialized_kwargs = {}
|
|
for x in args:
|
|
specialized_args.append(x.as_specialized(tx))
|
|
for k, v in kwargs.items():
|
|
specialized_kwargs.update({k: v.as_specialized(tx)})
|
|
return specialized_args, specialized_kwargs
|
|
|
|
|
|
dict_values = type(dict().values())
|
|
odict_values = type(collections.OrderedDict().values())
|
|
tuple_iterator = type(iter(tuple()))
|
|
tuple_iterator_len = tuple_iterator.__length_hint__
|
|
object_new = object.__new__
|
|
|
|
|
|
def nn_module_new(cls):
|
|
obj = object_new(cls)
|
|
torch.nn.Module.__init__(obj)
|
|
return obj
|
|
|
|
|
|
def product(it):
|
|
return functools.reduce(operator.mul, it, 1)
|
|
|
|
|
|
def tuple_iterator_getitem(it, index):
|
|
_, (obj,), start = it.__reduce__()
|
|
return obj[start + index]
|
|
|
|
|
|
def enum_repr(value, local):
|
|
enum_name = str(value)
|
|
|
|
name, val = enum_name.split(".")
|
|
scope = "L" if local else "G"
|
|
local_name = f'{scope}["{name}"].{val}'
|
|
return local_name
|
|
|
|
|
|
def dict_param_key_ids(value):
|
|
return {
|
|
id(k) for k in value.keys() if isinstance(k, (torch.nn.Parameter, torch.Tensor))
|
|
}
|
|
|
|
|
|
def dict_const_keys(value):
|
|
return {
|
|
k for k in value.keys() if not isinstance(k, (torch.nn.Parameter, torch.Tensor))
|
|
}
|
|
|
|
|
|
def dict_const_keys_repr(const_keys, *, local):
|
|
if any(isinstance(k, enum.Enum) for k in const_keys):
|
|
# To workaround repr(Enum) returning invalid global reference before python 3.11
|
|
# by calling enum_repr and removing quotes to render enum in guard code.
|
|
const_keys_str = f"{ {enum_repr(k, local=local) if isinstance(k, enum.Enum) else repr(k) for k in const_keys} }".replace(
|
|
"'", ""
|
|
)
|
|
else:
|
|
const_keys_str = f"{const_keys!r}"
|
|
return const_keys_str
|
|
|
|
|
|
def global_key_name(key):
|
|
return f"__dict_key_{id(key)}"
|
|
|
|
|
|
from torch._subclasses import ( # noqa: F401
|
|
FakeTensorMode,
|
|
UnsupportedFakeTensorException,
|
|
)
|
|
|
|
|
|
def wrap_fake_exception(fn):
|
|
try:
|
|
return fn()
|
|
except UnsupportedFakeTensorException as e:
|
|
from .exc import unimplemented
|
|
|
|
msg = f"Unsupported: {e.reason} with fake tensor propagation."
|
|
log.warning(msg)
|
|
raise unimplemented(msg) from e
|
|
|
|
|
|
def deepcopy_to_fake_tensor(obj, fake_mode):
|
|
with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode):
|
|
return wrap_fake_exception(lambda: copy.deepcopy(obj))
|
|
|
|
|
|
def rmse(ref, res):
|
|
"""
|
|
Calculate root mean squared error
|
|
"""
|
|
return torch.sqrt(torch.mean(torch.square(ref - res)))
|
|
|
|
|
|
def same(
|
|
ref,
|
|
res,
|
|
fp64_ref=None,
|
|
cos_similarity=False,
|
|
tol=1e-4,
|
|
equal_nan=False,
|
|
exact_dtype=True,
|
|
relax_numpy_equality=False,
|
|
ignore_non_fp=False,
|
|
log_error=log.error,
|
|
):
|
|
"""Check correctness to see if ref and res match"""
|
|
if fp64_ref is None:
|
|
fp64_ref = ref
|
|
if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
|
|
assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}"
|
|
return len(ref) == len(res) and all(
|
|
same(
|
|
ai,
|
|
bi,
|
|
fp64_refi,
|
|
cos_similarity,
|
|
tol,
|
|
equal_nan,
|
|
exact_dtype,
|
|
relax_numpy_equality,
|
|
ignore_non_fp,
|
|
log_error=log_error,
|
|
)
|
|
for ai, bi, fp64_refi in zip(ref, res, fp64_ref)
|
|
)
|
|
elif isinstance(ref, dict):
|
|
assert isinstance(res, dict)
|
|
assert set(ref.keys()) == set(
|
|
res.keys()
|
|
), f"keys mismatch {set(ref.keys())} == {set(res.keys())}"
|
|
for k in sorted(ref.keys()):
|
|
if not (
|
|
same(
|
|
ref[k],
|
|
res[k],
|
|
fp64_ref[k],
|
|
cos_similarity=cos_similarity,
|
|
tol=tol,
|
|
equal_nan=equal_nan,
|
|
exact_dtype=exact_dtype,
|
|
relax_numpy_equality=relax_numpy_equality,
|
|
ignore_non_fp=ignore_non_fp,
|
|
log_error=log_error,
|
|
)
|
|
):
|
|
log_error("Accuracy failed for key name %s", k)
|
|
return False
|
|
return True
|
|
elif isinstance(ref, torch.Tensor):
|
|
assert not isinstance(ref, torch._subclasses.FakeTensor)
|
|
assert not isinstance(res, torch._subclasses.FakeTensor)
|
|
|
|
if ref.is_sparse:
|
|
assert res.is_sparse
|
|
ref = ref.to_dense()
|
|
res = res.to_dense()
|
|
assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}"
|
|
if exact_dtype:
|
|
if ref.dtype != res.dtype:
|
|
log_error("dtype mismatch %s, %s", ref.dtype, res.dtype)
|
|
return False
|
|
if ref.dtype == torch.bool:
|
|
if ignore_non_fp:
|
|
return True
|
|
# triton stores bool as int8, so add this for more accurate checking
|
|
r = torch.allclose(
|
|
ref.to(dtype=torch.uint8),
|
|
res.to(dtype=torch.uint8),
|
|
atol=tol,
|
|
rtol=tol,
|
|
equal_nan=equal_nan,
|
|
)
|
|
if not r:
|
|
log_error("Accuracy failed: uint8 tensor did not match")
|
|
return r
|
|
if cos_similarity:
|
|
ref = ref.flatten().to(torch.float32)
|
|
res = res.flatten().to(torch.float32)
|
|
if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=True):
|
|
# early exit that handles zero/nan better
|
|
# cosine_similarity(zeros(10), zeros(10), dim=0) is 0
|
|
return True
|
|
score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6)
|
|
if score < 0.99:
|
|
log.warning("Similarity score=%s", score.cpu().detach().item())
|
|
return score >= 0.99
|
|
else:
|
|
if not exact_dtype:
|
|
ref = ref.to(res.dtype)
|
|
|
|
# First try usual allclose
|
|
if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=equal_nan):
|
|
return True
|
|
|
|
# Check error from fp64 version
|
|
if fp64_ref.dtype == torch.float64:
|
|
ref_error = rmse(fp64_ref, ref).item()
|
|
res_error = rmse(fp64_ref, res).item()
|
|
multiplier = 2.0
|
|
|
|
if (
|
|
fp64_ref.numel() < 1000
|
|
or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1)
|
|
# large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE
|
|
or tol >= 2 * 1e-2
|
|
):
|
|
# In the presence of noise, noise might dominate our error
|
|
# metric for smaller tensors.
|
|
# Similary, for 1x1 kernels, there seems to be high noise with amp.
|
|
multiplier = 3.0
|
|
|
|
passes_test = res_error <= (multiplier * ref_error + tol / 10.0)
|
|
if not passes_test:
|
|
log_error(
|
|
"RMSE (res-fp64): %.5f, (ref-fp64): %.5f and shape=%s",
|
|
res_error,
|
|
ref_error,
|
|
res.size(),
|
|
)
|
|
# import pdb; pdb.set_trace()
|
|
return passes_test
|
|
|
|
if ignore_non_fp:
|
|
return True
|
|
|
|
log_error("Accuracy failed: allclose not within tol=%s", tol)
|
|
return False
|
|
elif isinstance(ref, (str, int, type(None), bool, torch.device)):
|
|
if ignore_non_fp:
|
|
return True
|
|
r = ref == res
|
|
if not r:
|
|
log_error("Accuracy failed (%s): %s != %s", type(ref), ref, res)
|
|
return r
|
|
elif isinstance(ref, float):
|
|
r = math.isclose(ref, res, rel_tol=tol, abs_tol=tol)
|
|
if not r:
|
|
log_error(
|
|
"Accuracy failed (float): %s != %s (within tol=%s)", ref, res, tol
|
|
)
|
|
return r
|
|
elif is_numpy_int_type(ref) or is_numpy_float_type(ref):
|
|
if relax_numpy_equality and not (
|
|
is_numpy_int_type(res) or is_numpy_float_type(res)
|
|
):
|
|
ref = ref.item()
|
|
r = (type(ref) is type(res)) and (ref == res)
|
|
if not r:
|
|
log_error("Accuracy failed (numpy): %s != %s", ref, res)
|
|
return r
|
|
elif is_numpy_ndarray(ref):
|
|
return (type(ref) is type(res)) and (ref == res).all()
|
|
elif type(ref).__name__ in (
|
|
"MaskedLMOutput",
|
|
"Seq2SeqLMOutput",
|
|
"CausalLMOutputWithCrossAttentions",
|
|
"LongformerMaskedLMOutput",
|
|
"Instances",
|
|
"SquashedNormal",
|
|
"Boxes",
|
|
"Normal",
|
|
"TanhTransform",
|
|
"Foo",
|
|
"Variable",
|
|
):
|
|
assert type(ref) is type(res)
|
|
return all(
|
|
same(
|
|
getattr(ref, key),
|
|
getattr(res, key),
|
|
getattr(fp64_ref, key),
|
|
cos_similarity=cos_similarity,
|
|
tol=tol,
|
|
equal_nan=equal_nan,
|
|
exact_dtype=exact_dtype,
|
|
relax_numpy_equality=relax_numpy_equality,
|
|
ignore_non_fp=ignore_non_fp,
|
|
log_error=log_error,
|
|
)
|
|
for key in ref.__dict__.keys()
|
|
)
|
|
else:
|
|
raise RuntimeError(f"unsupported type: {type(ref).__name__}")
|
|
|
|
|
|
def format_func_info(code):
|
|
short_filename = code.co_filename.split("/")[-1]
|
|
return f"'{code.co_name}' ({short_filename}:{code.co_firstlineno})"
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def disable_cache_limit():
|
|
prior = config.cache_size_limit
|
|
config.cache_size_limit = sys.maxsize
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
config.cache_size_limit = prior
|
|
|
|
|
|
# map from transformed code back to original user code
|
|
orig_code_map = ExactWeakKeyDictionary()
|
|
|
|
# keep a record of code_obj -> list of guard failure reasons for logging
|
|
guard_failures = collections.defaultdict(list)
|
|
|
|
# Keep a record of graph break reasons for logging
|
|
graph_break_reasons = list()
|
|
|
|
# keep record of compiled code, if we are in "error if recompile"
|
|
# to track code that dynamo has compiled previously
|
|
seen_code_map = ExactWeakKeyDictionary()
|
|
|
|
|
|
class CompileProfiler:
|
|
"""Utility for profiling how and what dynamo would compile.
|
|
|
|
Can be used for
|
|
* diagnosing recompilation issues
|
|
* determining an appropriate compile cache limit
|
|
* (TODO)confirming which functions got compiled/skipped
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.frame_count = 0
|
|
self.op_count = 0
|
|
self.backend_ctx_ctor = lambda: disable_cache_limit()
|
|
|
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
|
self.frame_count += 1
|
|
for node in gm.graph.nodes:
|
|
if "call" in node.op:
|
|
self.op_count += 1
|
|
return gm.forward
|
|
|
|
def __enter__(self):
|
|
self.old_report_guard_failure = config.report_guard_failures
|
|
config.report_guard_failures = True
|
|
return self
|
|
|
|
def __exit__(self, typ, val, traceback):
|
|
config.report_guard_failures = self.old_report_guard_failure
|
|
|
|
def get_metrics(self):
|
|
return {"guard_failures": guard_failures}
|
|
|
|
def report(self):
|
|
metrics = self.get_metrics()
|
|
gf = metrics["guard_failures"]
|
|
|
|
def num_recompiles(code):
|
|
return len(gf[code])
|
|
|
|
def recompile_reasons(code):
|
|
return "\n".join([str(x) for x in gf[code]])
|
|
|
|
summarized_gf = [
|
|
[format_func_info(code), num_recompiles(code), recompile_reasons(code)]
|
|
for code in gf
|
|
]
|
|
|
|
def graph_break_report():
|
|
if "graph_break" in counters:
|
|
graph_breaks = counters["graph_break"]
|
|
return tabulate(
|
|
[[msg, graph_breaks[msg]] for msg in graph_breaks],
|
|
headers=["Graph Break Reason", "Count"],
|
|
)
|
|
|
|
def recompilation_report():
|
|
if len(gf):
|
|
max_recompiles = max([num_recompiles(code) for code in gf])
|
|
recomp_table = tabulate(
|
|
summarized_gf,
|
|
headers=["Function", "Recompiles", "Recompile Reasons"],
|
|
)
|
|
return recomp_table + textwrap.dedent(
|
|
f"""
|
|
|
|
Set torch._dynamo.config.cache_size_limit to {max_recompiles} to avoid being cache limited.
|
|
"""
|
|
)
|
|
|
|
report = textwrap.dedent(
|
|
"""
|
|
Torchdynamo Profiler Report
|
|
===========================
|
|
|
|
Graph Breaks
|
|
------------
|
|
Graph breaks happen when torchdynamo encounters code it can't safely trace.
|
|
If you want to find out why breaks are happening, check below for each break reason
|
|
You may gain additional insight by passing `fullgraph=True` to torch.compile,
|
|
to stop at the first break.
|
|
|
|
"""
|
|
)
|
|
report += graph_break_report() or "No graph breaks detected."
|
|
report += textwrap.dedent(
|
|
"""
|
|
|
|
Recompilation
|
|
-------------
|
|
These subgraphs were recompiled more than once due to guard failures
|
|
Guard failures indicate some condition assumed to be static by the tracer changed,
|
|
making it unsafe to reuse the compiled program.
|
|
|
|
"""
|
|
)
|
|
report += recompilation_report() or "No recompilation detected.\n"
|
|
return report
|
|
|
|
|
|
# return same dir unless user changes config between calls
|
|
@functools.lru_cache(None)
|
|
def _get_debug_dir(root_dir):
|
|
dir_name = (
|
|
"run_"
|
|
+ datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
|
|
# use pid to avoid conflicts among ranks
|
|
+ "-pid_"
|
|
+ str(os.getpid())
|
|
)
|
|
return os.path.join(root_dir, dir_name)
|
|
|
|
|
|
def get_debug_dir():
|
|
debug_root = config.debug_dir_root
|
|
return _get_debug_dir(debug_root)
|
|
|
|
|
|
def get_fake_value(node, tx):
|
|
"""
|
|
Run the computation represented by `node` using fake tensors and return the result.
|
|
"""
|
|
from .exc import (
|
|
TorchRuntimeError,
|
|
unimplemented,
|
|
Unsupported,
|
|
UserError,
|
|
UserErrorType,
|
|
)
|
|
|
|
op = node.op
|
|
|
|
def fake_wrapper(e):
|
|
if isinstance(e, torch.Tensor):
|
|
assert isinstance(e, FakeTensor)
|
|
return e
|
|
|
|
def visit(n: torch.fx.Node):
|
|
return n.meta["example_value"]
|
|
|
|
args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit)
|
|
args = tree_map(fake_wrapper, args)
|
|
kwargs = tree_map(fake_wrapper, kwargs)
|
|
|
|
nnmodule = None
|
|
if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module):
|
|
# If the first argument is nn.Module, should copy to fake mode.
|
|
args = (deepcopy_to_fake_tensor(args[0], tx.fake_mode),) + tuple(args[1:])
|
|
|
|
if op == "call_module":
|
|
nnmodule = tx.output.nn_modules[node.target]
|
|
|
|
if is_lazy_module(nnmodule) and hasattr(nnmodule, "_initialize_hook"):
|
|
# In the case of a lazy module, we want to run
|
|
# the pre-hooks which initialize it.
|
|
# Afterwards, lazy module deletes its pre-hooks
|
|
# to avoid treating it as lazy on subsequent recompile.
|
|
nnmodule._infer_parameters(nnmodule, args)
|
|
|
|
# no matter it's lazy module or not, we should copy to fake mode.
|
|
nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode)
|
|
|
|
try:
|
|
with tx.fake_mode, enable_python_dispatcher():
|
|
return wrap_fake_exception(
|
|
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
|
|
)
|
|
except Unsupported:
|
|
raise
|
|
except RuntimeError as e:
|
|
cause = e
|
|
if e.__cause__ is not None:
|
|
cause = e.__cause__
|
|
|
|
if isinstance(
|
|
cause, torch._subclasses.fake_tensor.DataDependentOutputException
|
|
):
|
|
unimplemented(f"data dependent operator: {cause.func}")
|
|
elif isinstance(
|
|
cause, torch._subclasses.fake_tensor.DynamicOutputShapeException
|
|
):
|
|
unimplemented(f"dynamic shape operator: {cause.func}")
|
|
elif isinstance(
|
|
cause, torch._subclasses.fake_tensor.UnsupportedOperatorException
|
|
):
|
|
unimplemented(
|
|
f"unsupported operator: {cause.func} (see "
|
|
"https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0"
|
|
" for how to fix)"
|
|
)
|
|
elif isinstance(
|
|
cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
|
|
):
|
|
unimplemented("guard on data-dependent symbolic int/float")
|
|
elif isinstance(cause, torch.utils._sympy.value_ranges.ValueRangeError):
|
|
raise UserError(UserErrorType.CONSTRAIN_VIOLATION, e.args[0]) from e
|
|
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
|
|
|
|
|
|
def run_node(tracer, node, args, kwargs, nnmodule):
|
|
"""
|
|
Runs a given node, with the given args and kwargs.
|
|
|
|
Behavior is dicatated by a node's op.
|
|
|
|
run_node is useful for extracting real values out of nodes.
|
|
See get_real_value for more info on common usage.
|
|
|
|
Note: The tracer arg is only used for 'get_attr' ops
|
|
Note: The nnmodule arg is only used for 'call_module' ops
|
|
|
|
Nodes that are not call_function, call_method, call_module, or get_attr will
|
|
raise an AssertionError.
|
|
"""
|
|
op = node.op
|
|
try:
|
|
if op == "call_function":
|
|
return node.target(*args, **kwargs)
|
|
elif op == "call_method":
|
|
return getattr(args[0], node.target)(*args[1:], **kwargs)
|
|
elif op == "call_module":
|
|
assert nnmodule is not None
|
|
return nnmodule(*args, **kwargs)
|
|
elif op == "get_attr":
|
|
return tracer.get_submodule(node.target)
|
|
elif op == "placeholder":
|
|
assert "example_value" in node.meta
|
|
return node.meta["example_value"]
|
|
except Exception as e:
|
|
fn_str = f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n"
|
|
raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
|
|
|
|
raise AssertionError(op)
|
|
|
|
|
|
def get_real_value(node, tracer):
|
|
"""
|
|
Run the actual computation represented by `node` and return the result.
|
|
This will execute any dependent nodes in the graph as well.
|
|
"""
|
|
from .exc import TorchRuntimeError
|
|
|
|
cache = tracer.real_value_cache
|
|
if node in cache:
|
|
return cache[node]
|
|
|
|
op = node.op
|
|
args, kwargs = torch.fx.node.map_arg(
|
|
(node.args, node.kwargs),
|
|
lambda n: get_real_value(n, tracer),
|
|
)
|
|
|
|
if op == "call_module":
|
|
nn_module = tracer.output_graph.nn_modules[node.target]
|
|
if not is_lazy_module(nn_module):
|
|
nn_module = copy.deepcopy(nn_module)
|
|
else:
|
|
# In the case of a lazy module, we want to run
|
|
# the pre-hooks which initialize it
|
|
nn_module(*args, **kwargs)
|
|
else:
|
|
nn_module = None
|
|
|
|
try:
|
|
real_value = run_node(tracer, node, args, kwargs, nn_module)
|
|
cache[node] = real_value
|
|
except RuntimeError as e:
|
|
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
|
|
return real_value
|
|
|
|
|
|
def assert_no_fake_params_or_buffers(gm):
|
|
from torch._subclasses.fake_tensor import FakeTensorConfig
|
|
|
|
def stack_or_hint(t):
|
|
if FakeTensorConfig.debug:
|
|
import traceback
|
|
|
|
return f"FAKE TENSOR CREATION TRACEBACK: \n {traceback.format_list(t._debug_trace)}"
|
|
else:
|
|
return "Enable TORCH_FAKE_TENSOR_DEBUG=1 to get creation stack traces on fake tensors."
|
|
|
|
for name, buffer in gm.named_buffers():
|
|
assert not isinstance(
|
|
buffer, torch._subclasses.FakeTensor
|
|
), f"Unexpected fake buffer {name} {stack_or_hint(buffer)}"
|
|
for name, param in gm.named_parameters():
|
|
assert not isinstance(
|
|
param, torch._subclasses.FakeTensor
|
|
), f"Unexpected fake param {name} {stack_or_hint(param)}"
|
|
|
|
|
|
def fqn(obj: Any):
|
|
"""
|
|
Returns the fully qualified name of the object.
|
|
"""
|
|
return f"{obj.__module__}.{obj.__qualname__}"
|
|
|
|
|
|
def ifdyn(count1, count2):
|
|
if torch._dynamo.config.dynamic_shapes:
|
|
return count1
|
|
else:
|
|
return count2
|
|
|
|
|
|
def ifdynstaticdefault(count1, count2):
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
return count1
|
|
else:
|
|
return count2
|
|
|
|
|
|
def ifunspec(count1, count2):
|
|
if torch._dynamo.config.dynamic_shapes and not torch._dynamo.config.specialize_int:
|
|
return count1
|
|
else:
|
|
return count2
|
|
|
|
|
|
def import_submodule(mod: types.ModuleType):
|
|
"""
|
|
Ensure all the files in a given submodule are imported
|
|
"""
|
|
for filename in sorted(os.listdir(os.path.dirname(mod.__file__))):
|
|
if filename.endswith(".py") and filename[0] != "_":
|
|
importlib.import_module(f"{mod.__name__}.{filename[:-3]}")
|
|
|
|
|
|
def object_has_getattribute(value: Any):
|
|
try:
|
|
if isinstance(
|
|
inspect.getattr_static(type(value), "__getattribute__"),
|
|
types.FunctionType,
|
|
):
|
|
return True
|
|
except AttributeError:
|
|
pass
|
|
return False
|
|
|
|
|
|
def get_custom_getattr(value: Any):
|
|
try:
|
|
getattr_fn = inspect.getattr_static(type(value), "__getattr__")
|
|
except AttributeError:
|
|
getattr_fn = None
|
|
if getattr_fn is torch.nn.Module.__getattr__:
|
|
# ignore this case of getattr
|
|
getattr_fn = None
|
|
return getattr_fn
|
|
|
|
|
|
class TensorStaticReason(enum.Enum):
|
|
PARAMETER = 2
|
|
CONFIG_NOT_DYN = 3
|
|
NOT_TENSOR = 4
|
|
NN_MODULE_PROPERTY = 5
|
|
|
|
|
|
def tensor_static_reason_to_message(reason: TensorStaticReason):
|
|
if reason == TensorStaticReason.PARAMETER:
|
|
return "mark_dynamic on parameter, parameters are always static today."
|
|
if reason == TensorStaticReason.CONFIG_NOT_DYN:
|
|
return "mark_dynamic usage with dynamic_shapes=False is not yet supported"
|
|
if reason == TensorStaticReason.NOT_TENSOR:
|
|
return "mark_dynamic on a non tensor, how did this happen?"
|
|
if reason == TensorStaticReason.NN_MODULE_PROPERTY:
|
|
return "tensor is static because it is nn module associated."
|
|
raise AssertionError(f"Illegal reason {reason}")
|
|
|
|
|
|
def tensor_always_has_static_shape(
|
|
tensor: Union[torch.Tensor, Any], is_tensor: bool, guard_source: "GuardSource"
|
|
) -> Tuple[bool, TensorStaticReason]:
|
|
"""
|
|
Given a tensor, source, and is_tensor flag, determine if a shape should be static.
|
|
|
|
Args:
|
|
tensor - the real tensor to evaluate, parameters force a static shape.
|
|
is_tensor - internal dynamo check, esentially "is_tensor": target_cls is TensorVariable,
|
|
tensors not in a TensorVariable for whatever reason are forced static.
|
|
|
|
Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape.
|
|
The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed.
|
|
"""
|
|
if type(tensor) is torch.nn.Parameter:
|
|
return True, TensorStaticReason.PARAMETER
|
|
if config.dynamic_shapes is False:
|
|
return True, TensorStaticReason.CONFIG_NOT_DYN
|
|
if not is_tensor:
|
|
return True, TensorStaticReason.NOT_TENSOR
|
|
if guard_source.is_nn_module():
|
|
return True, TensorStaticReason.NN_MODULE_PROPERTY
|
|
return False, None
|
|
|
|
|
|
class LazyString:
|
|
def __init__(self, func, *args, **kwargs):
|
|
self.func = func
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
def __str__(self):
|
|
return self.func(*self.args, **self.kwargs)
|
|
|
|
|
|
def lazy_format_graph_code(name, gm, maybe_id=None):
|
|
def format_name():
|
|
if maybe_id is not None:
|
|
return f"{name} {maybe_id}"
|
|
else:
|
|
return name
|
|
|
|
return LazyString(
|
|
lambda: _format_graph_code(
|
|
f"===== {format_name()} =====\n",
|
|
gm.forward.__code__.co_filename,
|
|
gm.print_readable(print_output=False),
|
|
)
|
|
)
|
|
|
|
|
|
def _format_graph_code(name, filename, graph_str):
|
|
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
|
|
|
|
|
|
def lazy_format_graph_tabular(fn_name, gm):
|
|
def inner():
|
|
try:
|
|
from tabulate import tabulate # TODO: Check that this is installed
|
|
except ImportError:
|
|
return (
|
|
"Tabulate module missing, please install tabulate to log the graph in tabular format, logging code instead:\n"
|
|
+ str(lazy_format_graph_code(fn_name, gm))
|
|
)
|
|
|
|
node_specs = [
|
|
[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes
|
|
]
|
|
graph_str = tabulate(
|
|
node_specs, headers=["opcode", "name", "target", "args", "kwargs"]
|
|
)
|
|
return _format_graph_code(fn_name, gm.forward.__code__.co_filename, graph_str)
|
|
|
|
return LazyString(inner)
|
|
|
|
|
|
def format_bytecode(prefix, name, filename, line_no, code):
|
|
return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n"
|
|
|
|
|
|
forward_hook_names = ["_forward_pre_hooks", "_forward_hooks"]
|
|
backward_hook_names = ["_backward_pre_hooks", "_backward_hooks"]
|
|
state_dict_hook_names = [
|
|
"_state_dict_pre_hooks",
|
|
"_state_dict_hooks",
|
|
"_load_state_dict_pre_hooks",
|
|
"_load_state_dict_post_hooks",
|
|
]
|
|
all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names
|
|
|
|
|
|
def nnmodule_has_hooks(
|
|
mod,
|
|
check_forward_hooks=False,
|
|
check_backward_hooks=False,
|
|
check_state_dict_hooks=False,
|
|
):
|
|
"""
|
|
Sometimes its useful to differentiate between types of hooks such as forward/backward/pre
|
|
hooks executed during module.__call__, and state_dict hooks which are executed separately.
|
|
"""
|
|
hook_dicts_to_check = []
|
|
check_all_hooks = (
|
|
not check_forward_hooks
|
|
and not check_backward_hooks
|
|
and not check_state_dict_hooks
|
|
)
|
|
if check_forward_hooks or check_all_hooks:
|
|
hook_dicts_to_check.extend(forward_hook_names)
|
|
if check_backward_hooks or check_all_hooks:
|
|
hook_dicts_to_check.extend(backward_hook_names)
|
|
if check_state_dict_hooks:
|
|
hook_dicts_to_check.extend(state_dict_hook_names)
|
|
return any(len(getattr(mod, x)) > 0 for x in hook_dicts_to_check if hasattr(mod, x))
|
|
|
|
|
|
def to_numpy_helper(value):
|
|
"""Convert tensor and torch_np.ndarray to numpy.ndarray."""
|
|
if isinstance(value, torch_np.ndarray):
|
|
return value.tensor.numpy()
|
|
elif isinstance(value, torch.Tensor):
|
|
return value.numpy()
|
|
elif isinstance(value, (tuple, list)):
|
|
return type(value)(to_numpy_helper(obj) for obj in value)
|
|
else:
|
|
return value
|
|
|
|
|
|
def numpy_to_tensor(value):
|
|
"""Convert torch_np.ndarray to tensor, leave other types intact. If a list/tuple, loop through it to convert."""
|
|
if isinstance(value, torch_np.ndarray):
|
|
return value.tensor
|
|
elif isinstance(value, (tuple, list)):
|
|
return type(value)(numpy_to_tensor(obj) for obj in value)
|
|
else:
|
|
return value
|
|
|
|
|
|
class numpy_to_tensor_wrapper:
|
|
def __init__(self, f):
|
|
self.f = f
|
|
self.__name__ = "wrapped_" + self.f.__name__
|
|
|
|
def __repr__(self):
|
|
return f"<Wrapped function <original {self.f.__name__}>>"
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
out = self.f(*args, **kwargs)
|
|
return numpy_to_tensor(out)
|
|
|
|
|
|
def numpy_attr_wrapper(obj, name):
|
|
if isinstance(obj, torch_np.ndarray):
|
|
out = getattr(obj, name)
|
|
return numpy_to_tensor(out)
|
|
elif isinstance(obj, torch.Tensor):
|
|
out = getattr(torch_np.ndarray(obj), name)
|
|
return numpy_to_tensor(out)
|
|
|
|
|
|
def defake(x):
|
|
if not isinstance(x, FakeTensor):
|
|
return x
|
|
if x._has_symbolic_sizes_strides:
|
|
size = [
|
|
s.node.shape_env.size_hint(s.node.expr)
|
|
if isinstance(s, torch.SymInt)
|
|
else s
|
|
for s in x.size()
|
|
]
|
|
stride = [
|
|
s.node.shape_env.size_hint(s.node.expr)
|
|
if isinstance(s, torch.SymInt)
|
|
else s
|
|
for s in x.stride()
|
|
]
|
|
else:
|
|
size = x.size()
|
|
stride = x.stride()
|
|
y = torch.empty_strided(
|
|
size,
|
|
stride,
|
|
dtype=x.dtype,
|
|
device=x.device,
|
|
requires_grad=x.requires_grad,
|
|
)
|
|
y.zero_()
|
|
return y
|
|
|
|
|
|
# NB: The dictionary has to be created lazily after TorchPatcher is called so
|
|
# that we pick up the disabled torch.utils.checkpoint wrapper. Therefore, it is
|
|
# sitting in a separate function.
|
|
@functools.lru_cache(None)
|
|
def higher_order_op_converter():
|
|
import torch._higher_order_ops.wrap
|
|
|
|
return {
|
|
torch.utils.checkpoint.checkpoint: torch._higher_order_ops.wrap.wrap_activation_checkpoint,
|
|
}
|
|
|
|
|
|
def requires_higher_order_op(obj):
|
|
return obj in higher_order_op_converter()
|
|
|
|
|
|
def get_higher_order_op(obj):
|
|
if (
|
|
obj is torch.utils.checkpoint.checkpoint
|
|
and not torch._functorch.config.functionalize_rng_ops
|
|
):
|
|
from .exc import unimplemented
|
|
|
|
# TODO - functionalize_rng_ops flags cannot be turned ON by default
|
|
# because 1) Performance concerns - seed and offset are read and passed
|
|
# to each AOT graph 2) Inductor has rand-specific optimizations and
|
|
# there is work remaining to compose them together with
|
|
# functionalization.
|
|
#
|
|
# Until we make it ON by default, we will have to ask users to turn on
|
|
# this flag manually. TODO - Revisit if there is a simpler way to
|
|
# resolve this problem.
|
|
torch._logging.warning_once(
|
|
log,
|
|
"torch.compile on activation checkpointing is an experimental feature. "
|
|
"Please manually set torch._functorch.config.functionalize_rng_ops=True "
|
|
"to run torch.compile with activation checkpointing. Without this flag, "
|
|
"checkpointed function will not get compiled and fallback to eager.",
|
|
)
|
|
unimplemented(
|
|
"torch.compile requires functioanlization of rng ops to be turned on"
|
|
)
|
|
return higher_order_op_converter().get(obj)
|