C++ API parity: Linear

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27382

Test Plan: Imported from OSS

Differential Revision: D17766735

Pulled By: pbelevich

fbshipit-source-id: c7a66daeb17550eb9a5d26944427723d4ebdc6c8
This commit is contained in:
Pavel Belevich
2019-10-24 07:09:33 -07:00
committed by Facebook Github Bot
parent 59402f51cf
commit dd277e9086
12 changed files with 156 additions and 41 deletions

View File

@ -393,7 +393,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
ASSERT_EQ(
c10::str(sequential),
"torch::nn::Sequential(\n"
" (0): torch::nn::Linear(in=10, out=3, with_bias=true)\n"
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
" (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
" (2): torch::nn::Dropout(rate=0.5)\n"
" (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
@ -412,7 +412,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
ASSERT_EQ(
c10::str(sequential_named),
"torch::nn::Sequential(\n"
" (linear): torch::nn::Linear(in=10, out=3, with_bias=true)\n"
" (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
" (conv2d): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
" (dropout): torch::nn::Dropout(rate=0.5)\n"
" (batchnorm): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"