This pr uses the coalescing information in generating a tiling. The previous tiling heuristic would have each dependency generate a tiling. Then, we sum up the score for each generated tiling, preferring any 2d tiling over the default. The new tiling heuristics scores each tiling by its global coalesced memory. This gives both a potentially better tiling (especially for more complicated, 3d patterns) as well as information we can use in generating block sizes.
In triton heuristics, for generating 3d tiled reductions, we take the same total block size that the 2d reduction would use, then distribute the block according to whichever block coalesces the most memory.
The motivating kernel is in https://github.com/pytorch/pytorch/issues/149982 which is a 32 element reduction. A smaller version of it is [here](https://gist.github.com/eellison/0fa9396f5479eb4dba09756e3bf6ff2a). We need to run this kernel once in the forward per linear layer on a contiguous tensor, and once in the backward on a transposed tensor.
While the contiguous kernel has coalesced accesses, and is performant on master, the transposed version accesses uncoalesced memory on main and is ~2.8x slower. See, this [full log](https://gist.github.com/eellison/fa644bfd9d0ae11dadb62e17a5d48a83) from the above repro. Now, with this PR, it is only ~1.15x slower. See the [updated log](https://gist.github.com/eellison/0b2b653309494d28cf7b48929a022075).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153751
Approved by: https://github.com/jansel
ghstack dependencies: #153723, #153730, #153748
Find variables that coalesce the reads and writes and score the total size. If uncoalesced memory expressions are found, look for additional tiling of variables which will coalesce memory accesses.
For instance - for the following expression: `(32*p0) // 2048`, tiling p0 by 64 will make this expression coalesced.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153748
Approved by: https://github.com/jansel
ghstack dependencies: #153723, #153730
In order to take the globally best tiling, we need to normalize all the node read and writes to a common iteration space. This first pr finds a common split among nodes in a fused scheduler node, and then normalizes reads and writes to the common split.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153723
Approved by: https://github.com/jansel