mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
261 lines
9.3 KiB
Python
261 lines
9.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
from dataclasses import dataclass
|
|
import functools
|
|
from typing import Any, Dict, NamedTuple, Optional, Set, Tuple, List, Callable, Union
|
|
import torch
|
|
from torch._C import _disabled_torch_function_impl
|
|
from torch.fx.node import map_aggregate
|
|
import torch.utils._pytree as pytree
|
|
from torch.fx import Tracer, GraphModule
|
|
import torch.fx as fx
|
|
import torch.fx._pytree as fx_pytree
|
|
from torch import Tensor
|
|
from .nnc_compile import nnc_compile
|
|
from .decompositions import decomposition_table
|
|
from enum import Enum
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
|
|
|
|
USE_DECOMPOSE = False
|
|
|
|
@contextmanager
|
|
def pythonkey_decompose():
|
|
global USE_DECOMPOSE
|
|
USE_DECOMPOSE = True
|
|
try:
|
|
yield USE_DECOMPOSE
|
|
finally:
|
|
USE_DECOMPOSE = False
|
|
|
|
class PythonTensor(torch.Tensor):
|
|
elem: torch.Tensor
|
|
|
|
__slots__ = ['elem', 'proxy']
|
|
|
|
@staticmethod
|
|
def __new__(cls, elem, proxy):
|
|
# The wrapping tensor (PythonTensor) is just a meta tensor, so it
|
|
# doesn't hold any memory (meta tensor is generally the preferred type
|
|
# of tensor you want to make a subclass from)...
|
|
meta = elem.new_empty((0,))
|
|
meta.set_(meta.storage(), 0, elem.size(), elem.stride())
|
|
r = torch.Tensor._make_subclass(cls, meta, elem.requires_grad)
|
|
|
|
# ...the real tensor is held as an element on the tensor.
|
|
r.elem = elem
|
|
r.proxy = proxy
|
|
return r
|
|
|
|
def __repr__(self):
|
|
return f"PythonTensor({self.elem})"
|
|
|
|
__torch_function__ = _disabled_torch_function_impl
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
if func in decomposition_table and USE_DECOMPOSE:
|
|
return decomposition_table[func](*args, **kwargs)
|
|
def unwrap_proxy(e):
|
|
return e.proxy if isinstance(e, PythonTensor) else e
|
|
|
|
def unwrap_tensor(e):
|
|
return e.elem if isinstance(e, PythonTensor) else e
|
|
proxy_args = pytree.tree_map(unwrap_proxy, args)
|
|
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
|
|
proxy_out = func(*proxy_args, **proxy_kwargs)
|
|
real_out = func(*pytree.tree_map(unwrap_tensor, args), **pytree.tree_map(unwrap_tensor, kwargs))
|
|
|
|
def wrap_with_proxy(e, idx):
|
|
# Some ops (like native_batch_norm_backward) return undefined tensors that get converted into None in python.
|
|
# As the function signature expects tensors, if we directly return these None tensors back to C++, we'll error.
|
|
if e is None:
|
|
return PythonTensor(torch.empty(()), proxy_out[idx])
|
|
return PythonTensor(e, proxy_out[idx]) if type(e) == torch.Tensor else e
|
|
if isinstance(real_out, tuple):
|
|
return tuple([wrap_with_proxy(e, idx) for idx, e in enumerate(real_out)])
|
|
elif isinstance(real_out, list):
|
|
return list([wrap_with_proxy(e, idx) for idx, e in enumerate(real_out)])
|
|
elif isinstance(real_out, torch.Tensor):
|
|
return PythonTensor(real_out, proxy_out)
|
|
else:
|
|
return real_out
|
|
|
|
class PythonKeyTracer(Tracer):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
|
|
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
|
|
return forward(*args, **kwargs)
|
|
|
|
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
|
if isinstance(attr_val, torch.nn.Parameter):
|
|
for n, p in self.root.named_parameters():
|
|
if attr_val is p:
|
|
if n not in parameter_proxy_cache:
|
|
proxy = self.create_proxy('get_attr', n, (), {})
|
|
parameter_proxy_cache[n] = PythonTensor(attr_val, proxy)
|
|
return parameter_proxy_cache[n]
|
|
return attr_val
|
|
return attr_val
|
|
|
|
# We need to do this so that parameters entering the `make_fx` context have
|
|
# a reference to them (and also have requires_grad set on them correctly
|
|
# I'm not actually sure if this is the right thing to do ...
|
|
def create_arg(self, a: Any):
|
|
if isinstance(a, torch.nn.Parameter):
|
|
for n, p in self.root.named_parameters():
|
|
if a is p:
|
|
return self.create_node('get_attr', n, (), {})
|
|
qualname : Optional[str] = None
|
|
|
|
if not qualname:
|
|
i = 0
|
|
while True:
|
|
qualname = f'_param_constant{i}'
|
|
if not hasattr(self.root, qualname):
|
|
break
|
|
i += 1
|
|
setattr(self.root, qualname, a)
|
|
|
|
return self.create_node('get_attr', qualname, (), {})
|
|
return super().create_arg(a)
|
|
|
|
|
|
def pythonkey_trace(root : Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule:
|
|
tracer = PythonKeyTracer()
|
|
graph = tracer.trace(root, concrete_args)
|
|
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
|
return GraphModule(tracer.root, graph, name)
|
|
|
|
def wrap_key(f, inps):
|
|
flat_inps, inp_spec = pytree.tree_flatten(inps)
|
|
@functools.wraps(f)
|
|
def wrapped(*args):
|
|
flat_args, args_spec = pytree.tree_flatten(args)
|
|
assert(len(flat_args) == len(flat_inps))
|
|
for idx, arg in enumerate(flat_args):
|
|
if isinstance(flat_inps[idx], torch.Tensor):
|
|
flat_args[idx] = PythonTensor(flat_inps[idx], arg)
|
|
else:
|
|
flat_args[idx] = flat_inps[idx]
|
|
|
|
tree_args = pytree.tree_unflatten(flat_args, args_spec)
|
|
out = f(*tree_args)
|
|
flat_outs, out_spec = pytree.tree_flatten(out)
|
|
for idx in range(len(flat_outs)):
|
|
if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], PythonTensor):
|
|
flat_outs[idx] = flat_outs[idx].proxy
|
|
return pytree.tree_unflatten(flat_outs, out_spec)
|
|
|
|
return wrapped
|
|
|
|
def make_fx(f):
|
|
@functools.wraps(f)
|
|
def wrapped(*args):
|
|
phs = pytree.tree_map(lambda x: fx.PH, args)
|
|
t = pythonkey_trace(wrap_key(f, args), concrete_args=tuple(phs))
|
|
return t
|
|
|
|
return wrapped
|
|
|
|
@dataclass(eq=True, frozen=True)
|
|
class TensorSpec:
|
|
shape: Tuple[int, ...]
|
|
stride: Tuple[int, ...]
|
|
dtype: torch.dtype
|
|
device: torch.device
|
|
|
|
@dataclass(eq=True, frozen=True)
|
|
class ConcreteValueSpec:
|
|
value: Any
|
|
|
|
@dataclass(eq=True, frozen=True)
|
|
class SpecializationKey:
|
|
func: Callable
|
|
specs: Tuple[Union[TensorSpec, ConcreteValueSpec], ...]
|
|
|
|
def get_spec(arg):
|
|
if isinstance(arg, torch.Tensor):
|
|
return TensorSpec(
|
|
tuple(arg.shape),
|
|
tuple(arg.stride()),
|
|
arg.dtype,
|
|
arg.device)
|
|
return ConcreteValueSpec(arg)
|
|
|
|
def construct_specialization_key(f, args):
|
|
flat_args, _ = pytree.tree_flatten(args)
|
|
return SpecializationKey(f, tuple(get_spec(arg) for arg in flat_args))
|
|
|
|
nnc_jit_cache: Dict[Callable, Dict[SpecializationKey, Callable]] = {}
|
|
|
|
class RetrievalStatus(Enum):
|
|
Success = 0
|
|
UnknownFunc = 1
|
|
UnknownSpecialization = 2
|
|
|
|
def retrieve_from_cache(f, key):
|
|
if f not in nnc_jit_cache:
|
|
return RetrievalStatus.UnknownFunc, None
|
|
cache_for_f = nnc_jit_cache[f]
|
|
if key not in cache_for_f:
|
|
return RetrievalStatus.UnknownSpecialization, None
|
|
return RetrievalStatus.Success, cache_for_f[key]
|
|
|
|
def add_to_cache(f, key, compiled_f):
|
|
if f not in nnc_jit_cache:
|
|
nnc_jit_cache[f] = {key: compiled_f}
|
|
else:
|
|
nnc_jit_cache[f][key] = compiled_f
|
|
|
|
def nnc_jit(f, static_argnums=None, skip_specialization = False):
|
|
local_cache = None
|
|
@functools.wraps(f)
|
|
def compiled(*args):
|
|
nonlocal local_cache, static_argnums
|
|
if local_cache is not None and skip_specialization:
|
|
return local_cache(*args)
|
|
key = construct_specialization_key(f, args)
|
|
status, compiled_f = retrieve_from_cache(f, key)
|
|
if status is RetrievalStatus.Success:
|
|
return compiled_f(*args)
|
|
if status is RetrievalStatus.UnknownSpecialization:
|
|
warnings.warn(
|
|
f'Recompiling kernel for {f} due to new specialization. '
|
|
f'We recompile when we see inputs with new sizes/strides/'
|
|
f'dtype/device. Frequent recompilations can be bad for '
|
|
f'performance.',
|
|
stacklevel=2)
|
|
|
|
fx_model = make_fx(f)(*args)
|
|
fx_model.graph.lint()
|
|
if static_argnums is None:
|
|
static_argnums = []
|
|
if isinstance(static_argnums, int):
|
|
static_argnums = [static_argnums]
|
|
args = list(args)
|
|
for idx in range(len(args)):
|
|
if idx in static_argnums:
|
|
args[idx] = torch.empty(())
|
|
args = tuple(args)
|
|
compiled_f = nnc_compile(fx_model, args)
|
|
local_cache = compiled_f
|
|
add_to_cache(f, key, compiled_f)
|
|
return compiled_f(*args)
|
|
return compiled
|
|
|
|
def make_nnc(f):
|
|
@functools.wraps(f)
|
|
def wrapped(*args):
|
|
fx_model = make_fx(f)(*args)
|
|
fx_model.graph.lint()
|
|
compiled_f = nnc_compile(fx_model, args, get_loopnest=True)
|
|
return compiled_f
|
|
|
|
return wrapped
|