mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# Issue This PR cleans up an edge case that wasn't handled by https://github.com/pytorch/pytorch/pull/137243. The existing tiling code assumes that `node.get_ranges()` is a reliable source of pointwise and reduction numels. This is true for pointwise kernels, but the situation is more complicated with reductions. Since reductions change the number of elements in a tensor, not all ops within a reduction kernel will have the same number of iterations. For example, `var_mean` fuses pointwise division with the output of reduction sum, and the division lacks the corresponding reduction ranges. # Fix Instead of getting numels from `node.get_ranges()`, explicitly pass the global pointwise and reduction numels to the relevant tiling functions. In `SIMDKernel.complete_partial_tiling`, we solve for the missing numel by diving the global numel by the partial tiling's numel. This ensures all tilings have the correct global numel. Also, in `SIMDKernel.is_compatible`, add the global reduction numel to node ranges that are missing it. For example, `{"x": 8, "r0_": 8}` is compatible with a node of ranges `([8], [])` when we have `reduction_numel=8`. Finally, this PR generalizes some of the existing codegen to handle multiple reduction dims. We already had code to ignore reduction splits for pointwise kernels, but it only worked for 1D reductions. Now it can handle ND. # Test plan This PR parametrizes the existing CI test for `var_mean` to also run with tiled reductions. It also adds a new test checking that `var_mean` generates 2D tilings (with tiled reduction enabled). These new tests would fail on the current main branch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144041 Approved by: https://github.com/jansel
Note [TH abstraction violation] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TH/THC provide some hpp headers, which are proper C++ headers rather than C headers. These headers serve double duty as *internal implementation detail* headers, whose contents should largely not be used by external clients. Ideally, we would not install these headers at all; instead, you should use public functions (in headers like `THTensor.h`, NOT `THTensor.hpp`) to manipulate these structs. However, there are a few places in torch/csrc where we violate this abstraction. They are marked with a pointer to this note. Each of those sites will have to be refactored when we refactor the guts of THTensor and related structures.