mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Performance optimizations to proxy tensor (#85049)
- Lazily allocate FX nodes for size/stride accessors on proxy tensor - Properly track derived computations on strides/numel/etc - Remove unnecessary tree_map at end of proxy tensor trace checking invariants; we will just have to be smart (it's too expensive) - Avoid tree_map in sym proxy tracing Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/85049 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
d49943bda8
commit
00ce302c07
@ -302,6 +302,10 @@ class TORCH_API TensorBase {
|
||||
return impl_->sym_numel();
|
||||
}
|
||||
|
||||
c10::SymInt sym_storage_offset() const {
|
||||
return impl_->sym_storage_offset();
|
||||
}
|
||||
|
||||
// Length of one array element in bytes. This is the traditional
|
||||
// Numpy naming.
|
||||
size_t itemsize() const {
|
||||
|
@ -800,9 +800,7 @@ def forward(self, a_1):
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
|
||||
mul = sym_size * 2; sym_size = None
|
||||
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
|
||||
sym_size_1 = torch.ops.aten.sym_size(empty, 0)
|
||||
detach = torch.ops.aten.detach.default(empty); empty = None
|
||||
sym_size_2 = torch.ops.aten.sym_size(detach, 0)
|
||||
return detach""")
|
||||
|
||||
def test_symint_to_tensor(self):
|
||||
@ -814,7 +812,6 @@ def forward(self, a_1):
|
||||
def forward(self, a_1):
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0)
|
||||
div = torch.ops.aten.div.Tensor(a_1, sym_size); a_1 = sym_size = None
|
||||
sym_size_1 = torch.ops.aten.sym_size(div, 0)
|
||||
return div""")
|
||||
|
||||
r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
|
||||
@ -825,7 +822,6 @@ def forward(self, a_1):
|
||||
lt = sym_size < 0
|
||||
eq = sym_size == sym_size; sym_size = None
|
||||
div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None
|
||||
sym_size_1 = torch.ops.aten.sym_size(div, 0)
|
||||
return div""")
|
||||
|
||||
def test_cat(self):
|
||||
|
@ -213,7 +213,7 @@ static PyObject * THPVariable_storage_offset(PyObject* self_, PyObject* args)
|
||||
return handle_torch_function(self_, "storage_offset");
|
||||
}
|
||||
auto& self = THPVariable_Unpack(self_);
|
||||
return wrap(self.storage_offset());
|
||||
return py::cast(self.sym_storage_offset()).release().ptr();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
@ -67,12 +67,22 @@ void sym_size_int(Stack& stack) {
|
||||
auto t = pop(stack).toTensor();
|
||||
push(stack, t.sym_sizes()[dim]);
|
||||
}
|
||||
void sym_stride_int(Stack& stack) {
|
||||
auto dim = pop(stack).toInt();
|
||||
auto t = pop(stack).toTensor();
|
||||
push(stack, t.sym_strides()[dim]);
|
||||
}
|
||||
|
||||
void sym_numel(Stack& stack) {
|
||||
auto t = std::move(pop(stack)).toTensor();
|
||||
push(stack, t.sym_numel());
|
||||
}
|
||||
|
||||
void sym_storage_offset(Stack& stack) {
|
||||
auto t = std::move(pop(stack)).toTensor();
|
||||
push(stack, t.sym_storage_offset());
|
||||
}
|
||||
|
||||
void sym_stride(Stack& stack) {
|
||||
auto t = std::move(pop(stack)).toTensor();
|
||||
pack(stack, t.sym_strides().vec());
|
||||
|
@ -23,8 +23,12 @@ void sym_size(Stack& stack);
|
||||
|
||||
void sym_size_int(Stack& stack);
|
||||
|
||||
void sym_stride_int(Stack& stack);
|
||||
|
||||
void sym_numel(Stack& stack);
|
||||
|
||||
void sym_storage_offset(Stack& stack);
|
||||
|
||||
void sym_stride(Stack& stack);
|
||||
|
||||
void device(Stack& stack);
|
||||
|
@ -420,6 +420,11 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
"aten::sym_size.int(Tensor self, int dim) -> SymInt"),
|
||||
sym_size_int,
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::sym_stride.int(Tensor self, int dim) -> SymInt"),
|
||||
sym_stride_int,
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("aten::stride(Tensor self) -> int[]"),
|
||||
[](Stack& stack) {
|
||||
@ -431,6 +436,11 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
TORCH_SELECTIVE_SCHEMA("aten::sym_numel(Tensor self) -> SymInt"),
|
||||
sym_numel,
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::sym_storage_offset(Tensor self) -> SymInt"),
|
||||
sym_storage_offset,
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("aten::sym_stride(Tensor self) -> SymInt[]"),
|
||||
sym_stride,
|
||||
|
@ -21,7 +21,6 @@ import weakref
|
||||
from torch.utils._python_dispatch import TorchDispatchMode, enable_torch_dispatch_mode
|
||||
from torch._subclasses import FakeTensor
|
||||
from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat
|
||||
import torch.fx.experimental.symbolic_shapes as symbolic_shapes
|
||||
from torch.fx import Proxy
|
||||
|
||||
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy"]
|
||||
@ -107,23 +106,33 @@ def set_meta(proxy, val):
|
||||
|
||||
|
||||
def track_tensor(tensor, proxy, *, constant, tracer):
|
||||
def try_set_proxy_slot(outer_s, proxy_callable, *args):
|
||||
assert callable(proxy_callable)
|
||||
if isinstance(outer_s, SymInt):
|
||||
inner_s = outer_s.get_pyobj()
|
||||
assert isinstance(inner_s, PySymInt)
|
||||
proxy = None
|
||||
|
||||
def thunk():
|
||||
nonlocal proxy
|
||||
if proxy is None:
|
||||
proxy = proxy_callable(inner_s, *args)
|
||||
return proxy
|
||||
set_proxy_slot(inner_s, tracer, thunk)
|
||||
|
||||
# The basic idea is that we need to associate each tensor/SymInt
|
||||
# with a Proxy. How do we setup this association? We just store
|
||||
# the proxy on the proxy slot of the object, keyed on the tracer
|
||||
# (so that if we have multiple tracers at the same time, they
|
||||
# don't clobber each other.)
|
||||
for i, s in enumerate(tensor.shape):
|
||||
if isinstance(s, SymInt):
|
||||
inner_s = s.get_pyobj()
|
||||
assert isinstance(inner_s, PySymInt)
|
||||
# TODO: improve naming
|
||||
# TODO: lazily insert this into the graph only on first
|
||||
# use? Maybe complicated and DCE is a better idea
|
||||
s_proxy = torch.ops.aten.sym_size(proxy, i)
|
||||
set_meta(s_proxy, inner_s)
|
||||
set_proxy_slot(inner_s, tracer, s_proxy)
|
||||
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_size(proxy, i), x), i)
|
||||
|
||||
# TODO: also do stride/numel
|
||||
for i, s in enumerate(tensor.stride()):
|
||||
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_stride(proxy, i), x), i)
|
||||
|
||||
try_set_proxy_slot(tensor.numel(), lambda x: set_meta(torch.ops.aten.sym_numel(proxy), x))
|
||||
try_set_proxy_slot(tensor.storage_offset(), lambda x: set_meta(torch.ops.aten.sym_storage_offset(proxy), x))
|
||||
set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant))
|
||||
|
||||
def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
|
||||
@ -177,7 +186,7 @@ def fetch_sym_proxy(tracer):
|
||||
return n.constant
|
||||
else:
|
||||
# NB: we REQUIRE all symints to be tracked
|
||||
return get_proxy_slot(n, tracer)
|
||||
return get_proxy_slot(n, tracer)()
|
||||
return inner
|
||||
|
||||
|
||||
@ -204,14 +213,13 @@ def proxy_call(proxy_mode, func, args, kwargs):
|
||||
# Some of these are not "real" aten ops and will fail if we
|
||||
# call _dispatch_has_kernel_for_dispatch_key on them.
|
||||
# This list is probably incomplete
|
||||
if func not in [torch.ops.aten.size.default]:
|
||||
if func not in [torch.ops.aten.size.default, torch.ops.aten.sym_storage_offset.default]:
|
||||
with proxy_mode.restore():
|
||||
r = func.decompose(*args, **kwargs)
|
||||
if r is not NotImplemented:
|
||||
return r
|
||||
|
||||
tracer = proxy_mode.tracer
|
||||
|
||||
f_args, f_kwargs = pytree.tree_map_only(torch.Tensor, fetch_tensor_proxy(tracer), (args, kwargs))
|
||||
|
||||
# If there are SymInts, we also should not consider this constant.
|
||||
@ -435,24 +443,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
if not self.enable_tracing:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if symbolic_shapes.is_symbolic_op(func):
|
||||
with self.restore():
|
||||
return symbolic_shapes.handle_symbolic_op(func, args, kwargs)
|
||||
|
||||
if func in [prim.device.default]:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
out = proxy_call(self, func, args, kwargs)
|
||||
|
||||
def assert_proxy_tensor(e):
|
||||
assert has_proxy_slot(e, self.tracer), \
|
||||
f"Internal Error: make_fx is incorrectly baking a tensor constant into the graph: {str(e)}"
|
||||
|
||||
# When we trace factory functions, we expect that tensor outputs are *always* tracked.
|
||||
# (Except for torch.tensor() constants handled through lift(), which is handled
|
||||
# specially further up).
|
||||
pytree.tree_map_only(torch.Tensor, assert_proxy_tensor, out)
|
||||
return out
|
||||
return proxy_call(self, func, args, kwargs)
|
||||
|
||||
|
||||
SymInt = torch.SymIntNode
|
||||
@ -480,21 +474,24 @@ class ProxySymDispatchMode(SymDispatchMode):
|
||||
def __sym_dispatch__(self, func, types, args, kwargs):
|
||||
if not self.enable_tracing:
|
||||
return func(*args, **kwargs)
|
||||
p_args, p_kwargs = pytree.tree_map_only(
|
||||
(PySymInt, PySymFloat),
|
||||
lambda s: get_proxy_slot(s, self.tracer) if s.constant is None else s.constant,
|
||||
(args, kwargs)
|
||||
# For speed, we assume there are no nested data structures
|
||||
# (otherwise we could use tree_map)
|
||||
# We also assume there are no keyword arguments.
|
||||
assert not kwargs
|
||||
n_args = tuple(
|
||||
get_proxy_slot(a, self.tracer)().node if a.constant is None else a.constant
|
||||
if isinstance(a, (PySymInt, PySymFloat)) else a
|
||||
for a in args
|
||||
)
|
||||
|
||||
# func doesn't have a __torch_function__ that Proxy can interpose, so
|
||||
# we gotta do it manually
|
||||
n_args, n_kwargs = pytree.tree_map_only(fx.Proxy, lambda p: p.node, (p_args, p_kwargs))
|
||||
|
||||
n_out = self.tracer.create_node("call_function", func, n_args, n_kwargs)
|
||||
n_out = self.tracer.create_node("call_function", func, n_args, {})
|
||||
p_out = fx.Proxy(n_out, self.tracer)
|
||||
out = func(*args, **kwargs)
|
||||
set_meta(p_out, out)
|
||||
assert isinstance(out, (PySymInt, PySymFloat)), f"{func}(*{args}, **{kwargs}) = {out}"
|
||||
set_proxy_slot(out, self.tracer, p_out)
|
||||
set_proxy_slot(out, self.tracer, lambda: p_out)
|
||||
return out
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user