Files
pytorch/torch/csrc/autograd/python_legacy_variable.cpp
Edward Yang e0aebe241d Refactor tensor_new.cpp to use TensorOptions instead of DispatchKey (#54034)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54034

Fixes #53544

I had to touch a bunch of lines but the refactoring was fairly
mechanical.  Here's how it works.

The basic concept behind this PR is that tensor_new.cpp was previously
abusing DispatchKey when it actually meant TensorOptions.  The provided
DispatchKey argument to most of the constructor functions typically
comes from torch::tensors::get_default_dispatch_key();  it doesn't
really make sense for people to set the default dispatch key, but
this got grandfathered in due to the old API set_default_tensor_type
(where the "Type" concept got refactored into "DispatchKey" concept
over time).  See also #53124.  But the upshot is that, semantically,
what we refer to as the default dispatch key really is more like
torch.set_default_tensor_type(torch.Tensor) versus
torch.set_default_tensor_type(torch.cuda.Tensor): clearly the user
wants to do something about *construction* of the tensor, and
TensorOptions captures that exactly.

So, how exactly to translate from one to the other?
- Sources (things that used to PRODUCE DispatchKey)
  - Most top level functions take a DispatchKey as their argument.  I
    use the new function dispatchKeyToTensorOptions to convert it into
    a TensorOptions
  - typeIdWithDefault now produces a TensorOptions (probably could do
    with a rename, though I didn't)
- Sinks (things that used to CONSUME DispatchKey)
  - Previously, the function options() was typically used to convert the
    DispatchKey into a TensorOptions.  Now its replacement build_options
    just takes a TensorOptions and sets some extra fields on it.
    Irritatingly, I can't just replace
    `build_options(options, scalar_type, device)` with
    `options.dtype(scalar_type).device(device)` because the semantics
    are slightly different: if device is nullopt, we should preserve
    the usage of the device specified in options (what options.device()
    does is overwrite the device unconditionally; e.g., if device is
    nullopt, unset device from options)
  - The other major sink for DispatchKey was `internal_new_from_data`,
    but it turns out it only really extracts the device type from
    the dispatch key.  Now it just pulls out the device from
    TensorOptions.
- To actually do the translation of DispatchKey to TensorOptions, I
  introduce new functions dispatchKeyToLayout (replicating
  layout_from_backend--there are still a few uses of this function
  so I couldn't delete it) and dispatchKeyToDeviceType (replacing
  computeDeviceType)
- In all internal functions, whenever DispatchKey is taken as an argument,
  I instead take TensorOptions as an argument, and pass it along.
- Anywhere `legacyExtractDispatchKey(other.key_set())` equality was
  previously used, I now do `other.options().type_equal()`, which
  is the intended BC for doing "backend to backend" comparisons
- There are a few places in the sparse constructors where we allocated
  a tensor for values, and then read out the dispatch key from the
  result to allocate the keys.  As best as I can tell, this is totally
  equivalent to just passing in the options to both values and indices
  (the only difference is dtype, which is captured via a separate
  argument)

This refactor doesn't really go far enough: for example, there are now
functions that take both TensorOptions and ScalarType, when really
the TensorOptions can capture this all.  I kept it solely just
s/DispatchKey/TensorOptions/ to reduce the number of possible bugs;
also, a lot of this will be mooted by a proper fix to #53124.

Even with this limited refactor, the payoff is sweet.  I can delete:

- backendToCPU
- backendToXPU
- backendToCUDA
- backendToHIP
- backendToBackendOfDeviceType

The reason I can do this is because I can simply overwrite layout in TensorOptions
to do the conversion, rather than having to type out each backend case
explicitly.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: bhosmer

Differential Revision: D27109509

Pulled By: ezyang

fbshipit-source-id: 91d16cfbc390127770362ac04fb43f7e070077e9
2021-03-19 09:08:32 -07:00

144 lines
5.7 KiB
C++

#include <torch/csrc/autograd/python_legacy_variable.h>
#include <ATen/ATen.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/python_function.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/jit/frontend/tracer.h>
using namespace at;
namespace torch { namespace autograd {
static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject *kwds) {
HANDLE_TH_ERRORS
THPObjectPtr _data;
PyObject *data = nullptr;
PyObject *grad_fn = nullptr;
char is_volatile = 0;
char requires_grad = 0;
const char* name = nullptr;
const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", "name", nullptr};
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbOz", (char**)accepted_args,
&data, &requires_grad, &is_volatile, &grad_fn, &name))
return nullptr;
if (grad_fn == Py_None)
grad_fn = nullptr;
if (is_volatile) {
auto r = PyErr_WarnEx(PyExc_UserWarning,
"volatile was removed and now has no effect. Use `with torch.no_grad():` "
"instead.", 1);
if (r != 0) throw python_error();
}
if (is_volatile && requires_grad) {
throw ValueError("Variable can't be volatile and require_grad at the same time!");
}
if (grad_fn && !THPFunction_Check(grad_fn)) {
throw TypeError("_grad_fn has to be a Function object or None, but got %s",
Py_TYPE(grad_fn)->tp_name);
}
Variable var;
if (!data || data == Py_None) {
// For legacy serialization code, create an empty tensor. This is also used
// by nn.Parameter() with no arguments.
auto dispatch_key = torch::tensors::get_default_dispatch_key();
auto scalar_type = torch::tensors::get_default_scalar_type();
auto options = TensorOptions(scalar_type)
.device(dispatchKeyToDeviceType(dispatch_key))
.layout(dispatchKeyToLayout(dispatch_key));
var = at::empty({0}, options);
} else if (THPVariable_Check(data)) {
var = ((THPVariable*)data)->cdata.detach();
} else {
throw torch::TypeError("Variable data has to be a tensor, but got %s",
Py_TYPE(data)->tp_name);
}
// We set `tensor`'s `allow_tensor_metadata_change` to true here, because we want to
// allow the following use case for backward compatibility:
//
// ```python
// var = Variable(torch.randn(2, 3))
// var.resize_(4, 5)
// ```
var.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
TORCH_CHECK(!grad_fn,
"_grad_fn argument to legacy Variable constructor is no longer supported. "
"Instead, please invoke your _grad_fn to produce a variable with it as the "
"_grad_fn.");
var.set_requires_grad(requires_grad);
if (name) {
impl::set_name(var, name);
}
if (jit::tracer::isTracing() && data && data != Py_None && THPVariable_Check(data)) {
if (auto *v = jit::tracer::getValueTrace(((THPVariable*)data)->cdata)) {
jit::tracer::setValueTrace(var, v);
}
}
return THPVariable_Wrap(std::move(var));
END_HANDLE_TH_ERRORS
}
PyTypeObject THPLegacyVariableType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._LegacyVariableBase", /* tp_name */
0, /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
THPVariable_pynew /* tp_new */
};
void init_legacy_variable(PyObject *module) {
if (PyType_Ready(&THPLegacyVariableType) < 0) {
throw python_error();
}
auto obj = (PyObject*)&THPLegacyVariableType;
Py_INCREF(obj);
if (PyModule_AddObject(module, "_LegacyVariableBase", obj) < 0) {
throw python_error();
}
}
}} // namespace torch::autograd