Adds lowering for `aten.searchsorted`. This entails:
1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`.
2. Adding support for striding to `ops.bucketize`.
3. Adding support for sorting tensors to `ops.bucketize`.
4. Adding a lowering for `aten.searchsorted.Tensor`.
5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors.
6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions.
Closes#135873
Differential Revision: [D63766514](https://our.internmc.facebook.com/intern/diff/D63766514)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135701
Approved by: https://github.com/amjames, https://github.com/eellison, https://github.com/davidberard98
The scheduler searches for fusion opportunities by looking for common memory access. Two memory access are considered common not only when the buffer name match, but it also requires more things
- index formula matches
- var_ranges matches
In this PR, I want to log all the fusion failures due to mismatch index formula or var_ranges. I also want to further categories the failures. Right now I found the following failure categories
- rand_seed: the index for rand seed access is an integer and different access uses different integer offset
- different numel: this happens for cat operation
- broadcast: e.g. kernel A write a buffer which is broadcasted and read by kernel B
- different loop orders: the major category we want inductor to be able to fuse
- different offset: happens when use a concatenated linear layer to project Q/K/V and then split the result. Each split will point to the same buffer with different offset.
- unknown
My hope is to make sure for the models I tested, there is no fusion failure falling in the unknown category so all the failures are well understood and categories. Right now it's true for BertForMaskedLM ( https://gist.github.com/shunting314/6dc2c903629d342fa63ba731a171adc2 ), DistillGPT2 ( https://gist.github.com/shunting314/145176f2e850103c7fad4ad72f0e200e ) and llm.c ( https://gist.github.com/shunting314/cfc64a326312a889ba55f79bd47b2082 )
For BertForMaskedLM, we found 82 instances of fusion failures and majority of them are due to different loop orders! Studying the log a bit more can help us figure out where all these loop order mismatch comes from in real models.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124986
Approved by: https://github.com/eellison, https://github.com/jansel
dependencies.py is used for tracking reads and writes, which is used for identifying dependencies between buffers: i.e. if buffer X reads buffer Y, then X depends on Y. ops.bucketize() reads from an offsets tensor, so we should track it in dependencies.py to correctly track dependencies. Since bucketize performs a binary search over the offsets tensor, the dependency is marked as a StarDep to indicate that the entire tensor is needed.
Use case: we find that jagged tensor dense_to_jagged ops - which use bucketize() to map jagged indices to dense indices - perform better if the bucketize() kernel is separated from the gather kernel. Previously, because bucketize() wasn't marked as reading anything, it would just get inlined.
Differential Revision: [D47422704](https://our.internmc.facebook.com/intern/diff/D47422704)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105102
Approved by: https://github.com/eellison
dependencies.py is used for tracking reads and writes, which is used for identifying dependencies between buffers: i.e. if buffer X reads buffer Y, then X depends on Y. ops.bucketize() reads from an offsets tensor, so we should track it in dependencies.py to correctly track dependencies. Since bucketize performs a binary search over the offsets tensor, the dependency is marked as a StarDep to indicate that the entire tensor is needed.
Use case: we find that jagged tensor dense_to_jagged ops - which use bucketize() to map jagged indices to dense indices - perform better if the bucketize() kernel is separated from the gather kernel. Previously, because bucketize() wasn't marked as reading anything, it would just get inlined.
Differential Revision: [D47422704](https://our.internmc.facebook.com/intern/diff/D47422704)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105102
Approved by: https://github.com/eellison