mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
a4ebc61f15
commit
3bf922a6ce
@ -1,13 +1,14 @@
|
||||
import torch._C
|
||||
from contextlib import contextmanager
|
||||
import unittest.mock
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
import itertools
|
||||
import unittest.mock
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
import torch._ops
|
||||
|
||||
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher', 'enable_pre_dispatch']
|
||||
import torch
|
||||
import torch._C
|
||||
import torch._ops
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
|
||||
|
||||
no_python_dispatcher = torch._C._DisablePythonDispatcher
|
||||
enable_python_dispatcher = torch._C._EnablePythonDispatcher
|
||||
@ -15,6 +16,7 @@ enable_pre_dispatch = torch._C._EnablePreDispatch
|
||||
|
||||
CROSSREF_FUNCTIONALIZE = False
|
||||
|
||||
|
||||
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
|
||||
"""
|
||||
Warning: the set of overloads this will report is very subtle. It is precisely
|
||||
@ -40,9 +42,12 @@ def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
|
||||
for overload in packet:
|
||||
yield getattr(packet, overload)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suspend_functionalization():
|
||||
f_tls = torch._C._dispatch_tls_is_dispatch_key_included(torch._C.DispatchKey.Functionalize)
|
||||
f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
|
||||
torch._C.DispatchKey.Functionalize
|
||||
)
|
||||
f_rv = torch._C._functionalization_reapply_views_tls()
|
||||
if f_tls:
|
||||
torch._disable_functionalization()
|
||||
@ -52,12 +57,18 @@ def suspend_functionalization():
|
||||
if f_tls:
|
||||
torch._enable_functionalization(reapply_views=f_rv)
|
||||
|
||||
|
||||
def check_tensor_metadata_matches(nv, rv, desc):
|
||||
assert callable(desc)
|
||||
assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
|
||||
assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
|
||||
same_strides, idx = torch._prims_common.check_significant_strides(nv, rv, only_cuda=False)
|
||||
assert same_strides, f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
|
||||
same_strides, idx = torch._prims_common.check_significant_strides(
|
||||
nv, rv, only_cuda=False
|
||||
)
|
||||
assert (
|
||||
same_strides
|
||||
), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
|
||||
|
||||
|
||||
def check_metadata_matches(n, r, desc):
|
||||
assert callable(desc)
|
||||
@ -71,6 +82,7 @@ def check_metadata_matches(n, r, desc):
|
||||
continue
|
||||
check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
|
||||
|
||||
|
||||
class Lit:
|
||||
def __init__(self, s):
|
||||
self.s = s
|
||||
@ -78,14 +90,19 @@ class Lit:
|
||||
def __repr__(self):
|
||||
return self.s
|
||||
|
||||
|
||||
def _fmt(a: object) -> object:
|
||||
if isinstance(a, torch.Tensor):
|
||||
return Lit(f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})")
|
||||
return Lit(
|
||||
f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
|
||||
)
|
||||
else:
|
||||
return a
|
||||
|
||||
|
||||
def make_crossref_functionalize(op, final_key):
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
||||
# This case is pretty weird, suppress it for now
|
||||
if op == torch.ops.aten.lift_fresh.default:
|
||||
return final_key
|
||||
@ -117,7 +134,9 @@ def make_crossref_functionalize(op, final_key):
|
||||
|
||||
with suspend_functionalization():
|
||||
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
|
||||
orig_f_args, orig_f_kwargs = pytree.tree_map(maybe_detach, (f_args, f_kwargs))
|
||||
orig_f_args, orig_f_kwargs = pytree.tree_map(
|
||||
maybe_detach, (f_args, f_kwargs)
|
||||
)
|
||||
with fake_mode:
|
||||
f_r = op(*f_args, **f_kwargs)
|
||||
r = op._op_dk(final_key, *args, **kwargs)
|
||||
@ -126,14 +145,20 @@ def make_crossref_functionalize(op, final_key):
|
||||
fmt_args = ", ".join(
|
||||
itertools.chain(
|
||||
(repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
|
||||
(f"{k}={pytree.tree_map(_fmt, v)}" for k, v in orig_f_kwargs.items()),
|
||||
(
|
||||
f"{k}={pytree.tree_map(_fmt, v)}"
|
||||
for k, v in orig_f_kwargs.items()
|
||||
),
|
||||
)
|
||||
)
|
||||
return f"{op}({fmt_args})"
|
||||
|
||||
check_metadata_matches(f_r, r, desc)
|
||||
return r
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
# NB: enabling this is slow, don't do it in a hot loop. This is purely
|
||||
# for debugging purposes.
|
||||
@contextmanager
|
||||
@ -142,7 +167,8 @@ def enable_crossref_functionalize():
|
||||
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|
||||
try:
|
||||
with enable_python_dispatcher(), unittest.mock.patch(
|
||||
'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True):
|
||||
"torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True
|
||||
):
|
||||
yield
|
||||
finally:
|
||||
for op in all_py_loaded_overloads():
|
||||
|
Reference in New Issue
Block a user