Performance optimizations to proxy tensor (#85049)

- Lazily allocate FX nodes for size/stride accessors on proxy tensor
- Properly track derived computations on strides/numel/etc
- Remove unnecessary tree_map at end of proxy tensor trace checking
  invariants; we will just have to be smart (it's too expensive)
- Avoid tree_map in sym proxy tracing

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85049
Approved by: https://github.com/wconstab
This commit is contained in:
Edward Z. Yang
2022-09-14 22:00:50 -07:00
committed by PyTorch MergeBot
parent d49943bda8
commit 00ce302c07
7 changed files with 63 additions and 42 deletions

View File

@ -302,6 +302,10 @@ class TORCH_API TensorBase {
return impl_->sym_numel();
}
c10::SymInt sym_storage_offset() const {
return impl_->sym_storage_offset();
}
// Length of one array element in bytes. This is the traditional
// Numpy naming.
size_t itemsize() const {

View File

@ -800,9 +800,7 @@ def forward(self, a_1):
sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
mul = sym_size * 2; sym_size = None
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
sym_size_1 = torch.ops.aten.sym_size(empty, 0)
detach = torch.ops.aten.detach.default(empty); empty = None
sym_size_2 = torch.ops.aten.sym_size(detach, 0)
return detach""")
def test_symint_to_tensor(self):
@ -814,7 +812,6 @@ def forward(self, a_1):
def forward(self, a_1):
sym_size = torch.ops.aten.sym_size(a_1, 0)
div = torch.ops.aten.div.Tensor(a_1, sym_size); a_1 = sym_size = None
sym_size_1 = torch.ops.aten.sym_size(div, 0)
return div""")
r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
@ -825,7 +822,6 @@ def forward(self, a_1):
lt = sym_size < 0
eq = sym_size == sym_size; sym_size = None
div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None
sym_size_1 = torch.ops.aten.sym_size(div, 0)
return div""")
def test_cat(self):

View File

@ -213,7 +213,7 @@ static PyObject * THPVariable_storage_offset(PyObject* self_, PyObject* args)
return handle_torch_function(self_, "storage_offset");
}
auto& self = THPVariable_Unpack(self_);
return wrap(self.storage_offset());
return py::cast(self.sym_storage_offset()).release().ptr();
END_HANDLE_TH_ERRORS
}

View File

@ -67,12 +67,22 @@ void sym_size_int(Stack& stack) {
auto t = pop(stack).toTensor();
push(stack, t.sym_sizes()[dim]);
}
void sym_stride_int(Stack& stack) {
auto dim = pop(stack).toInt();
auto t = pop(stack).toTensor();
push(stack, t.sym_strides()[dim]);
}
void sym_numel(Stack& stack) {
auto t = std::move(pop(stack)).toTensor();
push(stack, t.sym_numel());
}
void sym_storage_offset(Stack& stack) {
auto t = std::move(pop(stack)).toTensor();
push(stack, t.sym_storage_offset());
}
void sym_stride(Stack& stack) {
auto t = std::move(pop(stack)).toTensor();
pack(stack, t.sym_strides().vec());

View File

@ -23,8 +23,12 @@ void sym_size(Stack& stack);
void sym_size_int(Stack& stack);
void sym_stride_int(Stack& stack);
void sym_numel(Stack& stack);
void sym_storage_offset(Stack& stack);
void sym_stride(Stack& stack);
void device(Stack& stack);

View File

@ -420,6 +420,11 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
"aten::sym_size.int(Tensor self, int dim) -> SymInt"),
sym_size_int,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::sym_stride.int(Tensor self, int dim) -> SymInt"),
sym_stride_int,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::stride(Tensor self) -> int[]"),
[](Stack& stack) {
@ -431,6 +436,11 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
TORCH_SELECTIVE_SCHEMA("aten::sym_numel(Tensor self) -> SymInt"),
sym_numel,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::sym_storage_offset(Tensor self) -> SymInt"),
sym_storage_offset,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::sym_stride(Tensor self) -> SymInt[]"),
sym_stride,

View File

@ -21,7 +21,6 @@ import weakref
from torch.utils._python_dispatch import TorchDispatchMode, enable_torch_dispatch_mode
from torch._subclasses import FakeTensor
from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat
import torch.fx.experimental.symbolic_shapes as symbolic_shapes
from torch.fx import Proxy
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy"]
@ -107,23 +106,33 @@ def set_meta(proxy, val):
def track_tensor(tensor, proxy, *, constant, tracer):
def try_set_proxy_slot(outer_s, proxy_callable, *args):
assert callable(proxy_callable)
if isinstance(outer_s, SymInt):
inner_s = outer_s.get_pyobj()
assert isinstance(inner_s, PySymInt)
proxy = None
def thunk():
nonlocal proxy
if proxy is None:
proxy = proxy_callable(inner_s, *args)
return proxy
set_proxy_slot(inner_s, tracer, thunk)
# The basic idea is that we need to associate each tensor/SymInt
# with a Proxy. How do we setup this association? We just store
# the proxy on the proxy slot of the object, keyed on the tracer
# (so that if we have multiple tracers at the same time, they
# don't clobber each other.)
for i, s in enumerate(tensor.shape):
if isinstance(s, SymInt):
inner_s = s.get_pyobj()
assert isinstance(inner_s, PySymInt)
# TODO: improve naming
# TODO: lazily insert this into the graph only on first
# use? Maybe complicated and DCE is a better idea
s_proxy = torch.ops.aten.sym_size(proxy, i)
set_meta(s_proxy, inner_s)
set_proxy_slot(inner_s, tracer, s_proxy)
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_size(proxy, i), x), i)
# TODO: also do stride/numel
for i, s in enumerate(tensor.stride()):
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_stride(proxy, i), x), i)
try_set_proxy_slot(tensor.numel(), lambda x: set_meta(torch.ops.aten.sym_numel(proxy), x))
try_set_proxy_slot(tensor.storage_offset(), lambda x: set_meta(torch.ops.aten.sym_storage_offset(proxy), x))
set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant))
def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
@ -177,7 +186,7 @@ def fetch_sym_proxy(tracer):
return n.constant
else:
# NB: we REQUIRE all symints to be tracked
return get_proxy_slot(n, tracer)
return get_proxy_slot(n, tracer)()
return inner
@ -204,14 +213,13 @@ def proxy_call(proxy_mode, func, args, kwargs):
# Some of these are not "real" aten ops and will fail if we
# call _dispatch_has_kernel_for_dispatch_key on them.
# This list is probably incomplete
if func not in [torch.ops.aten.size.default]:
if func not in [torch.ops.aten.size.default, torch.ops.aten.sym_storage_offset.default]:
with proxy_mode.restore():
r = func.decompose(*args, **kwargs)
if r is not NotImplemented:
return r
tracer = proxy_mode.tracer
f_args, f_kwargs = pytree.tree_map_only(torch.Tensor, fetch_tensor_proxy(tracer), (args, kwargs))
# If there are SymInts, we also should not consider this constant.
@ -435,24 +443,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
if not self.enable_tracing:
return func(*args, **kwargs)
if symbolic_shapes.is_symbolic_op(func):
with self.restore():
return symbolic_shapes.handle_symbolic_op(func, args, kwargs)
if func in [prim.device.default]:
return func(*args, **kwargs)
out = proxy_call(self, func, args, kwargs)
def assert_proxy_tensor(e):
assert has_proxy_slot(e, self.tracer), \
f"Internal Error: make_fx is incorrectly baking a tensor constant into the graph: {str(e)}"
# When we trace factory functions, we expect that tensor outputs are *always* tracked.
# (Except for torch.tensor() constants handled through lift(), which is handled
# specially further up).
pytree.tree_map_only(torch.Tensor, assert_proxy_tensor, out)
return out
return proxy_call(self, func, args, kwargs)
SymInt = torch.SymIntNode
@ -480,21 +474,24 @@ class ProxySymDispatchMode(SymDispatchMode):
def __sym_dispatch__(self, func, types, args, kwargs):
if not self.enable_tracing:
return func(*args, **kwargs)
p_args, p_kwargs = pytree.tree_map_only(
(PySymInt, PySymFloat),
lambda s: get_proxy_slot(s, self.tracer) if s.constant is None else s.constant,
(args, kwargs)
# For speed, we assume there are no nested data structures
# (otherwise we could use tree_map)
# We also assume there are no keyword arguments.
assert not kwargs
n_args = tuple(
get_proxy_slot(a, self.tracer)().node if a.constant is None else a.constant
if isinstance(a, (PySymInt, PySymFloat)) else a
for a in args
)
# func doesn't have a __torch_function__ that Proxy can interpose, so
# we gotta do it manually
n_args, n_kwargs = pytree.tree_map_only(fx.Proxy, lambda p: p.node, (p_args, p_kwargs))
n_out = self.tracer.create_node("call_function", func, n_args, n_kwargs)
n_out = self.tracer.create_node("call_function", func, n_args, {})
p_out = fx.Proxy(n_out, self.tracer)
out = func(*args, **kwargs)
set_meta(p_out, out)
assert isinstance(out, (PySymInt, PySymFloat)), f"{func}(*{args}, **{kwargs}) = {out}"
set_proxy_slot(out, self.tracer, p_out)
set_proxy_slot(out, self.tracer, lambda: p_out)
return out