Add meta function for _upsample_bilinear2d_aa (#94982)

Differential Revision: D43353000

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94982
Approved by: https://github.com/ezyang
This commit is contained in:
Yanan Cao (PyTorch)
2023-02-19 07:11:18 +00:00
committed by PyTorch MergeBot
parent 17d0b7f532
commit 039b4c8809
9 changed files with 73 additions and 2 deletions

View File

@ -98,6 +98,7 @@ dtensor_fails = {
xfail("__rsub__"),
xfail("_native_batch_norm_legit"),
xfail("_softmax_backward_data"),
xfail("_upsample_bilinear2d_aa"),
xfail("addbmm"),
xfail("addmv"),
xfail("addr"),

View File

@ -2542,6 +2542,7 @@ symbolic_aot_autograd_failures = {
xfail('trapz', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de...
xfail('unflatten', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList
xfail('var', ''), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('var', 'unbiased'), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('var_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides

View File

@ -1080,6 +1080,7 @@ class TestOperators(TestCase):
xfail('nn.functional.dropout3d', ''),
xfail('as_strided_scatter', ''),
xfail('masked.cumprod', ''),
xfail("_upsample_bilinear2d_aa"), # hit vmap fallback, which is disabled
}))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
@ -1188,6 +1189,7 @@ class TestOperators(TestCase):
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
xfail("native_dropout_backward"),
xfail("_upsample_bilinear2d_aa"), # hit vmap fallback, which is disabled
}))
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
if not op.supports_autograd:

View File

@ -3751,6 +3751,7 @@ class TestVmapOperatorsOpInfo(TestCase):
# RuntimeError: Expected all tensors to be on the same device,
# but found at least two devices, cuda:0 and cpu!
xfail('ge', device_type='cuda'),
xfail('_upsample_bilinear2d_aa'),
}))
def test_op_has_batch_rule(self, device, dtype, op):
# needs to be fixed

View File

@ -329,6 +329,8 @@ CROSS_REF_EXCLUDE_SET = {
(None, None, "norm"),
# native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise)
(None, None, "native_batch_norm"),
(None, None, "_upsample_bilinear2d_aa"),
}
CROSS_REF_BACKWARD_EXCLUDE_SET = {

View File

@ -1508,7 +1508,7 @@ class TestNormalizeOperators(JitTestCase):
@ops(op_db, allowed_dtypes=(torch.float,))
def test_normalize_operator_exhaustive(self, device, dtype, op):
# These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot"}
fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa"}
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
if isinstance(op.op, torch._ops.OpOverload):
self.skipTest("normalize operator doesn't work on torch.ops")

View File

@ -2686,6 +2686,18 @@ def gru_impl(
return out, torch.stack(final_hiddens, 0)
@register_decomposition(aten._upsample_bilinear2d_aa.vec)
@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.Autograd)
def upsample_bilinear2d_aa_vec(input, output_size, align_corners, scale_factors):
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
scale_h = get_scale_value(scale_factors, 0)
scale_w = get_scale_value(scale_factors, 1)
return torch.ops.aten._upsample_bilinear2d_aa(
input, osize, align_corners, scale_h, scale_w
)
@register_decomposition(aten.upsample_bilinear2d.vec)
@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd)

View File

@ -2639,6 +2639,22 @@ def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
).contiguous()
@register_meta(aten._upsample_bilinear2d_aa.default)
def meta_upsample_bilinear2d_aa(
input, output_size, align_corners, scales_h=None, scales_w=None
):
full_output_size = upsample_common_check(
input.size(), output_size, num_spatial_dims=2
)
check(
input.numel() != 0 or all([size > 0 for size in input.size()[1:]]),
lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
)
return input.new_empty(full_output_size).to(
memory_format=utils.suggest_memory_format(input)
)
# We must also trigger meta registrations from PrimTorch ref
# decompositions
import torch._refs

View File

@ -18,7 +18,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
all_types, empty_types, complex_types_and, integral_types
all_types, empty_types, complex_types_and, integral_types, floating_types_and_half
)
from torch.testing._internal.common_device_type import \
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
@ -4076,6 +4076,23 @@ def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
yield SampleInput(make_arg(shape(D, rank)), scale_factor=0.6)
def sample_inputs_upsample_aten(mode, self, device, dtype, requires_grad, **kwargs):
N = 6
C = 3
H = 10
W = 20
S = 3
L = 5
input_tensor = make_tensor(torch.Size([N, C, H, W]), device=device, dtype=dtype,
requires_grad=requires_grad, low=-1, high=1)
yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scale_factors=None)
yield SampleInput(input_tensor, output_size=torch.Size([L, L]), align_corners=False, scale_factors=None)
yield SampleInput(input_tensor, output_size=None, align_corners=False, scale_factors=[1.7, 0.9])
yield SampleInput(input_tensor, output_size=None, align_corners=True, scale_factors=[0.8, 1.0])
def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs):
N = 5
for _ in range(1, N):
@ -12368,6 +12385,25 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
),
supports_out=False),
OpInfo('_upsample_bilinear2d_aa',
op=torch.ops.aten._upsample_bilinear2d_aa,
aten_name='_upsample_bilinear2d_aa',
supports_autograd=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=floating_types_and(torch.uint8),
dtypesIfCUDA=floating_types_and_half(),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
sample_inputs_func=partial(sample_inputs_upsample_aten, 'bilinear'),
supports_out=False,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
DecorateInfo(unittest.expectedFailure, 'TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive'),
DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'),
)),
OpInfo(
"nn.functional.soft_margin_loss",
dtypes=floating_types_and(torch.bfloat16),