mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT SSA] Allow updating shape functions without recompilation (#83629)
In order to avoid extra round trips, and avoid confusion in places such as this to manually pull in the latest copy of the shape_functions.py file This also fixes the cases where people pull in the wrong version of the file. This can happen in cases such as when developers run `python setup.py install` instead of `python setup.py develop` to generate their current copy of Pytorch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83629 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
53cda905be
commit
eff28d61c9
@ -1,12 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
from torch.jit._shape_functions import (
|
||||
bounded_compute_graph_mapping,
|
||||
shape_compute_graph_mapping,
|
||||
)
|
||||
|
||||
# Manually importing the shape function module based on current directory
|
||||
# instead of torch imports to avoid needing to recompile Pytorch before
|
||||
# running the script
|
||||
|
||||
file_path = Path.cwd() / "torch" / "jit" / "_shape_functions.py"
|
||||
module_name = "torch.jit._shape_functions"
|
||||
|
||||
err_msg = """Could not find shape functions file, please make sure
|
||||
you are in the root directory of the Pytorch git repo"""
|
||||
if not file_path.exists():
|
||||
raise Exception(err_msg)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
assert spec is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
assert spec.loader is not None
|
||||
assert module is not None
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
|
||||
shape_compute_graph_mapping = module.shape_compute_graph_mapping
|
||||
|
||||
|
||||
SHAPE_HEADER = r"""
|
||||
/**
|
||||
|
Reference in New Issue
Block a user