Port functorch decomps over and fix some tests

Still some stuff to fix up, will finish later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76621
Approved by: https://github.com/ezyang
This commit is contained in:
Horace He
2022-05-01 08:48:48 +00:00
committed by PyTorch MergeBot
parent 786903ea29
commit fb24614011
3 changed files with 390 additions and 91 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: primTorch"]
from collections import defaultdict
from torch import Tensor
import torch.autograd
from torch.utils._python_dispatch import enable_python_mode
@ -26,6 +27,8 @@ import functools
from functools import partial
import unittest
aten = torch.ops.aten
# TODO: this isn't going to work with non-aten namespaces
def overload_to_aten_name(overload):
@ -161,8 +164,8 @@ def op_assert_ref(test_case, op, orig, decomp, ref, args, kwargs):
)
def op_assert_equal(test_case, op, a, b, args, kwargs):
assert a.dtype == b.dtype
def op_assert_equal(test_case, op, orig, decomp, args, kwargs):
assert orig.dtype == decomp.dtype, f"Operation: {op}"
# Before adding an entry to this table, make sure your decomposition is right :)
tol_table = {
# Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
@ -172,11 +175,12 @@ def op_assert_equal(test_case, op, a, b, args, kwargs):
1e-3,
),
}
if (b.dtype, op) in tol_table:
rtol, atol = tol_table[(b.dtype, op)]
if (decomp.dtype, op) in tol_table:
rtol, atol = tol_table[(decomp.dtype, op)]
else:
rtol, atol = _getDefaultRtolAndAtol(a.dtype, b.dtype)
test_case.assertEqual(a, b, rtol=rtol, atol=atol)
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
test_case.assertEqual(orig, decomp, rtol=rtol, atol=atol, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}")
# Given f, returns an f' such that:
@ -260,45 +264,31 @@ def normalize_op_input_output(f, sample, requires_grad=True):
CROSS_REF_EXCLUDE_SET = {
(
"cpu",
torch.bfloat16,
"nn.functional.layer_norm",
), # "batch_norm" not implemented for 'BFloat16'
("cpu", torch.bfloat16, "addmm"), # decomposition loses precision
("cpu", torch.bfloat16, "softmax"), # needs relaxed prec
("cpu", torch.bfloat16, "log_softmax"), # needs relaxed prec
# CUBLAS_STATUS_NOT_SUPPORTED when calling
# `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k,
# (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF,
# (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
# (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
("cuda", torch.bfloat16, "nn.functional.bilinear"),
# randomness
("cuda", torch.float16, "nn.functional.dropout"),
("cuda", torch.bfloat16, "nn.functional.dropout"),
("cuda", torch.float64, "nn.functional.dropout"),
("cuda", torch.float32, "nn.functional.dropout"),
# decomp has problem even with opmath
("cuda", torch.bfloat16, "nn.functional.layer_norm"),
("cuda", torch.float16, "nn.functional.layer_norm"),
("cuda", torch.float16, "nn.functional.dropout"),
("cuda", torch.bfloat16, "nn.functional.dropout"),
# decomp doesn't return correct dtype
("cuda", torch.float64, "nn.functional.instance_norm"),
("cuda", torch.float32, "nn.functional.instance_norm"),
("cuda", torch.float64, "nn.functional.dropout"),
("cuda", torch.float32, "nn.functional.dropout"),
("cuda", torch.float64, "nn.functional.batch_norm"),
("cuda", torch.float32, "nn.functional.batch_norm"),
("cuda", torch.bfloat16, "nn.functional.batch_norm"),
("cuda", torch.float16, "nn.functional.batch_norm"),
("cuda", torch.bfloat16, "nn.functional.instance_norm"),
("cuda", torch.float16, "nn.functional.instance_norm"),
# complex is not handled
(None, torch.complex64, "var"),
(None, torch.complex128, "var"),
(None, torch.complex64, "nn.functional.tanhshrink"),
(None, torch.complex128, "nn.functional.tanhshrink"),
(None, torch.complex32, "sigmoid"),
(None, torch.complex64, "sigmoid"),
(None, torch.complex128, "sigmoid"),
(None, torch.complex64, "tanh"),
(None, torch.complex128, "tanh"),
}
all_decomposed = set()
all_called = defaultdict(int)
# Helpful snippet for testing coverage
"""
@ -309,6 +299,21 @@ def check_coverage():
atexit.register(check_coverage)
"""
# Helpful snippet for Horace to create his google sheet :)
"""
import atexit
def dump_ops():
with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g:
for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__):
f.write(f'{op.__name__}\n')
g.write(f'{count}\n')
with open('run_decompositions.txt', 'w') as f:
for op in sorted([i.__name__ for i in all_decomposed]):
f.write(f'{op}\n')
atexit.register(dump_ops)
"""
class TestDecomp(TestCase):
longMessage = True
@ -363,6 +368,7 @@ class TestDecomp(TestCase):
self.rel_tol = saved_rel_tol
called.add(func)
all_called[func] += 1
# Stuff we shouldn't bother testing
# (TODO: remove detach from the decomp table?)
@ -387,7 +393,6 @@ class TestDecomp(TestCase):
decomposition = decomposition_table[func]
do_relative_check = test_dtype in [torch.float16, torch.bfloat16]
real_out_unflat = func(*args, **kwargs)
real_out, _ = tree_flatten(real_out_unflat)
decomp_out, _ = tree_flatten(decomposition(*args, **kwargs))
@ -474,11 +479,11 @@ class TestDecomp(TestCase):
)
def test_torchscriptable(self, device):
skip_list = []
skip_list = [aten.rsub.Scalar]
for op, decomposition in decomposition_table.items():
if op in skip_list:
continue
torch.jit.script(decomposition)
f = torch.jit.script(decomposition)
instantiate_device_type_tests(TestDecomp, globals())