mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy https://github.com/pytorch/pytorch/pull/87722/ This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyang@fb.com> cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87817 Approved by: https://github.com/albanD, https://github.com/anjali411
122 lines
4.6 KiB
Python
122 lines
4.6 KiB
Python
import torch
|
|
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
|
|
from torch._ops import PyOperator
|
|
from torch.utils._pytree import tree_flatten
|
|
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot
|
|
import torch.utils._pytree as pytree
|
|
from torch.utils._python_dispatch import _get_current_dispatch_mode, _pop_mode_temporarily
|
|
from torch.fx.experimental.proxy_tensor import track_tensor_tree
|
|
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
|
|
|
|
|
|
"""
|
|
We're going to define a `cond` operation.
|
|
In order to do this, we need implementations for each of the dispatch keys.
|
|
"""
|
|
cond = PyOperator('cond')
|
|
|
|
|
|
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
|
def _unwrap_proxy(e):
|
|
if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
|
|
return e
|
|
return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
|
|
|
|
assert isinstance(operands, list), "Cond operands must be a list of tensors"
|
|
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"
|
|
|
|
true_graph = get_isolated_graphmodule(true_fn, operands, {})
|
|
false_graph = get_isolated_graphmodule(false_fn, operands, {})
|
|
|
|
true_outs = []
|
|
false_outs = []
|
|
for node in true_graph.graph.nodes:
|
|
if node.op == 'output':
|
|
true_outs.extend(node.args)
|
|
|
|
for node in false_graph.graph.nodes:
|
|
if node.op == 'output':
|
|
false_outs.extend(node.args)
|
|
|
|
flat_true_outs, _ = pytree.tree_flatten(true_outs)
|
|
flat_false_outs, _ = pytree.tree_flatten(false_outs)
|
|
assert(len(flat_true_outs) == len(flat_false_outs))
|
|
|
|
for i in range(0, len(flat_true_outs)):
|
|
true_out = flat_true_outs[i]
|
|
false_out = flat_false_outs[i]
|
|
assert true_out.meta['tensor_meta'] == false_out.meta['tensor_meta']
|
|
|
|
# There are probably better ways - I know that create_arg has some self incrementing name
|
|
# magic to it, but since we explicitly have to get the name for register_module,
|
|
# I was not sure how to do that. This kinda simulates it.
|
|
next_name = None
|
|
i = 0
|
|
while not next_name:
|
|
candidate = f"true_graph_{i}"
|
|
if hasattr(proxy_mode.tracer.root, candidate):
|
|
i += 1
|
|
else:
|
|
next_name = candidate
|
|
|
|
true_name = next_name
|
|
false_name = f"false_graph_{i}"
|
|
assert(not hasattr(proxy_mode.tracer.root, false_name))
|
|
|
|
proxy_mode.tracer.root.register_module(true_name, true_graph)
|
|
proxy_mode.tracer.root.register_module(false_name, false_graph)
|
|
|
|
args = (pred, true_graph, false_graph, [operands])
|
|
|
|
proxy_args = pytree.tree_map(_unwrap_proxy, args)
|
|
|
|
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
|
|
name="conditional")
|
|
|
|
# At this point, we're *guaranteed* that whether an output came from the
|
|
# true or false branch is indistinguishable. So, as this is just for tracing
|
|
# purposes, choose the true branch.
|
|
|
|
# TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in
|
|
# a FakeTensorMode error :
|
|
# `Current active mode <class 'torch._subclasses.fake_tensor.FakeTensorMode'> not registered`
|
|
out = false_fn(*operands)
|
|
|
|
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
|
|
|
|
|
@cond.py_impl(DispatchKey.CPU)
|
|
def cond_dense(pred, true_fn, false_fn, operands):
|
|
mode = _get_current_dispatch_mode()
|
|
assert (mode is None), "Mode should never be enabled for CPU key"
|
|
if pred:
|
|
return true_fn(*operands)
|
|
else:
|
|
return false_fn(*operands)
|
|
|
|
|
|
@cond.py_impl(DispatchKey.AutogradCPU)
|
|
def cond_autograd(pred, true_fn, false_fn, *operands):
|
|
# TODO: support autograd
|
|
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
|
|
assert all([not f.requires_grad for f in flat_operands
|
|
if isinstance(f, torch.Tensor)])
|
|
|
|
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
|
|
return cond(pred, true_fn, false_fn, *operands)
|
|
|
|
|
|
@cond.py_impl(ProxyTorchDispatchMode)
|
|
def inner(pred, true_fn, false_fn, operands):
|
|
mode = _get_current_dispatch_mode()
|
|
assert (mode is not None), "Mode should always be enabled for python fallback key"
|
|
with _pop_mode_temporarily() as mode:
|
|
res = trace_cond(mode, cond, pred, true_fn, false_fn, operands)
|
|
return res
|
|
|
|
|
|
# TODO(voz): Make this automatic for keys, this is very ugly atm
|
|
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
|
|
cond.fallthrough(DispatchKey.ADInplaceOrView)
|
|
cond.fallthrough(DispatchKey.BackendSelect)
|