Using the same repro from the issue (but with BatchNorm2D)
Rectifies native_batch_norm schema by splitting the schema into 2:
1. one will have NON-optional alias-able running_mean and running_var inputs
2. the other will just not have those parameters at all (no_stats variation)
**Calling for name suggestions!**
## test plan
I've added tests in test_functionalization.py as well as an entry in common_method_invocations.py for `native_batch_norm_legit`
CI should pass.
## next steps
Because of bc/fc reasons, we reroute native_batch_norm to call our new schemas ONLY through the python dispatcher, but in 2 weeks or so, we should make `native_batch_norm_legit` the official batch_norm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88697
Approved by: https://github.com/albanD
The assert check are moved to top and the function now returns out. This is needed by the downstream torch-mlir project to correctly determine the output type.
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85801
Approved by: https://github.com/eellison
The current implementation of the `sum_mean_dim` shape function
takes `dim=[]` and `dim=None` to mean "no reduction". However, in the
ops `torch.sum` and `torch.mean`, both `dim=[]` and `dim=None` are
equivalent to "reduce along all dimensions". This commit fixes the
handling of `dim` in the `sum_mean_dim` shape function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83357
Approved by: https://github.com/Gamrix
I think the person who edited `mean.dim` edited the python and the associated Torchscript version manually in two different ways. This diff fixes that all up. It also fixes the inconsistencies in the `torch.nonzero` generated file
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83092
Approved by: https://github.com/eellison
Moves jit shape function registration to python. Like jit decompositions, a script must be run after adding new definitions which serializes them in a c++ file.
This was a request so that torch-mlir could define functions in python and upstream their shape functions. cc @silvasean @makslevental
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75546
Approved by: https://github.com/davidberard98