mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Replace AT_CHECK with TORCH_CHECK [shard 9/10]
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20435 Reviewed By: jerryzh168 Differential Revision: D15318877 fbshipit-source-id: 4d83571187ea14a604fef83ac355d328b46d93e1
This commit is contained in:
committed by
Facebook Github Bot
parent
365fc26571
commit
73a97387c1
@ -630,7 +630,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
template <typename T>
|
||||
inline T * data() const {
|
||||
AT_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
AT_CHECK(has_storage(),
|
||||
TORCH_CHECK(has_storage(),
|
||||
"Cannot access data pointer of Tensor that doesn't have storage");
|
||||
AT_ASSERTM(
|
||||
storage_initialized(),
|
||||
@ -663,7 +663,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
*/
|
||||
inline void* data() const {
|
||||
AT_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
AT_CHECK(has_storage(),
|
||||
TORCH_CHECK(has_storage(),
|
||||
"Cannot access data pointer of Tensor that doesn't have storage");
|
||||
AT_ASSERT(dtype_initialized());
|
||||
return static_cast<void*>(
|
||||
@ -740,7 +740,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* which is harder to misuse.
|
||||
*/
|
||||
virtual void resize_dim(int64_t ndim) {
|
||||
AT_CHECK(allow_tensor_metadata_change(), "resize_dim is not allowed on Tensor created from .data or .detach()");
|
||||
TORCH_CHECK(allow_tensor_metadata_change(), "resize_dim is not allowed on Tensor created from .data or .detach()");
|
||||
sizes_.resize(ndim, 0);
|
||||
strides_.resize(ndim, 0);
|
||||
refresh_numel();
|
||||
@ -756,7 +756,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* which is harder to misuse.
|
||||
*/
|
||||
virtual void set_size(int64_t dim, int64_t new_size) {
|
||||
AT_CHECK(allow_tensor_metadata_change(), "set_size is not allowed on Tensor created from .data or .detach()");
|
||||
TORCH_CHECK(allow_tensor_metadata_change(), "set_size is not allowed on Tensor created from .data or .detach()");
|
||||
sizes_.at(dim) = new_size;
|
||||
refresh_numel();
|
||||
refresh_contiguous();
|
||||
@ -769,7 +769,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* which is harder to misuse.
|
||||
*/
|
||||
virtual void set_stride(int64_t dim, int64_t new_stride) {
|
||||
AT_CHECK(allow_tensor_metadata_change(), "set_stride is not allowed on Tensor created from .data or .detach()");
|
||||
TORCH_CHECK(allow_tensor_metadata_change(), "set_stride is not allowed on Tensor created from .data or .detach()");
|
||||
strides_[dim] = new_stride;
|
||||
refresh_numel();
|
||||
refresh_contiguous();
|
||||
@ -783,7 +783,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* (and resizing if necessary.)
|
||||
*/
|
||||
virtual void set_storage_offset(int64_t storage_offset) {
|
||||
AT_CHECK(allow_tensor_metadata_change(), "set_storage_offset is not allowed on Tensor created from .data or .detach()");
|
||||
TORCH_CHECK(allow_tensor_metadata_change(), "set_storage_offset is not allowed on Tensor created from .data or .detach()");
|
||||
storage_offset_ = storage_offset;
|
||||
}
|
||||
|
||||
@ -798,7 +798,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* See Note [We regret making Variable hold a Tensor]
|
||||
*/
|
||||
void set_sizes_contiguous(IntArrayRef new_size) {
|
||||
AT_CHECK(allow_tensor_metadata_change(), "set_sizes_contiguous is not allowed on Tensor created from .data or .detach()");
|
||||
TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_contiguous is not allowed on Tensor created from .data or .detach()");
|
||||
AT_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
auto old_dim = sizes_.size();
|
||||
auto new_dim = new_size.size();
|
||||
@ -823,9 +823,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* See Note [We regret making Variable hold a Tensor]
|
||||
*/
|
||||
void set_sizes_and_strides(IntArrayRef new_size, IntArrayRef new_stride) {
|
||||
AT_CHECK(allow_tensor_metadata_change(), "set_sizes_and_strides is not allowed on Tensor created from .data or .detach()");
|
||||
TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_and_strides is not allowed on Tensor created from .data or .detach()");
|
||||
AT_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
new_size.size() == new_stride.size(),
|
||||
"dimensionality of sizes (",
|
||||
new_size.size(),
|
||||
@ -1342,7 +1342,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
}
|
||||
|
||||
void set_storage(at::Storage storage) {
|
||||
AT_CHECK(allow_tensor_metadata_change(), "set_storage is not allowed on Tensor created from .data or .detach()");
|
||||
TORCH_CHECK(allow_tensor_metadata_change(), "set_storage is not allowed on Tensor created from .data or .detach()");
|
||||
storage_ = std::move(storage);
|
||||
data_type_ = storage_.dtype();
|
||||
device_opt_ = storage_.device();
|
||||
|
@ -68,7 +68,7 @@ public:
|
||||
/// Construct a CUDAStream from a Stream. This construction is checked,
|
||||
/// and will raise an error if the Stream is not, in fact, a CUDA stream.
|
||||
explicit CUDAStream(Stream stream) : stream_(stream) {
|
||||
AT_CHECK(stream_.device_type() == DeviceType::CUDA);
|
||||
TORCH_CHECK(stream_.device_type() == DeviceType::CUDA);
|
||||
}
|
||||
|
||||
/// Construct a CUDAStream from a Stream with no error checking.
|
||||
|
@ -145,13 +145,13 @@ class ArrayRef final {
|
||||
|
||||
/// front - Get the first element.
|
||||
AT_CPP14_CONSTEXPR const T& front() const {
|
||||
AT_CHECK(!empty(), "ArrayRef: attempted to access front() of empty list");
|
||||
TORCH_CHECK(!empty(), "ArrayRef: attempted to access front() of empty list");
|
||||
return Data[0];
|
||||
}
|
||||
|
||||
/// back - Get the last element.
|
||||
AT_CPP14_CONSTEXPR const T& back() const {
|
||||
AT_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
|
||||
TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
|
||||
return Data[Length - 1];
|
||||
}
|
||||
|
||||
@ -163,7 +163,7 @@ class ArrayRef final {
|
||||
/// slice(n, m) - Chop off the first N elements of the array, and keep M
|
||||
/// elements in the array.
|
||||
AT_CPP14_CONSTEXPR ArrayRef<T> slice(size_t N, size_t M) const {
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
N + M <= size(),
|
||||
"ArrayRef: invalid slice, N = ",
|
||||
N,
|
||||
@ -188,7 +188,7 @@ class ArrayRef final {
|
||||
|
||||
/// Vector compatibility
|
||||
AT_CPP14_CONSTEXPR const T& at(size_t Index) const {
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
Index < Length,
|
||||
"ArrayRef: invalid index Index = ",
|
||||
Index,
|
||||
|
@ -26,8 +26,8 @@ std::ostream& operator<<(std::ostream& out, const SourceLocation& loc) {
|
||||
}
|
||||
|
||||
size_t ReplaceAll(std::string& s, const char* from, const char* to) {
|
||||
AT_CHECK(from && *from, "");
|
||||
AT_CHECK(to, "");
|
||||
TORCH_CHECK(from && *from, "");
|
||||
TORCH_CHECK(to, "");
|
||||
|
||||
size_t numReplaced = 0;
|
||||
std::string::size_type lenFrom = std::strlen(from);
|
||||
|
@ -27,7 +27,7 @@ void NUMABind(int numa_node_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
numa_node_id <= numa_max_node(),
|
||||
"NUMA node id ",
|
||||
numa_node_id,
|
||||
@ -46,7 +46,7 @@ int GetNUMANode(const void* ptr) {
|
||||
AT_ASSERT(ptr);
|
||||
|
||||
int numa_node = -1;
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
get_mempolicy(
|
||||
&numa_node,
|
||||
NULL,
|
||||
@ -83,7 +83,7 @@ void NUMAMove(void* ptr, size_t size, int numa_node_id) {
|
||||
numa_node_id >= 0 &&
|
||||
static_cast<unsigned>(numa_node_id) < sizeof(unsigned long) * 8);
|
||||
unsigned long mask = 1UL << numa_node_id;
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
mbind(
|
||||
reinterpret_cast<void*>(page_start_ptr),
|
||||
size + offset,
|
||||
|
@ -182,7 +182,7 @@ class C10OperatorWrapper final : public Operator<Context> {
|
||||
if (default_value.has_value()) {
|
||||
return this->template GetSingleArgument<T>(name, default_value->to<T>());
|
||||
} else {
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
this->template HasSingleArgumentOfType<T>(name),
|
||||
"Error in caffe2->c10 wrapper: Expected argument '",
|
||||
name,
|
||||
|
@ -422,7 +422,7 @@ void addObjectMethods(py::module& m) {
|
||||
.def("_wrap_tensor_impl", [](Blob* blob, void* ptr) {
|
||||
auto p = c10::intrusive_ptr<c10::TensorImpl, at::UndefinedTensorImpl>::
|
||||
unsafe_reclaim_from_nonowning(static_cast<c10::TensorImpl*>(ptr));
|
||||
AT_CHECK(p.defined(), "Can't wrap undefined tensor");
|
||||
TORCH_CHECK(p.defined(), "Can't wrap undefined tensor");
|
||||
auto at_tensor = at::Tensor::wrap_tensor_impl(std::move(p));
|
||||
BlobSetTensor(blob, Tensor(std::move(at_tensor)));
|
||||
});
|
||||
|
@ -197,7 +197,7 @@ void testRegisterFusionCachesKernel(std::ostream& out = std::cout) {
|
||||
std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
|
||||
return node->kind() == prim::FusionGroup;
|
||||
});
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
maybe_fusion_group != nodes.end(),
|
||||
"testRegisterFusionCachesKernel: could not create FusionGroup");
|
||||
return *maybe_fusion_group;
|
||||
|
@ -642,22 +642,22 @@ void checkTracedInputs(const TracedTestInputs& inputs) {
|
||||
const auto& sizes = std::get<1>(input);
|
||||
if (fn == "test") {
|
||||
found_test = true;
|
||||
AT_CHECK(sizes.size() == 1);
|
||||
AT_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
||||
TORCH_CHECK(sizes.size() == 1);
|
||||
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
||||
} else if (fn == "test::pow") {
|
||||
found_pow = true;
|
||||
AT_CHECK(sizes.size() == 2);
|
||||
AT_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
||||
AT_CHECK(sizes[1].empty());
|
||||
TORCH_CHECK(sizes.size() == 2);
|
||||
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
||||
TORCH_CHECK(sizes[1].empty());
|
||||
} else if (fn.find("::mul") != std::string::npos) {
|
||||
found_mul = true;
|
||||
AT_CHECK(sizes.size() > 1);
|
||||
AT_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
||||
TORCH_CHECK(sizes.size() > 1);
|
||||
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
||||
}
|
||||
}
|
||||
AT_CHECK(found_test);
|
||||
AT_CHECK(found_pow);
|
||||
AT_CHECK(found_mul);
|
||||
TORCH_CHECK(found_test);
|
||||
TORCH_CHECK(found_pow);
|
||||
TORCH_CHECK(found_mul);
|
||||
}
|
||||
|
||||
std::string getFullName(const autograd::profiler::RecordFunction* fn_ptr) {
|
||||
@ -736,7 +736,7 @@ void testAutogradProfiler() {
|
||||
for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos;
|
||||
count++, pos++) {
|
||||
}
|
||||
AT_CHECK(count == 200);
|
||||
TORCH_CHECK(count == 200);
|
||||
}
|
||||
|
||||
void testNoneSchemaMatch() {
|
||||
|
@ -40,7 +40,7 @@ struct ComplexCPUType : public at::CPUTypeDefault {
|
||||
AT_ASSERT(options.device().is_cpu());
|
||||
|
||||
for (auto x: size) {
|
||||
AT_CHECK(x >= 0, "Trying to create tensor using size with negative dimension: ", size);
|
||||
TORCH_CHECK(x >= 0, "Trying to create tensor using size with negative dimension: ", size);
|
||||
}
|
||||
auto* allocator = at::getCPUAllocator();
|
||||
int64_t nelements = at::prod_intlist(size);
|
||||
|
@ -6,8 +6,8 @@
|
||||
void sigmoid_add_cuda(const float* x, const float* y, float* output, int size);
|
||||
|
||||
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
|
||||
AT_CHECK(x.type().is_cuda(), "x must be a CUDA tensor");
|
||||
AT_CHECK(y.type().is_cuda(), "y must be a CUDA tensor");
|
||||
TORCH_CHECK(x.type().is_cuda(), "x must be a CUDA tensor");
|
||||
TORCH_CHECK(y.type().is_cuda(), "y must be a CUDA tensor");
|
||||
auto output = torch::zeros_like(x);
|
||||
sigmoid_add_cuda(
|
||||
x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
|
||||
|
@ -1581,7 +1581,7 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
|
||||
// This makes no assumption on the signs of sigma.
|
||||
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
|
||||
bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) {
|
||||
AT_CHECK(compute_uv,
|
||||
TORCH_CHECK(compute_uv,
|
||||
"svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ",
|
||||
"and hence we cannot compute backward. Please use torch.svd(compute_uv=True)");
|
||||
|
||||
@ -1664,7 +1664,7 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
|
||||
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
|
||||
Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
|
||||
bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v) {
|
||||
AT_CHECK(eigenvectors,
|
||||
TORCH_CHECK(eigenvectors,
|
||||
"symeig_backward: Setting eigenvectors to false in torch.symeig doesn't compute eigenvectors ",
|
||||
"and hence we cannot compute backward. Please use torch.symeig(eigenvectors=True)");
|
||||
|
||||
|
@ -121,7 +121,7 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k
|
||||
.pinned_memory(r.toBool(5));
|
||||
return wrap(dispatch_arange(end, options));
|
||||
} else {
|
||||
AT_CHECK(!r.toBool(5), " `pin_memory` and `out` parameters are incompatible");
|
||||
TORCH_CHECK(!r.toBool(5), " `pin_memory` and `out` parameters are incompatible");
|
||||
check_out_type_matches(r.tensor(1), r.scalartype(2), r.isNone(2), r.layout(3), r.isNone(3),
|
||||
r.device(4), r.isNone(4));
|
||||
return wrap(dispatch_arange(r.scalar(0), r.tensor(1)).set_requires_grad(r.toBool(6)));
|
||||
@ -141,7 +141,7 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k
|
||||
.pinned_memory(r.toBool(7));
|
||||
return wrap(dispatch_arange(start, end, step, options));
|
||||
} else {
|
||||
AT_CHECK(!r.toBool(7), " `pin_memory` and `out` parameters are incompatible");
|
||||
TORCH_CHECK(!r.toBool(7), " `pin_memory` and `out` parameters are incompatible");
|
||||
check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), r.layout(5), r.isNone(5),
|
||||
r.device(6), r.isNone(6));
|
||||
return wrap(dispatch_arange(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(8)));
|
||||
|
@ -321,7 +321,7 @@ static PyObject * THPVariable_cuda(PyObject* self, PyObject* args, PyObject* kwa
|
||||
ParsedArgs<2> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
auto device = r.isNone(0) ? at::Device(at::DeviceType::CUDA) : r.device(0);
|
||||
AT_CHECK(device.is_cuda(), "Invalid device, must be cuda device");
|
||||
TORCH_CHECK(device.is_cuda(), "Invalid device, must be cuda device");
|
||||
torch::utils::cuda_lazy_init();
|
||||
return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false));
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
@ -67,7 +67,7 @@ PyObject *THPDevice_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
|
||||
device_index = r.toInt64(1);
|
||||
// -1 is allowed in ATen/C++, to mean the default device, but not in
|
||||
// Python.
|
||||
AT_CHECK(device_index >= 0, "Device index must not be negative");
|
||||
TORCH_CHECK(device_index >= 0, "Device index must not be negative");
|
||||
}
|
||||
at::Device device(as_device.type(), device_index);
|
||||
return THPDevice_New(device);
|
||||
|
@ -55,7 +55,7 @@ PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
|
||||
|
||||
torch::ParsedArgs<1> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
AT_CHECK(r.idx < 2, "Not a type");
|
||||
TORCH_CHECK(r.idx < 2, "Not a type");
|
||||
at::ScalarType scalar_type;
|
||||
if (r.idx == 1) {
|
||||
scalar_type = torch::tensors::get_default_scalar_type();
|
||||
@ -81,7 +81,7 @@ PyObject* THPIInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
|
||||
});
|
||||
torch::ParsedArgs<1> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
AT_CHECK(r.idx == 0, "Not a type");
|
||||
TORCH_CHECK(r.idx == 0, "Not a type");
|
||||
|
||||
at::ScalarType scalar_type = r.scalartype(0);
|
||||
if (!at::isIntegralType(scalar_type)) {
|
||||
|
@ -38,7 +38,7 @@ make_data_loader(
|
||||
Dataset dataset,
|
||||
DataLoaderOptions options = DataLoaderOptions()) {
|
||||
const optional<size_t> size = dataset.size();
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
size.has_value(),
|
||||
"Expected the dataset to be sized in "
|
||||
"order to construct the Sampler");
|
||||
|
@ -55,7 +55,7 @@ class DataLoaderBase {
|
||||
/// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(),
|
||||
/// output_iterator)` are supported too.
|
||||
Iterator<Batch> begin() {
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
shuttle_.in_flight_jobs() == 0,
|
||||
"Attempted to get a new DataLoader iterator "
|
||||
"while another iterator is not yet exhausted");
|
||||
|
@ -112,7 +112,7 @@ class BatchDataBuffer {
|
||||
batch_example_indices.value().size() == example_count)
|
||||
BatchRequestType& indices = batch_example_indices.value();
|
||||
for (size_t i : indices) {
|
||||
AT_CHECK(i < data_size, "Index out of range");
|
||||
TORCH_CHECK(i < data_size, "Index out of range");
|
||||
batch.emplace_back(std::move(data[i]));
|
||||
}
|
||||
remaining_size -= example_count;
|
||||
@ -249,16 +249,16 @@ struct ChunkDatasetOptions {
|
||||
: preloader_count_(preloader_count),
|
||||
batch_size_(batch_size),
|
||||
cache_size_(cache_size) {
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
preloader_count_ > 0,
|
||||
"Preloader count is 0. At least one preloader needs to be specified.");
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
batch_size_ > 0,
|
||||
"Batch size is 0. A positive batch size needs to be specified.");
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
cache_size_ > 0,
|
||||
"Cache size is 0. A positive cache size needs to be specified.");
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
cache_size_ >= batch_size_,
|
||||
"Cache size is less than batch size. Cache needs to be large enough to "
|
||||
"hold at least one batch.");
|
||||
@ -323,11 +323,11 @@ class ChunkDataset final
|
||||
/// is dataset agnostic and does not need overriding in different chunk
|
||||
/// datasets.
|
||||
BatchType get_batch(size_t batch_size) override {
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
batch_buffer_ != nullptr,
|
||||
"Dataset needs to call reset() before calling get_batch().");
|
||||
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
batch_size == options_.batch_size_,
|
||||
"The requested batch size does not match with the initialized batch size.\n"
|
||||
" The requested batch size is ", batch_size,
|
||||
|
@ -50,7 +50,7 @@ struct ValidIterator : public IteratorImpl<Batch> {
|
||||
void next() override {
|
||||
// If we didn't get the very first batch yet, get it now.
|
||||
lazy_initialize();
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
batch_.has_value(), "Attempted to increment iterator past the end");
|
||||
// Increment to the next batch.
|
||||
batch_ = next_batch_();
|
||||
@ -62,7 +62,7 @@ struct ValidIterator : public IteratorImpl<Batch> {
|
||||
Batch& get() override {
|
||||
// If we didn't get the very first batch yet, get it now.
|
||||
lazy_initialize();
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
batch_.has_value(),
|
||||
"Attempted to dereference iterator that was past the end");
|
||||
return batch_.value();
|
||||
|
@ -31,7 +31,7 @@ class ExpandingArray {
|
||||
/// at runtime.
|
||||
/*implicit*/ ExpandingArray(at::ArrayRef<T> values) {
|
||||
// clang-format off
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
values.size() == D,
|
||||
"Expected ", D, " values, but instead got ", values.size());
|
||||
// clang-format on
|
||||
|
@ -41,7 +41,7 @@ class Cloneable : public virtual Module {
|
||||
copy->buffers_.clear();
|
||||
copy->children_.clear();
|
||||
copy->reset();
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
copy->parameters_.size() == parameters_.size(),
|
||||
"The cloned module does not have the same number of "
|
||||
"parameters as the original module after calling reset(). "
|
||||
@ -52,7 +52,7 @@ class Cloneable : public virtual Module {
|
||||
copy->parameters_[parameter.key()].set_data(
|
||||
device ? data.to(*device) : data);
|
||||
}
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
copy->buffers_.size() == buffers_.size(),
|
||||
"The cloned module does not have the same number of "
|
||||
"buffers as the original module after calling reset(). "
|
||||
@ -62,7 +62,7 @@ class Cloneable : public virtual Module {
|
||||
auto data = autograd::Variable(*buffer).data().clone();
|
||||
copy->buffers_[buffer.key()].set_data(device ? data.to(*device) : data);
|
||||
}
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
copy->children_.size() == children_.size(),
|
||||
"The cloned module does not have the same number of "
|
||||
"child modules as the original module after calling reset(). "
|
||||
@ -80,7 +80,7 @@ class Cloneable : public virtual Module {
|
||||
// was registered under the same name as `this`), but you never know what
|
||||
// crazy things `reset()` does, so `dynamic_cast` just to be safe.
|
||||
auto clone = std::dynamic_pointer_cast<Derived>(other.clone(device));
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
clone != nullptr,
|
||||
"Attempted to clone submodule, but it is of a "
|
||||
"different type than the submodule it was to be cloned into");
|
||||
|
@ -566,8 +566,8 @@ template <typename ModuleType>
|
||||
std::shared_ptr<ModuleType> Module::register_module(
|
||||
std::string name,
|
||||
std::shared_ptr<ModuleType> module) {
|
||||
AT_CHECK(!name.empty(), "Submodule name must not be empty");
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(!name.empty(), "Submodule name must not be empty");
|
||||
TORCH_CHECK(
|
||||
name.find('.') == std::string::npos,
|
||||
"Submodule name must not contain a dot (got '",
|
||||
name,
|
||||
|
@ -383,7 +383,7 @@ struct AnyModule::Holder : public AnyModule::Placeholder {
|
||||
/// Calls `forward()` on the underlying module, casting each `Value` in the
|
||||
/// argument vector to a concrete value.
|
||||
Value forward(std::vector<Value>&& arguments) override {
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
arguments.size() == sizeof...(ArgumentTypes),
|
||||
c10::demangle(type_info.name()),
|
||||
"'s forward() method expects ",
|
||||
@ -466,7 +466,7 @@ AnyModule& AnyModule::operator=(std::shared_ptr<ModuleType> module) {
|
||||
|
||||
template <typename... ArgumentTypes>
|
||||
AnyModule::Value AnyModule::any_forward(ArgumentTypes&&... arguments) {
|
||||
AT_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
|
||||
TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
|
||||
std::vector<Value> values;
|
||||
values.reserve(sizeof...(ArgumentTypes));
|
||||
torch::apply(
|
||||
@ -483,13 +483,13 @@ ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
|
||||
|
||||
template <typename T, typename>
|
||||
T& AnyModule::get() {
|
||||
AT_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
|
||||
TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
|
||||
return get_<T>();
|
||||
}
|
||||
|
||||
template <typename T, typename>
|
||||
const T& AnyModule::get() const {
|
||||
AT_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
|
||||
TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
|
||||
return get_<T>();
|
||||
}
|
||||
|
||||
@ -499,20 +499,20 @@ T AnyModule::get() const {
|
||||
}
|
||||
|
||||
inline std::shared_ptr<Module> AnyModule::ptr() const {
|
||||
AT_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
|
||||
TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
|
||||
return content_->ptr();
|
||||
}
|
||||
|
||||
template <typename T, typename>
|
||||
std::shared_ptr<T> AnyModule::ptr() const {
|
||||
AT_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
|
||||
TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
|
||||
// Call get() but discard the value, just to do the type checking.
|
||||
get_<T>();
|
||||
return std::dynamic_pointer_cast<T>(ptr());
|
||||
}
|
||||
|
||||
inline const std::type_info& AnyModule::type_info() const {
|
||||
AT_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule");
|
||||
TORCH_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule");
|
||||
return content_->type_info;
|
||||
}
|
||||
|
||||
|
@ -161,7 +161,7 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
|
||||
/// \endrst
|
||||
template <typename ReturnType = Tensor, typename... InputTypes>
|
||||
ReturnType forward(InputTypes&&... inputs) {
|
||||
AT_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential");
|
||||
TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential");
|
||||
|
||||
auto iterator = modules_.begin();
|
||||
auto input = iterator->any_forward(std::forward<InputTypes>(inputs)...);
|
||||
@ -263,7 +263,7 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
|
||||
static_assert(
|
||||
torch::detail::is_module<T>::value,
|
||||
"Can only call Sequential::at with an nn::Module type");
|
||||
AT_CHECK(index < size(), "Index out of range");
|
||||
TORCH_CHECK(index < size(), "Index out of range");
|
||||
return modules_[index].get<T>();
|
||||
}
|
||||
|
||||
@ -275,7 +275,7 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
|
||||
static_assert(
|
||||
torch::detail::is_module<T>::value,
|
||||
"Can only call Sequential::at with an nn::Module type");
|
||||
AT_CHECK(index < size(), "Index out of range");
|
||||
TORCH_CHECK(index < size(), "Index out of range");
|
||||
return modules_[index].get<T>();
|
||||
}
|
||||
|
||||
@ -283,7 +283,7 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
|
||||
/// underlying module at the given index. Throws an exception if the index is
|
||||
/// out of bounds.
|
||||
std::shared_ptr<Module> ptr(size_t index) const {
|
||||
AT_CHECK(index < size(), "Index out of range");
|
||||
TORCH_CHECK(index < size(), "Index out of range");
|
||||
return modules_[index].ptr();
|
||||
}
|
||||
|
||||
@ -295,7 +295,7 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
|
||||
static_assert(
|
||||
torch::detail::is_module<T>::value,
|
||||
"Can only call Sequential::ptr with an nn::Module type");
|
||||
AT_CHECK(index < size(), "Index out of range");
|
||||
TORCH_CHECK(index < size(), "Index out of range");
|
||||
return modules_[index].ptr<T>();
|
||||
}
|
||||
|
||||
|
@ -73,10 +73,10 @@ std::vector<Tensor> parallel_apply(
|
||||
std::vector<ModuleType>& modules,
|
||||
const std::vector<Tensor>& inputs,
|
||||
const optional<std::vector<Device>>& devices = nullopt) {
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
modules.size() == inputs.size(), "Must have as many inputs as modules");
|
||||
if (devices) {
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
modules.size() == devices->size(),
|
||||
"Must have as many devices as modules");
|
||||
}
|
||||
@ -140,7 +140,7 @@ Tensor data_parallel(
|
||||
int64_t dim = 0) {
|
||||
if (!devices) {
|
||||
const auto device_count = torch::cuda::device_count();
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
device_count > 0, "Expected at least one CUDA device to be available");
|
||||
devices = std::vector<Device>();
|
||||
devices->reserve(device_count);
|
||||
|
@ -98,19 +98,19 @@ class ModuleHolder : torch::detail::ModuleHolderIndicator {
|
||||
|
||||
/// Returns a shared pointer to the underlying module.
|
||||
const std::shared_ptr<Contained>& ptr() const {
|
||||
AT_CHECK(!is_empty(), "Accessing empty ModuleHolder");
|
||||
TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder");
|
||||
return impl_;
|
||||
}
|
||||
|
||||
/// Returns a pointer to the underlying module.
|
||||
Contained* get() {
|
||||
AT_CHECK(!is_empty(), "Accessing empty ModuleHolder");
|
||||
TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder");
|
||||
return impl_.get();
|
||||
}
|
||||
|
||||
/// Returns a const pointer to the underlying module.
|
||||
const Contained* get() const {
|
||||
AT_CHECK(!is_empty(), "Accessing empty ModuleHolder");
|
||||
TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder");
|
||||
return impl_.get();
|
||||
}
|
||||
|
||||
|
@ -295,41 +295,41 @@ typename OrderedDict<Key, Value>::ConstIterator OrderedDict<Key, Value>::end()
|
||||
|
||||
template <typename Key, typename Value>
|
||||
typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::front() {
|
||||
AT_CHECK(!items_.empty(), "Called front() on an empty OrderedDict");
|
||||
TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict");
|
||||
return items_.front();
|
||||
}
|
||||
|
||||
template <typename Key, typename Value>
|
||||
const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::front()
|
||||
const {
|
||||
AT_CHECK(!items_.empty(), "Called front() on an empty OrderedDict");
|
||||
TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict");
|
||||
return items_.front();
|
||||
}
|
||||
|
||||
template <typename Key, typename Value>
|
||||
typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::back() {
|
||||
AT_CHECK(!items_.empty(), "Called back() on an empty OrderedDict");
|
||||
TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict");
|
||||
return items_.back();
|
||||
}
|
||||
|
||||
template <typename Key, typename Value>
|
||||
const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::back()
|
||||
const {
|
||||
AT_CHECK(!items_.empty(), "Called back() on an empty OrderedDict");
|
||||
TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict");
|
||||
return items_.back();
|
||||
}
|
||||
|
||||
template <typename Key, typename Value>
|
||||
typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::operator[](
|
||||
size_t index) {
|
||||
AT_CHECK(index < items_.size(), "Index ", index, " is out of bounds");
|
||||
TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds");
|
||||
return items_[index];
|
||||
}
|
||||
|
||||
template <typename Key, typename Value>
|
||||
const typename OrderedDict<Key, Value>::
|
||||
Item& OrderedDict<Key, Value>::operator[](size_t index) const {
|
||||
AT_CHECK(index < items_.size(), "Index ", index, " is out of bounds");
|
||||
TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds");
|
||||
return items_[index];
|
||||
}
|
||||
|
||||
@ -352,7 +352,7 @@ const Value& OrderedDict<Key, Value>::operator[](const Key& key) const {
|
||||
template <typename Key, typename Value>
|
||||
template <typename K, typename V>
|
||||
Value& OrderedDict<Key, Value>::insert(K&& key, V&& value) {
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
index_.count(key) == 0, key_description_, " '", key, "' already defined");
|
||||
// Copy `key` here and move it into the index.
|
||||
items_.emplace_back(key, std::forward<V>(value));
|
||||
|
@ -45,7 +45,7 @@ uint32_t read_int32(std::ifstream& stream) {
|
||||
uint32_t expect_int32(std::ifstream& stream, uint32_t expected) {
|
||||
const auto value = read_int32(stream);
|
||||
// clang-format off
|
||||
AT_CHECK(value == expected,
|
||||
TORCH_CHECK(value == expected,
|
||||
"Expected to read number ", expected, " but found ", value, " instead");
|
||||
// clang-format on
|
||||
return value;
|
||||
@ -63,7 +63,7 @@ Tensor read_images(const std::string& root, bool train) {
|
||||
const auto path =
|
||||
join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename);
|
||||
std::ifstream images(path, std::ios::binary);
|
||||
AT_CHECK(images, "Error opening images file at ", path);
|
||||
TORCH_CHECK(images, "Error opening images file at ", path);
|
||||
|
||||
const auto count = train ? kTrainSize : kTestSize;
|
||||
|
||||
@ -83,7 +83,7 @@ Tensor read_targets(const std::string& root, bool train) {
|
||||
const auto path =
|
||||
join_paths(root, train ? kTrainTargetsFilename : kTestTargetsFilename);
|
||||
std::ifstream targets(path, std::ios::binary);
|
||||
AT_CHECK(targets, "Error opening targets file at ", path);
|
||||
TORCH_CHECK(targets, "Error opening targets file at ", path);
|
||||
|
||||
const auto count = train ? kTrainSize : kTestSize;
|
||||
|
||||
|
@ -18,7 +18,7 @@ namespace {
|
||||
struct Fan {
|
||||
explicit Fan(Tensor& tensor) {
|
||||
const auto dimensions = tensor.ndimension();
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
dimensions >= 2,
|
||||
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");
|
||||
|
||||
@ -73,7 +73,7 @@ Tensor constant_(Tensor tensor, Scalar value) {
|
||||
Tensor dirac_(Tensor tensor) {
|
||||
NoGradGuard guard;
|
||||
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
tensor.ndimension() >= 3 && tensor.ndimension() <= 5,
|
||||
"Only tensors with 3, 4, or 5 dimensions are supported");
|
||||
|
||||
@ -100,7 +100,7 @@ Tensor dirac_(Tensor tensor) {
|
||||
|
||||
Tensor eye_(Tensor matrix) {
|
||||
NoGradGuard guard;
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
matrix.ndimension() == 2, "Only tensors with 2 dimensions are supported");
|
||||
return torch::eye_out(matrix, matrix.size(0), matrix.size(1));
|
||||
}
|
||||
@ -118,7 +118,7 @@ Tensor ones_(Tensor tensor) {
|
||||
Tensor orthogonal_(Tensor tensor, double gain) {
|
||||
NoGradGuard guard;
|
||||
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
tensor.ndimension() >= 2,
|
||||
"Only tensors with 2 or more dimensions are supported");
|
||||
|
||||
@ -151,7 +151,7 @@ Tensor orthogonal_(Tensor tensor, double gain) {
|
||||
Tensor sparse_(Tensor tensor, double sparsity, double std) {
|
||||
NoGradGuard guard;
|
||||
|
||||
AT_CHECK(
|
||||
TORCH_CHECK(
|
||||
tensor.ndimension() == 2, "Only tensors with 2 dimensions are supported");
|
||||
|
||||
const auto rows = tensor.size(0);
|
||||
|
Reference in New Issue
Block a user