mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Our prevailing strategy for symbolic shapes in C++ is to only write the SymInt version of the code, and pay a slight performance tax from not knowing if it is symbolic or not. However, there are some fastpath functions where this tax is unacceptable, and we want to specialize for the int case. Sometimes, it is easy to template the function; but when the function involves Tensors, it is not, because the functions you may want to call are not templated, e.g., t.view vs t.view_symint This PR adds an at::symint:: namespace which contains templated functions for all functions in PyTorch which you can use in this way. To show this works, I refactored sum_to to stop incorrectly reinterpret casting and instead use a template. Instead of t.sizes(), we call at::symint::sizes<T>(t), and so forth. The template functions are SFINAE'd using a template argument that is not otherwise used. As such, deduction is impossible. Typically, deduction is hard anyway, because many of the constructors are ambiguous (this is why we split foo and foo_symint in the first place). So you must pass a template argument to these functions. These functions are codegened into Functions.h so they are subject to per-operator headers. This matters most for methods, which likely didn't include the per-operator header, so you will have to add an include in that case. We never generate method variants for these. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/86329 Approved by: https://github.com/bdhirsh, https://github.com/voznesenskym