Apply UFMT to low traffic torch modules (#106249)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
This commit is contained in:
Edward Z. Yang
2023-07-29 10:51:26 -04:00
committed by PyTorch MergeBot
parent a4ebc61f15
commit 3bf922a6ce
163 changed files with 8472 additions and 4412 deletions

View File

@ -1,13 +1,14 @@
import torch._C
from contextlib import contextmanager
import unittest.mock
import torch
import torch.utils._pytree as pytree
import itertools
import unittest.mock
from contextlib import contextmanager
from typing import Iterator
import torch._ops
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher', 'enable_pre_dispatch']
import torch
import torch._C
import torch._ops
import torch.utils._pytree as pytree
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
no_python_dispatcher = torch._C._DisablePythonDispatcher
enable_python_dispatcher = torch._C._EnablePythonDispatcher
@ -15,6 +16,7 @@ enable_pre_dispatch = torch._C._EnablePreDispatch
CROSSREF_FUNCTIONALIZE = False
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
"""
Warning: the set of overloads this will report is very subtle. It is precisely
@ -40,9 +42,12 @@ def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
for overload in packet:
yield getattr(packet, overload)
@contextmanager
def suspend_functionalization():
f_tls = torch._C._dispatch_tls_is_dispatch_key_included(torch._C.DispatchKey.Functionalize)
f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
torch._C.DispatchKey.Functionalize
)
f_rv = torch._C._functionalization_reapply_views_tls()
if f_tls:
torch._disable_functionalization()
@ -52,12 +57,18 @@ def suspend_functionalization():
if f_tls:
torch._enable_functionalization(reapply_views=f_rv)
def check_tensor_metadata_matches(nv, rv, desc):
assert callable(desc)
assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
same_strides, idx = torch._prims_common.check_significant_strides(nv, rv, only_cuda=False)
assert same_strides, f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
same_strides, idx = torch._prims_common.check_significant_strides(
nv, rv, only_cuda=False
)
assert (
same_strides
), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
def check_metadata_matches(n, r, desc):
assert callable(desc)
@ -71,6 +82,7 @@ def check_metadata_matches(n, r, desc):
continue
check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
class Lit:
def __init__(self, s):
self.s = s
@ -78,14 +90,19 @@ class Lit:
def __repr__(self):
return self.s
def _fmt(a: object) -> object:
if isinstance(a, torch.Tensor):
return Lit(f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})")
return Lit(
f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
)
else:
return a
def make_crossref_functionalize(op, final_key):
from torch._subclasses.fake_tensor import FakeTensorMode
# This case is pretty weird, suppress it for now
if op == torch.ops.aten.lift_fresh.default:
return final_key
@ -117,7 +134,9 @@ def make_crossref_functionalize(op, final_key):
with suspend_functionalization():
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
orig_f_args, orig_f_kwargs = pytree.tree_map(maybe_detach, (f_args, f_kwargs))
orig_f_args, orig_f_kwargs = pytree.tree_map(
maybe_detach, (f_args, f_kwargs)
)
with fake_mode:
f_r = op(*f_args, **f_kwargs)
r = op._op_dk(final_key, *args, **kwargs)
@ -126,14 +145,20 @@ def make_crossref_functionalize(op, final_key):
fmt_args = ", ".join(
itertools.chain(
(repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
(f"{k}={pytree.tree_map(_fmt, v)}" for k, v in orig_f_kwargs.items()),
(
f"{k}={pytree.tree_map(_fmt, v)}"
for k, v in orig_f_kwargs.items()
),
)
)
return f"{op}({fmt_args})"
check_metadata_matches(f_r, r, desc)
return r
return handler
# NB: enabling this is slow, don't do it in a hot loop. This is purely
# for debugging purposes.
@contextmanager
@ -142,7 +167,8 @@ def enable_crossref_functionalize():
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
try:
with enable_python_dispatcher(), unittest.mock.patch(
'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True):
"torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True
):
yield
finally:
for op in all_py_loaded_overloads():