mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86950 Approved by: https://github.com/Chillee
149 lines
5.5 KiB
Python
149 lines
5.5 KiB
Python
import copy
|
|
import functools
|
|
import itertools
|
|
import operator
|
|
|
|
import torch
|
|
from torch.fx.node import map_aggregate
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from .. import config
|
|
from ..utils import clone_inputs, fake_tensors_available
|
|
|
|
if fake_tensors_available:
|
|
from torch._subclasses import FakeTensorMode # noqa: F401
|
|
|
|
from ..utils import deepcopy_to_fake_tensor, wrap_to_fake_tensor
|
|
|
|
|
|
class ShapeAliasingAndMutationProp(ShapeProp):
|
|
def __init__(self, *args, **kwargs):
|
|
super(ShapeAliasingAndMutationProp, self).__init__(*args, **kwargs)
|
|
self.input_alias_groups = set()
|
|
self.storage_to_alias_group = dict()
|
|
self.make_alias_group = itertools.count(1)
|
|
|
|
def tensor_alias_group(self, value: torch.Tensor):
|
|
"""Assign a unique identifier to the storage of a given tensor"""
|
|
storage = StorageWeakRef(value.storage())
|
|
alias_group = self.storage_to_alias_group.get(storage)
|
|
if alias_group is None:
|
|
alias_group = next(self.make_alias_group)
|
|
self.storage_to_alias_group[storage] = alias_group
|
|
return alias_group
|
|
|
|
def placeholder(self, target, args, kwargs):
|
|
value = super().placeholder(target, args, kwargs)
|
|
assert isinstance(value, torch.Tensor)
|
|
self.input_alias_groups.add(self.tensor_alias_group(value))
|
|
return value
|
|
|
|
def run_node(self, n: torch.fx.Node):
|
|
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
|
tensor_args = self.extract_tensors((args, kwargs))
|
|
|
|
input_versions1 = [obj._version for obj in tensor_args]
|
|
result = getattr(self, n.op)(n.target, args, kwargs)
|
|
input_versions2 = [obj._version for obj in tensor_args]
|
|
|
|
n.meta["type"] = type(result)
|
|
n.meta["alias_groups"] = {
|
|
self.tensor_alias_group(obj) for obj in self.extract_tensors(result)
|
|
}
|
|
|
|
if (
|
|
not n.meta["alias_groups"]
|
|
and n.op == "call_function"
|
|
and n.target == operator.setitem
|
|
):
|
|
n.meta["alias_groups"] = {self.tensor_alias_group(tensor_args[0])}
|
|
|
|
n.meta["mutates_alias_groups"] = {
|
|
self.tensor_alias_group(tensor)
|
|
for tensor, v1, v2 in zip(tensor_args, input_versions1, input_versions2)
|
|
if v1 != v2
|
|
}
|
|
# Partial mutation refers to the mutation caused by getitem that can
|
|
# potentially result in changing only a slice of the original tensor
|
|
n.meta["partial_mutation"] = False
|
|
|
|
def visit_arg(arg: torch.fx.Node):
|
|
if (
|
|
arg.op == "call_function" and arg.target == operator.getitem
|
|
) or arg.meta["partial_mutation"]:
|
|
if bool(n.meta["mutates_alias_groups"] & arg.meta["alias_groups"]):
|
|
n.meta["partial_mutation"] = True
|
|
|
|
torch.fx.map_arg((n.args, n.kwargs), visit_arg)
|
|
n.meta["is_input_alias"] = bool(
|
|
self.input_alias_groups & n.meta["alias_groups"]
|
|
)
|
|
n.meta["is_input_mutation"] = bool(
|
|
self.input_alias_groups & n.meta["mutates_alias_groups"]
|
|
)
|
|
n.meta["is_mutation"] = bool(n.meta["mutates_alias_groups"])
|
|
n.meta["tensor_metas"] = [
|
|
_extract_tensor_metadata(obj) for obj in self.extract_tensors(result)
|
|
]
|
|
tensors = self.extract_tensors(result)
|
|
if tensors:
|
|
n.meta["device"] = tensors[0].device
|
|
n.meta["dtype"] = tensors[0].dtype
|
|
|
|
return result
|
|
|
|
@staticmethod
|
|
def extract_tensors(result):
|
|
"""Return a flat list of tensors found in some nested data structure"""
|
|
seen = set()
|
|
tensors = []
|
|
|
|
def visit(obj):
|
|
if isinstance(obj, torch.Tensor) and id(obj) not in seen:
|
|
seen.add(id(obj))
|
|
tensors.append(obj)
|
|
|
|
map_aggregate(result, visit)
|
|
return tensors
|
|
|
|
def run(self, *args):
|
|
try:
|
|
super().run(*args)
|
|
finally:
|
|
# cleanup
|
|
self.env.clear()
|
|
|
|
|
|
def has_mutation(gm, example_inputs, inputs_only=False):
|
|
"""Check if the graph module has any form of mutation. If inputs_only is
|
|
true, we only check for mutation of inputs"""
|
|
# TODO - moco gives bad accuracy with Aliasing. gm is getting mutated in a bad way.
|
|
|
|
# Clone the inputs such that intermediate tensors (not leaf tensors) with
|
|
# requires_grad to True are now converted to False to avoid Runtime Error
|
|
# like "leaf variable that requires grad is inplace modified"
|
|
example_inputs = clone_inputs(example_inputs)
|
|
if fake_tensors_available and config.fake_tensor_propagation:
|
|
with FakeTensorMode() as fake_mode:
|
|
pass
|
|
fake_wrapper = functools.partial(wrap_to_fake_tensor, fake_mode=fake_mode)
|
|
example_inputs = tree_map(fake_wrapper, example_inputs)
|
|
new_gm = deepcopy_to_fake_tensor(gm, fake_mode)
|
|
with fake_mode.restore() if hasattr(fake_mode, "restore") else fake_mode:
|
|
ShapeAliasingAndMutationProp(new_gm).run(*example_inputs)
|
|
else:
|
|
new_gm = copy.deepcopy(gm)
|
|
example_inputs = copy.deepcopy(example_inputs)
|
|
ShapeAliasingAndMutationProp(new_gm).run(*example_inputs)
|
|
|
|
for node in new_gm.graph.nodes:
|
|
if node.meta["is_mutation"] or node.meta["is_input_mutation"]:
|
|
if inputs_only:
|
|
if node.meta["is_input_alias"]:
|
|
return True
|
|
else:
|
|
return True
|
|
return False
|