Commit Graph

9 Commits

Author SHA1 Message Date
2460dced8f Add torch.nn.GELU for GELU activation (#28944)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28944

Add torch.nn.GELU for GELU activation

Test Plan: buck test mode/dev-nosan //caffe2/test:nn -- "GELU"

Reviewed By: hl475, houseroad

Differential Revision: D18240946

fbshipit-source-id: 6284b30def9bd4c12bf7fb2ed08b1b2f0310bb78
2019-11-03 21:55:05 -08:00
b5d15315d8 Improve C++ maxpool and avgpool (#26521)
Summary:
This PR makes the following improvements:
1. Add `forward_with_indices` method to all C++ MaxPool modules, to return the max indices along with the outputs. (We can't make two `forward` methods that return different types based on input, because that will break the type deduction of `torch::detail::return_type_of_forward_t`)
2. Add `max_poolNd_with_indices` to `torch::nn::functional`, to be used when indices of the max values are needed. (We can't merge this with `torch::nn::functional::max_poolNd` because the return type of `max_poolNd` has to be defined statically).
3. Improve `pretty_print` of C++ MaxPoolNd and AvgPoolNd modules to match the Python `extra_repr`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26521

Differential Revision: D17507358

Pulled By: yf225

fbshipit-source-id: b6c0e2b27b38378cdc0c75f4bfc797b3c6b17cd9
2019-09-25 13:52:58 -07:00
28a2dafc15 C++ Average Pool Module (#25800)
Summary:
This PR adds Average Pool module to C++ front-end.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25800

Differential Revision: D17318094

Pulled By: yf225

fbshipit-source-id: c914c0e802bbe5f1d1f0a21a669c28bc956899db
2019-09-11 16:39:56 -07:00
ba9fda14a7 C++ MaxPool Module
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24860

Differential Revision: D17260361

Pulled By: yf225

fbshipit-source-id: 4b8c894d3bdf675cfeb9fc84934fe0339a048c1e
2019-09-11 08:56:57 -07:00
e04836004d L1Loss module (#25902)
Summary:
yf225 This is L1Loss module. I don't think that ```_Loss``` and ```_WeightedLoss``` as base Python classes do anything. First one sets reduction type and also takes in ```reduce``` parameter which is deprecated. The second one only registers ```weight``` parameter. I don't think that we should keep this structure. What do you think?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25902

Differential Revision: D17307045

Pulled By: yf225

fbshipit-source-id: ad3eda2ee8dcf4465054b376c1be89b39d11532f
2019-09-11 07:18:17 -07:00
3680cef44e C++ Fold nn module
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24160

Differential Revision: D17260740

Pulled By: yf225

fbshipit-source-id: f0c7769316bed330289ca3d948f2e39c72ec928b
2019-09-10 13:19:37 -07:00
2fe8341aac Map module options between Python and C++ in API parity test (#25784)
Summary:
`torch.nn` modules in Python save their kwarg options directly as module object attributes, while `torch::nn` modules in C++ save their options inside the `options` field of the module object. This PR tries to map between these two (by using the newly added `options_args` list to discover options arguments in Python module), to make sure options equivalence is properly checked.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25784

Differential Revision: D17238609

Pulled By: yf225

fbshipit-source-id: 2febd277ddcbe3ab458ac3feaaf93e4c94bb5b98
2019-09-06 15:30:36 -07:00
ef6ea545e8 Add Python/C++ API parity tracker for torch.nn (#25289)
Summary:
This PR adds Python/C++ API parity tracker at `test/cpp_api_parity/parity-tracker.md`, which currently shows parity status for `torch.nn` modules.

A good amount of line changes here is moving `new_criterion_tests` from `test_nn.py` to `common_nn.py`, so that it can be used in `test_cpp_api_parity.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25289

Differential Revision: D17188085

Pulled By: yf225

fbshipit-source-id: 33d12fb1a4de2d9147ed09380973f361a3981fdf
2019-09-04 19:46:33 -07:00
1bf1970fe2 Add Python/C++ torch.nn API parity test harness (#23852)
Summary:
This PR adds test harness for checking Python / C++ API parity for `torch.nn.Module` subclasses. Under the hood, we use JIT tracing to transfer `nn.Module` state from Python to C++, so that we can test initialization / forward / backward on Python / C++ modules with the same parameters and buffers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23852

Differential Revision: D16830204

Pulled By: yf225

fbshipit-source-id: 9b5298c0e8cd30e341a9f026e6f05604a82d6002
2019-08-26 08:02:25 -07:00