mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit f786fbdebdd24d3a6807e3b9fbf055836db4ad60. Forward fixes Pull Request resolved: https://github.com/pytorch/pytorch/pull/110964 Approved by: https://github.com/ezyang, https://github.com/anijain2305
653 lines
22 KiB
C++
653 lines
22 KiB
C++
#define PY_SSIZE_T_CLEAN
|
|
#include <c10/util/flat_hash_map.h>
|
|
#include <torch/csrc/autograd/grad_mode.h>
|
|
#include <torch/csrc/dynamo/guards.h>
|
|
#include <torch/csrc/utils/disable_torch_function.h>
|
|
#include <torch/csrc/utils/python_numbers.h>
|
|
#include <torch/csrc/utils/python_symnode.h>
|
|
#include <torch/extension.h>
|
|
#include <sstream>
|
|
|
|
namespace {
|
|
|
|
struct LocalState {
|
|
// TLS state that changes operators
|
|
c10::impl::LocalDispatchKeySet dispatch_modifier;
|
|
bool grad_mode_enabled;
|
|
|
|
at::DispatchKeySet apply(at::DispatchKeySet ks) const {
|
|
return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_;
|
|
}
|
|
|
|
LocalState()
|
|
: dispatch_modifier(c10::impl::tls_local_dispatch_key_set()),
|
|
grad_mode_enabled(at::GradMode::is_enabled()) {}
|
|
};
|
|
|
|
class TensorCheck {
|
|
public:
|
|
TensorCheck(
|
|
const LocalState& state,
|
|
PyTypeObject* pt,
|
|
const at::Tensor& v,
|
|
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
|
|
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)
|
|
: pytype(pt),
|
|
dispatch_key_(state.apply(v.key_set()).raw_repr()),
|
|
dtype_(v.dtype().toScalarType()),
|
|
device_index_(v.device().index()),
|
|
requires_grad_(v.requires_grad()),
|
|
sizes_(std::move(dynamic_dims_sizes)),
|
|
strides_(std::move(dynamic_dims_strides)),
|
|
dim_(static_cast<int64_t>(sizes_.size())) {
|
|
// TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should
|
|
// we just treat this as optional?
|
|
}
|
|
|
|
// See note in guards.py [Note - On Export Tensor Guards]
|
|
// Logic parallel to here must be maintained in python
|
|
bool check(const LocalState& state, const at::Tensor& v) {
|
|
if (dispatch_key_ != state.apply(v.key_set()).raw_repr() ||
|
|
dtype_ != v.dtype().toScalarType() ||
|
|
device_index_ != v.device().index() ||
|
|
requires_grad_ != v.requires_grad()) {
|
|
return false;
|
|
}
|
|
auto ndim = v.ndimension();
|
|
if (ndim != dim_) {
|
|
return false;
|
|
}
|
|
const auto& sizes = v.sym_sizes();
|
|
const auto& strides = v.sym_strides();
|
|
for (auto i : c10::irange(ndim)) {
|
|
auto known_size = sizes_[i];
|
|
auto known_stride = strides_[i];
|
|
if (known_size.has_value()) {
|
|
if (known_size.value() != sizes[i]) {
|
|
return false;
|
|
}
|
|
}
|
|
if (known_stride.has_value()) {
|
|
if (known_stride.value() != strides[i]) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::string check_verbose(
|
|
const LocalState& state,
|
|
const at::Tensor& v,
|
|
const std::string& tensor_name) {
|
|
std::stringstream fail_reason;
|
|
fail_reason << "tensor '" << tensor_name << "' ";
|
|
if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) {
|
|
// return fmt::format("tensor dispatch key mismatch. expected {}, actual
|
|
// {}", dispatch_key_, state.apply(v.key_set()).raw_repr());
|
|
fail_reason << "dispatch key set mismatch. expected "
|
|
<< c10::DispatchKeySet(
|
|
c10::DispatchKeySet::RAW, dispatch_key_)
|
|
<< ", actual " << state.apply(v.key_set());
|
|
return fail_reason.str();
|
|
} else if (dtype_ != v.dtype().toScalarType()) {
|
|
// return fmt::format("tensor dtype mismatch. expected {}, actual {}",
|
|
// dtype_, v.dtype().toScalarType());
|
|
fail_reason << "dtype mismatch. expected " << dtype_ << ", actual "
|
|
<< v.dtype().toScalarType();
|
|
return fail_reason.str();
|
|
} else if (device_index_ != v.device().index()) {
|
|
fail_reason
|
|
<< "Tensor device index mismatch. Expected device index to be "
|
|
<< device_index_ << ", actual " << v.device().index();
|
|
return fail_reason.str();
|
|
} else if (requires_grad_ != v.requires_grad()) {
|
|
// return fmt::format("tensor requires_grad mismatch. expected {}",
|
|
// requires_grad_);
|
|
fail_reason << "requires_grad mismatch. expected requires_grad="
|
|
<< requires_grad_;
|
|
return fail_reason.str();
|
|
}
|
|
auto ndim = v.ndimension();
|
|
if (ndim != dim_) {
|
|
// return fmt::format("tensor rank mismatch. expected {}, actual {}",
|
|
// sizes_.size(), ndim);
|
|
fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual "
|
|
<< ndim;
|
|
return fail_reason.str();
|
|
}
|
|
const auto& sizes = v.sym_sizes();
|
|
const auto& strides = v.sym_strides();
|
|
for (auto i : c10::irange(ndim)) {
|
|
auto known_size = sizes_[i];
|
|
auto known_stride = strides_[i];
|
|
if (known_size.has_value() && (known_size.value() != sizes[i])) {
|
|
fail_reason << "size mismatch at index " << i << ". expected "
|
|
<< known_size.value() << ", actual " << sizes[i];
|
|
return fail_reason.str();
|
|
}
|
|
if (known_stride.has_value() && known_stride.value() != strides[i]) {
|
|
fail_reason << "stride mismatch at index " << i << ". expected "
|
|
<< known_stride.value() << ", actual " << strides[i];
|
|
return fail_reason.str();
|
|
}
|
|
}
|
|
return "";
|
|
}
|
|
|
|
PyTypeObject* pytype;
|
|
|
|
private:
|
|
uint64_t dispatch_key_; // DispatchKeySet includes device/layout
|
|
at::ScalarType dtype_;
|
|
// Note(voz): While dispatch_key_ is sufficiently representative of a device
|
|
// In that keys are more granular AND device specific - they do not
|
|
// necessarily capture device indices correctly.
|
|
at::DeviceIndex device_index_;
|
|
bool requires_grad_;
|
|
// NB: These are unset if dynamic shapes is enabled.
|
|
std::vector<std::optional<c10::SymInt>> sizes_;
|
|
std::vector<std::optional<c10::SymInt>> strides_;
|
|
// Not strictly required for dense tensors, but nested tensors need it.
|
|
int64_t dim_;
|
|
};
|
|
|
|
typedef std::vector<TensorCheck> ChecksList;
|
|
|
|
typedef struct {
|
|
PyObject_HEAD;
|
|
ChecksList* checks;
|
|
} TensorGuards;
|
|
|
|
static void TensorGuards_dealloc(TensorGuards* self) {
|
|
if (self->checks != nullptr) {
|
|
delete self->checks;
|
|
self->checks = nullptr;
|
|
}
|
|
Py_TYPE(self)->tp_free((PyObject*)self);
|
|
}
|
|
|
|
static PyObject* TensorGuards_new(
|
|
PyTypeObject* type,
|
|
PyObject* args,
|
|
PyObject* kwds) {
|
|
TensorGuards* self = (TensorGuards*)type->tp_alloc(type, 0);
|
|
if (self != nullptr) {
|
|
self->checks = new ChecksList();
|
|
}
|
|
return (PyObject*)self;
|
|
}
|
|
|
|
static std::vector<std::optional<c10::SymInt>> wrapIntegersInOptional(
|
|
const c10::SymIntArrayRef& intArray) {
|
|
std::vector<std::optional<c10::SymInt>> optVec(intArray.size());
|
|
std::transform(
|
|
intArray.begin(),
|
|
intArray.end(),
|
|
optVec.begin(),
|
|
[](const c10::SymInt& value) { return std::make_optional(value); });
|
|
return optVec;
|
|
}
|
|
|
|
static std::vector<std::optional<c10::SymInt>> pyListToVecOptInt(
|
|
PyObject* pyList) {
|
|
std::vector<std::optional<c10::SymInt>> vec;
|
|
Py_ssize_t size = PyList_Size(pyList);
|
|
for (Py_ssize_t i = 0; i < size; i++) {
|
|
PyObject* item = PyList_GetItem(pyList, i);
|
|
auto handle = py::handle(item);
|
|
if (item == Py_None) {
|
|
vec.emplace_back(std::nullopt);
|
|
} else if (torch::is_symint(handle)) {
|
|
vec.emplace_back(py::cast<c10::SymInt>(handle));
|
|
} else {
|
|
int64_t value = PyLong_AsLongLong(item);
|
|
if (value == -1 && PyErr_Occurred()) {
|
|
PyErr_SetString(
|
|
PyExc_TypeError,
|
|
"Size or stride list item is not a valid integer.");
|
|
TORCH_CHECK(false, "Size or stride list item is not a valid integer.");
|
|
}
|
|
vec.emplace_back(c10::SymInt(value));
|
|
}
|
|
}
|
|
return vec;
|
|
}
|
|
|
|
static std::vector<std::vector<std::optional<c10::SymInt>>> get_dynamic_dims(
|
|
PyObject* dynamic_dims_py) {
|
|
std::vector<std::vector<std::optional<c10::SymInt>>> per_tensor_dynamic_dims;
|
|
if (dynamic_dims_py != Py_None) {
|
|
Py_ssize_t size = PyList_Size(dynamic_dims_py);
|
|
for (Py_ssize_t i = 0; i < size; i++) {
|
|
PyObject* py_list = PyList_GetItem(dynamic_dims_py, i);
|
|
std::vector<std::optional<c10::SymInt>> vec = pyListToVecOptInt(py_list);
|
|
per_tensor_dynamic_dims.push_back(std::move(vec));
|
|
}
|
|
}
|
|
return per_tensor_dynamic_dims;
|
|
}
|
|
|
|
static int TensorGuards_init(
|
|
TensorGuards* self,
|
|
PyObject* args,
|
|
PyObject* kwds) {
|
|
if (!PyTuple_CheckExact(args)) {
|
|
PyErr_SetString(PyExc_TypeError, "expected tuple()");
|
|
return -1;
|
|
}
|
|
// Top level structure is List[List[Union[int, None]]]
|
|
PyObject* dynamic_dims_sizes_py =
|
|
PyDict_GetItemString(kwds, "dynamic_dims_sizes");
|
|
if (dynamic_dims_sizes_py == nullptr) {
|
|
PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_sizes=...");
|
|
return -1;
|
|
}
|
|
PyObject* dynamic_dims_strides_py =
|
|
PyDict_GetItemString(kwds, "dynamic_dims_strides");
|
|
if (dynamic_dims_strides_py == nullptr) {
|
|
PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_strides=...");
|
|
return -1;
|
|
}
|
|
|
|
// dynamic_dims_strides/sizes_py is None when dynamic_shapes=False - this is
|
|
// an optimization to avoid invoking .size()/.stride() in python needlessly
|
|
std::vector<std::vector<std::optional<c10::SymInt>>>
|
|
per_tensor_dynamic_dims_sizes = get_dynamic_dims(dynamic_dims_sizes_py);
|
|
std::vector<std::vector<std::optional<c10::SymInt>>>
|
|
per_tensor_dynamic_dims_strides =
|
|
get_dynamic_dims(dynamic_dims_strides_py);
|
|
|
|
auto& checks = *self->checks;
|
|
auto len = PyTuple_GET_SIZE(args);
|
|
checks.reserve(len);
|
|
LocalState state;
|
|
|
|
for (auto i : c10::irange(len)) {
|
|
PyObject* item = PyTuple_GET_ITEM(args, i);
|
|
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
|
|
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
|
|
return -1;
|
|
}
|
|
auto tensor = THPVariable_Unpack(item);
|
|
std::vector<std::optional<c10::SymInt>> tensor_dims_size =
|
|
per_tensor_dynamic_dims_sizes.empty()
|
|
? wrapIntegersInOptional(tensor.sym_sizes())
|
|
: per_tensor_dynamic_dims_sizes[i];
|
|
std::vector<std::optional<c10::SymInt>> tensor_dims_stride =
|
|
per_tensor_dynamic_dims_strides.empty()
|
|
? wrapIntegersInOptional(tensor.sym_strides())
|
|
: per_tensor_dynamic_dims_strides[i];
|
|
|
|
checks.emplace_back(
|
|
state,
|
|
Py_TYPE(item),
|
|
std::move(tensor),
|
|
std::move(tensor_dims_size),
|
|
std::move(tensor_dims_stride));
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
PyObject* TensorGuards_check(
|
|
TensorGuards* self,
|
|
PyObject* args,
|
|
PyObject* kwargs) {
|
|
if (!PyTuple_CheckExact(args)) {
|
|
PyErr_SetString(PyExc_TypeError, "expected tuple()");
|
|
return nullptr;
|
|
}
|
|
auto& checks = *self->checks;
|
|
auto len = PyTuple_GET_SIZE(args);
|
|
|
|
// kwargs is just ignored here
|
|
|
|
if (static_cast<decltype(len)>(checks.size()) != len) {
|
|
PyErr_SetString(PyExc_TypeError, "wrong length");
|
|
return nullptr;
|
|
}
|
|
|
|
LocalState state;
|
|
// Note - all the tensors that make it to guards must be unique. Dynamo
|
|
// builder handles guarding for positive aliases (X is Y). However, we do not
|
|
// create guards for negative alias (X is not Y) as that is an N^2
|
|
// relationship. Instead, we rely on the uniqueness upstream to verify, at
|
|
// check_fn time (this function).
|
|
ska::flat_hash_map<PyObject*, std::nullptr_t> unique_tensors;
|
|
for (auto i : c10::irange(len)) {
|
|
PyObject* item = PyTuple_GET_ITEM(args, i);
|
|
|
|
if (Py_TYPE(item) != checks[i].pytype) {
|
|
Py_RETURN_FALSE;
|
|
}
|
|
auto insertion = unique_tensors.insert({item, nullptr});
|
|
if (!insertion.second) {
|
|
// Violates uniqueness
|
|
Py_RETURN_FALSE;
|
|
}
|
|
if (!checks[i].check(state, THPVariable_Unpack(item))) {
|
|
Py_RETURN_FALSE;
|
|
}
|
|
}
|
|
|
|
Py_RETURN_TRUE;
|
|
}
|
|
|
|
PyObject* TensorGuards_check_verbose(
|
|
TensorGuards* self,
|
|
PyObject* args,
|
|
PyObject* kwargs) {
|
|
if (!PyTuple_CheckExact(args)) {
|
|
PyErr_SetString(PyExc_TypeError, "expected tuple()");
|
|
return nullptr;
|
|
}
|
|
auto& checks = *self->checks;
|
|
auto len = PyTuple_GET_SIZE(args);
|
|
|
|
if (static_cast<decltype(len)>(checks.size()) != len) {
|
|
PyErr_SetString(PyExc_TypeError, "wrong length");
|
|
return nullptr;
|
|
}
|
|
|
|
PyObject* tensor_check_names_py =
|
|
PyDict_GetItemString(kwargs, "tensor_check_names");
|
|
if (tensor_check_names_py == nullptr) {
|
|
PyErr_SetString(PyExc_TypeError, "missing tensor_check_names kwarg");
|
|
return nullptr;
|
|
}
|
|
|
|
if (!PyList_Check(tensor_check_names_py)) {
|
|
PyErr_SetString(PyExc_TypeError, "tensor_check_names kwarg must be a list");
|
|
return nullptr;
|
|
}
|
|
|
|
auto names_size = PyList_Size(tensor_check_names_py);
|
|
if (names_size != static_cast<decltype(names_size)>(checks.size())) {
|
|
PyErr_SetString(
|
|
PyExc_TypeError,
|
|
"tensor_check_names should be the same size as # tensors");
|
|
return nullptr;
|
|
}
|
|
|
|
std::vector<std::string> tensor_check_names;
|
|
tensor_check_names.reserve(names_size);
|
|
for (auto i : c10::irange(names_size)) {
|
|
PyObject* value = PyList_GetItem(tensor_check_names_py, i);
|
|
if (!PyUnicode_Check(value)) {
|
|
PyErr_SetString(
|
|
PyExc_TypeError, "tensor_check_names must only contain strings");
|
|
return nullptr;
|
|
}
|
|
tensor_check_names.emplace_back(PyUnicode_AsUTF8(value));
|
|
}
|
|
|
|
LocalState state;
|
|
ska::flat_hash_map<PyObject*, std::nullptr_t> unique_tensors;
|
|
for (auto i : c10::irange(len)) {
|
|
PyObject* item = PyTuple_GET_ITEM(args, i);
|
|
if (Py_TYPE(item) != checks[i].pytype) {
|
|
std::stringstream fail_reason;
|
|
PyObject* type_str = PyObject_Str(PyObject_Type(item));
|
|
fail_reason << "expected type of '" << tensor_check_names[i]
|
|
<< "' to be a tensor type, ";
|
|
if (!type_str) {
|
|
fail_reason << "but found a different type";
|
|
} else {
|
|
fail_reason << "' but found " << PyUnicode_AsUTF8(type_str);
|
|
}
|
|
return Py_BuildValue("s", fail_reason.str().c_str());
|
|
}
|
|
|
|
auto insertion = unique_tensors.insert({item, nullptr});
|
|
if (!insertion.second) {
|
|
std::stringstream fail_reason;
|
|
fail_reason << "Duplicate tensor found where not expected! ";
|
|
fail_reason << tensor_check_names[i]
|
|
<< "should not alias to anything, but is aliased";
|
|
return Py_BuildValue("s", fail_reason.str().c_str());
|
|
}
|
|
std::string fail_reason = checks[i].check_verbose(
|
|
state, THPVariable_Unpack(item), tensor_check_names[i]);
|
|
if (fail_reason.length() > 0) {
|
|
return Py_BuildValue("s", fail_reason.c_str());
|
|
}
|
|
}
|
|
|
|
Py_RETURN_TRUE;
|
|
}
|
|
|
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
static PyMethodDef TensorGuards_methods[] = {
|
|
{"check",
|
|
(PyCFunction)(void*)TensorGuards_check,
|
|
METH_VARARGS | METH_KEYWORDS,
|
|
""},
|
|
{"check_verbose",
|
|
(PyCFunction)(void*)TensorGuards_check_verbose,
|
|
METH_VARARGS | METH_KEYWORDS,
|
|
"verbose fail reasons for failed checks"},
|
|
{nullptr} /* Sentinel */
|
|
};
|
|
|
|
static PyTypeObject TensorGuardsType = {PyVarObject_HEAD_INIT(nullptr, 0)};
|
|
|
|
struct GlobalStateGuard {
|
|
PyObject_HEAD;
|
|
|
|
inline void init() {
|
|
auto& ctx = at::globalContext();
|
|
_grad_mode = at::GradMode::is_enabled();
|
|
_torch_function = torch::torch_function_enabled();
|
|
_deterministic_algorithms = ctx.deterministicAlgorithms();
|
|
_deterministic_algorithms_warn_only = ctx.deterministicAlgorithmsWarnOnly();
|
|
_allow_tf32 = ctx.allowTF32CuBLAS();
|
|
_allow_fp16_reduce = ctx.allowFP16ReductionCuBLAS();
|
|
_allow_bf16_reduce = ctx.allowBF16ReductionCuBLAS();
|
|
_num_threads = at::get_num_threads();
|
|
_default_dtype = at::get_default_dtype();
|
|
}
|
|
|
|
inline bool check() {
|
|
auto& ctx = at::globalContext();
|
|
return (_grad_mode == at::GradMode::is_enabled() &&
|
|
_torch_function == torch::torch_function_enabled() &&
|
|
_deterministic_algorithms == ctx.deterministicAlgorithms() &&
|
|
_deterministic_algorithms_warn_only ==
|
|
ctx.deterministicAlgorithmsWarnOnly() &&
|
|
_allow_tf32 == ctx.allowTF32CuBLAS() &&
|
|
_allow_fp16_reduce == ctx.allowFP16ReductionCuBLAS() &&
|
|
_allow_bf16_reduce == ctx.allowBF16ReductionCuBLAS() &&
|
|
_num_threads == at::get_num_threads()) &&
|
|
_default_dtype == at::get_default_dtype();
|
|
}
|
|
|
|
bool _grad_mode;
|
|
bool _torch_function;
|
|
bool _deterministic_algorithms;
|
|
bool _deterministic_algorithms_warn_only;
|
|
bool _allow_tf32;
|
|
bool _allow_fp16_reduce;
|
|
bool _allow_bf16_reduce;
|
|
int _num_threads;
|
|
caffe2::TypeMeta _default_dtype;
|
|
// TODO(jansel): we should guard on more state as inductor starts using it
|
|
};
|
|
|
|
int GlobalStateGuard_init(
|
|
GlobalStateGuard* self,
|
|
PyObject* args,
|
|
PyObject* kwargs) {
|
|
self->init();
|
|
return 0;
|
|
}
|
|
|
|
PyObject* GlobalStateGuard_check(
|
|
GlobalStateGuard* self,
|
|
PyObject* args,
|
|
PyObject* kwargs) {
|
|
if (self->check()) {
|
|
Py_RETURN_TRUE;
|
|
} else {
|
|
Py_RETURN_FALSE;
|
|
}
|
|
}
|
|
|
|
static PyMethodDef GlobalStateGuard_methods[] = {
|
|
{"check",
|
|
(PyCFunction)(void*)GlobalStateGuard_check,
|
|
METH_NOARGS,
|
|
"Return true if global state was the same as at creation time"},
|
|
{nullptr}};
|
|
static PyTypeObject GlobalStateGuardType = {PyVarObject_HEAD_INIT(nullptr, 0)};
|
|
|
|
static PyObject* check_type_id(PyObject* dummy, PyObject* args) {
|
|
// faster `lambda obj, expected: id(type(obj)) == expected`
|
|
PyObject* obj = nullptr;
|
|
unsigned long long expected = 0;
|
|
if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) {
|
|
return nullptr;
|
|
}
|
|
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
|
if (Py_TYPE(obj) == (void*)expected) {
|
|
Py_RETURN_TRUE;
|
|
} else {
|
|
Py_RETURN_FALSE;
|
|
}
|
|
}
|
|
|
|
static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
|
|
// faster `lambda obj, expected: id(obj) == expected`
|
|
PyObject* obj = nullptr;
|
|
unsigned long long expected = 0;
|
|
if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) {
|
|
return nullptr;
|
|
}
|
|
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
|
if (obj == (void*)expected) {
|
|
Py_RETURN_TRUE;
|
|
} else {
|
|
Py_RETURN_FALSE;
|
|
}
|
|
}
|
|
|
|
static PyObject* dict_version(PyObject* dummy, PyObject* args) {
|
|
// Retrieves the version of a dictionary.
|
|
PyObject* obj = nullptr;
|
|
if (!PyArg_ParseTuple(args, "O", &obj)) {
|
|
return nullptr;
|
|
}
|
|
if (!PyDict_Check(obj)) {
|
|
return nullptr;
|
|
}
|
|
return THPUtils_packUInt64(((PyDictObject*)obj)->ma_version_tag);
|
|
}
|
|
|
|
static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
|
|
/*
|
|
Assert that a given tensor has a given size/stride, but ignore strides
|
|
of size==1 dimensions. Implemented in C++ as this is on the hot path.
|
|
*/
|
|
PyObject* item = nullptr;
|
|
PyObject* size = nullptr;
|
|
PyObject* stride = nullptr;
|
|
if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
|
|
return nullptr;
|
|
}
|
|
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
|
|
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
|
|
return nullptr;
|
|
}
|
|
if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
|
|
PyErr_SetString(PyExc_TypeError, "expected tuple()");
|
|
return nullptr;
|
|
}
|
|
at::Tensor tensor = THPVariable_Unpack(item);
|
|
int64_t ndim = tensor.ndimension();
|
|
if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
|
|
PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
|
|
return nullptr;
|
|
}
|
|
for (auto i : c10::irange(ndim)) {
|
|
int64_t want_size = THPUtils_unpackLong(PyTuple_GET_ITEM(size, i));
|
|
int64_t want_stride = THPUtils_unpackLong(PyTuple_GET_ITEM(stride, i));
|
|
int64_t actual_size = tensor.size(i);
|
|
int64_t actual_stride = tensor.stride(i);
|
|
if (want_size != actual_size ||
|
|
// ignore stride differences when size is 1
|
|
(want_stride != actual_stride && actual_size > 1)) {
|
|
std::stringstream msg;
|
|
msg << "expected size " << actual_size << "==" << want_size << ", stride "
|
|
<< actual_stride << "==" << want_stride << " at dim=" << i;
|
|
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
|
|
return nullptr;
|
|
}
|
|
}
|
|
Py_RETURN_TRUE;
|
|
}
|
|
|
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
static PyMethodDef _methods[] = {
|
|
{"check_type_id", check_type_id, METH_VARARGS, nullptr},
|
|
{"check_obj_id", check_obj_id, METH_VARARGS, nullptr},
|
|
{"assert_size_stride", assert_size_stride, METH_VARARGS, nullptr},
|
|
{"dict_version", dict_version, METH_VARARGS, NULL},
|
|
{nullptr, nullptr, 0, nullptr}};
|
|
|
|
static struct PyModuleDef _module = {
|
|
PyModuleDef_HEAD_INIT,
|
|
"torch._C._dynamo.guards",
|
|
"Module containing checks on tensors",
|
|
-1,
|
|
_methods};
|
|
|
|
} // namespace
|
|
|
|
PyObject* torch_c_dynamo_guards_init() {
|
|
// initialize TensorGuardsType
|
|
TensorGuardsType.tp_name = "torch._C._dynamo.guards.TensorGuards";
|
|
TensorGuardsType.tp_basicsize = sizeof(TensorGuards);
|
|
TensorGuardsType.tp_itemsize = 0;
|
|
TensorGuardsType.tp_dealloc = (destructor)TensorGuards_dealloc;
|
|
TensorGuardsType.tp_flags = Py_TPFLAGS_DEFAULT;
|
|
TensorGuardsType.tp_doc = "Check properties of a torch.Tensor";
|
|
TensorGuardsType.tp_methods = TensorGuards_methods;
|
|
TensorGuardsType.tp_init = (initproc)TensorGuards_init;
|
|
TensorGuardsType.tp_new = TensorGuards_new;
|
|
|
|
if (PyType_Ready(&TensorGuardsType) < 0)
|
|
return nullptr;
|
|
|
|
GlobalStateGuardType.tp_name = "torch._C._dynamo.guards.GlobalStateGuard";
|
|
GlobalStateGuardType.tp_basicsize = sizeof(GlobalStateGuard);
|
|
GlobalStateGuardType.tp_itemsize = 0;
|
|
GlobalStateGuardType.tp_flags = Py_TPFLAGS_DEFAULT;
|
|
GlobalStateGuardType.tp_doc = "Guard on PyTorch global flags such as no_grad";
|
|
GlobalStateGuardType.tp_methods = GlobalStateGuard_methods;
|
|
GlobalStateGuardType.tp_init = (initproc)GlobalStateGuard_init;
|
|
GlobalStateGuardType.tp_new = PyType_GenericNew;
|
|
|
|
if (PyType_Ready(&GlobalStateGuardType) < 0)
|
|
return nullptr;
|
|
|
|
auto m = PyModule_Create(&_module);
|
|
if (m == nullptr)
|
|
return nullptr;
|
|
|
|
Py_INCREF(&TensorGuardsType);
|
|
if (PyModule_AddObject(m, "TensorGuards", (PyObject*)&TensorGuardsType) < 0) {
|
|
Py_DECREF(&TensorGuardsType);
|
|
Py_DECREF(m);
|
|
return nullptr;
|
|
}
|
|
|
|
Py_INCREF(&GlobalStateGuardType);
|
|
if (PyModule_AddObject(
|
|
m, "GlobalStateGuard", (PyObject*)&GlobalStateGuardType) < 0) {
|
|
Py_DECREF(&GlobalStateGuardType);
|
|
Py_DECREF(m);
|
|
return nullptr;
|
|
}
|
|
|
|
return m;
|
|
}
|