mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
This is the initial foreach map HOP for pointwise ops which will be extended in the future to support grouped GEMMs and other ops. This PR utilizes PrimHOPBase class to represent foreach_map as a HOP with a single subgraph. The way this is implemented is that the user API `foreach_map` provides a single pointwise torch op, and internally this function calls a polyfill which has the same semantics as a foreach op (ie iterates over lists of operands applying the op elementwise). The higher order op is passed through the stack down to inductor where a lowering in essence inlines the subgraph into the main graph. This is done by interpreting it with a pointwise subgraph lowering, grouping the outputs by device, and registering the output buffers as foreach groups as applicable. For testing I was able to reuse the existing foreach tests by creating a wrapper function which matches the foreach op interfaces for those tests and then run all of the existing foreach tests on foreach_map. TODO before landing: * Add tests for general functions * Test warning if unsupported op will block fusion Followups: * I need to add tests for backwards (this will be a followup PR because backwards will require other work as well) Pull Request resolved: https://github.com/pytorch/pytorch/pull/142098 Approved by: https://github.com/eellison