mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add OpInfo tests for torch.{dot, vdot, bmm, mv} (#56409)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56409 Reviewed By: nikithamalgifb Differential Revision: D27870769 Pulled By: anjali411 fbshipit-source-id: a1a0e89856529a4739c7612c5b1e3c5ed2569126
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e4faebca0d
commit
062e70590c
@ -5385,10 +5385,8 @@ complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone'
|
||||
'expand', 'rot90', 'transpose',
|
||||
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
|
||||
'chunk', 'split', 'split_with_sizes', 'zero_',
|
||||
'__radd__', 'sum', 'mul',
|
||||
'__rmul__', 'dot', 'vdot', 'matmul',
|
||||
'bmm', 'mv', 'ger', 'diagonal', 'fill_', 'sub',
|
||||
'mean', 'inverse', 'linalg.tensorinv', 'matrix_exp',
|
||||
'__radd__', 'mul', '__rmul__', 'matmul',
|
||||
'diagonal', 'fill_', 'sub',
|
||||
'narrow', 'swapaxes', 'swapdims', 'tensor_split',
|
||||
'baddbmm'] + complex_list_filter + separate_complex_tests
|
||||
|
||||
|
@ -658,6 +658,36 @@ def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs):
|
||||
else:
|
||||
return (input, )
|
||||
|
||||
def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs):
|
||||
return (
|
||||
SampleInput(
|
||||
make_tensor((S, M, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
||||
args=(
|
||||
make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs):
|
||||
return (
|
||||
SampleInput(
|
||||
make_tensor((M, S, M, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
||||
args=(
|
||||
make_tensor((M, M, S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs):
|
||||
return (
|
||||
SampleInput(
|
||||
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
||||
args=(
|
||||
make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def sample_inputs_addmv(op_info, device, dtype, requires_grad, **kwargs):
|
||||
test_cases = (((S,), (S, M), (M,), 1, 1, False),
|
||||
((S,), (S, M), (M,), 0.2, 0.6, False),
|
||||
@ -3047,8 +3077,7 @@ op_db: List[OpInfo] = [
|
||||
OpInfo('addbmm',
|
||||
dtypes=floating_types(),
|
||||
dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
|
||||
*[torch.bfloat16] if CUDA11OrLater else []),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
|
||||
dtypesIfROCM=floating_types_and(torch.half),
|
||||
skips=(
|
||||
# addbmm does not correctly warn when resizing out= inputs
|
||||
@ -3061,6 +3090,50 @@ op_db: List[OpInfo] = [
|
||||
SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.bfloat16, ),
|
||||
device_type='cuda', active_if=not SM53OrLater)),
|
||||
sample_inputs_func=sample_inputs_addbmm),
|
||||
OpInfo('dot',
|
||||
dtypes=all_types_and_complex_and(torch.float16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16),
|
||||
skips=(
|
||||
# dot does not handle correctly out= dtypes
|
||||
# https://github.com/pytorch/pytorch/issues/55561
|
||||
SkipInfo('TestCommon', 'test_out'),
|
||||
),
|
||||
assert_autodiffed=True,
|
||||
sample_inputs_func=sample_inputs_dot_vdot),
|
||||
OpInfo('vdot',
|
||||
dtypes=all_types_and_complex_and(torch.float16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16),
|
||||
skips=(
|
||||
# vdot does not handle correctly out= dtypes
|
||||
# https://github.com/pytorch/pytorch/issues/55561
|
||||
SkipInfo('TestCommon', 'test_out'),
|
||||
),
|
||||
sample_inputs_func=sample_inputs_dot_vdot),
|
||||
OpInfo('bmm',
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
|
||||
assert_autodiffed=True,
|
||||
skips=(
|
||||
# bmm does not correctly warn when resizing out= inputs
|
||||
SkipInfo('TestCommon', 'test_out'),
|
||||
# cuda gradchecks are slow
|
||||
# see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
|
||||
SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),
|
||||
SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.bfloat16, ),
|
||||
device_type='cuda', active_if=not SM53OrLater)),
|
||||
sample_inputs_func=sample_inputs_bmm),
|
||||
OpInfo('mv',
|
||||
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
|
||||
skips=(
|
||||
# bmm does not correctly warn when resizing out= inputs
|
||||
SkipInfo('TestCommon', 'test_out'),
|
||||
SkipInfo('TestOpInfo', 'test_supported_backward', dtypes=(torch.float16,)),
|
||||
# mv calls into addmv which doesn't fully support float16
|
||||
# RuntimeError: "addmv_impl_cpu" not implemented for 'Half'
|
||||
SkipInfo('TestOpInfo', 'test_supported_dtypes', dtypes=(torch.float16,)),),
|
||||
assert_autodiffed=True,
|
||||
sample_inputs_func=sample_inputs_mv),
|
||||
OpInfo('addr',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/50747
|
||||
@ -5270,10 +5343,6 @@ def method_tests():
|
||||
('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs'),
|
||||
('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), (), ident,
|
||||
{'beta': 0.2, 'alpha': 0.6}),
|
||||
('dot', (L,), ((L,),), '', (True,)),
|
||||
('vdot', (L,), ((L,),),),
|
||||
('bmm', (M, S, M), ((M, M, S),), '', (True,)),
|
||||
('mv', (S, M), ((M,),), '', (True,)),
|
||||
('mvlgamma', torch.empty(S,).uniform_(0.5, 1), [1], "p=1"),
|
||||
('mvlgamma', torch.empty(S,).uniform_(1, 2), [2], "p=2"),
|
||||
('mvlgamma', torch.empty(S, S).uniform_(1.5, 3), [3], "p=3"),
|
||||
|
Reference in New Issue
Block a user