mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 9790d90e4b0288796ab44a6b4979db0a67580ba8. Reverted https://github.com/pytorch/pytorch/pull/83628 on behalf of https://github.com/malfet due to Breaks internal builds, see D39076487
512 lines
19 KiB
Python
512 lines
19 KiB
Python
# Owner(s): ["module: primTorch"]
|
|
|
|
from collections import defaultdict
|
|
from torch import Tensor
|
|
import torch.autograd
|
|
from torch.utils._python_dispatch import enable_torch_dispatch_mode
|
|
from torch._decomp import decomposition_table
|
|
|
|
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
from torch.utils._mode_utils import no_dispatch
|
|
from torch.testing._internal.common_utils import (
|
|
is_iterable_of_tensors,
|
|
TestCase,
|
|
skipIfCrossRef,
|
|
suppress_warnings,
|
|
TEST_WITH_ASAN,
|
|
run_tests,
|
|
skipIfSlowGradcheckEnv,
|
|
)
|
|
from torch.testing._internal.common_device_type import (
|
|
onlyNativeDeviceTypes,
|
|
ops,
|
|
instantiate_device_type_tests,
|
|
)
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
|
|
import itertools
|
|
import functools
|
|
from functools import partial
|
|
import unittest
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
# TODO: this isn't going to work with non-aten namespaces
|
|
def overload_to_aten_name(overload):
|
|
return overload._schema.name.split("::")[1]
|
|
|
|
|
|
# All operators that can have decomp tests
|
|
decomposition_names = {overload_to_aten_name(k) for k in decomposition_table}
|
|
_decomp_test_ops = [
|
|
op
|
|
for op in op_db
|
|
if op.aten_name in decomposition_names
|
|
or op.aten_backward_name in decomposition_names
|
|
]
|
|
|
|
|
|
def diff_arg(arg, requires_grad=True):
|
|
def is_differentiable_arg(arg):
|
|
if requires_grad:
|
|
return arg.requires_grad
|
|
else:
|
|
return arg.is_floating_point() or arg.is_complex()
|
|
|
|
if is_iterable_of_tensors(arg):
|
|
if all([is_differentiable_arg(a) for a in arg]):
|
|
return True
|
|
if all([not is_differentiable_arg(a) for a in arg]):
|
|
return False
|
|
raise RuntimeError("NYI: The test runner can't handle this")
|
|
return isinstance(arg, Tensor) and is_differentiable_arg(arg)
|
|
|
|
|
|
# Version of autograd.grad with some differences:
|
|
# - pytree inputs is allowed (but leaves of the pytree have to all
|
|
# be tensors)
|
|
# - if an input is not used as part of derivatives, we will return a
|
|
# zero-filled tensor for the result
|
|
def _autograd_grad(
|
|
outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
|
|
):
|
|
inputs, inputs_spec = tree_flatten(inputs)
|
|
diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
|
|
if grad_outputs is None:
|
|
diff_outputs = tuple(out for out in outputs if out.requires_grad)
|
|
else:
|
|
diff_grad_outputs = [
|
|
(out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
|
|
]
|
|
if len(diff_grad_outputs) == 0:
|
|
diff_outputs, grad_outputs = (), ()
|
|
else:
|
|
diff_outputs, grad_outputs = zip(*diff_grad_outputs)
|
|
grad_inputs = torch.autograd.grad(
|
|
diff_outputs,
|
|
diff_inputs,
|
|
grad_outputs,
|
|
retain_graph=retain_graph,
|
|
create_graph=create_graph,
|
|
allow_unused=True,
|
|
)
|
|
result = []
|
|
grad_inputs_iter = iter(grad_inputs)
|
|
for inp in inputs:
|
|
if inp.requires_grad:
|
|
grad_input = next(grad_inputs_iter)
|
|
if grad_input is None:
|
|
result.append(torch.zeros_like(inp))
|
|
else:
|
|
result.append(grad_input)
|
|
else:
|
|
result.append(torch.zeros_like(inp))
|
|
return tree_unflatten(result, inputs_spec)
|
|
|
|
|
|
def _as_tuple(val):
|
|
if isinstance(val, tuple):
|
|
return val
|
|
return (val,)
|
|
|
|
|
|
def ref_vjp_no_create(f, *primals):
|
|
result = f(*primals)
|
|
|
|
def wrapped(cotangents):
|
|
return _autograd_grad(
|
|
_as_tuple(result), primals, _as_tuple(cotangents), create_graph=False
|
|
)
|
|
|
|
return result, wrapped
|
|
|
|
|
|
dtype_precisions = {
|
|
torch.float16: (0.001, 1e-5),
|
|
torch.bfloat16: (0.016, 1e-4),
|
|
torch.float32: (1.3e-6, 1e-5),
|
|
torch.float64: (1e-7, 1e-7),
|
|
torch.complex32: (0.001, 1e-5),
|
|
torch.complex64: (1.3e-6, 1e-5),
|
|
torch.complex128: (1e-7, 1e-7),
|
|
}
|
|
# Returns the "default" rtol and atol for comparing scalars or
|
|
# tensors of the given dtypes.
|
|
|
|
|
|
def _getDefaultRtolAndAtol(dtype0, dtype1):
|
|
rtol = max(
|
|
dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0]
|
|
)
|
|
atol = max(
|
|
dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1]
|
|
)
|
|
return rtol, atol
|
|
|
|
|
|
def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs):
|
|
assert orig.dtype == decomp.dtype, f"{i} Operation: {op}"
|
|
if orig.numel() == 0 or decomp.numel() == 0:
|
|
assert orig.numel() == decomp.numel()
|
|
return
|
|
assert orig.shape == decomp.shape, f"{i} Operation: {op}"
|
|
tol_table = {
|
|
(torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
|
|
(torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
|
|
(torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5,
|
|
(torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
|
|
(torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-6,
|
|
(torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-6,
|
|
}
|
|
if ref.is_floating_point():
|
|
orig_diff = (orig - ref).abs().max()
|
|
decomp_diff = (decomp - ref).abs().max()
|
|
atol = tol_table.get((test_dtype, op), 1e-7)
|
|
if decomp_diff > orig_diff + atol:
|
|
raise RuntimeError(
|
|
f"Difference from float64 is larger with decomposition {op.__name__}"
|
|
f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
|
|
f"atol = {atol}\n"
|
|
f"args = {args}\n"
|
|
f"kwargs = {kwargs}"
|
|
)
|
|
else:
|
|
test_case.assertEqual(
|
|
orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}"
|
|
)
|
|
|
|
|
|
def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
|
|
test_case.assertEqual(
|
|
orig.dtype, decomp.dtype, f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}")
|
|
# Before adding an entry to this table, make sure your decomposition is right :)
|
|
tol_table = {
|
|
# Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
|
|
(torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3),
|
|
(torch.float32, torch.ops.aten.native_layer_norm_backward.default): (
|
|
1e-3,
|
|
1e-3,
|
|
),
|
|
}
|
|
if (test_dtype, op) in tol_table:
|
|
rtol, atol = tol_table[(decomp.dtype, op)]
|
|
else:
|
|
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
|
|
test_case.assertEqual(orig, decomp, rtol=rtol, atol=atol, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}")
|
|
|
|
|
|
# Given f, returns an f' such that:
|
|
# - f' takes only positional arguments
|
|
# - All arguments to f' are floating-point Tensors
|
|
# - All outputs of f' are floating-point Tensors
|
|
def normalize_op_input_output2(
|
|
f, args, kwargs, output_process_fn_grad=None, requires_grad=True
|
|
):
|
|
flat_args, args_spec = tree_flatten(args)
|
|
diff_argnums = tuple(
|
|
i
|
|
for i, arg in enumerate(flat_args)
|
|
if diff_arg(arg, requires_grad=requires_grad)
|
|
)
|
|
assert len(diff_argnums) > 0
|
|
primals = tuple(flat_args[i] for i in diff_argnums)
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(*primals):
|
|
_args = list(flat_args)
|
|
for num, arg in zip(diff_argnums, primals):
|
|
_args[num] = arg
|
|
_args = tree_unflatten(_args, args_spec)
|
|
result = f(*_args, **kwargs)
|
|
if output_process_fn_grad is not None:
|
|
result = output_process_fn_grad(result)
|
|
if isinstance(result, tuple):
|
|
# TODO: Remove the following hack for namedtuples
|
|
result = tuple(result)
|
|
result = tuple(
|
|
r
|
|
for r in result
|
|
if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex())
|
|
)
|
|
assert len(result) > 0
|
|
return result
|
|
|
|
return wrapped, primals
|
|
|
|
|
|
# NB: This also upcasts dtype arguments
|
|
# TODO: handle complex correctly
|
|
def upcast_tensor(x, dtype=torch.float32):
|
|
if isinstance(x, Tensor) and x.dtype.is_floating_point:
|
|
return x.to(dtype=dtype)
|
|
elif (isinstance(x, torch.dtype)
|
|
and x in [torch.float16, torch.bfloat16, torch.float]):
|
|
return dtype
|
|
else:
|
|
return x
|
|
|
|
|
|
def normalize_op_input_output(f, sample, requires_grad=True):
|
|
args = tuple([sample.input] + list(sample.args))
|
|
return normalize_op_input_output2(
|
|
f,
|
|
args,
|
|
sample.kwargs,
|
|
sample.output_process_fn_grad,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
|
|
CROSS_REF_EXCLUDE_SET = {
|
|
# CUBLAS_STATUS_NOT_SUPPORTED when calling
|
|
# `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k,
|
|
# (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF,
|
|
# (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
|
|
# (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
|
|
("cuda", torch.bfloat16, "nn.functional.bilinear"),
|
|
# randomness
|
|
("cuda", torch.float16, "nn.functional.dropout"),
|
|
("cuda", torch.bfloat16, "nn.functional.dropout"),
|
|
("cuda", torch.float64, "nn.functional.dropout"),
|
|
("cuda", torch.float32, "nn.functional.dropout"),
|
|
(None, None, "new_empty"),
|
|
# decomp has problem even with opmath
|
|
# doesn't work
|
|
("cuda", torch.bfloat16, "nn.functional.embedding"),
|
|
|
|
# CompositeAutogradImplicit
|
|
# See https://github.com/pytorch/pytorch/issues/81669
|
|
(None, None, "nn.functional.relu6"),
|
|
(None, None, "meshgrid"),
|
|
}
|
|
|
|
all_decomposed = set()
|
|
all_called = defaultdict(int)
|
|
|
|
# Helpful snippet for testing coverage
|
|
"""
|
|
import atexit
|
|
def check_coverage():
|
|
print("missing coverage:")
|
|
print("\n".join(map(str, decomposition_table.keys() - all_decomposed)))
|
|
atexit.register(check_coverage)
|
|
"""
|
|
|
|
# Helpful snippet for Horace to create his google sheet :)
|
|
"""
|
|
import atexit
|
|
def dump_ops():
|
|
with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g:
|
|
for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__):
|
|
f.write(f'{op.__name__}\n')
|
|
g.write(f'{count}\n')
|
|
with open('run_decompositions.txt', 'w') as f:
|
|
for op in sorted([i.__name__ for i in all_decomposed]):
|
|
f.write(f'{op}\n')
|
|
|
|
atexit.register(dump_ops)
|
|
"""
|
|
|
|
|
|
def any_unsupported(args, kwargs):
|
|
def test_unsupported(t):
|
|
if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
|
|
# These are all things that we haven't coded decompositions
|
|
# to handle correctly. Maybe they should.
|
|
return any([
|
|
t.is_sparse_csr, t.is_sparse, t.is_mkldnn, t.is_quantized,
|
|
t.is_nested, torch._is_functional_tensor(t),
|
|
])
|
|
elif torch.overrides.is_tensor_like(t):
|
|
# Decompositions will generally change the behavior of Tensor-like
|
|
# subclasses, so bypass tests in this case too
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
flat_args, _ = tree_flatten(args)
|
|
flat_kwargs, _ = tree_flatten(kwargs)
|
|
return any(test_unsupported(x) for x in itertools.chain(flat_args, flat_kwargs))
|
|
|
|
|
|
@skipIfSlowGradcheckEnv
|
|
class TestDecomp(TestCase):
|
|
longMessage = True
|
|
|
|
# NB: This actually overlaps with test_comprehensive, but it only
|
|
# runs on things that are definitely decomposed so it's a lot faster
|
|
# to run
|
|
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
|
@onlyNativeDeviceTypes
|
|
@skipIfCrossRef
|
|
@suppress_warnings
|
|
@ops(_decomp_test_ops)
|
|
def test_quick(self, device, dtype, op):
|
|
self.do_cross_ref(device, dtype, op, run_all=False)
|
|
|
|
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
|
@onlyNativeDeviceTypes
|
|
@skipIfCrossRef
|
|
@suppress_warnings
|
|
@ops(op_db)
|
|
def test_comprehensive(self, device, dtype, op):
|
|
self.do_cross_ref(device, dtype, op, run_all=True)
|
|
|
|
def do_cross_ref(self, device, dtype, op, *, run_all):
|
|
if (torch.device(device).type, dtype, op.name) in CROSS_REF_EXCLUDE_SET or (
|
|
None,
|
|
dtype,
|
|
op.name,
|
|
) in CROSS_REF_EXCLUDE_SET or (None, None, op.name) in CROSS_REF_EXCLUDE_SET:
|
|
self.skipTest(f"{op.name} in {dtype} not supported")
|
|
|
|
test_dtype = dtype
|
|
|
|
# We check the correctness of each decomposition right after running it.
|
|
# So, when we encounter a decomposition, we run the function normally, and
|
|
# then run the decomposition, and ensure they're identical.
|
|
called = set()
|
|
decomposed = set()
|
|
|
|
saved_precision = self.precision
|
|
saved_rel_tol = self.rel_tol
|
|
|
|
class DecompCrossRefMode(torch.Tensor):
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
with no_dispatch():
|
|
return cls._torch_dispatch(func, types, args, kwargs)
|
|
|
|
@classmethod
|
|
def _torch_dispatch(cls, func, types, args=(), kwargs=None):
|
|
self.precision = saved_precision
|
|
self.rel_tol = saved_rel_tol
|
|
|
|
called.add(func)
|
|
all_called[func] += 1
|
|
|
|
# Stuff we shouldn't bother testing
|
|
# (TODO: remove detach from the decomp table?)
|
|
if func not in decomposition_table or func in [
|
|
torch.ops.aten.detach.default,
|
|
# non-deterministic ops
|
|
torch.ops.aten.new_empty.default,
|
|
torch.ops.aten.new_empty.SymInt
|
|
] or any_unsupported(args, kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
decomposed.add(func)
|
|
all_decomposed.add(func)
|
|
|
|
# We take 2 main strategies for verifying correctness/numerical stability of decompositions
|
|
# The first one is simply tolerance checking between decomp_out and pytorch_out
|
|
# However, for fp16/bf16 and reductions, this becomes very
|
|
# finicky, as there are not many guarantees we can make.
|
|
# So, for fp16/bf16, we instead compare the difference of
|
|
# {decomp_out, pytorch_out_64} and {pytorch_out,
|
|
# pytorch_out_64}. In other words, we compare how far the
|
|
# decomposition and pytorch are from the "ground truth" (i.e.
|
|
# fp64). If the decomposition results in more error, we error
|
|
|
|
decomposition = decomposition_table[func]
|
|
|
|
do_relative_check = test_dtype in [torch.float16, torch.bfloat16]
|
|
real_out_unflat = func(*args, **kwargs)
|
|
real_out, _ = tree_flatten(real_out_unflat)
|
|
decomp_out, _ = tree_flatten(decomposition(*args, **kwargs))
|
|
assert len(real_out) == len(decomp_out)
|
|
|
|
if do_relative_check:
|
|
upcast = partial(upcast_tensor, dtype=torch.float64)
|
|
real_out_double, _ = tree_flatten(
|
|
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
|
|
)
|
|
for i, orig, decomp, ref in zip(range(len(real_out)), real_out, decomp_out, real_out_double):
|
|
if not isinstance(orig, torch.Tensor):
|
|
assert type(orig) == type(decomp)
|
|
assert orig == decomp
|
|
continue
|
|
op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs)
|
|
else:
|
|
for orig, decomp in zip(real_out, decomp_out):
|
|
if not isinstance(orig, torch.Tensor):
|
|
assert type(orig) == type(decomp)
|
|
assert orig == decomp
|
|
continue
|
|
op_assert_equal(self, func, test_dtype, orig, decomp, args, kwargs)
|
|
|
|
return real_out_unflat
|
|
|
|
requires_grad = (
|
|
op.supports_autograd
|
|
and dtype in op.supported_backward_dtypes(torch.device(device).type)
|
|
# TODO: OpInfo really ought to error out for this case, but it's
|
|
# not exercised in test_ops_gradients atm. The problem is not
|
|
# complex32 per-se (which is supported by data movement only ops)
|
|
# but that when we do backwards we expect other ops like add to work
|
|
and not dtype == torch.complex32
|
|
)
|
|
samples = op.sample_inputs(device, test_dtype, requires_grad=requires_grad)
|
|
|
|
def check_decomposed(aten_name):
|
|
self.assertTrue(
|
|
any(overload_to_aten_name(c) == aten_name for c in decomposed),
|
|
msg=(f"aten.{aten_name} was not decomposed, saw calls for: "
|
|
f"{', '.join(map(str, list(called)))}. If your op is "
|
|
f"CompositeImplicitAutograd you should skip this test "
|
|
"by updating CROSS_REF_EXCLUDE_SET.")
|
|
)
|
|
|
|
aten_name = op.decomp_aten_name or op.aten_name
|
|
|
|
func = op.get_op()
|
|
for sample_input in samples:
|
|
if requires_grad:
|
|
if None in sample_input.args:
|
|
continue
|
|
|
|
fn, primals = normalize_op_input_output(func, sample_input)
|
|
primals = tree_map(
|
|
lambda x: x if isinstance(x, torch.Tensor) else x, primals
|
|
)
|
|
|
|
# Once https://github.com/pytorch/pytorch/pull/75965/ I can
|
|
# store the called list on the mode object instance and no
|
|
# explicit clearing is necessary as I will create a fresh mode
|
|
# for each region
|
|
decomposed.clear()
|
|
with enable_torch_dispatch_mode(DecompCrossRefMode):
|
|
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
|
|
if aten_name in decomposition_names:
|
|
check_decomposed(aten_name)
|
|
|
|
if op.aten_backward_name in decomposition_names or run_all:
|
|
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
|
|
|
|
decomposed.clear()
|
|
with enable_torch_dispatch_mode(DecompCrossRefMode):
|
|
decomp_vjp_fn(cotangents)
|
|
if not run_all:
|
|
check_decomposed(op.aten_backward_name)
|
|
|
|
elif aten_name in decomposition_names or run_all:
|
|
args = [sample_input.input] + list(sample_input.args)
|
|
kwargs = sample_input.kwargs
|
|
decomposed.clear()
|
|
with enable_torch_dispatch_mode(DecompCrossRefMode):
|
|
func(*args, **kwargs)
|
|
if not run_all:
|
|
check_decomposed(aten_name)
|
|
else:
|
|
assert op.supports_autograd
|
|
self.skipTest(
|
|
"only backwards is decomposed, but dtype doesn't support AD"
|
|
)
|
|
|
|
|
|
instantiate_device_type_tests(TestDecomp, globals())
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|