Functionalize and compute joint simultaneously. (#88063)

This also comes with some bug fixes that were uncovered from doing
this:

- Forward device calls to inner tensor in FunctionalTensorWrapper

- Make legacyExtractDispatchKey exclude Functionalize, so that
  it can get at the real device type key.  This is noncontroversial.

- Stop stripping dense from key set.  The reason for this is
  FunctionalWrapperTensor may be used in contexts where people
  query if it is dense or not.  If it doesn't report this correctly
  (from the dispatch key), it will cause errors.  This caused some
  torchbench models to fail when I did one-pass tracing.

- Save and restore reapply views TLS correctly

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88063
Approved by: https://github.com/bdhirsh
This commit is contained in:
Edward Z. Yang
2022-11-04 12:31:51 -07:00
committed by PyTorch MergeBot
parent 957a9b63c5
commit 0e3031f7e7
9 changed files with 148 additions and 55 deletions

View File

@ -37,22 +37,10 @@ void FunctionalTensorWrapper::set_constructor_metadata() {
// Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect
// to participate in the functorch transforms.
key_set_ = key_set_ - c10::functorch_transforms_ks - c10::python_ks;
// For better error handling,
// we also don't want our wrapper tensor to be able to dispatch directly
// to a backend kernel.
// Dispatching directly to e.g. a CPU kernel would always segfault,
// because wrapper tensors don't have any real data.
// (This should never happen because we should always hit a functionalization kernel,
// but can help make bugs less nasty).
// Here, we defensively remove any backend keys from the wrapper's keyset.
// We don't want to remove actual backend bits though (say we're redispatching to autograd;
// we need to know if we're dispatching to AutogradCPU or AutogradXLA).
// Instead, it's sufficient to remove the `Dense` dispatch key,
// which prevents us from accidentally trying to directly run a CPU/CUDA kernel.
key_set_ = key_set_.remove(c10::DispatchKey::Dense);
// We override a bunch of _custom(), so make sure they get called
// TODO: metadata copying may not actually be necessary then
set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
set_custom_device(true);
}
FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value)
@ -335,6 +323,9 @@ c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
std::move(version_counter), allow_tensor_metadata_change);
}
c10::Device FunctionalTensorWrapper::device_custom() const {
return value_.unsafeGetTensorImpl()->device();
}
at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const {
return value_.unsafeGetTensorImpl()->sizes();
}

View File

@ -148,6 +148,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
c10::SymInt sym_size_custom(int64_t d) const override;
c10::SymIntArrayRef sym_strides_custom() const override;
c10::SymInt sym_storage_offset_custom() const override;
c10::Device device_custom() const override;
private:
const char* tensorimpl_type_name() const override;

View File

@ -6,6 +6,7 @@
#include <ATen/record_function.h>
#include <ATen/SavedTensorHooks.h>
#include <ATen/FunctionalTensorWrapper.h>
namespace at {
@ -15,7 +16,8 @@ ThreadLocalState::ThreadLocalState()
functorch_tls_(functorch::getCopyOfFuncTorchTLS()),
autograd_tls_(c10::AutogradState::get_tls_state()),
python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()) {
python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()),
functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()) {
rf_tls_ = at::get_record_function_tls_();
saved_tensors_default_hooks_state_ = at::SavedTensorDefaultHooks::get_tls_state();
@ -53,6 +55,8 @@ void ThreadLocalState::setThreadLocalState(
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
functorch::setFuncTorchTLS(state.functorch_tls_);
at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_);
}
} // namespace at

View File

@ -74,6 +74,8 @@ class TORCH_API ThreadLocalState {
// TLS for saved tensors default hooks
at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
bool functionalization_reapply_views_state_;
friend class ThreadLocalStateGuard;
};

View File

@ -876,7 +876,10 @@ static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
// treatment;
return (s - autograd_dispatch_keyset_with_ADInplaceOrView -
autocast_dispatch_keyset -
DispatchKeySet({DispatchKey::PythonTLSSnapshot, DispatchKey::Python}))
DispatchKeySet(
{DispatchKey::Functionalize,
DispatchKey::PythonTLSSnapshot,
DispatchKey::Python}))
.highestPriorityTypeId();
}

View File

@ -1307,7 +1307,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* It can be expanded as needed in the future, e.g sparse Tensor.
*/
inline bool support_as_strided() const {
return is_nested() ? false : device().supports_as_strided();
if (is_nested()) {
return false;
}
if (key_set_.has(DispatchKey::Functionalize)) {
return false;
}
return device().supports_as_strided();
}
// ~~~~~ Autograd API ~~~~~

View File

@ -18,7 +18,6 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.nn.utils import stateless
from functorch import make_fx
from functorch.experimental import functionalize
from torch._dispatch.python import enable_python_dispatcher
from . import config
from .named_members_polyfill import _named_buffers, _named_parameters
@ -133,7 +132,99 @@ def setup_stacktrace_preservation_hooks(roots: List):
node.register_hook(get_posthook(special_stack))
def create_joint_forward_backward(fn):
# This is a version of functionalization that is specifically designed
# for the AOTAutograd use case. It might be generally applicable though
# (if so, move it out of this file), so I've tried to give it a name that
# describes what it does.
#
# Given a function f, it produces a new function g that:
#
# - Detaches all inputs before running f; the inner function
# does not directly participate in any pre-existing autograd.
# preserve_requires_grad is provided as a convenience to set the
# requires_grad on the new detached leaves in sync with the originals.
# (NB: In principle, you could backward through the pure operations
# produced by functionalization; this is not used for AOTAutograd
# and we have not tested it.)
#
# - Functionalizes all operations on f, under the assumption that the passed
# in function f must be "observationally pure"; that is, it cannot perform any
# mutations (inplace data or view operations) on the passed in inputs, nor is
# it allowed to directly close over tensors that aren't passed via its
# arguments. See
# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit
# for discussion how how to implement the more complicated case.
#
# Unlike functorch's variant, this doesn't use the functorch level system,
# instead it directly uses PyTorch's conventional dispatcher to hit the
# functionalization key. In particular, this means that FunctionalTensorWrapper
# can have autograd data stored directly on it.
#
# In typical AOTAutograd usage, the dispatch key order will look like:
#
# Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor
# outer tensor inner tensor
#
# TODO: Provide a faster version of this that assumes flat arguments
# (so no pytree necessary)
def detach_and_functionalize_pure(f, preserve_requires_grad=True):
@wraps(f)
def inner(*args, **kwargs):
def to_fun(t):
if isinstance(t, Tensor):
r = torch._to_functional_tensor(t)
# NB: r is a leaf; it has no grad_fn relating
# it to t. If t has autograd metadata, that
# metadata was preserved *inside* the r wrapper
if preserve_requires_grad:
r.requires_grad = t.requires_grad
return r
else:
return t
f_args, f_kwargs = pytree.tree_map(to_fun, (args, kwargs))
torch._enable_functionalization(reapply_views=True)
try:
outs = f(*f_args, **f_kwargs)
finally:
torch._disable_functionalization()
# Detect input mutation and error if found
flat_args, _ = pytree.tree_flatten((args, kwargs))
flat_f_args, _ = pytree.tree_flatten((f_args, f_kwargs))
# This is just for sanity checking, can be skipped
for arg, f_arg in zip(flat_args, flat_f_args):
if not isinstance(arg, Tensor):
continue
torch._sync(f_arg)
new_arg = torch._from_functional_tensor(f_arg)
# I want to do this assert, but it is annoying because
# we have operator tests that have mutating inputs. So
# I do something unsound instead
# assert arg is new_arg, "input argument was mutated, this is not valid"
if arg is not new_arg:
assert arg.shape == new_arg.shape
arg.copy_(new_arg)
def from_fun(t):
if not isinstance(t, Tensor) or not torch._is_functional_tensor(t):
return t
torch._sync(t)
return torch._from_functional_tensor(t)
return pytree.tree_map(from_fun, outs)
return inner
# This creates a joint forwards-backwards function given both
# the primals (to run forwards) and tangents (to run backwards).
#
# It has a precondition which is that the passed in function
# must be observationally pure; it is not permitted to mutate
# the primals or tangents.
def create_joint_forward_backward_pure(fn):
def joint_forward_backward(
primals: List[Any], tangents: List[Any]
) -> Tuple[List[Any], List[Any]]:
@ -366,7 +457,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
deduped_flat_args = remove_dupe_args(flat_args)
joint_forward_backward = create_joint_forward_backward(lambda *args: flat_fn(*add_dupe_args(args)))
joint_forward_backward = create_joint_forward_backward_pure(lambda *args: flat_fn(*add_dupe_args(args)))
out = flat_fn(*flat_args)
# Collect info on which output tensors require gradients,
@ -392,27 +483,13 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
if config.use_functionalize:
with enable_python_dispatcher():
# Trace once without decompositions, into a graph of ATen ops.
# NB: tracing_mode is real, as it's assumed the calling context setup
# fake tensor mode / symbolic shapes if that is needed
fx_g = make_fx(joint_forward_backward)(*joint_inputs)
context = disable_autocast_manager if disable_amp else nullcontext
def fake_fn(primals, tangents):
with torch.fx.traceback.override_stack_trace():
return torch.fx.Interpreter(fx_g).run(primals, tangents)
# Trace a second time, running functionalization, and THEN running decompositions.
# functionalization only acts on ATen today, and doesn't currently handle
# view and inplace ops that come from primtorch.
# Eventually, functionalization should support primtorch view/inplace ops,
# which will make it ok to run decompositions before functionalization.
with context():
fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs)
fx_g = make_fx(
detach_and_functionalize_pure(joint_forward_backward), aot_config.decompositions
)(*joint_inputs)
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
else:
warnings.warn("graph partitioning without functionalization is not sound, we may introduce errors")
fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(*joint_inputs)
if config.debug_joint:

View File

@ -416,6 +416,21 @@ class TestAOTAutograd(AOTTestCase):
self.assertEqual(ref_out, test_out)
def test_custom_autograd(self):
class CustomFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, grad_output):
return grad_output + 1
def f(x):
return CustomFn.apply(x)
self.verify_aot_autograd(f, [torch.randn(3)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_autocast_disable_guard(self):
guard = torch._C._DisableAutocast()
@ -1099,12 +1114,10 @@ symbolic_aot_autograd_failures = {
xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to...
xfail('matmul', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo...
xfail('max', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec...
xfail('max', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('median', ''), # could not find kernel
xfail('meshgrid', 'list_of_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('meshgrid', 'variadic_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('min', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec...
xfail('min', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('mode', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('msort', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
@ -1116,7 +1129,6 @@ symbolic_aot_autograd_failures = {
# Deleting this in a followup
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
xfail('nn.functional.pad', 'circular'),
xfail('nn.functional.poisson_nll_loss', ''),
xfail('nn.functional._scaled_dot_product_attention', ''), # Cannot call sizes() on tensor with symbolic ...

View File

@ -141,26 +141,23 @@ class TestFunctionalization(TestCase):
def forward(self, a_1):
view_copy = torch.ops.aten.view_copy.default(a_1, [1, 1024, 128, 128]); a_1 = None
clone = torch.ops.aten.clone.default(view_copy); view_copy = None
view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = None
view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128])
relu = torch.ops.aten.relu.default(view_copy_1); view_copy_1 = None
view_copy_2 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = None
sum_1 = torch.ops.aten.sum.default(relu)
ones_like = torch.ops.aten.ones_like.default(sum_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False, memory_format = torch.preserve_format); sum_1 = None
expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None
new_zeros = torch.ops.aten.new_zeros.default(expand_copy, [16777216])
as_strided_copy = torch.ops.aten.as_strided_copy.default(new_zeros, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0)
as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(new_zeros, [1, 1024, 128, 128], [16777216, 16384, 128, 1], 0)
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(new_zeros, expand_copy, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0); new_zeros = expand_copy = None
as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [1, 1024, 128, 128], [16777216, 16384, 128, 1], 0); as_strided_scatter = None
new_empty_strided = torch.ops.aten.new_empty_strided.default(as_strided_copy_2, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0)
as_strided_copy_4 = torch.ops.aten.as_strided_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0)
clone_1 = torch.ops.aten.clone.default(as_strided_copy_4, memory_format = torch.contiguous_format); as_strided_copy_4 = None
_reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(expand_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1]); expand_copy = None
new_empty_strided = torch.ops.aten.new_empty_strided.default(_reshape_alias_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
view_copy_3 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128])
view_copy_4 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128])
clone_1 = torch.ops.aten.clone.default(view_copy_4, memory_format = torch.contiguous_format); view_copy_4 = None
threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0); clone_1 = relu = None
_reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(as_strided_copy_2, [16, 64, 128, 128], [1048576, 16384, 128, 1])
detach_copy = torch.ops.aten.detach_copy.default(_reshape_alias_copy); _reshape_alias_copy = None
as_strided_scatter_1 = torch.ops.aten.as_strided_scatter.default(as_strided_copy_2, threshold_backward, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0); as_strided_copy_2 = threshold_backward = None
_reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(as_strided_scatter_1, [16, 64, 128, 128], [1048576, 16384, 128, 1]); as_strided_scatter_1 = None
detach_copy_1 = torch.ops.aten.detach_copy.default(_reshape_alias_copy_1); _reshape_alias_copy_1 = None
_reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(_reshape_alias_copy, [16, 64, 128, 128], [1048576, 16384, 128, 1]); _reshape_alias_copy = None
detach_copy = torch.ops.aten.detach_copy.default(_reshape_alias_copy_1); _reshape_alias_copy_1 = None
view_copy_5 = torch.ops.aten.view_copy.default(threshold_backward, [1, 1024, 128, 128]); threshold_backward = None
_reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_5, [16, 64, 128, 128], [1048576, 16384, 128, 1]); view_copy_5 = None
detach_copy_1 = torch.ops.aten.detach_copy.default(_reshape_alias_copy_2); _reshape_alias_copy_2 = None
return detach_copy_1
""") # noqa: B950