Add python and C++ support for LPPool3d (#114199)

Add python and C++ support for LPPool3d to Fixes #114114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114199
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Wongboo
2023-12-08 18:18:40 +00:00
committed by PyTorch MergeBot
parent 1c3a4a864c
commit 68f74dd162
21 changed files with 306 additions and 15 deletions

View File

@ -3116,6 +3116,7 @@ coverage_ignore_classes = [
"FractionalMaxPool3d",
"LPPool1d",
"LPPool2d",
"LPPool3d",
"MaxPool1d",
"MaxPool2d",
"MaxPool3d",

View File

@ -40,6 +40,7 @@ Pooling functions
max_unpool3d
lp_pool1d
lp_pool2d
lp_pool3d
adaptive_max_pool1d
adaptive_max_pool2d
adaptive_max_pool3d

View File

@ -103,6 +103,7 @@ Pooling layers
nn.FractionalMaxPool3d
nn.LPPool1d
nn.LPPool2d
nn.LPPool3d
nn.AdaptiveMaxPool1d
nn.AdaptiveMaxPool2d
nn.AdaptiveMaxPool3d

View File

@ -421,6 +421,7 @@ l1_loss
log_softmax
lp_pool1d
lp_pool2d
lp_pool3d
lstm_cell
margin_ranking_loss
max_pool1d_with_indices

View File

@ -254,17 +254,35 @@ TEST_F(FunctionalTest, LPPool2d) {
int stride = 2;
std::vector<int64_t> kernel_size({2, 3});
auto x = torch::ones({1, 2, 5});
auto x = torch::ones({1, 1, 2, 5});
auto y = F::lp_pool2d(
x, F::LPPool2dFuncOptions(norm_type, kernel_size).stride(stride));
auto expected =
(torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
(torch::pow(torch::tensor({{{{1, 1}}}}, torch::kFloat), norm_type) *
(kernel_size[0] * kernel_size[1]))
.pow(1. / norm_type);
ASSERT_EQ(y.ndimension(), 3);
ASSERT_EQ(y.ndimension(), 4);
ASSERT_TRUE(torch::allclose(y, expected));
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 2}));
}
TEST_F(FunctionalTest, LPPool3d) {
int norm_type = 2;
int stride = 2;
std::vector<int64_t> kernel_size({1, 2, 3});
auto x = torch::ones({1, 1, 1, 2, 5});
auto y = F::lp_pool3d(
x, F::LPPool3dFuncOptions(norm_type, kernel_size).stride(stride));
auto expected =
(torch::pow(torch::tensor({{{{{1, 1}}}}}, torch::kFloat), norm_type) *
(kernel_size[0] * kernel_size[1] * kernel_size[2]))
.pow(1. / norm_type);
ASSERT_EQ(y.ndimension(), 5);
ASSERT_TRUE(torch::allclose(y, expected));
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 1, 2}));
}
TEST_F(FunctionalTest, CosineSimilarity) {

View File

@ -533,16 +533,34 @@ TEST_F(ModulesTest, LPPool2d) {
std::vector<int64_t> kernel_size({2, 3});
LPPool2d model(LPPool2dOptions(norm_type, kernel_size).stride(stride));
auto x = torch::ones({1, 2, 5});
auto x = torch::ones({1, 1, 2, 5});
auto y = model(x);
auto expected =
(torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
(torch::pow(torch::tensor({{{{1, 1}}}}, torch::kFloat), norm_type) *
(kernel_size[0] * kernel_size[1]))
.pow(1. / norm_type);
ASSERT_EQ(y.ndimension(), 3);
ASSERT_EQ(y.ndimension(), 4);
ASSERT_TRUE(torch::allclose(y, expected));
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 2}));
}
TEST_F(ModulesTest, LPPool3d) {
int norm_type = 2;
int stride = 2;
std::vector<int64_t> kernel_size({1, 2, 3});
LPPool3d model(LPPool3dOptions(norm_type, kernel_size).stride(stride));
auto x = torch::ones({1, 1, 1, 2, 5});
auto y = model(x);
auto expected =
(torch::pow(torch::tensor({{{{{1, 1}}}}}, torch::kFloat), norm_type) *
(kernel_size[0] * kernel_size[1] * kernel_size[2]))
.pow(1. / norm_type);
ASSERT_EQ(y.ndimension(), 5);
ASSERT_TRUE(torch::allclose(y, expected));
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 1, 2}));
}
TEST_F(ModulesTest, Identity) {
@ -4779,6 +4797,14 @@ TEST_F(ModulesTest, PrettyPrintLPPool) {
.stride({5, 6})
.ceil_mode(true))),
"torch::nn::LPPool2d(norm_type=1, kernel_size=[3, 4], stride=[5, 6], ceil_mode=true)");
ASSERT_EQ(
c10::str(LPPool3d(2, std::vector<int64_t>({1, 2, 3}))),
"torch::nn::LPPool3d(norm_type=2, kernel_size=[1, 2, 3], stride=[1, 2, 3], ceil_mode=false)");
ASSERT_EQ(
c10::str(LPPool3d(LPPool3dOptions(1, std::vector<int64_t>({3, 4, 5}))
.stride({5, 6, 7})
.ceil_mode(true))),
"torch::nn::LPPool3d(norm_type=1, kernel_size=[3, 4, 5], stride=[5, 6, 7], ceil_mode=true)");
}
TEST_F(ModulesTest, PrettyPrintAdaptiveMaxPool) {

View File

@ -29,6 +29,7 @@ torch::nn::FractionalMaxPool2d|Yes|No
torch::nn::FractionalMaxPool3d|Yes|No
torch::nn::LPPool1d|Yes|No
torch::nn::LPPool2d|Yes|No
torch::nn::LPPool3d|Yes|No
torch::nn::AdaptiveMaxPool1d|Yes|No
torch::nn::AdaptiveMaxPool2d|Yes|No
torch::nn::AdaptiveMaxPool3d|Yes|No
@ -173,6 +174,7 @@ F::max_unpool2d|Yes|No
F::max_unpool3d|Yes|No
F::lp_pool1d|Yes|No
F::lp_pool2d|Yes|No
F::lp_pool3d|Yes|No
F::adaptive_max_pool1d|Yes|No
F::adaptive_max_pool2d|Yes|No
F::adaptive_max_pool3d|Yes|No

View File

@ -4238,6 +4238,7 @@ class TestFunctionalTracing(JitTestCase):
"max_pool3d": PROXY_ITERABLE,
"lp_pool2d": PROXY_ITERATED,
"lp_pool3d": PROXY_ITERATED,
"max_unpool1d": PROXY_ITERATED,
"max_unpool2d": PROXY_ITERATED,
"max_unpool3d": PROXY_ITERATED,

View File

@ -83,6 +83,7 @@ def build_constructor_arg_db():
torch.nn.L1Loss: ((), {}),
torch.nn.LPPool1d: ((2, 3), {}),
torch.nn.LPPool2d: ((2, 3), {}),
torch.nn.LPPool3d: ((2, 3), {}),
torch.nn.LSTM: ((5, 10), {}),
torch.nn.LSTMCell: ((5, 10), {}),
torch.nn.LayerNorm: ((2,), {}),

View File

@ -1093,6 +1093,56 @@ inline Tensor lp_pool2d(
options.ceil_mode());
}
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor lp_pool3d(
const Tensor& input,
double norm_type,
ExpandingArray<3> kernel_size,
ExpandingArray<3> stride,
bool ceil_mode) {
int kd = (*kernel_size)[0];
int kw = (*kernel_size)[1];
int kh = (*kernel_size)[2];
Tensor out = detail::avg_pool3d(
input.pow(norm_type),
kernel_size,
stride,
/*padding=*/0,
ceil_mode,
/*count_include_pad=*/true,
/*divisor_override=*/c10::nullopt);
return (torch::sign(out) * relu(torch::abs(out)))
.mul(kd * kw * kh)
.pow(1. / norm_type);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
/// See
/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.lp_pool3d
/// about the exact behavior of this functional.
///
/// See the documentation for `torch::nn::functional::LPPool3dFuncOptions` class
/// to learn what optional arguments are supported for this functional.
///
/// Example:
/// ```
/// namespace F = torch::nn::functional;
/// F::lp_pool3d(x, F::LPPool3dFuncOptions(3, {3, 3, 5}).stride(3));
/// ```
inline Tensor lp_pool3d(
const Tensor& input,
const LPPool3dFuncOptions& options) {
return detail::lp_pool3d(
input,
options.norm_type(),
options.kernel_size(),
options.stride(),
options.ceil_mode());
}
} // namespace functional
} // namespace nn
} // namespace torch

View File

@ -747,5 +747,33 @@ class TORCH_API LPPool2dImpl : public LPPoolImpl<2, LPPool2dImpl> {
/// learn about PyTorch's module storage semantics.
TORCH_MODULE(LPPool2d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LPPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Applies the LPPool3d function element-wise.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.LPPool3d to learn
/// about the exact behavior of this module.
///
/// See the documentation for `torch::nn::LPPool3dOptions` class to learn what
/// constructor arguments are supported for this module.
///
/// Example:
/// ```
/// LPPool3d model(LPPool3dOptions(1, std::vector<int64_t>({3, 4, 5})).stride(
/// {5, 6, 7}).ceil_mode(true));
/// ```
class TORCH_API LPPool3dImpl : public LPPoolImpl<3, LPPool3dImpl> {
public:
using LPPoolImpl<3, LPPool3dImpl>::LPPoolImpl;
Tensor forward(const Tensor& input);
};
/// A `ModuleHolder` subclass for `LPPool3dImpl`.
/// See the documentation for `LPPool3dImpl` class to learn what methods it
/// provides, and examples of how to use `LPPool3d` with
/// `torch::nn::LPPool3dOptions`. See the documentation for `ModuleHolder` to
/// learn about PyTorch's module storage semantics.
TORCH_MODULE(LPPool3d);
} // namespace nn
} // namespace torch

View File

@ -541,6 +541,15 @@ using LPPool1dOptions = LPPoolOptions<1>;
/// ```
using LPPool2dOptions = LPPoolOptions<2>;
/// `LPPoolOptions` specialized for the `LPPool3d` module.
///
/// Example:
/// ```
/// LPPool3d model(LPPool3dOptions(1, std::vector<int64_t>({3, 4, 5})).stride(
/// {5, 6, 7}).ceil_mode(true));
/// ```
using LPPool3dOptions = LPPoolOptions<3>;
namespace functional {
/// Options for `torch::nn::functional::lp_pool1d`.
///
@ -569,5 +578,19 @@ namespace functional {
using LPPool2dFuncOptions = LPPool2dOptions;
} // namespace functional
namespace functional {
/// Options for `torch::nn::functional::lp_pool3d`.
///
/// See the documentation for `torch::nn::LPPool3dOptions` class to learn what
/// arguments are supported.
///
/// Example:
/// ```
/// namespace F = torch::nn::functional;
/// F::lp_pool3d(x, F::LPPool3dFuncOptions(2, {2, 3, 4}).stride(2));
/// ```
using LPPool3dFuncOptions = LPPool3dOptions;
} // namespace functional
} // namespace nn
} // namespace torch

View File

@ -429,5 +429,16 @@ Tensor LPPool2dImpl::forward(const Tensor& input) {
template class LPPoolImpl<2, LPPool2dImpl>;
Tensor LPPool3dImpl::forward(const Tensor& input) {
return F::detail::lp_pool3d(
input,
options.norm_type(),
options.kernel_size(),
options.stride(),
options.ceil_mode());
}
template class LPPoolImpl<3, LPPool3dImpl>;
} // namespace nn
} // namespace torch

View File

@ -25,6 +25,7 @@ template struct MaxUnpoolOptions<3>;
template struct LPPoolOptions<1>;
template struct LPPoolOptions<2>;
template struct LPPoolOptions<3>;
} // namespace nn
} // namespace torch

View File

@ -1022,6 +1022,33 @@ def max_unpool3d(
return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding)
def lp_pool3d(
input: Tensor, norm_type: Union[int, float],
kernel_size: BroadcastingList3[int],
stride: Optional[BroadcastingList3[int]] = None,
ceil_mode: bool = False
) -> Tensor:
r"""
Apply a 3D power-average pooling over an input signal composed of several input planes.
If the sum of all inputs to the power of `p` is
zero, the gradient is set to zero as well.
See :class:`~torch.nn.LPPool3d` for details.
"""
if has_torch_function_unary(input):
return handle_torch_function(
lp_pool3d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode
)
kd, kw, kh = utils._triple(kernel_size)
if stride is not None:
out = avg_pool3d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
else:
out = avg_pool3d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode)
return (torch.sign(out) * relu(torch.abs(out))).mul(kd * kw * kh).pow(1.0 / norm_type)
def lp_pool2d(
input: Tensor, norm_type: Union[int, float],
kernel_size: BroadcastingList2[int],

View File

@ -127,6 +127,13 @@ def lp_pool2d(
stride: Union[Optional[_size], Optional[int]] = ...,
ceil_mode: bool = ...,
) -> Tensor: ...
def lp_pool3d(
input: Tensor,
norm_type: float,
kernel_size: _size_3_t,
stride: Union[Optional[_size], Optional[int]] = ...,
ceil_mode: bool = ...,
) -> Tensor: ...
def adaptive_max_pool1d_with_indices(
input: Tensor,
output_size: _size,

View File

@ -13,7 +13,7 @@ from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLos
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss, GaussianNLLLoss
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, LPPool3d, \
AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \
LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d
@ -48,8 +48,8 @@ __all__ = [
'HuberLoss', 'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList', 'ModuleDict',
'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
'LPPool1d', 'LPPool2d', 'LPPool3d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',

View File

@ -10,8 +10,8 @@ from ..common_types import (_size_any_t, _size_1_t, _size_2_t, _size_3_t,
__all__ = ['MaxPool1d', 'MaxPool2d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d',
'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'FractionalMaxPool2d', 'FractionalMaxPool3d', 'LPPool1d',
'LPPool2d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d']
'LPPool2d', 'LPPool3d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d']
class _MaxPoolNd(Module):
__constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
@ -953,8 +953,8 @@ class LPPool2d(_LPPoolNd):
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
Shape:
- Input: :math:`(N, C, H_{in}, W_{in})`
- Output: :math:`(N, C, H_{out}, W_{out})`, where
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
.. math::
H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
@ -981,6 +981,64 @@ class LPPool2d(_LPPoolNd):
self.stride, self.ceil_mode)
class LPPool3d(_LPPoolNd):
r"""Applies a 3D power-average pooling over an input signal composed of several input planes.
On each window, the function computed is:
.. math::
f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
- At p = :math:`\infty`, one gets Max Pooling
- At p = 1, one gets Sum Pooling (which is proportional to average pooling)
The parameters :attr:`kernel_size`, :attr:`stride` can either be:
- a single ``int`` -- in which case the same value is used for the height, width and depth dimension
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
the second `int` for the height dimension and the third `int` for the width dimension
.. note:: If the sum to the power of `p` is zero, the gradient of this function is
not defined. This implementation will set the gradient to zero in this case.
Args:
kernel_size: the size of the window
stride: the stride of the window. Default value is :attr:`kernel_size`
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
Shape:
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
:math:`(C, D_{out}, H_{out}, W_{out})`, where
.. math::
D_{out} = \left\lfloor\frac{D_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
.. math::
H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
.. math::
W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
Examples::
>>> # power-2 pool of square window of size=3, stride=2
>>> m = nn.LPPool3d(2, 3, stride=2)
>>> # pool of non-square window of power 1.2
>>> m = nn.LPPool3d(1.2, (3, 2, 2), stride=(2, 1, 2))
>>> input = torch.randn(20, 16, 50, 44, 31)
>>> output = m(input)
"""
kernel_size: _size_3_t
stride: _size_3_t
def forward(self, input: Tensor) -> Tensor:
return F.lp_pool3d(input, float(self.norm_type), self.kernel_size,
self.stride, self.ceil_mode)
class _AdaptiveMaxPoolNd(Module):
__constants__ = ['output_size', 'return_indices']
return_indices: bool

View File

@ -866,6 +866,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.logsigmoid: lambda input: -1,
torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
reduce=None, reduction='mean': -1),
torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,

View File

@ -1467,6 +1467,11 @@ def module_inputs_torch_nn_LPPool2d(module_info, device, dtype, requires_grad, t
ModuleInput(
constructor_input=FunctionInput(2, 2, 2),
forward_input=FunctionInput(make_input((1, 3, 7, 7)))),
ModuleInput(
constructor_input=FunctionInput(2, 2, 2),
forward_input=FunctionInput(make_input((3, 7, 7))),
reference_fn=no_batch_dim_reference_fn,
desc='no_batch_dim'),
ModuleInput(
constructor_input=FunctionInput(1.5, 2),
forward_input=FunctionInput(make_input((1, 3, 7, 7))),
@ -1474,6 +1479,25 @@ def module_inputs_torch_nn_LPPool2d(module_info, device, dtype, requires_grad, t
]
def module_inputs_torch_nn_LPPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return [
ModuleInput(
constructor_input=FunctionInput(2, 2, 2),
forward_input=FunctionInput(make_input((1, 3, 7, 7, 7)))),
ModuleInput(
constructor_input=FunctionInput(2, 2, 2),
forward_input=FunctionInput(make_input((3, 7, 7, 7))),
reference_fn=no_batch_dim_reference_fn,
desc='no_batch_dim'),
ModuleInput(
constructor_input=FunctionInput(1.5, 2),
forward_input=FunctionInput(make_input((1, 3, 7, 7, 7))),
desc='norm'),
]
def module_inputs_torch_nn_MaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -3073,6 +3097,14 @@ module_db: List[ModuleInfo] = [
device_type='mps',
),)
),
ModuleInfo(torch.nn.LPPool3d,
module_inputs_func=module_inputs_torch_nn_LPPool3d,
skips=(
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
DecorateInfo(skipIfMps),)
),
ModuleInfo(torch.nn.MaxPool1d,
module_inputs_func=module_inputs_torch_nn_MaxPool1d,
skips=(

View File

@ -115,6 +115,7 @@ nn_functional_tests = [
('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
('lp_pool1d', (S, S, S), (2., 3, 2,)),
('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)),
('adaptive_max_pool1d', (S, S, S), (5,)),
('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),