mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add meta function for _upsample_bilinear2d_aa (#94982)
Differential Revision: D43353000 Pull Request resolved: https://github.com/pytorch/pytorch/pull/94982 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
17d0b7f532
commit
039b4c8809
@ -98,6 +98,7 @@ dtensor_fails = {
|
||||
xfail("__rsub__"),
|
||||
xfail("_native_batch_norm_legit"),
|
||||
xfail("_softmax_backward_data"),
|
||||
xfail("_upsample_bilinear2d_aa"),
|
||||
xfail("addbmm"),
|
||||
xfail("addmv"),
|
||||
xfail("addr"),
|
||||
|
@ -2542,6 +2542,7 @@ symbolic_aot_autograd_failures = {
|
||||
xfail('trapz', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de...
|
||||
xfail('unflatten', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList
|
||||
xfail('var', ''), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
xfail('var', 'unbiased'), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
xfail('var_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
|
@ -1080,6 +1080,7 @@ class TestOperators(TestCase):
|
||||
xfail('nn.functional.dropout3d', ''),
|
||||
xfail('as_strided_scatter', ''),
|
||||
xfail('masked.cumprod', ''),
|
||||
xfail("_upsample_bilinear2d_aa"), # hit vmap fallback, which is disabled
|
||||
}))
|
||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
||||
def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
|
||||
@ -1188,6 +1189,7 @@ class TestOperators(TestCase):
|
||||
xfail("native_batch_norm"),
|
||||
xfail("_native_batch_norm_legit"),
|
||||
xfail("native_dropout_backward"),
|
||||
xfail("_upsample_bilinear2d_aa"), # hit vmap fallback, which is disabled
|
||||
}))
|
||||
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
|
||||
if not op.supports_autograd:
|
||||
|
@ -3751,6 +3751,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
# RuntimeError: Expected all tensors to be on the same device,
|
||||
# but found at least two devices, cuda:0 and cpu!
|
||||
xfail('ge', device_type='cuda'),
|
||||
xfail('_upsample_bilinear2d_aa'),
|
||||
}))
|
||||
def test_op_has_batch_rule(self, device, dtype, op):
|
||||
# needs to be fixed
|
||||
|
@ -329,6 +329,8 @@ CROSS_REF_EXCLUDE_SET = {
|
||||
(None, None, "norm"),
|
||||
# native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise)
|
||||
(None, None, "native_batch_norm"),
|
||||
|
||||
(None, None, "_upsample_bilinear2d_aa"),
|
||||
}
|
||||
|
||||
CROSS_REF_BACKWARD_EXCLUDE_SET = {
|
||||
|
@ -1508,7 +1508,7 @@ class TestNormalizeOperators(JitTestCase):
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
def test_normalize_operator_exhaustive(self, device, dtype, op):
|
||||
# These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
|
||||
fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot"}
|
||||
fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa"}
|
||||
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
if isinstance(op.op, torch._ops.OpOverload):
|
||||
self.skipTest("normalize operator doesn't work on torch.ops")
|
||||
|
@ -2686,6 +2686,18 @@ def gru_impl(
|
||||
return out, torch.stack(final_hiddens, 0)
|
||||
|
||||
|
||||
@register_decomposition(aten._upsample_bilinear2d_aa.vec)
|
||||
@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.Autograd)
|
||||
def upsample_bilinear2d_aa_vec(input, output_size, align_corners, scale_factors):
|
||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
||||
scale_h = get_scale_value(scale_factors, 0)
|
||||
scale_w = get_scale_value(scale_factors, 1)
|
||||
return torch.ops.aten._upsample_bilinear2d_aa(
|
||||
input, osize, align_corners, scale_h, scale_w
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition(aten.upsample_bilinear2d.vec)
|
||||
@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd)
|
||||
|
@ -2639,6 +2639,22 @@ def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
|
||||
).contiguous()
|
||||
|
||||
|
||||
@register_meta(aten._upsample_bilinear2d_aa.default)
|
||||
def meta_upsample_bilinear2d_aa(
|
||||
input, output_size, align_corners, scales_h=None, scales_w=None
|
||||
):
|
||||
full_output_size = upsample_common_check(
|
||||
input.size(), output_size, num_spatial_dims=2
|
||||
)
|
||||
check(
|
||||
input.numel() != 0 or all([size > 0 for size in input.size()[1:]]),
|
||||
lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
|
||||
)
|
||||
return input.new_empty(full_output_size).to(
|
||||
memory_format=utils.suggest_memory_format(input)
|
||||
)
|
||||
|
||||
|
||||
# We must also trigger meta registrations from PrimTorch ref
|
||||
# decompositions
|
||||
import torch._refs
|
||||
|
@ -18,7 +18,7 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_dtype import (
|
||||
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
|
||||
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
|
||||
all_types, empty_types, complex_types_and, integral_types
|
||||
all_types, empty_types, complex_types_and, integral_types, floating_types_and_half
|
||||
)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
||||
@ -4076,6 +4076,23 @@ def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
|
||||
yield SampleInput(make_arg(shape(D, rank)), scale_factor=0.6)
|
||||
|
||||
|
||||
def sample_inputs_upsample_aten(mode, self, device, dtype, requires_grad, **kwargs):
|
||||
N = 6
|
||||
C = 3
|
||||
H = 10
|
||||
W = 20
|
||||
S = 3
|
||||
L = 5
|
||||
|
||||
input_tensor = make_tensor(torch.Size([N, C, H, W]), device=device, dtype=dtype,
|
||||
requires_grad=requires_grad, low=-1, high=1)
|
||||
|
||||
yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scale_factors=None)
|
||||
yield SampleInput(input_tensor, output_size=torch.Size([L, L]), align_corners=False, scale_factors=None)
|
||||
yield SampleInput(input_tensor, output_size=None, align_corners=False, scale_factors=[1.7, 0.9])
|
||||
yield SampleInput(input_tensor, output_size=None, align_corners=True, scale_factors=[0.8, 1.0])
|
||||
|
||||
|
||||
def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs):
|
||||
N = 5
|
||||
for _ in range(1, N):
|
||||
@ -12368,6 +12385,25 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
),
|
||||
supports_out=False),
|
||||
OpInfo('_upsample_bilinear2d_aa',
|
||||
op=torch.ops.aten._upsample_bilinear2d_aa,
|
||||
aten_name='_upsample_bilinear2d_aa',
|
||||
supports_autograd=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
dtypes=floating_types_and(torch.uint8),
|
||||
dtypesIfCUDA=floating_types_and_half(),
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
sample_inputs_func=partial(sample_inputs_upsample_aten, 'bilinear'),
|
||||
supports_out=False,
|
||||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'),
|
||||
)),
|
||||
OpInfo(
|
||||
"nn.functional.soft_margin_loss",
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
|
Reference in New Issue
Block a user