Compare commits

...

10 Commits

6 changed files with 133 additions and 12 deletions

View File

@ -157,6 +157,7 @@ dtensor_fails = {
xfail("cholesky_solve"),
xfail("combinations"),
xfail("complex"),
xfail("convolution_backward"),
xfail("count_nonzero"),
xfail("cross"),
xfail("cummax"),

View File

@ -8218,6 +8218,9 @@ symbolic_aot_autograd_failures = {
"nn.functional.fractional_max_pool3d", ""
), # rand() received an invalid combination of arguments - g...
xfail("trace", ""), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail(
"convolution_backward", ""
), # Cannot call sizes() on tensor with symbolic sizes/strides
decorate(
"linalg.householder_product",
decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"),

View File

@ -381,6 +381,7 @@ CHECK_STRIDES_SKIPS = {
# channel_last and channel_last_3d related failures
aten.convolution.default,
aten.convolution_backward.default,
# following ops fails if include_storage_offset = True, but these are a bit edge casey
# we should still fix them, leaving them here for tracking.

View File

@ -3514,8 +3514,13 @@ def meta_convolution_backward(
if output_mask[1]:
backend_grad_weight = grad_output_.new_empty(weight_.size())
if output_mask[2]:
backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
if bias_sizes_opt:
backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
else:
if transposed:
backend_grad_bias = grad_output_.new_empty(weight_.size(1) * groups)
else:
backend_grad_bias = grad_output_.new_empty(weight_.size(0))
return (backend_grad_input, backend_grad_weight, backend_grad_bias)

View File

@ -253,18 +253,31 @@ std::vector<Shape> compute_shape_convolution_backward(
at::IntArrayRef output_padding,
int64_t groups,
::std::array<bool, 3> output_mask) {
if (bias_sizes.has_value()) {
return {
Shape(input.scalar_type(), input.sizes().vec()),
Shape(weight.scalar_type(), weight.sizes().vec()),
Shape(grad_output.scalar_type(), bias_sizes.value().vec())};
Shape s0, s1, s2;
if (output_mask[0]) {
s0 = Shape(input.scalar_type(), input.sizes().vec());
} else {
// TODO(whc) not sure whether to return 2 shapes here, or a 3rd one that is
// empty
return {
Shape(input.scalar_type(), input.sizes().vec()),
Shape(weight.scalar_type(), weight.sizes().vec())};
s0 = Shape();
}
if (output_mask[1]) {
s1 = Shape(weight.scalar_type(), weight.sizes().vec());
} else {
s1 = Shape();
}
if (output_mask[2]) {
if (bias_sizes.has_value()) {
s2 = Shape(grad_output.scalar_type(), bias_sizes.value().vec());
} else {
if (transposed) {
s2 = Shape(grad_output.scalar_type(), weight.size(1) * groups);
} else {
s2 = Shape(grad_output.scalar_type(), weight.size(0));
}
}
} else {
s2 = Shape();
}
return {s0, s1, s2};
}
std::vector<Shape> compute_shape_convolution(

View File

@ -8032,6 +8032,56 @@ def sample_inputs_dropout_backward(op_info, device, dtype, requires_grad, **kwar
for case, scale in product(cases, scale_vals):
yield SampleInput(make_arg(case), make_mask(case), scale)
def sample_inputs_convolution_backward(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
def get_in_dim(out_dim, pad, dialation, kernel, stride):
return (stride * (out_dim - 1)) + 1 + (dialation * (kernel - 1)) - (2 * pad)
def get_random_conv_bwd_inputs(num_cases):
pad = dialation = stride = kernel = 2
for (is_transposed, input_bias, groups, tensor_sizes) in zip(
[True, False],
[True, False],
[1, 4],
[[M, M, M, L, L], [M, M, S, L, L]]
):
[N, C_in, C_out, H_out, W_out] = tensor_sizes
C_in = (C_in // groups) * groups
C_out = (C_out // groups) * groups
H_in = get_in_dim(H_out, pad, dialation, kernel, stride)
W_in = get_in_dim(W_out, pad, dialation, kernel, stride)
if is_transposed:
grad_output = make_arg([N, C_in, H_in, W_in]),
args = (
make_arg([N, C_out, H_out, W_out]),
make_arg([C_out, C_in // groups, kernel, kernel]),
)
bias_size = [C_in * groups] if input_bias else None
else:
grad_output = make_arg([N, C_out, H_out, W_out]),
args = (
make_arg([N, C_in, H_in, W_in]),
make_arg([C_out, C_in // groups, kernel, kernel]),
)
bias_size = [C_out] if input_bias else None
kwargs = {
"bias_sizes": bias_size,
"stride": [stride, stride],
"padding": [pad, pad],
"dilation": [dialation, dialation],
"transposed": is_transposed,
"output_padding": [0],
"groups": groups,
"output_mask": [True, True, True],
}
yield (grad_output, args, kwargs)
for grad_output, args, kwargs in get_random_conv_bwd_inputs(5):
yield SampleInput(grad_output[0], args=args, kwargs=kwargs)
def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs):
def make_input(shape):
return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
@ -20808,6 +20858,54 @@ op_db: list[OpInfo] = [
),
),
),
OpInfo(
"convolution_backward",
op=torch.ops.aten.convolution_backward.default,
aten_name="convolution_backward",
dtypes=floating_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
decorators=(
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=5e-4, rtol=2e-6)}),
'TestCommon', 'test_noncontiguous_samples',
),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=2e-3, rtol=2e-3)}),
'TestCommon', 'test_noncontiguous_samples', active_if=TEST_WITH_ROCM
),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=2e-3, rtol=3e-3)}),
'TestOperators',
),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=5e-4, rtol=2e-5)}),
'TestCompositeCompliance',
),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}),
'TestCompositeCompliance', active_if=TEST_WITH_ROCM
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=2e-3, rtol=6e-1)}),
'TestInductorOpInfo', 'test_comprehensive',
),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=5e-4, rtol=5e-4)}),
'TestVmapOperatorsOpInfo', 'test_vmap_exhaustive',
),
),
skips=(
DecorateInfo(unittest.expectedFailure,
'TestConsistency', 'test_output_match', device_type="mps"),
DecorateInfo(unittest.expectedFailure,
'TestConsistency', 'test_output_grad_match', device_type="mps"),
),
sample_inputs_func=sample_inputs_convolution_backward
),
OpInfo(
"nn.functional.dropout2d",
op=lambda input, *args, **kwargs: