mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add python and C++ support for LPPool3d (#114199)
Add python and C++ support for LPPool3d to Fixes #114114 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114199 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c3a4a864c
commit
68f74dd162
@ -3116,6 +3116,7 @@ coverage_ignore_classes = [
|
|||||||
"FractionalMaxPool3d",
|
"FractionalMaxPool3d",
|
||||||
"LPPool1d",
|
"LPPool1d",
|
||||||
"LPPool2d",
|
"LPPool2d",
|
||||||
|
"LPPool3d",
|
||||||
"MaxPool1d",
|
"MaxPool1d",
|
||||||
"MaxPool2d",
|
"MaxPool2d",
|
||||||
"MaxPool3d",
|
"MaxPool3d",
|
||||||
|
|||||||
@ -40,6 +40,7 @@ Pooling functions
|
|||||||
max_unpool3d
|
max_unpool3d
|
||||||
lp_pool1d
|
lp_pool1d
|
||||||
lp_pool2d
|
lp_pool2d
|
||||||
|
lp_pool3d
|
||||||
adaptive_max_pool1d
|
adaptive_max_pool1d
|
||||||
adaptive_max_pool2d
|
adaptive_max_pool2d
|
||||||
adaptive_max_pool3d
|
adaptive_max_pool3d
|
||||||
|
|||||||
@ -103,6 +103,7 @@ Pooling layers
|
|||||||
nn.FractionalMaxPool3d
|
nn.FractionalMaxPool3d
|
||||||
nn.LPPool1d
|
nn.LPPool1d
|
||||||
nn.LPPool2d
|
nn.LPPool2d
|
||||||
|
nn.LPPool3d
|
||||||
nn.AdaptiveMaxPool1d
|
nn.AdaptiveMaxPool1d
|
||||||
nn.AdaptiveMaxPool2d
|
nn.AdaptiveMaxPool2d
|
||||||
nn.AdaptiveMaxPool3d
|
nn.AdaptiveMaxPool3d
|
||||||
|
|||||||
@ -421,6 +421,7 @@ l1_loss
|
|||||||
log_softmax
|
log_softmax
|
||||||
lp_pool1d
|
lp_pool1d
|
||||||
lp_pool2d
|
lp_pool2d
|
||||||
|
lp_pool3d
|
||||||
lstm_cell
|
lstm_cell
|
||||||
margin_ranking_loss
|
margin_ranking_loss
|
||||||
max_pool1d_with_indices
|
max_pool1d_with_indices
|
||||||
|
|||||||
@ -254,17 +254,35 @@ TEST_F(FunctionalTest, LPPool2d) {
|
|||||||
int stride = 2;
|
int stride = 2;
|
||||||
std::vector<int64_t> kernel_size({2, 3});
|
std::vector<int64_t> kernel_size({2, 3});
|
||||||
|
|
||||||
auto x = torch::ones({1, 2, 5});
|
auto x = torch::ones({1, 1, 2, 5});
|
||||||
auto y = F::lp_pool2d(
|
auto y = F::lp_pool2d(
|
||||||
x, F::LPPool2dFuncOptions(norm_type, kernel_size).stride(stride));
|
x, F::LPPool2dFuncOptions(norm_type, kernel_size).stride(stride));
|
||||||
auto expected =
|
auto expected =
|
||||||
(torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
|
(torch::pow(torch::tensor({{{{1, 1}}}}, torch::kFloat), norm_type) *
|
||||||
(kernel_size[0] * kernel_size[1]))
|
(kernel_size[0] * kernel_size[1]))
|
||||||
.pow(1. / norm_type);
|
.pow(1. / norm_type);
|
||||||
|
|
||||||
ASSERT_EQ(y.ndimension(), 3);
|
ASSERT_EQ(y.ndimension(), 4);
|
||||||
ASSERT_TRUE(torch::allclose(y, expected));
|
ASSERT_TRUE(torch::allclose(y, expected));
|
||||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FunctionalTest, LPPool3d) {
|
||||||
|
int norm_type = 2;
|
||||||
|
int stride = 2;
|
||||||
|
std::vector<int64_t> kernel_size({1, 2, 3});
|
||||||
|
|
||||||
|
auto x = torch::ones({1, 1, 1, 2, 5});
|
||||||
|
auto y = F::lp_pool3d(
|
||||||
|
x, F::LPPool3dFuncOptions(norm_type, kernel_size).stride(stride));
|
||||||
|
auto expected =
|
||||||
|
(torch::pow(torch::tensor({{{{{1, 1}}}}}, torch::kFloat), norm_type) *
|
||||||
|
(kernel_size[0] * kernel_size[1] * kernel_size[2]))
|
||||||
|
.pow(1. / norm_type);
|
||||||
|
|
||||||
|
ASSERT_EQ(y.ndimension(), 5);
|
||||||
|
ASSERT_TRUE(torch::allclose(y, expected));
|
||||||
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 1, 2}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FunctionalTest, CosineSimilarity) {
|
TEST_F(FunctionalTest, CosineSimilarity) {
|
||||||
|
|||||||
@ -533,16 +533,34 @@ TEST_F(ModulesTest, LPPool2d) {
|
|||||||
std::vector<int64_t> kernel_size({2, 3});
|
std::vector<int64_t> kernel_size({2, 3});
|
||||||
|
|
||||||
LPPool2d model(LPPool2dOptions(norm_type, kernel_size).stride(stride));
|
LPPool2d model(LPPool2dOptions(norm_type, kernel_size).stride(stride));
|
||||||
auto x = torch::ones({1, 2, 5});
|
auto x = torch::ones({1, 1, 2, 5});
|
||||||
auto y = model(x);
|
auto y = model(x);
|
||||||
auto expected =
|
auto expected =
|
||||||
(torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
|
(torch::pow(torch::tensor({{{{1, 1}}}}, torch::kFloat), norm_type) *
|
||||||
(kernel_size[0] * kernel_size[1]))
|
(kernel_size[0] * kernel_size[1]))
|
||||||
.pow(1. / norm_type);
|
.pow(1. / norm_type);
|
||||||
|
|
||||||
ASSERT_EQ(y.ndimension(), 3);
|
ASSERT_EQ(y.ndimension(), 4);
|
||||||
ASSERT_TRUE(torch::allclose(y, expected));
|
ASSERT_TRUE(torch::allclose(y, expected));
|
||||||
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ModulesTest, LPPool3d) {
|
||||||
|
int norm_type = 2;
|
||||||
|
int stride = 2;
|
||||||
|
std::vector<int64_t> kernel_size({1, 2, 3});
|
||||||
|
|
||||||
|
LPPool3d model(LPPool3dOptions(norm_type, kernel_size).stride(stride));
|
||||||
|
auto x = torch::ones({1, 1, 1, 2, 5});
|
||||||
|
auto y = model(x);
|
||||||
|
auto expected =
|
||||||
|
(torch::pow(torch::tensor({{{{{1, 1}}}}}, torch::kFloat), norm_type) *
|
||||||
|
(kernel_size[0] * kernel_size[1] * kernel_size[2]))
|
||||||
|
.pow(1. / norm_type);
|
||||||
|
|
||||||
|
ASSERT_EQ(y.ndimension(), 5);
|
||||||
|
ASSERT_TRUE(torch::allclose(y, expected));
|
||||||
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 1, 1, 2}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModulesTest, Identity) {
|
TEST_F(ModulesTest, Identity) {
|
||||||
@ -4779,6 +4797,14 @@ TEST_F(ModulesTest, PrettyPrintLPPool) {
|
|||||||
.stride({5, 6})
|
.stride({5, 6})
|
||||||
.ceil_mode(true))),
|
.ceil_mode(true))),
|
||||||
"torch::nn::LPPool2d(norm_type=1, kernel_size=[3, 4], stride=[5, 6], ceil_mode=true)");
|
"torch::nn::LPPool2d(norm_type=1, kernel_size=[3, 4], stride=[5, 6], ceil_mode=true)");
|
||||||
|
ASSERT_EQ(
|
||||||
|
c10::str(LPPool3d(2, std::vector<int64_t>({1, 2, 3}))),
|
||||||
|
"torch::nn::LPPool3d(norm_type=2, kernel_size=[1, 2, 3], stride=[1, 2, 3], ceil_mode=false)");
|
||||||
|
ASSERT_EQ(
|
||||||
|
c10::str(LPPool3d(LPPool3dOptions(1, std::vector<int64_t>({3, 4, 5}))
|
||||||
|
.stride({5, 6, 7})
|
||||||
|
.ceil_mode(true))),
|
||||||
|
"torch::nn::LPPool3d(norm_type=1, kernel_size=[3, 4, 5], stride=[5, 6, 7], ceil_mode=true)");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModulesTest, PrettyPrintAdaptiveMaxPool) {
|
TEST_F(ModulesTest, PrettyPrintAdaptiveMaxPool) {
|
||||||
|
|||||||
@ -29,6 +29,7 @@ torch::nn::FractionalMaxPool2d|Yes|No
|
|||||||
torch::nn::FractionalMaxPool3d|Yes|No
|
torch::nn::FractionalMaxPool3d|Yes|No
|
||||||
torch::nn::LPPool1d|Yes|No
|
torch::nn::LPPool1d|Yes|No
|
||||||
torch::nn::LPPool2d|Yes|No
|
torch::nn::LPPool2d|Yes|No
|
||||||
|
torch::nn::LPPool3d|Yes|No
|
||||||
torch::nn::AdaptiveMaxPool1d|Yes|No
|
torch::nn::AdaptiveMaxPool1d|Yes|No
|
||||||
torch::nn::AdaptiveMaxPool2d|Yes|No
|
torch::nn::AdaptiveMaxPool2d|Yes|No
|
||||||
torch::nn::AdaptiveMaxPool3d|Yes|No
|
torch::nn::AdaptiveMaxPool3d|Yes|No
|
||||||
@ -173,6 +174,7 @@ F::max_unpool2d|Yes|No
|
|||||||
F::max_unpool3d|Yes|No
|
F::max_unpool3d|Yes|No
|
||||||
F::lp_pool1d|Yes|No
|
F::lp_pool1d|Yes|No
|
||||||
F::lp_pool2d|Yes|No
|
F::lp_pool2d|Yes|No
|
||||||
|
F::lp_pool3d|Yes|No
|
||||||
F::adaptive_max_pool1d|Yes|No
|
F::adaptive_max_pool1d|Yes|No
|
||||||
F::adaptive_max_pool2d|Yes|No
|
F::adaptive_max_pool2d|Yes|No
|
||||||
F::adaptive_max_pool3d|Yes|No
|
F::adaptive_max_pool3d|Yes|No
|
||||||
|
|||||||
@ -4238,6 +4238,7 @@ class TestFunctionalTracing(JitTestCase):
|
|||||||
"max_pool3d": PROXY_ITERABLE,
|
"max_pool3d": PROXY_ITERABLE,
|
||||||
|
|
||||||
"lp_pool2d": PROXY_ITERATED,
|
"lp_pool2d": PROXY_ITERATED,
|
||||||
|
"lp_pool3d": PROXY_ITERATED,
|
||||||
"max_unpool1d": PROXY_ITERATED,
|
"max_unpool1d": PROXY_ITERATED,
|
||||||
"max_unpool2d": PROXY_ITERATED,
|
"max_unpool2d": PROXY_ITERATED,
|
||||||
"max_unpool3d": PROXY_ITERATED,
|
"max_unpool3d": PROXY_ITERATED,
|
||||||
|
|||||||
@ -83,6 +83,7 @@ def build_constructor_arg_db():
|
|||||||
torch.nn.L1Loss: ((), {}),
|
torch.nn.L1Loss: ((), {}),
|
||||||
torch.nn.LPPool1d: ((2, 3), {}),
|
torch.nn.LPPool1d: ((2, 3), {}),
|
||||||
torch.nn.LPPool2d: ((2, 3), {}),
|
torch.nn.LPPool2d: ((2, 3), {}),
|
||||||
|
torch.nn.LPPool3d: ((2, 3), {}),
|
||||||
torch.nn.LSTM: ((5, 10), {}),
|
torch.nn.LSTM: ((5, 10), {}),
|
||||||
torch.nn.LSTMCell: ((5, 10), {}),
|
torch.nn.LSTMCell: ((5, 10), {}),
|
||||||
torch.nn.LayerNorm: ((2,), {}),
|
torch.nn.LayerNorm: ((2,), {}),
|
||||||
|
|||||||
@ -1093,6 +1093,56 @@ inline Tensor lp_pool2d(
|
|||||||
options.ceil_mode());
|
options.ceil_mode());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef DOXYGEN_SHOULD_SKIP_THIS
|
||||||
|
namespace detail {
|
||||||
|
inline Tensor lp_pool3d(
|
||||||
|
const Tensor& input,
|
||||||
|
double norm_type,
|
||||||
|
ExpandingArray<3> kernel_size,
|
||||||
|
ExpandingArray<3> stride,
|
||||||
|
bool ceil_mode) {
|
||||||
|
int kd = (*kernel_size)[0];
|
||||||
|
int kw = (*kernel_size)[1];
|
||||||
|
int kh = (*kernel_size)[2];
|
||||||
|
Tensor out = detail::avg_pool3d(
|
||||||
|
input.pow(norm_type),
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
/*padding=*/0,
|
||||||
|
ceil_mode,
|
||||||
|
/*count_include_pad=*/true,
|
||||||
|
/*divisor_override=*/c10::nullopt);
|
||||||
|
|
||||||
|
return (torch::sign(out) * relu(torch::abs(out)))
|
||||||
|
.mul(kd * kw * kh)
|
||||||
|
.pow(1. / norm_type);
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
|
||||||
|
|
||||||
|
/// See
|
||||||
|
/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.lp_pool3d
|
||||||
|
/// about the exact behavior of this functional.
|
||||||
|
///
|
||||||
|
/// See the documentation for `torch::nn::functional::LPPool3dFuncOptions` class
|
||||||
|
/// to learn what optional arguments are supported for this functional.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// namespace F = torch::nn::functional;
|
||||||
|
/// F::lp_pool3d(x, F::LPPool3dFuncOptions(3, {3, 3, 5}).stride(3));
|
||||||
|
/// ```
|
||||||
|
inline Tensor lp_pool3d(
|
||||||
|
const Tensor& input,
|
||||||
|
const LPPool3dFuncOptions& options) {
|
||||||
|
return detail::lp_pool3d(
|
||||||
|
input,
|
||||||
|
options.norm_type(),
|
||||||
|
options.kernel_size(),
|
||||||
|
options.stride(),
|
||||||
|
options.ceil_mode());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace functional
|
} // namespace functional
|
||||||
} // namespace nn
|
} // namespace nn
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|||||||
@ -747,5 +747,33 @@ class TORCH_API LPPool2dImpl : public LPPoolImpl<2, LPPool2dImpl> {
|
|||||||
/// learn about PyTorch's module storage semantics.
|
/// learn about PyTorch's module storage semantics.
|
||||||
TORCH_MODULE(LPPool2d);
|
TORCH_MODULE(LPPool2d);
|
||||||
|
|
||||||
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LPPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
/// Applies the LPPool3d function element-wise.
|
||||||
|
/// See https://pytorch.org/docs/master/nn.html#torch.nn.LPPool3d to learn
|
||||||
|
/// about the exact behavior of this module.
|
||||||
|
///
|
||||||
|
/// See the documentation for `torch::nn::LPPool3dOptions` class to learn what
|
||||||
|
/// constructor arguments are supported for this module.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// LPPool3d model(LPPool3dOptions(1, std::vector<int64_t>({3, 4, 5})).stride(
|
||||||
|
/// {5, 6, 7}).ceil_mode(true));
|
||||||
|
/// ```
|
||||||
|
class TORCH_API LPPool3dImpl : public LPPoolImpl<3, LPPool3dImpl> {
|
||||||
|
public:
|
||||||
|
using LPPoolImpl<3, LPPool3dImpl>::LPPoolImpl;
|
||||||
|
|
||||||
|
Tensor forward(const Tensor& input);
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A `ModuleHolder` subclass for `LPPool3dImpl`.
|
||||||
|
/// See the documentation for `LPPool3dImpl` class to learn what methods it
|
||||||
|
/// provides, and examples of how to use `LPPool3d` with
|
||||||
|
/// `torch::nn::LPPool3dOptions`. See the documentation for `ModuleHolder` to
|
||||||
|
/// learn about PyTorch's module storage semantics.
|
||||||
|
TORCH_MODULE(LPPool3d);
|
||||||
|
|
||||||
} // namespace nn
|
} // namespace nn
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|||||||
@ -541,6 +541,15 @@ using LPPool1dOptions = LPPoolOptions<1>;
|
|||||||
/// ```
|
/// ```
|
||||||
using LPPool2dOptions = LPPoolOptions<2>;
|
using LPPool2dOptions = LPPoolOptions<2>;
|
||||||
|
|
||||||
|
/// `LPPoolOptions` specialized for the `LPPool3d` module.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// LPPool3d model(LPPool3dOptions(1, std::vector<int64_t>({3, 4, 5})).stride(
|
||||||
|
/// {5, 6, 7}).ceil_mode(true));
|
||||||
|
/// ```
|
||||||
|
using LPPool3dOptions = LPPoolOptions<3>;
|
||||||
|
|
||||||
namespace functional {
|
namespace functional {
|
||||||
/// Options for `torch::nn::functional::lp_pool1d`.
|
/// Options for `torch::nn::functional::lp_pool1d`.
|
||||||
///
|
///
|
||||||
@ -569,5 +578,19 @@ namespace functional {
|
|||||||
using LPPool2dFuncOptions = LPPool2dOptions;
|
using LPPool2dFuncOptions = LPPool2dOptions;
|
||||||
} // namespace functional
|
} // namespace functional
|
||||||
|
|
||||||
|
namespace functional {
|
||||||
|
/// Options for `torch::nn::functional::lp_pool3d`.
|
||||||
|
///
|
||||||
|
/// See the documentation for `torch::nn::LPPool3dOptions` class to learn what
|
||||||
|
/// arguments are supported.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// namespace F = torch::nn::functional;
|
||||||
|
/// F::lp_pool3d(x, F::LPPool3dFuncOptions(2, {2, 3, 4}).stride(2));
|
||||||
|
/// ```
|
||||||
|
using LPPool3dFuncOptions = LPPool3dOptions;
|
||||||
|
} // namespace functional
|
||||||
|
|
||||||
} // namespace nn
|
} // namespace nn
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|||||||
@ -429,5 +429,16 @@ Tensor LPPool2dImpl::forward(const Tensor& input) {
|
|||||||
|
|
||||||
template class LPPoolImpl<2, LPPool2dImpl>;
|
template class LPPoolImpl<2, LPPool2dImpl>;
|
||||||
|
|
||||||
|
Tensor LPPool3dImpl::forward(const Tensor& input) {
|
||||||
|
return F::detail::lp_pool3d(
|
||||||
|
input,
|
||||||
|
options.norm_type(),
|
||||||
|
options.kernel_size(),
|
||||||
|
options.stride(),
|
||||||
|
options.ceil_mode());
|
||||||
|
}
|
||||||
|
|
||||||
|
template class LPPoolImpl<3, LPPool3dImpl>;
|
||||||
|
|
||||||
} // namespace nn
|
} // namespace nn
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|||||||
@ -25,6 +25,7 @@ template struct MaxUnpoolOptions<3>;
|
|||||||
|
|
||||||
template struct LPPoolOptions<1>;
|
template struct LPPoolOptions<1>;
|
||||||
template struct LPPoolOptions<2>;
|
template struct LPPoolOptions<2>;
|
||||||
|
template struct LPPoolOptions<3>;
|
||||||
|
|
||||||
} // namespace nn
|
} // namespace nn
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|||||||
@ -1022,6 +1022,33 @@ def max_unpool3d(
|
|||||||
return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding)
|
return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding)
|
||||||
|
|
||||||
|
|
||||||
|
def lp_pool3d(
|
||||||
|
input: Tensor, norm_type: Union[int, float],
|
||||||
|
kernel_size: BroadcastingList3[int],
|
||||||
|
stride: Optional[BroadcastingList3[int]] = None,
|
||||||
|
ceil_mode: bool = False
|
||||||
|
) -> Tensor:
|
||||||
|
r"""
|
||||||
|
Apply a 3D power-average pooling over an input signal composed of several input planes.
|
||||||
|
|
||||||
|
If the sum of all inputs to the power of `p` is
|
||||||
|
zero, the gradient is set to zero as well.
|
||||||
|
|
||||||
|
See :class:`~torch.nn.LPPool3d` for details.
|
||||||
|
"""
|
||||||
|
if has_torch_function_unary(input):
|
||||||
|
return handle_torch_function(
|
||||||
|
lp_pool3d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode
|
||||||
|
)
|
||||||
|
kd, kw, kh = utils._triple(kernel_size)
|
||||||
|
if stride is not None:
|
||||||
|
out = avg_pool3d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
|
||||||
|
else:
|
||||||
|
out = avg_pool3d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode)
|
||||||
|
|
||||||
|
return (torch.sign(out) * relu(torch.abs(out))).mul(kd * kw * kh).pow(1.0 / norm_type)
|
||||||
|
|
||||||
|
|
||||||
def lp_pool2d(
|
def lp_pool2d(
|
||||||
input: Tensor, norm_type: Union[int, float],
|
input: Tensor, norm_type: Union[int, float],
|
||||||
kernel_size: BroadcastingList2[int],
|
kernel_size: BroadcastingList2[int],
|
||||||
|
|||||||
@ -127,6 +127,13 @@ def lp_pool2d(
|
|||||||
stride: Union[Optional[_size], Optional[int]] = ...,
|
stride: Union[Optional[_size], Optional[int]] = ...,
|
||||||
ceil_mode: bool = ...,
|
ceil_mode: bool = ...,
|
||||||
) -> Tensor: ...
|
) -> Tensor: ...
|
||||||
|
def lp_pool3d(
|
||||||
|
input: Tensor,
|
||||||
|
norm_type: float,
|
||||||
|
kernel_size: _size_3_t,
|
||||||
|
stride: Union[Optional[_size], Optional[int]] = ...,
|
||||||
|
ceil_mode: bool = ...,
|
||||||
|
) -> Tensor: ...
|
||||||
def adaptive_max_pool1d_with_indices(
|
def adaptive_max_pool1d_with_indices(
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
output_size: _size,
|
output_size: _size,
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLos
|
|||||||
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss, GaussianNLLLoss
|
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss, GaussianNLLLoss
|
||||||
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
|
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
|
||||||
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
|
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
|
||||||
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
|
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, LPPool3d, \
|
||||||
AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
|
AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
|
||||||
from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \
|
from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \
|
||||||
LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d
|
LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d
|
||||||
@ -48,8 +48,8 @@ __all__ = [
|
|||||||
'HuberLoss', 'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList', 'ModuleDict',
|
'HuberLoss', 'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList', 'ModuleDict',
|
||||||
'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
|
'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
|
||||||
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
|
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
|
||||||
'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
|
'LPPool1d', 'LPPool2d', 'LPPool3d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
|
||||||
'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
|
'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
|
||||||
'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
|
'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
|
||||||
'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
|
'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
|
||||||
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
|
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
|
||||||
|
|||||||
@ -10,8 +10,8 @@ from ..common_types import (_size_any_t, _size_1_t, _size_2_t, _size_3_t,
|
|||||||
|
|
||||||
__all__ = ['MaxPool1d', 'MaxPool2d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d',
|
__all__ = ['MaxPool1d', 'MaxPool2d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d',
|
||||||
'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'FractionalMaxPool2d', 'FractionalMaxPool3d', 'LPPool1d',
|
'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'FractionalMaxPool2d', 'FractionalMaxPool3d', 'LPPool1d',
|
||||||
'LPPool2d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
|
'LPPool2d', 'LPPool3d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
|
||||||
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d']
|
'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d']
|
||||||
|
|
||||||
class _MaxPoolNd(Module):
|
class _MaxPoolNd(Module):
|
||||||
__constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
|
__constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
|
||||||
@ -953,8 +953,8 @@ class LPPool2d(_LPPoolNd):
|
|||||||
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
|
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
- Input: :math:`(N, C, H_{in}, W_{in})`
|
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
||||||
- Output: :math:`(N, C, H_{out}, W_{out})`, where
|
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
||||||
@ -981,6 +981,64 @@ class LPPool2d(_LPPoolNd):
|
|||||||
self.stride, self.ceil_mode)
|
self.stride, self.ceil_mode)
|
||||||
|
|
||||||
|
|
||||||
|
class LPPool3d(_LPPoolNd):
|
||||||
|
r"""Applies a 3D power-average pooling over an input signal composed of several input planes.
|
||||||
|
|
||||||
|
On each window, the function computed is:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
|
||||||
|
|
||||||
|
- At p = :math:`\infty`, one gets Max Pooling
|
||||||
|
- At p = 1, one gets Sum Pooling (which is proportional to average pooling)
|
||||||
|
|
||||||
|
The parameters :attr:`kernel_size`, :attr:`stride` can either be:
|
||||||
|
|
||||||
|
- a single ``int`` -- in which case the same value is used for the height, width and depth dimension
|
||||||
|
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
|
||||||
|
the second `int` for the height dimension and the third `int` for the width dimension
|
||||||
|
|
||||||
|
.. note:: If the sum to the power of `p` is zero, the gradient of this function is
|
||||||
|
not defined. This implementation will set the gradient to zero in this case.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_size: the size of the window
|
||||||
|
stride: the stride of the window. Default value is :attr:`kernel_size`
|
||||||
|
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
|
||||||
|
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
|
||||||
|
:math:`(C, D_{out}, H_{out}, W_{out})`, where
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
D_{out} = \left\lfloor\frac{D_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> # power-2 pool of square window of size=3, stride=2
|
||||||
|
>>> m = nn.LPPool3d(2, 3, stride=2)
|
||||||
|
>>> # pool of non-square window of power 1.2
|
||||||
|
>>> m = nn.LPPool3d(1.2, (3, 2, 2), stride=(2, 1, 2))
|
||||||
|
>>> input = torch.randn(20, 16, 50, 44, 31)
|
||||||
|
>>> output = m(input)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel_size: _size_3_t
|
||||||
|
stride: _size_3_t
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
return F.lp_pool3d(input, float(self.norm_type), self.kernel_size,
|
||||||
|
self.stride, self.ceil_mode)
|
||||||
|
|
||||||
|
|
||||||
class _AdaptiveMaxPoolNd(Module):
|
class _AdaptiveMaxPoolNd(Module):
|
||||||
__constants__ = ['output_size', 'return_indices']
|
__constants__ = ['output_size', 'return_indices']
|
||||||
return_indices: bool
|
return_indices: bool
|
||||||
|
|||||||
@ -866,6 +866,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||||||
torch.nn.functional.logsigmoid: lambda input: -1,
|
torch.nn.functional.logsigmoid: lambda input: -1,
|
||||||
torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
||||||
torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
||||||
|
torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
|
||||||
torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
|
torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
|
||||||
reduce=None, reduction='mean': -1),
|
reduce=None, reduction='mean': -1),
|
||||||
torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
|
||||||
|
|||||||
@ -1467,6 +1467,11 @@ def module_inputs_torch_nn_LPPool2d(module_info, device, dtype, requires_grad, t
|
|||||||
ModuleInput(
|
ModuleInput(
|
||||||
constructor_input=FunctionInput(2, 2, 2),
|
constructor_input=FunctionInput(2, 2, 2),
|
||||||
forward_input=FunctionInput(make_input((1, 3, 7, 7)))),
|
forward_input=FunctionInput(make_input((1, 3, 7, 7)))),
|
||||||
|
ModuleInput(
|
||||||
|
constructor_input=FunctionInput(2, 2, 2),
|
||||||
|
forward_input=FunctionInput(make_input((3, 7, 7))),
|
||||||
|
reference_fn=no_batch_dim_reference_fn,
|
||||||
|
desc='no_batch_dim'),
|
||||||
ModuleInput(
|
ModuleInput(
|
||||||
constructor_input=FunctionInput(1.5, 2),
|
constructor_input=FunctionInput(1.5, 2),
|
||||||
forward_input=FunctionInput(make_input((1, 3, 7, 7))),
|
forward_input=FunctionInput(make_input((1, 3, 7, 7))),
|
||||||
@ -1474,6 +1479,25 @@ def module_inputs_torch_nn_LPPool2d(module_info, device, dtype, requires_grad, t
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def module_inputs_torch_nn_LPPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||||
|
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
return [
|
||||||
|
ModuleInput(
|
||||||
|
constructor_input=FunctionInput(2, 2, 2),
|
||||||
|
forward_input=FunctionInput(make_input((1, 3, 7, 7, 7)))),
|
||||||
|
ModuleInput(
|
||||||
|
constructor_input=FunctionInput(2, 2, 2),
|
||||||
|
forward_input=FunctionInput(make_input((3, 7, 7, 7))),
|
||||||
|
reference_fn=no_batch_dim_reference_fn,
|
||||||
|
desc='no_batch_dim'),
|
||||||
|
ModuleInput(
|
||||||
|
constructor_input=FunctionInput(1.5, 2),
|
||||||
|
forward_input=FunctionInput(make_input((1, 3, 7, 7, 7))),
|
||||||
|
desc='norm'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def module_inputs_torch_nn_MaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
|
def module_inputs_torch_nn_MaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||||
|
|
||||||
@ -3073,6 +3097,14 @@ module_db: List[ModuleInfo] = [
|
|||||||
device_type='mps',
|
device_type='mps',
|
||||||
),)
|
),)
|
||||||
),
|
),
|
||||||
|
ModuleInfo(torch.nn.LPPool3d,
|
||||||
|
module_inputs_func=module_inputs_torch_nn_LPPool3d,
|
||||||
|
skips=(
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
|
||||||
|
DecorateInfo(skipIfMps),)
|
||||||
|
),
|
||||||
ModuleInfo(torch.nn.MaxPool1d,
|
ModuleInfo(torch.nn.MaxPool1d,
|
||||||
module_inputs_func=module_inputs_torch_nn_MaxPool1d,
|
module_inputs_func=module_inputs_torch_nn_MaxPool1d,
|
||||||
skips=(
|
skips=(
|
||||||
|
|||||||
@ -115,6 +115,7 @@ nn_functional_tests = [
|
|||||||
('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
|
('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
|
||||||
('lp_pool1d', (S, S, S), (2., 3, 2,)),
|
('lp_pool1d', (S, S, S), (2., 3, 2,)),
|
||||||
('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
|
('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
|
||||||
|
('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)),
|
||||||
('adaptive_max_pool1d', (S, S, S), (5,)),
|
('adaptive_max_pool1d', (S, S, S), (5,)),
|
||||||
('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
|
('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
|
||||||
('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
|
('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
|
||||||
|
|||||||
Reference in New Issue
Block a user