Use torch:: instead of at:: (#8911)

Summary:
This PR is the final step to making `torch::` the only  namespace users of the C++ API ever see. Basically, I did:

``` cpp

namespace torch {
using namespace at;
}
```

And then changed `torch::` to `at::` almost everywhere. This worked surprisingly well out of the box. So users can now write `torch::relu`  and `torch::log_softmax` and `torch::conv2d` instead of having to know when to use `at::` and when `torch::`. This is happy!

Another thing I did was to have `using Dtype = at::ScalarType`, which will be the eventual name anyway.

ebetica ezyang apaszke zdevito
Closes https://github.com/pytorch/pytorch/pull/8911

Reviewed By: ezyang

Differential Revision: D8668230

Pulled By: goldsborough

fbshipit-source-id: a72ccb70fca763c396c4b0997d3c4767c8cf4fd3
This commit is contained in:
Peter Goldsborough
2018-06-27 14:34:06 -07:00
committed by Facebook Github Bot
parent 4c5192788b
commit fef9a66d08
24 changed files with 359 additions and 343 deletions

View File

@ -39,14 +39,14 @@ bool test_RNN_xor(Func&& model_maker, bool cuda = false) {
auto bs = 16U;
auto nlen = 5U;
const auto backend = cuda ? at::kCUDA : at::kCPU;
auto inp = at::rand({nlen, bs, 1}, backend).round().toType(torch::kFloat32);
auto lab = inp.sum(0);
const auto backend = cuda ? torch::kCUDA : torch::kCPU;
auto inputs =
torch::rand({nlen, bs, 1}, backend).round().toType(torch::kFloat32);
auto labels = inputs.sum(0).detach();
inputs.set_requires_grad(true);
auto x = torch::autograd::make_variable(inp, /*requires_grad=*/true);
auto y = torch::autograd::make_variable(lab);
x = forward_op(x);
torch::Tensor loss = at::mse_loss(x, y);
auto outputs = forward_op(inputs);
torch::Tensor loss = torch::mse_loss(outputs, labels);
optimizer.zero_grad();
loss.backward();
@ -84,7 +84,7 @@ TEST_CASE("rnn") {
SECTION("lstm") {
SECTION("sizes") {
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
auto x = torch::randn({10, 16, 128}, at::requires_grad());
auto x = torch::randn({10, 16, 128}, torch::requires_grad());
auto output = model->forward(x);
auto y = x.mean();
@ -112,7 +112,7 @@ TEST_CASE("rnn") {
}
}
auto x = torch::empty({3, 4, 2}, at::requires_grad());
auto x = torch::empty({3, 4, 2}, torch::requires_grad());
float size = x.data().numel();
auto p = static_cast<float*>(x.data().storage()->data());
for (size_t i = 0; i < size; i++) {
@ -188,7 +188,8 @@ TEST_CASE("rnn_cuda", "[cuda]") {
SECTION("sizes") {
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
model->cuda();
auto x = torch::randn({10, 16, 128}, at::requires_grad().device(at::kCUDA));
auto x = torch::randn(
{10, 16, 128}, torch::requires_grad().device(torch::kCUDA));
auto output = model->forward(x);
auto y = x.mean();