mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
We have known for a while that we should in principle support SymBool as a separate concept from SymInt and SymFloat ( in particular, every distinct numeric type should get its own API). However, recent work with unbacked SymInts in, e.g., https://github.com/pytorch/pytorch/pull/90985 have made this a priority to implement. The essential problem is that our logic for computing the contiguity of tensors performs branches on the passed in input sizes, and this causes us to require guards when constructing tensors from unbacked SymInts. Morally, this should not be a big deal because, we only really care about the regular (non-channels-last) contiguity of the tensor, which should be guaranteed since most people aren't calling `empty_strided` on the tensor, however, because we store a bool (not a SymBool, prior to this PR it doesn't exist) on TensorImpl, we are forced to *immediately* compute these values, even if the value ends up not being used at all. In particular, even when a user allocates a contiguous tensor, we still must compute channels-last contiguity (as some contiguous tensors are also channels-last contiguous, but others are not.) This PR implements SymBool, and makes TensorImpl use SymBool to store the contiguity information in ExtraMeta. There are a number of knock on effects, which I now discuss below. * I introduce a new C++ type SymBool, analogous to SymInt and SymFloat. This type supports logical and, logical or and logical negation. I support the bitwise operations on this class (but not the conventional logic operators) to make it clear that logical operations on SymBool are NOT short-circuiting. I also, for now, do NOT support implicit conversion of SymBool to bool (creating a guard in this case). This does matter too much in practice, as in this PR I did not modify the equality operations (e.g., `==` on SymInt) to return SymBool, so all preexisting implicit guards did not need to be changed. I also introduced symbolic comparison functions `sym_eq`, etc. on SymInt to make it possible to create SymBool. The current implementation of comparison functions makes it unfortunately easy to accidentally introduce guards when you do not mean to (as both `s0 == s1` and `s0.sym_eq(s1)` are valid spellings of equality operation); in the short term, I intend to prevent excess guarding in this situation by unit testing; in the long term making the equality operators return SymBool is probably the correct fix. * ~~I modify TensorImpl to store SymBool for the `is_contiguous` fields and friends on `ExtraMeta`. In practice, this essentially meant reverting most of the changes from https://github.com/pytorch/pytorch/pull/85936 . In particular, the fields on ExtraMeta are no longer strongly typed; at the time I was particularly concerned about the giant lambda I was using as the setter getting a desynchronized argument order, but now that I have individual setters for each field the only "big list" of boolean arguments is in the constructor of ExtraMeta, which seems like an acceptable risk. The semantics of TensorImpl are now that we guard only when you actually attempt to access the contiguity of the tensor via, e.g., `is_contiguous`. By in large, the contiguity calculation in the implementations now needs to be duplicated (as the boolean version can short circuit, but the SymBool version cannot); you should carefully review the duplicate new implementations. I typically use the `identity` template to disambiguate which version of the function I need, and rely on overloading to allow for implementation sharing. The changes to the `compute_` functions are particularly interesting; for most of the functions, I preserved their original non-symbolic implementation, and then introduce a new symbolic implementation that is branch-less (making use of our new SymBool operations). However, `compute_non_overlapping_and_dense` is special, see next bullet.~~ This appears to cause performance problems, so I am leaving this to an update PR. * (Update: the Python side pieces for this are still in this PR, but they are not wired up until later PRs.) While the contiguity calculations are relatively easy to write in a branch-free way, `compute_non_overlapping_and_dense` is not: it involves a sort on the strides. While in principle we can still make it go through by using a data oblivious sorting network, this seems like too much complication for a field that is likely never used (because typically, it will be obvious that a tensor is non overlapping and dense, because the tensor is contiguous.) So we take a different approach: instead of trying to trace through the logic computation of non-overlapping and dense, we instead introduce a new opaque operator IsNonOverlappingAndDenseIndicator which represents all of the compute that would have been done here. This function returns an integer 0 if `is_non_overlapping_and_dense` would have returned `False`, and an integer 1 otherwise, for technical reasons (Sympy does not easily allow defining custom functions that return booleans). The function itself only knows how to evaluate itself if all of its arguments are integers; otherwise it is left unevaluated. This means we can always guard on it (as `size_hint` will always be able to evaluate through it), but otherwise its insides are left a black box. We typically do NOT expect this custom function to show up in actual boolean expressions, because we will typically shortcut it due to the tensor being contiguous. It's possible we should apply this treatment to all of the other `compute_` operations, more investigation necessary. As a technical note, because this operator takes a pair of a list of SymInts, we need to support converting `ArrayRef<SymNode>` to Python, and I also unpack the pair of lists into a single list because I don't know if Sympy operations can actually validly take lists of Sympy expressions as inputs. See for example `_make_node_sizes_strides` * On the Python side, we also introduce a SymBool class, and update SymNode to track bool as a valid pytype. There is some subtlety here: bool is a subclass of int, so one has to be careful about `isinstance` checks (in fact, in most cases I replaced `isinstance(x, int)` with `type(x) is int` for expressly this reason.) Additionally, unlike, C++, I do NOT define bitwise inverse on SymBool, because it does not do the correct thing when run on booleans, e.g., `~True` is `-2`. (For that matter, they don't do the right thing in C++ either, but at least in principle the compiler can warn you about it with `-Wbool-operation`, and so the rule is simple in C++; only use logical operations if the types are statically known to be SymBool). Alas, logical negation is not overrideable, so we have to introduce `sym_not` which must be used in place of `not` whenever a SymBool can turn up. To avoid confusion with `__not__` which may imply that `operators.__not__` might be acceptable to use (it isn't), our magic method is called `__sym_not__`. The other bitwise operators `&` and `|` do the right thing with booleans and are acceptable to use. * There is some annoyance working with booleans in Sympy. Unlike int and float, booleans live in their own algebra and they support less operations than regular numbers. In particular, `sympy.expand` does not work on them. To get around this, I introduce `safe_expand` which only calls expand on operations which are known to be expandable. TODO: this PR appears to greatly regress performance of symbolic reasoning. In particular, `python test/functorch/test_aotdispatch.py -k max_pool2d` performs really poorly with these changes. Need to investigate. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/92149 Approved by: https://github.com/albanD, https://github.com/Skylion007
417 lines
12 KiB
C++
417 lines
12 KiB
C++
#include <torch/csrc/DynamicTypes.h>
|
|
#include <torch/csrc/THP.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/invalid_arguments.h>
|
|
#include <torch/csrc/utils/python_strings.h>
|
|
#include <torch/csrc/utils/python_symnode.h>
|
|
#include <torch/csrc/utils/python_tuples.h>
|
|
|
|
#include <torch/csrc/Export.h>
|
|
|
|
#include <algorithm>
|
|
#include <cstdarg>
|
|
#include <iterator>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
int THPUtils_getCallable(PyObject* arg, PyObject** result) {
|
|
if (!PyCallable_Check(arg))
|
|
return 0;
|
|
*result = arg;
|
|
return 1;
|
|
}
|
|
|
|
std::vector<int64_t> THPUtils_unpackLongs(PyObject* arg) {
|
|
bool tuple = PyTuple_Check(arg);
|
|
bool list = PyList_Check(arg);
|
|
if (tuple || list) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
const auto nDim = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
|
std::vector<int64_t> sizes(nDim);
|
|
for (int i = 0; i != nDim; ++i) {
|
|
PyObject* item =
|
|
tuple ? PyTuple_GET_ITEM(arg, i) : PyList_GET_ITEM(arg, i);
|
|
if (!THPUtils_checkLong(item)) {
|
|
std::ostringstream oss;
|
|
oss << "expected int at position " << i
|
|
<< ", but got: " << THPUtils_typename(item);
|
|
throw std::runtime_error(oss.str());
|
|
}
|
|
sizes[i] = THPUtils_unpackLong(item);
|
|
}
|
|
return sizes;
|
|
}
|
|
throw std::runtime_error("Expected tuple or list");
|
|
}
|
|
|
|
bool THPUtils_checkIntTuple(PyObject* arg) {
|
|
if (!PyTuple_Check(arg)) {
|
|
return false;
|
|
}
|
|
for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) {
|
|
if (!THPUtils_checkLong(PyTuple_GET_ITEM(arg, i))) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::vector<int> THPUtils_unpackIntTuple(PyObject* arg) {
|
|
if (!THPUtils_checkIntTuple(arg)) {
|
|
throw std::runtime_error("Couldn't unpack int tuple");
|
|
}
|
|
std::vector<int> values(PyTuple_GET_SIZE(arg));
|
|
for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) {
|
|
values[i] = (int)THPUtils_unpackLong(PyTuple_GET_ITEM(arg, i));
|
|
}
|
|
return values;
|
|
}
|
|
|
|
void THPUtils_setError(const char* format, ...) {
|
|
static const size_t ERROR_BUFFER_SIZE = 1000;
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
|
char buffer[ERROR_BUFFER_SIZE];
|
|
va_list fmt_args;
|
|
|
|
va_start(fmt_args, format);
|
|
vsnprintf(buffer, ERROR_BUFFER_SIZE, format, fmt_args);
|
|
va_end(fmt_args);
|
|
PyErr_SetString(PyExc_RuntimeError, buffer);
|
|
}
|
|
|
|
void THPUtils_addPyMethodDefs(
|
|
std::vector<PyMethodDef>& vector,
|
|
PyMethodDef* methods) {
|
|
if (!vector.empty()) {
|
|
// remove nullptr terminator
|
|
vector.pop_back();
|
|
}
|
|
while (true) {
|
|
vector.push_back(*methods);
|
|
if (!methods->ml_name) {
|
|
break;
|
|
}
|
|
methods++;
|
|
}
|
|
}
|
|
|
|
static const char* classOrTypename(PyObject* obj) {
|
|
if (PyType_Check(obj)) {
|
|
return ((PyTypeObject*)obj)->tp_name;
|
|
}
|
|
return Py_TYPE(obj)->tp_name;
|
|
}
|
|
|
|
PyObject* THPUtils_dispatchStateless(
|
|
PyObject* tensor,
|
|
const char* name,
|
|
PyObject* args,
|
|
PyObject* kwargs) {
|
|
THPObjectPtr methods(
|
|
PyObject_GetAttrString(tensor, THP_STATELESS_ATTRIBUTE_NAME));
|
|
if (!methods) {
|
|
return PyErr_Format(
|
|
PyExc_TypeError,
|
|
"Type %s doesn't implement stateless methods",
|
|
classOrTypename(tensor));
|
|
}
|
|
THPObjectPtr method(PyObject_GetAttrString(methods, name));
|
|
if (!method) {
|
|
return PyErr_Format(
|
|
PyExc_TypeError,
|
|
"Type %s doesn't implement stateless method %s",
|
|
classOrTypename(tensor),
|
|
name);
|
|
}
|
|
return PyObject_Call(method.get(), args, kwargs);
|
|
}
|
|
|
|
void THPUtils_invalidArguments(
|
|
PyObject* given_args,
|
|
PyObject* given_kwargs,
|
|
const char* function_name,
|
|
size_t num_options,
|
|
...) {
|
|
std::vector<std::string> option_strings;
|
|
va_list option_list;
|
|
va_start(option_list, num_options);
|
|
std::generate_n(
|
|
std::back_inserter(option_strings), num_options, [&option_list] {
|
|
return va_arg(option_list, const char*);
|
|
});
|
|
va_end(option_list);
|
|
|
|
PyErr_SetString(
|
|
PyExc_TypeError,
|
|
torch::format_invalid_args(
|
|
given_args, given_kwargs, function_name, option_strings)
|
|
.c_str());
|
|
}
|
|
|
|
template <>
|
|
void THPPointer<THPGenerator>::free() {
|
|
if (ptr)
|
|
Py_DECREF(ptr);
|
|
}
|
|
|
|
template class THPPointer<THPGenerator>;
|
|
|
|
static bool backCompatBroadcastWarn = false;
|
|
|
|
void setBackCompatBroadcastWarn(bool warn) {
|
|
backCompatBroadcastWarn = warn;
|
|
}
|
|
|
|
bool getBackCompatBroadcastWarn() {
|
|
return backCompatBroadcastWarn;
|
|
}
|
|
|
|
static bool backCompatKeepdimWarn = false;
|
|
|
|
void setBackCompatKeepdimWarn(bool warn) {
|
|
backCompatKeepdimWarn = warn;
|
|
}
|
|
|
|
bool getBackCompatKeepdimWarn() {
|
|
return backCompatKeepdimWarn;
|
|
}
|
|
|
|
bool maybeThrowBackCompatKeepdimWarn(char* func) {
|
|
if (getBackCompatKeepdimWarn()) {
|
|
std::ostringstream ss;
|
|
ss << "backwards compatibility: call to \"" << func
|
|
<< "\" uses default value for keepdim which has changed default to False. Consider passing as kwarg.",
|
|
PyErr_WarnEx(PyExc_UserWarning, ss.str().c_str(), 1);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <>
|
|
void THPPointer<THPStorage>::free() {
|
|
if (ptr)
|
|
Py_DECREF(ptr);
|
|
}
|
|
|
|
void storage_copy(at::Storage dst, at::Storage src, bool non_blocking) {
|
|
auto dst_options = c10::TensorOptions().device(dst.device()).dtype(at::kByte);
|
|
auto dst_t = at::empty({0}, {}, dst_options).set_(dst);
|
|
|
|
auto src_options = c10::TensorOptions().device(src.device()).dtype(at::kByte);
|
|
auto src_t = at::empty({0}, {}, src_options).set_(src);
|
|
dst_t.copy_(src_t, non_blocking);
|
|
}
|
|
|
|
void storage_fill(at::Storage self, uint8_t value) {
|
|
auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte);
|
|
auto self_t = at::empty({0}, {}, options).set_(self);
|
|
self_t.fill_(value);
|
|
}
|
|
|
|
void storage_set(at::Storage self, ptrdiff_t idx, uint8_t value) {
|
|
TORCH_CHECK(
|
|
(idx >= 0) && (idx < static_cast<ptrdiff_t>(self.nbytes())),
|
|
"out of bounds");
|
|
auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte);
|
|
auto self_t = at::empty({0}, {}, options).set_(self);
|
|
self_t[idx].fill_(value);
|
|
}
|
|
|
|
uint8_t storage_get(at::Storage self, ptrdiff_t idx) {
|
|
TORCH_CHECK(
|
|
(idx >= 0) && (idx < static_cast<ptrdiff_t>(self.nbytes())),
|
|
"out of bounds");
|
|
auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte);
|
|
auto self_t = at::empty({0}, {}, options).set_(self);
|
|
return self_t[idx].item<uint8_t>();
|
|
}
|
|
|
|
template class THPPointer<THPStorage>;
|
|
|
|
namespace torch {
|
|
namespace gdb {
|
|
/* ~~~ misc debugging utilities ~~~
|
|
*
|
|
* torch::gdb::* functions are NOT meant to be called by general pytorch code,
|
|
* but only from within a gdb session. As such, utils.h does not contain any
|
|
* declaration for those.
|
|
*/
|
|
|
|
// This is a helper needed by the torch-tensor-repr gdb command.
|
|
// Return an human-readable representation of the given Tensor. The resulting
|
|
// string is stored into a malloc()ed buffer. The caller is responsible to
|
|
// free() it. We use malloc() instead of new[] because it's much easier to
|
|
// call free than delete[] from withing gdb.
|
|
// Currently the code for computing the repr of a tensor is written in Python,
|
|
// so we need to wrap the Tensor into a Python object first.
|
|
char* tensor_repr(at::Tensor tensor) {
|
|
PyGILState_STATE gil = PyGILState_Ensure();
|
|
PyObject* pytensor = nullptr;
|
|
PyObject* repr = nullptr;
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
Py_ssize_t bufsize;
|
|
const char* buf = nullptr;
|
|
char* result = nullptr;
|
|
|
|
pytensor = THPVariable_Wrap(at::Tensor(tensor));
|
|
if (!pytensor)
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
|
goto error;
|
|
repr = PyObject_Repr(pytensor);
|
|
if (!repr)
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
|
goto error;
|
|
buf = PyUnicode_AsUTF8AndSize(repr, &bufsize);
|
|
if (!buf)
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
|
goto error;
|
|
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
|
|
result =
|
|
static_cast<char*>(malloc(bufsize + 1)); // account for the trailing \0
|
|
if (!result) {
|
|
fprintf(stderr, "cannot allocate memory for the result\n");
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
|
goto error;
|
|
}
|
|
// NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.strcpy)
|
|
strcpy(result, buf);
|
|
Py_XDECREF(pytensor);
|
|
Py_XDECREF(repr);
|
|
PyGILState_Release(gil);
|
|
return result;
|
|
|
|
error:
|
|
fprintf(stderr, "torch::gdb::tensor_repr: unexpected error\n");
|
|
if (PyErr_Occurred())
|
|
PyErr_Print();
|
|
Py_XDECREF(pytensor);
|
|
Py_XDECREF(repr);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
|
|
free(result);
|
|
PyGILState_Release(gil);
|
|
return nullptr;
|
|
}
|
|
|
|
} // namespace gdb
|
|
} // namespace torch
|
|
|
|
namespace pybind11 {
|
|
namespace detail {
|
|
|
|
bool type_caster<at::Tensor>::load(handle src, bool) {
|
|
PyObject* obj = src.ptr();
|
|
if (THPVariable_Check(obj)) {
|
|
value = THPVariable_Unpack(obj);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
handle type_caster<at::Tensor>::cast(
|
|
const at::Tensor& src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
return handle(THPVariable_Wrap(src));
|
|
}
|
|
|
|
bool type_caster<at::IntArrayRef>::load(handle src, bool) {
|
|
PyObject* source = src.ptr();
|
|
auto tuple = PyTuple_Check(source);
|
|
if (tuple || PyList_Check(source)) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
const auto size =
|
|
tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source);
|
|
v_value.resize(size);
|
|
for (const auto idx : c10::irange(size)) {
|
|
PyObject* obj =
|
|
tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
|
|
if (THPVariable_Check(obj)) {
|
|
v_value[idx] = THPVariable_Unpack(obj).item<int64_t>();
|
|
} else if (PyLong_Check(obj)) {
|
|
// use THPUtils_unpackLong after it is safe to include
|
|
// python_numbers.h
|
|
v_value[idx] = THPUtils_unpackLong(obj);
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
value = v_value;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
handle type_caster<at::IntArrayRef>::cast(
|
|
at::IntArrayRef src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
return handle(THPUtils_packInt64Array(src.size(), src.data()));
|
|
}
|
|
|
|
bool type_caster<at::SymIntArrayRef>::load(handle src, bool) {
|
|
PyObject* source = src.ptr();
|
|
|
|
auto tuple = PyTuple_Check(source);
|
|
if (tuple || PyList_Check(source)) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
const auto size =
|
|
tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source);
|
|
v_value.resize(size);
|
|
for (const auto idx : c10::irange(size)) {
|
|
PyObject* obj =
|
|
tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
|
|
|
|
if (THPVariable_Check(obj)) {
|
|
// TODO: this is for consistency with IntArrayRef but arguably
|
|
// we shouldn't really allow this on pybind11 casters
|
|
v_value[idx] = THPVariable_Unpack(obj).item<int64_t>();
|
|
} else if (torch::is_symint(py::handle(obj))) {
|
|
v_value[idx] = py::handle(obj).cast<c10::SymInt>();
|
|
} else if (PyLong_Check(obj)) {
|
|
v_value[idx] = c10::SymInt(THPUtils_unpackIndex(obj));
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
value = v_value;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
handle type_caster<at::SymIntArrayRef>::cast(
|
|
at::SymIntArrayRef src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
py::list t(src.size());
|
|
for (const auto i : c10::irange(src.size())) {
|
|
t[i] = py::cast(src[i]);
|
|
}
|
|
return t.release();
|
|
}
|
|
|
|
bool type_caster<at::ArrayRef<c10::SymNode>>::load(handle src, bool) {
|
|
TORCH_INTERNAL_ASSERT(0, "NYI");
|
|
}
|
|
handle type_caster<at::ArrayRef<c10::SymNode>>::cast(
|
|
at::ArrayRef<c10::SymNode> src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
py::list t(src.size());
|
|
for (const auto i : c10::irange(src.size())) {
|
|
// TODO: this is terrible but I don't know how to override when
|
|
// the SymNode is also explicitly cast by py::cast
|
|
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(src[i].get());
|
|
if (py_node) {
|
|
// Return the Python directly (unwrap)
|
|
t[i] = py_node->getPyObj();
|
|
} else {
|
|
t[i] = py::cast(src[i]);
|
|
}
|
|
}
|
|
return t.release();
|
|
}
|
|
|
|
} // namespace detail
|
|
} // namespace pybind11
|