Clone
2
Cpp API Quick Walkthrough
Manuel edited this page 2024-07-03 19:30:42 +02:00
This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

C++ API quick code walkthrough

PyTorch, but in C++

#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/function.h>

torch::Tensor a = torch::ones({2, 2}, torch::requires_grad());
torch::Tensor b = torch::randn({2, 2});
auto c = a + b;
c.sum().backward(); // a.grad() will now hold the gradient of c w.r.t. a.

Operators

Come straight from the at:: namespace. There is a using namespace at somewhere.

E.g., at::add, torch::add are the same thing

Modules

Mnist example: https://pytorch.org/cppdocs/frontend.html#end-to-end-example

C++ Modules are not implemented the same way as they are in Python but we try to reproduce their behavior/APIs as much as possible.

Optimizers

Optimizer interface SGD as example

Other utilities exist...

DataLoader: 5d82311f0d/torch/csrc/api/include/torch/data/dataloader.h (L19-L56). But Im not sure how different this is from the Python dataloader.

C++ Extensions

Read through: https://pytorch.org/tutorials/advanced/cpp_extension.html

Why?

  • Lets say you wanted to write a custom CPU or CUDA kernel for some operation in C++, and hook it up to the PyTorch frontend.
  • You can write your own setuptools Python extension, or you can use the PyTorch C++ extensions API.

There are two types of extensions, really:

Things like TorchVision use C++ extensions to add new kernels in their packages.

Next

Unit 5: torch.nn - Modules Lab