[test] attempt to functionalize ops with mutable positional-only args

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76320

Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2022-05-19 09:00:41 -07:00
committed by PyTorch MergeBot
parent b8639cf6e1
commit 0161e9eb00
22 changed files with 1281 additions and 274 deletions

View File

@ -505,6 +505,7 @@
variants: function
dispatch:
CPU: add_relu_
autogen: _add_relu.Scalar_out
# For C++ only, until we have conversion from C++ numbers to Tensor
- func: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
@ -518,6 +519,7 @@
variants: method
dispatch:
CompositeExplicitAutograd: add_
autogen: add.Scalar_out
- func: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor
structured_delegate: addmv.out
@ -609,6 +611,12 @@
- func: arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
# Note [arange.start_step schema]
# We want `arange.start_step` to be grouped up with `arange.start_out`,
# But this doesn't happen automatically because the step argument
# is defaultable for .start_out but not for .start_step.
# We should probably just make "step" a defaultable param on arange.start,
# and kill arange.start_step.
- func: arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- func: arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!)
@ -892,6 +900,7 @@
dispatch:
CPU, CUDA: bernoulli_
MPS: bernoulli_mps_
autogen: bernoulli.Tensor_functional, bernoulli.Tensor_out
- func: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -899,7 +908,10 @@
dispatch:
CPU, CUDA: bernoulli_
MPS: bernoulli_mps_
autogen: bernoulli.float_out
# Note [bernoulli.p schema]
# We should probably just fix the overload ambiguity by appending a _functional to the C++ API name (BC breaking)
# This out-of-place version isn't used explicitly, but needed by jit.
# There is no default valid on `p` here because it would introduce ambiguity
# with `bernoulli(Tensor self, *, Generator? generator=None)` declaration.
@ -1420,6 +1432,7 @@
SparseCPU, SparseCUDA: copy_sparse_wrapper_
CompositeExplicitAutograd: copy_
SparseCsrCPU, SparseCsrCUDA: copy_sparse_compressed_
autogen: copy.out
- func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor
dispatch:
@ -1780,6 +1793,7 @@
variants: method
dispatch:
CompositeExplicitAutograd: div_
autogen: div.Scalar_out
- func: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor
variants: function, method
@ -1790,6 +1804,7 @@
variants: method
dispatch:
CompositeExplicitAutograd: div_
autogen: div.Scalar_mode_out
# divide, alias for div
- func: divide.Tensor(Tensor self, Tensor other) -> Tensor
@ -1880,6 +1895,7 @@
dispatch:
CPU: embedding_renorm_cpu_
CUDA: embedding_renorm_cuda_
autogen: embedding_renorm.functional, embedding_renorm.out
- func: embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
@ -1993,6 +2009,7 @@
MPS: resize_mps_
QuantizedCPU: quantized_resize_cpu_
SparseCsrCPU, SparseCsrCUDA: resize_sparse_csr_
autogen: resize.functional, resize.out
# This is a utility function to enable users to resize out tensor while registering kernels for out variants.
# Eventually, we can consider exposing `resize_output` as a public API to ship it with python op registration
@ -2002,6 +2019,7 @@
variants: function
dispatch:
Meta: _resize_output_
autogen: _resize_output.functional, _resize_output.out
- func: empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
category_override: factory
@ -2201,6 +2219,7 @@
QuantizedCPU, QuantizedCUDA: fill_quantized_
Meta: fill_meta_
SparseCsrCPU, SparseCsrCUDA: fill_sparse_csr_
autogen: fill.Scalar_out
- func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -2210,6 +2229,7 @@
MPS: fill_tensor_mps_
QuantizedCPU, QuantizedCUDA: fill_quantized_
Meta: fill_meta_
autogen: fill.Tensor_out
- func: floor(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
@ -2494,6 +2514,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: index_put_
autogen: index_put.out
# NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp:
# - Tensor & Tensor::index_put_(ArrayRef<TensorIndex> indices, Tensor const & rhs)
# - Tensor & Tensor::index_put_(ArrayRef<TensorIndex> indices, Scalar v)
@ -2511,6 +2532,7 @@
variants: function
dispatch:
CPU, CUDA: _index_put_impl_
autogen: _index_put_impl.functional, _index_put_impl.out
- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
variants: function
@ -3380,6 +3402,7 @@
dispatch:
CompositeExplicitAutograd: mul_
SparseCsrCPU, SparseCsrCUDA: mul__scalar_sparse_csr
autogen: mul.Scalar_out
# multiply, alias for mul
- func: multiply.Tensor(Tensor self, Tensor other) -> Tensor
@ -3918,6 +3941,7 @@
MkldnnCPU: mkldnn_relu_
QuantizedCPU: relu_quantized_cpu_
NestedTensorCPU, NestedTensorCUDA: NestedTensor_relu_
autogen: relu.out
- func: relu6(Tensor self) -> Tensor
python_module: nn
@ -4061,6 +4085,7 @@
device_check: NoCheck # TensorIterator
dispatch:
CompositeExplicitAutograd: celu_
autogen: celu.out
- func: silu(Tensor self) -> Tensor
structured_delegate: silu.out
@ -4800,6 +4825,7 @@
device_guard: False
dispatch:
MkldnnCPU: mkldnn_transpose_
autogen: _mkldnn_transpose.out
- func: one_hot(Tensor self, int num_classes=-1) -> Tensor
python_module: nn
@ -5312,6 +5338,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: resize_as_
autogen: resize_as.functional, resize_as.out
- func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
@ -5319,6 +5346,7 @@
dispatch:
SparseCPU, SparseCUDA: resize_as_sparse_
SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_csr_
autogen: resize_as_sparse.functional, resize_as_sparse.out
- func: zero_(Tensor(a!) self) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -5330,6 +5358,7 @@
SparseCPU, SparseCUDA: zero_sparse_
SparseCsrCPU, SparseCsrCUDA: zero_sparse_csr_
MkldnnCPU: mkldnn_zero_
autogen: zero.functional, zero.out
- func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -5367,6 +5396,7 @@
variants: method
dispatch:
CompositeExplicitAutograd: sub_
autogen: sub.Scalar_out
# subtract, alias for sub
- func: subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
@ -5629,12 +5659,14 @@
variants: method
dispatch:
SparseCPU, SparseCUDA: sparse_resize_
autogen: sparse_resize.functional, sparse_resize.out
- func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
variants: method
dispatch:
SparseCPU, SparseCUDA: sparse_resize_and_clear_
autogen: sparse_resize_and_clear.functional, sparse_resize_and_clear.out
- func: sparse_mask(Tensor self, Tensor mask) -> Tensor
variants: method
@ -5740,6 +5772,7 @@
SparseCPU, SparseCUDA: _coalesced_sparse_
device_check: NoCheck
device_guard: False
autogen: _coalesced.functional, _coalesced.out
- func: indices(Tensor(a) self) -> Tensor(a)
variants: method
@ -5799,6 +5832,7 @@
variants: function
dispatch:
SparseCPU, SparseCUDA: copy_sparse_
autogen: copy_sparse_to_sparse.functional, copy_sparse_to_sparse.out
- func: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]
variants: function, method
@ -6007,7 +6041,7 @@
dispatch:
CPU: fused_moving_avg_obs_fake_quant_cpu
CUDA: fused_moving_avg_obs_fake_quant_cuda
autogen: _fused_moving_avg_obs_fq_helper.functional, _fused_moving_avg_obs_fq_helper.out
- func: _choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int)
variants: function
@ -6201,6 +6235,7 @@
device_guard: False
dispatch:
CPU, CUDA, Meta, MPS: set_
autogen: set.source_Storage_functional, set.source_Storage_out
- func: set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!)
variants: method
@ -6211,6 +6246,7 @@
CUDA: set_storage_cuda_
MPS: set_storage_mps_
QuantizedCPU, QuantizedCUDA: set_storage_quantized_
autogen: set.source_Storage_storage_offset_functional, set.source_Storage_storage_offset_out
- func: set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!)
variants: method
@ -6223,6 +6259,7 @@
device_guard: False
dispatch:
CPU, CUDA, Meta, MPS: set_tensor_
autogen: set.source_Tensor_functional, set.source_Tensor_out
- func: set_(Tensor(a!) self) -> Tensor(a!)
variants: method
@ -6231,6 +6268,7 @@
CUDA: set_cuda_
Meta: set_meta_
MPS: set_mps_
autogen: set.functional, set.out
- func: lift(Tensor self) -> Tensor
variants: method
@ -6253,6 +6291,7 @@
CPU: masked_fill__cpu
CUDA: masked_fill__cuda
MPS: masked_fill__mps
autogen: masked_fill.Scalar_out
- func: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor
device_check: NoCheck # TensorIterator
@ -6267,6 +6306,7 @@
CPU: masked_fill__cpu
CUDA: masked_fill__cuda
MPS: masked_fill__mps
autogen: masked_fill.Tensor_out
- func: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor
device_check: NoCheck # TensorIterator
@ -6279,6 +6319,7 @@
dispatch:
CPU: masked_scatter__cpu
CUDA: masked_scatter__cuda
autogen: masked_scatter.out
- func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor
variants: function, method
@ -6320,6 +6361,7 @@
variants: method
dispatch:
CPU, CUDA, MPS: put_
autogen: put.out
- func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
variants: function, method
@ -6367,6 +6409,7 @@
dispatch:
CPU: index_fill_
CUDA: index_fill_
autogen: index_fill.int_Scalar_out
- func: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
device_check: NoCheck # TensorIterator
@ -6379,6 +6422,7 @@
variants: method
dispatch:
CPU, CUDA: index_fill_
autogen: index_fill.int_Tensor_out
- func: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor
device_check: NoCheck # TensorIterator
@ -6695,12 +6739,14 @@
variants: method
dispatch:
CPU, CUDA: __ilshift__
autogen: __lshift__.Scalar_out
- func: __ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method
dispatch:
CPU, CUDA: __ilshift__
autogen: __lshift__.Tensor_out
- func: bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
@ -6760,12 +6806,14 @@
variants: method
dispatch:
CPU, CUDA: __irshift__
autogen: __rshift__.Scalar_out
- func: __irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method
dispatch:
CPU, CUDA: __irshift__
autogen: __rshift__.Tensor_out
- func: bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
@ -6855,6 +6903,7 @@
CPU, CUDA: random_
Meta: random_meta_
MPS: random_mps_
autogen: random.from_functional, random.from_out
- func: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -6863,6 +6912,7 @@
CPU, CUDA: random_
Meta: random_meta_
MPS: random_mps_
autogen: random.to_functional, random.to_out
- func: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -6870,6 +6920,7 @@
dispatch:
CPU, CUDA: random_
Meta: random_meta_
autogen: random.functional, random.out
- func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -6878,24 +6929,28 @@
CPU, CUDA: uniform_
MPS: uniform_mps_
Meta: uniform_meta_
autogen: uniform.functional, uniform.out
- func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method
dispatch:
CPU, CUDA: cauchy_
autogen: cauchy.functional, cauchy.out
- func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method
dispatch:
CPU, CUDA: log_normal_
autogen: log_normal.functional, log_normal.out
- func: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method
dispatch:
CPU, CUDA: exponential_
autogen: exponential.functional, exponential.out
- func: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -6904,6 +6959,7 @@
CPU, CUDA: geometric_
# wrappers for TH functions
autogen: geometric.functional, geometric.out
- func: diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
@ -8306,6 +8362,7 @@
MPS: normal_mps_
Meta: normal_meta_
SparseCsrCPU, SparseCsrCUDA: normal_sparse_csr_
autogen: normal.functional, normal.out
- func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
@ -8356,11 +8413,13 @@
variants: function
dispatch:
CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_
autogen: _amp_foreach_non_finite_check_and_unscale.functional, _amp_foreach_non_finite_check_and_unscale.out
- func: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!)
variants: function
dispatch:
CUDA: _amp_update_scale_cuda_
autogen: _amp_update_scale.functional, _amp_update_scale.out
#- func: _cat(Tensor[] tensors, int dim=0) -> Tensor
#dispatch:
@ -8388,6 +8447,7 @@
dispatch:
CPU: foreach_tensor_add_scalar_kernel_slow_
CUDA: foreach_tensor_add_scalar_kernel_cuda_
autogen: _foreach_add.Scalar_functional, _foreach_add.Scalar_out
- func: _foreach_sub.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8402,6 +8462,7 @@
dispatch:
CPU: foreach_tensor_sub_scalar_kernel_slow_
CUDA: foreach_tensor_sub_scalar_kernel_cuda_
autogen: _foreach_sub.Scalar_functional, _foreach_sub.Scalar_out
- func: _foreach_mul.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8416,6 +8477,7 @@
dispatch:
CPU: foreach_tensor_mul_scalar_kernel_slow_
CUDA: foreach_tensor_mul_scalar_kernel_cuda_
autogen: _foreach_mul.Scalar_functional, _foreach_mul.Scalar_out
- func: _foreach_div.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8430,6 +8492,7 @@
dispatch:
CPU: foreach_tensor_div_scalar_kernel_slow_
CUDA: foreach_tensor_div_scalar_kernel_cuda_
autogen: _foreach_div.Scalar_functional, _foreach_div.Scalar_out
- func: _foreach_add.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8444,6 +8507,7 @@
dispatch:
CPU: foreach_tensor_add_list_kernel_slow_
CUDA: foreach_tensor_add_list_kernel_cuda_
autogen: _foreach_add.List_functional, _foreach_add.List_out
- func: _foreach_sub.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8458,6 +8522,7 @@
dispatch:
CPU: foreach_tensor_sub_list_kernel_slow_
CUDA: foreach_tensor_sub_list_kernel_cuda_
autogen: _foreach_sub.List_functional, _foreach_sub.List_out
- func: _foreach_mul.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8472,6 +8537,7 @@
dispatch:
CPU: foreach_tensor_mul_list_kernel_slow_
CUDA: foreach_tensor_mul_list_kernel_cuda_
autogen: _foreach_mul.List_functional, _foreach_mul.List_out
- func: _foreach_div.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8486,6 +8552,7 @@
dispatch:
CPU: foreach_tensor_div_list_kernel_slow_
CUDA: foreach_tensor_div_list_kernel_cuda_
autogen: _foreach_div.List_functional, _foreach_div.List_out
- func: _foreach_add.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8500,6 +8567,7 @@
dispatch:
CPU: foreach_tensor_add_scalarlist_kernel_slow_
CUDA: foreach_tensor_add_scalarlist_kernel_cuda_
autogen: _foreach_add.ScalarList_functional, _foreach_add.ScalarList_out
- func: _foreach_sub.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8514,6 +8582,7 @@
dispatch:
CPU: foreach_tensor_sub_scalarlist_kernel_slow_
CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_
autogen: _foreach_sub.ScalarList_functional, _foreach_sub.ScalarList_out
- func: _foreach_div.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8528,6 +8597,7 @@
dispatch:
CPU: foreach_tensor_div_scalarlist_kernel_slow_
CUDA: foreach_tensor_div_scalarlist_kernel_cuda_
autogen: _foreach_div.ScalarList_functional, _foreach_div.ScalarList_out
- func: _foreach_mul.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8542,6 +8612,7 @@
dispatch:
CPU: foreach_tensor_mul_scalarlist_kernel_slow_
CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_
autogen: _foreach_mul.ScalarList_functional, _foreach_mul.ScalarList_out
- func: _foreach_exp(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8556,6 +8627,7 @@
dispatch:
CPU: foreach_tensor_zero_slow_
CUDA: foreach_tensor_zero_cuda_
autogen: _foreach_zero.functional, _foreach_zero.out
- func: _foreach_exp_(Tensor(a!)[] self) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8563,6 +8635,7 @@
dispatch:
CPU: foreach_tensor_exp_slow_
CUDA: foreach_tensor_exp_cuda_
autogen: _foreach_exp.functional, _foreach_exp.out
- func: _foreach_sqrt(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8577,6 +8650,7 @@
dispatch:
CPU: foreach_tensor_sqrt_slow_
CUDA: foreach_tensor_sqrt_cuda_
autogen: _foreach_sqrt.functional, _foreach_sqrt.out
- func: _foreach_abs(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8591,6 +8665,7 @@
dispatch:
CPU: foreach_tensor_abs_slow_
CUDA: foreach_tensor_abs_cuda_
autogen: _foreach_abs.functional, _foreach_abs.out
- func: _foreach_acos(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8605,6 +8680,7 @@
dispatch:
CPU: foreach_tensor_acos_slow_
CUDA: foreach_tensor_acos_cuda_
autogen: _foreach_acos.functional, _foreach_acos.out
- func: _foreach_asin(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8619,6 +8695,7 @@
dispatch:
CPU: foreach_tensor_asin_slow_
CUDA: foreach_tensor_asin_cuda_
autogen: _foreach_asin.functional, _foreach_asin.out
- func: _foreach_atan(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8633,6 +8710,7 @@
dispatch:
CPU: foreach_tensor_atan_slow_
CUDA: foreach_tensor_atan_cuda_
autogen: _foreach_atan.functional, _foreach_atan.out
- func: _foreach_ceil(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8647,6 +8725,7 @@
dispatch:
CPU: foreach_tensor_ceil_slow_
CUDA: foreach_tensor_ceil_cuda_
autogen: _foreach_ceil.functional, _foreach_ceil.out
- func: _foreach_cos(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8661,6 +8740,7 @@
dispatch:
CPU: foreach_tensor_cos_slow_
CUDA: foreach_tensor_cos_cuda_
autogen: _foreach_cos.functional, _foreach_cos.out
- func: _foreach_cosh(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8675,6 +8755,7 @@
dispatch:
CPU: foreach_tensor_cosh_slow_
CUDA: foreach_tensor_cosh_cuda_
autogen: _foreach_cosh.functional, _foreach_cosh.out
- func: _foreach_erf(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8689,6 +8770,7 @@
dispatch:
CPU: foreach_tensor_erf_slow_
CUDA: foreach_tensor_erf_cuda_
autogen: _foreach_erf.functional, _foreach_erf.out
- func: _foreach_erfc(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8703,6 +8785,7 @@
dispatch:
CPU: foreach_tensor_erfc_slow_
CUDA: foreach_tensor_erfc_cuda_
autogen: _foreach_erfc.functional, _foreach_erfc.out
- func: _foreach_expm1(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8717,6 +8800,7 @@
dispatch:
CPU: foreach_tensor_expm1_slow_
CUDA: foreach_tensor_expm1_cuda_
autogen: _foreach_expm1.functional, _foreach_expm1.out
- func: _foreach_floor(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8731,6 +8815,7 @@
dispatch:
CPU: foreach_tensor_floor_slow_
CUDA: foreach_tensor_floor_cuda_
autogen: _foreach_floor.functional, _foreach_floor.out
- func: _foreach_log(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8745,6 +8830,7 @@
dispatch:
CPU: foreach_tensor_log_slow_
CUDA: foreach_tensor_log_cuda_
autogen: _foreach_log.functional, _foreach_log.out
- func: _foreach_log10(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8759,6 +8845,7 @@
dispatch:
CPU: foreach_tensor_log10_slow_
CUDA: foreach_tensor_log10_cuda_
autogen: _foreach_log10.functional, _foreach_log10.out
- func: _foreach_log1p(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8773,6 +8860,7 @@
dispatch:
CPU: foreach_tensor_log1p_slow_
CUDA: foreach_tensor_log1p_cuda_
autogen: _foreach_log1p.functional, _foreach_log1p.out
- func: _foreach_log2(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8787,6 +8875,7 @@
dispatch:
CPU: foreach_tensor_log2_slow_
CUDA: foreach_tensor_log2_cuda_
autogen: _foreach_log2.functional, _foreach_log2.out
- func: _foreach_neg(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8801,6 +8890,7 @@
dispatch:
CPU: foreach_tensor_neg_slow_
CUDA: foreach_tensor_neg_cuda_
autogen: _foreach_neg.functional, _foreach_neg.out
- func: _foreach_tan(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8815,6 +8905,7 @@
dispatch:
CPU: foreach_tensor_tan_slow_
CUDA: foreach_tensor_tan_cuda_
autogen: _foreach_tan.functional, _foreach_tan.out
- func: _foreach_tanh(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8829,6 +8920,7 @@
dispatch:
CPU: foreach_tensor_tanh_slow_
CUDA: foreach_tensor_tanh_cuda_
autogen: _foreach_tanh.functional, _foreach_tanh.out
- func: _foreach_sin(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8843,6 +8935,7 @@
dispatch:
CPU: foreach_tensor_sin_slow_
CUDA: foreach_tensor_sin_cuda_
autogen: _foreach_sin.functional, _foreach_sin.out
- func: _foreach_sinh(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8857,6 +8950,7 @@
dispatch:
CPU: foreach_tensor_sinh_slow_
CUDA: foreach_tensor_sinh_cuda_
autogen: _foreach_sinh.functional, _foreach_sinh.out
- func: _foreach_round(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8871,6 +8965,7 @@
dispatch:
CPU: foreach_tensor_round_slow_
CUDA: foreach_tensor_round_cuda_
autogen: _foreach_round.functional, _foreach_round.out
- func: _foreach_lgamma(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8885,6 +8980,7 @@
dispatch:
CPU: foreach_tensor_lgamma_slow_
CUDA: foreach_tensor_lgamma_cuda_
autogen: _foreach_lgamma.functional, _foreach_lgamma.out
- func: _foreach_frac(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8899,6 +8995,7 @@
dispatch:
CPU: foreach_tensor_frac_slow_
CUDA: foreach_tensor_frac_cuda_
autogen: _foreach_frac.functional, _foreach_frac.out
- func: _foreach_reciprocal(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8913,6 +9010,7 @@
dispatch:
CPU: foreach_tensor_reciprocal_slow_
CUDA: foreach_tensor_reciprocal_cuda_
autogen: _foreach_reciprocal.functional, _foreach_reciprocal.out
- func: _foreach_sigmoid(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8927,6 +9025,7 @@
dispatch:
CPU: foreach_tensor_sigmoid_slow_
CUDA: foreach_tensor_sigmoid_cuda_
autogen: _foreach_sigmoid.functional, _foreach_sigmoid.out
- func: _foreach_trunc(Tensor[] tensors) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8941,6 +9040,7 @@
dispatch:
CPU: foreach_tensor_trunc_slow_
CUDA: foreach_tensor_trunc_cuda_
autogen: _foreach_trunc.functional, _foreach_trunc.out
- func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8948,6 +9048,7 @@
dispatch:
CPU: foreach_tensor_addcdiv_scalar_slow_
CUDA: foreach_tensor_addcdiv_scalar_cuda_
autogen: _foreach_addcdiv.Scalar_functional, _foreach_addcdiv.Scalar_out
- func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8955,6 +9056,7 @@
dispatch:
CPU: foreach_tensor_addcmul_scalar_slow_
CUDA: foreach_tensor_addcmul_scalar_cuda_
autogen: _foreach_addcmul.Scalar_functional, _foreach_addcmul.Scalar_out
- func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8962,6 +9064,7 @@
dispatch:
CPU: foreach_tensor_addcdiv_scalarlist_slow_
CUDA: foreach_tensor_addcdiv_scalarlist_cuda_
autogen: _foreach_addcdiv.ScalarList_functional, _foreach_addcdiv.ScalarList_out
- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -8969,6 +9072,7 @@
dispatch:
CPU: foreach_tensor_addcmul_scalarlist_slow_
CUDA: foreach_tensor_addcmul_scalarlist_cuda_
autogen: _foreach_addcmul.ScalarList_functional, _foreach_addcmul.ScalarList_out
- func: _foreach_addcdiv.Scalar(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -11507,7 +11611,7 @@
python_module: linalg
variants: function
- func: linalg_eigvalsh.out(Tensor self, str UPLO='L', *, Tensor(a!) out) -> Tensor(a!)
- func: linalg_eigvalsh.out(Tensor self, str UPLO="L", *, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
dispatch:
CPU, CUDA: linalg_eigvalsh_out
@ -11528,6 +11632,7 @@
dispatch:
CPU: _linalg_inv_out_helper_cpu
CUDA: _linalg_inv_out_helper_cuda
autogen: _linalg_inv_out_helper.functional, _linalg_inv_out_helper.out
- func: linalg_inv_ex(Tensor self, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
python_module: linalg

View File

@ -19,7 +19,7 @@ full_codegen:
- avg_pool2d_backward
- baddbmm
- bernoulli
- bernoulli_.float
- bernoulli.p
- binary_cross_entropy
- binary_cross_entropy_backward
- bitwise_and.Tensor
@ -72,8 +72,8 @@ full_codegen:
- log_sigmoid_forward
- lt.Scalar
- lt.Tensor
- masked_fill_.Scalar
- masked_fill_.Tensor
- masked_fill.Scalar
- masked_fill.Tensor
- max
- max.dim
- max_pool2d_with_indices
@ -101,12 +101,11 @@ full_codegen:
- norm.ScalarOpt_dim
- pow.Tensor_Scalar
- pow.Tensor_Tensor
- random_
- random_.from
- random_.to
- random.functional
- random.from_functional
- random.to_functional
- reciprocal
- relu
- relu_
- remainder.Tensor
- repeat
- rsqrt
@ -141,7 +140,7 @@ full_codegen:
- upsample_bilinear2d_backward
- upsample_nearest2d
- upsample_nearest2d_backward
- zero_
- zero.functional
- narrow_copy.SymInt
supported:
- as_strided

View File

@ -13,8 +13,25 @@ $ops_headers
namespace at {
namespace native {
// This file contains a number of kernels for aten functions that are fully code-generated.
// TODO: rename this file to something more generic.
at::Tensor clone_arg(const at::Tensor& t) {
return t.clone();
}
std::vector<at::Tensor> clone_arg(const at::TensorList& t_list) {
std::vector<at::Tensor> out(t_list.size());
for (const auto& i : c10::irange(t_list.size())) {
out[i] = t_list[i].clone();
}
return out;
}
${CompositeViewCopyKernel_Definitions}
${GeneratedCompositeFunctional_Definitions}
} // namespace native
} // namespace at

View File

@ -185,7 +185,9 @@ TEST_F(atest, ne_operators) {
TEST_F(atest, add_operators) {
auto exp_tensor = tensor({-10, 1, 0, -1, 10});
run_binary_ops_test(add_out, x_tensor, y_tensor, exp_tensor, INTBOOL, 2);
run_binary_ops_test<
at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Scalar&)>(
add_out, x_tensor, y_tensor, exp_tensor, INTBOOL, 2);
}
TEST_F(atest, max_operators) {

View File

@ -174,6 +174,17 @@ $2 = torch._ops.aten.add.Tensor($0, tensor([[1., 1.],
[1., 1.],
[1., 1.]]))""")
# Some ops that are mutable are neither inplace nor out= ops.
# They also need special handling.
def test_mutable_op_not_inplace_or_other(self):
def f(x):
return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0)
logs = self.get_logs(f, torch.ones(1))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper.functional($0, $0, $0, $0, $0, $0, $0, 1.0, 0, 1, 0)""")
def test_tensor_list_composite(self):
def f(x):
# Test an op with TensorList input

View File

@ -3944,8 +3944,6 @@ class TestFunctionalTracing(JitTestCase):
"upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
"upsample_nearest": INTERPOLATE_ARGS_CONFLICT,
"normalize" : MUTABLE,
}
# List of nn.functionals with Tensor inputs but not with type annotation

View File

@ -177,6 +177,9 @@ SKIP_PYTHON_BINDINGS_SIGNATURES = [
@with_native_function
def should_generate_py_binding(f: NativeFunction) -> bool:
# So far, all NativeFunctions that are entirely code-generated do not get python bindings.
if "generated" in f.tags:
return False
name = cpp.name(f.func)
for skip_regex in SKIP_PYTHON_BINDINGS:
if skip_regex.match(name):

View File

@ -432,7 +432,8 @@ def emit_trace_body(f: NativeFunction) -> List[str]:
assign_return_values = (
f"{tie_return_values(f)} = "
if f.func.kind() == SchemaKind.functional and f.func.returns
if f.func.kind() in [SchemaKind.functional, SchemaKind.mutable]
and f.func.returns
else ""
)

View File

@ -349,7 +349,7 @@ GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
GRADIENT_IMPLEMENTED_FOR_COMPLEX.update(GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX)
# Some operators invalidate the grad_accumulator. Let's reset it.
RESET_GRAD_ACCUMULATOR = {"set", "resize"}
RESET_GRAD_ACCUMULATOR = {"set_", "resize_"}
# NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
#
@ -734,7 +734,7 @@ def gen_variable_type_func(
if (
fn.info is None
and not get_base_name(f) in RESET_GRAD_ACCUMULATOR
and not str(f.func.name.name) in RESET_GRAD_ACCUMULATOR
and not get_base_name(f) in DONT_REQUIRE_DERIVATIVE
and len(gen_differentiable_outputs(fn)) > 0
and not cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE
@ -857,7 +857,14 @@ def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
and (len(differentiable_outputs) > 0)
)
if info is not None and info.has_derivatives and not requires_derivative:
if (
info is not None
and info.has_derivatives
and not requires_derivative
# out= ops are allowed to have zero returns which cause requires_derivative to be False
# we shouldn't error out though (out= ops for autograd just redispatch)
and len(f.func.returns) > 0
):
raise RuntimeError(
f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative"
)
@ -1528,7 +1535,7 @@ def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
# Save only after the forward AD has been set up
body.append(emit_save_outputs())
if base_name in RESET_GRAD_ACCUMULATOR:
if str(f.func.name.name) in RESET_GRAD_ACCUMULATOR:
# `inplace` implies that there is exactly one output named `self`,
# so we can keep the generated code easy. If you need to
# `reset_grad_accumulator` in an operator that's not `inplace`, you can

View File

@ -32,7 +32,10 @@ from torchgen.api.types import (
stringT,
)
from torchgen.api import cpp
from torchgen.gen import parse_native_yaml, get_grouped_by_view_native_functions
from torchgen.gen import (
parse_native_yaml,
get_grouped_by_view_native_functions,
)
from torchgen.context import with_native_function
from torchgen.model import (
FunctionSchema,

View File

@ -81,10 +81,11 @@ std::vector<int64_t> expand_param_if_needed(
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
std::vector<Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) {
TORCH_API std::vector<Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) {
double size_d = 0;
// shape inference code copied from RangeFactories.cpp arange_out function
// Note: AT_DISPATCH_ALL_TYPES_AND is just a macro that defines the correct c++ scalar_t type depending on out tensor
AT_DISPATCH_ALL_TYPES_AND(c10::kBFloat16, out.scalar_type(), "compute_shape_arange_out", [&]() {
// Note: acc_type further defines an accumulataion type depending on the scalar_t and whether its on cuda vs cpu.
using accscalar_t = at::acc_type<scalar_t, false>;
@ -129,7 +130,6 @@ std::vector<Shape> compute_shape_arange_out(const at::Scalar & start, const at::
// If any of start, end, or stop are floating-point, the dtype is inferred to be the default dtype, see get_default_dtype().
// Otherwise, the dtype is inferred to be torch.int64.
// Since out tensor is specified, its dtype should always be used?
return {Shape(out.scalar_type(), {size})};
}
@ -145,7 +145,7 @@ std::vector<Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optiona
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape> compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional<at::Generator> generator) {
std::vector<Shape> compute_shape_bernoulli(const at::Tensor & self, double p, c10::optional<at::Generator> generator) {
return compute_shape_bernoulli(self, generator);
}
@ -224,11 +224,11 @@ std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at:
}
}
std::vector<Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) {
std::vector<Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) {
std::vector<Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
@ -380,26 +380,22 @@ std::vector<Shape> compute_shape_native_dropout_backward(const at::Tensor & grad
return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())};
}
std::vector<Shape> compute_shape_random_(at::Tensor & self, c10::optional<at::Generator> generator) {
std::vector<Shape> compute_shape_random_functional(const at::Tensor & self, c10::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape> compute_shape_random_(at::Tensor & self, int64_t to, c10::optional<at::Generator> generator) {
return compute_shape_random_(self, generator);
std::vector<Shape> compute_shape_random_functional(const at::Tensor & self, int64_t to, c10::optional<at::Generator> generator) {
return compute_shape_random_functional(self, generator);
}
std::vector<Shape> compute_shape_random_(at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator) {
return compute_shape_random_(self, generator);
std::vector<Shape> compute_shape_random_functional(const at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator) {
return compute_shape_random_functional(self, generator);
}
std::vector<Shape> compute_shape_relu(const at::Tensor& self) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape> compute_shape_relu_(at::Tensor& self) {
return compute_shape_relu(self);
}
std::vector<Shape> compute_shape_bitwise_and(const at::Tensor& self, const at::Scalar& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
@ -417,7 +413,7 @@ std::vector<Shape> compute_shape_sum(
return {Shape(self.scalar_type(), {})};;
}
std::vector<Shape> compute_shape_zero_(at::Tensor& self) {
std::vector<Shape> compute_shape_zero_functional(const at::Tensor& self) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

View File

@ -16,7 +16,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape__adaptive_avg_pool2d_bac
TORCH_API std::vector<torch::lazy::Shape> compute_shape_abs(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, double p, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_cat(at::TensorList tensors, int64_t dim);
@ -37,8 +37,8 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_l1_loss_backward(const a
TORCH_API std::vector<torch::lazy::Shape> compute_shape_log_sigmoid_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_log_sigmoid_forward(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_logdet(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Tensor & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_max(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_min(const at::Tensor & self);
@ -50,11 +50,10 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm_backwa
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nonzero(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_(at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_(at::Tensor & self, int64_t to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_(at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_functional(const at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_functional(const at::Tensor & self, int64_t to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_functional(const at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_relu(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_relu_(at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_smooth_l1_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sort(const at::Tensor & self, int64_t dim, bool descending);
@ -65,7 +64,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_std(const at::Tensor & s
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sum(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
TORCH_API std::vector<torch::lazy::Shape> compute_shape__to_copy(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, bool non_blocking, c10::optional<at::MemoryFormat> memory_format);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_trace(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_(at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_functional(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length);
} // namespace lazy
} // namespace torch

View File

@ -310,17 +310,41 @@ def match_differentiability_info(
for info in differentiability_infos
if info.func.func.kind() == SchemaKind.functional
}
non_functional_info_by_signature = {
info.func.func.signature(strip_default=True): info
for info in differentiability_infos
if info.func.func.kind() != SchemaKind.functional
}
def find_info(f: NativeFunction) -> Tuple[Optional[DifferentiabilityInfo], bool]:
# (1) Check for an exact match
if f.func in info_by_schema:
return info_by_schema[f.func], True
# if there is no exact match look for the out-of-place signature.
# (2) If no exact match, check if the out-of-place variant
# of this operator has a match.
# i.e mul() for mul_() or mul_out()
return (
functional_info_by_signature.get(f.func.signature(strip_default=True)),
False,
)
f_sig = f.func.signature(strip_default=True)
if f_sig in functional_info_by_signature:
return functional_info_by_signature[f_sig], False
# (3) Some operators have a derivative explicitly defined for the mutable
# variant, but get a code-generated out-of-place variant which does *not*
# come with a derivative formula.
# For the generated out-of-place variant, use the mutable variant's formula
# if it exists.
if "generated" in f.tags and f_sig in non_functional_info_by_signature:
info = non_functional_info_by_signature[f_sig]
# See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
assert not any(
"self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs
), f"""\
Attempted to convert a derivative formula for a mutable operator
to be used by automatically by its functional variant ("{str(f.func)}").
this is not currently supported (we'd need to fix up the formula in the codegen)."""
return info, False
return None, False
result: List[NativeFunctionWithDifferentiabilityInfo] = []
for f in native_functions:

View File

@ -64,7 +64,9 @@ from typing import Optional, Sequence, Union, List, Set
def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
name = str(func.name.name)
if func.is_out_fn():
if func.is_functional_fn():
name += "_functional"
elif func.is_out_fn():
if faithful_name_for_out_overloads:
name += "_outf"
else:

View File

@ -60,6 +60,8 @@ from torchgen.api.types import (
options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))
longVec_ctype = VectorCType(BaseCType(longT))
optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
@ -287,20 +289,38 @@ Check this module for more information.
return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})"
elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
options = direct_solve(options_ctype)
return f"optTypeMetaToScalarType({options}.dtype_opt())"
try:
options = direct_solve(options_ctype)
return f"optTypeMetaToScalarType({options}.dtype_opt())"
except UnsatError:
out_tensor = direct_solve(out_tensor_ctype)
return f"{out_tensor}.scalar_type()"
elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
options = direct_solve(options_ctype)
return f"{options}.layout_opt()"
try:
options = direct_solve(options_ctype)
return f"{options}.layout_opt()"
except UnsatError:
out_tensor = direct_solve(out_tensor_ctype)
return f"{out_tensor}.layout()"
elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
options = direct_solve(options_ctype)
return f"{options}.device_opt()"
try:
options = direct_solve(options_ctype)
return f"{options}.device_opt()"
except UnsatError:
out_tensor = direct_solve(out_tensor_ctype)
return f"{out_tensor}.device()"
elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
options = direct_solve(options_ctype)
return f"{options}.pinned_memory_opt()"
try:
options = direct_solve(options_ctype)
return f"{options}.pinned_memory_opt()"
except UnsatError:
# If we're calling a factory op from its out= variant,
# We don't actually care about the value of pin_memory.
out_tensor = direct_solve(out_tensor_ctype)
return "c10::nullopt"
# We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
elif goal.type == BaseCType(intArrayRefT):
@ -348,6 +368,15 @@ Check this module for more information.
# With arguments like std::vector<IntArrayRef>.
# If that changes, we'll have to add the translation here.
# We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor.
# We could probably generalize this to non-tensor types too.
if goal.type == MutRefCType(BaseCType(tensorT)):
const_ref_tensor_ctype = NamedCType(
goal.name, ConstRefCType(BaseCType(tensorT))
)
argname = direct_solve(const_ref_tensor_ctype)
return f"const_cast<Tensor&>({argname})"
unsat(goal)
return [Expr(solve(g, direct=False), g) for g in goal_ctypes]

View File

@ -26,6 +26,7 @@ F = TypeVar(
F2 = TypeVar(
"F2",
NativeFunction,
NativeFunctionsGroup,
Optional[NativeFunction],
bool,
)

View File

@ -616,6 +616,10 @@ check_inplace(out, sizes, options);
const auto& out = outputs_[output_idx].get();
resize_out(out, sizes, strides, options);
{create_proxy}"""
elif k is SchemaKind.mutable:
raise AssertionError(
"SchemaKind.mutable structured operators are currently not supported"
)
else:
assert_never(k)
@ -631,6 +635,10 @@ resize_out(out, sizes, strides, options);
out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
elif k is SchemaKind.mutable:
raise AssertionError(
"SchemaKind.mutable structured operators are currently not supported"
)
else:
assert_never(k)

View File

@ -33,6 +33,10 @@ from torchgen.model import (
ViewSchemaKind,
BaseOperatorName,
)
from torchgen.native_function_generation import (
pre_group_native_functions,
add_generated_native_functions,
)
from torchgen.api.types import (
Binding,
CppSignatureGroup,
@ -72,6 +76,7 @@ from torchgen.gen_functionalization_type import (
gen_functionalization_registration,
gen_functionalization_view_inverse_declaration,
gen_composite_view_copy_kernel,
gen_composite_functional_kernel,
)
T = TypeVar("T")
@ -180,6 +185,7 @@ def parse_native_yaml_struct(
index={},
)
)
add_generated_native_functions(rs, bs)
for k, v in bs.items():
# All structured in-tree operators are implemented in terms of their out operator.
indices[k] = BackendIndex(
@ -1278,59 +1284,48 @@ def get_custom_build_selector(
return selector
def pre_group_native_functions(
native_functions: Sequence[NativeFunction],
) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]:
pre_grouped_native_functions: Dict[
FunctionSchema, Dict[SchemaKind, NativeFunction]
] = defaultdict(dict)
for f in native_functions:
d = pre_grouped_native_functions[f.func.signature()]
assert f.func.kind() not in d
d[f.func.kind()] = f
return pre_grouped_native_functions
def get_grouped_by_view_native_functions(
native_functions: Sequence[NativeFunction],
) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]:
def maybe_create_view_group(
d: Dict[ViewSchemaKind, NativeFunction]
d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction]
) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]:
funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = []
if ViewSchemaKind.aliasing not in d:
# Case 1: this op / op group is not aliasing, so we don't create a view group.
# return the original (ungrouped) native functions instead.
for func in d.values():
funcs.append(func)
else:
# Case 2: this op group contains an aliasing op, so we create a ViewGroup for it.
# The handling for out= ops here is unfortunate.
# out= ops don't really make sense for view operators.
# However, we have at least one existing {view}_copy.out operator in native_functions.yaml.
# It shouldn't be part of a view group, so we explicitly don't group it.
# There currently aren't any out= view ops (and there probably shouldn't be).
# We also expect that when we hit this case, the `non_aliasing` op in the dict
# *must* be a view_copy op (this is asserted in the NativeFunctionsViewGroup constructor)
if ViewSchemaKind.out in d:
funcs.append(d[ViewSchemaKind.out])
if ViewSchemaKind.aliasing in d:
view = d.pop(ViewSchemaKind.aliasing)
view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
view_copy = d.pop(SchemaKind.functional, None)
funcs.append(
NativeFunctionsViewGroup(
view=d[ViewSchemaKind.aliasing],
view_copy=d.get(ViewSchemaKind.non_aliasing, None),
view_inplace=d.get(ViewSchemaKind.inplace, None),
view=view,
view_copy=view_copy,
view_inplace=view_inplace,
)
)
# Take the remaining functions that weren't part of the view group
# and emit them separately
for func in d.values():
funcs.append(func)
return funcs
grouped_by_views: Dict[
FunctionSchema, Dict[ViewSchemaKind, NativeFunction]
FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction]
] = defaultdict(dict)
for f in native_functions:
schema = f.func.view_signature()
assert f.view_schema_kind not in grouped_by_views[schema]
grouped_by_views[schema][f.view_schema_kind] = f
view_kind: ViewSchemaKind = f.view_schema_kind
# We need to group up ops relevant to the same "view", consisting of:
# view op (ViewSchemaKind.aliasing)
# view_inplace op (ViewSchemaKind.aliasing_inplace)
# view_copy op (SchemaKind.functional)
if view_kind == ViewSchemaKind.non_aliasing:
kind = f.func.kind()
assert kind not in grouped_by_views[schema]
grouped_by_views[schema][kind] = f
else:
assert view_kind not in grouped_by_views[schema]
grouped_by_views[schema][view_kind] = f
return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
@ -1343,6 +1338,9 @@ def get_grouped_native_functions(
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
r = NativeFunctionsGroup.from_dict(d)
if r is None:
# Invariant: any NativeFunctions that are code-generated
# should have been grouped into NativeFunctionsGroup objects
assert not any("generated" in f.tags for f in d.values())
return list(d.values())
else:
return [r]
@ -1835,9 +1833,7 @@ def gen_source_files(
native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
structured_native_functions: Sequence[NativeFunctionsGroup],
native_functions_with_view_groups: Sequence[
Union[NativeFunction, NativeFunctionsViewGroup]
],
view_groups: Sequence[NativeFunctionsViewGroup],
selector: SelectiveBuilder,
static_dispatch_idx: List[BackendIndex],
backend_indices: Dict[DispatchKey, BackendIndex],
@ -2118,30 +2114,11 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
},
)
# We need to easily map from [inplace_op_name] -> [functional_op] for the functionalization pass,
# so here I generate a mapping from every operator name to its corresponding functional NativeFunction (if it exist).
pre_grouped_d: Dict[
FunctionSchema, Dict[SchemaKind, NativeFunction]
] = pre_group_native_functions(native_functions)
to_functional_op: Dict[OperatorName, Optional[NativeFunction]] = {
k: v
for d in [
{
f.func.name: pre_grouped_d[func][SchemaKind.functional]
if SchemaKind.functional in pre_grouped_d[func].keys()
else None
for f in pre_grouped_d[func].values()
}
for func in pre_grouped_d.keys()
]
for k, v in d.items()
}
def functionalization_env_callable(
g: Union[NativeFunction, NativeFunctionsViewGroup]
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
) -> Dict[str, List[str]]:
def gen_op_headers(
g: Union[NativeFunction, NativeFunctionsViewGroup]
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
) -> List[str]:
if isinstance(g, NativeFunctionsViewGroup):
# view ops always get a functionalization kernel
@ -2155,11 +2132,28 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
]
return headers
elif isinstance(g, NativeFunctionsGroup):
headers = [
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
f"#include <ATen/ops/{g.out.root_name}_native.h>",
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
]
if g.inplace is not None:
headers += [
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
]
if g.mutable is not None:
headers += [
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
]
return headers
else:
f = g
return [
f"#include <ATen/ops/{f.root_name}_native.h>",
f"#include <ATen/ops/{f.root_name}_ops.h>",
f"#include <ATen/ops/{g.root_name}_native.h>",
f"#include <ATen/ops/{g.root_name}_ops.h>",
]
return {
@ -2167,11 +2161,6 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
"func_definitions": gen_functionalization_definition(
selector,
g,
# We need to manually map inplace ops to their out-of-place variants
# (we can't do this with NativeFunctionsGroup today because not all inplace ops have out= variants)
None
if isinstance(g, NativeFunctionsViewGroup)
else to_functional_op.get(g.func.name, None),
),
"func_registrations": gen_functionalization_registration(
selector,
@ -2180,9 +2169,30 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
),
}
all_groups: List[
Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
] = list(structured_native_functions) + list(
view_groups # type: ignore[assignment, arg-type, operator]
)
# Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
# The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
# (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
# (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
# Although this could go away long-term if we add a dedicated dispatch key for decompositions.
structured_map: Dict[OperatorName, NativeFunction] = {
f.func.name: f
for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
}
view_map: Dict[OperatorName, NativeFunction] = {
f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
}
for f in native_functions:
if f.func.name not in structured_map and f.func.name not in view_map:
all_groups.append(f)
cpu_fm.write_sharded(
"RegisterFunctionalization.cpp",
native_functions_with_view_groups,
all_groups,
key_fn=key_func,
env_callable=functionalization_env_callable,
num_shards=4,
@ -2203,11 +2213,7 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
lambda g: gen_functionalization_view_inverse_declaration(
selector, g
),
[
g
for g in native_functions_with_view_groups
if isinstance(g, NativeFunctionsViewGroup)
],
view_groups,
)
)
},
@ -2239,17 +2245,23 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
[g.view] if g.view_copy is None else [g.view, g.view_copy]
)
)
for g in native_functions_with_view_groups
if isinstance(g, NativeFunctionsViewGroup)
for g in view_groups
]
+ [
"\n".join(
f"#include <ATen/ops/{f.root_name}_ops.h>"
for f in [g.inplace, g.mutable]
if f is not None and "generated" not in f.tags
)
for g in structured_native_functions
],
"CompositeViewCopyKernel_Definitions": list(
mapMaybe(gen_composite_view_copy_kernel, view_groups)
),
"GeneratedCompositeFunctional_Definitions": list(
mapMaybe(
gen_composite_view_copy_kernel,
[
g
for g in native_functions_with_view_groups
if isinstance(g, NativeFunctionsViewGroup)
],
gen_composite_functional_kernel,
structured_native_functions,
)
),
},
@ -2377,12 +2389,18 @@ def main() -> None:
)
grouped_native_functions = get_grouped_native_functions(native_functions)
structured_native_functions = [
g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
]
native_functions_with_view_groups = get_grouped_by_view_native_functions(
native_functions
)
view_groups = [
g
for g in native_functions_with_view_groups
if isinstance(g, NativeFunctionsViewGroup)
]
template_dir = os.path.join(options.source_path, "templates")
@ -2451,7 +2469,7 @@ def main() -> None:
native_functions=native_functions,
grouped_native_functions=grouped_native_functions,
structured_native_functions=structured_native_functions,
native_functions_with_view_groups=native_functions_with_view_groups,
view_groups=view_groups,
selector=selector,
static_dispatch_idx=static_dispatch_idx,
backend_indices=backend_indices,

View File

@ -4,6 +4,7 @@ from torchgen.api.types import (
Binding,
FunctionalizationLambda,
ViewInverseSignature,
Expr,
NativeSignature,
CType,
BaseCType,
@ -19,8 +20,9 @@ from torchgen.context import (
)
from torchgen.model import (
Argument,
Return,
NativeFunction,
SchemaKind,
NativeFunctionsGroup,
BackendIndex,
FunctionSchema,
SelfArgument,
@ -30,10 +32,30 @@ from torchgen.model import (
NativeFunctionsViewGroup,
ListType,
)
from torchgen.native_function_generation import (
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
)
from torchgen.selective_build.selector import SelectiveBuilder
from typing import List, Optional, Union, Tuple, Callable
# Note: [Mutable Ops Not Using Functionalization]
# Ops in this list currently do not work with functionalization and should be fixed.
MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = (
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
+ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
+ INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
+ [
# It will be BC-breaking, but we should fix their schemas.
# should be inplace?
"record_stream",
]
)
# This file contains codegen that relates to the functionalization pass.
# It includes:
# - gen_functionalization_definition
@ -91,24 +113,96 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
"""
def modifies_arguments(f: NativeFunction) -> bool:
return f.func.kind() in [SchemaKind.inplace, SchemaKind.out]
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
assert len(rets) == len(names)
if len(rets) == 0:
return ""
elif len(rets) == 1:
return f"return {names[0]};"
else:
return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});"
# This function constructs the return statement for the kernels that contain mutations
# It mostly just needs to special case multi-output returns to wrap the result in a tuple
def return_str(f: NativeFunction) -> str:
# Need to check both # outs and # returns. Why?
# out= ops with a mutable Tensor(a!)[] argument are expected to have a void return type.
if len(f.func.arguments.out) != 0 and len(f.func.returns) != 0:
if len(f.func.arguments.out) > 1:
return_names = ", ".join(a.name for a in f.func.arguments.out)
return f"return {DispatcherSignature.from_schema(f.func).returns_type().cpp_type()}({return_names});"
# Given a function, and the name of a variable correponding to the output of that function,
# gather up all of the individual returns that are not aliased
def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> List[str]:
aliased_rets = func.aliased_return_names()
non_aliased_names = []
is_out_var_a_tuple = len(func.returns) > 1
for (i, r) in enumerate(aliased_rets):
if r is None:
non_aliased_names.append(
f"std::get<{i}>({out_var})" if is_out_var_a_tuple else out_var
)
return non_aliased_names
@with_native_function
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
# We should only be generating these for code-generated NativeFunctions
if "generated" not in g.functional.tags:
return None
# And we always write the kernel for a generated op in terms of a non-generated op.
if g.inplace is not None and "generated" not in g.inplace.tags:
target_f = g.inplace
elif g.mutable is not None and "generated" not in g.mutable.tags:
target_f = g.mutable
else:
# We should be guaranteed to have a valid inplace/mutable variant to call into.
# See Note: [Mutable Ops Not Using Functionalization]
raise AssertionError(str(g.functional.func))
sig = DispatcherSignature(g.functional.func)
target_sig = DispatcherSignature(target_f.func)
context: List[Union[Binding, Expr]] = []
clone_mutable_inputs = []
cloned_return_names = []
# We can't just directly pass all of the arguments from the functional op into the mutating op.
# We need to check for which inputs to the mutating operator are mutable,
# and clone those inputs first.
for a_curr, a_tgt in zip(
dispatcher.jit_arguments(g.functional.func),
dispatcher.jit_arguments(target_f.func),
):
if a_tgt.annotation is not None and a_tgt.annotation.is_write:
clone_mutable_inputs.append(
f"auto {a_curr.name}_clone = clone_arg({a_curr.name});"
)
context.append(
Expr(
expr=f"{a_curr.name}_clone",
type=dispatcher.argument_type(a_curr, binds=a_curr.name),
)
)
# Invariant: mutable arguments on the inner mutable op are always returns on the functional op.
cloned_return_names.append(f"{a_curr.name}_clone")
else:
return f"return {f.func.arguments.out[0].name}"
if f.func.arguments.self_arg is not None and len(f.func.returns) != 0:
return f"return {f.func.arguments.self_arg.argument.name}"
return ""
context.append(dispatcher.argument(a_curr))
exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())])
out_name = "output"
maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else ""
inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name)
ret_str = return_str(
g.functional.func.returns, inner_return_names + cloned_return_names
)
clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
return f"""
{sig.defn()} {{
{clone_mutable_inputs_str}
{maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
{ret_str}
}}
"""
def modifies_arguments(f: NativeFunction) -> bool:
return any(
a.annotation is not None and a.annotation.is_write
for a in f.func.arguments.flat_all
)
def wrapper_name(func: FunctionSchema) -> str:
@ -352,11 +446,103 @@ def emit_view_functionalization_body(
"""
def maybe_create_output(f: NativeFunction, var_name: str) -> str:
if len(f.func.returns) == 0:
return ""
return_type = dispatcher.returns_type(f.func.returns).remove_const_ref().cpp_type()
return f"{return_type} {var_name} = "
# Given a NativeFunction, and a variable name corresponding to the output of redispatching on the function,
# this returns two lists of names, consisting of:
# - the names of returns corresponding to the original (mutable) inputs of the outer function
# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function
def get_mutable_redispatch_return_names(
f: NativeFunction, inner_return_var: str
) -> Tuple[List[str], List[str]]:
aliased_returns = []
non_aliased_returns = []
for (i, name) in enumerate(f.func.aliased_return_names()):
if name is not None:
aliased_returns.append(name)
else:
non_aliased_returns.append(
inner_return_var
if len(f.func.returns) == 1
else f"std::get<{i}>({inner_return_var})"
)
return aliased_returns, non_aliased_returns
# When functionalization "no-op's" and redispatches on a mutable operator, we need to take care so that:
# - For fresh outputs, we return the result of the redispatch (without wrapping outputs)
# - For outputs that were aliased to inputs, we return the inputs directly (since some of them might have been wrapped)
def return_from_mutable_noop_redispatch(
f: NativeFunction, inner_return_var: str
) -> str:
aliased, non_aliased = get_mutable_redispatch_return_names(f, inner_return_var)
# Just get all of the return names, and immediately return them
return return_str(f.func.returns, aliased + non_aliased)
def wrap_propagate_mutations_and_return(
f: NativeFunction, functional_op: NativeFunction, inner_return_var: str
) -> str:
mutable_arg_names = f.func.arguments.mutable_arg_names()
(
aliased_outer_rets,
non_aliased_outer_rets,
) = get_mutable_redispatch_return_names(f, inner_return_var)
_, non_aliased_inner_rets = get_mutable_redispatch_return_names(
functional_op, inner_return_var
)
# The outer function may have a mix of aliased and non-aliased outputs,
# But the inner functional op that we're transforming to should only have non-aliased outputs
assert len(mutable_arg_names) + len(non_aliased_outer_rets) == len(
non_aliased_inner_rets
)
# First, take all of the newly created outputs from the inner call and wrap them into functional tensors
updates = []
non_aliased_wrapped_ret_names = []
for (i, inner_ret) in enumerate(
non_aliased_inner_rets[: len(non_aliased_outer_rets)]
):
ret_name = f"output_{i}"
updates.append(
f"""\
auto output_{i} = at::functionalization::impl::to_functional_tensor({inner_ret});"""
)
non_aliased_wrapped_ret_names.append(ret_name)
# Next, take all of the mutated outputs from the inner call corresponding to mutated inputs,
# and propogate the mutations
for (outer_arg, inner_ret) in zip(
mutable_arg_names, non_aliased_inner_rets[len(non_aliased_outer_rets) :]
):
updates.append(
f"""\
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
at::functionalization::impl::commit_update({outer_arg});"""
)
# Finally, we return:
# - Any mutable arguments that also returns
# - Any immutable returns that were created wrapping the output from the inner call
returns_str = return_str(
f.func.returns, aliased_outer_rets + non_aliased_wrapped_ret_names
)
updates_str = "\n".join(updates)
return f"""\
{updates_str}
{returns_str}"""
# Generates the Functionalization kernel for:
# - mutation ops (inplace and out= ops)
@with_native_function_and
def emit_inplace_functionalization_body(
f: NativeFunction, functional_op: Optional[NativeFunction]
f: NativeFunction, g: NativeFunctionsGroup
) -> str:
# mutation case
assert modifies_arguments(f)
@ -400,42 +586,15 @@ def emit_inplace_functionalization_body(
for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
]
if functional_op is None:
# We can't functionalize this inplace op, since we don't know what the corresponding functional op is.
return_type = (
dispatcher.returns_type(f.func.returns).remove_const_ref().cpp_type()
)
warn_str = f"""Note: the functionalization pass encountered an operator ({str(f.func.name)}) that it could not \
functionalize, because it couldn't find an out-of-place equivalent of the operator to call. \
Instead, it's calling the inplace/view operator directly. \
If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch."""
return f"""
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{
TORCH_WARN("{warn_str}");
}}
{unwrap_tensor_args_str}
at::AutoDispatchSkipFunctionalize guard;
// Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops.
at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)});
{return_str(f)};
}}
"""
else:
# call the out-of-place variant of the op
return_type = (
dispatcher.returns_type(functional_op.func.returns)
.remove_const_ref()
.cpp_type()
)
functional_sig = DispatcherSignature.from_schema(functional_op.func)
functional_exprs = [
e.expr
for e in translate(
unwrapped_args_ctx, functional_sig.arguments(), method=False
)
]
# call the out-of-place variant of the op
return_type = (
dispatcher.returns_type(g.functional.func.returns).remove_const_ref().cpp_type()
)
functional_sig = DispatcherSignature.from_schema(g.functional.func)
functional_exprs = [
e.expr
for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False)
]
if f.func.is_out_fn():
mutable_input_post_processing = "\n".join(
@ -471,17 +630,16 @@ If this causes problems in your program, consider upstreaming the out-of-place o
}} else {{
// case 2: arguments are not functional tensors, so we no-op and redispatch.
at::AutoDispatchSkipFunctionalize guard;
at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)});
{return_str(f)};
{maybe_create_output(f, 'tmp_output')}at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)});
{return_from_mutable_noop_redispatch(f, 'tmp_output')};
}}
}} else {{
{return_type} tmp_output;
{{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = at::_ops::{functional_op.func.name.unambiguous_name()}::call({', '.join(functional_exprs)});
tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({', '.join(functional_exprs)});
}}
{mutable_input_post_processing}
{return_str(f)};
{wrap_propagate_mutations_and_return(f, g.functional, 'tmp_output')}
}}
}}"""
@ -508,7 +666,7 @@ def gen_functionalization_view_inverse_declaration(
def gen_functionalization_registration(
selector: SelectiveBuilder,
g: Union[NativeFunction, NativeFunctionsViewGroup],
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup],
composite_implicit_autograd_index: BackendIndex,
) -> List[str]:
@with_native_function
@ -542,8 +700,16 @@ def gen_functionalization_registration(
assert g.view_inplace.is_view_op
view_str.append(emit_registration_helper(g.view_inplace))
return view_str
elif isinstance(g, NativeFunctionsGroup):
fns = list(g.functions())
else:
f = g
if str(g.func.name) in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION:
return []
fns = [g]
registrations = []
for f in fns:
if str(f.func.name) == "lift":
# See Note [Functionalization <> torch.Tensor constructor]
return []
@ -552,14 +718,17 @@ def gen_functionalization_registration(
# We *also* need to directly register CompositeImplicitAUtograd kernels
# so that they decompose properly before functioanlization.
if modifies_arguments(f) or f.has_composite_implicit_autograd_kernel:
return [emit_registration_helper(f)]
return []
registrations.append(emit_registration_helper(f))
return registrations
def gen_functionalization_definition(
selector: SelectiveBuilder,
g: Union[NativeFunction, NativeFunctionsViewGroup],
functional_op: Optional[NativeFunction],
# Note: Ideally this code should never have to look at NativeFunction
# (and instead only need to operate on grouped NativeFunctions).
# The only reason currently is because we need to emit direct dispatch registrations
# For CompositeImplicitAutograd operators, which are potentially ungrouped.
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup],
) -> List[str]:
# Don't generate kernels in mobile build
if not selector.include_all_operators:
@ -576,9 +745,24 @@ def gen_functionalization_definition(
if g.view_inplace is not None:
view_defs.append(emit_view_functionalization_body(g, view_inplace=True))
return view_defs
elif isinstance(g, NativeFunction):
# Invariant: all mutable operators that we need to handle in functionalization
# should have been properly grouped up.
# TODO: The below ops all have "problematic" schemas that prevent them from
# getting functionalized. Instead of bending over backwards to get things to work,
# I think we should either:
# (1) fix their schemas (BC-breaking)
# (2) hand-write their functionalization kernels
if str(g.func.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION:
assert g.has_composite_implicit_autograd_kernel or not modifies_arguments(g)
return []
else:
# Case 2: emit inplace -> out-of-place kernels for the functionalization pass
f = g
if modifies_arguments(f):
return [emit_inplace_functionalization_body(f, functional_op)]
mutation_defs = []
mutation_defs.append(emit_inplace_functionalization_body(g.out, g))
if g.inplace is not None:
mutation_defs.append(emit_inplace_functionalization_body(g.inplace, g))
if g.mutable is not None:
mutation_defs.append(emit_inplace_functionalization_body(g.mutable, g))
return mutation_defs
return []

View File

@ -13,7 +13,6 @@ from typing import (
Callable,
Iterable,
Iterator,
Tuple,
Type,
)
from torchgen.api.types import BaseCppType
@ -27,7 +26,6 @@ from torchgen.gen import (
from torchgen.api.lazy import setValueT
from torchgen.model import (
FunctionSchema,
NativeFunction,
NativeFunctionsGroup,
OperatorName,
@ -331,39 +329,18 @@ def run_gen_lazy_tensor(
def concat_map_codegen(
func: Callable[[NativeFunction], Sequence[str]],
xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]],
*,
codegenInplaceVariant: bool = False,
) -> Iterator[str]:
"""
We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
only code-gen additional entries for the inplace variant for the native functions.
Note: If xs is not sorted, there may be an edge case when generating IR classes. Considering relu and relu_, if
we encounter relu_ before relu. we will then generate an IR class with op = at::aten::relu_ for both relu and
relu_ which will cause problems for relu.
TODO(alanwaketan): Once all ops are grouped properly, we should no longer need this hack.
"""
generated = set()
def gen_key(func: FunctionSchema) -> Tuple[str, str]:
# we want to generate unique entries for overloads of functional variants,
# but not for inplace variants unless explicitly told `codegenInplaceVariant`
return (func.name.name.base, func.name.overload_name)
for x in xs:
f = x.functional if isinstance(x, NativeFunctionsGroup) else x
# For the 'or'd terms:
# 1. codegenInplaceVariant means we can generate the in-place variant corresponding items.
# 2. not f.func.name.name.inplace means the op is not a in-place variant, so we can generate the item.
# 3. f.func.name.name.base not in generated means even for in-place ops we still need to generate the item
# as if they were the functional variants for one time.
if f.func.name in full_codegen and (
codegenInplaceVariant
or not f.func.name.name.inplace
or gen_key(f.func) not in generated
):
generated.add(gen_key(f.func))
for r in func(f):
yield r
fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x]
for f in fs:
if f.func.name in full_codegen:
for r in func(f):
yield r
selector = SelectiveBuilder.get_nop_selector()
@ -402,7 +379,6 @@ def run_gen_lazy_tensor(
backend_indices[backend_key], tensor_class
),
grouped_native_functions,
codegenInplaceVariant=True,
)
)
@ -495,7 +471,6 @@ def run_gen_lazy_tensor(
get_device_fn,
),
grouped_native_functions,
codegenInplaceVariant=True,
)
),
},

View File

@ -3,6 +3,7 @@ import re
from torchgen.utils import assert_never
from dataclasses import dataclass
import dataclasses
from typing import List, Dict, Optional, Iterator, Tuple, Set, Sequence, Callable, Union
from enum import Enum, auto
import itertools
@ -296,7 +297,9 @@ class DeviceCheckType(Enum):
ExactSame = 1
ViewSchemaKind = Enum("ViewSchemaKind", ("aliasing", "inplace", "out", "non_aliasing"))
ViewSchemaKind = Enum(
"ViewSchemaKind", ("aliasing", "aliasing_inplace", "non_aliasing")
)
# The basic input to the code generation is native_functions.yaml.
# The name "native", BTW, comes from the distinction between native
@ -354,6 +357,14 @@ class NativeFunction:
# defined. This is for conveniently reporting error messages!
loc: "Location"
# A list of operators that are expected to be auto-generated for this NativeFunction.
# Note: This list isn't actually directly used by the codegen to generate anything.
# Instead, the codegen figures out what operators to generate purely based off of
# function schema, and uses the autogen declarations to error check.
# We expect every NativeFunction that gets auto-generated be explicitly called out
# in native_functions.yaml
autogen: List["OperatorName"]
# If non-empty, this kernel is subject to ufunc codegen.
# Sorted by ufunc_key
ufunc_inner_loop: Dict[UfuncKey, "UfuncInnerLoop"]
@ -582,6 +593,14 @@ class NativeFunction:
"implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only"
)
autogen_str = e.pop("autogen", "")
assert isinstance(autogen_str, str)
autogen = (
[]
if autogen_str == ""
else [OperatorName.parse(x) for x in autogen_str.split(", ")]
)
raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {})
ufunc_inner_loop = {}
if isinstance(raw_ufunc_inner_loop, str):
@ -652,6 +671,7 @@ class NativeFunction:
structured_delegate=structured_delegate,
structured_inherits=structured_inherits,
precomputed=precomputed,
autogen=autogen,
ufunc_inner_loop=ufunc_inner_loop,
manual_kernel_registration=manual_kernel_registration,
manual_cpp_binding=manual_cpp_binding,
@ -754,12 +774,10 @@ class NativeFunction:
@property
def view_schema_kind(self) -> ViewSchemaKind:
# This covers both "ordinary" inplace ops, and inplace_views
if self.func.name.name.inplace:
return ViewSchemaKind.inplace
elif self.func.is_out_fn():
return ViewSchemaKind.out
elif self.is_view_op:
if self.is_view_op and self.func.name.name.inplace:
assert "inplace_view" in self.tags
return ViewSchemaKind.aliasing_inplace
if self.is_view_op:
return ViewSchemaKind.aliasing
else:
return ViewSchemaKind.non_aliasing
@ -769,7 +787,7 @@ class NativeFunction:
return self.func.name.name.base
SchemaKind = Enum("SchemaKind", ("functional", "inplace", "out"))
SchemaKind = Enum("SchemaKind", ("functional", "inplace", "out", "mutable"))
# A structured kernel is guaranteed to have a functional and out variant, and
# optionally an inplace variant.
@ -781,6 +799,7 @@ SchemaKind = Enum("SchemaKind", ("functional", "inplace", "out"))
class NativeFunctionsGroup:
functional: NativeFunction
inplace: Optional[NativeFunction]
mutable: Optional[NativeFunction]
out: NativeFunction
@property
@ -797,10 +816,19 @@ class NativeFunctionsGroup:
f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
)
assert self.functional.func.kind() == SchemaKind.functional
assert not self.functional.is_view_op, (
"View operator shouldn't be grouped into NativeFunctionsGroup objects."
f"This is likely because you tried to add an out= variant for '{f.func.name}', which is an existing view operator."
"out= variants of view operators are not valid. Please reach out to to the core team if you have questions."
)
assert self.out.func.kind() == SchemaKind.out
if self.inplace is not None:
assert self.inplace.func.kind() == SchemaKind.inplace
if self.mutable is not None:
assert self.mutable.func.kind() == SchemaKind.mutable
if self.structured:
# For now, structured composite kernels are not supported (need some
# design work to figure out how to make the composite case work)
@ -813,6 +841,25 @@ class NativeFunctionsGroup:
if self.inplace is not None:
assert self.inplace.structured_delegate == self.out.func.name
generated_fns = [
str(f.func.name) for f in self.functions() if "generated" in f.tags
]
generated_fns_str = ", ".join(str(x) for x in generated_fns)
expected_generated_fns = f.autogen
expected_generated_fns_str = ", ".join(str(x) for x in expected_generated_fns)
if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
raise RuntimeError(
f"The codegen expects to be able to generate '{generated_fns_str}'."
" In order to generate them however, we expect them to be called out explicitly in the yaml."
f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}"
)
if expected_generated_fns_str != generated_fns_str:
raise RuntimeError(
f"The codegen expects to be able to generate '{generated_fns_str}'."
f" To do so, it expects a line: 'autogen: {generated_fns_str}'."
f" Instead, it found 'autogen: {generated_fns_str}'"
)
def signature(self) -> "FunctionSchema":
return self.out.func.signature()
@ -821,6 +868,8 @@ class NativeFunctionsGroup:
yield self.out
if self.inplace is not None:
yield self.inplace
if self.mutable is not None:
yield self.mutable
@property
def root_name(self) -> str:
@ -836,6 +885,7 @@ class NativeFunctionsGroup:
d = dict(d) # non-destructive updates please
functional = d.pop(SchemaKind.functional, None)
inplace = d.pop(SchemaKind.inplace, None)
mutable = d.pop(SchemaKind.mutable, None)
out = d.pop(SchemaKind.out, None)
assert not d
assert functional is not None
@ -847,6 +897,7 @@ class NativeFunctionsGroup:
return NativeFunctionsGroup(
functional=functional,
inplace=inplace,
mutable=mutable,
out=out,
)
@ -1045,12 +1096,26 @@ class FunctionSchema:
assert str(r) == func, f"{str(r)} != {func}"
return r
def returns_are_aliased(self) -> bool:
# We assert earlier that schemas can't have a mix of aliased and non-aliased returns
return any(
r
for r in self.returns
if r.annotation is not None and r.annotation.is_write
)
def __post_init__(self) -> None:
for arg, ret in zip(self.arguments.out, self.returns):
assert arg.annotation == ret.annotation, (
"Out arguments must have matching return Tensor; furthermore, "
"the ith-argument needs to correspond to the ith return"
)
# We also enforce that if you have any mutable, positional args, then they are not returned.
# This makes it easier to group these functions properly with their functional/out= counterparts.
for a in self.arguments.post_self_positional_mutable:
assert not any(
a.annotation == r.annotation for r in self.returns
), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}"
# Invariant: we expect out arguments to appear as keyword arguments in the schema.
# This means that all mutable returns should be aliased to a keyword argument
# (except for "self", which we explicitly don't treat as an out argument because of its use in methods)
@ -1063,6 +1128,19 @@ class FunctionSchema:
for ret in self.returns
if ret.annotation is not None and ret.annotation.is_write
]
immutable_returns = [
ret
for ret in self.returns
if ret.annotation is None or not ret.annotation.is_write
]
# Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)",
# because:
# (1) It's more annoying to handle properly
# (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple.
# Instead, we expect the (a!) argument to not be returned.
assert (
len(mutable_returns) == 0 or len(immutable_returns) == 0
), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
for ret in mutable_returns:
assert any([ret.annotation == arg.annotation for arg in out_and_self]), (
'All mutable returns must be aliased either to a keyword argument, or to "self". '
@ -1103,6 +1181,22 @@ class FunctionSchema:
# so in all other cases we expect the return type to be none.
assert len(self.returns) == 0
if self.arguments.tensor_options is not None:
assert self.kind() == SchemaKind.functional, (
"Found an operator that is not functional, but has tensor options arguments."
"This is not allowed- tensor options arguments are only allowed for factory functions."
f"schema: {str(self)}"
)
if self.is_functional_fn():
assert self.kind() == SchemaKind.functional, (
"Found an operator that is not functional, but its overload contains the string 'functional'."
"This is a special keyword in the codegen, please use a different overload name."
f"schema: {str(self)}"
)
def is_functional_fn(self) -> bool:
return "functional" in self.name.overload_name
def is_out_fn(self) -> bool:
# Note [is_out_fn]
#
@ -1139,44 +1233,99 @@ class FunctionSchema:
modifies the self argument inplace; an out schema writes
the result into an explicitly provided out argument.
"""
is_inplace = self.name.name.inplace
is_out = bool(self.arguments.out)
assert not (is_inplace and is_out)
is_inplace = self.name.name.inplace
is_mutable = any(
a.annotation is not None and a.annotation.is_write
for a in self.arguments.post_self_positional
)
assert not (is_out and is_inplace)
# out= and inplace schemas can also have post_self_positional mutable args,
# but we give precedence to out= and inplace when deciding the schema kind.
# Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops
# to also worry about mutable post_self_positional arguments,
# but it seems like a much bigger lift to classify them has having a new schema kind.
# The number of ops that fit in this strange category is small enough that
# we can probably manually write code for them instead of forcing the codegen to handle them.
if is_inplace:
return SchemaKind.inplace
elif is_out:
return SchemaKind.out
elif is_mutable:
return SchemaKind.mutable
else:
return SchemaKind.functional
# For every return:
# - If the return aliases an input, we return the input name
# - Otherwise, we return None.
# If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
def aliased_return_names(self) -> List[Optional[str]]:
outs: List[Optional[str]] = []
for r in self.returns:
aliased_args = [
a
for a in self.arguments.flat_all
if a.annotation is not None and a.annotation == r.annotation
]
if len(aliased_args) == 0:
outs.append(None)
elif len(aliased_args) == 1:
outs.append(aliased_args[0].name)
else:
aliased_names = ", ".join(a.name for a in aliased_args)
raise AssertionError(
f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})"
)
return outs
def signature(
self, *, strip_default: bool = False, strip_view_copy_name: bool = False
self,
*,
strip_default: bool = False,
strip_view_copy_name: bool = False,
keep_return_names: bool = False,
) -> "FunctionSchema":
"""
Certain schemas are 'related', in that they are simply
inplace/out/functional versions of the same function. This method
factors these schemas into the "core" functional signature which
is equal across all versions.
Certain schemas are 'related', in that they are simply
inplace/out/functional versions of the same function. This method
factors these schemas into the "core" functional signature which
is equal across all versions.
Here is what normalization happens to the schema to convert
it to a signature:
- The overload name is stripped (name is retained, since
it expresses semantic content about what the function does)
- Inplace is set False
- Out arguments are stripped
- Mutability annotations are stripped (this is sound
because you cannot overload on mutability annotation)
- Return names are stripped since they are not overloadable and
some variants have return names but some not
Here is what normalization happens to the schema to convert
it to a signature:
- The overload name is stripped (name is retained, since
it expresses semantic content about what the function does)
- Inplace is set False
- Out arguments are stripped
- Mutable post_self_positional args are converted to returns
- Mutability annotations are stripped (this is sound
because you cannot overload on mutability annotation)
- Return names are stripped since they are not overloadable and
some variants have return names but some not
- TensorOptions are dropped
because out= variants of factory functions don't include them
(and we want to be able to pair up factory functions with their out variants)
Finally, we want to be able to pair up related "view" and their
corresponding "view_copy" operators. We do this by optionally
stripping the trailing "_copy" from the base name.
Finally, we want to be able to pair up related "view" and their
corresponding "view_copy" operators. We do this by optionally
stripping the trailing "_copy" from the base name.
Example of a mutable op before and after:
f.func (Mutable operator):
_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
f.func (Corresponding functional operator):
_fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) # noqa: B950
f.func.signature() output:
_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) # noqa: B950
"""
def strip_ret_annotation(r: Return) -> Return:
return Return(
name=None,
name=r.name if keep_return_names else None,
type=r.type,
annotation=None,
)
@ -1185,6 +1334,43 @@ class FunctionSchema:
if strip_view_copy_name and base_name.endswith("_copy"):
base_name = base_name.replace("_copy", "")
# find mutable inputs that are not originally returned, and convert them to returns
returns_from_mutable_inputs = tuple(
# When we're grouping functions we strip the return names,
# but when we're generating the actual functional variants then we follow
# a convention for what to name the returns
Return(
name=f"{a.name}_out" if keep_return_names else None,
type=a.type,
annotation=None,
)
for a in itertools.chain(
# Order is important here (otherwise e.g. inplace with mutable args
# and out= with mutable args won't have the same signature)
[self.arguments.self_arg.argument]
if self.arguments.self_arg is not None
else [],
self.arguments.out,
self.arguments.post_self_positional,
)
if a.annotation is not None
and a.annotation.is_write
and not any(a.annotation == r.annotation for r in self.returns)
)
original_returns = tuple(map(strip_ret_annotation, self.returns))
# Ordering is important here. We expect the "mutable input" returns to come last.
returns = original_returns + returns_from_mutable_inputs
args_sig = self.arguments.signature(strip_default=strip_default)
# See Note [arange.start_step schema]
if str(self.name) == "arange.start_step":
args_sig = Arguments.parse(
str(args_sig).replace("Scalar step", "Scalar step=1")
)
# See Note [bernoulli.p schema]
if str(self.name) == "bernoulli.p":
args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5"))
return FunctionSchema(
name=OperatorName(
name=BaseOperatorName(
@ -1194,16 +1380,23 @@ class FunctionSchema:
),
overload_name="", # stripped
),
arguments=self.arguments.signature(strip_default=strip_default),
returns=tuple(map(strip_ret_annotation, self.returns)),
arguments=args_sig,
returns=returns,
)
def view_signature(self) -> "FunctionSchema":
return self.signature(strip_view_copy_name=True)
def with_name(self, name: "OperatorName") -> "FunctionSchema":
return FunctionSchema(
name=name,
arguments=self.arguments,
returns=self.returns,
)
@property
def modifies_arguments(self) -> bool:
return self.kind() in [SchemaKind.inplace, SchemaKind.out]
return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
def __str__(self) -> str:
all_arguments_str = str(self.arguments)
@ -1587,6 +1780,10 @@ class Arguments:
ret.extend(self.post_self_positional)
return ret
@property
def post_self_positional_mutable(self) -> Sequence[Argument]:
return [a for a in self.post_self_positional if a.is_write]
# NB: doesn't contain out arguments
@property
def flat_kwarg_only(self) -> Sequence[Argument]:
@ -1640,6 +1837,13 @@ class Arguments:
ret.extend(self.out)
return ret
def mutable_arg_names(self) -> List[str]:
return [
a.name
for a in self.flat_all
if a.annotation is not None and a.annotation.is_write
]
def signature(self, *, strip_default: bool = False) -> "Arguments":
# dataclasses.replace could be used here, but it is less
# type safe so for now I've opted to type everything out
@ -1661,18 +1865,36 @@ class Arguments:
post_self_positional=tuple(
map(strip_arg_annotation, self.post_self_positional)
),
# Since TensorOptions are droped, the post_tensor_options_kwargs are
# converted to pre_tensor_options_kwargs
pre_tensor_options_kwarg_only=tuple(
map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)
),
# NB: tensor_options guaranteed to not have any alias annotations
tensor_options=self.tensor_options,
post_tensor_options_kwarg_only=tuple(
map(strip_arg_annotation, self.post_tensor_options_kwarg_only)
),
)
+ tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)),
# TensorOptions are dropped in signature,
# so we can pair factory functions with their out= variants.
tensor_options=None,
post_tensor_options_kwarg_only=tuple(),
# out arguments are dropped in signature
out=(),
)
def remove_self_annotation(self) -> "Arguments":
assert self.self_arg is not None
return dataclasses.replace(
self,
self_arg=SelfArgument(
dataclasses.replace(self.self_arg.argument, annotation=None)
),
)
def with_out_args(self, outs: List[Argument]) -> "Arguments":
assert len(self.out) == 0
return dataclasses.replace(
self,
out=tuple(outs),
)
@staticmethod
def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]:
positional: List[Argument] = []
@ -1806,6 +2028,17 @@ class Arguments:
if self.tensor_options is None:
assert not self.post_tensor_options_kwarg_only
# We don't allow any of the following to have argument annotations,
# to keep things simple.
mutable_pre_self_positionals = [
a
for a in self.pre_self_positional
if a.annotation is not None and a.annotation.is_write
]
assert (
len(mutable_pre_self_positionals) == 0
), "mutable pre_self_positional arguments are not currently supported in the schema"
# Names that validly are __iXXX__ indicating inplace operations.
# Taken from https://www.python.org/dev/peps/pep-0203/#new-methods
@ -1923,6 +2156,16 @@ class OperatorName:
overload_name=self.overload_name,
)
def with_overload(self, overload: str) -> "OperatorName":
return OperatorName(
name=BaseOperatorName(
base=self.name.base,
inplace=False,
dunder_method=self.name.dunder_method,
),
overload_name=overload,
)
def gets_generated_out_inplace_wrapper(
f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex

View File

@ -0,0 +1,382 @@
from torchgen.model import (
Argument,
DispatchKey,
FunctionSchema,
BaseType,
BaseTy,
Return,
Annotation,
NativeFunction,
OperatorName,
BackendIndex,
BackendMetadata,
DeviceCheckType,
SchemaKind,
Variant,
)
from torchgen.utils import (
concatMap,
)
from typing import List, Tuple, Sequence, Dict
from collections import defaultdict
# See Note: [Out ops with functional variants that don't get grouped properly]
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
# This has a functional variant, but it's currently marked private.
# This function should be marked private as well (*_backward ops aren't exposed to python anyway).
"adaptive_avg_pool3d_backward.grad_input",
# There's a functional variant, _slow_conv2d_backward.output_mask, that isn't grouped properly.
# Maybe we can kill this operator in favor of convolution_backward?
"_slow_conv2d_backward.grad_input",
]
# See Note: [Mutable ops that cannot get an out variant]
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
# should be out=?
"_cummax_helper",
# should be out=?
"_cummin_helper",
]
INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
# polygamma and polygamma.out both exist, but have a
# pre-self arg (while polygamma_ does not)
# We should either fix this schema so it can be grouped properly,
# or allow the codegen to generate new functional/out= NativeFunctions for this op
# (which would require changing its overload name to prevent overload ambiguity).
"polygamma_"
]
# Groups "similar" NativeFunctions together
# example add.Tensor, add_.Tensor, add.out
# "similar" NativeFunctions are all expected to have an identical `signature()`,
# But have differing SchemaKinds.
def pre_group_native_functions(
native_functions: Sequence[NativeFunction],
) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]:
pre_grouped_native_functions: Dict[
FunctionSchema, Dict[SchemaKind, NativeFunction]
] = defaultdict(dict)
for f in native_functions:
d = pre_grouped_native_functions[f.func.signature()]
assert f.func.kind() not in d
d[f.func.kind()] = f
return pre_grouped_native_functions
# Helper function: given an inplace FunctionSchema, generate its corresponding out= variant
# Example before:
# _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
# Example after:
# _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out)
def self_to_out_signature(func: FunctionSchema) -> FunctionSchema:
# Generating an out= schema from an inplace schema.
assert func.kind() == SchemaKind.inplace
assert func.arguments.self_arg is not None
# The new out= schema has:
# - a new out argument with the same type as "func" (but with a mutable annotation)
# - The returns (if any) now alias the out= argument instead of "func"
# - an "out" overload name
return FunctionSchema(
name=func.name.remove_inplace().with_overload(
"out" if not func.name.overload_name else f"{func.name.overload_name}_out"
),
arguments=func.arguments.remove_self_annotation().with_out_args(
[
Argument(
name="out",
type=func.arguments.self_arg.argument.type,
default=None,
annotation=func.arguments.self_arg.argument.annotation,
)
]
),
returns=func.returns,
)
# Helper function: given a mutable FunctionSchema, generate its corresponding out= variant
# Example before:
# _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
# Example after:
# _fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) # noqa: B950
def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
# Generating an out= schema from a mutable schema.
assert func.kind() == SchemaKind.mutable
# The new out= schema has:
# - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments
# (if the argument is a tensor then we also return it for method chaining,
# otherwise we return nothing)
# - an "out" overload name
#
# Note that:
# (1) This also means that we can *only* generate an out= variant from a mutable schema
# if the mutable schema has at least one tensor-like non-aliasing return.
# (2) The generated out= variant still has mutable positional arguments,
# but if necessary we could probably add another out= variant that also
# functionalizes the mutable arguments (a functional_out variant)
# More of a sanity check - our existing restrictions on schemas should enforce that
# mutable schema kinds never return their mutable arguments.
assert not any(
r.annotation is not None and r.annotation.is_write for r in func.returns
)
tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
assert len(tensorlike_rets) > 0
used_annotations = concatMap(
lambda a: [] if a.annotation is None else a.annotation.alias_set,
func.arguments.flat_all,
)
valid_annotations = [
x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations
]
all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
new_out_args: List[Argument] = []
# The end result of new_returns is that:
# - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
# - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
new_returns: List[Return] = []
for (i, r) in enumerate(func.returns):
if r.type.is_tensor_like():
new_out = Argument(
name=f"out{i}",
type=r.type,
default=None,
annotation=Annotation.parse(f"{valid_annotations[i]}!"),
)
new_out_args.append(new_out)
if all_rets_are_tensors:
# The convention for out= schemas is that they only return their out arguments
# if the return is a plain Tensor (or if it's a tuple of plain Tensors)
new_ret = Return(
name=None, type=new_out.type, annotation=new_out.annotation
)
new_returns.append(new_ret)
else:
new_returns.append(r)
return FunctionSchema(
name=func.name.remove_inplace().with_overload(
"out" if not func.name.overload_name else f"{func.name.overload_name}_out"
),
arguments=func.arguments.with_out_args(new_out_args),
returns=tuple(new_returns),
)
# This function, given function of one SchemaKind, as well as a target SchemaKind,
# generates a new NativeFunction with the same properties, but using the target SchemaKind.
# We only actually generate functions for either functional or out= SchemaKinds.
# This function returns a tuple, with:
# - The generated NativeFunction
# - a dictionary of `BackendIndex` objects, describing which dispatch keys
# we will generate kernels for, for the new NativeFunction.
# Details are in the function, but we only generate composite kernels (in some cases) today.
def generate_function(
f: NativeFunction, k: SchemaKind
) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]:
from torchgen.api import cpp
if k == SchemaKind.functional:
assert f.func.kind() != SchemaKind.functional
gets_composite_kernel = True
# The new "functional" NativeFunction has:
# - any mutable arguments have been converted into (immutable) returns.
# (if a mutable argument was not also a return, it gets converted to one)
# - a "functional" overload name.
# The default grouping logic in signature() actually already does this,
# so we can piggy-back off it (but we still want return names)
func = f.func.signature(keep_return_names=True).with_name(
f.func.name.remove_inplace().with_overload(
"functional"
if not f.func.name.overload_name
else f"{f.func.name.overload_name}_functional"
)
)
elif k == SchemaKind.out:
# We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily,
# but at least today, there is no good reason to actually use them.
# we'll generate a dispatcher entry for them, but won't actually register any kernels for them.
gets_composite_kernel = False
if f.func.kind() == SchemaKind.inplace:
func = self_to_out_signature(f.func)
elif f.func.kind() == SchemaKind.mutable:
func = mutable_to_out_signature(f.func)
else:
raise AssertionError(
"We only bother generating out= functions from either inplace or mutable variants"
)
else:
raise AssertionError(
"We currently only generate either functional or out= NativeFunctions"
)
if gets_composite_kernel:
backend_metadata = {
DispatchKey.CompositeExplicitAutograd: {
func.name: BackendMetadata(cpp.name(func), structured=False)
}
}
else:
backend_metadata = {}
return (
NativeFunction(
func=func,
use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
# These generated fn's aren't meant to be user friendly- don't generate methods.
variants=set([Variant.function]),
structured=False,
structured_delegate=None,
structured_inherits=None,
precomputed=None,
autogen=[],
ufunc_inner_loop={},
manual_kernel_registration=False,
manual_cpp_binding=False,
python_module=None,
category_override=None,
device_guard=False,
device_check=DeviceCheckType.NoCheck,
loc=f.loc,
cpp_no_default_args=set(),
is_abstract=f.is_abstract,
has_composite_implicit_autograd_kernel=False,
has_composite_explicit_autograd_kernel=gets_composite_kernel,
# Every generated NativeFunction gets a "generated" tag, so it's easy to tell
# which NativeFunction objects did not come directly from native_functions.yaml.
tags=set(["generated"]),
),
backend_metadata,
)
# This function is responsible for adding generated NativeFunctions which don't appear
# explicitly in the codegen.
# You can inspect the full list of NativeFunctions yourself with the torchgen package, by running
# torchgen.parse_native_yaml("aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml")
# (Maybe we should make a friendly API for this)
#
# Note: this function *mutates* its two inputs,
# adding the new NativeFunctions / BackendMetadata to them
def add_generated_native_functions(
rs: List[NativeFunction],
indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
) -> None:
# The main code for gnerating new NativeFunctions
# First we group of NaitveFunctions by schema kind,
# then we detect which ones are missing and generate them.
pre_grouped_native_functions = pre_group_native_functions(rs)
for k, d in pre_grouped_native_functions.items():
has_functional = SchemaKind.functional in d
has_inplace = SchemaKind.inplace in d
has_mutable = SchemaKind.mutable in d
has_out = SchemaKind.out in d
# We automatically generate a few native functions that don't exist in the yaml, for a few reasons:
# (1) If an operator has an inplace/out= variant but no functional variant, we can generate
# a simple functional variant that the functionalization pass can consume.
# (2) If an operator has an inplace and functional but no out= variant, we generate an out=
# variant, mostly so we can easily pair up functions into NativeFunctionsGroup,
# while maintaining the constraint that the out= variant is "required".
#
# For now, we don't bother generated NativeFunctions for existing operators
# that only have a functional variant.
if has_mutable or has_inplace or has_out:
# Don't bother generating functions trio's for native functions that bypass the dispatcher.
are_manual = all(f.manual_cpp_binding for f in d.values())
# Don't bother generating functional + out= variants for view operators
has_view_ops = (
has_inplace and "inplace_view" in d[SchemaKind.inplace].tags
) or any(f.is_view_op for f in d.values())
# Don't generate the other variants for CompositeImplicitAutograd operators.
# We could probably do this, but the main benefit of generating the function triplets
# is for transforms that need them, and transforms don't need to act directly
# on CompositeImplicitAutograd operators (since we let them decompose).
are_composite_implicit = all(
f.has_composite_implicit_autograd_kernel for f in d.values()
)
if are_manual or has_view_ops or are_composite_implicit:
continue
if has_out and len(d.values()) == 1:
# Note: [Out ops with functional variants that don't get grouped properly]
# In theory we could validly have an out= operator in native_functions.yaml
# that has no other variants.
# But today, all of the operators where that's the case actually do have
# functional variants, that we are just unable to pair up properly.
# I think banning this all together is probably safer
# (you can always add a functional variant yourself if you want to add a new out= operator).
#
# We should probably fix the existing cases; this check is to prevent us from adding more over time.
if (
str(d[SchemaKind.out].func.name)
not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
):
raise AssertionError(
f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}"
)
continue
# Some inplace ops that have problematic schemas (that we should fix), which prevent us
# from generating out= and functional variants
if (
has_inplace
and str(d[SchemaKind.inplace].func.name)
in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
):
continue
base_fn = (
d[SchemaKind.inplace]
if has_inplace
else d[SchemaKind.mutable]
if has_mutable
else d[SchemaKind.out]
)
# Note: [Mutable ops that cannot get an out variant]
# We can only generate an out= variant if either:
# - the original function has tensor-like returns (since we can convert them to out kwargs)
# - or it's inplace (since we can convert `self` to an out kwarg)
# There are only two functions that don't fit this criteria today though,
# and they both look like they should be fixed to be out= variants,
# so if feels safer to ban this schema all-together
gets_out_variant = not has_out and (
base_fn.func.kind() == SchemaKind.inplace
or any(r.type.is_tensor_like() for r in base_fn.func.returns)
)
if not has_out and not gets_out_variant:
if (
str(base_fn.func.name)
not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
):
raise AssertionError(
f"""Found a mutable operator that we could not generate an out= variant for: {str(base_fn.func)}.
These operators are problematic, because we can't easily auto-generate functionalization code for them. If you really need
the operator have the schema mentioned, that add the name of the operator to the allow-list. Otherwise if possible,
please convert it to an inplace operator"""
)
# Generate an out= variant
if gets_out_variant:
fn, metadata = generate_function(base_fn, SchemaKind.out)
d[SchemaKind.out] = fn
BackendIndex.grow_index(indices, metadata)
rs.append(fn)
# Generate a functional variant, but only do it if the operator got an out= variant
# (Functional variants are only useful if we can group up the variants,
# which we can only do if they have an out= variant)
if not has_functional and (has_out or gets_out_variant):
fn, metadata = generate_function(base_fn, SchemaKind.functional)
d[SchemaKind.functional] = fn
BackendIndex.grow_index(indices, metadata)
rs.append(fn)