mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Fix: https://github.com/pytorch/xla/issues/8755 This PR introduces `TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE` environment variable. Setting this variable makes it so the functionalization kernels won't run the meta reference, which is used to propagate expected sizes and strides. Currently, PyTorch/XLA doesn't actually propagates the correct strides to its tensors. It was also shown that calling these meta functions may incur in significant overhead. Running the provided minimal reproducer (see issue), we see a speedup close to 4.3x: - Baseline: 0.0747s - `XLA_DISABLE_FUNCTIONALIZATION=1`: 0.0159s - `TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1`: 0.0175s In summary, this PR: - Creates the `disable_meta_reference()` function, which checks whether the environment variable is set - Modifies codegen for functionalization kernels, adding the call to `disable_meta_reference()` function to the appropriate conditions - Creates a new bash function for running `lazy/test_ts_opinfo.py` with the environment variable set Pull Request resolved: https://github.com/pytorch/pytorch/pull/148822 Approved by: https://github.com/bdhirsh