Avoid COW materialize in nn.functional forward ops (3) (#122443)

Affected ops:
* repeat
* unfold
* logsigmoid
* pixel_shuffle/unshuffle
* remaining norm ops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122443
Approved by: https://github.com/ezyang
This commit is contained in:
Kurt Mohler
2024-03-25 21:59:30 +00:00
committed by PyTorch MergeBot
parent b6982bf2b2
commit 5e66bf5f42
14 changed files with 85 additions and 90 deletions

View File

@ -16,8 +16,8 @@
template <typename index_t>
static void compute_cpu(
index_t* repeat_ptr,
int64_t* cumsum_ptr,
const index_t* repeat_ptr,
const int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
int64_t result_size) {

View File

@ -14,7 +14,7 @@ namespace at::native {
template <
typename index_t,
void compute(index_t*, int64_t*, index_t*, int64_t, int64_t)>
void compute(const index_t*, const int64_t*, index_t*, int64_t, int64_t)>
static inline Tensor repeat_interleave_common(
const Tensor& repeats,
c10::optional<int64_t> output_size) {
@ -38,8 +38,8 @@ static inline Tensor repeat_interleave_common(
}
Tensor result = at::empty({total}, repeats.options());
index_t* repeat_ptr = repeats_.data_ptr<index_t>();
int64_t* cumsum_ptr = cumsum.data_ptr<int64_t>();
const index_t* repeat_ptr = repeats_.const_data_ptr<index_t>();
const int64_t* cumsum_ptr = cumsum.const_data_ptr<int64_t>();
index_t* result_ptr = result.data_ptr<index_t>();
compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
return result;

View File

@ -6,7 +6,25 @@
namespace at::native {
using unfold2d_fn = void (*)(
using unfold2d_copy_fn = void (*)(
ScalarType dtype,
void *finput,
const void *input,
int64_t kH,
int64_t kW,
int64_t dH,
int64_t dW,
int64_t padH,
int64_t padW,
int64_t n_input_plane,
int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width,
bool is_channels_last
);
using unfold2d_acc_fn = void (*)(
ScalarType dtype,
void *finput,
void *input,
@ -24,7 +42,7 @@ using unfold2d_fn = void (*)(
bool is_channels_last
);
DECLARE_DISPATCH(unfold2d_fn, unfolded2d_copy_stub);
DECLARE_DISPATCH(unfold2d_fn, unfolded2d_acc_stub);
DECLARE_DISPATCH(unfold2d_copy_fn, unfolded2d_copy_stub);
DECLARE_DISPATCH(unfold2d_acc_fn, unfolded2d_acc_stub);
} // namespace at::native

View File

@ -30,7 +30,7 @@ static void log_sigmoid_cpu_kernel(TensorBase &output, TensorBase &buffer, const
using Vec = Vectorized<scalar_t>;
scalar_t* output_data = output.data_ptr<scalar_t>();
scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
scalar_t* input_data = input.data_ptr<scalar_t>();
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
parallel_for(0, input.numel(), 1, [&] (int64_t begin, int64_t end) {
int64_t size = end - begin;
int64_t d = 0;
@ -65,7 +65,7 @@ static void log_sigmoid_cpu_kernel(TensorBase &output, TensorBase &buffer, const
using Vec = Vectorized<scalar_t>;
scalar_t* output_data = output.data_ptr<scalar_t>();
scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
scalar_t* input_data = input.data_ptr<scalar_t>();
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
parallel_for(0, input.numel(), 1, [&] (int64_t begin, int64_t end) {
int64_t size = end - begin;
int64_t d = 0;

View File

@ -17,7 +17,7 @@ void cpu_pixel_shuffle(
TensorBase& output,
const TensorBase& input,
int64_t upscale_factor) {
auto input_data = input.data_ptr<scalar_t>();
auto input_data = input.const_data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
// [(B1...Bn), C, H, W] => [N, C, H, W]
@ -59,7 +59,7 @@ void cpu_pixel_shuffle_channels_last(
int64_t upscale_factor) {
TORCH_CHECK(input.ndimension() == 4,
"pixel shuffle with channels last format supports tensors with 4 dims");
auto input_data = input.data_ptr<scalar_t>();
auto input_data = input.const_data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
int64_t nbatch = input.size(0);
@ -81,7 +81,7 @@ void cpu_pixel_shuffle_channels_last(
data_index_init(begin, n, nbatch, h, height);
for (const auto i : c10::irange(begin, end)) {
for (const auto w : c10::irange(width)) {
scalar_t* input_ptr = input_data + n * height * width * channels + h * width * channels + w * channels;
const scalar_t* input_ptr = input_data + n * height * width * channels + h * width * channels + w * channels;
// step 1: transpose each channel lane
// from: [c, s1*s2]
@ -115,7 +115,7 @@ void cpu_pixel_unshuffle(
TensorBase& output,
const TensorBase& input,
int64_t downscale_factor) {
auto input_data = input.data_ptr<scalar_t>();
auto input_data = input.const_data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
// [(B1...Bn), C, H, W] => [N, C, H, W]
@ -158,7 +158,7 @@ void cpu_pixel_unshuffle_channels_last(
int64_t downscale_factor) {
TORCH_CHECK(input.ndimension() == 4,
"pixel unshuffle with channels last format supports tensors with 4 dims");
auto input_data = input.data_ptr<scalar_t>();
auto input_data = input.const_data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
int64_t nbatch = input.size(0);

View File

@ -228,7 +228,7 @@ void unfolded2d_acc_kernel(
template <typename scalar_t>
static void unfolded2d_copy(
scalar_t* input_data,
const scalar_t* input_data,
scalar_t* finput_data,
int64_t kH,
int64_t kW,
@ -256,7 +256,7 @@ static void unfolded2d_copy(
nip * ((size_t)kH * kW * output_height * output_width) +
kh * ((size_t)kW * output_height * output_width) +
kw * ((size_t)output_height * output_width);
scalar_t* src =
const scalar_t* src =
input_data + nip * ((size_t)input_height * input_width);
if (padW > 0 || padH > 0) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -335,7 +335,7 @@ static void unfolded2d_copy(
template <typename scalar_t>
static void unfolded2d_copy_channels_last(
scalar_t* input_data,
const scalar_t* input_data,
scalar_t* finput_data,
int64_t kH,
int64_t kW,
@ -355,7 +355,7 @@ static void unfolded2d_copy_channels_last(
for (const auto k C10_UNUSED: c10::irange(start, end)) {
scalar_t* dst = finput_data + y * output_width * kH * kW * n_input_plane + x * kH * kW * n_input_plane;
scalar_t* src = input_data;
const scalar_t* src = input_data;
if (padW > 0 || padH > 0) {
for (int64_t kh = 0; kh < kH; kh++) {
@ -393,7 +393,7 @@ static void unfolded2d_copy_channels_last(
void unfolded2d_copy_kernel(
ScalarType dtype,
void *finput_data,
void *input_data,
const void *input_data,
int64_t kH,
int64_t kW,
int64_t dH,
@ -415,7 +415,7 @@ void unfolded2d_copy_kernel(
if (is_channels_last) {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "unfolded2d_copy_channels_last", [&] {
unfolded2d_copy_channels_last(
static_cast<scalar_t*>(input_data),
static_cast<const scalar_t*>(input_data),
static_cast<scalar_t*>(finput_data),
kH, kW,
dH, dW,
@ -429,7 +429,7 @@ void unfolded2d_copy_kernel(
} else {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "unfolded2d_copy", [&] {
unfolded2d_copy(
static_cast<scalar_t*>(input_data),
static_cast<const scalar_t*>(input_data),
static_cast<scalar_t*>(finput_data),
kH, kW,
dH, dW,

View File

@ -34,13 +34,13 @@ void batch_norm_cpu_collect_linear_and_constant_terms(
const Tensor& save_mean, const Tensor& save_invstd,
const Tensor& running_mean, const Tensor& running_var, bool train, double eps) {
const param_t* weight_data = weight.defined() ? weight.data_ptr<param_t>() : nullptr;
const param_t* bias_data = bias.defined() ? bias.data_ptr<param_t>() : nullptr;
const param_t* weight_data = weight.defined() ? weight.const_data_ptr<param_t>() : nullptr;
const param_t* bias_data = bias.defined() ? bias.const_data_ptr<param_t>() : nullptr;
auto save_mean_a = conditional_accessor_1d<param_t>(save_mean);
auto save_invstd_a = conditional_accessor_1d<param_t>(save_invstd);
auto running_mean_a = conditional_accessor_1d<param_t>(running_mean);
auto running_var_a = conditional_accessor_1d<param_t>(running_var);
auto save_mean_a = conditional_accessor_1d<const param_t>(save_mean);
auto save_invstd_a = conditional_accessor_1d<const param_t>(save_invstd);
auto running_mean_a = conditional_accessor_1d<const param_t>(running_mean);
auto running_var_a = conditional_accessor_1d<const param_t>(running_var);
/// Collect the linear and constant terms regarding the input.
/// output(n, c, h, w)
@ -91,7 +91,7 @@ batch_norm_cpu_contiguous_impl(Tensor& output, const Tensor& input,
save_mean, save_invstd, running_mean, running_var, train, eps);
scalar_t* output_data = output.data_ptr<scalar_t>();
const scalar_t* input_data = input.data_ptr<scalar_t>();
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
// Apply the linear terms to the input,
// output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
@ -143,7 +143,7 @@ batch_norm_cpu_channels_last_impl(Tensor& output, const Tensor& input,
save_mean, save_invstd, running_mean, running_var, train, eps);
scalar_t* output_data = output.data_ptr<scalar_t>();
const scalar_t* input_data = input.data_ptr<scalar_t>();
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
// Apply the linear terms to the input,
// output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
@ -185,7 +185,7 @@ batch_norm_cpu_collect_stats_contiguous_impl(
int64_t image_size = input.numel() / n_batch / n_channel;
int64_t N = input.numel() / n_channel;
const scalar_t* input_data = input.data_ptr<scalar_t>();
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
scalar_t* mean_data = mean.data_ptr<scalar_t>();
scalar_t* var_sum_data = var_sum.data_ptr<scalar_t>();
@ -229,7 +229,7 @@ batch_norm_cpu_collect_stats_channels_last_impl(
int64_t n_channel = input.size(1);
int64_t N = input.numel() / n_channel;
const scalar_t* input_data = input.data_ptr<scalar_t>();
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
scalar_t* mean_data = mean.data_ptr<scalar_t>();
scalar_t* var_sum_data = var_sum.data_ptr<scalar_t>();

View File

@ -43,9 +43,9 @@ void GroupNormKernelImplInternal(
TORCH_CHECK(!beta.defined() || beta.numel() == C);
const int64_t G = group;
const int64_t D = C / G;
const T* X_data = X.data_ptr<T>();
const PT* gamma_data = gamma.defined() ? gamma.data_ptr<PT>() : nullptr;
const PT* beta_data = beta.defined() ? beta.data_ptr<PT>() : nullptr;
const T* X_data = X.const_data_ptr<T>();
const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
const PT* beta_data = beta.defined() ? beta.const_data_ptr<PT>() : nullptr;
T* Y_data = Y.data_ptr<T>();
PT* mean_data = mean.data_ptr<PT>();
PT* rstd_data = rstd.data_ptr<PT>();
@ -298,9 +298,9 @@ void GroupNormKernelImplChannelsLastInternal(
TORCH_CHECK(!beta.defined() || beta.numel() == C);
const int64_t G = group;
const int64_t D = C / G;
const T* X_data = X.data_ptr<T>();
const PT* gamma_data = gamma.defined() ? gamma.data_ptr<PT>() : nullptr;
const PT* beta_data = beta.defined() ? beta.data_ptr<PT>() : nullptr;
const T* X_data = X.const_data_ptr<T>();
const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
const PT* beta_data = beta.defined() ? beta.const_data_ptr<PT>() : nullptr;
T* Y_data = Y.data_ptr<T>();
PT* mean_data = mean.data_ptr<PT>();
PT* rstd_data = rstd.data_ptr<PT>();

View File

@ -80,7 +80,7 @@ std::tuple<Tensor&, Tensor&> log_sigmoid_forward_out_cuda(const Tensor& input, T
// NOTE: buffer is only used by CPU dispatch, we just ignore it here
auto iter = TensorIteratorConfig()
.add_output(result)
.add_input(input)
.add_const_input(input)
.build();
launch_log_sigmoid_forward_kernel(iter);
return std::forward_as_tuple(result, buffer);

View File

@ -210,12 +210,12 @@ __device__ __forceinline__ void welford_merge_block_vertical(C& count,
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, bool train, typename index_t>
__global__ void batch_norm_transform_input_kernel(
const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
const GenericPackedTensorAccessor<const input_scalar_t, 3, RestrictPtrTraits, index_t> input,
GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> output,
const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> mean_,
const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> var_or_invstd,
const GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> weight,
const GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> bias,
const GenericPackedTensorAccessor<const stat_scalar_t, 1, RestrictPtrTraits, index_t> weight,
const GenericPackedTensorAccessor<const stat_scalar_t, 1, RestrictPtrTraits, index_t> bias,
stat_accscalar_t epsilon) {
index_t plane = blockIdx.x;
@ -267,7 +267,7 @@ struct Var {
template <typename VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
__global__ void batch_norm_collect_statistics_kernel(
const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
const GenericPackedTensorAccessor<const input_scalar_t, 3, RestrictPtrTraits, index_t> input,
const stat_accscalar_t epsilon,
const stat_accscalar_t momentum,
GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
@ -582,7 +582,7 @@ __global__ void batch_norm_backward_elemt_kernel(
template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> get_packed_accessor(
const Tensor& t, c10::string_view var_name) {
constexpr auto expect_type = c10::CppTypeToScalarType<scalar_t>::value;
constexpr auto expect_type = c10::CppTypeToScalarType<typename std::remove_const<scalar_t>::type>::value;
const auto actual_type = t.scalar_type();
TORCH_CHECK(actual_type == expect_type, "Expected ", var_name,
" to have type ", expect_type, " but got ", actual_type);
@ -670,7 +670,7 @@ void batch_norm_stats_cuda_template(
resize_output(out_mean, {n_input});
resize_output(out_invstd, {n_input});
auto input = get_packed_accessor<
scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
const scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
out_invstd.sizes()[0]);
TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
@ -700,13 +700,13 @@ void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_,
auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1});
auto input = get_packed_accessor<
input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
const input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
auto output = get_packed_accessor<
input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output");
auto weight = packed_accessor_or_dummy<
stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight");
const stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight");
auto bias = packed_accessor_or_dummy<
stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias");
const stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias");
auto mean = packed_accessor_or_dummy<
stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean");
auto invstd = packed_accessor_or_dummy<

View File

@ -12,8 +12,8 @@
template <typename index_t>
__global__ static void compute_cuda_kernel(
index_t* repeat_ptr,
int64_t* cumsum_ptr,
const index_t* repeat_ptr,
const int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
int64_t result_size) {
@ -35,8 +35,8 @@ __global__ static void compute_cuda_kernel(
template <typename index_t>
static void compute_cuda(
index_t* repeat_ptr,
int64_t* cumsum_ptr,
const index_t* repeat_ptr,
const int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
int64_t result_size) {

View File

@ -496,11 +496,11 @@ void GroupNorm1dForward(
auto iter = TensorIteratorConfig()
.resize_outputs(false)
.add_owned_output(Y.view({N, G, D}))
.add_owned_input(X.view({N, G, D}))
.add_owned_const_input(X.view({N, G, D}))
.add_owned_input(mean.view({N, G, 1}))
.add_owned_input(rstd.view({N, G, 1}))
.add_owned_input(gamma.view({1, G, D}))
.add_owned_input(beta.view({1, G, D}))
.add_owned_const_input(gamma.view({1, G, D}))
.add_owned_const_input(beta.view({1, G, D}))
.build();
gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd, T gamma, T beta) -> T {
return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
@ -511,10 +511,10 @@ void GroupNorm1dForward(
auto iter = TensorIteratorConfig()
.resize_outputs(false)
.add_owned_output(Y.view({N, G, D}))
.add_owned_input(X.view({N, G, D}))
.add_owned_const_input(X.view({N, G, D}))
.add_owned_input(mean.view({N, G, 1}))
.add_owned_input(rstd.view({N, G, 1}))
.add_owned_input(gamma.view({1, G, D}))
.add_owned_const_input(gamma.view({1, G, D}))
.build();
gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd, T gamma) -> T {
return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
@ -524,10 +524,10 @@ void GroupNorm1dForward(
auto iter = TensorIteratorConfig()
.resize_outputs(false)
.add_owned_output(Y.view({N, G, D}))
.add_owned_input(X.view({N, G, D}))
.add_owned_const_input(X.view({N, G, D}))
.add_owned_input(mean.view({N, G, 1}))
.add_owned_input(rstd.view({N, G, 1}))
.add_owned_input(beta.view({1, G, D}))
.add_owned_const_input(beta.view({1, G, D}))
.build();
gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd, T beta) -> T {
return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
@ -538,7 +538,7 @@ void GroupNorm1dForward(
auto iter = TensorIteratorConfig()
.resize_outputs(false)
.add_owned_output(Y.view({N * G, D}))
.add_owned_input(X.view({N * G, D}))
.add_owned_const_input(X.view({N * G, D}))
.add_owned_input(mean.view({N * G, 1}))
.add_owned_input(rstd.view({N * G, 1}))
.build();
@ -590,7 +590,7 @@ void GroupNormKernelImplInternal(
auto iter = TensorIteratorConfig()
.resize_outputs(false)
.add_owned_output(Y.view({N * G, D * HxW}))
.add_owned_input(X.view({N * G, D * HxW}))
.add_owned_const_input(X.view({N * G, D * HxW}))
.add_owned_input(mean.view({N * G, 1}))
.add_owned_input(rstd.view({N * G, 1}))
.build();
@ -622,7 +622,7 @@ void GroupNormKernelImplInternal(
.check_all_same_dtype(std::is_same<T, T_ACC>::value)
.resize_outputs(false)
.add_owned_output(Y.view({N * C, HxW}))
.add_owned_input(X.view({N * C, HxW}))
.add_owned_const_input(X.view({N * C, HxW}))
.add_owned_input(a.view({N * C, 1}))
.add_owned_input(b.view({N * C, 1}))
.build();

View File

@ -151,8 +151,8 @@ static id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device, const
}
template <typename index_t>
void computeRepeatIndices(index_t* repeat_ptr,
int64_t* cumsum_ptr,
void computeRepeatIndices(const index_t* repeat_ptr,
const int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
int64_t result_size) {

View File

@ -13070,8 +13070,7 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
allow_cow_input_materialize=[3, 4],
sample_inputs_func=sample_inputs_native_batch_norm,
skips=(
# NotImplementedError: Could not run
@ -13099,8 +13098,7 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
allow_cow_input_materialize=[3, 4],
sample_inputs_func=sample_inputs__native_batch_norm_legit,
skips=(
# NotImplementedError: Could not run
@ -13126,8 +13124,7 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
allow_cow_input_materialize=[3, 4],
sample_inputs_func=sample_inputs__batch_norm_with_update,
skips=(
# NotImplementedError: Could not run
@ -13298,8 +13295,6 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
dtypes=floating_types_and(torch.half, torch.bfloat16),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
sample_inputs_func=sample_inputs_binary_cross_entropy_with_logits,
@ -13617,8 +13612,6 @@ op_db: List[OpInfo] = [
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
error_inputs_func=error_inputs_group_norm,
decorators=[
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
@ -13634,8 +13627,6 @@ op_db: List[OpInfo] = [
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
allow_cow_input_materialize=['running_mean', 'running_var'],
decorators=[
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
@ -14028,8 +14019,6 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_multilabel_soft_margin_loss,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
decorators=(
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
@ -14753,8 +14742,6 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_gradgrad=True,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
# autodiff_nonfusible_nodes=["aten::log_sigmoid"],
decorators=[
DecorateInfo(
@ -14947,8 +14934,6 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
allow_cow_input_materialize=[1, 2],
sample_inputs_func=sample_inputs_batch_norm,
skips=(
@ -14972,8 +14957,6 @@ op_db: List[OpInfo] = [
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
allow_cow_input_materialize=[1, 2],
decorators=[onlyCUDA, disablecuDNN],
skips=(
@ -19501,8 +19484,6 @@ op_db: List[OpInfo] = [
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
skips=(
DecorateInfo(
unittest.skip("Skipped!"),
@ -19538,8 +19519,6 @@ op_db: List[OpInfo] = [
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
skips=(
DecorateInfo(
unittest.skip("Skipped!"),
@ -19556,8 +19535,6 @@ op_db: List[OpInfo] = [
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
skips=(
DecorateInfo(
unittest.skip("Skipped!"),