mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Updated TH to ATen porting guide (markdown)
@ -42,32 +42,6 @@ In code in TH, you will see references to scalar_t. In fact, all code in TH is i
|
||||
|
||||
This is obviously horrible, and in ATen we do things more normally using templates. However, this doesn't mean you should go out and template everything. Instead, there is usually a critical loop inside the kernel which actually needs to be templated, and then everything else can be compiled once generically. At the point this occurs, you should use `AT_DISPATCH_FLOATING_TYPES_AND_HALF` to perform this dispatch from code which doesn't at compile time what scalar_t is, to code that knows what scalar_t. For usage guidance, search for some examples.
|
||||
|
||||
### Dispatch and OpenMP
|
||||
|
||||
Many functions in THNN are parallelized in the following way:
|
||||
|
||||
```
|
||||
int64_t batch;
|
||||
#pragma omp parallel for private(batch)
|
||||
for (batch = 0; batch < numBatch; ++batch) {
|
||||
THNN_(SpatialFractionalMaxPooling_updateOutput_frame)(
|
||||
input->data<scalar_t>() + batch * numPlanes * inputH * inputW,
|
||||
output->data<scalar_t>() + batch * numPlanes * outputH * outputW,
|
||||
THIndexTensor_(data)(indices) + batch * numPlanes * outputH * outputW,
|
||||
randomSamples->data<scalar_t>() + batch * numPlanes * 2,
|
||||
numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This snippet of code is particularly delicate to port, for a few reasons:
|
||||
|
||||
1. You need to do a dispatch before you can retrieve the appropriately typed scalar_t pointers, implying use of the AT_DISPATCH_ macro. However...
|
||||
2. #pragma omp declarations don't work inside lambdas...
|
||||
3. You do NOT want to do a dynamic dispatch every time around the batch loop; you should do it once to get into the tight loop.
|
||||
|
||||
The recommendation is to make a new helper function to contain the OpenMP pragma, and call it from inside the AT_DISPATCH_ macro. Do the extraction of the pointers once, outside of the loop (since they're same every time around). Also, separate code paths for batch==1 and batch>1, because parallelism across 1 batch in the former case is redundant and affects performance.
|
||||
|
||||
## C++ torch API guidance
|
||||
|
||||
For the most part, the at::Tensor API matches the Python Torch tensor API directly; so you can check the official docs to find out what is callable: https://pytorch.org/docs/stable/tensors.html
|
||||
|
Reference in New Issue
Block a user