Replace AT_CHECK with TORCH_CHECK [shard 6/10]

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20430

Reviewed By: jerryzh168

Differential Revision: D15318250

fbshipit-source-id: eaee93447d757124a0c9fb5dcde503ae6a065912
This commit is contained in:
Edward Yang
2019-05-14 15:45:44 -07:00
committed by Facebook Github Bot
parent 5b45355431
commit 358fb51e77
16 changed files with 103 additions and 103 deletions

View File

@ -76,7 +76,7 @@ Tensor& lerp_cuda_tensor_(Tensor& self, const Tensor& end, const Tensor& weight)
Tensor& lerp_cuda_scalar_(Tensor& self, const Tensor& end, Scalar weight) {
Tensor b_self, b_end;
std::tie(b_self, b_end) = expand_outplace(self, end, "lerp__cuda");
AT_CHECK(b_self.sizes() == self.sizes(),
TORCH_CHECK(b_self.sizes() == self.sizes(),
"output with shape ", self.sizes(),
" doesn't match the broadcast shape ", b_self.sizes());
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cuda", [&]{
@ -87,7 +87,7 @@ Tensor& lerp_cuda_scalar_(Tensor& self, const Tensor& end, Scalar weight) {
Tensor lerp_cuda_tensor(const Tensor& self, const Tensor& end, const Tensor& weight) {
Tensor b_self, b_end, b_weight;
AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
TORCH_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
"weight should be of dimension max(self.dim(), end.dim()) or lesser");
std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_cuda");
Tensor result = at::empty_like(b_self);

View File

@ -179,9 +179,9 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
int64_t batch_size = log_probs.size(1);
int64_t num_labels = log_probs.size(2);
AT_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range");
AT_CHECK(input_lengths.size() == batch_size, "input_lengths must be of size batch_size");
AT_CHECK(target_lengths.size() == batch_size, "target_lengths must be of size batch_size");
TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range");
TORCH_CHECK(input_lengths.size() == batch_size, "input_lengths must be of size batch_size");
TORCH_CHECK(target_lengths.size() == batch_size, "target_lengths must be of size batch_size");
int64_t lp_input_stride = log_probs.stride(0);
int64_t lp_char_stride = log_probs.stride(2);
@ -211,13 +211,13 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
}
tg_target_stride = targets.stride(1);
checkSize(c, targets_arg, 0, batch_size);
AT_CHECK(targets.size(1) >= max_target_length,
TORCH_CHECK(targets.size(1) >= max_target_length,
"Expected tensor to have size at least ", max_target_length, " at dimension 1, but got size ", targets.size(1), " for ", targets_arg,
" (while checking arguments for ", c, ")");
}
int64_t max_input_length = log_probs.size(0);
for (int64_t b = 0; b < batch_size; b++) {
AT_CHECK(input_lengths[b] <= max_input_length,
TORCH_CHECK(input_lengths[b] <= max_input_length,
"Expected tensor to have size at least ", max_input_length, " at dimension 1, but got size ", targets.size(0), " for ", targets_arg,
" (while checking arguments for ", c, ")");
}

View File

@ -86,7 +86,7 @@ std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_backward", [&] {
auto mean_st = running_mean.dtype();
auto var_st = running_var.dtype();
AT_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types");
TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types");
// <sigh> Some workloads depend on passing in half input and float stats, which is
// usually handled by cuDNN. However, the JIT sometimes replaces cuDNN calls with this
// one so it needs to support the same case, or people start to complain.

View File

@ -38,7 +38,7 @@ struct LogspaceOp {
};
Tensor& linspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t steps) {
AT_CHECK(steps >= 0, "number of steps must be non-negative");
TORCH_CHECK(steps >= 0, "number of steps must be non-negative");
if (result.numel() != steps) {
result.resize_({steps});
@ -68,7 +68,7 @@ Tensor& linspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step
}
Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t steps, double base) {
AT_CHECK(steps >= 0, "number of steps must be non-negative");
TORCH_CHECK(steps >= 0, "number of steps must be non-negative");
if (result.numel() != steps) {
result.resize_({steps});
@ -105,11 +105,11 @@ Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
auto xend = end.to<accscalar_t>();
auto xstep = step.to<accscalar_t>();
AT_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
AT_CHECK(std::isfinite(static_cast<double>(xstart)) &&
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
std::isfinite(static_cast<double>(xend)),
"unsupported range: ", xstart, " -> ", xend);
AT_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");
int64_t size = static_cast<int64_t>(((xend - xstart) / xstep) + 1);
if (result.numel() != size) {
@ -152,14 +152,14 @@ Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
/ step.to<double>());
}
AT_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
AT_CHECK(std::isfinite(static_cast<double>(xstart)) &&
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
std::isfinite(static_cast<double>(xend)),
"unsupported range: ", xstart, " -> ", xend);
AT_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");
AT_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
"invalid size, possible overflow?");
int64_t size = static_cast<int64_t>(size_d);

View File

@ -148,14 +148,14 @@ __global__ void reflection_pad2d_backward_out_kernel(
void reflection_pad1d_out_template(
Tensor &output, const Tensor &input_, IntArrayRef padding) {
AT_CHECK(canUse32BitIndexMath(input_),
TORCH_CHECK(canUse32BitIndexMath(input_),
"input tensor must fit into 32-bit index math");
int64_t dim_plane = 0;
int64_t dim_w = 1;
int64_t nbatch = 1;
AT_CHECK(input_.numel() > 0 &&
TORCH_CHECK(input_.numel() > 0 &&
(input_.ndimension() == 2 || input_.ndimension() == 3), "non-empty 2D "
"or 3D (batch mode) tensor expected for input, but got: ", input_);
@ -172,11 +172,11 @@ void reflection_pad1d_out_template(
int64_t input_w = input_.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;
AT_CHECK(pad_l < input_w && pad_r < input_w, "Padding size should be less "
TORCH_CHECK(pad_l < input_w && pad_r < input_w, "Padding size should be less "
"than the corresponding input dimension, but got: padding (", pad_l, ", ",
pad_r, ") at dimension ", dim_w, " of input ", input_);
AT_CHECK(output_w >= 1,
TORCH_CHECK(output_w >= 1,
"input (W: ", input_w, ")is too small. Calculated output W: ", output_w);
if (input_.ndimension() == 2) {
@ -206,10 +206,10 @@ void reflection_pad1d_backward_out_template(
Tensor & grad_input, const Tensor & grad_output_,
const Tensor & input, IntArrayRef padding) {
AT_CHECK(canUse32BitIndexMath(input),
TORCH_CHECK(canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
AT_CHECK(canUse32BitIndexMath(grad_output_),
TORCH_CHECK(canUse32BitIndexMath(grad_output_),
"input tensor must fit into 32-bit index math");
int64_t dim_plane = 0;
@ -231,7 +231,7 @@ void reflection_pad1d_backward_out_template(
Tensor grad_output = grad_output_.contiguous();
AT_CHECK(output_w == grad_output.size(dim_w),
TORCH_CHECK(output_w == grad_output.size(dim_w),
"gradOutput width unexpected. Expected: ", output_w, ", Got: ",
grad_output.size(dim_w));
@ -252,7 +252,7 @@ void reflection_pad1d_backward_out_template(
void reflection_pad2d_out_template(
Tensor &output, const Tensor &input_, IntArrayRef padding) {
AT_CHECK(canUse32BitIndexMath(input_),
TORCH_CHECK(canUse32BitIndexMath(input_),
"input tensor must fit into 32-bit index math");
int plane_dim = 0;
@ -260,7 +260,7 @@ void reflection_pad2d_out_template(
int dim_w = 2;
int nbatch = 1;
AT_CHECK(input_.numel() > 0 &&
TORCH_CHECK(input_.numel() > 0 &&
(input_.ndimension() == 3 || input_.ndimension() == 4), "non-empty 3D or "
"4D (batch mode) tensor expected for input, but got: ", input_);
@ -280,12 +280,12 @@ void reflection_pad2d_out_template(
int input_h = input_.size(dim_h);
int input_w = input_.size(dim_w);
AT_CHECK(pad_l < input_w && pad_r < input_w,
TORCH_CHECK(pad_l < input_w && pad_r < input_w,
"Padding size should be less than the corresponding input dimension, but "
"got: padding (", pad_l, ", ", pad_r, ") at dimension ", dim_w,
" of input ", input_.sizes());
AT_CHECK(pad_t < input_h && pad_b < input_h,
TORCH_CHECK(pad_t < input_h && pad_b < input_h,
"Padding size should be less than the corresponding input dimension, but "
"got: padding (", pad_t, ", ", pad_b, ") at dimension ", dim_h,
" of input ", input_.sizes());
@ -293,7 +293,7 @@ void reflection_pad2d_out_template(
int output_h = input_h + pad_t + pad_b;
int output_w = input_w + pad_l + pad_r;
AT_CHECK(output_w >= 1 || output_h >= 1,
TORCH_CHECK(output_w >= 1 || output_h >= 1,
"input (H: ", input_h, ", W: ", input_w, ")is too small. Calculated "
"output H: ", output_h, " W: ", output_w);
@ -326,9 +326,9 @@ void reflection_pad2d_out_template(
void reflection_pad2d_backward_out_template(
Tensor &grad_input, const Tensor &grad_output_,
const Tensor &input, IntArrayRef padding) {
AT_CHECK(canUse32BitIndexMath(input),
TORCH_CHECK(canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
AT_CHECK(canUse32BitIndexMath(grad_output_),
TORCH_CHECK(canUse32BitIndexMath(grad_output_),
"output gradient tensor must fit into 32-bit index math");
int plane_dim = 0;
@ -355,9 +355,9 @@ void reflection_pad2d_backward_out_template(
int output_h = input_h + pad_t + pad_b;
int output_w = input_w + pad_l + pad_r;
AT_CHECK(output_w == grad_output_.size(dim_w), "grad_output width "
TORCH_CHECK(output_w == grad_output_.size(dim_w), "grad_output width "
"unexpected. Expected: ", output_w, ", Got: ", grad_output_.size(dim_w));
AT_CHECK(output_h == grad_output_.size(dim_h), "grad_output height "
TORCH_CHECK(output_h == grad_output_.size(dim_h), "grad_output height "
"unexpected. Expected: ", output_h, ", Got: ", grad_output_.size(dim_h));
Tensor grad_output = grad_output_.contiguous();

View File

@ -205,9 +205,9 @@ void replication_pad1d_out_cuda_template(
const Tensor& input,
IntArrayRef paddingSize)
{
AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
AT_CHECK(paddingSize.size() == 2, "padding Size is expected to be 2");
TORCH_CHECK(paddingSize.size() == 2, "padding Size is expected to be 2");
int padL = paddingSize[0];
int padR = paddingSize[1];
@ -216,7 +216,7 @@ void replication_pad1d_out_cuda_template(
int numBatch = 1;
int numInputDims = input.ndimension();
AT_CHECK(input.numel() > 0 && (numInputDims == 2 || numInputDims == 3),
TORCH_CHECK(input.numel() > 0 && (numInputDims == 2 || numInputDims == 3),
"2D or 3D (batch mode) tensor expected for input")
if (numInputDims == 3) {
@ -229,7 +229,7 @@ void replication_pad1d_out_cuda_template(
int inputW = input.size(dimw);
int outputW = inputW + padL + padR;
AT_CHECK(outputW >= 1,
TORCH_CHECK(outputW >= 1,
"input (W: ", inputW, ")is too small."
" Calculated output W: ", outputW);
@ -279,11 +279,11 @@ void replication_pad1d_backward_out_cuda_template(
IntArrayRef paddingSize)
{
AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
AT_CHECK(at::cuda::detail::canUse32BitIndexMath(gradOutput),
TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(gradOutput),
"output gradient tensor must fit into 32-bit index math");
AT_CHECK(paddingSize.size() == 2, "padding Size is expected to be 2");
TORCH_CHECK(paddingSize.size() == 2, "padding Size is expected to be 2");
int padL = paddingSize[0];
int padR = paddingSize[1];
@ -298,7 +298,7 @@ void replication_pad1d_backward_out_cuda_template(
int iwidth = input.size(dimw);
int owidth = iwidth + padL + padR;
AT_CHECK(owidth == gradOutput.size(dimw),
TORCH_CHECK(owidth == gradOutput.size(dimw),
"gradOutput width unexpected. Expected: ", owidth, ", Got: ",
gradOutput.size(dimw));
@ -336,9 +336,9 @@ void replication_pad2d_out_cuda_template(
const Tensor& input,
IntArrayRef paddingSize)
{
AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
AT_CHECK(paddingSize.size() == 4, "padding Size is expected to be 4");
TORCH_CHECK(paddingSize.size() == 4, "padding Size is expected to be 4");
int padL = paddingSize[0];
int padR = paddingSize[1];
@ -350,7 +350,7 @@ void replication_pad2d_out_cuda_template(
int numBatch = 1;
int numInputDims = input.dim();
AT_CHECK(input.numel() && (numInputDims == 3 || numInputDims == 4),
TORCH_CHECK(input.numel() && (numInputDims == 3 || numInputDims == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input, but got: ",
input)
@ -367,7 +367,7 @@ void replication_pad2d_out_cuda_template(
int outputH = inputH + padT + padB;
int outputW = inputW + padL + padR;
AT_CHECK(outputW >= 1 || outputH >= 1,
TORCH_CHECK(outputW >= 1 || outputH >= 1,
"input (H: ", inputH, ", W: ", inputW, ") is too small."
" Calculated output H: ", outputH, " W: ", outputW);
@ -418,11 +418,11 @@ void replication_pad2d_backward_out_cuda_template(
IntArrayRef paddingSize)
{
AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
AT_CHECK(at::cuda::detail::canUse32BitIndexMath(gradOutput),
TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(gradOutput),
"output gradient tensor must fit into 32-bit index math");
AT_CHECK(paddingSize.size() == 4, "padding Size is expected to be 4");
TORCH_CHECK(paddingSize.size() == 4, "padding Size is expected to be 4");
int padL = paddingSize[0];
int padR = paddingSize[1];
@ -443,10 +443,10 @@ void replication_pad2d_backward_out_cuda_template(
int oheight = iheight + padT + padB;
int owidth = iwidth + padL + padR;
AT_CHECK(owidth == gradOutput.size(dimw),
TORCH_CHECK(owidth == gradOutput.size(dimw),
"gradOutput width unexpected. Expected: ", owidth, ", Got: ",
gradOutput.size(dimw));
AT_CHECK(oheight == gradOutput.size(dimh),
TORCH_CHECK(oheight == gradOutput.size(dimh),
"gradOutput height unexpected. Expected: ", oheight, ", Got: ",
gradOutput.size(dimh));
@ -483,11 +483,11 @@ static inline void shapeCheck3d(
int pleft, int pright,
int ptop, int pbottom,
int pfront, int pback) {
AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
int numInputDims = input.dim();
AT_CHECK(input.numel() && (numInputDims == 4 || numInputDims == 5),
TORCH_CHECK(input.numel() && (numInputDims == 4 || numInputDims == 5),
"non-empty 4D or 5D (batch mode) tensor expected for input, but got: ", input);
int planeDim = 0;
@ -508,7 +508,7 @@ static inline void shapeCheck3d(
int odepth = idepth + pfront + pback;
int oheight = iheight + ptop + pbottom;
int owidth = iwidth + pleft + pright;
AT_CHECK(owidth >= 1 || oheight >= 1 || odepth >= 1,
TORCH_CHECK(owidth >= 1 || oheight >= 1 || odepth >= 1,
"input (D: ", idepth, " H: ", iheight, ", W: ", iwidth,
") is too small."
" Calculated output D: ", odepth, " H: ", oheight, " W: ", owidth);
@ -521,11 +521,11 @@ static inline void shapeAndGradOutputCheck3d(
int pleft, int pright,
int ptop, int pbottom,
int pfront, int pback) {
AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
int numInputDims = input.dim();
AT_CHECK(input.numel() && (numInputDims == 4 || numInputDims == 5),
TORCH_CHECK(input.numel() && (numInputDims == 4 || numInputDims == 5),
"non-empty 4D or 5D (batch mode) tensor expected for input, but got: ", input);
int planeDim = 0;
@ -546,24 +546,24 @@ static inline void shapeAndGradOutputCheck3d(
int odepth = idepth + pfront + pback;
int oheight = iheight + ptop + pbottom;
int owidth = iwidth + pleft + pright;
AT_CHECK(owidth >= 1 || oheight >= 1 || odepth >= 1,
TORCH_CHECK(owidth >= 1 || oheight >= 1 || odepth >= 1,
"input (D: ", idepth, " H: ", iheight, ", W: ", iwidth,
") is too small."
" Calculated output D: ", odepth, " H: ", oheight, " W: ", owidth);
AT_CHECK(at::cuda::detail::canUse32BitIndexMath(gradOutput),
TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(gradOutput),
"output gradient tensor must fit into 32-bit index math");
AT_CHECK(numPlanes == gradOutput.size(planeDim),
TORCH_CHECK(numPlanes == gradOutput.size(planeDim),
"gradOutput width unexpected. Expected: ", numPlanes, ", Got: ",
gradOutput.size(planeDim));
AT_CHECK(owidth == gradOutput.size(dimw),
TORCH_CHECK(owidth == gradOutput.size(dimw),
"gradOutput width unexpected. Expected: ", owidth, ", Got: ",
gradOutput.size(dimw));
AT_CHECK(oheight == gradOutput.size(dimh),
TORCH_CHECK(oheight == gradOutput.size(dimh),
"gradOutput height unexpected. Expected: ", oheight, ", Got: ",
gradOutput.size(dimh));
AT_CHECK(odepth == gradOutput.size(dimd),
TORCH_CHECK(odepth == gradOutput.size(dimd),
"gradOutput depth unexpected. Expected: ", odepth, ", Got: ",
gradOutput.size(dimd));
}
@ -573,7 +573,7 @@ void replication_pad3d_out_cuda_template(
const Tensor& input,
IntArrayRef paddingSize)
{
AT_CHECK(paddingSize.size() == 6, "padding Size is expected to be 6");
TORCH_CHECK(paddingSize.size() == 6, "padding Size is expected to be 6");
int pleft = paddingSize[0];
int pright = paddingSize[1];
int ptop = paddingSize[2];
@ -654,7 +654,7 @@ void replication_pad3d_backward_out_cuda_template(
const Tensor& input,
IntArrayRef paddingSize)
{
AT_CHECK(paddingSize.size() == 6, "padding Size is expected to be 6");
TORCH_CHECK(paddingSize.size() == 6, "padding Size is expected to be 6");
int pleft = paddingSize[0];
int pright = paddingSize[1];
int ptop = paddingSize[2];

View File

@ -482,7 +482,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
static_assert(std::is_same<acc_type<at::Half, true>, float>::value, "accscalar_t for half should be float");
if (input.dim() == 0) input = input.view(1);
int64_t dim = maybe_wrap_dim(dim_, input.dim());
AT_CHECK(dim >=0 && dim < input.dim(), "dim must be non-negative and less than input dimensions");
TORCH_CHECK(dim >=0 && dim < input.dim(), "dim must be non-negative and less than input dimensions");
int64_t outer_size = 1;
int64_t dim_size = input.size(dim);
@ -557,7 +557,7 @@ Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t
auto grad = grad_.contiguous();
static_assert(std::is_same<acc_type<at::Half, true>, float>::value, "accscalar_t for half should be float");
if (grad.dim() == 0) grad = grad.view(1);
AT_CHECK(dim >=0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions");
TORCH_CHECK(dim >=0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions");
auto output = output_.contiguous();
if (output.dim() == 0) output = output.view(1);
int64_t outer_size = 1;

View File

@ -145,11 +145,11 @@ void kthvalue_cuda_template(
// FIXME: This seems bogus, I only do this because it was the old behaviour.
// The reductions are fine, as long as the axis being reduced along
// isn't of 0 elements (and the output has elements).
AT_CHECK(
TORCH_CHECK(
self.numel() > 0,
"cannot perform reduction function kthvalue",
" on tensor with no elements because the operation does not have an identity");
AT_CHECK(k >= 1 && k <= slicesize, "selected number k out of range");
TORCH_CHECK(k >= 1 && k <= slicesize, "selected number k out of range");
_reduction_with_indices_allocate_or_resize_output(
values, indices, self, dim, keepdim);
@ -159,7 +159,7 @@ void kthvalue_cuda_template(
return;
}
AT_CHECK(
TORCH_CHECK(
self.dim() <= MAX_TENSORINFO_DIMS,
"cannot operate on more than ",
MAX_TENSORINFO_DIMS,
@ -188,14 +188,14 @@ void kthvalue_cuda_template(
// this does not reduce to median with dim beause we don't want to copy twice
template <typename scalar_t>
Tensor median_cuda_template(const Tensor& self) {
AT_CHECK(self.numel() > 0, "median cannot be called with empty tensor");
TORCH_CHECK(self.numel() > 0, "median cannot be called with empty tensor");
if (self.dim() == 0 && self.numel() == 1) {
return self.clone();
}
auto self_copy = self.clone().view(-1);
auto values = at::empty({1}, self.options());
auto indices = at::empty({1}, self.options().dtype(kLong));
AT_CHECK(
TORCH_CHECK(
self.dim() <= MAX_TENSORINFO_DIMS,
"cannot operate on more than ",
MAX_TENSORINFO_DIMS,

View File

@ -286,7 +286,7 @@ CuFFTParamsLRUCache &cufft_get_plan_cache(int64_t device_index) {
namespace detail {
int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index) {
AT_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
"cufft_get_plan_cache_max_size: expected 0 <= device_index < ",
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
device_index);
@ -294,7 +294,7 @@ int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index) {
}
void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size) {
AT_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
"cufft_set_plan_cache_max_size: expected 0 <= device_index < ",
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
device_index);
@ -302,7 +302,7 @@ void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size)
}
int64_t cufft_get_plan_cache_size_impl(int64_t device_index) {
AT_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
"cufft_get_plan_cache_size: expected 0 <= device_index < ",
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
device_index);
@ -310,7 +310,7 @@ int64_t cufft_get_plan_cache_size_impl(int64_t device_index) {
}
void cufft_clear_plan_cache_impl(int64_t device_index) {
AT_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
"cufft_clear_plan_cache: expected 0 <= device_index < ",
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
device_index);

View File

@ -26,7 +26,7 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n) {
}
Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) {
AT_CHECK(n >= 0, "n must be greater or equal to 0, got ", n);
TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n);
if(m < 0) {
m = n;
@ -46,7 +46,7 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) {
Tensor empty_cuda(IntArrayRef size, const TensorOptions& options) {
AT_ASSERT(options.backend() == at::Backend::CUDA);
AT_ASSERT(!options.is_variable()); // is_variable should have been 'unpacked' // TODO: remove this when Variable and Tensor are merged
AT_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned");
TORCH_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned");
check_size_nonnegative(size);
auto* allocator = at::cuda::getCUDADeviceAllocator();
@ -74,8 +74,8 @@ Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, const TensorOpti
}
Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
AT_CHECK(n >= 0, "n must be non-negative, got", n);
AT_CHECK(at::scalar_tensor(n, result.options()).defined(),
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
TORCH_CHECK(at::scalar_tensor(n, result.options()).defined(),
"n is too large for result tensor type: '", result.type().toString(), "'");
result.resize_({n});
@ -322,7 +322,7 @@ Tensor tril_indices_cuda(
dim3 dim_grid;
// using tril_size instead of tensor.numel(), as each thread takes care of
// two elements in the tensor.
AT_CHECK(
TORCH_CHECK(
cuda::getApplyGrid(tril_size, dim_grid, tensor.get_device()),
"unable to get dim grid");
@ -398,7 +398,7 @@ Tensor triu_indices_cuda(
// using triu_size instead of tensor.numel(), as each thread takes care of
// two elements in the tensor.
AT_CHECK(
TORCH_CHECK(
cuda::getApplyGrid(triu_size, dim_grid, tensor.get_device()),
"unable to get dim grid");

View File

@ -179,7 +179,7 @@ Tensor roll_cuda(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) {
dim3 dim_block = cuda::getApplyBlock();
dim3 dim_grid;
AT_CHECK(cuda::getApplyGrid(N, dim_grid, in_tensor.get_device()), "unable to get dim grid");
TORCH_CHECK(cuda::getApplyGrid(N, dim_grid, in_tensor.get_device()), "unable to get dim grid");
auto total_dims = in_tensor.dim();

View File

@ -38,7 +38,7 @@ std::tuple<Tensor, Tensor, int64_t> compute_unique(
if (!return_inverse) {
inverse_indices = at::empty({0}, options);
} else {
AT_CHECK(sorted_indices.defined(),
TORCH_CHECK(sorted_indices.defined(),
"return_inverse is set to true, but sorted_indices is undefined. Send a bug report!");
const int64_t *sorted_indices_ptr = sorted_indices.data<int64_t>();
Tensor inv_loc = at::empty({num_inp}, options);

View File

@ -415,10 +415,10 @@ std::tuple<Tensor, Tensor> weight_norm_cuda_backward
{
// These checks should always succeed, because weight_norm_fused_backward should only
// ever be recorded in the autograd graph via weight_norm, which passes contiguous v and g.
AT_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
AT_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
AT_CHECK(saved_norms.is_contiguous(), "saved_norms must be contiguous");
AT_CHECK(dim == 0 || dim == saved_v.dim() - 1, "fused kernels can only be applied for first or last dim")
TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
TORCH_CHECK(saved_norms.is_contiguous(), "saved_norms must be contiguous");
TORCH_CHECK(dim == 0 || dim == saved_v.dim() - 1, "fused kernels can only be applied for first or last dim")
auto grad_v = at::empty_like(saved_v);
auto grad_g = at::empty_like(saved_g);

View File

@ -217,10 +217,10 @@ Tensor narrowGroup(const Tensor& t, int dim, int group_idx, int64_t groups) {
// Used on pad, stride and dilation
static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
{
AT_CHECK(args.size() <= expected_size,
TORCH_CHECK(args.size() <= expected_size,
"Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
expected_size, " (while checking arguments for ", c, ")");
AT_CHECK(args.size() >= expected_size,
TORCH_CHECK(args.size() >= expected_size,
"Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
expected_size, " (while checking arguments for ", c, ")");

View File

@ -46,14 +46,14 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss(const Tensor& log_probs_t, const Tens
checkBackend(c, {*log_probs}, Backend::CUDA);
checkBackend(c, {*targets}, Backend::CPU);
int64_t batch_size = log_probs->size(1);
AT_CHECK(input_lengths_.size() == batch_size, "input_lengths needs to have size to match batch_size");
AT_CHECK(target_lengths_.size() == batch_size, "target_lengths needs to have size to match batch_size");
TORCH_CHECK(input_lengths_.size() == batch_size, "input_lengths needs to have size to match batch_size");
TORCH_CHECK(target_lengths_.size() == batch_size, "target_lengths needs to have size to match batch_size");
std::vector<int> input_lengths(input_lengths_.begin(), input_lengths_.end());
std::vector<int> target_lengths(target_lengths_.begin(), target_lengths_.end());
setCuDNNStreamToCurrent();
AT_CHECK(BLANK == 0, "blank must be label 0 for cudnn_ctc_loss");
TORCH_CHECK(BLANK == 0, "blank must be label 0 for cudnn_ctc_loss");
// checked in dispatch:
// assert other conditions for cudnnCTCLoss: all label lengths <= 256
// all input lengths = logprob.size(0)

View File

@ -627,7 +627,7 @@ Tensor _cudnn_rnn_flatten_weight(
bool fn_bidirectional
) {
AT_CHECK(weight_arr.size() > 0,
TORCH_CHECK(weight_arr.size() > 0,
"_cudnn_rnn_flatten_weight_: cannot flatten empty weight list");
auto any_param = weight_arr[0];
@ -701,7 +701,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
// TODO: Set device to input
if (fn.rnn.mode != CUDNN_LSTM) {
AT_CHECK(!cx.defined(),
TORCH_CHECK(!cx.defined(),
"rnn: illegal defined cx for non-LSTM RNN");
}
@ -714,9 +714,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
auto output_size = _output_size(fn.rnn, fn.tensors);
AT_CHECK(hx.is_contiguous(),
TORCH_CHECK(hx.is_contiguous(),
"rnn: hx is not contiguous");
AT_CHECK(!cx.defined() || cx.is_contiguous(),
TORCH_CHECK(!cx.defined() || cx.is_contiguous(),
"rnn: cx is not contiguous");
auto x = input.contiguous();
@ -750,7 +750,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
w_desc.set(weight_buf, 3);
}
AT_CHECK(!cx.defined() || cx.sizes().equals(hidden_size),
TORCH_CHECK(!cx.defined() || cx.sizes().equals(hidden_size),
"Expected cell size ", IntArrayRef{hidden_size}, ", got ", cx.sizes());
size_t workspace_size;
@ -842,7 +842,7 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
auto handle = getCudnnHandle();
if (fn.rnn.mode != CUDNN_LSTM) {
AT_CHECK(!cx.defined(),
TORCH_CHECK(!cx.defined(),
"rnn: illegal defined cx for non-LSTM RNN");
}
@ -857,9 +857,9 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
auto output_size = _output_size(fn.rnn, fn.tensors);
AT_CHECK(hx.is_contiguous(),
TORCH_CHECK(hx.is_contiguous(),
"rnn: hx is not contiguous");
AT_CHECK(!cx.defined() || cx.is_contiguous(),
TORCH_CHECK(!cx.defined() || cx.is_contiguous(),
"rnn: cx is not contiguous");
auto x = input.contiguous();
@ -873,15 +873,15 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
AT_ASSERTM(cx.defined() || !output_mask[2], "illegally required grad of cx for non-LSTM RNN");
auto dcx = cx.defined() ? at::empty(hidden_size, cx.options()) : Tensor();
AT_CHECK(fn_train,
TORCH_CHECK(fn_train,
"cudnn RNN backward can only be called in training mode");
AT_CHECK(input.sizes().equals(input_size),
TORCH_CHECK(input.sizes().equals(input_size),
"Expected input size ", IntArrayRef{input_size}, ", got ", input.sizes());
AT_CHECK(output.sizes().equals(output_size),
TORCH_CHECK(output.sizes().equals(output_size),
"Expected output size ", IntArrayRef{output_size}, ", got ", output.sizes());
AT_CHECK(!hx.defined() || hx.sizes().equals(hidden_size),
TORCH_CHECK(!hx.defined() || hx.sizes().equals(hidden_size),
"Expected hidden size ", IntArrayRef{hidden_size}, ", got ", hx.sizes());
AT_CHECK(!cx.defined() || cx.sizes().equals(hidden_size),
"Expected cell size ", IntArrayRef{hidden_size}, ", got ", cx.sizes());