mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[test] attempt to functionalize ops with mutable positional-only args
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76320 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
b8639cf6e1
commit
0161e9eb00
@ -505,6 +505,7 @@
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: add_relu_
|
||||
autogen: _add_relu.Scalar_out
|
||||
|
||||
# For C++ only, until we have conversion from C++ numbers to Tensor
|
||||
- func: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
|
||||
@ -518,6 +519,7 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: add_
|
||||
autogen: add.Scalar_out
|
||||
|
||||
- func: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
||||
structured_delegate: addmv.out
|
||||
@ -609,6 +611,12 @@
|
||||
|
||||
- func: arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
|
||||
# Note [arange.start_step schema]
|
||||
# We want `arange.start_step` to be grouped up with `arange.start_out`,
|
||||
# But this doesn't happen automatically because the step argument
|
||||
# is defaultable for .start_out but not for .start_step.
|
||||
# We should probably just make "step" a defaultable param on arange.start,
|
||||
# and kill arange.start_step.
|
||||
- func: arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
|
||||
- func: arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!)
|
||||
@ -892,6 +900,7 @@
|
||||
dispatch:
|
||||
CPU, CUDA: bernoulli_
|
||||
MPS: bernoulli_mps_
|
||||
autogen: bernoulli.Tensor_functional, bernoulli.Tensor_out
|
||||
|
||||
- func: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -899,7 +908,10 @@
|
||||
dispatch:
|
||||
CPU, CUDA: bernoulli_
|
||||
MPS: bernoulli_mps_
|
||||
autogen: bernoulli.float_out
|
||||
|
||||
# Note [bernoulli.p schema]
|
||||
# We should probably just fix the overload ambiguity by appending a _functional to the C++ API name (BC breaking)
|
||||
# This out-of-place version isn't used explicitly, but needed by jit.
|
||||
# There is no default valid on `p` here because it would introduce ambiguity
|
||||
# with `bernoulli(Tensor self, *, Generator? generator=None)` declaration.
|
||||
@ -1420,6 +1432,7 @@
|
||||
SparseCPU, SparseCUDA: copy_sparse_wrapper_
|
||||
CompositeExplicitAutograd: copy_
|
||||
SparseCsrCPU, SparseCsrCUDA: copy_sparse_compressed_
|
||||
autogen: copy.out
|
||||
|
||||
- func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor
|
||||
dispatch:
|
||||
@ -1780,6 +1793,7 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: div_
|
||||
autogen: div.Scalar_out
|
||||
|
||||
- func: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor
|
||||
variants: function, method
|
||||
@ -1790,6 +1804,7 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: div_
|
||||
autogen: div.Scalar_mode_out
|
||||
|
||||
# divide, alias for div
|
||||
- func: divide.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
@ -1880,6 +1895,7 @@
|
||||
dispatch:
|
||||
CPU: embedding_renorm_cpu_
|
||||
CUDA: embedding_renorm_cuda_
|
||||
autogen: embedding_renorm.functional, embedding_renorm.out
|
||||
|
||||
- func: embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
|
||||
|
||||
@ -1993,6 +2009,7 @@
|
||||
MPS: resize_mps_
|
||||
QuantizedCPU: quantized_resize_cpu_
|
||||
SparseCsrCPU, SparseCsrCUDA: resize_sparse_csr_
|
||||
autogen: resize.functional, resize.out
|
||||
|
||||
# This is a utility function to enable users to resize out tensor while registering kernels for out variants.
|
||||
# Eventually, we can consider exposing `resize_output` as a public API to ship it with python op registration
|
||||
@ -2002,6 +2019,7 @@
|
||||
variants: function
|
||||
dispatch:
|
||||
Meta: _resize_output_
|
||||
autogen: _resize_output.functional, _resize_output.out
|
||||
|
||||
- func: empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
category_override: factory
|
||||
@ -2201,6 +2219,7 @@
|
||||
QuantizedCPU, QuantizedCUDA: fill_quantized_
|
||||
Meta: fill_meta_
|
||||
SparseCsrCPU, SparseCsrCUDA: fill_sparse_csr_
|
||||
autogen: fill.Scalar_out
|
||||
|
||||
- func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -2210,6 +2229,7 @@
|
||||
MPS: fill_tensor_mps_
|
||||
QuantizedCPU, QuantizedCUDA: fill_quantized_
|
||||
Meta: fill_meta_
|
||||
autogen: fill.Tensor_out
|
||||
|
||||
- func: floor(Tensor self) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -2494,6 +2514,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: index_put_
|
||||
autogen: index_put.out
|
||||
# NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp:
|
||||
# - Tensor & Tensor::index_put_(ArrayRef<TensorIndex> indices, Tensor const & rhs)
|
||||
# - Tensor & Tensor::index_put_(ArrayRef<TensorIndex> indices, Scalar v)
|
||||
@ -2511,6 +2532,7 @@
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: _index_put_impl_
|
||||
autogen: _index_put_impl.functional, _index_put_impl.out
|
||||
|
||||
- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
|
||||
variants: function
|
||||
@ -3380,6 +3402,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: mul_
|
||||
SparseCsrCPU, SparseCsrCUDA: mul__scalar_sparse_csr
|
||||
autogen: mul.Scalar_out
|
||||
|
||||
# multiply, alias for mul
|
||||
- func: multiply.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
@ -3918,6 +3941,7 @@
|
||||
MkldnnCPU: mkldnn_relu_
|
||||
QuantizedCPU: relu_quantized_cpu_
|
||||
NestedTensorCPU, NestedTensorCUDA: NestedTensor_relu_
|
||||
autogen: relu.out
|
||||
|
||||
- func: relu6(Tensor self) -> Tensor
|
||||
python_module: nn
|
||||
@ -4061,6 +4085,7 @@
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: celu_
|
||||
autogen: celu.out
|
||||
|
||||
- func: silu(Tensor self) -> Tensor
|
||||
structured_delegate: silu.out
|
||||
@ -4800,6 +4825,7 @@
|
||||
device_guard: False
|
||||
dispatch:
|
||||
MkldnnCPU: mkldnn_transpose_
|
||||
autogen: _mkldnn_transpose.out
|
||||
|
||||
- func: one_hot(Tensor self, int num_classes=-1) -> Tensor
|
||||
python_module: nn
|
||||
@ -5312,6 +5338,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: resize_as_
|
||||
autogen: resize_as.functional, resize_as.out
|
||||
|
||||
- func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
|
||||
use_const_ref_for_mutable_tensors: True
|
||||
@ -5319,6 +5346,7 @@
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: resize_as_sparse_
|
||||
SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_csr_
|
||||
autogen: resize_as_sparse.functional, resize_as_sparse.out
|
||||
|
||||
- func: zero_(Tensor(a!) self) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -5330,6 +5358,7 @@
|
||||
SparseCPU, SparseCUDA: zero_sparse_
|
||||
SparseCsrCPU, SparseCsrCUDA: zero_sparse_csr_
|
||||
MkldnnCPU: mkldnn_zero_
|
||||
autogen: zero.functional, zero.out
|
||||
|
||||
- func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -5367,6 +5396,7 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: sub_
|
||||
autogen: sub.Scalar_out
|
||||
|
||||
# subtract, alias for sub
|
||||
- func: subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
||||
@ -5629,12 +5659,14 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_resize_
|
||||
autogen: sparse_resize.functional, sparse_resize.out
|
||||
|
||||
- func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
|
||||
use_const_ref_for_mutable_tensors: True
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_resize_and_clear_
|
||||
autogen: sparse_resize_and_clear.functional, sparse_resize_and_clear.out
|
||||
|
||||
- func: sparse_mask(Tensor self, Tensor mask) -> Tensor
|
||||
variants: method
|
||||
@ -5740,6 +5772,7 @@
|
||||
SparseCPU, SparseCUDA: _coalesced_sparse_
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
autogen: _coalesced.functional, _coalesced.out
|
||||
|
||||
- func: indices(Tensor(a) self) -> Tensor(a)
|
||||
variants: method
|
||||
@ -5799,6 +5832,7 @@
|
||||
variants: function
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: copy_sparse_
|
||||
autogen: copy_sparse_to_sparse.functional, copy_sparse_to_sparse.out
|
||||
|
||||
- func: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]
|
||||
variants: function, method
|
||||
@ -6007,7 +6041,7 @@
|
||||
dispatch:
|
||||
CPU: fused_moving_avg_obs_fake_quant_cpu
|
||||
CUDA: fused_moving_avg_obs_fake_quant_cuda
|
||||
|
||||
autogen: _fused_moving_avg_obs_fq_helper.functional, _fused_moving_avg_obs_fq_helper.out
|
||||
|
||||
- func: _choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int)
|
||||
variants: function
|
||||
@ -6201,6 +6235,7 @@
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CPU, CUDA, Meta, MPS: set_
|
||||
autogen: set.source_Storage_functional, set.source_Storage_out
|
||||
|
||||
- func: set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!)
|
||||
variants: method
|
||||
@ -6211,6 +6246,7 @@
|
||||
CUDA: set_storage_cuda_
|
||||
MPS: set_storage_mps_
|
||||
QuantizedCPU, QuantizedCUDA: set_storage_quantized_
|
||||
autogen: set.source_Storage_storage_offset_functional, set.source_Storage_storage_offset_out
|
||||
|
||||
- func: set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!)
|
||||
variants: method
|
||||
@ -6223,6 +6259,7 @@
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CPU, CUDA, Meta, MPS: set_tensor_
|
||||
autogen: set.source_Tensor_functional, set.source_Tensor_out
|
||||
|
||||
- func: set_(Tensor(a!) self) -> Tensor(a!)
|
||||
variants: method
|
||||
@ -6231,6 +6268,7 @@
|
||||
CUDA: set_cuda_
|
||||
Meta: set_meta_
|
||||
MPS: set_mps_
|
||||
autogen: set.functional, set.out
|
||||
|
||||
- func: lift(Tensor self) -> Tensor
|
||||
variants: method
|
||||
@ -6253,6 +6291,7 @@
|
||||
CPU: masked_fill__cpu
|
||||
CUDA: masked_fill__cuda
|
||||
MPS: masked_fill__mps
|
||||
autogen: masked_fill.Scalar_out
|
||||
|
||||
- func: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -6267,6 +6306,7 @@
|
||||
CPU: masked_fill__cpu
|
||||
CUDA: masked_fill__cuda
|
||||
MPS: masked_fill__mps
|
||||
autogen: masked_fill.Tensor_out
|
||||
|
||||
- func: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -6279,6 +6319,7 @@
|
||||
dispatch:
|
||||
CPU: masked_scatter__cpu
|
||||
CUDA: masked_scatter__cuda
|
||||
autogen: masked_scatter.out
|
||||
|
||||
- func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor
|
||||
variants: function, method
|
||||
@ -6320,6 +6361,7 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: put_
|
||||
autogen: put.out
|
||||
|
||||
- func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
|
||||
variants: function, method
|
||||
@ -6367,6 +6409,7 @@
|
||||
dispatch:
|
||||
CPU: index_fill_
|
||||
CUDA: index_fill_
|
||||
autogen: index_fill.int_Scalar_out
|
||||
|
||||
- func: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -6379,6 +6422,7 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: index_fill_
|
||||
autogen: index_fill.int_Tensor_out
|
||||
|
||||
- func: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -6695,12 +6739,14 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: __ilshift__
|
||||
autogen: __lshift__.Scalar_out
|
||||
|
||||
- func: __ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: __ilshift__
|
||||
autogen: __lshift__.Tensor_out
|
||||
|
||||
- func: bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -6760,12 +6806,14 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: __irshift__
|
||||
autogen: __rshift__.Scalar_out
|
||||
|
||||
- func: __irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: __irshift__
|
||||
autogen: __rshift__.Tensor_out
|
||||
|
||||
- func: bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -6855,6 +6903,7 @@
|
||||
CPU, CUDA: random_
|
||||
Meta: random_meta_
|
||||
MPS: random_mps_
|
||||
autogen: random.from_functional, random.from_out
|
||||
|
||||
- func: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -6863,6 +6912,7 @@
|
||||
CPU, CUDA: random_
|
||||
Meta: random_meta_
|
||||
MPS: random_mps_
|
||||
autogen: random.to_functional, random.to_out
|
||||
|
||||
- func: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -6870,6 +6920,7 @@
|
||||
dispatch:
|
||||
CPU, CUDA: random_
|
||||
Meta: random_meta_
|
||||
autogen: random.functional, random.out
|
||||
|
||||
- func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -6878,24 +6929,28 @@
|
||||
CPU, CUDA: uniform_
|
||||
MPS: uniform_mps_
|
||||
Meta: uniform_meta_
|
||||
autogen: uniform.functional, uniform.out
|
||||
|
||||
- func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: cauchy_
|
||||
autogen: cauchy.functional, cauchy.out
|
||||
|
||||
- func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: log_normal_
|
||||
autogen: log_normal.functional, log_normal.out
|
||||
|
||||
- func: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: exponential_
|
||||
autogen: exponential.functional, exponential.out
|
||||
|
||||
- func: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -6904,6 +6959,7 @@
|
||||
CPU, CUDA: geometric_
|
||||
|
||||
# wrappers for TH functions
|
||||
autogen: geometric.functional, geometric.out
|
||||
|
||||
- func: diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
@ -8306,6 +8362,7 @@
|
||||
MPS: normal_mps_
|
||||
Meta: normal_meta_
|
||||
SparseCsrCPU, SparseCsrCUDA: normal_sparse_csr_
|
||||
autogen: normal.functional, normal.out
|
||||
|
||||
- func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
@ -8356,11 +8413,13 @@
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_
|
||||
autogen: _amp_foreach_non_finite_check_and_unscale.functional, _amp_foreach_non_finite_check_and_unscale.out
|
||||
|
||||
- func: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _amp_update_scale_cuda_
|
||||
autogen: _amp_update_scale.functional, _amp_update_scale.out
|
||||
|
||||
#- func: _cat(Tensor[] tensors, int dim=0) -> Tensor
|
||||
#dispatch:
|
||||
@ -8388,6 +8447,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_add_scalar_kernel_slow_
|
||||
CUDA: foreach_tensor_add_scalar_kernel_cuda_
|
||||
autogen: _foreach_add.Scalar_functional, _foreach_add.Scalar_out
|
||||
|
||||
- func: _foreach_sub.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8402,6 +8462,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_sub_scalar_kernel_slow_
|
||||
CUDA: foreach_tensor_sub_scalar_kernel_cuda_
|
||||
autogen: _foreach_sub.Scalar_functional, _foreach_sub.Scalar_out
|
||||
|
||||
- func: _foreach_mul.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8416,6 +8477,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_mul_scalar_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_scalar_kernel_cuda_
|
||||
autogen: _foreach_mul.Scalar_functional, _foreach_mul.Scalar_out
|
||||
|
||||
- func: _foreach_div.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8430,6 +8492,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_div_scalar_kernel_slow_
|
||||
CUDA: foreach_tensor_div_scalar_kernel_cuda_
|
||||
autogen: _foreach_div.Scalar_functional, _foreach_div.Scalar_out
|
||||
|
||||
- func: _foreach_add.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8444,6 +8507,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_add_list_kernel_slow_
|
||||
CUDA: foreach_tensor_add_list_kernel_cuda_
|
||||
autogen: _foreach_add.List_functional, _foreach_add.List_out
|
||||
|
||||
- func: _foreach_sub.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8458,6 +8522,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_sub_list_kernel_slow_
|
||||
CUDA: foreach_tensor_sub_list_kernel_cuda_
|
||||
autogen: _foreach_sub.List_functional, _foreach_sub.List_out
|
||||
|
||||
- func: _foreach_mul.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8472,6 +8537,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_mul_list_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_list_kernel_cuda_
|
||||
autogen: _foreach_mul.List_functional, _foreach_mul.List_out
|
||||
|
||||
- func: _foreach_div.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8486,6 +8552,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_div_list_kernel_slow_
|
||||
CUDA: foreach_tensor_div_list_kernel_cuda_
|
||||
autogen: _foreach_div.List_functional, _foreach_div.List_out
|
||||
|
||||
- func: _foreach_add.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8500,6 +8567,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_add_scalarlist_kernel_slow_
|
||||
CUDA: foreach_tensor_add_scalarlist_kernel_cuda_
|
||||
autogen: _foreach_add.ScalarList_functional, _foreach_add.ScalarList_out
|
||||
|
||||
- func: _foreach_sub.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8514,6 +8582,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_sub_scalarlist_kernel_slow_
|
||||
CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_
|
||||
autogen: _foreach_sub.ScalarList_functional, _foreach_sub.ScalarList_out
|
||||
|
||||
- func: _foreach_div.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8528,6 +8597,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_div_scalarlist_kernel_slow_
|
||||
CUDA: foreach_tensor_div_scalarlist_kernel_cuda_
|
||||
autogen: _foreach_div.ScalarList_functional, _foreach_div.ScalarList_out
|
||||
|
||||
- func: _foreach_mul.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8542,6 +8612,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_mul_scalarlist_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_
|
||||
autogen: _foreach_mul.ScalarList_functional, _foreach_mul.ScalarList_out
|
||||
|
||||
- func: _foreach_exp(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8556,6 +8627,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_zero_slow_
|
||||
CUDA: foreach_tensor_zero_cuda_
|
||||
autogen: _foreach_zero.functional, _foreach_zero.out
|
||||
|
||||
- func: _foreach_exp_(Tensor(a!)[] self) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8563,6 +8635,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_exp_slow_
|
||||
CUDA: foreach_tensor_exp_cuda_
|
||||
autogen: _foreach_exp.functional, _foreach_exp.out
|
||||
|
||||
- func: _foreach_sqrt(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8577,6 +8650,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_sqrt_slow_
|
||||
CUDA: foreach_tensor_sqrt_cuda_
|
||||
autogen: _foreach_sqrt.functional, _foreach_sqrt.out
|
||||
|
||||
- func: _foreach_abs(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8591,6 +8665,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_abs_slow_
|
||||
CUDA: foreach_tensor_abs_cuda_
|
||||
autogen: _foreach_abs.functional, _foreach_abs.out
|
||||
|
||||
- func: _foreach_acos(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8605,6 +8680,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_acos_slow_
|
||||
CUDA: foreach_tensor_acos_cuda_
|
||||
autogen: _foreach_acos.functional, _foreach_acos.out
|
||||
|
||||
- func: _foreach_asin(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8619,6 +8695,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_asin_slow_
|
||||
CUDA: foreach_tensor_asin_cuda_
|
||||
autogen: _foreach_asin.functional, _foreach_asin.out
|
||||
|
||||
- func: _foreach_atan(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8633,6 +8710,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_atan_slow_
|
||||
CUDA: foreach_tensor_atan_cuda_
|
||||
autogen: _foreach_atan.functional, _foreach_atan.out
|
||||
|
||||
- func: _foreach_ceil(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8647,6 +8725,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_ceil_slow_
|
||||
CUDA: foreach_tensor_ceil_cuda_
|
||||
autogen: _foreach_ceil.functional, _foreach_ceil.out
|
||||
|
||||
- func: _foreach_cos(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8661,6 +8740,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_cos_slow_
|
||||
CUDA: foreach_tensor_cos_cuda_
|
||||
autogen: _foreach_cos.functional, _foreach_cos.out
|
||||
|
||||
- func: _foreach_cosh(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8675,6 +8755,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_cosh_slow_
|
||||
CUDA: foreach_tensor_cosh_cuda_
|
||||
autogen: _foreach_cosh.functional, _foreach_cosh.out
|
||||
|
||||
- func: _foreach_erf(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8689,6 +8770,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_erf_slow_
|
||||
CUDA: foreach_tensor_erf_cuda_
|
||||
autogen: _foreach_erf.functional, _foreach_erf.out
|
||||
|
||||
- func: _foreach_erfc(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8703,6 +8785,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_erfc_slow_
|
||||
CUDA: foreach_tensor_erfc_cuda_
|
||||
autogen: _foreach_erfc.functional, _foreach_erfc.out
|
||||
|
||||
- func: _foreach_expm1(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8717,6 +8800,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_expm1_slow_
|
||||
CUDA: foreach_tensor_expm1_cuda_
|
||||
autogen: _foreach_expm1.functional, _foreach_expm1.out
|
||||
|
||||
- func: _foreach_floor(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8731,6 +8815,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_floor_slow_
|
||||
CUDA: foreach_tensor_floor_cuda_
|
||||
autogen: _foreach_floor.functional, _foreach_floor.out
|
||||
|
||||
- func: _foreach_log(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8745,6 +8830,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_log_slow_
|
||||
CUDA: foreach_tensor_log_cuda_
|
||||
autogen: _foreach_log.functional, _foreach_log.out
|
||||
|
||||
- func: _foreach_log10(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8759,6 +8845,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_log10_slow_
|
||||
CUDA: foreach_tensor_log10_cuda_
|
||||
autogen: _foreach_log10.functional, _foreach_log10.out
|
||||
|
||||
- func: _foreach_log1p(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8773,6 +8860,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_log1p_slow_
|
||||
CUDA: foreach_tensor_log1p_cuda_
|
||||
autogen: _foreach_log1p.functional, _foreach_log1p.out
|
||||
|
||||
- func: _foreach_log2(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8787,6 +8875,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_log2_slow_
|
||||
CUDA: foreach_tensor_log2_cuda_
|
||||
autogen: _foreach_log2.functional, _foreach_log2.out
|
||||
|
||||
- func: _foreach_neg(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8801,6 +8890,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_neg_slow_
|
||||
CUDA: foreach_tensor_neg_cuda_
|
||||
autogen: _foreach_neg.functional, _foreach_neg.out
|
||||
|
||||
- func: _foreach_tan(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8815,6 +8905,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_tan_slow_
|
||||
CUDA: foreach_tensor_tan_cuda_
|
||||
autogen: _foreach_tan.functional, _foreach_tan.out
|
||||
|
||||
- func: _foreach_tanh(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8829,6 +8920,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_tanh_slow_
|
||||
CUDA: foreach_tensor_tanh_cuda_
|
||||
autogen: _foreach_tanh.functional, _foreach_tanh.out
|
||||
|
||||
- func: _foreach_sin(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8843,6 +8935,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_sin_slow_
|
||||
CUDA: foreach_tensor_sin_cuda_
|
||||
autogen: _foreach_sin.functional, _foreach_sin.out
|
||||
|
||||
- func: _foreach_sinh(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8857,6 +8950,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_sinh_slow_
|
||||
CUDA: foreach_tensor_sinh_cuda_
|
||||
autogen: _foreach_sinh.functional, _foreach_sinh.out
|
||||
|
||||
- func: _foreach_round(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8871,6 +8965,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_round_slow_
|
||||
CUDA: foreach_tensor_round_cuda_
|
||||
autogen: _foreach_round.functional, _foreach_round.out
|
||||
|
||||
- func: _foreach_lgamma(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8885,6 +8980,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_lgamma_slow_
|
||||
CUDA: foreach_tensor_lgamma_cuda_
|
||||
autogen: _foreach_lgamma.functional, _foreach_lgamma.out
|
||||
|
||||
- func: _foreach_frac(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8899,6 +8995,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_frac_slow_
|
||||
CUDA: foreach_tensor_frac_cuda_
|
||||
autogen: _foreach_frac.functional, _foreach_frac.out
|
||||
|
||||
- func: _foreach_reciprocal(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8913,6 +9010,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_reciprocal_slow_
|
||||
CUDA: foreach_tensor_reciprocal_cuda_
|
||||
autogen: _foreach_reciprocal.functional, _foreach_reciprocal.out
|
||||
|
||||
- func: _foreach_sigmoid(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8927,6 +9025,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_sigmoid_slow_
|
||||
CUDA: foreach_tensor_sigmoid_cuda_
|
||||
autogen: _foreach_sigmoid.functional, _foreach_sigmoid.out
|
||||
|
||||
- func: _foreach_trunc(Tensor[] tensors) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8941,6 +9040,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_trunc_slow_
|
||||
CUDA: foreach_tensor_trunc_cuda_
|
||||
autogen: _foreach_trunc.functional, _foreach_trunc.out
|
||||
|
||||
- func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8948,6 +9048,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_addcdiv_scalar_slow_
|
||||
CUDA: foreach_tensor_addcdiv_scalar_cuda_
|
||||
autogen: _foreach_addcdiv.Scalar_functional, _foreach_addcdiv.Scalar_out
|
||||
|
||||
- func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8955,6 +9056,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_addcmul_scalar_slow_
|
||||
CUDA: foreach_tensor_addcmul_scalar_cuda_
|
||||
autogen: _foreach_addcmul.Scalar_functional, _foreach_addcmul.Scalar_out
|
||||
|
||||
- func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8962,6 +9064,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_addcdiv_scalarlist_slow_
|
||||
CUDA: foreach_tensor_addcdiv_scalarlist_cuda_
|
||||
autogen: _foreach_addcdiv.ScalarList_functional, _foreach_addcdiv.ScalarList_out
|
||||
|
||||
- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -8969,6 +9072,7 @@
|
||||
dispatch:
|
||||
CPU: foreach_tensor_addcmul_scalarlist_slow_
|
||||
CUDA: foreach_tensor_addcmul_scalarlist_cuda_
|
||||
autogen: _foreach_addcmul.ScalarList_functional, _foreach_addcmul.ScalarList_out
|
||||
|
||||
- func: _foreach_addcdiv.Scalar(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -11507,7 +11611,7 @@
|
||||
python_module: linalg
|
||||
variants: function
|
||||
|
||||
- func: linalg_eigvalsh.out(Tensor self, str UPLO='L', *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: linalg_eigvalsh.out(Tensor self, str UPLO="L", *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: linalg
|
||||
dispatch:
|
||||
CPU, CUDA: linalg_eigvalsh_out
|
||||
@ -11528,6 +11632,7 @@
|
||||
dispatch:
|
||||
CPU: _linalg_inv_out_helper_cpu
|
||||
CUDA: _linalg_inv_out_helper_cuda
|
||||
autogen: _linalg_inv_out_helper.functional, _linalg_inv_out_helper.out
|
||||
|
||||
- func: linalg_inv_ex(Tensor self, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
|
||||
python_module: linalg
|
||||
|
@ -19,7 +19,7 @@ full_codegen:
|
||||
- avg_pool2d_backward
|
||||
- baddbmm
|
||||
- bernoulli
|
||||
- bernoulli_.float
|
||||
- bernoulli.p
|
||||
- binary_cross_entropy
|
||||
- binary_cross_entropy_backward
|
||||
- bitwise_and.Tensor
|
||||
@ -72,8 +72,8 @@ full_codegen:
|
||||
- log_sigmoid_forward
|
||||
- lt.Scalar
|
||||
- lt.Tensor
|
||||
- masked_fill_.Scalar
|
||||
- masked_fill_.Tensor
|
||||
- masked_fill.Scalar
|
||||
- masked_fill.Tensor
|
||||
- max
|
||||
- max.dim
|
||||
- max_pool2d_with_indices
|
||||
@ -101,12 +101,11 @@ full_codegen:
|
||||
- norm.ScalarOpt_dim
|
||||
- pow.Tensor_Scalar
|
||||
- pow.Tensor_Tensor
|
||||
- random_
|
||||
- random_.from
|
||||
- random_.to
|
||||
- random.functional
|
||||
- random.from_functional
|
||||
- random.to_functional
|
||||
- reciprocal
|
||||
- relu
|
||||
- relu_
|
||||
- remainder.Tensor
|
||||
- repeat
|
||||
- rsqrt
|
||||
@ -141,7 +140,7 @@ full_codegen:
|
||||
- upsample_bilinear2d_backward
|
||||
- upsample_nearest2d
|
||||
- upsample_nearest2d_backward
|
||||
- zero_
|
||||
- zero.functional
|
||||
- narrow_copy.SymInt
|
||||
supported:
|
||||
- as_strided
|
||||
|
@ -13,8 +13,25 @@ $ops_headers
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
// This file contains a number of kernels for aten functions that are fully code-generated.
|
||||
// TODO: rename this file to something more generic.
|
||||
|
||||
at::Tensor clone_arg(const at::Tensor& t) {
|
||||
return t.clone();
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> clone_arg(const at::TensorList& t_list) {
|
||||
std::vector<at::Tensor> out(t_list.size());
|
||||
for (const auto& i : c10::irange(t_list.size())) {
|
||||
out[i] = t_list[i].clone();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
${CompositeViewCopyKernel_Definitions}
|
||||
|
||||
${GeneratedCompositeFunctional_Definitions}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -185,7 +185,9 @@ TEST_F(atest, ne_operators) {
|
||||
|
||||
TEST_F(atest, add_operators) {
|
||||
auto exp_tensor = tensor({-10, 1, 0, -1, 10});
|
||||
run_binary_ops_test(add_out, x_tensor, y_tensor, exp_tensor, INTBOOL, 2);
|
||||
run_binary_ops_test<
|
||||
at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Scalar&)>(
|
||||
add_out, x_tensor, y_tensor, exp_tensor, INTBOOL, 2);
|
||||
}
|
||||
|
||||
TEST_F(atest, max_operators) {
|
||||
|
@ -174,6 +174,17 @@ $2 = torch._ops.aten.add.Tensor($0, tensor([[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.]]))""")
|
||||
|
||||
# Some ops that are mutable are neither inplace nor out= ops.
|
||||
# They also need special handling.
|
||||
def test_mutable_op_not_inplace_or_other(self):
|
||||
def f(x):
|
||||
return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0)
|
||||
|
||||
logs = self.get_logs(f, torch.ones(1))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper.functional($0, $0, $0, $0, $0, $0, $0, 1.0, 0, 1, 0)""")
|
||||
|
||||
def test_tensor_list_composite(self):
|
||||
def f(x):
|
||||
# Test an op with TensorList input
|
||||
|
@ -3944,8 +3944,6 @@ class TestFunctionalTracing(JitTestCase):
|
||||
|
||||
"upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
|
||||
"upsample_nearest": INTERPOLATE_ARGS_CONFLICT,
|
||||
|
||||
"normalize" : MUTABLE,
|
||||
}
|
||||
|
||||
# List of nn.functionals with Tensor inputs but not with type annotation
|
||||
|
@ -177,6 +177,9 @@ SKIP_PYTHON_BINDINGS_SIGNATURES = [
|
||||
|
||||
@with_native_function
|
||||
def should_generate_py_binding(f: NativeFunction) -> bool:
|
||||
# So far, all NativeFunctions that are entirely code-generated do not get python bindings.
|
||||
if "generated" in f.tags:
|
||||
return False
|
||||
name = cpp.name(f.func)
|
||||
for skip_regex in SKIP_PYTHON_BINDINGS:
|
||||
if skip_regex.match(name):
|
||||
|
@ -432,7 +432,8 @@ def emit_trace_body(f: NativeFunction) -> List[str]:
|
||||
|
||||
assign_return_values = (
|
||||
f"{tie_return_values(f)} = "
|
||||
if f.func.kind() == SchemaKind.functional and f.func.returns
|
||||
if f.func.kind() in [SchemaKind.functional, SchemaKind.mutable]
|
||||
and f.func.returns
|
||||
else ""
|
||||
)
|
||||
|
||||
|
@ -349,7 +349,7 @@ GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
|
||||
GRADIENT_IMPLEMENTED_FOR_COMPLEX.update(GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX)
|
||||
|
||||
# Some operators invalidate the grad_accumulator. Let's reset it.
|
||||
RESET_GRAD_ACCUMULATOR = {"set", "resize"}
|
||||
RESET_GRAD_ACCUMULATOR = {"set_", "resize_"}
|
||||
|
||||
# NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
|
||||
#
|
||||
@ -734,7 +734,7 @@ def gen_variable_type_func(
|
||||
|
||||
if (
|
||||
fn.info is None
|
||||
and not get_base_name(f) in RESET_GRAD_ACCUMULATOR
|
||||
and not str(f.func.name.name) in RESET_GRAD_ACCUMULATOR
|
||||
and not get_base_name(f) in DONT_REQUIRE_DERIVATIVE
|
||||
and len(gen_differentiable_outputs(fn)) > 0
|
||||
and not cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE
|
||||
@ -857,7 +857,14 @@ def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
|
||||
and (len(differentiable_outputs) > 0)
|
||||
)
|
||||
|
||||
if info is not None and info.has_derivatives and not requires_derivative:
|
||||
if (
|
||||
info is not None
|
||||
and info.has_derivatives
|
||||
and not requires_derivative
|
||||
# out= ops are allowed to have zero returns which cause requires_derivative to be False
|
||||
# we shouldn't error out though (out= ops for autograd just redispatch)
|
||||
and len(f.func.returns) > 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative"
|
||||
)
|
||||
@ -1528,7 +1535,7 @@ def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
|
||||
# Save only after the forward AD has been set up
|
||||
body.append(emit_save_outputs())
|
||||
|
||||
if base_name in RESET_GRAD_ACCUMULATOR:
|
||||
if str(f.func.name.name) in RESET_GRAD_ACCUMULATOR:
|
||||
# `inplace` implies that there is exactly one output named `self`,
|
||||
# so we can keep the generated code easy. If you need to
|
||||
# `reset_grad_accumulator` in an operator that's not `inplace`, you can
|
||||
|
@ -32,7 +32,10 @@ from torchgen.api.types import (
|
||||
stringT,
|
||||
)
|
||||
from torchgen.api import cpp
|
||||
from torchgen.gen import parse_native_yaml, get_grouped_by_view_native_functions
|
||||
from torchgen.gen import (
|
||||
parse_native_yaml,
|
||||
get_grouped_by_view_native_functions,
|
||||
)
|
||||
from torchgen.context import with_native_function
|
||||
from torchgen.model import (
|
||||
FunctionSchema,
|
||||
|
@ -81,10 +81,11 @@ std::vector<int64_t> expand_param_if_needed(
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
||||
|
||||
std::vector<Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) {
|
||||
TORCH_API std::vector<Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) {
|
||||
double size_d = 0;
|
||||
// shape inference code copied from RangeFactories.cpp arange_out function
|
||||
// Note: AT_DISPATCH_ALL_TYPES_AND is just a macro that defines the correct c++ scalar_t type depending on out tensor
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND(c10::kBFloat16, out.scalar_type(), "compute_shape_arange_out", [&]() {
|
||||
// Note: acc_type further defines an accumulataion type depending on the scalar_t and whether its on cuda vs cpu.
|
||||
using accscalar_t = at::acc_type<scalar_t, false>;
|
||||
@ -129,7 +130,6 @@ std::vector<Shape> compute_shape_arange_out(const at::Scalar & start, const at::
|
||||
// If any of start, end, or stop are floating-point, the dtype is inferred to be the default dtype, see get_default_dtype().
|
||||
// Otherwise, the dtype is inferred to be torch.int64.
|
||||
|
||||
// Since out tensor is specified, its dtype should always be used?
|
||||
return {Shape(out.scalar_type(), {size})};
|
||||
}
|
||||
|
||||
@ -145,7 +145,7 @@ std::vector<Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optiona
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional<at::Generator> generator) {
|
||||
std::vector<Shape> compute_shape_bernoulli(const at::Tensor & self, double p, c10::optional<at::Generator> generator) {
|
||||
return compute_shape_bernoulli(self, generator);
|
||||
}
|
||||
|
||||
@ -224,11 +224,11 @@ std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at:
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) {
|
||||
std::vector<Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) {
|
||||
std::vector<Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
@ -380,26 +380,22 @@ std::vector<Shape> compute_shape_native_dropout_backward(const at::Tensor & grad
|
||||
return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_random_(at::Tensor & self, c10::optional<at::Generator> generator) {
|
||||
std::vector<Shape> compute_shape_random_functional(const at::Tensor & self, c10::optional<at::Generator> generator) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_random_(at::Tensor & self, int64_t to, c10::optional<at::Generator> generator) {
|
||||
return compute_shape_random_(self, generator);
|
||||
std::vector<Shape> compute_shape_random_functional(const at::Tensor & self, int64_t to, c10::optional<at::Generator> generator) {
|
||||
return compute_shape_random_functional(self, generator);
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_random_(at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator) {
|
||||
return compute_shape_random_(self, generator);
|
||||
std::vector<Shape> compute_shape_random_functional(const at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator) {
|
||||
return compute_shape_random_functional(self, generator);
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_relu(const at::Tensor& self) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_relu_(at::Tensor& self) {
|
||||
return compute_shape_relu(self);
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_bitwise_and(const at::Tensor& self, const at::Scalar& other) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
@ -417,7 +413,7 @@ std::vector<Shape> compute_shape_sum(
|
||||
return {Shape(self.scalar_type(), {})};;
|
||||
}
|
||||
|
||||
std::vector<Shape> compute_shape_zero_(at::Tensor& self) {
|
||||
std::vector<Shape> compute_shape_zero_functional(const at::Tensor& self) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
|
@ -16,7 +16,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape__adaptive_avg_pool2d_bac
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_abs(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, double p, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_cat(at::TensorList tensors, int64_t dim);
|
||||
@ -37,8 +37,8 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_l1_loss_backward(const a
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_log_sigmoid_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_log_sigmoid_forward(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_logdet(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Tensor & value);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_max(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_min(const at::Tensor & self);
|
||||
@ -50,11 +50,10 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm_backwa
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nonzero(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_(at::Tensor & self, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_(at::Tensor & self, int64_t to, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_(at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_functional(const at::Tensor & self, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_functional(const at::Tensor & self, int64_t to, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random_functional(const at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_relu(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_relu_(at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_smooth_l1_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sort(const at::Tensor & self, int64_t dim, bool descending);
|
||||
@ -65,7 +64,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_std(const at::Tensor & s
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_sum(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape__to_copy(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, bool non_blocking, c10::optional<at::MemoryFormat> memory_format);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_trace(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_(at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_functional(const at::Tensor & self);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length);
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
@ -310,17 +310,41 @@ def match_differentiability_info(
|
||||
for info in differentiability_infos
|
||||
if info.func.func.kind() == SchemaKind.functional
|
||||
}
|
||||
non_functional_info_by_signature = {
|
||||
info.func.func.signature(strip_default=True): info
|
||||
for info in differentiability_infos
|
||||
if info.func.func.kind() != SchemaKind.functional
|
||||
}
|
||||
|
||||
def find_info(f: NativeFunction) -> Tuple[Optional[DifferentiabilityInfo], bool]:
|
||||
# (1) Check for an exact match
|
||||
if f.func in info_by_schema:
|
||||
return info_by_schema[f.func], True
|
||||
|
||||
# if there is no exact match look for the out-of-place signature.
|
||||
# (2) If no exact match, check if the out-of-place variant
|
||||
# of this operator has a match.
|
||||
# i.e mul() for mul_() or mul_out()
|
||||
return (
|
||||
functional_info_by_signature.get(f.func.signature(strip_default=True)),
|
||||
False,
|
||||
)
|
||||
f_sig = f.func.signature(strip_default=True)
|
||||
if f_sig in functional_info_by_signature:
|
||||
return functional_info_by_signature[f_sig], False
|
||||
|
||||
# (3) Some operators have a derivative explicitly defined for the mutable
|
||||
# variant, but get a code-generated out-of-place variant which does *not*
|
||||
# come with a derivative formula.
|
||||
# For the generated out-of-place variant, use the mutable variant's formula
|
||||
# if it exists.
|
||||
if "generated" in f.tags and f_sig in non_functional_info_by_signature:
|
||||
info = non_functional_info_by_signature[f_sig]
|
||||
# See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
|
||||
assert not any(
|
||||
"self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs
|
||||
), f"""\
|
||||
Attempted to convert a derivative formula for a mutable operator
|
||||
to be used by automatically by its functional variant ("{str(f.func)}").
|
||||
this is not currently supported (we'd need to fix up the formula in the codegen)."""
|
||||
return info, False
|
||||
|
||||
return None, False
|
||||
|
||||
result: List[NativeFunctionWithDifferentiabilityInfo] = []
|
||||
for f in native_functions:
|
||||
|
@ -64,7 +64,9 @@ from typing import Optional, Sequence, Union, List, Set
|
||||
|
||||
def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
|
||||
name = str(func.name.name)
|
||||
if func.is_out_fn():
|
||||
if func.is_functional_fn():
|
||||
name += "_functional"
|
||||
elif func.is_out_fn():
|
||||
if faithful_name_for_out_overloads:
|
||||
name += "_outf"
|
||||
else:
|
||||
|
@ -60,6 +60,8 @@ from torchgen.api.types import (
|
||||
|
||||
options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
|
||||
|
||||
out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))
|
||||
|
||||
longVec_ctype = VectorCType(BaseCType(longT))
|
||||
optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
|
||||
optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
|
||||
@ -287,20 +289,38 @@ Check this module for more information.
|
||||
return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})"
|
||||
|
||||
elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
|
||||
options = direct_solve(options_ctype)
|
||||
return f"optTypeMetaToScalarType({options}.dtype_opt())"
|
||||
try:
|
||||
options = direct_solve(options_ctype)
|
||||
return f"optTypeMetaToScalarType({options}.dtype_opt())"
|
||||
except UnsatError:
|
||||
out_tensor = direct_solve(out_tensor_ctype)
|
||||
return f"{out_tensor}.scalar_type()"
|
||||
|
||||
elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
|
||||
options = direct_solve(options_ctype)
|
||||
return f"{options}.layout_opt()"
|
||||
try:
|
||||
options = direct_solve(options_ctype)
|
||||
return f"{options}.layout_opt()"
|
||||
except UnsatError:
|
||||
out_tensor = direct_solve(out_tensor_ctype)
|
||||
return f"{out_tensor}.layout()"
|
||||
|
||||
elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
|
||||
options = direct_solve(options_ctype)
|
||||
return f"{options}.device_opt()"
|
||||
try:
|
||||
options = direct_solve(options_ctype)
|
||||
return f"{options}.device_opt()"
|
||||
except UnsatError:
|
||||
out_tensor = direct_solve(out_tensor_ctype)
|
||||
return f"{out_tensor}.device()"
|
||||
|
||||
elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
|
||||
options = direct_solve(options_ctype)
|
||||
return f"{options}.pinned_memory_opt()"
|
||||
try:
|
||||
options = direct_solve(options_ctype)
|
||||
return f"{options}.pinned_memory_opt()"
|
||||
except UnsatError:
|
||||
# If we're calling a factory op from its out= variant,
|
||||
# We don't actually care about the value of pin_memory.
|
||||
out_tensor = direct_solve(out_tensor_ctype)
|
||||
return "c10::nullopt"
|
||||
|
||||
# We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
|
||||
elif goal.type == BaseCType(intArrayRefT):
|
||||
@ -348,6 +368,15 @@ Check this module for more information.
|
||||
# With arguments like std::vector<IntArrayRef>.
|
||||
# If that changes, we'll have to add the translation here.
|
||||
|
||||
# We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor.
|
||||
# We could probably generalize this to non-tensor types too.
|
||||
if goal.type == MutRefCType(BaseCType(tensorT)):
|
||||
const_ref_tensor_ctype = NamedCType(
|
||||
goal.name, ConstRefCType(BaseCType(tensorT))
|
||||
)
|
||||
argname = direct_solve(const_ref_tensor_ctype)
|
||||
return f"const_cast<Tensor&>({argname})"
|
||||
|
||||
unsat(goal)
|
||||
|
||||
return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
|
||||
|
@ -26,6 +26,7 @@ F = TypeVar(
|
||||
F2 = TypeVar(
|
||||
"F2",
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
Optional[NativeFunction],
|
||||
bool,
|
||||
)
|
||||
|
@ -616,6 +616,10 @@ check_inplace(out, sizes, options);
|
||||
const auto& out = outputs_[output_idx].get();
|
||||
resize_out(out, sizes, strides, options);
|
||||
{create_proxy}"""
|
||||
elif k is SchemaKind.mutable:
|
||||
raise AssertionError(
|
||||
"SchemaKind.mutable structured operators are currently not supported"
|
||||
)
|
||||
else:
|
||||
assert_never(k)
|
||||
|
||||
@ -631,6 +635,10 @@ resize_out(out, sizes, strides, options);
|
||||
out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
|
||||
out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
|
||||
return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
|
||||
elif k is SchemaKind.mutable:
|
||||
raise AssertionError(
|
||||
"SchemaKind.mutable structured operators are currently not supported"
|
||||
)
|
||||
else:
|
||||
assert_never(k)
|
||||
|
||||
|
184
torchgen/gen.py
184
torchgen/gen.py
@ -33,6 +33,10 @@ from torchgen.model import (
|
||||
ViewSchemaKind,
|
||||
BaseOperatorName,
|
||||
)
|
||||
from torchgen.native_function_generation import (
|
||||
pre_group_native_functions,
|
||||
add_generated_native_functions,
|
||||
)
|
||||
from torchgen.api.types import (
|
||||
Binding,
|
||||
CppSignatureGroup,
|
||||
@ -72,6 +76,7 @@ from torchgen.gen_functionalization_type import (
|
||||
gen_functionalization_registration,
|
||||
gen_functionalization_view_inverse_declaration,
|
||||
gen_composite_view_copy_kernel,
|
||||
gen_composite_functional_kernel,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -180,6 +185,7 @@ def parse_native_yaml_struct(
|
||||
index={},
|
||||
)
|
||||
)
|
||||
add_generated_native_functions(rs, bs)
|
||||
for k, v in bs.items():
|
||||
# All structured in-tree operators are implemented in terms of their out operator.
|
||||
indices[k] = BackendIndex(
|
||||
@ -1278,59 +1284,48 @@ def get_custom_build_selector(
|
||||
return selector
|
||||
|
||||
|
||||
def pre_group_native_functions(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]:
|
||||
pre_grouped_native_functions: Dict[
|
||||
FunctionSchema, Dict[SchemaKind, NativeFunction]
|
||||
] = defaultdict(dict)
|
||||
for f in native_functions:
|
||||
d = pre_grouped_native_functions[f.func.signature()]
|
||||
assert f.func.kind() not in d
|
||||
d[f.func.kind()] = f
|
||||
return pre_grouped_native_functions
|
||||
|
||||
|
||||
def get_grouped_by_view_native_functions(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]:
|
||||
def maybe_create_view_group(
|
||||
d: Dict[ViewSchemaKind, NativeFunction]
|
||||
d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction]
|
||||
) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]:
|
||||
funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = []
|
||||
if ViewSchemaKind.aliasing not in d:
|
||||
# Case 1: this op / op group is not aliasing, so we don't create a view group.
|
||||
# return the original (ungrouped) native functions instead.
|
||||
for func in d.values():
|
||||
funcs.append(func)
|
||||
else:
|
||||
# Case 2: this op group contains an aliasing op, so we create a ViewGroup for it.
|
||||
# The handling for out= ops here is unfortunate.
|
||||
# out= ops don't really make sense for view operators.
|
||||
# However, we have at least one existing {view}_copy.out operator in native_functions.yaml.
|
||||
# It shouldn't be part of a view group, so we explicitly don't group it.
|
||||
# There currently aren't any out= view ops (and there probably shouldn't be).
|
||||
# We also expect that when we hit this case, the `non_aliasing` op in the dict
|
||||
# *must* be a view_copy op (this is asserted in the NativeFunctionsViewGroup constructor)
|
||||
if ViewSchemaKind.out in d:
|
||||
funcs.append(d[ViewSchemaKind.out])
|
||||
if ViewSchemaKind.aliasing in d:
|
||||
view = d.pop(ViewSchemaKind.aliasing)
|
||||
view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
|
||||
view_copy = d.pop(SchemaKind.functional, None)
|
||||
|
||||
funcs.append(
|
||||
NativeFunctionsViewGroup(
|
||||
view=d[ViewSchemaKind.aliasing],
|
||||
view_copy=d.get(ViewSchemaKind.non_aliasing, None),
|
||||
view_inplace=d.get(ViewSchemaKind.inplace, None),
|
||||
view=view,
|
||||
view_copy=view_copy,
|
||||
view_inplace=view_inplace,
|
||||
)
|
||||
)
|
||||
# Take the remaining functions that weren't part of the view group
|
||||
# and emit them separately
|
||||
for func in d.values():
|
||||
funcs.append(func)
|
||||
return funcs
|
||||
|
||||
grouped_by_views: Dict[
|
||||
FunctionSchema, Dict[ViewSchemaKind, NativeFunction]
|
||||
FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction]
|
||||
] = defaultdict(dict)
|
||||
for f in native_functions:
|
||||
schema = f.func.view_signature()
|
||||
assert f.view_schema_kind not in grouped_by_views[schema]
|
||||
grouped_by_views[schema][f.view_schema_kind] = f
|
||||
view_kind: ViewSchemaKind = f.view_schema_kind
|
||||
# We need to group up ops relevant to the same "view", consisting of:
|
||||
# view op (ViewSchemaKind.aliasing)
|
||||
# view_inplace op (ViewSchemaKind.aliasing_inplace)
|
||||
# view_copy op (SchemaKind.functional)
|
||||
if view_kind == ViewSchemaKind.non_aliasing:
|
||||
kind = f.func.kind()
|
||||
assert kind not in grouped_by_views[schema]
|
||||
grouped_by_views[schema][kind] = f
|
||||
else:
|
||||
assert view_kind not in grouped_by_views[schema]
|
||||
grouped_by_views[schema][view_kind] = f
|
||||
|
||||
return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
|
||||
|
||||
@ -1343,6 +1338,9 @@ def get_grouped_native_functions(
|
||||
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
|
||||
r = NativeFunctionsGroup.from_dict(d)
|
||||
if r is None:
|
||||
# Invariant: any NativeFunctions that are code-generated
|
||||
# should have been grouped into NativeFunctionsGroup objects
|
||||
assert not any("generated" in f.tags for f in d.values())
|
||||
return list(d.values())
|
||||
else:
|
||||
return [r]
|
||||
@ -1835,9 +1833,7 @@ def gen_source_files(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
structured_native_functions: Sequence[NativeFunctionsGroup],
|
||||
native_functions_with_view_groups: Sequence[
|
||||
Union[NativeFunction, NativeFunctionsViewGroup]
|
||||
],
|
||||
view_groups: Sequence[NativeFunctionsViewGroup],
|
||||
selector: SelectiveBuilder,
|
||||
static_dispatch_idx: List[BackendIndex],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
@ -2118,30 +2114,11 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
|
||||
},
|
||||
)
|
||||
|
||||
# We need to easily map from [inplace_op_name] -> [functional_op] for the functionalization pass,
|
||||
# so here I generate a mapping from every operator name to its corresponding functional NativeFunction (if it exist).
|
||||
pre_grouped_d: Dict[
|
||||
FunctionSchema, Dict[SchemaKind, NativeFunction]
|
||||
] = pre_group_native_functions(native_functions)
|
||||
to_functional_op: Dict[OperatorName, Optional[NativeFunction]] = {
|
||||
k: v
|
||||
for d in [
|
||||
{
|
||||
f.func.name: pre_grouped_d[func][SchemaKind.functional]
|
||||
if SchemaKind.functional in pre_grouped_d[func].keys()
|
||||
else None
|
||||
for f in pre_grouped_d[func].values()
|
||||
}
|
||||
for func in pre_grouped_d.keys()
|
||||
]
|
||||
for k, v in d.items()
|
||||
}
|
||||
|
||||
def functionalization_env_callable(
|
||||
g: Union[NativeFunction, NativeFunctionsViewGroup]
|
||||
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
|
||||
) -> Dict[str, List[str]]:
|
||||
def gen_op_headers(
|
||||
g: Union[NativeFunction, NativeFunctionsViewGroup]
|
||||
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
|
||||
) -> List[str]:
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
# view ops always get a functionalization kernel
|
||||
@ -2155,11 +2132,28 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
|
||||
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
|
||||
]
|
||||
return headers
|
||||
elif isinstance(g, NativeFunctionsGroup):
|
||||
headers = [
|
||||
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
|
||||
f"#include <ATen/ops/{g.out.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
|
||||
]
|
||||
if g.inplace is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
|
||||
]
|
||||
if g.mutable is not None:
|
||||
headers += [
|
||||
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
|
||||
]
|
||||
return headers
|
||||
else:
|
||||
f = g
|
||||
return [
|
||||
f"#include <ATen/ops/{f.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{f.root_name}_ops.h>",
|
||||
f"#include <ATen/ops/{g.root_name}_native.h>",
|
||||
f"#include <ATen/ops/{g.root_name}_ops.h>",
|
||||
]
|
||||
|
||||
return {
|
||||
@ -2167,11 +2161,6 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
|
||||
"func_definitions": gen_functionalization_definition(
|
||||
selector,
|
||||
g,
|
||||
# We need to manually map inplace ops to their out-of-place variants
|
||||
# (we can't do this with NativeFunctionsGroup today because not all inplace ops have out= variants)
|
||||
None
|
||||
if isinstance(g, NativeFunctionsViewGroup)
|
||||
else to_functional_op.get(g.func.name, None),
|
||||
),
|
||||
"func_registrations": gen_functionalization_registration(
|
||||
selector,
|
||||
@ -2180,9 +2169,30 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
|
||||
),
|
||||
}
|
||||
|
||||
all_groups: List[
|
||||
Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
|
||||
] = list(structured_native_functions) + list(
|
||||
view_groups # type: ignore[assignment, arg-type, operator]
|
||||
)
|
||||
# Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
|
||||
# The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
|
||||
# (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
|
||||
# (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
|
||||
# Although this could go away long-term if we add a dedicated dispatch key for decompositions.
|
||||
structured_map: Dict[OperatorName, NativeFunction] = {
|
||||
f.func.name: f
|
||||
for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
|
||||
}
|
||||
view_map: Dict[OperatorName, NativeFunction] = {
|
||||
f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
|
||||
}
|
||||
for f in native_functions:
|
||||
if f.func.name not in structured_map and f.func.name not in view_map:
|
||||
all_groups.append(f)
|
||||
|
||||
cpu_fm.write_sharded(
|
||||
"RegisterFunctionalization.cpp",
|
||||
native_functions_with_view_groups,
|
||||
all_groups,
|
||||
key_fn=key_func,
|
||||
env_callable=functionalization_env_callable,
|
||||
num_shards=4,
|
||||
@ -2203,11 +2213,7 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
|
||||
lambda g: gen_functionalization_view_inverse_declaration(
|
||||
selector, g
|
||||
),
|
||||
[
|
||||
g
|
||||
for g in native_functions_with_view_groups
|
||||
if isinstance(g, NativeFunctionsViewGroup)
|
||||
],
|
||||
view_groups,
|
||||
)
|
||||
)
|
||||
},
|
||||
@ -2239,17 +2245,23 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
|
||||
[g.view] if g.view_copy is None else [g.view, g.view_copy]
|
||||
)
|
||||
)
|
||||
for g in native_functions_with_view_groups
|
||||
if isinstance(g, NativeFunctionsViewGroup)
|
||||
for g in view_groups
|
||||
]
|
||||
+ [
|
||||
"\n".join(
|
||||
f"#include <ATen/ops/{f.root_name}_ops.h>"
|
||||
for f in [g.inplace, g.mutable]
|
||||
if f is not None and "generated" not in f.tags
|
||||
)
|
||||
for g in structured_native_functions
|
||||
],
|
||||
"CompositeViewCopyKernel_Definitions": list(
|
||||
mapMaybe(gen_composite_view_copy_kernel, view_groups)
|
||||
),
|
||||
"GeneratedCompositeFunctional_Definitions": list(
|
||||
mapMaybe(
|
||||
gen_composite_view_copy_kernel,
|
||||
[
|
||||
g
|
||||
for g in native_functions_with_view_groups
|
||||
if isinstance(g, NativeFunctionsViewGroup)
|
||||
],
|
||||
gen_composite_functional_kernel,
|
||||
structured_native_functions,
|
||||
)
|
||||
),
|
||||
},
|
||||
@ -2377,12 +2389,18 @@ def main() -> None:
|
||||
)
|
||||
|
||||
grouped_native_functions = get_grouped_native_functions(native_functions)
|
||||
|
||||
structured_native_functions = [
|
||||
g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
|
||||
]
|
||||
native_functions_with_view_groups = get_grouped_by_view_native_functions(
|
||||
native_functions
|
||||
)
|
||||
view_groups = [
|
||||
g
|
||||
for g in native_functions_with_view_groups
|
||||
if isinstance(g, NativeFunctionsViewGroup)
|
||||
]
|
||||
|
||||
template_dir = os.path.join(options.source_path, "templates")
|
||||
|
||||
@ -2451,7 +2469,7 @@ def main() -> None:
|
||||
native_functions=native_functions,
|
||||
grouped_native_functions=grouped_native_functions,
|
||||
structured_native_functions=structured_native_functions,
|
||||
native_functions_with_view_groups=native_functions_with_view_groups,
|
||||
view_groups=view_groups,
|
||||
selector=selector,
|
||||
static_dispatch_idx=static_dispatch_idx,
|
||||
backend_indices=backend_indices,
|
||||
|
@ -4,6 +4,7 @@ from torchgen.api.types import (
|
||||
Binding,
|
||||
FunctionalizationLambda,
|
||||
ViewInverseSignature,
|
||||
Expr,
|
||||
NativeSignature,
|
||||
CType,
|
||||
BaseCType,
|
||||
@ -19,8 +20,9 @@ from torchgen.context import (
|
||||
)
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
Return,
|
||||
NativeFunction,
|
||||
SchemaKind,
|
||||
NativeFunctionsGroup,
|
||||
BackendIndex,
|
||||
FunctionSchema,
|
||||
SelfArgument,
|
||||
@ -30,10 +32,30 @@ from torchgen.model import (
|
||||
NativeFunctionsViewGroup,
|
||||
ListType,
|
||||
)
|
||||
from torchgen.native_function_generation import (
|
||||
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
|
||||
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
|
||||
INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
|
||||
)
|
||||
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
from typing import List, Optional, Union, Tuple, Callable
|
||||
|
||||
|
||||
# Note: [Mutable Ops Not Using Functionalization]
|
||||
# Ops in this list currently do not work with functionalization and should be fixed.
|
||||
MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = (
|
||||
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
|
||||
+ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
|
||||
+ INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
|
||||
+ [
|
||||
# It will be BC-breaking, but we should fix their schemas.
|
||||
# should be inplace?
|
||||
"record_stream",
|
||||
]
|
||||
)
|
||||
|
||||
# This file contains codegen that relates to the functionalization pass.
|
||||
# It includes:
|
||||
# - gen_functionalization_definition
|
||||
@ -91,24 +113,96 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
|
||||
"""
|
||||
|
||||
|
||||
def modifies_arguments(f: NativeFunction) -> bool:
|
||||
return f.func.kind() in [SchemaKind.inplace, SchemaKind.out]
|
||||
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
|
||||
assert len(rets) == len(names)
|
||||
if len(rets) == 0:
|
||||
return ""
|
||||
elif len(rets) == 1:
|
||||
return f"return {names[0]};"
|
||||
else:
|
||||
return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});"
|
||||
|
||||
|
||||
# This function constructs the return statement for the kernels that contain mutations
|
||||
# It mostly just needs to special case multi-output returns to wrap the result in a tuple
|
||||
def return_str(f: NativeFunction) -> str:
|
||||
# Need to check both # outs and # returns. Why?
|
||||
# out= ops with a mutable Tensor(a!)[] argument are expected to have a void return type.
|
||||
if len(f.func.arguments.out) != 0 and len(f.func.returns) != 0:
|
||||
if len(f.func.arguments.out) > 1:
|
||||
return_names = ", ".join(a.name for a in f.func.arguments.out)
|
||||
return f"return {DispatcherSignature.from_schema(f.func).returns_type().cpp_type()}({return_names});"
|
||||
# Given a function, and the name of a variable correponding to the output of that function,
|
||||
# gather up all of the individual returns that are not aliased
|
||||
def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> List[str]:
|
||||
aliased_rets = func.aliased_return_names()
|
||||
non_aliased_names = []
|
||||
is_out_var_a_tuple = len(func.returns) > 1
|
||||
for (i, r) in enumerate(aliased_rets):
|
||||
if r is None:
|
||||
non_aliased_names.append(
|
||||
f"std::get<{i}>({out_var})" if is_out_var_a_tuple else out_var
|
||||
)
|
||||
return non_aliased_names
|
||||
|
||||
|
||||
@with_native_function
|
||||
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
||||
# We should only be generating these for code-generated NativeFunctions
|
||||
if "generated" not in g.functional.tags:
|
||||
return None
|
||||
# And we always write the kernel for a generated op in terms of a non-generated op.
|
||||
if g.inplace is not None and "generated" not in g.inplace.tags:
|
||||
target_f = g.inplace
|
||||
elif g.mutable is not None and "generated" not in g.mutable.tags:
|
||||
target_f = g.mutable
|
||||
else:
|
||||
# We should be guaranteed to have a valid inplace/mutable variant to call into.
|
||||
# See Note: [Mutable Ops Not Using Functionalization]
|
||||
raise AssertionError(str(g.functional.func))
|
||||
|
||||
sig = DispatcherSignature(g.functional.func)
|
||||
target_sig = DispatcherSignature(target_f.func)
|
||||
|
||||
context: List[Union[Binding, Expr]] = []
|
||||
clone_mutable_inputs = []
|
||||
cloned_return_names = []
|
||||
# We can't just directly pass all of the arguments from the functional op into the mutating op.
|
||||
# We need to check for which inputs to the mutating operator are mutable,
|
||||
# and clone those inputs first.
|
||||
for a_curr, a_tgt in zip(
|
||||
dispatcher.jit_arguments(g.functional.func),
|
||||
dispatcher.jit_arguments(target_f.func),
|
||||
):
|
||||
if a_tgt.annotation is not None and a_tgt.annotation.is_write:
|
||||
clone_mutable_inputs.append(
|
||||
f"auto {a_curr.name}_clone = clone_arg({a_curr.name});"
|
||||
)
|
||||
context.append(
|
||||
Expr(
|
||||
expr=f"{a_curr.name}_clone",
|
||||
type=dispatcher.argument_type(a_curr, binds=a_curr.name),
|
||||
)
|
||||
)
|
||||
# Invariant: mutable arguments on the inner mutable op are always returns on the functional op.
|
||||
cloned_return_names.append(f"{a_curr.name}_clone")
|
||||
else:
|
||||
return f"return {f.func.arguments.out[0].name}"
|
||||
if f.func.arguments.self_arg is not None and len(f.func.returns) != 0:
|
||||
return f"return {f.func.arguments.self_arg.argument.name}"
|
||||
return ""
|
||||
context.append(dispatcher.argument(a_curr))
|
||||
exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())])
|
||||
|
||||
out_name = "output"
|
||||
maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else ""
|
||||
inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name)
|
||||
ret_str = return_str(
|
||||
g.functional.func.returns, inner_return_names + cloned_return_names
|
||||
)
|
||||
|
||||
clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
|
||||
return f"""
|
||||
{sig.defn()} {{
|
||||
{clone_mutable_inputs_str}
|
||||
{maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
|
||||
{ret_str}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
def modifies_arguments(f: NativeFunction) -> bool:
|
||||
return any(
|
||||
a.annotation is not None and a.annotation.is_write
|
||||
for a in f.func.arguments.flat_all
|
||||
)
|
||||
|
||||
|
||||
def wrapper_name(func: FunctionSchema) -> str:
|
||||
@ -352,11 +446,103 @@ def emit_view_functionalization_body(
|
||||
"""
|
||||
|
||||
|
||||
def maybe_create_output(f: NativeFunction, var_name: str) -> str:
|
||||
if len(f.func.returns) == 0:
|
||||
return ""
|
||||
return_type = dispatcher.returns_type(f.func.returns).remove_const_ref().cpp_type()
|
||||
return f"{return_type} {var_name} = "
|
||||
|
||||
|
||||
# Given a NativeFunction, and a variable name corresponding to the output of redispatching on the function,
|
||||
# this returns two lists of names, consisting of:
|
||||
# - the names of returns corresponding to the original (mutable) inputs of the outer function
|
||||
# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function
|
||||
def get_mutable_redispatch_return_names(
|
||||
f: NativeFunction, inner_return_var: str
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
aliased_returns = []
|
||||
non_aliased_returns = []
|
||||
for (i, name) in enumerate(f.func.aliased_return_names()):
|
||||
if name is not None:
|
||||
aliased_returns.append(name)
|
||||
else:
|
||||
non_aliased_returns.append(
|
||||
inner_return_var
|
||||
if len(f.func.returns) == 1
|
||||
else f"std::get<{i}>({inner_return_var})"
|
||||
)
|
||||
return aliased_returns, non_aliased_returns
|
||||
|
||||
|
||||
# When functionalization "no-op's" and redispatches on a mutable operator, we need to take care so that:
|
||||
# - For fresh outputs, we return the result of the redispatch (without wrapping outputs)
|
||||
# - For outputs that were aliased to inputs, we return the inputs directly (since some of them might have been wrapped)
|
||||
def return_from_mutable_noop_redispatch(
|
||||
f: NativeFunction, inner_return_var: str
|
||||
) -> str:
|
||||
aliased, non_aliased = get_mutable_redispatch_return_names(f, inner_return_var)
|
||||
# Just get all of the return names, and immediately return them
|
||||
return return_str(f.func.returns, aliased + non_aliased)
|
||||
|
||||
|
||||
def wrap_propagate_mutations_and_return(
|
||||
f: NativeFunction, functional_op: NativeFunction, inner_return_var: str
|
||||
) -> str:
|
||||
mutable_arg_names = f.func.arguments.mutable_arg_names()
|
||||
(
|
||||
aliased_outer_rets,
|
||||
non_aliased_outer_rets,
|
||||
) = get_mutable_redispatch_return_names(f, inner_return_var)
|
||||
_, non_aliased_inner_rets = get_mutable_redispatch_return_names(
|
||||
functional_op, inner_return_var
|
||||
)
|
||||
# The outer function may have a mix of aliased and non-aliased outputs,
|
||||
# But the inner functional op that we're transforming to should only have non-aliased outputs
|
||||
assert len(mutable_arg_names) + len(non_aliased_outer_rets) == len(
|
||||
non_aliased_inner_rets
|
||||
)
|
||||
|
||||
# First, take all of the newly created outputs from the inner call and wrap them into functional tensors
|
||||
updates = []
|
||||
non_aliased_wrapped_ret_names = []
|
||||
for (i, inner_ret) in enumerate(
|
||||
non_aliased_inner_rets[: len(non_aliased_outer_rets)]
|
||||
):
|
||||
ret_name = f"output_{i}"
|
||||
updates.append(
|
||||
f"""\
|
||||
auto output_{i} = at::functionalization::impl::to_functional_tensor({inner_ret});"""
|
||||
)
|
||||
non_aliased_wrapped_ret_names.append(ret_name)
|
||||
|
||||
# Next, take all of the mutated outputs from the inner call corresponding to mutated inputs,
|
||||
# and propogate the mutations
|
||||
for (outer_arg, inner_ret) in zip(
|
||||
mutable_arg_names, non_aliased_inner_rets[len(non_aliased_outer_rets) :]
|
||||
):
|
||||
updates.append(
|
||||
f"""\
|
||||
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
|
||||
at::functionalization::impl::commit_update({outer_arg});"""
|
||||
)
|
||||
|
||||
# Finally, we return:
|
||||
# - Any mutable arguments that also returns
|
||||
# - Any immutable returns that were created wrapping the output from the inner call
|
||||
returns_str = return_str(
|
||||
f.func.returns, aliased_outer_rets + non_aliased_wrapped_ret_names
|
||||
)
|
||||
updates_str = "\n".join(updates)
|
||||
return f"""\
|
||||
{updates_str}
|
||||
{returns_str}"""
|
||||
|
||||
|
||||
# Generates the Functionalization kernel for:
|
||||
# - mutation ops (inplace and out= ops)
|
||||
@with_native_function_and
|
||||
def emit_inplace_functionalization_body(
|
||||
f: NativeFunction, functional_op: Optional[NativeFunction]
|
||||
f: NativeFunction, g: NativeFunctionsGroup
|
||||
) -> str:
|
||||
# mutation case
|
||||
assert modifies_arguments(f)
|
||||
@ -400,42 +586,15 @@ def emit_inplace_functionalization_body(
|
||||
for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
|
||||
]
|
||||
|
||||
if functional_op is None:
|
||||
# We can't functionalize this inplace op, since we don't know what the corresponding functional op is.
|
||||
return_type = (
|
||||
dispatcher.returns_type(f.func.returns).remove_const_ref().cpp_type()
|
||||
)
|
||||
warn_str = f"""Note: the functionalization pass encountered an operator ({str(f.func.name)}) that it could not \
|
||||
functionalize, because it couldn't find an out-of-place equivalent of the operator to call. \
|
||||
Instead, it's calling the inplace/view operator directly. \
|
||||
If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch."""
|
||||
|
||||
return f"""
|
||||
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
|
||||
if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{
|
||||
TORCH_WARN("{warn_str}");
|
||||
}}
|
||||
{unwrap_tensor_args_str}
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
// Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops.
|
||||
at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)});
|
||||
{return_str(f)};
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
# call the out-of-place variant of the op
|
||||
return_type = (
|
||||
dispatcher.returns_type(functional_op.func.returns)
|
||||
.remove_const_ref()
|
||||
.cpp_type()
|
||||
)
|
||||
functional_sig = DispatcherSignature.from_schema(functional_op.func)
|
||||
functional_exprs = [
|
||||
e.expr
|
||||
for e in translate(
|
||||
unwrapped_args_ctx, functional_sig.arguments(), method=False
|
||||
)
|
||||
]
|
||||
# call the out-of-place variant of the op
|
||||
return_type = (
|
||||
dispatcher.returns_type(g.functional.func.returns).remove_const_ref().cpp_type()
|
||||
)
|
||||
functional_sig = DispatcherSignature.from_schema(g.functional.func)
|
||||
functional_exprs = [
|
||||
e.expr
|
||||
for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False)
|
||||
]
|
||||
|
||||
if f.func.is_out_fn():
|
||||
mutable_input_post_processing = "\n".join(
|
||||
@ -471,17 +630,16 @@ If this causes problems in your program, consider upstreaming the out-of-place o
|
||||
}} else {{
|
||||
// case 2: arguments are not functional tensors, so we no-op and redispatch.
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)});
|
||||
{return_str(f)};
|
||||
{maybe_create_output(f, 'tmp_output')}at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)});
|
||||
{return_from_mutable_noop_redispatch(f, 'tmp_output')};
|
||||
}}
|
||||
}} else {{
|
||||
{return_type} tmp_output;
|
||||
{{
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
tmp_output = at::_ops::{functional_op.func.name.unambiguous_name()}::call({', '.join(functional_exprs)});
|
||||
tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({', '.join(functional_exprs)});
|
||||
}}
|
||||
{mutable_input_post_processing}
|
||||
{return_str(f)};
|
||||
{wrap_propagate_mutations_and_return(f, g.functional, 'tmp_output')}
|
||||
}}
|
||||
}}"""
|
||||
|
||||
@ -508,7 +666,7 @@ def gen_functionalization_view_inverse_declaration(
|
||||
|
||||
def gen_functionalization_registration(
|
||||
selector: SelectiveBuilder,
|
||||
g: Union[NativeFunction, NativeFunctionsViewGroup],
|
||||
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup],
|
||||
composite_implicit_autograd_index: BackendIndex,
|
||||
) -> List[str]:
|
||||
@with_native_function
|
||||
@ -542,8 +700,16 @@ def gen_functionalization_registration(
|
||||
assert g.view_inplace.is_view_op
|
||||
view_str.append(emit_registration_helper(g.view_inplace))
|
||||
return view_str
|
||||
|
||||
elif isinstance(g, NativeFunctionsGroup):
|
||||
fns = list(g.functions())
|
||||
else:
|
||||
f = g
|
||||
if str(g.func.name) in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION:
|
||||
return []
|
||||
fns = [g]
|
||||
|
||||
registrations = []
|
||||
for f in fns:
|
||||
if str(f.func.name) == "lift":
|
||||
# See Note [Functionalization <> torch.Tensor constructor]
|
||||
return []
|
||||
@ -552,14 +718,17 @@ def gen_functionalization_registration(
|
||||
# We *also* need to directly register CompositeImplicitAUtograd kernels
|
||||
# so that they decompose properly before functioanlization.
|
||||
if modifies_arguments(f) or f.has_composite_implicit_autograd_kernel:
|
||||
return [emit_registration_helper(f)]
|
||||
return []
|
||||
registrations.append(emit_registration_helper(f))
|
||||
return registrations
|
||||
|
||||
|
||||
def gen_functionalization_definition(
|
||||
selector: SelectiveBuilder,
|
||||
g: Union[NativeFunction, NativeFunctionsViewGroup],
|
||||
functional_op: Optional[NativeFunction],
|
||||
# Note: Ideally this code should never have to look at NativeFunction
|
||||
# (and instead only need to operate on grouped NativeFunctions).
|
||||
# The only reason currently is because we need to emit direct dispatch registrations
|
||||
# For CompositeImplicitAutograd operators, which are potentially ungrouped.
|
||||
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup],
|
||||
) -> List[str]:
|
||||
# Don't generate kernels in mobile build
|
||||
if not selector.include_all_operators:
|
||||
@ -576,9 +745,24 @@ def gen_functionalization_definition(
|
||||
if g.view_inplace is not None:
|
||||
view_defs.append(emit_view_functionalization_body(g, view_inplace=True))
|
||||
return view_defs
|
||||
elif isinstance(g, NativeFunction):
|
||||
# Invariant: all mutable operators that we need to handle in functionalization
|
||||
# should have been properly grouped up.
|
||||
# TODO: The below ops all have "problematic" schemas that prevent them from
|
||||
# getting functionalized. Instead of bending over backwards to get things to work,
|
||||
# I think we should either:
|
||||
# (1) fix their schemas (BC-breaking)
|
||||
# (2) hand-write their functionalization kernels
|
||||
if str(g.func.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION:
|
||||
assert g.has_composite_implicit_autograd_kernel or not modifies_arguments(g)
|
||||
return []
|
||||
else:
|
||||
# Case 2: emit inplace -> out-of-place kernels for the functionalization pass
|
||||
f = g
|
||||
if modifies_arguments(f):
|
||||
return [emit_inplace_functionalization_body(f, functional_op)]
|
||||
mutation_defs = []
|
||||
mutation_defs.append(emit_inplace_functionalization_body(g.out, g))
|
||||
if g.inplace is not None:
|
||||
mutation_defs.append(emit_inplace_functionalization_body(g.inplace, g))
|
||||
if g.mutable is not None:
|
||||
mutation_defs.append(emit_inplace_functionalization_body(g.mutable, g))
|
||||
return mutation_defs
|
||||
return []
|
||||
|
@ -13,7 +13,6 @@ from typing import (
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
from torchgen.api.types import BaseCppType
|
||||
@ -27,7 +26,6 @@ from torchgen.gen import (
|
||||
from torchgen.api.lazy import setValueT
|
||||
|
||||
from torchgen.model import (
|
||||
FunctionSchema,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
OperatorName,
|
||||
@ -331,39 +329,18 @@ def run_gen_lazy_tensor(
|
||||
def concat_map_codegen(
|
||||
func: Callable[[NativeFunction], Sequence[str]],
|
||||
xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]],
|
||||
*,
|
||||
codegenInplaceVariant: bool = False,
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
|
||||
only code-gen additional entries for the inplace variant for the native functions.
|
||||
Note: If xs is not sorted, there may be an edge case when generating IR classes. Considering relu and relu_, if
|
||||
we encounter relu_ before relu. we will then generate an IR class with op = at::aten::relu_ for both relu and
|
||||
relu_ which will cause problems for relu.
|
||||
TODO(alanwaketan): Once all ops are grouped properly, we should no longer need this hack.
|
||||
"""
|
||||
generated = set()
|
||||
|
||||
def gen_key(func: FunctionSchema) -> Tuple[str, str]:
|
||||
# we want to generate unique entries for overloads of functional variants,
|
||||
# but not for inplace variants unless explicitly told `codegenInplaceVariant`
|
||||
return (func.name.name.base, func.name.overload_name)
|
||||
|
||||
for x in xs:
|
||||
f = x.functional if isinstance(x, NativeFunctionsGroup) else x
|
||||
# For the 'or'd terms:
|
||||
# 1. codegenInplaceVariant means we can generate the in-place variant corresponding items.
|
||||
# 2. not f.func.name.name.inplace means the op is not a in-place variant, so we can generate the item.
|
||||
# 3. f.func.name.name.base not in generated means even for in-place ops we still need to generate the item
|
||||
# as if they were the functional variants for one time.
|
||||
if f.func.name in full_codegen and (
|
||||
codegenInplaceVariant
|
||||
or not f.func.name.name.inplace
|
||||
or gen_key(f.func) not in generated
|
||||
):
|
||||
generated.add(gen_key(f.func))
|
||||
for r in func(f):
|
||||
yield r
|
||||
fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x]
|
||||
for f in fs:
|
||||
if f.func.name in full_codegen:
|
||||
for r in func(f):
|
||||
yield r
|
||||
|
||||
selector = SelectiveBuilder.get_nop_selector()
|
||||
|
||||
@ -402,7 +379,6 @@ def run_gen_lazy_tensor(
|
||||
backend_indices[backend_key], tensor_class
|
||||
),
|
||||
grouped_native_functions,
|
||||
codegenInplaceVariant=True,
|
||||
)
|
||||
)
|
||||
|
||||
@ -495,7 +471,6 @@ def run_gen_lazy_tensor(
|
||||
get_device_fn,
|
||||
),
|
||||
grouped_native_functions,
|
||||
codegenInplaceVariant=True,
|
||||
)
|
||||
),
|
||||
},
|
||||
|
@ -3,6 +3,7 @@ import re
|
||||
from torchgen.utils import assert_never
|
||||
|
||||
from dataclasses import dataclass
|
||||
import dataclasses
|
||||
from typing import List, Dict, Optional, Iterator, Tuple, Set, Sequence, Callable, Union
|
||||
from enum import Enum, auto
|
||||
import itertools
|
||||
@ -296,7 +297,9 @@ class DeviceCheckType(Enum):
|
||||
ExactSame = 1
|
||||
|
||||
|
||||
ViewSchemaKind = Enum("ViewSchemaKind", ("aliasing", "inplace", "out", "non_aliasing"))
|
||||
ViewSchemaKind = Enum(
|
||||
"ViewSchemaKind", ("aliasing", "aliasing_inplace", "non_aliasing")
|
||||
)
|
||||
|
||||
# The basic input to the code generation is native_functions.yaml.
|
||||
# The name "native", BTW, comes from the distinction between native
|
||||
@ -354,6 +357,14 @@ class NativeFunction:
|
||||
# defined. This is for conveniently reporting error messages!
|
||||
loc: "Location"
|
||||
|
||||
# A list of operators that are expected to be auto-generated for this NativeFunction.
|
||||
# Note: This list isn't actually directly used by the codegen to generate anything.
|
||||
# Instead, the codegen figures out what operators to generate purely based off of
|
||||
# function schema, and uses the autogen declarations to error check.
|
||||
# We expect every NativeFunction that gets auto-generated be explicitly called out
|
||||
# in native_functions.yaml
|
||||
autogen: List["OperatorName"]
|
||||
|
||||
# If non-empty, this kernel is subject to ufunc codegen.
|
||||
# Sorted by ufunc_key
|
||||
ufunc_inner_loop: Dict[UfuncKey, "UfuncInnerLoop"]
|
||||
@ -582,6 +593,14 @@ class NativeFunction:
|
||||
"implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only"
|
||||
)
|
||||
|
||||
autogen_str = e.pop("autogen", "")
|
||||
assert isinstance(autogen_str, str)
|
||||
autogen = (
|
||||
[]
|
||||
if autogen_str == ""
|
||||
else [OperatorName.parse(x) for x in autogen_str.split(", ")]
|
||||
)
|
||||
|
||||
raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {})
|
||||
ufunc_inner_loop = {}
|
||||
if isinstance(raw_ufunc_inner_loop, str):
|
||||
@ -652,6 +671,7 @@ class NativeFunction:
|
||||
structured_delegate=structured_delegate,
|
||||
structured_inherits=structured_inherits,
|
||||
precomputed=precomputed,
|
||||
autogen=autogen,
|
||||
ufunc_inner_loop=ufunc_inner_loop,
|
||||
manual_kernel_registration=manual_kernel_registration,
|
||||
manual_cpp_binding=manual_cpp_binding,
|
||||
@ -754,12 +774,10 @@ class NativeFunction:
|
||||
|
||||
@property
|
||||
def view_schema_kind(self) -> ViewSchemaKind:
|
||||
# This covers both "ordinary" inplace ops, and inplace_views
|
||||
if self.func.name.name.inplace:
|
||||
return ViewSchemaKind.inplace
|
||||
elif self.func.is_out_fn():
|
||||
return ViewSchemaKind.out
|
||||
elif self.is_view_op:
|
||||
if self.is_view_op and self.func.name.name.inplace:
|
||||
assert "inplace_view" in self.tags
|
||||
return ViewSchemaKind.aliasing_inplace
|
||||
if self.is_view_op:
|
||||
return ViewSchemaKind.aliasing
|
||||
else:
|
||||
return ViewSchemaKind.non_aliasing
|
||||
@ -769,7 +787,7 @@ class NativeFunction:
|
||||
return self.func.name.name.base
|
||||
|
||||
|
||||
SchemaKind = Enum("SchemaKind", ("functional", "inplace", "out"))
|
||||
SchemaKind = Enum("SchemaKind", ("functional", "inplace", "out", "mutable"))
|
||||
|
||||
# A structured kernel is guaranteed to have a functional and out variant, and
|
||||
# optionally an inplace variant.
|
||||
@ -781,6 +799,7 @@ SchemaKind = Enum("SchemaKind", ("functional", "inplace", "out"))
|
||||
class NativeFunctionsGroup:
|
||||
functional: NativeFunction
|
||||
inplace: Optional[NativeFunction]
|
||||
mutable: Optional[NativeFunction]
|
||||
out: NativeFunction
|
||||
|
||||
@property
|
||||
@ -797,10 +816,19 @@ class NativeFunctionsGroup:
|
||||
f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
|
||||
)
|
||||
assert self.functional.func.kind() == SchemaKind.functional
|
||||
assert not self.functional.is_view_op, (
|
||||
"View operator shouldn't be grouped into NativeFunctionsGroup objects."
|
||||
f"This is likely because you tried to add an out= variant for '{f.func.name}', which is an existing view operator."
|
||||
"out= variants of view operators are not valid. Please reach out to to the core team if you have questions."
|
||||
)
|
||||
assert self.out.func.kind() == SchemaKind.out
|
||||
|
||||
if self.inplace is not None:
|
||||
assert self.inplace.func.kind() == SchemaKind.inplace
|
||||
|
||||
if self.mutable is not None:
|
||||
assert self.mutable.func.kind() == SchemaKind.mutable
|
||||
|
||||
if self.structured:
|
||||
# For now, structured composite kernels are not supported (need some
|
||||
# design work to figure out how to make the composite case work)
|
||||
@ -813,6 +841,25 @@ class NativeFunctionsGroup:
|
||||
if self.inplace is not None:
|
||||
assert self.inplace.structured_delegate == self.out.func.name
|
||||
|
||||
generated_fns = [
|
||||
str(f.func.name) for f in self.functions() if "generated" in f.tags
|
||||
]
|
||||
generated_fns_str = ", ".join(str(x) for x in generated_fns)
|
||||
expected_generated_fns = f.autogen
|
||||
expected_generated_fns_str = ", ".join(str(x) for x in expected_generated_fns)
|
||||
if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
|
||||
raise RuntimeError(
|
||||
f"The codegen expects to be able to generate '{generated_fns_str}'."
|
||||
" In order to generate them however, we expect them to be called out explicitly in the yaml."
|
||||
f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}"
|
||||
)
|
||||
if expected_generated_fns_str != generated_fns_str:
|
||||
raise RuntimeError(
|
||||
f"The codegen expects to be able to generate '{generated_fns_str}'."
|
||||
f" To do so, it expects a line: 'autogen: {generated_fns_str}'."
|
||||
f" Instead, it found 'autogen: {generated_fns_str}'"
|
||||
)
|
||||
|
||||
def signature(self) -> "FunctionSchema":
|
||||
return self.out.func.signature()
|
||||
|
||||
@ -821,6 +868,8 @@ class NativeFunctionsGroup:
|
||||
yield self.out
|
||||
if self.inplace is not None:
|
||||
yield self.inplace
|
||||
if self.mutable is not None:
|
||||
yield self.mutable
|
||||
|
||||
@property
|
||||
def root_name(self) -> str:
|
||||
@ -836,6 +885,7 @@ class NativeFunctionsGroup:
|
||||
d = dict(d) # non-destructive updates please
|
||||
functional = d.pop(SchemaKind.functional, None)
|
||||
inplace = d.pop(SchemaKind.inplace, None)
|
||||
mutable = d.pop(SchemaKind.mutable, None)
|
||||
out = d.pop(SchemaKind.out, None)
|
||||
assert not d
|
||||
assert functional is not None
|
||||
@ -847,6 +897,7 @@ class NativeFunctionsGroup:
|
||||
return NativeFunctionsGroup(
|
||||
functional=functional,
|
||||
inplace=inplace,
|
||||
mutable=mutable,
|
||||
out=out,
|
||||
)
|
||||
|
||||
@ -1045,12 +1096,26 @@ class FunctionSchema:
|
||||
assert str(r) == func, f"{str(r)} != {func}"
|
||||
return r
|
||||
|
||||
def returns_are_aliased(self) -> bool:
|
||||
# We assert earlier that schemas can't have a mix of aliased and non-aliased returns
|
||||
return any(
|
||||
r
|
||||
for r in self.returns
|
||||
if r.annotation is not None and r.annotation.is_write
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
for arg, ret in zip(self.arguments.out, self.returns):
|
||||
assert arg.annotation == ret.annotation, (
|
||||
"Out arguments must have matching return Tensor; furthermore, "
|
||||
"the ith-argument needs to correspond to the ith return"
|
||||
)
|
||||
# We also enforce that if you have any mutable, positional args, then they are not returned.
|
||||
# This makes it easier to group these functions properly with their functional/out= counterparts.
|
||||
for a in self.arguments.post_self_positional_mutable:
|
||||
assert not any(
|
||||
a.annotation == r.annotation for r in self.returns
|
||||
), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}"
|
||||
# Invariant: we expect out arguments to appear as keyword arguments in the schema.
|
||||
# This means that all mutable returns should be aliased to a keyword argument
|
||||
# (except for "self", which we explicitly don't treat as an out argument because of its use in methods)
|
||||
@ -1063,6 +1128,19 @@ class FunctionSchema:
|
||||
for ret in self.returns
|
||||
if ret.annotation is not None and ret.annotation.is_write
|
||||
]
|
||||
immutable_returns = [
|
||||
ret
|
||||
for ret in self.returns
|
||||
if ret.annotation is None or not ret.annotation.is_write
|
||||
]
|
||||
# Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)",
|
||||
# because:
|
||||
# (1) It's more annoying to handle properly
|
||||
# (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple.
|
||||
# Instead, we expect the (a!) argument to not be returned.
|
||||
assert (
|
||||
len(mutable_returns) == 0 or len(immutable_returns) == 0
|
||||
), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
|
||||
for ret in mutable_returns:
|
||||
assert any([ret.annotation == arg.annotation for arg in out_and_self]), (
|
||||
'All mutable returns must be aliased either to a keyword argument, or to "self". '
|
||||
@ -1103,6 +1181,22 @@ class FunctionSchema:
|
||||
# so in all other cases we expect the return type to be none.
|
||||
assert len(self.returns) == 0
|
||||
|
||||
if self.arguments.tensor_options is not None:
|
||||
assert self.kind() == SchemaKind.functional, (
|
||||
"Found an operator that is not functional, but has tensor options arguments."
|
||||
"This is not allowed- tensor options arguments are only allowed for factory functions."
|
||||
f"schema: {str(self)}"
|
||||
)
|
||||
if self.is_functional_fn():
|
||||
assert self.kind() == SchemaKind.functional, (
|
||||
"Found an operator that is not functional, but its overload contains the string 'functional'."
|
||||
"This is a special keyword in the codegen, please use a different overload name."
|
||||
f"schema: {str(self)}"
|
||||
)
|
||||
|
||||
def is_functional_fn(self) -> bool:
|
||||
return "functional" in self.name.overload_name
|
||||
|
||||
def is_out_fn(self) -> bool:
|
||||
# Note [is_out_fn]
|
||||
#
|
||||
@ -1139,44 +1233,99 @@ class FunctionSchema:
|
||||
modifies the self argument inplace; an out schema writes
|
||||
the result into an explicitly provided out argument.
|
||||
"""
|
||||
is_inplace = self.name.name.inplace
|
||||
is_out = bool(self.arguments.out)
|
||||
assert not (is_inplace and is_out)
|
||||
is_inplace = self.name.name.inplace
|
||||
is_mutable = any(
|
||||
a.annotation is not None and a.annotation.is_write
|
||||
for a in self.arguments.post_self_positional
|
||||
)
|
||||
assert not (is_out and is_inplace)
|
||||
# out= and inplace schemas can also have post_self_positional mutable args,
|
||||
# but we give precedence to out= and inplace when deciding the schema kind.
|
||||
# Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops
|
||||
# to also worry about mutable post_self_positional arguments,
|
||||
# but it seems like a much bigger lift to classify them has having a new schema kind.
|
||||
# The number of ops that fit in this strange category is small enough that
|
||||
# we can probably manually write code for them instead of forcing the codegen to handle them.
|
||||
if is_inplace:
|
||||
return SchemaKind.inplace
|
||||
elif is_out:
|
||||
return SchemaKind.out
|
||||
elif is_mutable:
|
||||
return SchemaKind.mutable
|
||||
else:
|
||||
return SchemaKind.functional
|
||||
|
||||
# For every return:
|
||||
# - If the return aliases an input, we return the input name
|
||||
# - Otherwise, we return None.
|
||||
# If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
|
||||
def aliased_return_names(self) -> List[Optional[str]]:
|
||||
outs: List[Optional[str]] = []
|
||||
for r in self.returns:
|
||||
aliased_args = [
|
||||
a
|
||||
for a in self.arguments.flat_all
|
||||
if a.annotation is not None and a.annotation == r.annotation
|
||||
]
|
||||
if len(aliased_args) == 0:
|
||||
outs.append(None)
|
||||
elif len(aliased_args) == 1:
|
||||
outs.append(aliased_args[0].name)
|
||||
else:
|
||||
aliased_names = ", ".join(a.name for a in aliased_args)
|
||||
raise AssertionError(
|
||||
f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})"
|
||||
)
|
||||
return outs
|
||||
|
||||
def signature(
|
||||
self, *, strip_default: bool = False, strip_view_copy_name: bool = False
|
||||
self,
|
||||
*,
|
||||
strip_default: bool = False,
|
||||
strip_view_copy_name: bool = False,
|
||||
keep_return_names: bool = False,
|
||||
) -> "FunctionSchema":
|
||||
"""
|
||||
Certain schemas are 'related', in that they are simply
|
||||
inplace/out/functional versions of the same function. This method
|
||||
factors these schemas into the "core" functional signature which
|
||||
is equal across all versions.
|
||||
Certain schemas are 'related', in that they are simply
|
||||
inplace/out/functional versions of the same function. This method
|
||||
factors these schemas into the "core" functional signature which
|
||||
is equal across all versions.
|
||||
|
||||
Here is what normalization happens to the schema to convert
|
||||
it to a signature:
|
||||
- The overload name is stripped (name is retained, since
|
||||
it expresses semantic content about what the function does)
|
||||
- Inplace is set False
|
||||
- Out arguments are stripped
|
||||
- Mutability annotations are stripped (this is sound
|
||||
because you cannot overload on mutability annotation)
|
||||
- Return names are stripped since they are not overloadable and
|
||||
some variants have return names but some not
|
||||
Here is what normalization happens to the schema to convert
|
||||
it to a signature:
|
||||
- The overload name is stripped (name is retained, since
|
||||
it expresses semantic content about what the function does)
|
||||
- Inplace is set False
|
||||
- Out arguments are stripped
|
||||
- Mutable post_self_positional args are converted to returns
|
||||
- Mutability annotations are stripped (this is sound
|
||||
because you cannot overload on mutability annotation)
|
||||
- Return names are stripped since they are not overloadable and
|
||||
some variants have return names but some not
|
||||
- TensorOptions are dropped
|
||||
because out= variants of factory functions don't include them
|
||||
(and we want to be able to pair up factory functions with their out variants)
|
||||
|
||||
Finally, we want to be able to pair up related "view" and their
|
||||
corresponding "view_copy" operators. We do this by optionally
|
||||
stripping the trailing "_copy" from the base name.
|
||||
Finally, we want to be able to pair up related "view" and their
|
||||
corresponding "view_copy" operators. We do this by optionally
|
||||
stripping the trailing "_copy" from the base name.
|
||||
|
||||
Example of a mutable op before and after:
|
||||
|
||||
f.func (Mutable operator):
|
||||
_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
|
||||
|
||||
f.func (Corresponding functional operator):
|
||||
_fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) # noqa: B950
|
||||
|
||||
f.func.signature() output:
|
||||
_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) # noqa: B950
|
||||
"""
|
||||
|
||||
def strip_ret_annotation(r: Return) -> Return:
|
||||
return Return(
|
||||
name=None,
|
||||
name=r.name if keep_return_names else None,
|
||||
type=r.type,
|
||||
annotation=None,
|
||||
)
|
||||
@ -1185,6 +1334,43 @@ class FunctionSchema:
|
||||
if strip_view_copy_name and base_name.endswith("_copy"):
|
||||
base_name = base_name.replace("_copy", "")
|
||||
|
||||
# find mutable inputs that are not originally returned, and convert them to returns
|
||||
returns_from_mutable_inputs = tuple(
|
||||
# When we're grouping functions we strip the return names,
|
||||
# but when we're generating the actual functional variants then we follow
|
||||
# a convention for what to name the returns
|
||||
Return(
|
||||
name=f"{a.name}_out" if keep_return_names else None,
|
||||
type=a.type,
|
||||
annotation=None,
|
||||
)
|
||||
for a in itertools.chain(
|
||||
# Order is important here (otherwise e.g. inplace with mutable args
|
||||
# and out= with mutable args won't have the same signature)
|
||||
[self.arguments.self_arg.argument]
|
||||
if self.arguments.self_arg is not None
|
||||
else [],
|
||||
self.arguments.out,
|
||||
self.arguments.post_self_positional,
|
||||
)
|
||||
if a.annotation is not None
|
||||
and a.annotation.is_write
|
||||
and not any(a.annotation == r.annotation for r in self.returns)
|
||||
)
|
||||
original_returns = tuple(map(strip_ret_annotation, self.returns))
|
||||
# Ordering is important here. We expect the "mutable input" returns to come last.
|
||||
returns = original_returns + returns_from_mutable_inputs
|
||||
|
||||
args_sig = self.arguments.signature(strip_default=strip_default)
|
||||
# See Note [arange.start_step schema]
|
||||
if str(self.name) == "arange.start_step":
|
||||
args_sig = Arguments.parse(
|
||||
str(args_sig).replace("Scalar step", "Scalar step=1")
|
||||
)
|
||||
# See Note [bernoulli.p schema]
|
||||
if str(self.name) == "bernoulli.p":
|
||||
args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5"))
|
||||
|
||||
return FunctionSchema(
|
||||
name=OperatorName(
|
||||
name=BaseOperatorName(
|
||||
@ -1194,16 +1380,23 @@ class FunctionSchema:
|
||||
),
|
||||
overload_name="", # stripped
|
||||
),
|
||||
arguments=self.arguments.signature(strip_default=strip_default),
|
||||
returns=tuple(map(strip_ret_annotation, self.returns)),
|
||||
arguments=args_sig,
|
||||
returns=returns,
|
||||
)
|
||||
|
||||
def view_signature(self) -> "FunctionSchema":
|
||||
return self.signature(strip_view_copy_name=True)
|
||||
|
||||
def with_name(self, name: "OperatorName") -> "FunctionSchema":
|
||||
return FunctionSchema(
|
||||
name=name,
|
||||
arguments=self.arguments,
|
||||
returns=self.returns,
|
||||
)
|
||||
|
||||
@property
|
||||
def modifies_arguments(self) -> bool:
|
||||
return self.kind() in [SchemaKind.inplace, SchemaKind.out]
|
||||
return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
|
||||
|
||||
def __str__(self) -> str:
|
||||
all_arguments_str = str(self.arguments)
|
||||
@ -1587,6 +1780,10 @@ class Arguments:
|
||||
ret.extend(self.post_self_positional)
|
||||
return ret
|
||||
|
||||
@property
|
||||
def post_self_positional_mutable(self) -> Sequence[Argument]:
|
||||
return [a for a in self.post_self_positional if a.is_write]
|
||||
|
||||
# NB: doesn't contain out arguments
|
||||
@property
|
||||
def flat_kwarg_only(self) -> Sequence[Argument]:
|
||||
@ -1640,6 +1837,13 @@ class Arguments:
|
||||
ret.extend(self.out)
|
||||
return ret
|
||||
|
||||
def mutable_arg_names(self) -> List[str]:
|
||||
return [
|
||||
a.name
|
||||
for a in self.flat_all
|
||||
if a.annotation is not None and a.annotation.is_write
|
||||
]
|
||||
|
||||
def signature(self, *, strip_default: bool = False) -> "Arguments":
|
||||
# dataclasses.replace could be used here, but it is less
|
||||
# type safe so for now I've opted to type everything out
|
||||
@ -1661,18 +1865,36 @@ class Arguments:
|
||||
post_self_positional=tuple(
|
||||
map(strip_arg_annotation, self.post_self_positional)
|
||||
),
|
||||
# Since TensorOptions are droped, the post_tensor_options_kwargs are
|
||||
# converted to pre_tensor_options_kwargs
|
||||
pre_tensor_options_kwarg_only=tuple(
|
||||
map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)
|
||||
),
|
||||
# NB: tensor_options guaranteed to not have any alias annotations
|
||||
tensor_options=self.tensor_options,
|
||||
post_tensor_options_kwarg_only=tuple(
|
||||
map(strip_arg_annotation, self.post_tensor_options_kwarg_only)
|
||||
),
|
||||
)
|
||||
+ tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)),
|
||||
# TensorOptions are dropped in signature,
|
||||
# so we can pair factory functions with their out= variants.
|
||||
tensor_options=None,
|
||||
post_tensor_options_kwarg_only=tuple(),
|
||||
# out arguments are dropped in signature
|
||||
out=(),
|
||||
)
|
||||
|
||||
def remove_self_annotation(self) -> "Arguments":
|
||||
assert self.self_arg is not None
|
||||
return dataclasses.replace(
|
||||
self,
|
||||
self_arg=SelfArgument(
|
||||
dataclasses.replace(self.self_arg.argument, annotation=None)
|
||||
),
|
||||
)
|
||||
|
||||
def with_out_args(self, outs: List[Argument]) -> "Arguments":
|
||||
assert len(self.out) == 0
|
||||
return dataclasses.replace(
|
||||
self,
|
||||
out=tuple(outs),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]:
|
||||
positional: List[Argument] = []
|
||||
@ -1806,6 +2028,17 @@ class Arguments:
|
||||
if self.tensor_options is None:
|
||||
assert not self.post_tensor_options_kwarg_only
|
||||
|
||||
# We don't allow any of the following to have argument annotations,
|
||||
# to keep things simple.
|
||||
mutable_pre_self_positionals = [
|
||||
a
|
||||
for a in self.pre_self_positional
|
||||
if a.annotation is not None and a.annotation.is_write
|
||||
]
|
||||
assert (
|
||||
len(mutable_pre_self_positionals) == 0
|
||||
), "mutable pre_self_positional arguments are not currently supported in the schema"
|
||||
|
||||
|
||||
# Names that validly are __iXXX__ indicating inplace operations.
|
||||
# Taken from https://www.python.org/dev/peps/pep-0203/#new-methods
|
||||
@ -1923,6 +2156,16 @@ class OperatorName:
|
||||
overload_name=self.overload_name,
|
||||
)
|
||||
|
||||
def with_overload(self, overload: str) -> "OperatorName":
|
||||
return OperatorName(
|
||||
name=BaseOperatorName(
|
||||
base=self.name.base,
|
||||
inplace=False,
|
||||
dunder_method=self.name.dunder_method,
|
||||
),
|
||||
overload_name=overload,
|
||||
)
|
||||
|
||||
|
||||
def gets_generated_out_inplace_wrapper(
|
||||
f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex
|
||||
|
382
torchgen/native_function_generation.py
Normal file
382
torchgen/native_function_generation.py
Normal file
@ -0,0 +1,382 @@
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
DispatchKey,
|
||||
FunctionSchema,
|
||||
BaseType,
|
||||
BaseTy,
|
||||
Return,
|
||||
Annotation,
|
||||
NativeFunction,
|
||||
OperatorName,
|
||||
BackendIndex,
|
||||
BackendMetadata,
|
||||
DeviceCheckType,
|
||||
SchemaKind,
|
||||
Variant,
|
||||
)
|
||||
from torchgen.utils import (
|
||||
concatMap,
|
||||
)
|
||||
|
||||
|
||||
from typing import List, Tuple, Sequence, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
# See Note: [Out ops with functional variants that don't get grouped properly]
|
||||
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
|
||||
# This has a functional variant, but it's currently marked private.
|
||||
# This function should be marked private as well (*_backward ops aren't exposed to python anyway).
|
||||
"adaptive_avg_pool3d_backward.grad_input",
|
||||
# There's a functional variant, _slow_conv2d_backward.output_mask, that isn't grouped properly.
|
||||
# Maybe we can kill this operator in favor of convolution_backward?
|
||||
"_slow_conv2d_backward.grad_input",
|
||||
]
|
||||
|
||||
|
||||
# See Note: [Mutable ops that cannot get an out variant]
|
||||
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
|
||||
# should be out=?
|
||||
"_cummax_helper",
|
||||
# should be out=?
|
||||
"_cummin_helper",
|
||||
]
|
||||
|
||||
INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
|
||||
# polygamma and polygamma.out both exist, but have a
|
||||
# pre-self arg (while polygamma_ does not)
|
||||
# We should either fix this schema so it can be grouped properly,
|
||||
# or allow the codegen to generate new functional/out= NativeFunctions for this op
|
||||
# (which would require changing its overload name to prevent overload ambiguity).
|
||||
"polygamma_"
|
||||
]
|
||||
|
||||
# Groups "similar" NativeFunctions together
|
||||
# example add.Tensor, add_.Tensor, add.out
|
||||
# "similar" NativeFunctions are all expected to have an identical `signature()`,
|
||||
# But have differing SchemaKinds.
|
||||
def pre_group_native_functions(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]:
|
||||
pre_grouped_native_functions: Dict[
|
||||
FunctionSchema, Dict[SchemaKind, NativeFunction]
|
||||
] = defaultdict(dict)
|
||||
for f in native_functions:
|
||||
d = pre_grouped_native_functions[f.func.signature()]
|
||||
assert f.func.kind() not in d
|
||||
d[f.func.kind()] = f
|
||||
return pre_grouped_native_functions
|
||||
|
||||
|
||||
# Helper function: given an inplace FunctionSchema, generate its corresponding out= variant
|
||||
# Example before:
|
||||
# _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
|
||||
# Example after:
|
||||
# _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out)
|
||||
def self_to_out_signature(func: FunctionSchema) -> FunctionSchema:
|
||||
# Generating an out= schema from an inplace schema.
|
||||
assert func.kind() == SchemaKind.inplace
|
||||
assert func.arguments.self_arg is not None
|
||||
# The new out= schema has:
|
||||
# - a new out argument with the same type as "func" (but with a mutable annotation)
|
||||
# - The returns (if any) now alias the out= argument instead of "func"
|
||||
# - an "out" overload name
|
||||
return FunctionSchema(
|
||||
name=func.name.remove_inplace().with_overload(
|
||||
"out" if not func.name.overload_name else f"{func.name.overload_name}_out"
|
||||
),
|
||||
arguments=func.arguments.remove_self_annotation().with_out_args(
|
||||
[
|
||||
Argument(
|
||||
name="out",
|
||||
type=func.arguments.self_arg.argument.type,
|
||||
default=None,
|
||||
annotation=func.arguments.self_arg.argument.annotation,
|
||||
)
|
||||
]
|
||||
),
|
||||
returns=func.returns,
|
||||
)
|
||||
|
||||
|
||||
# Helper function: given a mutable FunctionSchema, generate its corresponding out= variant
|
||||
# Example before:
|
||||
# _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
|
||||
# Example after:
|
||||
# _fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) # noqa: B950
|
||||
def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
|
||||
# Generating an out= schema from a mutable schema.
|
||||
assert func.kind() == SchemaKind.mutable
|
||||
# The new out= schema has:
|
||||
# - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments
|
||||
# (if the argument is a tensor then we also return it for method chaining,
|
||||
# otherwise we return nothing)
|
||||
# - an "out" overload name
|
||||
#
|
||||
# Note that:
|
||||
# (1) This also means that we can *only* generate an out= variant from a mutable schema
|
||||
# if the mutable schema has at least one tensor-like non-aliasing return.
|
||||
# (2) The generated out= variant still has mutable positional arguments,
|
||||
# but if necessary we could probably add another out= variant that also
|
||||
# functionalizes the mutable arguments (a functional_out variant)
|
||||
|
||||
# More of a sanity check - our existing restrictions on schemas should enforce that
|
||||
# mutable schema kinds never return their mutable arguments.
|
||||
assert not any(
|
||||
r.annotation is not None and r.annotation.is_write for r in func.returns
|
||||
)
|
||||
|
||||
tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
|
||||
assert len(tensorlike_rets) > 0
|
||||
|
||||
used_annotations = concatMap(
|
||||
lambda a: [] if a.annotation is None else a.annotation.alias_set,
|
||||
func.arguments.flat_all,
|
||||
)
|
||||
valid_annotations = [
|
||||
x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations
|
||||
]
|
||||
|
||||
all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
|
||||
|
||||
new_out_args: List[Argument] = []
|
||||
# The end result of new_returns is that:
|
||||
# - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
|
||||
# - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
|
||||
new_returns: List[Return] = []
|
||||
for (i, r) in enumerate(func.returns):
|
||||
if r.type.is_tensor_like():
|
||||
new_out = Argument(
|
||||
name=f"out{i}",
|
||||
type=r.type,
|
||||
default=None,
|
||||
annotation=Annotation.parse(f"{valid_annotations[i]}!"),
|
||||
)
|
||||
new_out_args.append(new_out)
|
||||
if all_rets_are_tensors:
|
||||
# The convention for out= schemas is that they only return their out arguments
|
||||
# if the return is a plain Tensor (or if it's a tuple of plain Tensors)
|
||||
new_ret = Return(
|
||||
name=None, type=new_out.type, annotation=new_out.annotation
|
||||
)
|
||||
new_returns.append(new_ret)
|
||||
else:
|
||||
new_returns.append(r)
|
||||
|
||||
return FunctionSchema(
|
||||
name=func.name.remove_inplace().with_overload(
|
||||
"out" if not func.name.overload_name else f"{func.name.overload_name}_out"
|
||||
),
|
||||
arguments=func.arguments.with_out_args(new_out_args),
|
||||
returns=tuple(new_returns),
|
||||
)
|
||||
|
||||
|
||||
# This function, given function of one SchemaKind, as well as a target SchemaKind,
|
||||
# generates a new NativeFunction with the same properties, but using the target SchemaKind.
|
||||
# We only actually generate functions for either functional or out= SchemaKinds.
|
||||
# This function returns a tuple, with:
|
||||
# - The generated NativeFunction
|
||||
# - a dictionary of `BackendIndex` objects, describing which dispatch keys
|
||||
# we will generate kernels for, for the new NativeFunction.
|
||||
# Details are in the function, but we only generate composite kernels (in some cases) today.
|
||||
def generate_function(
|
||||
f: NativeFunction, k: SchemaKind
|
||||
) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]:
|
||||
from torchgen.api import cpp
|
||||
|
||||
if k == SchemaKind.functional:
|
||||
assert f.func.kind() != SchemaKind.functional
|
||||
gets_composite_kernel = True
|
||||
# The new "functional" NativeFunction has:
|
||||
# - any mutable arguments have been converted into (immutable) returns.
|
||||
# (if a mutable argument was not also a return, it gets converted to one)
|
||||
# - a "functional" overload name.
|
||||
# The default grouping logic in signature() actually already does this,
|
||||
# so we can piggy-back off it (but we still want return names)
|
||||
func = f.func.signature(keep_return_names=True).with_name(
|
||||
f.func.name.remove_inplace().with_overload(
|
||||
"functional"
|
||||
if not f.func.name.overload_name
|
||||
else f"{f.func.name.overload_name}_functional"
|
||||
)
|
||||
)
|
||||
elif k == SchemaKind.out:
|
||||
# We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily,
|
||||
# but at least today, there is no good reason to actually use them.
|
||||
# we'll generate a dispatcher entry for them, but won't actually register any kernels for them.
|
||||
gets_composite_kernel = False
|
||||
if f.func.kind() == SchemaKind.inplace:
|
||||
func = self_to_out_signature(f.func)
|
||||
elif f.func.kind() == SchemaKind.mutable:
|
||||
func = mutable_to_out_signature(f.func)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"We only bother generating out= functions from either inplace or mutable variants"
|
||||
)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"We currently only generate either functional or out= NativeFunctions"
|
||||
)
|
||||
|
||||
if gets_composite_kernel:
|
||||
backend_metadata = {
|
||||
DispatchKey.CompositeExplicitAutograd: {
|
||||
func.name: BackendMetadata(cpp.name(func), structured=False)
|
||||
}
|
||||
}
|
||||
else:
|
||||
backend_metadata = {}
|
||||
|
||||
return (
|
||||
NativeFunction(
|
||||
func=func,
|
||||
use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
|
||||
# These generated fn's aren't meant to be user friendly- don't generate methods.
|
||||
variants=set([Variant.function]),
|
||||
structured=False,
|
||||
structured_delegate=None,
|
||||
structured_inherits=None,
|
||||
precomputed=None,
|
||||
autogen=[],
|
||||
ufunc_inner_loop={},
|
||||
manual_kernel_registration=False,
|
||||
manual_cpp_binding=False,
|
||||
python_module=None,
|
||||
category_override=None,
|
||||
device_guard=False,
|
||||
device_check=DeviceCheckType.NoCheck,
|
||||
loc=f.loc,
|
||||
cpp_no_default_args=set(),
|
||||
is_abstract=f.is_abstract,
|
||||
has_composite_implicit_autograd_kernel=False,
|
||||
has_composite_explicit_autograd_kernel=gets_composite_kernel,
|
||||
# Every generated NativeFunction gets a "generated" tag, so it's easy to tell
|
||||
# which NativeFunction objects did not come directly from native_functions.yaml.
|
||||
tags=set(["generated"]),
|
||||
),
|
||||
backend_metadata,
|
||||
)
|
||||
|
||||
|
||||
# This function is responsible for adding generated NativeFunctions which don't appear
|
||||
# explicitly in the codegen.
|
||||
# You can inspect the full list of NativeFunctions yourself with the torchgen package, by running
|
||||
# torchgen.parse_native_yaml("aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml")
|
||||
# (Maybe we should make a friendly API for this)
|
||||
#
|
||||
# Note: this function *mutates* its two inputs,
|
||||
# adding the new NativeFunctions / BackendMetadata to them
|
||||
def add_generated_native_functions(
|
||||
rs: List[NativeFunction],
|
||||
indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
|
||||
) -> None:
|
||||
# The main code for gnerating new NativeFunctions
|
||||
# First we group of NaitveFunctions by schema kind,
|
||||
# then we detect which ones are missing and generate them.
|
||||
pre_grouped_native_functions = pre_group_native_functions(rs)
|
||||
for k, d in pre_grouped_native_functions.items():
|
||||
has_functional = SchemaKind.functional in d
|
||||
has_inplace = SchemaKind.inplace in d
|
||||
has_mutable = SchemaKind.mutable in d
|
||||
has_out = SchemaKind.out in d
|
||||
|
||||
# We automatically generate a few native functions that don't exist in the yaml, for a few reasons:
|
||||
# (1) If an operator has an inplace/out= variant but no functional variant, we can generate
|
||||
# a simple functional variant that the functionalization pass can consume.
|
||||
# (2) If an operator has an inplace and functional but no out= variant, we generate an out=
|
||||
# variant, mostly so we can easily pair up functions into NativeFunctionsGroup,
|
||||
# while maintaining the constraint that the out= variant is "required".
|
||||
#
|
||||
# For now, we don't bother generated NativeFunctions for existing operators
|
||||
# that only have a functional variant.
|
||||
if has_mutable or has_inplace or has_out:
|
||||
|
||||
# Don't bother generating functions trio's for native functions that bypass the dispatcher.
|
||||
are_manual = all(f.manual_cpp_binding for f in d.values())
|
||||
# Don't bother generating functional + out= variants for view operators
|
||||
has_view_ops = (
|
||||
has_inplace and "inplace_view" in d[SchemaKind.inplace].tags
|
||||
) or any(f.is_view_op for f in d.values())
|
||||
# Don't generate the other variants for CompositeImplicitAutograd operators.
|
||||
# We could probably do this, but the main benefit of generating the function triplets
|
||||
# is for transforms that need them, and transforms don't need to act directly
|
||||
# on CompositeImplicitAutograd operators (since we let them decompose).
|
||||
are_composite_implicit = all(
|
||||
f.has_composite_implicit_autograd_kernel for f in d.values()
|
||||
)
|
||||
if are_manual or has_view_ops or are_composite_implicit:
|
||||
continue
|
||||
if has_out and len(d.values()) == 1:
|
||||
# Note: [Out ops with functional variants that don't get grouped properly]
|
||||
# In theory we could validly have an out= operator in native_functions.yaml
|
||||
# that has no other variants.
|
||||
# But today, all of the operators where that's the case actually do have
|
||||
# functional variants, that we are just unable to pair up properly.
|
||||
# I think banning this all together is probably safer
|
||||
# (you can always add a functional variant yourself if you want to add a new out= operator).
|
||||
#
|
||||
# We should probably fix the existing cases; this check is to prevent us from adding more over time.
|
||||
if (
|
||||
str(d[SchemaKind.out].func.name)
|
||||
not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Some inplace ops that have problematic schemas (that we should fix), which prevent us
|
||||
# from generating out= and functional variants
|
||||
if (
|
||||
has_inplace
|
||||
and str(d[SchemaKind.inplace].func.name)
|
||||
in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
|
||||
):
|
||||
continue
|
||||
|
||||
base_fn = (
|
||||
d[SchemaKind.inplace]
|
||||
if has_inplace
|
||||
else d[SchemaKind.mutable]
|
||||
if has_mutable
|
||||
else d[SchemaKind.out]
|
||||
)
|
||||
|
||||
# Note: [Mutable ops that cannot get an out variant]
|
||||
# We can only generate an out= variant if either:
|
||||
# - the original function has tensor-like returns (since we can convert them to out kwargs)
|
||||
# - or it's inplace (since we can convert `self` to an out kwarg)
|
||||
# There are only two functions that don't fit this criteria today though,
|
||||
# and they both look like they should be fixed to be out= variants,
|
||||
# so if feels safer to ban this schema all-together
|
||||
gets_out_variant = not has_out and (
|
||||
base_fn.func.kind() == SchemaKind.inplace
|
||||
or any(r.type.is_tensor_like() for r in base_fn.func.returns)
|
||||
)
|
||||
if not has_out and not gets_out_variant:
|
||||
if (
|
||||
str(base_fn.func.name)
|
||||
not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
|
||||
):
|
||||
raise AssertionError(
|
||||
f"""Found a mutable operator that we could not generate an out= variant for: {str(base_fn.func)}.
|
||||
These operators are problematic, because we can't easily auto-generate functionalization code for them. If you really need
|
||||
the operator have the schema mentioned, that add the name of the operator to the allow-list. Otherwise if possible,
|
||||
please convert it to an inplace operator"""
|
||||
)
|
||||
|
||||
# Generate an out= variant
|
||||
if gets_out_variant:
|
||||
fn, metadata = generate_function(base_fn, SchemaKind.out)
|
||||
d[SchemaKind.out] = fn
|
||||
BackendIndex.grow_index(indices, metadata)
|
||||
rs.append(fn)
|
||||
|
||||
# Generate a functional variant, but only do it if the operator got an out= variant
|
||||
# (Functional variants are only useful if we can group up the variants,
|
||||
# which we can only do if they have an out= variant)
|
||||
if not has_functional and (has_out or gets_out_variant):
|
||||
fn, metadata = generate_function(base_fn, SchemaKind.functional)
|
||||
d[SchemaKind.functional] = fn
|
||||
BackendIndex.grow_index(indices, metadata)
|
||||
rs.append(fn)
|
Reference in New Issue
Block a user