[JIT SSA] Allow updating shape functions without recompilation (#83629)

In order to avoid extra round trips, and avoid confusion in places such as
this to manually pull in the latest copy of the shape_functions.py file

This also fixes the cases where people pull in the wrong version of the file. This can happen in cases such as when developers run `python setup.py install` instead of `python setup.py develop` to generate their current copy of Pytorch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83629
Approved by: https://github.com/davidberard98
This commit is contained in:
John Clow
2022-08-17 15:07:57 -07:00
committed by PyTorch MergeBot
parent 53cda905be
commit eff28d61c9

View File

@ -1,12 +1,34 @@
#!/usr/bin/env python3
import importlib.util
import os
import sys
from itertools import chain
from pathlib import Path
from torch.jit._shape_functions import (
bounded_compute_graph_mapping,
shape_compute_graph_mapping,
)
# Manually importing the shape function module based on current directory
# instead of torch imports to avoid needing to recompile Pytorch before
# running the script
file_path = Path.cwd() / "torch" / "jit" / "_shape_functions.py"
module_name = "torch.jit._shape_functions"
err_msg = """Could not find shape functions file, please make sure
you are in the root directory of the Pytorch git repo"""
if not file_path.exists():
raise Exception(err_msg)
spec = importlib.util.spec_from_file_location(module_name, file_path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
assert spec.loader is not None
assert module is not None
spec.loader.exec_module(module)
bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
shape_compute_graph_mapping = module.shape_compute_graph_mapping
SHAPE_HEADER = r"""
/**