fix overload ambiguity with functional ops; fix _foreach op grouping (#80556)

This should fix the last issue that @anijain2305 hit when running ResNet with TorchDynamo <> functionalization.

Today if you try to call an `OpOverloadPacket` from python with some arguments, we will use the types of those arguments to perform overload resolution. With some functional variants of ops, this can be ambiguous.

Today this affects just one op: `_fused_moving_avg_obs_fq_helper`, although it would potentially affect e.g. `native_batch_norm` in the future.

Example:
```
# There are technically two overloads:
# torch.ops.aten._fused_moving_avg_obs_fq_helper.default (returns 2 argument, mutates 4 of its inputs inplace)
# torch.ops.aten._fused_moving_avg_obs_fq_helper.functional (returns 6 argument, mutates none of its inputs)

# We pick the wrong one - no way to know that we should pick the functional one, just from the call site.
outs = torch.ops.aten._fused_moving_avg_obs_fq_helper(a, a, a, a, a, a, a, 1.0, 0, 1, 0)
# raises an error - tries to call the overload with only 2 returns
return _fused_moving_avg_obs_fq_helper_functional[5]
```

Specifically, functionalization will bake `_fused_moving_avg_obs_fq_helper.functional` into the graph, but when AOTAutograd tries to compile with TorchScript, it needs to remove the overload name (TS doesn't know how to parse overload names directly, so we need to remove the overload name and let it infer the right overload at runtime later- so it picks the wrong one).

The situation is pretty similar to inplace; `ops.aten.add` and `ops.aten.add_` represent two different `OverloadPacket` objects; they can't be overloads of the same op, because their schemas would be ambiguous - the alias annotations are different, but that isn't enough to disambiguate).

In this PR, I try to fix the situation in a pretty similar way to how we handle `inplace` in the data model: `inplace` ops get their own base operator name, but they are represented as a flag inside of `BaseOperatorName` in the data model.

Two other important changes that I made as part of this PR:

(1) Originally, there were ~100 different `*_functional` operators: e.g. we had operators named `resize.functional` and `zero.functional`. The `_functional` bit isn't actually necessary in most cases: it's only necessary for operators that **also** have a `SchemaKind.mutable` variant, where `_fused_moving_avg_obs_fq_helper` is the only op that fits that description today. So I removed the unnecessary notion of "functional" from those other ops. I also added a bunch of assertions to force this restriction.

I think that makes more sense in the long run, because it eliminates an unnecessary difference in the model. E.g. we don't have `add_.Tensor` and `add.Tensor_functional`. We just have `add_.Tensor` and `add.Tensor`.

(2) I noticed that we actually still weren't pairing up a bunch of `_foreach` operators correctly, because their input arguments were different (`self` vs. `tensors`). Since they're private API's, I went ahead and changed the argument names directly so they get matched up. Before this PR, we were generating a separate `_foreach_add` and `_foreach_add.functional` variant in a bunch of cases, that really did the same thing (but happened to have a different name for the first argument).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80556
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
Brian Hirsh
2022-07-05 19:45:20 -07:00
committed by PyTorch MergeBot
parent ce0786add2
commit 960758b0b7
13 changed files with 216 additions and 150 deletions

View File

@ -131,7 +131,7 @@ const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet,
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = at::resize_functional(self_, size, memory_format);
tmp_output = at::resize(self_, size, memory_format);
}
auto itemsize = self.dtype().itemsize();

View File

@ -328,6 +328,11 @@ Tensor normal_meta(const Tensor& mean, const Tensor& std, c10::optional<Generato
return at::native::templates::normal_impl<NormalMeta, Generator>(mean, std, gen);
}
// functional variant, only used by the functionalization pass.
Tensor normal_functional(const Tensor& self, double mean, double std, c10::optional<at::Generator> generator) {
return self.clone().normal_(mean, std, generator);
}
// ==================================================== Random ========================================================
template<typename RNG>

View File

@ -905,7 +905,7 @@
dispatch:
CPU, CUDA: bernoulli_
MPS: bernoulli_mps_
autogen: bernoulli.Tensor_functional, bernoulli.Tensor_out
autogen: bernoulli.Tensor, bernoulli.Tensor_out
- func: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -1902,7 +1902,7 @@
dispatch:
CPU: embedding_renorm_cpu_
CUDA: embedding_renorm_cuda_
autogen: embedding_renorm.functional, embedding_renorm.out
autogen: embedding_renorm, embedding_renorm.out
- func: embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
@ -2027,7 +2027,7 @@
MPS: resize_mps_
QuantizedCPU: quantized_resize_cpu_
SparseCsrCPU, SparseCsrCUDA: resize_sparse_csr_
autogen: resize.functional, resize.out
autogen: resize, resize.out
# This is a utility function to enable users to resize out tensor while registering kernels for out variants.
# Eventually, we can consider exposing `resize_output` as a public API to ship it with python op registration
@ -2037,7 +2037,7 @@
variants: function
dispatch:
Meta: _resize_output_
autogen: _resize_output.functional, _resize_output.out
autogen: _resize_output, _resize_output.out
- func: empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
category_override: factory
@ -2561,7 +2561,7 @@
dispatch:
CPU, CUDA: _index_put_impl_
QuantizedCPU: _index_put_impl_quantized_cpu_
autogen: _index_put_impl.functional, _index_put_impl.out
autogen: _index_put_impl, _index_put_impl.out
- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
variants: function
@ -5391,7 +5391,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: resize_as_
autogen: resize_as.functional, resize_as.out
autogen: resize_as, resize_as.out
- func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
@ -5399,7 +5399,7 @@
dispatch:
SparseCPU, SparseCUDA: resize_as_sparse_
SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_csr_
autogen: resize_as_sparse.functional, resize_as_sparse.out
autogen: resize_as_sparse, resize_as_sparse.out
- func: zero_(Tensor(a!) self) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -5411,7 +5411,7 @@
SparseCPU, SparseCUDA: zero_sparse_
SparseCsrCPU, SparseCsrCUDA: zero_sparse_csr_
MkldnnCPU: mkldnn_zero_
autogen: zero.functional, zero.out
autogen: zero, zero.out
- func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -5712,14 +5712,14 @@
variants: method
dispatch:
SparseCPU, SparseCUDA: sparse_resize_
autogen: sparse_resize.functional, sparse_resize.out
autogen: sparse_resize, sparse_resize.out
- func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
variants: method
dispatch:
SparseCPU, SparseCUDA: sparse_resize_and_clear_
autogen: sparse_resize_and_clear.functional, sparse_resize_and_clear.out
autogen: sparse_resize_and_clear, sparse_resize_and_clear.out
- func: sparse_mask(Tensor self, Tensor mask) -> Tensor
variants: method
@ -5825,7 +5825,7 @@
SparseCPU, SparseCUDA: _coalesced_sparse_
device_check: NoCheck
device_guard: False
autogen: _coalesced.functional, _coalesced.out
autogen: _coalesced, _coalesced.out
- func: indices(Tensor(a) self) -> Tensor(a)
variants: method
@ -5885,7 +5885,7 @@
variants: function
dispatch:
SparseCPU, SparseCUDA: copy_sparse_
autogen: copy_sparse_to_sparse.functional, copy_sparse_to_sparse.out
autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out
- func: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]
variants: function, method
@ -6094,7 +6094,7 @@
dispatch:
CPU: fused_moving_avg_obs_fake_quant_cpu
CUDA: fused_moving_avg_obs_fake_quant_cuda
autogen: _fused_moving_avg_obs_fq_helper.functional, _fused_moving_avg_obs_fq_helper.out
autogen: _fused_moving_avg_obs_fq_helper_functional, _fused_moving_avg_obs_fq_helper.out
- func: _choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int)
variants: function
@ -6288,7 +6288,7 @@
device_guard: False
dispatch:
CPU, CUDA, Meta, MPS: set_
autogen: set.source_Storage_functional, set.source_Storage_out
autogen: set.source_Storage, set.source_Storage_out
- func: set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!)
variants: method
@ -6299,7 +6299,7 @@
CUDA: set_storage_cuda_
MPS: set_storage_mps_
QuantizedCPU, QuantizedCUDA: set_storage_quantized_
autogen: set.source_Storage_storage_offset_functional, set.source_Storage_storage_offset_out
autogen: set.source_Storage_storage_offset, set.source_Storage_storage_offset_out
- func: set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!)
variants: method
@ -6312,7 +6312,7 @@
device_guard: False
dispatch:
CPU, CUDA, Meta, MPS: set_tensor_
autogen: set.source_Tensor_functional, set.source_Tensor_out
autogen: set.source_Tensor, set.source_Tensor_out
- func: set_(Tensor(a!) self) -> Tensor(a!)
variants: method
@ -6321,7 +6321,7 @@
CUDA: set_cuda_
Meta: set_meta_
MPS: set_mps_
autogen: set.functional, set.out
autogen: set, set.out
- func: lift(Tensor self) -> Tensor
variants: method
@ -6954,7 +6954,7 @@
CPU, CUDA: random_
Meta: random_meta_
MPS: random_mps_
autogen: random.from_functional, random.from_out
autogen: random.from, random.from_out
- func: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -6963,7 +6963,7 @@
CPU, CUDA: random_
Meta: random_meta_
MPS: random_mps_
autogen: random.to_functional, random.to_out
autogen: random.to, random.to_out
- func: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -6971,7 +6971,7 @@
dispatch:
CPU, CUDA: random_
Meta: random_meta_
autogen: random.functional, random.out
autogen: random, random.out
- func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -6980,21 +6980,21 @@
CPU, CUDA: uniform_
MPS: uniform_mps_
Meta: uniform_meta_
autogen: uniform.functional, uniform.out
autogen: uniform, uniform.out
- func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method
dispatch:
CPU, CUDA: cauchy_
autogen: cauchy.functional, cauchy.out
autogen: cauchy, cauchy.out
- func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method
dispatch:
CPU, CUDA: log_normal_
autogen: log_normal.functional, log_normal.out
autogen: log_normal, log_normal.out
- func: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -7002,7 +7002,7 @@
dispatch:
CPU, CUDA: exponential_
MPS: exponential_mps_
autogen: exponential.functional, exponential.out
autogen: exponential, exponential.out
- func: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -7011,7 +7011,7 @@
CPU, CUDA: geometric_
# wrappers for TH functions
autogen: geometric.functional, geometric.out
autogen: geometric, geometric.out
- func: diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
@ -8396,7 +8396,15 @@
MPS: normal_mps_
Meta: normal_meta_
SparseCsrCPU, SparseCsrCUDA: normal_sparse_csr_
autogen: normal.functional, normal.out
autogen: normal.out
# Only used by the functionalization pass.
# Normally, the codegen would be able to generate a normal() NativeFunction,
# but we can't due to overload ambiguity with normal.Tensor_float.
- func: normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor
device_check: NoCheck # TensorIterator
dispatch:
CompositeExplicitAutograd: normal_functional
- func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
@ -8447,13 +8455,13 @@
variants: function
dispatch:
CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_
autogen: _amp_foreach_non_finite_check_and_unscale.functional, _amp_foreach_non_finite_check_and_unscale.out
autogen: _amp_foreach_non_finite_check_and_unscale, _amp_foreach_non_finite_check_and_unscale.out
- func: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!)
variants: function
dispatch:
CUDA: _amp_update_scale_cuda_
autogen: _amp_update_scale.functional, _amp_update_scale.out
autogen: _amp_update_scale, _amp_update_scale.out
#- func: _cat(Tensor[] tensors, int dim=0) -> Tensor
#dispatch:
@ -8468,7 +8476,7 @@
#CUDA: cat_out_cuda
#QuantizedCPU: cat_out_quantized_cpu
- func: _foreach_add.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
- func: _foreach_add.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8481,9 +8489,9 @@
dispatch:
CPU: foreach_tensor_add_scalar_kernel_slow_
CUDA: foreach_tensor_add_scalar_kernel_cuda_
autogen: _foreach_add.Scalar_functional, _foreach_add.Scalar_out
autogen: _foreach_add.Scalar_out
- func: _foreach_sub.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
- func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8496,9 +8504,9 @@
dispatch:
CPU: foreach_tensor_sub_scalar_kernel_slow_
CUDA: foreach_tensor_sub_scalar_kernel_cuda_
autogen: _foreach_sub.Scalar_functional, _foreach_sub.Scalar_out
autogen: _foreach_sub.Scalar_out
- func: _foreach_mul.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
- func: _foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8511,9 +8519,9 @@
dispatch:
CPU: foreach_tensor_mul_scalar_kernel_slow_
CUDA: foreach_tensor_mul_scalar_kernel_cuda_
autogen: _foreach_mul.Scalar_functional, _foreach_mul.Scalar_out
autogen: _foreach_mul.Scalar_out
- func: _foreach_div.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
- func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8526,9 +8534,9 @@
dispatch:
CPU: foreach_tensor_div_scalar_kernel_slow_
CUDA: foreach_tensor_div_scalar_kernel_cuda_
autogen: _foreach_div.Scalar_functional, _foreach_div.Scalar_out
autogen: _foreach_div.Scalar_out
- func: _foreach_add.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[]
- func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8541,9 +8549,9 @@
dispatch:
CPU: foreach_tensor_add_list_kernel_slow_
CUDA: foreach_tensor_add_list_kernel_cuda_
autogen: _foreach_add.List_functional, _foreach_add.List_out
autogen: _foreach_add.List_out
- func: _foreach_sub.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[]
- func: _foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8556,9 +8564,9 @@
dispatch:
CPU: foreach_tensor_sub_list_kernel_slow_
CUDA: foreach_tensor_sub_list_kernel_cuda_
autogen: _foreach_sub.List_functional, _foreach_sub.List_out
autogen: _foreach_sub.List_out
- func: _foreach_mul.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[]
- func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8571,9 +8579,9 @@
dispatch:
CPU: foreach_tensor_mul_list_kernel_slow_
CUDA: foreach_tensor_mul_list_kernel_cuda_
autogen: _foreach_mul.List_functional, _foreach_mul.List_out
autogen: _foreach_mul.List_out
- func: _foreach_div.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[]
- func: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8586,9 +8594,9 @@
dispatch:
CPU: foreach_tensor_div_list_kernel_slow_
CUDA: foreach_tensor_div_list_kernel_cuda_
autogen: _foreach_div.List_functional, _foreach_div.List_out
autogen: _foreach_div.List_out
- func: _foreach_add.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
- func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8601,9 +8609,9 @@
dispatch:
CPU: foreach_tensor_add_scalarlist_kernel_slow_
CUDA: foreach_tensor_add_scalarlist_kernel_cuda_
autogen: _foreach_add.ScalarList_functional, _foreach_add.ScalarList_out
autogen: _foreach_add.ScalarList_out
- func: _foreach_sub.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
- func: _foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8616,9 +8624,9 @@
dispatch:
CPU: foreach_tensor_sub_scalarlist_kernel_slow_
CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_
autogen: _foreach_sub.ScalarList_functional, _foreach_sub.ScalarList_out
autogen: _foreach_sub.ScalarList_out
- func: _foreach_div.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
- func: _foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8631,9 +8639,9 @@
dispatch:
CPU: foreach_tensor_div_scalarlist_kernel_slow_
CUDA: foreach_tensor_div_scalarlist_kernel_cuda_
autogen: _foreach_div.ScalarList_functional, _foreach_div.ScalarList_out
autogen: _foreach_div.ScalarList_out
- func: _foreach_mul.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
- func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8646,9 +8654,9 @@
dispatch:
CPU: foreach_tensor_mul_scalarlist_kernel_slow_
CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_
autogen: _foreach_mul.ScalarList_functional, _foreach_mul.ScalarList_out
autogen: _foreach_mul.ScalarList_out
- func: _foreach_exp(Tensor[] tensors) -> Tensor[]
- func: _foreach_exp(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8661,7 +8669,7 @@
dispatch:
CPU: foreach_tensor_zero_slow_
CUDA: foreach_tensor_zero_cuda_
autogen: _foreach_zero.functional, _foreach_zero.out
autogen: _foreach_zero, _foreach_zero.out
- func: _foreach_exp_(Tensor(a!)[] self) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8669,9 +8677,9 @@
dispatch:
CPU: foreach_tensor_exp_slow_
CUDA: foreach_tensor_exp_cuda_
autogen: _foreach_exp.functional, _foreach_exp.out
autogen: _foreach_exp.out
- func: _foreach_sqrt(Tensor[] tensors) -> Tensor[]
- func: _foreach_sqrt(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8684,9 +8692,9 @@
dispatch:
CPU: foreach_tensor_sqrt_slow_
CUDA: foreach_tensor_sqrt_cuda_
autogen: _foreach_sqrt.functional, _foreach_sqrt.out
autogen: _foreach_sqrt.out
- func: _foreach_abs(Tensor[] tensors) -> Tensor[]
- func: _foreach_abs(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8699,9 +8707,9 @@
dispatch:
CPU: foreach_tensor_abs_slow_
CUDA: foreach_tensor_abs_cuda_
autogen: _foreach_abs.functional, _foreach_abs.out
autogen: _foreach_abs.out
- func: _foreach_acos(Tensor[] tensors) -> Tensor[]
- func: _foreach_acos(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8714,9 +8722,9 @@
dispatch:
CPU: foreach_tensor_acos_slow_
CUDA: foreach_tensor_acos_cuda_
autogen: _foreach_acos.functional, _foreach_acos.out
autogen: _foreach_acos.out
- func: _foreach_asin(Tensor[] tensors) -> Tensor[]
- func: _foreach_asin(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8729,9 +8737,9 @@
dispatch:
CPU: foreach_tensor_asin_slow_
CUDA: foreach_tensor_asin_cuda_
autogen: _foreach_asin.functional, _foreach_asin.out
autogen: _foreach_asin.out
- func: _foreach_atan(Tensor[] tensors) -> Tensor[]
- func: _foreach_atan(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8744,9 +8752,9 @@
dispatch:
CPU: foreach_tensor_atan_slow_
CUDA: foreach_tensor_atan_cuda_
autogen: _foreach_atan.functional, _foreach_atan.out
autogen: _foreach_atan.out
- func: _foreach_ceil(Tensor[] tensors) -> Tensor[]
- func: _foreach_ceil(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8759,9 +8767,9 @@
dispatch:
CPU: foreach_tensor_ceil_slow_
CUDA: foreach_tensor_ceil_cuda_
autogen: _foreach_ceil.functional, _foreach_ceil.out
autogen: _foreach_ceil.out
- func: _foreach_cos(Tensor[] tensors) -> Tensor[]
- func: _foreach_cos(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8774,9 +8782,9 @@
dispatch:
CPU: foreach_tensor_cos_slow_
CUDA: foreach_tensor_cos_cuda_
autogen: _foreach_cos.functional, _foreach_cos.out
autogen: _foreach_cos.out
- func: _foreach_cosh(Tensor[] tensors) -> Tensor[]
- func: _foreach_cosh(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8789,9 +8797,9 @@
dispatch:
CPU: foreach_tensor_cosh_slow_
CUDA: foreach_tensor_cosh_cuda_
autogen: _foreach_cosh.functional, _foreach_cosh.out
autogen: _foreach_cosh.out
- func: _foreach_erf(Tensor[] tensors) -> Tensor[]
- func: _foreach_erf(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8804,9 +8812,9 @@
dispatch:
CPU: foreach_tensor_erf_slow_
CUDA: foreach_tensor_erf_cuda_
autogen: _foreach_erf.functional, _foreach_erf.out
autogen: _foreach_erf.out
- func: _foreach_erfc(Tensor[] tensors) -> Tensor[]
- func: _foreach_erfc(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8819,9 +8827,9 @@
dispatch:
CPU: foreach_tensor_erfc_slow_
CUDA: foreach_tensor_erfc_cuda_
autogen: _foreach_erfc.functional, _foreach_erfc.out
autogen: _foreach_erfc.out
- func: _foreach_expm1(Tensor[] tensors) -> Tensor[]
- func: _foreach_expm1(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8834,9 +8842,9 @@
dispatch:
CPU: foreach_tensor_expm1_slow_
CUDA: foreach_tensor_expm1_cuda_
autogen: _foreach_expm1.functional, _foreach_expm1.out
autogen: _foreach_expm1.out
- func: _foreach_floor(Tensor[] tensors) -> Tensor[]
- func: _foreach_floor(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8849,9 +8857,9 @@
dispatch:
CPU: foreach_tensor_floor_slow_
CUDA: foreach_tensor_floor_cuda_
autogen: _foreach_floor.functional, _foreach_floor.out
autogen: _foreach_floor.out
- func: _foreach_log(Tensor[] tensors) -> Tensor[]
- func: _foreach_log(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8864,9 +8872,9 @@
dispatch:
CPU: foreach_tensor_log_slow_
CUDA: foreach_tensor_log_cuda_
autogen: _foreach_log.functional, _foreach_log.out
autogen: _foreach_log.out
- func: _foreach_log10(Tensor[] tensors) -> Tensor[]
- func: _foreach_log10(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8879,9 +8887,9 @@
dispatch:
CPU: foreach_tensor_log10_slow_
CUDA: foreach_tensor_log10_cuda_
autogen: _foreach_log10.functional, _foreach_log10.out
autogen: _foreach_log10.out
- func: _foreach_log1p(Tensor[] tensors) -> Tensor[]
- func: _foreach_log1p(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8894,9 +8902,9 @@
dispatch:
CPU: foreach_tensor_log1p_slow_
CUDA: foreach_tensor_log1p_cuda_
autogen: _foreach_log1p.functional, _foreach_log1p.out
autogen: _foreach_log1p.out
- func: _foreach_log2(Tensor[] tensors) -> Tensor[]
- func: _foreach_log2(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8909,9 +8917,9 @@
dispatch:
CPU: foreach_tensor_log2_slow_
CUDA: foreach_tensor_log2_cuda_
autogen: _foreach_log2.functional, _foreach_log2.out
autogen: _foreach_log2.out
- func: _foreach_neg(Tensor[] tensors) -> Tensor[]
- func: _foreach_neg(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8924,9 +8932,9 @@
dispatch:
CPU: foreach_tensor_neg_slow_
CUDA: foreach_tensor_neg_cuda_
autogen: _foreach_neg.functional, _foreach_neg.out
autogen: _foreach_neg.out
- func: _foreach_tan(Tensor[] tensors) -> Tensor[]
- func: _foreach_tan(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8939,9 +8947,9 @@
dispatch:
CPU: foreach_tensor_tan_slow_
CUDA: foreach_tensor_tan_cuda_
autogen: _foreach_tan.functional, _foreach_tan.out
autogen: _foreach_tan.out
- func: _foreach_tanh(Tensor[] tensors) -> Tensor[]
- func: _foreach_tanh(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8954,9 +8962,9 @@
dispatch:
CPU: foreach_tensor_tanh_slow_
CUDA: foreach_tensor_tanh_cuda_
autogen: _foreach_tanh.functional, _foreach_tanh.out
autogen: _foreach_tanh.out
- func: _foreach_sin(Tensor[] tensors) -> Tensor[]
- func: _foreach_sin(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8969,9 +8977,9 @@
dispatch:
CPU: foreach_tensor_sin_slow_
CUDA: foreach_tensor_sin_cuda_
autogen: _foreach_sin.functional, _foreach_sin.out
autogen: _foreach_sin.out
- func: _foreach_sinh(Tensor[] tensors) -> Tensor[]
- func: _foreach_sinh(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8984,9 +8992,9 @@
dispatch:
CPU: foreach_tensor_sinh_slow_
CUDA: foreach_tensor_sinh_cuda_
autogen: _foreach_sinh.functional, _foreach_sinh.out
autogen: _foreach_sinh.out
- func: _foreach_round(Tensor[] tensors) -> Tensor[]
- func: _foreach_round(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -8999,9 +9007,9 @@
dispatch:
CPU: foreach_tensor_round_slow_
CUDA: foreach_tensor_round_cuda_
autogen: _foreach_round.functional, _foreach_round.out
autogen: _foreach_round.out
- func: _foreach_lgamma(Tensor[] tensors) -> Tensor[]
- func: _foreach_lgamma(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -9014,9 +9022,9 @@
dispatch:
CPU: foreach_tensor_lgamma_slow_
CUDA: foreach_tensor_lgamma_cuda_
autogen: _foreach_lgamma.functional, _foreach_lgamma.out
autogen: _foreach_lgamma.out
- func: _foreach_frac(Tensor[] tensors) -> Tensor[]
- func: _foreach_frac(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -9029,9 +9037,9 @@
dispatch:
CPU: foreach_tensor_frac_slow_
CUDA: foreach_tensor_frac_cuda_
autogen: _foreach_frac.functional, _foreach_frac.out
autogen: _foreach_frac.out
- func: _foreach_reciprocal(Tensor[] tensors) -> Tensor[]
- func: _foreach_reciprocal(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -9044,9 +9052,9 @@
dispatch:
CPU: foreach_tensor_reciprocal_slow_
CUDA: foreach_tensor_reciprocal_cuda_
autogen: _foreach_reciprocal.functional, _foreach_reciprocal.out
autogen: _foreach_reciprocal.out
- func: _foreach_sigmoid(Tensor[] tensors) -> Tensor[]
- func: _foreach_sigmoid(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -9059,9 +9067,9 @@
dispatch:
CPU: foreach_tensor_sigmoid_slow_
CUDA: foreach_tensor_sigmoid_cuda_
autogen: _foreach_sigmoid.functional, _foreach_sigmoid.out
autogen: _foreach_sigmoid.out
- func: _foreach_trunc(Tensor[] tensors) -> Tensor[]
- func: _foreach_trunc(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -9074,7 +9082,7 @@
dispatch:
CPU: foreach_tensor_trunc_slow_
CUDA: foreach_tensor_trunc_cuda_
autogen: _foreach_trunc.functional, _foreach_trunc.out
autogen: _foreach_trunc.out
- func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -9082,7 +9090,7 @@
dispatch:
CPU: foreach_tensor_addcdiv_scalar_slow_
CUDA: foreach_tensor_addcdiv_scalar_cuda_
autogen: _foreach_addcdiv.Scalar_functional, _foreach_addcdiv.Scalar_out
autogen: _foreach_addcdiv.Scalar_out
- func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -9090,7 +9098,7 @@
dispatch:
CPU: foreach_tensor_addcmul_scalar_slow_
CUDA: foreach_tensor_addcmul_scalar_cuda_
autogen: _foreach_addcmul.Scalar_functional, _foreach_addcmul.Scalar_out
autogen: _foreach_addcmul.Scalar_out
- func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -9098,7 +9106,7 @@
dispatch:
CPU: foreach_tensor_addcdiv_scalarlist_slow_
CUDA: foreach_tensor_addcdiv_scalarlist_cuda_
autogen: _foreach_addcdiv.ScalarList_functional, _foreach_addcdiv.ScalarList_out
autogen: _foreach_addcdiv.ScalarList_out
- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -9106,51 +9114,51 @@
dispatch:
CPU: foreach_tensor_addcmul_scalarlist_slow_
CUDA: foreach_tensor_addcmul_scalarlist_cuda_
autogen: _foreach_addcmul.ScalarList_functional, _foreach_addcmul.ScalarList_out
autogen: _foreach_addcmul.ScalarList_out
- func: _foreach_addcdiv.Scalar(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
- func: _foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CPU: foreach_tensor_addcdiv_scalar_slow
CUDA: foreach_tensor_addcdiv_scalar_cuda
- func: _foreach_addcmul.Scalar(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
- func: _foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CPU: foreach_tensor_addcmul_scalar_slow
CUDA: foreach_tensor_addcmul_scalar_cuda
- func: _foreach_addcdiv.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
- func: _foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CPU: foreach_tensor_addcdiv_scalarlist_slow
CUDA: foreach_tensor_addcdiv_scalarlist_cuda
- func: _foreach_addcmul.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
- func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CPU: foreach_tensor_addcmul_scalarlist_slow
CUDA: foreach_tensor_addcmul_scalarlist_cuda
- func: _foreach_maximum.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[]
- func: _foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CPU: foreach_tensor_maximum_slow
CUDA: foreach_tensor_maximum_cuda
- func: _foreach_minimum.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[]
- func: _foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CPU: foreach_tensor_minimum_slow
CUDA: foreach_tensor_minimum_cuda
- func: _foreach_norm.Scalar(Tensor[] tensors, Scalar ord=2) -> Tensor[]
- func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
@ -11675,7 +11683,7 @@
dispatch:
CPU: _linalg_inv_out_helper_cpu
CUDA: _linalg_inv_out_helper_cuda
autogen: _linalg_inv_out_helper.functional, _linalg_inv_out_helper.out
autogen: _linalg_inv_out_helper, _linalg_inv_out_helper.out
- func: linalg_inv_ex(Tensor self, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
python_module: linalg

View File

@ -101,9 +101,9 @@ full_codegen:
- norm.ScalarOpt_dim
- pow.Tensor_Scalar
- pow.Tensor_Tensor
- random.functional
- random.from_functional
- random.to_functional
- random
- random.from
- random.to
- reciprocal
- relu
- remainder.Tensor
@ -140,7 +140,7 @@ full_codegen:
- upsample_bilinear2d_backward
- upsample_nearest2d
- upsample_nearest2d_backward
- zero.functional
- zero
- narrow_copy.SymInt
- alias_copy
- as_strided_copy

View File

@ -99,6 +99,8 @@ ALLOW_LIST = [
("aten::_segment_reduce_backward", datetime.date(2022, 6, 30)),
("aten::empty.SymInt", datetime.date(9999, 1, 1)),
("c10d::broadcast", datetime.date(2022, 6, 25)),
("aten::.*functional", datetime.date(2022, 8, 1)),
("aten::_foreach.*", datetime.date(2022, 8, 1)),
# TODO: FIXME: prims shouldn't be checked
("prims::.*", datetime.date(9999, 1, 1)),
]

View File

@ -203,7 +203,7 @@ $2 = torch._ops.aten.add.Tensor($0, tensor([[1., 1.],
logs = self.get_logs(f, torch.ones(1))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper.functional($0, $0, $0, $0, $0, $0, $0, 1.0, 0, 1, 0)""")
$1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper_functional.default($0, $0, $0, $0, $0, $0, $0, 1.0, 0, 1, 0)""") # noqa: B950
def test_as_strided(self):
def f(x):
@ -527,7 +527,7 @@ $3 = torch._ops.aten.fill.Scalar($2, 0)""")
$0 = input('input')
$1 = torch._ops.aten.add.Tensor($0, 1)
$2 = torch._ops.aten.view_copy.default($1, [4, 4])
$3 = torch._ops.aten.resize.functional($2, [3, 3])
$3 = torch._ops.aten.resize.default($2, [3, 3])
$4 = torch._ops.aten.as_strided_copy.default($2, [3, 3], [3, 1])
$5 = torch._ops.aten.view_copy.default($4, [-1])
$6 = torch._ops.aten.add.Tensor($5, 1)
@ -562,7 +562,7 @@ $14 = torch._ops.aten.add.Tensor($13, 1)""")
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.add.Tensor($0, 1)
$2 = torch._ops.aten.resize.functional($1, [5, 5])
$2 = torch._ops.aten.resize.default($1, [5, 5])
$3 = torch._ops.aten.view_copy.default($2, [25])
$4 = torch._ops.aten.fill.Scalar($3, 1)
$5 = torch._ops.aten.view_copy.default($4, [5, 5])

View File

@ -156,6 +156,7 @@ _SKIP_PYTHON_BINDINGS = [
"fill.Tensor", # only used by the functionalization pass
"fill.Scalar", # only used by the functionalization pass
"lift",
"normal_functional", # only used by the functionalization pas
]
SKIP_PYTHON_BINDINGS = list(

View File

@ -36,7 +36,7 @@ _device_not_kwarg_ops = (
aten.to.device,
aten.to.prim_Device,
aten._pin_memory.default,
aten._resize_output.functional,
aten._resize_output.default,
aten._resize_output.out,
)

View File

@ -679,25 +679,25 @@ std::vector<Shape> compute_shape_native_dropout_backward(
return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())};
}
std::vector<Shape> compute_shape_random_functional(
std::vector<Shape> compute_shape_random(
const at::Tensor& self,
c10::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape> compute_shape_random_functional(
std::vector<Shape> compute_shape_random(
const at::Tensor& self,
int64_t to,
c10::optional<at::Generator> generator) {
return compute_shape_random_functional(self, generator);
return compute_shape_random(self, generator);
}
std::vector<Shape> compute_shape_random_functional(
std::vector<Shape> compute_shape_random(
const at::Tensor& self,
int64_t from,
c10::optional<int64_t> to,
c10::optional<at::Generator> generator) {
return compute_shape_random_functional(self, generator);
return compute_shape_random(self, generator);
}
std::vector<Shape> compute_shape_relu(const at::Tensor& self) {
@ -725,7 +725,7 @@ std::vector<Shape> compute_shape_sum(
;
}
std::vector<Shape> compute_shape_zero_functional(const at::Tensor& self) {
std::vector<Shape> compute_shape_zero(const at::Tensor& self) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

View File

@ -66,9 +66,9 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_new_empty_strided(const
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nonzero(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_functional(const at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_functional(const at::Tensor & self, int64_t to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_functional(const at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, int64_t to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_relu(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_slogdet(const at::Tensor & self);
@ -81,7 +81,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_std(const at::Tensor & s
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sum(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<torch::lazy::Shape> compute_shape__to_copy(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, bool non_blocking, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_trace(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_functional(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy_symint(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length);
// Non-Native ops

View File

@ -65,8 +65,6 @@ from typing import Optional, Sequence, Union, List, Set
def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
name = str(func.name.name)
if func.is_functional_fn():
name += "_functional"
if func.is_symint_fn():
name += "_symint"
if func.is_out_fn():

View File

@ -872,6 +872,8 @@ class NativeFunctionsGroup:
if self.mutable is not None:
assert self.mutable.func.kind() == SchemaKind.mutable
assert self.mutable.namespace == self.functional.namespace
# See Note [Overload Ambiguity With Functional Variants]
assert self.functional.func.name.name.functional_overload
if self.structured:
# For now, structured composite kernels are not supported (need some
@ -901,7 +903,7 @@ class NativeFunctionsGroup:
raise RuntimeError(
f"The codegen expects to be able to generate '{generated_fns_str}'."
f" To do so, it expects a line: 'autogen: {generated_fns_str}'."
f" Instead, it found 'autogen: {generated_fns_str}'"
f" Instead, it found 'autogen: {expected_generated_fns_str}'"
)
def signature(self) -> "FunctionSchema":
@ -2135,6 +2137,26 @@ class BaseOperatorName:
base: str
inplace: bool
dunder_method: bool
# Note [Overload Ambiguity With Functional Variants]
# A handful of operators have both a "mutable" and a "functional" variant.
# (native_batch_norm is a good example, although this isn't the case today).
# For those operators, the mutable and functional variant take in the same set of
# arguments, but have different alias annotations.
# this makes it ambiguous when you try to resolve an OverloadPacket into an overload,
# given a set of input arguments.
#
# So instead of making the "functional" variant in this case a real overload, e.g:
# native_batch_norm (mutable variant)
# native_batch_norm.functional (functional variant)
# we make it a new base operator,
# native_batch_norm_functional (functional variant)
#
# In an ideal world, we would probably invert this so the operators were:
# native_batch_norm.mutable (mutable variant)
# native_batch_norm (functional variant)
#
# Doing that is BC-breaking though, so we're stuck with the above modeling.
functional_overload: bool = False
@staticmethod
def parse(op: str) -> "BaseOperatorName":
@ -2165,7 +2187,24 @@ class BaseOperatorName:
base = base[:-1]
else:
inplace = False
r = BaseOperatorName(base=base, inplace=inplace, dunder_method=dunder_method)
# See Note [Overload Ambiguity With Functional Variants]
functional_suffix = "_functional"
if base.endswith(functional_suffix):
functional_overload = True
base = base[: -len(functional_suffix)]
# This seems complicated and unnecessary, so banning dunder methods
# for now on ops that have a functional + mutable variant (like native_batch_norm).
assert not dunder_method and not inplace
else:
functional_overload = False
r = BaseOperatorName(
base=base,
inplace=inplace,
dunder_method=dunder_method,
functional_overload=functional_overload,
)
assert str(r) == op, f"{str(r)} != {op}"
return r
@ -2174,7 +2213,13 @@ class BaseOperatorName:
i = "i" if self.inplace else ""
return f"__{i}{self.base}__"
else:
i = "_" if self.inplace else ""
i = (
"_"
if self.inplace
else "_functional"
if self.functional_overload
else ""
)
return f"{self.base}{i}"

View File

@ -11,6 +11,7 @@ from torchgen.model import (
Argument,
BackendIndex,
BackendMetadata,
BaseOperatorName,
BaseTy,
BaseType,
DEFAULT_KERNEL_NAMESPACE,
@ -193,14 +194,20 @@ def generate_function(
# The new "functional" NativeFunction has:
# - any mutable arguments have been converted into (immutable) returns.
# (if a mutable argument was not also a return, it gets converted to one)
# - a "functional" overload name.
# - "_functional" appended to the base name, ONLY IF this op has a mutable variant.
# See Note [Overload Ambiguity With Functional Variants]
# The default grouping logic in signature() actually already does this,
# so we can piggy-back off it (but we still want return names)
func = f.func.signature(keep_return_names=True).with_name(
f.func.name.remove_inplace().with_overload(
"functional"
if not f.func.name.overload_name
else f"{f.func.name.overload_name}_functional"
OperatorName(
name=BaseOperatorName(
base=f.func.name.name.base,
inplace=False,
dunder_method=f.func.name.name.dunder_method,
# See Note [Overload Ambiguity With Functional Variants]
functional_overload=f.func.kind() == SchemaKind.mutable,
),
overload_name=f.func.name.overload_name,
)
)
elif k == SchemaKind.out: