mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Port functorch decomps over and fix some tests
Still some stuff to fix up, will finish later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76621 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
786903ea29
commit
fb24614011
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: primTorch"]
|
||||
|
||||
from collections import defaultdict
|
||||
from torch import Tensor
|
||||
import torch.autograd
|
||||
from torch.utils._python_dispatch import enable_python_mode
|
||||
@ -26,6 +27,8 @@ import functools
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
# TODO: this isn't going to work with non-aten namespaces
|
||||
def overload_to_aten_name(overload):
|
||||
@ -161,8 +164,8 @@ def op_assert_ref(test_case, op, orig, decomp, ref, args, kwargs):
|
||||
)
|
||||
|
||||
|
||||
def op_assert_equal(test_case, op, a, b, args, kwargs):
|
||||
assert a.dtype == b.dtype
|
||||
def op_assert_equal(test_case, op, orig, decomp, args, kwargs):
|
||||
assert orig.dtype == decomp.dtype, f"Operation: {op}"
|
||||
# Before adding an entry to this table, make sure your decomposition is right :)
|
||||
tol_table = {
|
||||
# Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
|
||||
@ -172,11 +175,12 @@ def op_assert_equal(test_case, op, a, b, args, kwargs):
|
||||
1e-3,
|
||||
),
|
||||
}
|
||||
if (b.dtype, op) in tol_table:
|
||||
rtol, atol = tol_table[(b.dtype, op)]
|
||||
if (decomp.dtype, op) in tol_table:
|
||||
rtol, atol = tol_table[(decomp.dtype, op)]
|
||||
else:
|
||||
rtol, atol = _getDefaultRtolAndAtol(a.dtype, b.dtype)
|
||||
test_case.assertEqual(a, b, rtol=rtol, atol=atol)
|
||||
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
|
||||
|
||||
test_case.assertEqual(orig, decomp, rtol=rtol, atol=atol, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}")
|
||||
|
||||
|
||||
# Given f, returns an f' such that:
|
||||
@ -260,45 +264,31 @@ def normalize_op_input_output(f, sample, requires_grad=True):
|
||||
|
||||
|
||||
CROSS_REF_EXCLUDE_SET = {
|
||||
(
|
||||
"cpu",
|
||||
torch.bfloat16,
|
||||
"nn.functional.layer_norm",
|
||||
), # "batch_norm" not implemented for 'BFloat16'
|
||||
("cpu", torch.bfloat16, "addmm"), # decomposition loses precision
|
||||
("cpu", torch.bfloat16, "softmax"), # needs relaxed prec
|
||||
("cpu", torch.bfloat16, "log_softmax"), # needs relaxed prec
|
||||
# CUBLAS_STATUS_NOT_SUPPORTED when calling
|
||||
# `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k,
|
||||
# (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF,
|
||||
# (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
|
||||
# (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
|
||||
("cuda", torch.bfloat16, "nn.functional.bilinear"),
|
||||
# randomness
|
||||
("cuda", torch.float16, "nn.functional.dropout"),
|
||||
("cuda", torch.bfloat16, "nn.functional.dropout"),
|
||||
("cuda", torch.float64, "nn.functional.dropout"),
|
||||
("cuda", torch.float32, "nn.functional.dropout"),
|
||||
# decomp has problem even with opmath
|
||||
("cuda", torch.bfloat16, "nn.functional.layer_norm"),
|
||||
("cuda", torch.float16, "nn.functional.layer_norm"),
|
||||
("cuda", torch.float16, "nn.functional.dropout"),
|
||||
("cuda", torch.bfloat16, "nn.functional.dropout"),
|
||||
# decomp doesn't return correct dtype
|
||||
("cuda", torch.float64, "nn.functional.instance_norm"),
|
||||
("cuda", torch.float32, "nn.functional.instance_norm"),
|
||||
("cuda", torch.float64, "nn.functional.dropout"),
|
||||
("cuda", torch.float32, "nn.functional.dropout"),
|
||||
("cuda", torch.float64, "nn.functional.batch_norm"),
|
||||
("cuda", torch.float32, "nn.functional.batch_norm"),
|
||||
("cuda", torch.bfloat16, "nn.functional.batch_norm"),
|
||||
("cuda", torch.float16, "nn.functional.batch_norm"),
|
||||
("cuda", torch.bfloat16, "nn.functional.instance_norm"),
|
||||
("cuda", torch.float16, "nn.functional.instance_norm"),
|
||||
# complex is not handled
|
||||
(None, torch.complex64, "var"),
|
||||
(None, torch.complex128, "var"),
|
||||
(None, torch.complex64, "nn.functional.tanhshrink"),
|
||||
(None, torch.complex128, "nn.functional.tanhshrink"),
|
||||
(None, torch.complex32, "sigmoid"),
|
||||
(None, torch.complex64, "sigmoid"),
|
||||
(None, torch.complex128, "sigmoid"),
|
||||
(None, torch.complex64, "tanh"),
|
||||
(None, torch.complex128, "tanh"),
|
||||
}
|
||||
|
||||
all_decomposed = set()
|
||||
all_called = defaultdict(int)
|
||||
|
||||
# Helpful snippet for testing coverage
|
||||
"""
|
||||
@ -309,6 +299,21 @@ def check_coverage():
|
||||
atexit.register(check_coverage)
|
||||
"""
|
||||
|
||||
# Helpful snippet for Horace to create his google sheet :)
|
||||
"""
|
||||
import atexit
|
||||
def dump_ops():
|
||||
with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g:
|
||||
for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__):
|
||||
f.write(f'{op.__name__}\n')
|
||||
g.write(f'{count}\n')
|
||||
with open('run_decompositions.txt', 'w') as f:
|
||||
for op in sorted([i.__name__ for i in all_decomposed]):
|
||||
f.write(f'{op}\n')
|
||||
|
||||
atexit.register(dump_ops)
|
||||
"""
|
||||
|
||||
|
||||
class TestDecomp(TestCase):
|
||||
longMessage = True
|
||||
@ -363,6 +368,7 @@ class TestDecomp(TestCase):
|
||||
self.rel_tol = saved_rel_tol
|
||||
|
||||
called.add(func)
|
||||
all_called[func] += 1
|
||||
|
||||
# Stuff we shouldn't bother testing
|
||||
# (TODO: remove detach from the decomp table?)
|
||||
@ -387,7 +393,6 @@ class TestDecomp(TestCase):
|
||||
decomposition = decomposition_table[func]
|
||||
|
||||
do_relative_check = test_dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
real_out_unflat = func(*args, **kwargs)
|
||||
real_out, _ = tree_flatten(real_out_unflat)
|
||||
decomp_out, _ = tree_flatten(decomposition(*args, **kwargs))
|
||||
@ -474,11 +479,11 @@ class TestDecomp(TestCase):
|
||||
)
|
||||
|
||||
def test_torchscriptable(self, device):
|
||||
skip_list = []
|
||||
skip_list = [aten.rsub.Scalar]
|
||||
for op, decomposition in decomposition_table.items():
|
||||
if op in skip_list:
|
||||
continue
|
||||
torch.jit.script(decomposition)
|
||||
f = torch.jit.script(decomposition)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestDecomp, globals())
|
||||
|
Reference in New Issue
Block a user