mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
minor proxy_tensor reorg (#165266)
Moving some code around in proxy_tensor in preparation for the next PR. There we no actual changes (other than simple relabeling such as `self.tracer` -> `tracer`): - Move _compute_proxy() out of ProxyTorchDispatchMode. - Give `sympy_expr_tracker` a structured type instead of `object`. - Split SymNode registration out of ProxyTorchDispatchMode.__sym_dispatch__() so it can be reused. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165266 Approved by: https://github.com/ezyang, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
2cd5fd1588
commit
e86942f422
@ -286,7 +286,8 @@ def set_proxy_slot( # type: ignore[no-redef]
|
||||
# is derivable from a primal that we use that.
|
||||
assert isinstance(obj, py_sym_types), type(obj)
|
||||
if obj not in tracer.symnode_tracker:
|
||||
tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy)
|
||||
proxy = typing.cast(_PySymProxyType, proxy)
|
||||
tracer.symnode_tracker[obj] = proxy
|
||||
|
||||
# WAR: python test/dynamo/test_subclasses.py
|
||||
# TestNestedTensor.test_basic_autograd
|
||||
@ -303,7 +304,7 @@ def set_proxy_slot( # type: ignore[no-redef]
|
||||
import sympy
|
||||
|
||||
if isinstance(obj.node.expr, sympy.Symbol):
|
||||
tracer.sympy_expr_tracker[obj.node.expr] = proxy
|
||||
tracer.sympy_expr_tracker[obj.node.expr] = _SympyExprTrackerValue(proxy)
|
||||
|
||||
|
||||
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
|
||||
@ -409,7 +410,7 @@ def get_proxy_slot(
|
||||
if obj not in tracker:
|
||||
# Last ditch
|
||||
if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
|
||||
value = tracer.sympy_expr_tracker[obj.node.expr]
|
||||
value = tracer.sympy_expr_tracker[obj.node.expr].proxy
|
||||
else:
|
||||
if isinstance(default, _NoDefault):
|
||||
raise RuntimeError(
|
||||
@ -1108,10 +1109,15 @@ class _SymNodeDict:
|
||||
return len(self.sym_node_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SympyExprTrackerValue:
|
||||
proxy: _PySymProxyType
|
||||
|
||||
|
||||
class PythonKeyTracer(Tracer):
|
||||
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
|
||||
symnode_tracker: _SymNodeDict
|
||||
sympy_expr_tracker: dict[sympy.Symbol, object]
|
||||
sympy_expr_tracker: dict[sympy.Symbol, _SympyExprTrackerValue]
|
||||
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
|
||||
torch_fn_counts: dict[OpOverload, int]
|
||||
enable_thunkify: bool = False
|
||||
@ -1123,7 +1129,7 @@ class PythonKeyTracer(Tracer):
|
||||
self.script_object_tracker = WeakIdKeyDictionary(
|
||||
dict=None, ref_type=_WeakHashRef
|
||||
)
|
||||
self.sympy_expr_tracker = dict()
|
||||
self.sympy_expr_tracker = {}
|
||||
|
||||
# Stores the torch function that was called during tracing
|
||||
self.torch_fn_metadata = None
|
||||
@ -1578,39 +1584,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
def is_infra_mode(cls) -> bool:
|
||||
return True
|
||||
|
||||
def _compute_proxy(
|
||||
self, func: OpOverload, args: tuple[object, ...], out: PySymType
|
||||
) -> Proxy:
|
||||
# Handle torch.sym_sum
|
||||
n_args: tuple[object, ...]
|
||||
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
||||
n_args = (
|
||||
tuple(
|
||||
(
|
||||
get_proxy_slot(a, self.tracer).force().node
|
||||
if isinstance(a, py_sym_types)
|
||||
else a
|
||||
)
|
||||
for a in args[0]
|
||||
),
|
||||
)
|
||||
else:
|
||||
n_args = tuple(
|
||||
(
|
||||
get_proxy_slot(a, self.tracer).force().node
|
||||
if isinstance(a, py_sym_types)
|
||||
else a
|
||||
)
|
||||
for a in args
|
||||
)
|
||||
|
||||
# func doesn't have a __torch_function__ that Proxy can interpose, so
|
||||
# we gotta do it manually
|
||||
n_out = self.tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type]
|
||||
p_out = fx.Proxy(n_out, self.tracer)
|
||||
set_meta(p_out, out)
|
||||
return p_out
|
||||
|
||||
def __sym_dispatch__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
@ -1631,25 +1604,63 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
# We also assume there are no keyword arguments.
|
||||
assert not kwargs
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
# If func returned a constant, we don't need to trace; we have
|
||||
# determined that the result is constant (no matter if the inputs
|
||||
# were symbolic) and it is no longer necessary to trace the
|
||||
# computation. This could occur if func triggered some guards.
|
||||
if isinstance(out, py_sym_types):
|
||||
p_out_thunk = thunkify(
|
||||
self.tracer, self._compute_proxy, func=func, args=args, out=out
|
||||
)
|
||||
set_proxy_slot(out, self.tracer, p_out_thunk)
|
||||
|
||||
_sym_register(self.tracer, func, args, out)
|
||||
return out
|
||||
|
||||
|
||||
def _sym_register(
|
||||
tracer: _ProxyTracer, func: OpOverload, args: tuple[object, ...], out: object
|
||||
) -> None:
|
||||
# If func returned a constant, we don't need to trace; we have
|
||||
# determined that the result is constant (no matter if the inputs
|
||||
# were symbolic) and it is no longer necessary to trace the
|
||||
# computation. This could occur if func triggered some guards.
|
||||
if isinstance(out, py_sym_types):
|
||||
p_out_thunk = thunkify(
|
||||
tracer, _compute_proxy, tracer, func=func, args=args, out=out
|
||||
)
|
||||
set_proxy_slot(out, tracer, p_out_thunk)
|
||||
|
||||
|
||||
def _compute_proxy(
|
||||
tracer: _ProxyTracer, func: OpOverload, args: tuple[object, ...], out: PySymType
|
||||
) -> Proxy:
|
||||
# Handle torch.sym_sum
|
||||
n_args: tuple[object, ...]
|
||||
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
||||
n_args = (
|
||||
tuple(
|
||||
(
|
||||
get_proxy_slot(a, tracer).force().node
|
||||
if isinstance(a, py_sym_types)
|
||||
else a
|
||||
)
|
||||
for a in args[0]
|
||||
),
|
||||
)
|
||||
else:
|
||||
n_args = tuple(
|
||||
(
|
||||
get_proxy_slot(a, tracer).force().node
|
||||
if isinstance(a, py_sym_types)
|
||||
else a
|
||||
)
|
||||
for a in args
|
||||
)
|
||||
|
||||
# func doesn't have a __torch_function__ that Proxy can interpose, so
|
||||
# we gotta do it manually
|
||||
n_out = tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type]
|
||||
p_out = fx.Proxy(n_out, tracer)
|
||||
set_meta(p_out, out)
|
||||
return p_out
|
||||
|
||||
|
||||
class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):
|
||||
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
|
||||
symnode_tracker: MutableMapping[PySymType, _PySymProxyType]
|
||||
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
|
||||
sympy_expr_tracker: dict[sympy.Symbol, object]
|
||||
sympy_expr_tracker: dict[sympy.Symbol, _SympyExprTrackerValue]
|
||||
torch_fn_metadata: Optional[OpOverload]
|
||||
torch_fn_counts: dict[OpOverload, int]
|
||||
enable_thunkify: bool = False
|
||||
|
Reference in New Issue
Block a user