Add tests to check pretty print when padding is a string in C++ API (#153126)

Currently there are no tests to verify the behaviour of pretty print when padding is `torch::kSame` or `torch::kValid`. This PR just adds this tests to check for future regressions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153126
Approved by: https://github.com/Skylion007
This commit is contained in:
Alvaro-Kothe
2025-05-08 17:55:20 +00:00
committed by PyTorch MergeBot
parent d36261d2e6
commit e86b6b2a19

View File

@ -4591,6 +4591,15 @@ TEST_F(ModulesTest, PrettyPrintConv) {
ASSERT_EQ(
c10::str(Conv1d(3, 4, 5)),
"torch::nn::Conv1d(3, 4, kernel_size=5, stride=1)");
{
auto options = Conv1dOptions(3, 4, 5);
ASSERT_EQ(
c10::str(Conv1d(options.padding(torch::kSame))),
"torch::nn::Conv1d(3, 4, kernel_size=5, stride=1, padding='same')");
ASSERT_EQ(
c10::str(Conv1d(options.padding(torch::kValid))),
"torch::nn::Conv1d(3, 4, kernel_size=5, stride=1, padding='valid')");
}
ASSERT_EQ(
c10::str(Conv2d(3, 4, 5)),
@ -4605,6 +4614,15 @@ TEST_F(ModulesTest, PrettyPrintConv) {
c10::str(Conv2d(options)),
"torch::nn::Conv2d(3, 4, kernel_size=[5, 6], stride=[1, 2])");
}
{
auto options = Conv2dOptions(3, 4, std::vector<int64_t>{5, 6});
ASSERT_EQ(
c10::str(Conv2d(options.padding(torch::kSame))),
"torch::nn::Conv2d(3, 4, kernel_size=[5, 6], stride=[1, 1], padding='same')");
ASSERT_EQ(
c10::str(Conv2d(options.padding(torch::kValid))),
"torch::nn::Conv2d(3, 4, kernel_size=[5, 6], stride=[1, 1], padding='valid')");
}
ASSERT_EQ(
c10::str(Conv3d(4, 4, std::vector<int64_t>{5, 6, 7})),
@ -4630,6 +4648,15 @@ TEST_F(ModulesTest, PrettyPrintConv) {
"bias=false, "
"padding_mode=kCircular)");
}
{
auto options = Conv3dOptions(3, 4, std::vector<int64_t>{5, 6, 7});
ASSERT_EQ(
c10::str(Conv3d(options.padding(torch::kSame))),
"torch::nn::Conv3d(3, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1], padding='same')");
ASSERT_EQ(
c10::str(Conv3d(options.padding(torch::kValid))),
"torch::nn::Conv3d(3, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1], padding='valid')");
}
}
TEST_F(ModulesTest, PrettyPrintConvTranspose) {