mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Revert "Add error checking for padding modules (#106147)"
This reverts commit 0547b6279d6f7249c0e588508c2561589514d3aa. Reverted https://github.com/pytorch/pytorch/pull/106147 on behalf of https://github.com/jeanschmidt due to sadly it is breaking internal builds, and I can't coordinate a FF due to timezone differences ([comment](https://github.com/pytorch/pytorch/pull/106147#issuecomment-1661870970))
This commit is contained in:
@ -2416,60 +2416,6 @@ rnn_gru_lstm_module_info_decorators = (
|
||||
|
||||
# Start of module error inputs functions.
|
||||
|
||||
def module_error_inputs_torch_nn_Pad1d(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
is_constant = kwargs.get('is_constant', False)
|
||||
|
||||
return [
|
||||
ErrorModuleInput(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
|
||||
forward_input=FunctionInput(make_input((2, 3, 4, 5))),
|
||||
),
|
||||
error_on=ModuleErrorEnum.FORWARD_ERROR,
|
||||
error_type=ValueError,
|
||||
error_regex=r"expected 2D or 3D input \(got 4D input\)",
|
||||
|
||||
),
|
||||
]
|
||||
|
||||
def module_error_inputs_torch_nn_Pad2d(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
is_constant = kwargs.get('is_constant', False)
|
||||
|
||||
return [
|
||||
ErrorModuleInput(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
|
||||
forward_input=FunctionInput(make_input((2, 3))),
|
||||
),
|
||||
error_on=ModuleErrorEnum.FORWARD_ERROR,
|
||||
error_type=ValueError,
|
||||
error_regex=r"expected 3D or 4D input \(got 2D input\)",
|
||||
|
||||
),
|
||||
]
|
||||
|
||||
def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
is_constant = kwargs.get('is_constant', False)
|
||||
|
||||
return [
|
||||
ErrorModuleInput(
|
||||
ModuleInput(
|
||||
constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
|
||||
forward_input=FunctionInput(make_input((2, 3))),
|
||||
),
|
||||
error_on=ModuleErrorEnum.FORWARD_ERROR,
|
||||
error_type=ValueError,
|
||||
error_regex=r"expected 4D or 5D input \(got 2D input\)",
|
||||
|
||||
),
|
||||
]
|
||||
|
||||
def module_error_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
samples = [
|
||||
@ -3300,13 +3246,11 @@ module_db: List[ModuleInfo] = [
|
||||
decorators=rnn_gru_lstm_module_info_decorators),
|
||||
ModuleInfo(torch.nn.ReflectionPad1d,
|
||||
module_inputs_func=module_inputs_torch_nn_ReflectionPad1d,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_Pad1d,
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
|
||||
),
|
||||
ModuleInfo(torch.nn.ReflectionPad2d,
|
||||
module_inputs_func=module_inputs_torch_nn_ReflectionPad2d,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_Pad2d,
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
skips=(
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
|
||||
@ -3317,7 +3261,6 @@ module_db: List[ModuleInfo] = [
|
||||
),
|
||||
ModuleInfo(torch.nn.ReflectionPad3d,
|
||||
module_inputs_func=module_inputs_torch_nn_ReflectionPad3d,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_Pad3d,
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
skips=(
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
|
||||
@ -3328,13 +3271,11 @@ module_db: List[ModuleInfo] = [
|
||||
),
|
||||
ModuleInfo(torch.nn.ReplicationPad1d,
|
||||
module_inputs_func=module_inputs_torch_nn_ReplicationPad1d,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_Pad1d,
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
|
||||
),
|
||||
ModuleInfo(torch.nn.ReplicationPad2d,
|
||||
module_inputs_func=module_inputs_torch_nn_ReplicationPad2d,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_Pad2d,
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
skips=(
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
|
||||
@ -3345,7 +3286,6 @@ module_db: List[ModuleInfo] = [
|
||||
),
|
||||
ModuleInfo(torch.nn.ReplicationPad3d,
|
||||
module_inputs_func=module_inputs_torch_nn_ReplicationPad3d,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_Pad3d,
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
skips=(
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
|
||||
@ -3363,13 +3303,11 @@ module_db: List[ModuleInfo] = [
|
||||
),
|
||||
ModuleInfo(torch.nn.ZeroPad1d,
|
||||
module_inputs_func=module_inputs_torch_nn_ZeroPad1d,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_Pad1d,
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
|
||||
),
|
||||
ModuleInfo(torch.nn.ZeroPad2d,
|
||||
module_inputs_func=module_inputs_torch_nn_ZeroPad2d,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_Pad2d,
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
|
||||
# Fails with channels last test on MPS backend
|
||||
@ -3377,7 +3315,6 @@ module_db: List[ModuleInfo] = [
|
||||
),
|
||||
ModuleInfo(torch.nn.ZeroPad3d,
|
||||
module_inputs_func=module_inputs_torch_nn_ZeroPad3d,
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_Pad3d,
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
|
||||
# Fails with channels last test on MPS backend
|
||||
@ -3405,13 +3342,11 @@ module_db: List[ModuleInfo] = [
|
||||
),
|
||||
ModuleInfo(torch.nn.ConstantPad1d,
|
||||
module_inputs_func=module_inputs_torch_nn_ConstantPad1d,
|
||||
module_error_inputs_func=partial(module_error_inputs_torch_nn_Pad1d, is_constant=True),
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
|
||||
),
|
||||
ModuleInfo(torch.nn.ConstantPad2d,
|
||||
module_inputs_func=module_inputs_torch_nn_ConstantPad2d,
|
||||
module_error_inputs_func=partial(module_error_inputs_torch_nn_Pad2d, is_constant=True),
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
|
||||
# Fails with channels last test on MPS backend
|
||||
@ -3419,7 +3354,6 @@ module_db: List[ModuleInfo] = [
|
||||
),
|
||||
ModuleInfo(torch.nn.ConstantPad3d,
|
||||
module_inputs_func=module_inputs_torch_nn_ConstantPad3d,
|
||||
module_error_inputs_func=partial(module_error_inputs_torch_nn_Pad3d, is_constant=True),
|
||||
skips=(
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
|
||||
# Fails with channels last test on MPS backend
|
||||
|
||||
Reference in New Issue
Block a user