minor proxy_tensor reorg (#165266)

Moving some code around in proxy_tensor in preparation for the next PR. There we
no actual changes (other than simple relabeling such as `self.tracer` ->
`tracer`):

- Move _compute_proxy() out of ProxyTorchDispatchMode.

- Give `sympy_expr_tracker` a structured type instead of `object`.

- Split SymNode registration out of ProxyTorchDispatchMode.__sym_dispatch__() so
  it can be reused.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165266
Approved by: https://github.com/ezyang, https://github.com/mlazos
This commit is contained in:
Aaron Orenstein
2025-10-12 08:36:52 -07:00
committed by PyTorch MergeBot
parent 2cd5fd1588
commit e86942f422

View File

@ -286,7 +286,8 @@ def set_proxy_slot( # type: ignore[no-redef]
# is derivable from a primal that we use that.
assert isinstance(obj, py_sym_types), type(obj)
if obj not in tracer.symnode_tracker:
tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy)
proxy = typing.cast(_PySymProxyType, proxy)
tracer.symnode_tracker[obj] = proxy
# WAR: python test/dynamo/test_subclasses.py
# TestNestedTensor.test_basic_autograd
@ -303,7 +304,7 @@ def set_proxy_slot( # type: ignore[no-redef]
import sympy
if isinstance(obj.node.expr, sympy.Symbol):
tracer.sympy_expr_tracker[obj.node.expr] = proxy
tracer.sympy_expr_tracker[obj.node.expr] = _SympyExprTrackerValue(proxy)
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
@ -409,7 +410,7 @@ def get_proxy_slot(
if obj not in tracker:
# Last ditch
if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
value = tracer.sympy_expr_tracker[obj.node.expr]
value = tracer.sympy_expr_tracker[obj.node.expr].proxy
else:
if isinstance(default, _NoDefault):
raise RuntimeError(
@ -1108,10 +1109,15 @@ class _SymNodeDict:
return len(self.sym_node_dict)
@dataclass
class _SympyExprTrackerValue:
proxy: _PySymProxyType
class PythonKeyTracer(Tracer):
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
symnode_tracker: _SymNodeDict
sympy_expr_tracker: dict[sympy.Symbol, object]
sympy_expr_tracker: dict[sympy.Symbol, _SympyExprTrackerValue]
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
torch_fn_counts: dict[OpOverload, int]
enable_thunkify: bool = False
@ -1123,7 +1129,7 @@ class PythonKeyTracer(Tracer):
self.script_object_tracker = WeakIdKeyDictionary(
dict=None, ref_type=_WeakHashRef
)
self.sympy_expr_tracker = dict()
self.sympy_expr_tracker = {}
# Stores the torch function that was called during tracing
self.torch_fn_metadata = None
@ -1578,39 +1584,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
def is_infra_mode(cls) -> bool:
return True
def _compute_proxy(
self, func: OpOverload, args: tuple[object, ...], out: PySymType
) -> Proxy:
# Handle torch.sym_sum
n_args: tuple[object, ...]
if len(args) == 1 and isinstance(args[0], (list, tuple)):
n_args = (
tuple(
(
get_proxy_slot(a, self.tracer).force().node
if isinstance(a, py_sym_types)
else a
)
for a in args[0]
),
)
else:
n_args = tuple(
(
get_proxy_slot(a, self.tracer).force().node
if isinstance(a, py_sym_types)
else a
)
for a in args
)
# func doesn't have a __torch_function__ that Proxy can interpose, so
# we gotta do it manually
n_out = self.tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type]
p_out = fx.Proxy(n_out, self.tracer)
set_meta(p_out, out)
return p_out
def __sym_dispatch__(
self,
func: OpOverload,
@ -1631,25 +1604,63 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
# We also assume there are no keyword arguments.
assert not kwargs
out = func(*args, **kwargs)
# If func returned a constant, we don't need to trace; we have
# determined that the result is constant (no matter if the inputs
# were symbolic) and it is no longer necessary to trace the
# computation. This could occur if func triggered some guards.
if isinstance(out, py_sym_types):
p_out_thunk = thunkify(
self.tracer, self._compute_proxy, func=func, args=args, out=out
)
set_proxy_slot(out, self.tracer, p_out_thunk)
_sym_register(self.tracer, func, args, out)
return out
def _sym_register(
tracer: _ProxyTracer, func: OpOverload, args: tuple[object, ...], out: object
) -> None:
# If func returned a constant, we don't need to trace; we have
# determined that the result is constant (no matter if the inputs
# were symbolic) and it is no longer necessary to trace the
# computation. This could occur if func triggered some guards.
if isinstance(out, py_sym_types):
p_out_thunk = thunkify(
tracer, _compute_proxy, tracer, func=func, args=args, out=out
)
set_proxy_slot(out, tracer, p_out_thunk)
def _compute_proxy(
tracer: _ProxyTracer, func: OpOverload, args: tuple[object, ...], out: PySymType
) -> Proxy:
# Handle torch.sym_sum
n_args: tuple[object, ...]
if len(args) == 1 and isinstance(args[0], (list, tuple)):
n_args = (
tuple(
(
get_proxy_slot(a, tracer).force().node
if isinstance(a, py_sym_types)
else a
)
for a in args[0]
),
)
else:
n_args = tuple(
(
get_proxy_slot(a, tracer).force().node
if isinstance(a, py_sym_types)
else a
)
for a in args
)
# func doesn't have a __torch_function__ that Proxy can interpose, so
# we gotta do it manually
n_out = tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type]
p_out = fx.Proxy(n_out, tracer)
set_meta(p_out, out)
return p_out
class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
symnode_tracker: MutableMapping[PySymType, _PySymProxyType]
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
sympy_expr_tracker: dict[sympy.Symbol, object]
sympy_expr_tracker: dict[sympy.Symbol, _SympyExprTrackerValue]
torch_fn_metadata: Optional[OpOverload]
torch_fn_counts: dict[OpOverload, int]
enable_thunkify: bool = False