Enable aten-aten decomps (#85921)

Invokes aten-aten decomps with re-entrant FakeMode. These decomps are being used in other places, so it's good to unify the path static fake tensor takes / get additional testing etc. There is also an instance where we return different devices with cpu/cuda which this fixes ([batch_norm](https://github.com/pytorch/pytorch/blob/master/torch/_decomp/decompositions.py#L1374))

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85921
Approved by: https://github.com/ezyang
This commit is contained in:
Elias Ellison
2022-10-07 18:01:13 +00:00
committed by PyTorch MergeBot
parent af9c6bc851
commit d3f7c34cb3
8 changed files with 64 additions and 35 deletions

View File

@ -880,7 +880,6 @@ symbolic_aot_autograd_failures = {
xfail('mvlgamma', 'mvlgamma_p_3'), # aten.digamma_.default - couldn't find symbolic meta function/decom...
xfail('mvlgamma', 'mvlgamma_p_5'), # aten.digamma_.default - couldn't find symbolic meta function/decom...
xfail('nanmedian', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition
xfail('native_layer_norm', ''), # could not find kernel
xfail('nn.functional._scaled_dot_product_attention', ''), # Cannot call sizes() on tensor with symbolic ...
xfail('nn.functional.adaptive_avg_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.adaptive_avg_pool2d', ''), # aten._adaptive_avg_pool2d_backward.default - couldn't ...
@ -923,7 +922,6 @@ symbolic_aot_autograd_failures = {
xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st...
xfail('nn.functional.kl_div', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.l1_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.layer_norm', ''), # could not find kernel
xfail('nn.functional.linear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.local_response_norm', ''), # aten.fill.Scalar - couldn't find symbolic meta functio...
xfail('nn.functional.max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides

View File

@ -464,6 +464,15 @@ class FakeTensorConstHandling(TestCase):
inputs = [a, b]
ref = fn(inputs)
def test_fake_tensor_batch_norm_cpu(self):
with torch._subclasses.CrossRefFakeMode():
m = torch.nn.Sequential(
torch.nn.BatchNorm2d(10),
torch.nn.ReLU(),
)
m.eval()
out = m(torch.randn([2, 10, 8, 8]))
def test_shared_storage_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1.])

View File

@ -1811,10 +1811,6 @@ fake_backward_xfails = fake_tensor_stride_failing_ops | {
"linalg.norm",
"linalg.svd",
"linalg.svdvals",
"nn.functional.binary_cross_entropy_with_logits",
"nn.functional.huber_loss",
"nn.functional.logsigmoid",
"nn.functional.multilabel_soft_margin_loss",
"pca_lowrank",
"roll",
"svd_lowrank",

View File

@ -843,9 +843,8 @@ def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModul
def _set_cublas_allow_fp16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowFP16ReductionCuBLAS
def _set_conj(x: Tensor, conj: _bool) -> None: ...
def _set_neg(x: Tensor, neg: _bool) -> None: ...
def _add_meta_to_tls_dispatch_include() -> None: ...
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...
def _meta_in_tls_dispatch_include() -> _bool: ...
def _remove_meta_from_tls_dispatch_include() -> None: ...
def _has_storage(x: Tensor) -> _bool: ...
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
# NB: There is no Capsule type in typing, see

View File

@ -18,6 +18,10 @@ decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
# decompositions which have been disabled as meta kernel implementations,
# usually due to mismatching strides, aliasing, or other inconsistent property
_disabled_meta_decomps = set()
def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False):
"""
@ -105,6 +109,11 @@ def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False
name = op_overload._schema.name
if op_overload._schema.overload_name:
name += "." + op_overload._schema.overload_name
if disable_meta:
global _disabled_meta_decomps
_disabled_meta_decomps.add(op_overload)
if (
not disable_meta
# TorchScript dumps a bunch of extra nonsense overloads

View File

@ -1284,12 +1284,8 @@ def native_layer_norm_backward(
if M <= 0 or N <= 0:
return (
input.new_zeros(input_shape) if output_mask[0] else None,
input.new_zeros(input_shape[axis:])
if output_mask[1] and weight_cast
else None,
input.new_zeros(input_shape[axis:])
if output_mask[2] and bias_cast
else None,
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
)
x_hat = (input_cast - mean) * rstd

View File

@ -115,6 +115,20 @@ def get_schema_info(func):
return torch._C._SchemaInfo(func._schema) # type: ignore[attr-defined]
# many of the decompositions registered to torch/_prims do not at the moment model
# aliasing or strides, so as an incremental step, just enable the decompositions in
# torch/_decomp/decompositions.py.
# decomps are used for aot autograd tracing so we would like to unify on their
# implementation and add additional testing to them
@functools.lru_cache(None)
def torch_decomp_decompositions(func):
from torch._decomp import decomposition_table
decompositions = torch._decomp.decompositions
decomp_attrs = [getattr(decompositions, attr) for attr in dir(decompositions)]
return decomposition_table[func] in decomp_attrs
def tree_flatten_only(ty: Type[T], pytree: PyTree):
flat_vals, _ = tree_flatten(pytree)
return [elem for elem in flat_vals if isinstance(elem, ty)]
@ -302,7 +316,8 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):
input_device = new_kwargs["device"]
out_device = input_device if input_device else new_kwargs["input"].device
new_kwargs["device"] = torch.device("meta")
r = func(*args, **new_kwargs)
inp = new_kwargs.pop("input")
r = func(inp, **new_kwargs)
return fake_mode.fake_tensor_converter(fake_mode, r, out_device)
@ -329,7 +344,7 @@ def to_copy(fake_mode, func, *args, **kwargs):
input_device = new_kwargs.pop("device", None)
out_device = input_device if input_device else new_kwargs["input"].device
with no_dispatch(), in_kernel_invocation_manager(fake_mode):
with in_kernel_invocation_manager(fake_mode):
input = new_kwargs.pop("input").to("meta")
return FakeTensor(fake_mode, aten._to_copy(input, **new_kwargs), out_device)
@ -417,18 +432,19 @@ def nyi(fake_mode, func, *args, **kwargs):
@contextlib.contextmanager
def in_kernel_invocation_manager(fake_mode):
# See: note [Fake Tensor Dispatch Keys]
prev_in_kernel = fake_mode.in_kernel_invocation
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
prev = fake_mode.in_kernel_invocation
assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}"
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
fake_mode.in_kernel_invocation = True
if not meta_in_tls:
torch._C._add_meta_to_tls_dispatch_include()
torch._C._set_meta_in_tls_dispatch_include(True)
try:
yield
finally:
fake_mode.in_kernel_invocation = prev
if not meta_in_tls:
torch._C._remove_meta_from_tls_dispatch_include()
fake_mode.in_kernel_invocation = prev_in_kernel
torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel)
del guard
class FakeTensor(torch.Tensor):
@ -728,6 +744,8 @@ class FakeTensorMode(TorchDispatchMode):
# is written to must be invalidated
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
from torch._decomp import _disabled_meta_decomps, decomposition_table
# IDK: feels bad man, sym_numel on as_strided infinite loops otherwise
if (
has_symbolic_sizes
@ -735,7 +753,6 @@ class FakeTensorMode(TorchDispatchMode):
):
# TODO: Find better approach for this
# Avoid circular import
from torch._decomp import decomposition_table
from torch._meta_registrations import meta_table
with no_dispatch():
@ -759,6 +776,15 @@ class FakeTensorMode(TorchDispatchMode):
if r is not NotImplemented:
return r
if (
func in decomposition_table
and torch_decomp_decompositions(func)
and func not in _disabled_meta_decomps
and all(not e.is_sparse for e in flat_arg_fake_tensors)
):
with self:
return decomposition_table[func](*args, **kwargs)
# prims already wrap FakeTensor inputs to FakeTensor outputs
# and do device logic, we dont need do anything but run them
# and ensure that Meta kernels are dispatched to (see)

View File

@ -1386,25 +1386,21 @@ Call this whenever a new thread is created in order to propagate values from
py_module.def(
"_has_storage", [](const at::Tensor& x) { return x.has_storage(); });
py_module.def("_add_meta_to_tls_dispatch_include", []() {
py_module.def("_set_meta_in_tls_dispatch_include", [](bool meta_in_tls) {
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
if (meta_in_tls) {
local_keyset.included_ = local_keyset.included_ | key_set;
c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
});
py_module.def("_remove_meta_from_tls_dispatch_include", []() {
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
auto k = key_set.highestBackendKey();
local_keyset.included_ = local_keyset.included_.remove_backend(k);
} else {
local_keyset.included_ =
local_keyset.included_.remove_backend(c10::BackendComponent::MetaBit);
}
c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
});
py_module.def("_meta_in_tls_dispatch_include", []() {
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
auto k = key_set.highestBackendKey();
return local_keyset.included_.has_backend(k);
return local_keyset.included_.has_backend(c10::BackendComponent::MetaBit);
});
py_module.def("_dump_local_tls_set", []() {