diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 747c5ca2fe28..dbc71e026151 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -505,6 +505,7 @@ variants: function dispatch: CPU: add_relu_ + autogen: _add_relu.Scalar_out # For C++ only, until we have conversion from C++ numbers to Tensor - func: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor @@ -518,6 +519,7 @@ variants: method dispatch: CompositeExplicitAutograd: add_ + autogen: add.Scalar_out - func: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor structured_delegate: addmv.out @@ -609,6 +611,12 @@ - func: arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +# Note [arange.start_step schema] +# We want `arange.start_step` to be grouped up with `arange.start_out`, +# But this doesn't happen automatically because the step argument +# is defaultable for .start_out but not for .start_step. +# We should probably just make "step" a defaultable param on arange.start, +# and kill arange.start_step. - func: arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!) @@ -892,6 +900,7 @@ dispatch: CPU, CUDA: bernoulli_ MPS: bernoulli_mps_ + autogen: bernoulli.Tensor_functional, bernoulli.Tensor_out - func: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -899,7 +908,10 @@ dispatch: CPU, CUDA: bernoulli_ MPS: bernoulli_mps_ + autogen: bernoulli.float_out +# Note [bernoulli.p schema] +# We should probably just fix the overload ambiguity by appending a _functional to the C++ API name (BC breaking) # This out-of-place version isn't used explicitly, but needed by jit. # There is no default valid on `p` here because it would introduce ambiguity # with `bernoulli(Tensor self, *, Generator? generator=None)` declaration. @@ -1420,6 +1432,7 @@ SparseCPU, SparseCUDA: copy_sparse_wrapper_ CompositeExplicitAutograd: copy_ SparseCsrCPU, SparseCsrCUDA: copy_sparse_compressed_ + autogen: copy.out - func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor dispatch: @@ -1780,6 +1793,7 @@ variants: method dispatch: CompositeExplicitAutograd: div_ + autogen: div.Scalar_out - func: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor variants: function, method @@ -1790,6 +1804,7 @@ variants: method dispatch: CompositeExplicitAutograd: div_ + autogen: div.Scalar_mode_out # divide, alias for div - func: divide.Tensor(Tensor self, Tensor other) -> Tensor @@ -1880,6 +1895,7 @@ dispatch: CPU: embedding_renorm_cpu_ CUDA: embedding_renorm_cuda_ + autogen: embedding_renorm.functional, embedding_renorm.out - func: embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor @@ -1993,6 +2009,7 @@ MPS: resize_mps_ QuantizedCPU: quantized_resize_cpu_ SparseCsrCPU, SparseCsrCUDA: resize_sparse_csr_ + autogen: resize.functional, resize.out # This is a utility function to enable users to resize out tensor while registering kernels for out variants. # Eventually, we can consider exposing `resize_output` as a public API to ship it with python op registration @@ -2002,6 +2019,7 @@ variants: function dispatch: Meta: _resize_output_ + autogen: _resize_output.functional, _resize_output.out - func: empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor category_override: factory @@ -2201,6 +2219,7 @@ QuantizedCPU, QuantizedCUDA: fill_quantized_ Meta: fill_meta_ SparseCsrCPU, SparseCsrCUDA: fill_sparse_csr_ + autogen: fill.Scalar_out - func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -2210,6 +2229,7 @@ MPS: fill_tensor_mps_ QuantizedCPU, QuantizedCUDA: fill_quantized_ Meta: fill_meta_ + autogen: fill.Tensor_out - func: floor(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -2494,6 +2514,7 @@ variants: function, method dispatch: CompositeExplicitAutograd: index_put_ + autogen: index_put.out # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: # - Tensor & Tensor::index_put_(ArrayRef indices, Tensor const & rhs) # - Tensor & Tensor::index_put_(ArrayRef indices, Scalar v) @@ -2511,6 +2532,7 @@ variants: function dispatch: CPU, CUDA: _index_put_impl_ + autogen: _index_put_impl.functional, _index_put_impl.out - func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor variants: function @@ -3380,6 +3402,7 @@ dispatch: CompositeExplicitAutograd: mul_ SparseCsrCPU, SparseCsrCUDA: mul__scalar_sparse_csr + autogen: mul.Scalar_out # multiply, alias for mul - func: multiply.Tensor(Tensor self, Tensor other) -> Tensor @@ -3918,6 +3941,7 @@ MkldnnCPU: mkldnn_relu_ QuantizedCPU: relu_quantized_cpu_ NestedTensorCPU, NestedTensorCUDA: NestedTensor_relu_ + autogen: relu.out - func: relu6(Tensor self) -> Tensor python_module: nn @@ -4061,6 +4085,7 @@ device_check: NoCheck # TensorIterator dispatch: CompositeExplicitAutograd: celu_ + autogen: celu.out - func: silu(Tensor self) -> Tensor structured_delegate: silu.out @@ -4800,6 +4825,7 @@ device_guard: False dispatch: MkldnnCPU: mkldnn_transpose_ + autogen: _mkldnn_transpose.out - func: one_hot(Tensor self, int num_classes=-1) -> Tensor python_module: nn @@ -5312,6 +5338,7 @@ variants: function, method dispatch: CompositeExplicitAutograd: resize_as_ + autogen: resize_as.functional, resize_as.out - func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!) use_const_ref_for_mutable_tensors: True @@ -5319,6 +5346,7 @@ dispatch: SparseCPU, SparseCUDA: resize_as_sparse_ SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_csr_ + autogen: resize_as_sparse.functional, resize_as_sparse.out - func: zero_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5330,6 +5358,7 @@ SparseCPU, SparseCUDA: zero_sparse_ SparseCsrCPU, SparseCsrCUDA: zero_sparse_csr_ MkldnnCPU: mkldnn_zero_ + autogen: zero.functional, zero.out - func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5367,6 +5396,7 @@ variants: method dispatch: CompositeExplicitAutograd: sub_ + autogen: sub.Scalar_out # subtract, alias for sub - func: subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) @@ -5629,12 +5659,14 @@ variants: method dispatch: SparseCPU, SparseCUDA: sparse_resize_ + autogen: sparse_resize.functional, sparse_resize.out - func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) use_const_ref_for_mutable_tensors: True variants: method dispatch: SparseCPU, SparseCUDA: sparse_resize_and_clear_ + autogen: sparse_resize_and_clear.functional, sparse_resize_and_clear.out - func: sparse_mask(Tensor self, Tensor mask) -> Tensor variants: method @@ -5740,6 +5772,7 @@ SparseCPU, SparseCUDA: _coalesced_sparse_ device_check: NoCheck device_guard: False + autogen: _coalesced.functional, _coalesced.out - func: indices(Tensor(a) self) -> Tensor(a) variants: method @@ -5799,6 +5832,7 @@ variants: function dispatch: SparseCPU, SparseCUDA: copy_sparse_ + autogen: copy_sparse_to_sparse.functional, copy_sparse_to_sparse.out - func: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] variants: function, method @@ -6007,7 +6041,7 @@ dispatch: CPU: fused_moving_avg_obs_fake_quant_cpu CUDA: fused_moving_avg_obs_fake_quant_cuda - + autogen: _fused_moving_avg_obs_fq_helper.functional, _fused_moving_avg_obs_fq_helper.out - func: _choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int) variants: function @@ -6201,6 +6235,7 @@ device_guard: False dispatch: CPU, CUDA, Meta, MPS: set_ + autogen: set.source_Storage_functional, set.source_Storage_out - func: set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!) variants: method @@ -6211,6 +6246,7 @@ CUDA: set_storage_cuda_ MPS: set_storage_mps_ QuantizedCPU, QuantizedCUDA: set_storage_quantized_ + autogen: set.source_Storage_storage_offset_functional, set.source_Storage_storage_offset_out - func: set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!) variants: method @@ -6223,6 +6259,7 @@ device_guard: False dispatch: CPU, CUDA, Meta, MPS: set_tensor_ + autogen: set.source_Tensor_functional, set.source_Tensor_out - func: set_(Tensor(a!) self) -> Tensor(a!) variants: method @@ -6231,6 +6268,7 @@ CUDA: set_cuda_ Meta: set_meta_ MPS: set_mps_ + autogen: set.functional, set.out - func: lift(Tensor self) -> Tensor variants: method @@ -6253,6 +6291,7 @@ CPU: masked_fill__cpu CUDA: masked_fill__cuda MPS: masked_fill__mps + autogen: masked_fill.Scalar_out - func: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor device_check: NoCheck # TensorIterator @@ -6267,6 +6306,7 @@ CPU: masked_fill__cpu CUDA: masked_fill__cuda MPS: masked_fill__mps + autogen: masked_fill.Tensor_out - func: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor device_check: NoCheck # TensorIterator @@ -6279,6 +6319,7 @@ dispatch: CPU: masked_scatter__cpu CUDA: masked_scatter__cuda + autogen: masked_scatter.out - func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor variants: function, method @@ -6320,6 +6361,7 @@ variants: method dispatch: CPU, CUDA, MPS: put_ + autogen: put.out - func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor variants: function, method @@ -6367,6 +6409,7 @@ dispatch: CPU: index_fill_ CUDA: index_fill_ + autogen: index_fill.int_Scalar_out - func: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor device_check: NoCheck # TensorIterator @@ -6379,6 +6422,7 @@ variants: method dispatch: CPU, CUDA: index_fill_ + autogen: index_fill.int_Tensor_out - func: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor device_check: NoCheck # TensorIterator @@ -6695,12 +6739,14 @@ variants: method dispatch: CPU, CUDA: __ilshift__ + autogen: __lshift__.Scalar_out - func: __ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CPU, CUDA: __ilshift__ + autogen: __lshift__.Tensor_out - func: bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -6760,12 +6806,14 @@ variants: method dispatch: CPU, CUDA: __irshift__ + autogen: __rshift__.Scalar_out - func: __irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CPU, CUDA: __irshift__ + autogen: __rshift__.Tensor_out - func: bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -6855,6 +6903,7 @@ CPU, CUDA: random_ Meta: random_meta_ MPS: random_mps_ + autogen: random.from_functional, random.from_out - func: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -6863,6 +6912,7 @@ CPU, CUDA: random_ Meta: random_meta_ MPS: random_mps_ + autogen: random.to_functional, random.to_out - func: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -6870,6 +6920,7 @@ dispatch: CPU, CUDA: random_ Meta: random_meta_ + autogen: random.functional, random.out - func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -6878,24 +6929,28 @@ CPU, CUDA: uniform_ MPS: uniform_mps_ Meta: uniform_meta_ + autogen: uniform.functional, uniform.out - func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CPU, CUDA: cauchy_ + autogen: cauchy.functional, cauchy.out - func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CPU, CUDA: log_normal_ + autogen: log_normal.functional, log_normal.out - func: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CPU, CUDA: exponential_ + autogen: exponential.functional, exponential.out - func: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -6904,6 +6959,7 @@ CPU, CUDA: geometric_ # wrappers for TH functions + autogen: geometric.functional, geometric.out - func: diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -8306,6 +8362,7 @@ MPS: normal_mps_ Meta: normal_meta_ SparseCsrCPU, SparseCsrCUDA: normal_sparse_csr_ + autogen: normal.functional, normal.out - func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -8356,11 +8413,13 @@ variants: function dispatch: CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_ + autogen: _amp_foreach_non_finite_check_and_unscale.functional, _amp_foreach_non_finite_check_and_unscale.out - func: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!) variants: function dispatch: CUDA: _amp_update_scale_cuda_ + autogen: _amp_update_scale.functional, _amp_update_scale.out #- func: _cat(Tensor[] tensors, int dim=0) -> Tensor #dispatch: @@ -8388,6 +8447,7 @@ dispatch: CPU: foreach_tensor_add_scalar_kernel_slow_ CUDA: foreach_tensor_add_scalar_kernel_cuda_ + autogen: _foreach_add.Scalar_functional, _foreach_add.Scalar_out - func: _foreach_sub.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8402,6 +8462,7 @@ dispatch: CPU: foreach_tensor_sub_scalar_kernel_slow_ CUDA: foreach_tensor_sub_scalar_kernel_cuda_ + autogen: _foreach_sub.Scalar_functional, _foreach_sub.Scalar_out - func: _foreach_mul.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8416,6 +8477,7 @@ dispatch: CPU: foreach_tensor_mul_scalar_kernel_slow_ CUDA: foreach_tensor_mul_scalar_kernel_cuda_ + autogen: _foreach_mul.Scalar_functional, _foreach_mul.Scalar_out - func: _foreach_div.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8430,6 +8492,7 @@ dispatch: CPU: foreach_tensor_div_scalar_kernel_slow_ CUDA: foreach_tensor_div_scalar_kernel_cuda_ + autogen: _foreach_div.Scalar_functional, _foreach_div.Scalar_out - func: _foreach_add.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8444,6 +8507,7 @@ dispatch: CPU: foreach_tensor_add_list_kernel_slow_ CUDA: foreach_tensor_add_list_kernel_cuda_ + autogen: _foreach_add.List_functional, _foreach_add.List_out - func: _foreach_sub.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8458,6 +8522,7 @@ dispatch: CPU: foreach_tensor_sub_list_kernel_slow_ CUDA: foreach_tensor_sub_list_kernel_cuda_ + autogen: _foreach_sub.List_functional, _foreach_sub.List_out - func: _foreach_mul.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8472,6 +8537,7 @@ dispatch: CPU: foreach_tensor_mul_list_kernel_slow_ CUDA: foreach_tensor_mul_list_kernel_cuda_ + autogen: _foreach_mul.List_functional, _foreach_mul.List_out - func: _foreach_div.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8486,6 +8552,7 @@ dispatch: CPU: foreach_tensor_div_list_kernel_slow_ CUDA: foreach_tensor_div_list_kernel_cuda_ + autogen: _foreach_div.List_functional, _foreach_div.List_out - func: _foreach_add.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8500,6 +8567,7 @@ dispatch: CPU: foreach_tensor_add_scalarlist_kernel_slow_ CUDA: foreach_tensor_add_scalarlist_kernel_cuda_ + autogen: _foreach_add.ScalarList_functional, _foreach_add.ScalarList_out - func: _foreach_sub.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8514,6 +8582,7 @@ dispatch: CPU: foreach_tensor_sub_scalarlist_kernel_slow_ CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_ + autogen: _foreach_sub.ScalarList_functional, _foreach_sub.ScalarList_out - func: _foreach_div.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8528,6 +8597,7 @@ dispatch: CPU: foreach_tensor_div_scalarlist_kernel_slow_ CUDA: foreach_tensor_div_scalarlist_kernel_cuda_ + autogen: _foreach_div.ScalarList_functional, _foreach_div.ScalarList_out - func: _foreach_mul.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8542,6 +8612,7 @@ dispatch: CPU: foreach_tensor_mul_scalarlist_kernel_slow_ CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_ + autogen: _foreach_mul.ScalarList_functional, _foreach_mul.ScalarList_out - func: _foreach_exp(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8556,6 +8627,7 @@ dispatch: CPU: foreach_tensor_zero_slow_ CUDA: foreach_tensor_zero_cuda_ + autogen: _foreach_zero.functional, _foreach_zero.out - func: _foreach_exp_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8563,6 +8635,7 @@ dispatch: CPU: foreach_tensor_exp_slow_ CUDA: foreach_tensor_exp_cuda_ + autogen: _foreach_exp.functional, _foreach_exp.out - func: _foreach_sqrt(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8577,6 +8650,7 @@ dispatch: CPU: foreach_tensor_sqrt_slow_ CUDA: foreach_tensor_sqrt_cuda_ + autogen: _foreach_sqrt.functional, _foreach_sqrt.out - func: _foreach_abs(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8591,6 +8665,7 @@ dispatch: CPU: foreach_tensor_abs_slow_ CUDA: foreach_tensor_abs_cuda_ + autogen: _foreach_abs.functional, _foreach_abs.out - func: _foreach_acos(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8605,6 +8680,7 @@ dispatch: CPU: foreach_tensor_acos_slow_ CUDA: foreach_tensor_acos_cuda_ + autogen: _foreach_acos.functional, _foreach_acos.out - func: _foreach_asin(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8619,6 +8695,7 @@ dispatch: CPU: foreach_tensor_asin_slow_ CUDA: foreach_tensor_asin_cuda_ + autogen: _foreach_asin.functional, _foreach_asin.out - func: _foreach_atan(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8633,6 +8710,7 @@ dispatch: CPU: foreach_tensor_atan_slow_ CUDA: foreach_tensor_atan_cuda_ + autogen: _foreach_atan.functional, _foreach_atan.out - func: _foreach_ceil(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8647,6 +8725,7 @@ dispatch: CPU: foreach_tensor_ceil_slow_ CUDA: foreach_tensor_ceil_cuda_ + autogen: _foreach_ceil.functional, _foreach_ceil.out - func: _foreach_cos(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8661,6 +8740,7 @@ dispatch: CPU: foreach_tensor_cos_slow_ CUDA: foreach_tensor_cos_cuda_ + autogen: _foreach_cos.functional, _foreach_cos.out - func: _foreach_cosh(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8675,6 +8755,7 @@ dispatch: CPU: foreach_tensor_cosh_slow_ CUDA: foreach_tensor_cosh_cuda_ + autogen: _foreach_cosh.functional, _foreach_cosh.out - func: _foreach_erf(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8689,6 +8770,7 @@ dispatch: CPU: foreach_tensor_erf_slow_ CUDA: foreach_tensor_erf_cuda_ + autogen: _foreach_erf.functional, _foreach_erf.out - func: _foreach_erfc(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8703,6 +8785,7 @@ dispatch: CPU: foreach_tensor_erfc_slow_ CUDA: foreach_tensor_erfc_cuda_ + autogen: _foreach_erfc.functional, _foreach_erfc.out - func: _foreach_expm1(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8717,6 +8800,7 @@ dispatch: CPU: foreach_tensor_expm1_slow_ CUDA: foreach_tensor_expm1_cuda_ + autogen: _foreach_expm1.functional, _foreach_expm1.out - func: _foreach_floor(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8731,6 +8815,7 @@ dispatch: CPU: foreach_tensor_floor_slow_ CUDA: foreach_tensor_floor_cuda_ + autogen: _foreach_floor.functional, _foreach_floor.out - func: _foreach_log(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8745,6 +8830,7 @@ dispatch: CPU: foreach_tensor_log_slow_ CUDA: foreach_tensor_log_cuda_ + autogen: _foreach_log.functional, _foreach_log.out - func: _foreach_log10(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8759,6 +8845,7 @@ dispatch: CPU: foreach_tensor_log10_slow_ CUDA: foreach_tensor_log10_cuda_ + autogen: _foreach_log10.functional, _foreach_log10.out - func: _foreach_log1p(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8773,6 +8860,7 @@ dispatch: CPU: foreach_tensor_log1p_slow_ CUDA: foreach_tensor_log1p_cuda_ + autogen: _foreach_log1p.functional, _foreach_log1p.out - func: _foreach_log2(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8787,6 +8875,7 @@ dispatch: CPU: foreach_tensor_log2_slow_ CUDA: foreach_tensor_log2_cuda_ + autogen: _foreach_log2.functional, _foreach_log2.out - func: _foreach_neg(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8801,6 +8890,7 @@ dispatch: CPU: foreach_tensor_neg_slow_ CUDA: foreach_tensor_neg_cuda_ + autogen: _foreach_neg.functional, _foreach_neg.out - func: _foreach_tan(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8815,6 +8905,7 @@ dispatch: CPU: foreach_tensor_tan_slow_ CUDA: foreach_tensor_tan_cuda_ + autogen: _foreach_tan.functional, _foreach_tan.out - func: _foreach_tanh(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8829,6 +8920,7 @@ dispatch: CPU: foreach_tensor_tanh_slow_ CUDA: foreach_tensor_tanh_cuda_ + autogen: _foreach_tanh.functional, _foreach_tanh.out - func: _foreach_sin(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8843,6 +8935,7 @@ dispatch: CPU: foreach_tensor_sin_slow_ CUDA: foreach_tensor_sin_cuda_ + autogen: _foreach_sin.functional, _foreach_sin.out - func: _foreach_sinh(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8857,6 +8950,7 @@ dispatch: CPU: foreach_tensor_sinh_slow_ CUDA: foreach_tensor_sinh_cuda_ + autogen: _foreach_sinh.functional, _foreach_sinh.out - func: _foreach_round(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8871,6 +8965,7 @@ dispatch: CPU: foreach_tensor_round_slow_ CUDA: foreach_tensor_round_cuda_ + autogen: _foreach_round.functional, _foreach_round.out - func: _foreach_lgamma(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8885,6 +8980,7 @@ dispatch: CPU: foreach_tensor_lgamma_slow_ CUDA: foreach_tensor_lgamma_cuda_ + autogen: _foreach_lgamma.functional, _foreach_lgamma.out - func: _foreach_frac(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8899,6 +8995,7 @@ dispatch: CPU: foreach_tensor_frac_slow_ CUDA: foreach_tensor_frac_cuda_ + autogen: _foreach_frac.functional, _foreach_frac.out - func: _foreach_reciprocal(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8913,6 +9010,7 @@ dispatch: CPU: foreach_tensor_reciprocal_slow_ CUDA: foreach_tensor_reciprocal_cuda_ + autogen: _foreach_reciprocal.functional, _foreach_reciprocal.out - func: _foreach_sigmoid(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8927,6 +9025,7 @@ dispatch: CPU: foreach_tensor_sigmoid_slow_ CUDA: foreach_tensor_sigmoid_cuda_ + autogen: _foreach_sigmoid.functional, _foreach_sigmoid.out - func: _foreach_trunc(Tensor[] tensors) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8941,6 +9040,7 @@ dispatch: CPU: foreach_tensor_trunc_slow_ CUDA: foreach_tensor_trunc_cuda_ + autogen: _foreach_trunc.functional, _foreach_trunc.out - func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8948,6 +9048,7 @@ dispatch: CPU: foreach_tensor_addcdiv_scalar_slow_ CUDA: foreach_tensor_addcdiv_scalar_cuda_ + autogen: _foreach_addcdiv.Scalar_functional, _foreach_addcdiv.Scalar_out - func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8955,6 +9056,7 @@ dispatch: CPU: foreach_tensor_addcmul_scalar_slow_ CUDA: foreach_tensor_addcmul_scalar_cuda_ + autogen: _foreach_addcmul.Scalar_functional, _foreach_addcmul.Scalar_out - func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8962,6 +9064,7 @@ dispatch: CPU: foreach_tensor_addcdiv_scalarlist_slow_ CUDA: foreach_tensor_addcdiv_scalarlist_cuda_ + autogen: _foreach_addcdiv.ScalarList_functional, _foreach_addcdiv.ScalarList_out - func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8969,6 +9072,7 @@ dispatch: CPU: foreach_tensor_addcmul_scalarlist_slow_ CUDA: foreach_tensor_addcmul_scalarlist_cuda_ + autogen: _foreach_addcmul.ScalarList_functional, _foreach_addcmul.ScalarList_out - func: _foreach_addcdiv.Scalar(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11507,7 +11611,7 @@ python_module: linalg variants: function -- func: linalg_eigvalsh.out(Tensor self, str UPLO='L', *, Tensor(a!) out) -> Tensor(a!) +- func: linalg_eigvalsh.out(Tensor self, str UPLO="L", *, Tensor(a!) out) -> Tensor(a!) python_module: linalg dispatch: CPU, CUDA: linalg_eigvalsh_out @@ -11528,6 +11632,7 @@ dispatch: CPU: _linalg_inv_out_helper_cpu CUDA: _linalg_inv_out_helper_cuda + autogen: _linalg_inv_out_helper.functional, _linalg_inv_out_helper.out - func: linalg_inv_ex(Tensor self, *, bool check_errors=False) -> (Tensor inverse, Tensor info) python_module: linalg diff --git a/aten/src/ATen/native/ts_native_functions.yaml b/aten/src/ATen/native/ts_native_functions.yaml index 80febbd039fc..ba05aca4248e 100644 --- a/aten/src/ATen/native/ts_native_functions.yaml +++ b/aten/src/ATen/native/ts_native_functions.yaml @@ -19,7 +19,7 @@ full_codegen: - avg_pool2d_backward - baddbmm - bernoulli - - bernoulli_.float + - bernoulli.p - binary_cross_entropy - binary_cross_entropy_backward - bitwise_and.Tensor @@ -72,8 +72,8 @@ full_codegen: - log_sigmoid_forward - lt.Scalar - lt.Tensor - - masked_fill_.Scalar - - masked_fill_.Tensor + - masked_fill.Scalar + - masked_fill.Tensor - max - max.dim - max_pool2d_with_indices @@ -101,12 +101,11 @@ full_codegen: - norm.ScalarOpt_dim - pow.Tensor_Scalar - pow.Tensor_Tensor - - random_ - - random_.from - - random_.to + - random.functional + - random.from_functional + - random.to_functional - reciprocal - relu - - relu_ - remainder.Tensor - repeat - rsqrt @@ -141,7 +140,7 @@ full_codegen: - upsample_bilinear2d_backward - upsample_nearest2d - upsample_nearest2d_backward - - zero_ + - zero.functional - narrow_copy.SymInt supported: - as_strided diff --git a/aten/src/ATen/templates/CompositeViewCopyKernels.cpp b/aten/src/ATen/templates/CompositeViewCopyKernels.cpp index 558802a7b7e8..30c2ac2c44d7 100644 --- a/aten/src/ATen/templates/CompositeViewCopyKernels.cpp +++ b/aten/src/ATen/templates/CompositeViewCopyKernels.cpp @@ -13,8 +13,25 @@ $ops_headers namespace at { namespace native { +// This file contains a number of kernels for aten functions that are fully code-generated. +// TODO: rename this file to something more generic. + +at::Tensor clone_arg(const at::Tensor& t) { + return t.clone(); +} + +std::vector clone_arg(const at::TensorList& t_list) { + std::vector out(t_list.size()); + for (const auto& i : c10::irange(t_list.size())) { + out[i] = t_list[i].clone(); + } + return out; +} + ${CompositeViewCopyKernel_Definitions} +${GeneratedCompositeFunctional_Definitions} + } // namespace native } // namespace at diff --git a/aten/src/ATen/test/atest.cpp b/aten/src/ATen/test/atest.cpp index 6ea874fdb4ad..122fa122a58a 100644 --- a/aten/src/ATen/test/atest.cpp +++ b/aten/src/ATen/test/atest.cpp @@ -185,7 +185,9 @@ TEST_F(atest, ne_operators) { TEST_F(atest, add_operators) { auto exp_tensor = tensor({-10, 1, 0, -1, 10}); - run_binary_ops_test(add_out, x_tensor, y_tensor, exp_tensor, INTBOOL, 2); + run_binary_ops_test< + at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Scalar&)>( + add_out, x_tensor, y_tensor, exp_tensor, INTBOOL, 2); } TEST_F(atest, max_operators) { diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 31220b9f2d5a..4a0e7f91cdf9 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -174,6 +174,17 @@ $2 = torch._ops.aten.add.Tensor($0, tensor([[1., 1.], [1., 1.], [1., 1.]]))""") + # Some ops that are mutable are neither inplace nor out= ops. + # They also need special handling. + def test_mutable_op_not_inplace_or_other(self): + def f(x): + return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0) + + logs = self.get_logs(f, torch.ones(1)) + self.assertExpectedInline('\n'.join(logs), """\ +$0 = input('input') +$1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper.functional($0, $0, $0, $0, $0, $0, $0, 1.0, 0, 1, 0)""") + def test_tensor_list_composite(self): def f(x): # Test an op with TensorList input diff --git a/test/test_fx.py b/test/test_fx.py index 56b28371456e..d3bc09fd2ee4 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -3944,8 +3944,6 @@ class TestFunctionalTracing(JitTestCase): "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT, "upsample_nearest": INTERPOLATE_ARGS_CONFLICT, - - "normalize" : MUTABLE, } # List of nn.functionals with Tensor inputs but not with type annotation diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index ab592764e5bd..1b6d69f0d571 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -177,6 +177,9 @@ SKIP_PYTHON_BINDINGS_SIGNATURES = [ @with_native_function def should_generate_py_binding(f: NativeFunction) -> bool: + # So far, all NativeFunctions that are entirely code-generated do not get python bindings. + if "generated" in f.tags: + return False name = cpp.name(f.func) for skip_regex in SKIP_PYTHON_BINDINGS: if skip_regex.match(name): diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py index 8072c6cad2d9..46c3baf3b1e2 100644 --- a/tools/autograd/gen_trace_type.py +++ b/tools/autograd/gen_trace_type.py @@ -432,7 +432,8 @@ def emit_trace_body(f: NativeFunction) -> List[str]: assign_return_values = ( f"{tie_return_values(f)} = " - if f.func.kind() == SchemaKind.functional and f.func.returns + if f.func.kind() in [SchemaKind.functional, SchemaKind.mutable] + and f.func.returns else "" ) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 78e8e4edce13..b54ca547e93b 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -349,7 +349,7 @@ GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = { GRADIENT_IMPLEMENTED_FOR_COMPLEX.update(GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX) # Some operators invalidate the grad_accumulator. Let's reset it. -RESET_GRAD_ACCUMULATOR = {"set", "resize"} +RESET_GRAD_ACCUMULATOR = {"set_", "resize_"} # NOTE [ TensorImpl and Storage Pointer Sanity Checks ] # @@ -734,7 +734,7 @@ def gen_variable_type_func( if ( fn.info is None - and not get_base_name(f) in RESET_GRAD_ACCUMULATOR + and not str(f.func.name.name) in RESET_GRAD_ACCUMULATOR and not get_base_name(f) in DONT_REQUIRE_DERIVATIVE and len(gen_differentiable_outputs(fn)) > 0 and not cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE @@ -857,7 +857,14 @@ def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]: and (len(differentiable_outputs) > 0) ) - if info is not None and info.has_derivatives and not requires_derivative: + if ( + info is not None + and info.has_derivatives + and not requires_derivative + # out= ops are allowed to have zero returns which cause requires_derivative to be False + # we shouldn't error out though (out= ops for autograd just redispatch) + and len(f.func.returns) > 0 + ): raise RuntimeError( f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative" ) @@ -1528,7 +1535,7 @@ def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]: # Save only after the forward AD has been set up body.append(emit_save_outputs()) - if base_name in RESET_GRAD_ACCUMULATOR: + if str(f.func.name.name) in RESET_GRAD_ACCUMULATOR: # `inplace` implies that there is exactly one output named `self`, # so we can keep the generated code easy. If you need to # `reset_grad_accumulator` in an operator that's not `inplace`, you can diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index 185a4cdcef49..feeecefc11a0 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -32,7 +32,10 @@ from torchgen.api.types import ( stringT, ) from torchgen.api import cpp -from torchgen.gen import parse_native_yaml, get_grouped_by_view_native_functions +from torchgen.gen import ( + parse_native_yaml, + get_grouped_by_view_native_functions, +) from torchgen.context import with_native_function from torchgen.model import ( FunctionSchema, diff --git a/torch/csrc/lazy/core/shape_inference.cpp b/torch/csrc/lazy/core/shape_inference.cpp index 6c6462886ed1..25a64f4d9491 100644 --- a/torch/csrc/lazy/core/shape_inference.cpp +++ b/torch/csrc/lazy/core/shape_inference.cpp @@ -81,10 +81,11 @@ std::vector expand_param_if_needed( #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" -std::vector compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) { +TORCH_API std::vector compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) { double size_d = 0; // shape inference code copied from RangeFactories.cpp arange_out function // Note: AT_DISPATCH_ALL_TYPES_AND is just a macro that defines the correct c++ scalar_t type depending on out tensor + AT_DISPATCH_ALL_TYPES_AND(c10::kBFloat16, out.scalar_type(), "compute_shape_arange_out", [&]() { // Note: acc_type further defines an accumulataion type depending on the scalar_t and whether its on cuda vs cpu. using accscalar_t = at::acc_type; @@ -129,7 +130,6 @@ std::vector compute_shape_arange_out(const at::Scalar & start, const at:: // If any of start, end, or stop are floating-point, the dtype is inferred to be the default dtype, see get_default_dtype(). // Otherwise, the dtype is inferred to be torch.int64. - // Since out tensor is specified, its dtype should always be used? return {Shape(out.scalar_type(), {size})}; } @@ -145,7 +145,7 @@ std::vector compute_shape_bernoulli(const at::Tensor & self, c10::optiona return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional generator) { +std::vector compute_shape_bernoulli(const at::Tensor & self, double p, c10::optional generator) { return compute_shape_bernoulli(self, generator); } @@ -224,11 +224,11 @@ std::vector compute_shape_convolution(const at::Tensor & input, const at: } } -std::vector compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { +std::vector compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { +std::vector compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { return {Shape(self.scalar_type(), self.sizes().vec())}; } @@ -380,26 +380,22 @@ std::vector compute_shape_native_dropout_backward(const at::Tensor & grad return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())}; } -std::vector compute_shape_random_(at::Tensor & self, c10::optional generator) { +std::vector compute_shape_random_functional(const at::Tensor & self, c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_random_(at::Tensor & self, int64_t to, c10::optional generator) { - return compute_shape_random_(self, generator); +std::vector compute_shape_random_functional(const at::Tensor & self, int64_t to, c10::optional generator) { + return compute_shape_random_functional(self, generator); } -std::vector compute_shape_random_(at::Tensor & self, int64_t from, c10::optional to, c10::optional generator) { - return compute_shape_random_(self, generator); +std::vector compute_shape_random_functional(const at::Tensor & self, int64_t from, c10::optional to, c10::optional generator) { + return compute_shape_random_functional(self, generator); } std::vector compute_shape_relu(const at::Tensor& self) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_relu_(at::Tensor& self) { - return compute_shape_relu(self); -} - std::vector compute_shape_bitwise_and(const at::Tensor& self, const at::Scalar& other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } @@ -417,7 +413,7 @@ std::vector compute_shape_sum( return {Shape(self.scalar_type(), {})};; } -std::vector compute_shape_zero_(at::Tensor& self) { +std::vector compute_shape_zero_functional(const at::Tensor& self) { return {Shape(self.scalar_type(), self.sizes().vec())}; } diff --git a/torch/csrc/lazy/core/shape_inference.h b/torch/csrc/lazy/core/shape_inference.h index 4b1815e98a1d..94746bb4df70 100644 --- a/torch/csrc/lazy/core/shape_inference.h +++ b/torch/csrc/lazy/core/shape_inference.h @@ -16,7 +16,7 @@ TORCH_API std::vector compute_shape__adaptive_avg_pool2d_bac TORCH_API std::vector compute_shape_abs(const at::Tensor & self); TORCH_API std::vector compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); TORCH_API std::vector compute_shape_bernoulli(const at::Tensor & self, c10::optional generator); -TORCH_API std::vector compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional generator); +TORCH_API std::vector compute_shape_bernoulli(const at::Tensor & self, double p, c10::optional generator); TORCH_API std::vector compute_shape_binary_cross_entropy(const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction); TORCH_API std::vector compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction); TORCH_API std::vector compute_shape_cat(at::TensorList tensors, int64_t dim); @@ -37,8 +37,8 @@ TORCH_API std::vector compute_shape_l1_loss_backward(const a TORCH_API std::vector compute_shape_log_sigmoid_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer); TORCH_API std::vector compute_shape_log_sigmoid_forward(const at::Tensor & self); TORCH_API std::vector compute_shape_logdet(const at::Tensor & self); -TORCH_API std::vector compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value); -TORCH_API std::vector compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Tensor & value); +TORCH_API std::vector compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value); +TORCH_API std::vector compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value); TORCH_API std::vector compute_shape_max(const at::Tensor & self); TORCH_API std::vector compute_shape_mean(const at::Tensor & self, c10::optional dtype); TORCH_API std::vector compute_shape_min(const at::Tensor & self); @@ -50,11 +50,10 @@ TORCH_API std::vector compute_shape_native_layer_norm_backwa TORCH_API std::vector compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight); TORCH_API std::vector compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index); TORCH_API std::vector compute_shape_nonzero(const at::Tensor & self); -TORCH_API std::vector compute_shape_random_(at::Tensor & self, c10::optional generator); -TORCH_API std::vector compute_shape_random_(at::Tensor & self, int64_t to, c10::optional generator); -TORCH_API std::vector compute_shape_random_(at::Tensor & self, int64_t from, c10::optional to, c10::optional generator); +TORCH_API std::vector compute_shape_random_functional(const at::Tensor & self, c10::optional generator); +TORCH_API std::vector compute_shape_random_functional(const at::Tensor & self, int64_t to, c10::optional generator); +TORCH_API std::vector compute_shape_random_functional(const at::Tensor & self, int64_t from, c10::optional to, c10::optional generator); TORCH_API std::vector compute_shape_relu(const at::Tensor & self); -TORCH_API std::vector compute_shape_relu_(at::Tensor & self); TORCH_API std::vector compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats); TORCH_API std::vector compute_shape_smooth_l1_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta); TORCH_API std::vector compute_shape_sort(const at::Tensor & self, int64_t dim, bool descending); @@ -65,7 +64,7 @@ TORCH_API std::vector compute_shape_std(const at::Tensor & s TORCH_API std::vector compute_shape_sum(const at::Tensor & self, c10::optional dtype); TORCH_API std::vector compute_shape__to_copy(const at::Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, c10::optional memory_format); TORCH_API std::vector compute_shape_trace(const at::Tensor & self); -TORCH_API std::vector compute_shape_zero_(at::Tensor & self); +TORCH_API std::vector compute_shape_zero_functional(const at::Tensor & self); TORCH_API std::vector compute_shape_narrow_copy(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length); } // namespace lazy } // namespace torch diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 01875dcb006c..11dd831bacd3 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -310,17 +310,41 @@ def match_differentiability_info( for info in differentiability_infos if info.func.func.kind() == SchemaKind.functional } + non_functional_info_by_signature = { + info.func.func.signature(strip_default=True): info + for info in differentiability_infos + if info.func.func.kind() != SchemaKind.functional + } def find_info(f: NativeFunction) -> Tuple[Optional[DifferentiabilityInfo], bool]: + # (1) Check for an exact match if f.func in info_by_schema: return info_by_schema[f.func], True - # if there is no exact match look for the out-of-place signature. + # (2) If no exact match, check if the out-of-place variant + # of this operator has a match. # i.e mul() for mul_() or mul_out() - return ( - functional_info_by_signature.get(f.func.signature(strip_default=True)), - False, - ) + f_sig = f.func.signature(strip_default=True) + if f_sig in functional_info_by_signature: + return functional_info_by_signature[f_sig], False + + # (3) Some operators have a derivative explicitly defined for the mutable + # variant, but get a code-generated out-of-place variant which does *not* + # come with a derivative formula. + # For the generated out-of-place variant, use the mutable variant's formula + # if it exists. + if "generated" in f.tags and f_sig in non_functional_info_by_signature: + info = non_functional_info_by_signature[f_sig] + # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389 + assert not any( + "self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs + ), f"""\ +Attempted to convert a derivative formula for a mutable operator + to be used by automatically by its functional variant ("{str(f.func)}"). + this is not currently supported (we'd need to fix up the formula in the codegen).""" + return info, False + + return None, False result: List[NativeFunctionWithDifferentiabilityInfo] = [] for f in native_functions: diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index 12e4b5733e60..39c3c8684fa1 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -64,7 +64,9 @@ from typing import Optional, Sequence, Union, List, Set def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str: name = str(func.name.name) - if func.is_out_fn(): + if func.is_functional_fn(): + name += "_functional" + elif func.is_out_fn(): if faithful_name_for_out_overloads: name += "_outf" else: diff --git a/torchgen/api/translate.py b/torchgen/api/translate.py index 3d05c531734c..372350cea58a 100644 --- a/torchgen/api/translate.py +++ b/torchgen/api/translate.py @@ -60,6 +60,8 @@ from torchgen.api.types import ( options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) +out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT))) + longVec_ctype = VectorCType(BaseCType(longT)) optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) @@ -287,20 +289,38 @@ Check this module for more information. return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})" elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): - options = direct_solve(options_ctype) - return f"optTypeMetaToScalarType({options}.dtype_opt())" + try: + options = direct_solve(options_ctype) + return f"optTypeMetaToScalarType({options}.dtype_opt())" + except UnsatError: + out_tensor = direct_solve(out_tensor_ctype) + return f"{out_tensor}.scalar_type()" elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): - options = direct_solve(options_ctype) - return f"{options}.layout_opt()" + try: + options = direct_solve(options_ctype) + return f"{options}.layout_opt()" + except UnsatError: + out_tensor = direct_solve(out_tensor_ctype) + return f"{out_tensor}.layout()" elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): - options = direct_solve(options_ctype) - return f"{options}.device_opt()" + try: + options = direct_solve(options_ctype) + return f"{options}.device_opt()" + except UnsatError: + out_tensor = direct_solve(out_tensor_ctype) + return f"{out_tensor}.device()" elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): - options = direct_solve(options_ctype) - return f"{options}.pinned_memory_opt()" + try: + options = direct_solve(options_ctype) + return f"{options}.pinned_memory_opt()" + except UnsatError: + # If we're calling a factory op from its out= variant, + # We don't actually care about the value of pin_memory. + out_tensor = direct_solve(out_tensor_ctype) + return "c10::nullopt" # We can always do translations from value types to reference types, like vector -> IntArrayRef elif goal.type == BaseCType(intArrayRefT): @@ -348,6 +368,15 @@ Check this module for more information. # With arguments like std::vector. # If that changes, we'll have to add the translation here. + # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor. + # We could probably generalize this to non-tensor types too. + if goal.type == MutRefCType(BaseCType(tensorT)): + const_ref_tensor_ctype = NamedCType( + goal.name, ConstRefCType(BaseCType(tensorT)) + ) + argname = direct_solve(const_ref_tensor_ctype) + return f"const_cast({argname})" + unsat(goal) return [Expr(solve(g, direct=False), g) for g in goal_ctypes] diff --git a/torchgen/context.py b/torchgen/context.py index ab0b90dcb732..f65e3daaa8d9 100644 --- a/torchgen/context.py +++ b/torchgen/context.py @@ -26,6 +26,7 @@ F = TypeVar( F2 = TypeVar( "F2", NativeFunction, + NativeFunctionsGroup, Optional[NativeFunction], bool, ) diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index e6ca469f54f4..3844ee0bd1a5 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -616,6 +616,10 @@ check_inplace(out, sizes, options); const auto& out = outputs_[output_idx].get(); resize_out(out, sizes, strides, options); {create_proxy}""" + elif k is SchemaKind.mutable: + raise AssertionError( + "SchemaKind.mutable structured operators are currently not supported" + ) else: assert_never(k) @@ -631,6 +635,10 @@ resize_out(out, sizes, strides, options); out_args = ", ".join(f"Tensor& out{i}" for i in range(returns)) out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns)) return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}" + elif k is SchemaKind.mutable: + raise AssertionError( + "SchemaKind.mutable structured operators are currently not supported" + ) else: assert_never(k) diff --git a/torchgen/gen.py b/torchgen/gen.py index 8d6576181141..90051ce2679a 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -33,6 +33,10 @@ from torchgen.model import ( ViewSchemaKind, BaseOperatorName, ) +from torchgen.native_function_generation import ( + pre_group_native_functions, + add_generated_native_functions, +) from torchgen.api.types import ( Binding, CppSignatureGroup, @@ -72,6 +76,7 @@ from torchgen.gen_functionalization_type import ( gen_functionalization_registration, gen_functionalization_view_inverse_declaration, gen_composite_view_copy_kernel, + gen_composite_functional_kernel, ) T = TypeVar("T") @@ -180,6 +185,7 @@ def parse_native_yaml_struct( index={}, ) ) + add_generated_native_functions(rs, bs) for k, v in bs.items(): # All structured in-tree operators are implemented in terms of their out operator. indices[k] = BackendIndex( @@ -1278,59 +1284,48 @@ def get_custom_build_selector( return selector -def pre_group_native_functions( - native_functions: Sequence[NativeFunction], -) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]: - pre_grouped_native_functions: Dict[ - FunctionSchema, Dict[SchemaKind, NativeFunction] - ] = defaultdict(dict) - for f in native_functions: - d = pre_grouped_native_functions[f.func.signature()] - assert f.func.kind() not in d - d[f.func.kind()] = f - return pre_grouped_native_functions - - def get_grouped_by_view_native_functions( native_functions: Sequence[NativeFunction], ) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]: def maybe_create_view_group( - d: Dict[ViewSchemaKind, NativeFunction] + d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction] ) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]: funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = [] - if ViewSchemaKind.aliasing not in d: - # Case 1: this op / op group is not aliasing, so we don't create a view group. - # return the original (ungrouped) native functions instead. - for func in d.values(): - funcs.append(func) - else: - # Case 2: this op group contains an aliasing op, so we create a ViewGroup for it. - # The handling for out= ops here is unfortunate. - # out= ops don't really make sense for view operators. - # However, we have at least one existing {view}_copy.out operator in native_functions.yaml. - # It shouldn't be part of a view group, so we explicitly don't group it. - # There currently aren't any out= view ops (and there probably shouldn't be). - # We also expect that when we hit this case, the `non_aliasing` op in the dict - # *must* be a view_copy op (this is asserted in the NativeFunctionsViewGroup constructor) - if ViewSchemaKind.out in d: - funcs.append(d[ViewSchemaKind.out]) + if ViewSchemaKind.aliasing in d: + view = d.pop(ViewSchemaKind.aliasing) + view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None) + view_copy = d.pop(SchemaKind.functional, None) funcs.append( NativeFunctionsViewGroup( - view=d[ViewSchemaKind.aliasing], - view_copy=d.get(ViewSchemaKind.non_aliasing, None), - view_inplace=d.get(ViewSchemaKind.inplace, None), + view=view, + view_copy=view_copy, + view_inplace=view_inplace, ) ) + # Take the remaining functions that weren't part of the view group + # and emit them separately + for func in d.values(): + funcs.append(func) return funcs grouped_by_views: Dict[ - FunctionSchema, Dict[ViewSchemaKind, NativeFunction] + FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction] ] = defaultdict(dict) for f in native_functions: schema = f.func.view_signature() - assert f.view_schema_kind not in grouped_by_views[schema] - grouped_by_views[schema][f.view_schema_kind] = f + view_kind: ViewSchemaKind = f.view_schema_kind + # We need to group up ops relevant to the same "view", consisting of: + # view op (ViewSchemaKind.aliasing) + # view_inplace op (ViewSchemaKind.aliasing_inplace) + # view_copy op (SchemaKind.functional) + if view_kind == ViewSchemaKind.non_aliasing: + kind = f.func.kind() + assert kind not in grouped_by_views[schema] + grouped_by_views[schema][kind] = f + else: + assert view_kind not in grouped_by_views[schema] + grouped_by_views[schema][view_kind] = f return list(concatMap(maybe_create_view_group, grouped_by_views.values())) @@ -1343,6 +1338,9 @@ def get_grouped_native_functions( ) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]: r = NativeFunctionsGroup.from_dict(d) if r is None: + # Invariant: any NativeFunctions that are code-generated + # should have been grouped into NativeFunctionsGroup objects + assert not any("generated" in f.tags for f in d.values()) return list(d.values()) else: return [r] @@ -1835,9 +1833,7 @@ def gen_source_files( native_functions: Sequence[NativeFunction], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], structured_native_functions: Sequence[NativeFunctionsGroup], - native_functions_with_view_groups: Sequence[ - Union[NativeFunction, NativeFunctionsViewGroup] - ], + view_groups: Sequence[NativeFunctionsViewGroup], selector: SelectiveBuilder, static_dispatch_idx: List[BackendIndex], backend_indices: Dict[DispatchKey, BackendIndex], @@ -2118,30 +2114,11 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { }, ) - # We need to easily map from [inplace_op_name] -> [functional_op] for the functionalization pass, - # so here I generate a mapping from every operator name to its corresponding functional NativeFunction (if it exist). - pre_grouped_d: Dict[ - FunctionSchema, Dict[SchemaKind, NativeFunction] - ] = pre_group_native_functions(native_functions) - to_functional_op: Dict[OperatorName, Optional[NativeFunction]] = { - k: v - for d in [ - { - f.func.name: pre_grouped_d[func][SchemaKind.functional] - if SchemaKind.functional in pre_grouped_d[func].keys() - else None - for f in pre_grouped_d[func].values() - } - for func in pre_grouped_d.keys() - ] - for k, v in d.items() - } - def functionalization_env_callable( - g: Union[NativeFunction, NativeFunctionsViewGroup] + g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] ) -> Dict[str, List[str]]: def gen_op_headers( - g: Union[NativeFunction, NativeFunctionsViewGroup] + g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] ) -> List[str]: if isinstance(g, NativeFunctionsViewGroup): # view ops always get a functionalization kernel @@ -2155,11 +2132,28 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { f"#include ", ] return headers + elif isinstance(g, NativeFunctionsGroup): + headers = [ + f"#include ", + f"#include ", + f"#include ", + f"#include ", + ] + if g.inplace is not None: + headers += [ + f"#include ", + f"#include ", + ] + if g.mutable is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers else: - f = g return [ - f"#include ", - f"#include ", + f"#include ", + f"#include ", ] return { @@ -2167,11 +2161,6 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { "func_definitions": gen_functionalization_definition( selector, g, - # We need to manually map inplace ops to their out-of-place variants - # (we can't do this with NativeFunctionsGroup today because not all inplace ops have out= variants) - None - if isinstance(g, NativeFunctionsViewGroup) - else to_functional_op.get(g.func.name, None), ), "func_registrations": gen_functionalization_registration( selector, @@ -2180,9 +2169,30 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { ), } + all_groups: List[ + Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] + ] = list(structured_native_functions) + list( + view_groups # type: ignore[assignment, arg-type, operator] + ) + # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly. + # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because: + # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic) + # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped. + # Although this could go away long-term if we add a dedicated dispatch key for decompositions. + structured_map: Dict[OperatorName, NativeFunction] = { + f.func.name: f + for f in concatMap(lambda g: list(g.functions()), structured_native_functions) + } + view_map: Dict[OperatorName, NativeFunction] = { + f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups) + } + for f in native_functions: + if f.func.name not in structured_map and f.func.name not in view_map: + all_groups.append(f) + cpu_fm.write_sharded( "RegisterFunctionalization.cpp", - native_functions_with_view_groups, + all_groups, key_fn=key_func, env_callable=functionalization_env_callable, num_shards=4, @@ -2203,11 +2213,7 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { lambda g: gen_functionalization_view_inverse_declaration( selector, g ), - [ - g - for g in native_functions_with_view_groups - if isinstance(g, NativeFunctionsViewGroup) - ], + view_groups, ) ) }, @@ -2239,17 +2245,23 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { [g.view] if g.view_copy is None else [g.view, g.view_copy] ) ) - for g in native_functions_with_view_groups - if isinstance(g, NativeFunctionsViewGroup) + for g in view_groups + ] + + [ + "\n".join( + f"#include " + for f in [g.inplace, g.mutable] + if f is not None and "generated" not in f.tags + ) + for g in structured_native_functions ], "CompositeViewCopyKernel_Definitions": list( + mapMaybe(gen_composite_view_copy_kernel, view_groups) + ), + "GeneratedCompositeFunctional_Definitions": list( mapMaybe( - gen_composite_view_copy_kernel, - [ - g - for g in native_functions_with_view_groups - if isinstance(g, NativeFunctionsViewGroup) - ], + gen_composite_functional_kernel, + structured_native_functions, ) ), }, @@ -2377,12 +2389,18 @@ def main() -> None: ) grouped_native_functions = get_grouped_native_functions(native_functions) + structured_native_functions = [ g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup) ] native_functions_with_view_groups = get_grouped_by_view_native_functions( native_functions ) + view_groups = [ + g + for g in native_functions_with_view_groups + if isinstance(g, NativeFunctionsViewGroup) + ] template_dir = os.path.join(options.source_path, "templates") @@ -2451,7 +2469,7 @@ def main() -> None: native_functions=native_functions, grouped_native_functions=grouped_native_functions, structured_native_functions=structured_native_functions, - native_functions_with_view_groups=native_functions_with_view_groups, + view_groups=view_groups, selector=selector, static_dispatch_idx=static_dispatch_idx, backend_indices=backend_indices, diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index c6cf76744f95..a28a3d0e3809 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -4,6 +4,7 @@ from torchgen.api.types import ( Binding, FunctionalizationLambda, ViewInverseSignature, + Expr, NativeSignature, CType, BaseCType, @@ -19,8 +20,9 @@ from torchgen.context import ( ) from torchgen.model import ( Argument, + Return, NativeFunction, - SchemaKind, + NativeFunctionsGroup, BackendIndex, FunctionSchema, SelfArgument, @@ -30,10 +32,30 @@ from torchgen.model import ( NativeFunctionsViewGroup, ListType, ) +from torchgen.native_function_generation import ( + OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, + MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, + INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY, +) + from torchgen.selective_build.selector import SelectiveBuilder from typing import List, Optional, Union, Tuple, Callable + +# Note: [Mutable Ops Not Using Functionalization] +# Ops in this list currently do not work with functionalization and should be fixed. +MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = ( + OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY + + MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT + + INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY + + [ + # It will be BC-breaking, but we should fix their schemas. + # should be inplace? + "record_stream", + ] +) + # This file contains codegen that relates to the functionalization pass. # It includes: # - gen_functionalization_definition @@ -91,24 +113,96 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str] """ -def modifies_arguments(f: NativeFunction) -> bool: - return f.func.kind() in [SchemaKind.inplace, SchemaKind.out] +def return_str(rets: Tuple[Return, ...], names: List[str]) -> str: + assert len(rets) == len(names) + if len(rets) == 0: + return "" + elif len(rets) == 1: + return f"return {names[0]};" + else: + return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});" -# This function constructs the return statement for the kernels that contain mutations -# It mostly just needs to special case multi-output returns to wrap the result in a tuple -def return_str(f: NativeFunction) -> str: - # Need to check both # outs and # returns. Why? - # out= ops with a mutable Tensor(a!)[] argument are expected to have a void return type. - if len(f.func.arguments.out) != 0 and len(f.func.returns) != 0: - if len(f.func.arguments.out) > 1: - return_names = ", ".join(a.name for a in f.func.arguments.out) - return f"return {DispatcherSignature.from_schema(f.func).returns_type().cpp_type()}({return_names});" +# Given a function, and the name of a variable correponding to the output of that function, +# gather up all of the individual returns that are not aliased +def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> List[str]: + aliased_rets = func.aliased_return_names() + non_aliased_names = [] + is_out_var_a_tuple = len(func.returns) > 1 + for (i, r) in enumerate(aliased_rets): + if r is None: + non_aliased_names.append( + f"std::get<{i}>({out_var})" if is_out_var_a_tuple else out_var + ) + return non_aliased_names + + +@with_native_function +def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]: + # We should only be generating these for code-generated NativeFunctions + if "generated" not in g.functional.tags: + return None + # And we always write the kernel for a generated op in terms of a non-generated op. + if g.inplace is not None and "generated" not in g.inplace.tags: + target_f = g.inplace + elif g.mutable is not None and "generated" not in g.mutable.tags: + target_f = g.mutable + else: + # We should be guaranteed to have a valid inplace/mutable variant to call into. + # See Note: [Mutable Ops Not Using Functionalization] + raise AssertionError(str(g.functional.func)) + + sig = DispatcherSignature(g.functional.func) + target_sig = DispatcherSignature(target_f.func) + + context: List[Union[Binding, Expr]] = [] + clone_mutable_inputs = [] + cloned_return_names = [] + # We can't just directly pass all of the arguments from the functional op into the mutating op. + # We need to check for which inputs to the mutating operator are mutable, + # and clone those inputs first. + for a_curr, a_tgt in zip( + dispatcher.jit_arguments(g.functional.func), + dispatcher.jit_arguments(target_f.func), + ): + if a_tgt.annotation is not None and a_tgt.annotation.is_write: + clone_mutable_inputs.append( + f"auto {a_curr.name}_clone = clone_arg({a_curr.name});" + ) + context.append( + Expr( + expr=f"{a_curr.name}_clone", + type=dispatcher.argument_type(a_curr, binds=a_curr.name), + ) + ) + # Invariant: mutable arguments on the inner mutable op are always returns on the functional op. + cloned_return_names.append(f"{a_curr.name}_clone") else: - return f"return {f.func.arguments.out[0].name}" - if f.func.arguments.self_arg is not None and len(f.func.returns) != 0: - return f"return {f.func.arguments.self_arg.argument.name}" - return "" + context.append(dispatcher.argument(a_curr)) + exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())]) + + out_name = "output" + maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else "" + inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name) + ret_str = return_str( + g.functional.func.returns, inner_return_names + cloned_return_names + ) + + clone_mutable_inputs_str = "\n".join(clone_mutable_inputs) + return f""" +{sig.defn()} {{ + {clone_mutable_inputs_str} + {maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs}); + {ret_str} +}} +""" + + +def modifies_arguments(f: NativeFunction) -> bool: + return any( + a.annotation is not None and a.annotation.is_write + for a in f.func.arguments.flat_all + ) def wrapper_name(func: FunctionSchema) -> str: @@ -352,11 +446,103 @@ def emit_view_functionalization_body( """ +def maybe_create_output(f: NativeFunction, var_name: str) -> str: + if len(f.func.returns) == 0: + return "" + return_type = dispatcher.returns_type(f.func.returns).remove_const_ref().cpp_type() + return f"{return_type} {var_name} = " + + +# Given a NativeFunction, and a variable name corresponding to the output of redispatching on the function, +# this returns two lists of names, consisting of: +# - the names of returns corresponding to the original (mutable) inputs of the outer function +# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function +def get_mutable_redispatch_return_names( + f: NativeFunction, inner_return_var: str +) -> Tuple[List[str], List[str]]: + aliased_returns = [] + non_aliased_returns = [] + for (i, name) in enumerate(f.func.aliased_return_names()): + if name is not None: + aliased_returns.append(name) + else: + non_aliased_returns.append( + inner_return_var + if len(f.func.returns) == 1 + else f"std::get<{i}>({inner_return_var})" + ) + return aliased_returns, non_aliased_returns + + +# When functionalization "no-op's" and redispatches on a mutable operator, we need to take care so that: +# - For fresh outputs, we return the result of the redispatch (without wrapping outputs) +# - For outputs that were aliased to inputs, we return the inputs directly (since some of them might have been wrapped) +def return_from_mutable_noop_redispatch( + f: NativeFunction, inner_return_var: str +) -> str: + aliased, non_aliased = get_mutable_redispatch_return_names(f, inner_return_var) + # Just get all of the return names, and immediately return them + return return_str(f.func.returns, aliased + non_aliased) + + +def wrap_propagate_mutations_and_return( + f: NativeFunction, functional_op: NativeFunction, inner_return_var: str +) -> str: + mutable_arg_names = f.func.arguments.mutable_arg_names() + ( + aliased_outer_rets, + non_aliased_outer_rets, + ) = get_mutable_redispatch_return_names(f, inner_return_var) + _, non_aliased_inner_rets = get_mutable_redispatch_return_names( + functional_op, inner_return_var + ) + # The outer function may have a mix of aliased and non-aliased outputs, + # But the inner functional op that we're transforming to should only have non-aliased outputs + assert len(mutable_arg_names) + len(non_aliased_outer_rets) == len( + non_aliased_inner_rets + ) + + # First, take all of the newly created outputs from the inner call and wrap them into functional tensors + updates = [] + non_aliased_wrapped_ret_names = [] + for (i, inner_ret) in enumerate( + non_aliased_inner_rets[: len(non_aliased_outer_rets)] + ): + ret_name = f"output_{i}" + updates.append( + f"""\ + auto output_{i} = at::functionalization::impl::to_functional_tensor({inner_ret});""" + ) + non_aliased_wrapped_ret_names.append(ret_name) + + # Next, take all of the mutated outputs from the inner call corresponding to mutated inputs, + # and propogate the mutations + for (outer_arg, inner_ret) in zip( + mutable_arg_names, non_aliased_inner_rets[len(non_aliased_outer_rets) :] + ): + updates.append( + f"""\ + at::functionalization::impl::replace_({outer_arg}, {inner_ret}); + at::functionalization::impl::commit_update({outer_arg});""" + ) + + # Finally, we return: + # - Any mutable arguments that also returns + # - Any immutable returns that were created wrapping the output from the inner call + returns_str = return_str( + f.func.returns, aliased_outer_rets + non_aliased_wrapped_ret_names + ) + updates_str = "\n".join(updates) + return f"""\ +{updates_str} + {returns_str}""" + + # Generates the Functionalization kernel for: # - mutation ops (inplace and out= ops) @with_native_function_and def emit_inplace_functionalization_body( - f: NativeFunction, functional_op: Optional[NativeFunction] + f: NativeFunction, g: NativeFunctionsGroup ) -> str: # mutation case assert modifies_arguments(f) @@ -400,42 +586,15 @@ def emit_inplace_functionalization_body( for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False) ] - if functional_op is None: - # We can't functionalize this inplace op, since we don't know what the corresponding functional op is. - return_type = ( - dispatcher.returns_type(f.func.returns).remove_const_ref().cpp_type() - ) - warn_str = f"""Note: the functionalization pass encountered an operator ({str(f.func.name)}) that it could not \ -functionalize, because it couldn't find an out-of-place equivalent of the operator to call. \ -Instead, it's calling the inplace/view operator directly. \ -If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch.""" - - return f""" - {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ - if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{ - TORCH_WARN("{warn_str}"); - }} - {unwrap_tensor_args_str} - at::AutoDispatchSkipFunctionalize guard; - // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops. - at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)}); - {return_str(f)}; - }} -""" - else: - # call the out-of-place variant of the op - return_type = ( - dispatcher.returns_type(functional_op.func.returns) - .remove_const_ref() - .cpp_type() - ) - functional_sig = DispatcherSignature.from_schema(functional_op.func) - functional_exprs = [ - e.expr - for e in translate( - unwrapped_args_ctx, functional_sig.arguments(), method=False - ) - ] + # call the out-of-place variant of the op + return_type = ( + dispatcher.returns_type(g.functional.func.returns).remove_const_ref().cpp_type() + ) + functional_sig = DispatcherSignature.from_schema(g.functional.func) + functional_exprs = [ + e.expr + for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False) + ] if f.func.is_out_fn(): mutable_input_post_processing = "\n".join( @@ -471,17 +630,16 @@ If this causes problems in your program, consider upstreaming the out-of-place o }} else {{ // case 2: arguments are not functional tensors, so we no-op and redispatch. at::AutoDispatchSkipFunctionalize guard; - at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)}); - {return_str(f)}; + {maybe_create_output(f, 'tmp_output')}at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)}); + {return_from_mutable_noop_redispatch(f, 'tmp_output')}; }} }} else {{ {return_type} tmp_output; {{ at::AutoDispatchSkipFunctionalize guard; - tmp_output = at::_ops::{functional_op.func.name.unambiguous_name()}::call({', '.join(functional_exprs)}); + tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({', '.join(functional_exprs)}); }} - {mutable_input_post_processing} - {return_str(f)}; + {wrap_propagate_mutations_and_return(f, g.functional, 'tmp_output')} }} }}""" @@ -508,7 +666,7 @@ def gen_functionalization_view_inverse_declaration( def gen_functionalization_registration( selector: SelectiveBuilder, - g: Union[NativeFunction, NativeFunctionsViewGroup], + g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup], composite_implicit_autograd_index: BackendIndex, ) -> List[str]: @with_native_function @@ -542,8 +700,16 @@ def gen_functionalization_registration( assert g.view_inplace.is_view_op view_str.append(emit_registration_helper(g.view_inplace)) return view_str + + elif isinstance(g, NativeFunctionsGroup): + fns = list(g.functions()) else: - f = g + if str(g.func.name) in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION: + return [] + fns = [g] + + registrations = [] + for f in fns: if str(f.func.name) == "lift": # See Note [Functionalization <> torch.Tensor constructor] return [] @@ -552,14 +718,17 @@ def gen_functionalization_registration( # We *also* need to directly register CompositeImplicitAUtograd kernels # so that they decompose properly before functioanlization. if modifies_arguments(f) or f.has_composite_implicit_autograd_kernel: - return [emit_registration_helper(f)] - return [] + registrations.append(emit_registration_helper(f)) + return registrations def gen_functionalization_definition( selector: SelectiveBuilder, - g: Union[NativeFunction, NativeFunctionsViewGroup], - functional_op: Optional[NativeFunction], + # Note: Ideally this code should never have to look at NativeFunction + # (and instead only need to operate on grouped NativeFunctions). + # The only reason currently is because we need to emit direct dispatch registrations + # For CompositeImplicitAutograd operators, which are potentially ungrouped. + g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup], ) -> List[str]: # Don't generate kernels in mobile build if not selector.include_all_operators: @@ -576,9 +745,24 @@ def gen_functionalization_definition( if g.view_inplace is not None: view_defs.append(emit_view_functionalization_body(g, view_inplace=True)) return view_defs + elif isinstance(g, NativeFunction): + # Invariant: all mutable operators that we need to handle in functionalization + # should have been properly grouped up. + # TODO: The below ops all have "problematic" schemas that prevent them from + # getting functionalized. Instead of bending over backwards to get things to work, + # I think we should either: + # (1) fix their schemas (BC-breaking) + # (2) hand-write their functionalization kernels + if str(g.func.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION: + assert g.has_composite_implicit_autograd_kernel or not modifies_arguments(g) + return [] else: # Case 2: emit inplace -> out-of-place kernels for the functionalization pass - f = g - if modifies_arguments(f): - return [emit_inplace_functionalization_body(f, functional_op)] + mutation_defs = [] + mutation_defs.append(emit_inplace_functionalization_body(g.out, g)) + if g.inplace is not None: + mutation_defs.append(emit_inplace_functionalization_body(g.inplace, g)) + if g.mutable is not None: + mutation_defs.append(emit_inplace_functionalization_body(g.mutable, g)) + return mutation_defs return [] diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py index d160ef2ef486..6b2f2e5aaceb 100644 --- a/torchgen/gen_lazy_tensor.py +++ b/torchgen/gen_lazy_tensor.py @@ -13,7 +13,6 @@ from typing import ( Callable, Iterable, Iterator, - Tuple, Type, ) from torchgen.api.types import BaseCppType @@ -27,7 +26,6 @@ from torchgen.gen import ( from torchgen.api.lazy import setValueT from torchgen.model import ( - FunctionSchema, NativeFunction, NativeFunctionsGroup, OperatorName, @@ -331,39 +329,18 @@ def run_gen_lazy_tensor( def concat_map_codegen( func: Callable[[NativeFunction], Sequence[str]], xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]], - *, - codegenInplaceVariant: bool = False, ) -> Iterator[str]: """ We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we only code-gen additional entries for the inplace variant for the native functions. - Note: If xs is not sorted, there may be an edge case when generating IR classes. Considering relu and relu_, if - we encounter relu_ before relu. we will then generate an IR class with op = at::aten::relu_ for both relu and - relu_ which will cause problems for relu. - TODO(alanwaketan): Once all ops are grouped properly, we should no longer need this hack. """ - generated = set() - - def gen_key(func: FunctionSchema) -> Tuple[str, str]: - # we want to generate unique entries for overloads of functional variants, - # but not for inplace variants unless explicitly told `codegenInplaceVariant` - return (func.name.name.base, func.name.overload_name) for x in xs: - f = x.functional if isinstance(x, NativeFunctionsGroup) else x - # For the 'or'd terms: - # 1. codegenInplaceVariant means we can generate the in-place variant corresponding items. - # 2. not f.func.name.name.inplace means the op is not a in-place variant, so we can generate the item. - # 3. f.func.name.name.base not in generated means even for in-place ops we still need to generate the item - # as if they were the functional variants for one time. - if f.func.name in full_codegen and ( - codegenInplaceVariant - or not f.func.name.name.inplace - or gen_key(f.func) not in generated - ): - generated.add(gen_key(f.func)) - for r in func(f): - yield r + fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x] + for f in fs: + if f.func.name in full_codegen: + for r in func(f): + yield r selector = SelectiveBuilder.get_nop_selector() @@ -402,7 +379,6 @@ def run_gen_lazy_tensor( backend_indices[backend_key], tensor_class ), grouped_native_functions, - codegenInplaceVariant=True, ) ) @@ -495,7 +471,6 @@ def run_gen_lazy_tensor( get_device_fn, ), grouped_native_functions, - codegenInplaceVariant=True, ) ), }, diff --git a/torchgen/model.py b/torchgen/model.py index e0888344a825..aec30e216abc 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -3,6 +3,7 @@ import re from torchgen.utils import assert_never from dataclasses import dataclass +import dataclasses from typing import List, Dict, Optional, Iterator, Tuple, Set, Sequence, Callable, Union from enum import Enum, auto import itertools @@ -296,7 +297,9 @@ class DeviceCheckType(Enum): ExactSame = 1 -ViewSchemaKind = Enum("ViewSchemaKind", ("aliasing", "inplace", "out", "non_aliasing")) +ViewSchemaKind = Enum( + "ViewSchemaKind", ("aliasing", "aliasing_inplace", "non_aliasing") +) # The basic input to the code generation is native_functions.yaml. # The name "native", BTW, comes from the distinction between native @@ -354,6 +357,14 @@ class NativeFunction: # defined. This is for conveniently reporting error messages! loc: "Location" + # A list of operators that are expected to be auto-generated for this NativeFunction. + # Note: This list isn't actually directly used by the codegen to generate anything. + # Instead, the codegen figures out what operators to generate purely based off of + # function schema, and uses the autogen declarations to error check. + # We expect every NativeFunction that gets auto-generated be explicitly called out + # in native_functions.yaml + autogen: List["OperatorName"] + # If non-empty, this kernel is subject to ufunc codegen. # Sorted by ufunc_key ufunc_inner_loop: Dict[UfuncKey, "UfuncInnerLoop"] @@ -582,6 +593,14 @@ class NativeFunction: "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only" ) + autogen_str = e.pop("autogen", "") + assert isinstance(autogen_str, str) + autogen = ( + [] + if autogen_str == "" + else [OperatorName.parse(x) for x in autogen_str.split(", ")] + ) + raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {}) ufunc_inner_loop = {} if isinstance(raw_ufunc_inner_loop, str): @@ -652,6 +671,7 @@ class NativeFunction: structured_delegate=structured_delegate, structured_inherits=structured_inherits, precomputed=precomputed, + autogen=autogen, ufunc_inner_loop=ufunc_inner_loop, manual_kernel_registration=manual_kernel_registration, manual_cpp_binding=manual_cpp_binding, @@ -754,12 +774,10 @@ class NativeFunction: @property def view_schema_kind(self) -> ViewSchemaKind: - # This covers both "ordinary" inplace ops, and inplace_views - if self.func.name.name.inplace: - return ViewSchemaKind.inplace - elif self.func.is_out_fn(): - return ViewSchemaKind.out - elif self.is_view_op: + if self.is_view_op and self.func.name.name.inplace: + assert "inplace_view" in self.tags + return ViewSchemaKind.aliasing_inplace + if self.is_view_op: return ViewSchemaKind.aliasing else: return ViewSchemaKind.non_aliasing @@ -769,7 +787,7 @@ class NativeFunction: return self.func.name.name.base -SchemaKind = Enum("SchemaKind", ("functional", "inplace", "out")) +SchemaKind = Enum("SchemaKind", ("functional", "inplace", "out", "mutable")) # A structured kernel is guaranteed to have a functional and out variant, and # optionally an inplace variant. @@ -781,6 +799,7 @@ SchemaKind = Enum("SchemaKind", ("functional", "inplace", "out")) class NativeFunctionsGroup: functional: NativeFunction inplace: Optional[NativeFunction] + mutable: Optional[NativeFunction] out: NativeFunction @property @@ -797,10 +816,19 @@ class NativeFunctionsGroup: f"that don't have matching signatures: {test_sig} != {f.func.signature()}" ) assert self.functional.func.kind() == SchemaKind.functional + assert not self.functional.is_view_op, ( + "View operator shouldn't be grouped into NativeFunctionsGroup objects." + f"This is likely because you tried to add an out= variant for '{f.func.name}', which is an existing view operator." + "out= variants of view operators are not valid. Please reach out to to the core team if you have questions." + ) assert self.out.func.kind() == SchemaKind.out + if self.inplace is not None: assert self.inplace.func.kind() == SchemaKind.inplace + if self.mutable is not None: + assert self.mutable.func.kind() == SchemaKind.mutable + if self.structured: # For now, structured composite kernels are not supported (need some # design work to figure out how to make the composite case work) @@ -813,6 +841,25 @@ class NativeFunctionsGroup: if self.inplace is not None: assert self.inplace.structured_delegate == self.out.func.name + generated_fns = [ + str(f.func.name) for f in self.functions() if "generated" in f.tags + ] + generated_fns_str = ", ".join(str(x) for x in generated_fns) + expected_generated_fns = f.autogen + expected_generated_fns_str = ", ".join(str(x) for x in expected_generated_fns) + if len(expected_generated_fns) == 0 and len(generated_fns) > 0: + raise RuntimeError( + f"The codegen expects to be able to generate '{generated_fns_str}'." + " In order to generate them however, we expect them to be called out explicitly in the yaml." + f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}" + ) + if expected_generated_fns_str != generated_fns_str: + raise RuntimeError( + f"The codegen expects to be able to generate '{generated_fns_str}'." + f" To do so, it expects a line: 'autogen: {generated_fns_str}'." + f" Instead, it found 'autogen: {generated_fns_str}'" + ) + def signature(self) -> "FunctionSchema": return self.out.func.signature() @@ -821,6 +868,8 @@ class NativeFunctionsGroup: yield self.out if self.inplace is not None: yield self.inplace + if self.mutable is not None: + yield self.mutable @property def root_name(self) -> str: @@ -836,6 +885,7 @@ class NativeFunctionsGroup: d = dict(d) # non-destructive updates please functional = d.pop(SchemaKind.functional, None) inplace = d.pop(SchemaKind.inplace, None) + mutable = d.pop(SchemaKind.mutable, None) out = d.pop(SchemaKind.out, None) assert not d assert functional is not None @@ -847,6 +897,7 @@ class NativeFunctionsGroup: return NativeFunctionsGroup( functional=functional, inplace=inplace, + mutable=mutable, out=out, ) @@ -1045,12 +1096,26 @@ class FunctionSchema: assert str(r) == func, f"{str(r)} != {func}" return r + def returns_are_aliased(self) -> bool: + # We assert earlier that schemas can't have a mix of aliased and non-aliased returns + return any( + r + for r in self.returns + if r.annotation is not None and r.annotation.is_write + ) + def __post_init__(self) -> None: for arg, ret in zip(self.arguments.out, self.returns): assert arg.annotation == ret.annotation, ( "Out arguments must have matching return Tensor; furthermore, " "the ith-argument needs to correspond to the ith return" ) + # We also enforce that if you have any mutable, positional args, then they are not returned. + # This makes it easier to group these functions properly with their functional/out= counterparts. + for a in self.arguments.post_self_positional_mutable: + assert not any( + a.annotation == r.annotation for r in self.returns + ), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}" # Invariant: we expect out arguments to appear as keyword arguments in the schema. # This means that all mutable returns should be aliased to a keyword argument # (except for "self", which we explicitly don't treat as an out argument because of its use in methods) @@ -1063,6 +1128,19 @@ class FunctionSchema: for ret in self.returns if ret.annotation is not None and ret.annotation.is_write ] + immutable_returns = [ + ret + for ret in self.returns + if ret.annotation is None or not ret.annotation.is_write + ] + # Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)", + # because: + # (1) It's more annoying to handle properly + # (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple. + # Instead, we expect the (a!) argument to not be returned. + assert ( + len(mutable_returns) == 0 or len(immutable_returns) == 0 + ), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}" for ret in mutable_returns: assert any([ret.annotation == arg.annotation for arg in out_and_self]), ( 'All mutable returns must be aliased either to a keyword argument, or to "self". ' @@ -1103,6 +1181,22 @@ class FunctionSchema: # so in all other cases we expect the return type to be none. assert len(self.returns) == 0 + if self.arguments.tensor_options is not None: + assert self.kind() == SchemaKind.functional, ( + "Found an operator that is not functional, but has tensor options arguments." + "This is not allowed- tensor options arguments are only allowed for factory functions." + f"schema: {str(self)}" + ) + if self.is_functional_fn(): + assert self.kind() == SchemaKind.functional, ( + "Found an operator that is not functional, but its overload contains the string 'functional'." + "This is a special keyword in the codegen, please use a different overload name." + f"schema: {str(self)}" + ) + + def is_functional_fn(self) -> bool: + return "functional" in self.name.overload_name + def is_out_fn(self) -> bool: # Note [is_out_fn] # @@ -1139,44 +1233,99 @@ class FunctionSchema: modifies the self argument inplace; an out schema writes the result into an explicitly provided out argument. """ - is_inplace = self.name.name.inplace is_out = bool(self.arguments.out) - assert not (is_inplace and is_out) + is_inplace = self.name.name.inplace + is_mutable = any( + a.annotation is not None and a.annotation.is_write + for a in self.arguments.post_self_positional + ) + assert not (is_out and is_inplace) + # out= and inplace schemas can also have post_self_positional mutable args, + # but we give precedence to out= and inplace when deciding the schema kind. + # Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops + # to also worry about mutable post_self_positional arguments, + # but it seems like a much bigger lift to classify them has having a new schema kind. + # The number of ops that fit in this strange category is small enough that + # we can probably manually write code for them instead of forcing the codegen to handle them. if is_inplace: return SchemaKind.inplace elif is_out: return SchemaKind.out + elif is_mutable: + return SchemaKind.mutable else: return SchemaKind.functional + # For every return: + # - If the return aliases an input, we return the input name + # - Otherwise, we return None. + # If return names were enforced to be consistent with aliasing information, then we wouldn't need this. + def aliased_return_names(self) -> List[Optional[str]]: + outs: List[Optional[str]] = [] + for r in self.returns: + aliased_args = [ + a + for a in self.arguments.flat_all + if a.annotation is not None and a.annotation == r.annotation + ] + if len(aliased_args) == 0: + outs.append(None) + elif len(aliased_args) == 1: + outs.append(aliased_args[0].name) + else: + aliased_names = ", ".join(a.name for a in aliased_args) + raise AssertionError( + f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})" + ) + return outs + def signature( - self, *, strip_default: bool = False, strip_view_copy_name: bool = False + self, + *, + strip_default: bool = False, + strip_view_copy_name: bool = False, + keep_return_names: bool = False, ) -> "FunctionSchema": """ - Certain schemas are 'related', in that they are simply - inplace/out/functional versions of the same function. This method - factors these schemas into the "core" functional signature which - is equal across all versions. + Certain schemas are 'related', in that they are simply + inplace/out/functional versions of the same function. This method + factors these schemas into the "core" functional signature which + is equal across all versions. - Here is what normalization happens to the schema to convert - it to a signature: - - The overload name is stripped (name is retained, since - it expresses semantic content about what the function does) - - Inplace is set False - - Out arguments are stripped - - Mutability annotations are stripped (this is sound - because you cannot overload on mutability annotation) - - Return names are stripped since they are not overloadable and - some variants have return names but some not + Here is what normalization happens to the schema to convert + it to a signature: + - The overload name is stripped (name is retained, since + it expresses semantic content about what the function does) + - Inplace is set False + - Out arguments are stripped + - Mutable post_self_positional args are converted to returns + - Mutability annotations are stripped (this is sound + because you cannot overload on mutability annotation) + - Return names are stripped since they are not overloadable and + some variants have return names but some not + - TensorOptions are dropped + because out= variants of factory functions don't include them + (and we want to be able to pair up factory functions with their out variants) - Finally, we want to be able to pair up related "view" and their - corresponding "view_copy" operators. We do this by optionally - stripping the trailing "_copy" from the base name. + Finally, we want to be able to pair up related "view" and their + corresponding "view_copy" operators. We do this by optionally + stripping the trailing "_copy" from the base name. + + Example of a mutable op before and after: + + f.func (Mutable operator): + _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950 + + f.func (Corresponding functional operator): + _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) # noqa: B950 + + f.func.signature() output: + _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) # noqa: B950 """ def strip_ret_annotation(r: Return) -> Return: return Return( - name=None, + name=r.name if keep_return_names else None, type=r.type, annotation=None, ) @@ -1185,6 +1334,43 @@ class FunctionSchema: if strip_view_copy_name and base_name.endswith("_copy"): base_name = base_name.replace("_copy", "") + # find mutable inputs that are not originally returned, and convert them to returns + returns_from_mutable_inputs = tuple( + # When we're grouping functions we strip the return names, + # but when we're generating the actual functional variants then we follow + # a convention for what to name the returns + Return( + name=f"{a.name}_out" if keep_return_names else None, + type=a.type, + annotation=None, + ) + for a in itertools.chain( + # Order is important here (otherwise e.g. inplace with mutable args + # and out= with mutable args won't have the same signature) + [self.arguments.self_arg.argument] + if self.arguments.self_arg is not None + else [], + self.arguments.out, + self.arguments.post_self_positional, + ) + if a.annotation is not None + and a.annotation.is_write + and not any(a.annotation == r.annotation for r in self.returns) + ) + original_returns = tuple(map(strip_ret_annotation, self.returns)) + # Ordering is important here. We expect the "mutable input" returns to come last. + returns = original_returns + returns_from_mutable_inputs + + args_sig = self.arguments.signature(strip_default=strip_default) + # See Note [arange.start_step schema] + if str(self.name) == "arange.start_step": + args_sig = Arguments.parse( + str(args_sig).replace("Scalar step", "Scalar step=1") + ) + # See Note [bernoulli.p schema] + if str(self.name) == "bernoulli.p": + args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5")) + return FunctionSchema( name=OperatorName( name=BaseOperatorName( @@ -1194,16 +1380,23 @@ class FunctionSchema: ), overload_name="", # stripped ), - arguments=self.arguments.signature(strip_default=strip_default), - returns=tuple(map(strip_ret_annotation, self.returns)), + arguments=args_sig, + returns=returns, ) def view_signature(self) -> "FunctionSchema": return self.signature(strip_view_copy_name=True) + def with_name(self, name: "OperatorName") -> "FunctionSchema": + return FunctionSchema( + name=name, + arguments=self.arguments, + returns=self.returns, + ) + @property def modifies_arguments(self) -> bool: - return self.kind() in [SchemaKind.inplace, SchemaKind.out] + return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable] def __str__(self) -> str: all_arguments_str = str(self.arguments) @@ -1587,6 +1780,10 @@ class Arguments: ret.extend(self.post_self_positional) return ret + @property + def post_self_positional_mutable(self) -> Sequence[Argument]: + return [a for a in self.post_self_positional if a.is_write] + # NB: doesn't contain out arguments @property def flat_kwarg_only(self) -> Sequence[Argument]: @@ -1640,6 +1837,13 @@ class Arguments: ret.extend(self.out) return ret + def mutable_arg_names(self) -> List[str]: + return [ + a.name + for a in self.flat_all + if a.annotation is not None and a.annotation.is_write + ] + def signature(self, *, strip_default: bool = False) -> "Arguments": # dataclasses.replace could be used here, but it is less # type safe so for now I've opted to type everything out @@ -1661,18 +1865,36 @@ class Arguments: post_self_positional=tuple( map(strip_arg_annotation, self.post_self_positional) ), + # Since TensorOptions are droped, the post_tensor_options_kwargs are + # converted to pre_tensor_options_kwargs pre_tensor_options_kwarg_only=tuple( map(strip_arg_annotation, self.pre_tensor_options_kwarg_only) - ), - # NB: tensor_options guaranteed to not have any alias annotations - tensor_options=self.tensor_options, - post_tensor_options_kwarg_only=tuple( - map(strip_arg_annotation, self.post_tensor_options_kwarg_only) - ), + ) + + tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)), + # TensorOptions are dropped in signature, + # so we can pair factory functions with their out= variants. + tensor_options=None, + post_tensor_options_kwarg_only=tuple(), # out arguments are dropped in signature out=(), ) + def remove_self_annotation(self) -> "Arguments": + assert self.self_arg is not None + return dataclasses.replace( + self, + self_arg=SelfArgument( + dataclasses.replace(self.self_arg.argument, annotation=None) + ), + ) + + def with_out_args(self, outs: List[Argument]) -> "Arguments": + assert len(self.out) == 0 + return dataclasses.replace( + self, + out=tuple(outs), + ) + @staticmethod def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]: positional: List[Argument] = [] @@ -1806,6 +2028,17 @@ class Arguments: if self.tensor_options is None: assert not self.post_tensor_options_kwarg_only + # We don't allow any of the following to have argument annotations, + # to keep things simple. + mutable_pre_self_positionals = [ + a + for a in self.pre_self_positional + if a.annotation is not None and a.annotation.is_write + ] + assert ( + len(mutable_pre_self_positionals) == 0 + ), "mutable pre_self_positional arguments are not currently supported in the schema" + # Names that validly are __iXXX__ indicating inplace operations. # Taken from https://www.python.org/dev/peps/pep-0203/#new-methods @@ -1923,6 +2156,16 @@ class OperatorName: overload_name=self.overload_name, ) + def with_overload(self, overload: str) -> "OperatorName": + return OperatorName( + name=BaseOperatorName( + base=self.name.base, + inplace=False, + dunder_method=self.name.dunder_method, + ), + overload_name=overload, + ) + def gets_generated_out_inplace_wrapper( f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py new file mode 100644 index 000000000000..0fdbb3b0ec50 --- /dev/null +++ b/torchgen/native_function_generation.py @@ -0,0 +1,382 @@ +from torchgen.model import ( + Argument, + DispatchKey, + FunctionSchema, + BaseType, + BaseTy, + Return, + Annotation, + NativeFunction, + OperatorName, + BackendIndex, + BackendMetadata, + DeviceCheckType, + SchemaKind, + Variant, +) +from torchgen.utils import ( + concatMap, +) + + +from typing import List, Tuple, Sequence, Dict +from collections import defaultdict + +# See Note: [Out ops with functional variants that don't get grouped properly] +OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [ + # This has a functional variant, but it's currently marked private. + # This function should be marked private as well (*_backward ops aren't exposed to python anyway). + "adaptive_avg_pool3d_backward.grad_input", + # There's a functional variant, _slow_conv2d_backward.output_mask, that isn't grouped properly. + # Maybe we can kill this operator in favor of convolution_backward? + "_slow_conv2d_backward.grad_input", +] + + +# See Note: [Mutable ops that cannot get an out variant] +MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ + # should be out=? + "_cummax_helper", + # should be out=? + "_cummin_helper", +] + +INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [ + # polygamma and polygamma.out both exist, but have a + # pre-self arg (while polygamma_ does not) + # We should either fix this schema so it can be grouped properly, + # or allow the codegen to generate new functional/out= NativeFunctions for this op + # (which would require changing its overload name to prevent overload ambiguity). + "polygamma_" +] + +# Groups "similar" NativeFunctions together +# example add.Tensor, add_.Tensor, add.out +# "similar" NativeFunctions are all expected to have an identical `signature()`, +# But have differing SchemaKinds. +def pre_group_native_functions( + native_functions: Sequence[NativeFunction], +) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]: + pre_grouped_native_functions: Dict[ + FunctionSchema, Dict[SchemaKind, NativeFunction] + ] = defaultdict(dict) + for f in native_functions: + d = pre_grouped_native_functions[f.func.signature()] + assert f.func.kind() not in d + d[f.func.kind()] = f + return pre_grouped_native_functions + + +# Helper function: given an inplace FunctionSchema, generate its corresponding out= variant +# Example before: +# _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) +# Example after: +# _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) +def self_to_out_signature(func: FunctionSchema) -> FunctionSchema: + # Generating an out= schema from an inplace schema. + assert func.kind() == SchemaKind.inplace + assert func.arguments.self_arg is not None + # The new out= schema has: + # - a new out argument with the same type as "func" (but with a mutable annotation) + # - The returns (if any) now alias the out= argument instead of "func" + # - an "out" overload name + return FunctionSchema( + name=func.name.remove_inplace().with_overload( + "out" if not func.name.overload_name else f"{func.name.overload_name}_out" + ), + arguments=func.arguments.remove_self_annotation().with_out_args( + [ + Argument( + name="out", + type=func.arguments.self_arg.argument.type, + default=None, + annotation=func.arguments.self_arg.argument.annotation, + ) + ] + ), + returns=func.returns, + ) + + +# Helper function: given a mutable FunctionSchema, generate its corresponding out= variant +# Example before: +# _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950 +# Example after: +# _fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) # noqa: B950 +def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema: + # Generating an out= schema from a mutable schema. + assert func.kind() == SchemaKind.mutable + # The new out= schema has: + # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments + # (if the argument is a tensor then we also return it for method chaining, + # otherwise we return nothing) + # - an "out" overload name + # + # Note that: + # (1) This also means that we can *only* generate an out= variant from a mutable schema + # if the mutable schema has at least one tensor-like non-aliasing return. + # (2) The generated out= variant still has mutable positional arguments, + # but if necessary we could probably add another out= variant that also + # functionalizes the mutable arguments (a functional_out variant) + + # More of a sanity check - our existing restrictions on schemas should enforce that + # mutable schema kinds never return their mutable arguments. + assert not any( + r.annotation is not None and r.annotation.is_write for r in func.returns + ) + + tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()] + assert len(tensorlike_rets) > 0 + + used_annotations = concatMap( + lambda a: [] if a.annotation is None else a.annotation.alias_set, + func.arguments.flat_all, + ) + valid_annotations = [ + x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations + ] + + all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns) + + new_out_args: List[Argument] = [] + # The end result of new_returns is that: + # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added. + # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any). + new_returns: List[Return] = [] + for (i, r) in enumerate(func.returns): + if r.type.is_tensor_like(): + new_out = Argument( + name=f"out{i}", + type=r.type, + default=None, + annotation=Annotation.parse(f"{valid_annotations[i]}!"), + ) + new_out_args.append(new_out) + if all_rets_are_tensors: + # The convention for out= schemas is that they only return their out arguments + # if the return is a plain Tensor (or if it's a tuple of plain Tensors) + new_ret = Return( + name=None, type=new_out.type, annotation=new_out.annotation + ) + new_returns.append(new_ret) + else: + new_returns.append(r) + + return FunctionSchema( + name=func.name.remove_inplace().with_overload( + "out" if not func.name.overload_name else f"{func.name.overload_name}_out" + ), + arguments=func.arguments.with_out_args(new_out_args), + returns=tuple(new_returns), + ) + + +# This function, given function of one SchemaKind, as well as a target SchemaKind, +# generates a new NativeFunction with the same properties, but using the target SchemaKind. +# We only actually generate functions for either functional or out= SchemaKinds. +# This function returns a tuple, with: +# - The generated NativeFunction +# - a dictionary of `BackendIndex` objects, describing which dispatch keys +# we will generate kernels for, for the new NativeFunction. +# Details are in the function, but we only generate composite kernels (in some cases) today. +def generate_function( + f: NativeFunction, k: SchemaKind +) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]: + from torchgen.api import cpp + + if k == SchemaKind.functional: + assert f.func.kind() != SchemaKind.functional + gets_composite_kernel = True + # The new "functional" NativeFunction has: + # - any mutable arguments have been converted into (immutable) returns. + # (if a mutable argument was not also a return, it gets converted to one) + # - a "functional" overload name. + # The default grouping logic in signature() actually already does this, + # so we can piggy-back off it (but we still want return names) + func = f.func.signature(keep_return_names=True).with_name( + f.func.name.remove_inplace().with_overload( + "functional" + if not f.func.name.overload_name + else f"{f.func.name.overload_name}_functional" + ) + ) + elif k == SchemaKind.out: + # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily, + # but at least today, there is no good reason to actually use them. + # we'll generate a dispatcher entry for them, but won't actually register any kernels for them. + gets_composite_kernel = False + if f.func.kind() == SchemaKind.inplace: + func = self_to_out_signature(f.func) + elif f.func.kind() == SchemaKind.mutable: + func = mutable_to_out_signature(f.func) + else: + raise AssertionError( + "We only bother generating out= functions from either inplace or mutable variants" + ) + else: + raise AssertionError( + "We currently only generate either functional or out= NativeFunctions" + ) + + if gets_composite_kernel: + backend_metadata = { + DispatchKey.CompositeExplicitAutograd: { + func.name: BackendMetadata(cpp.name(func), structured=False) + } + } + else: + backend_metadata = {} + + return ( + NativeFunction( + func=func, + use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, + # These generated fn's aren't meant to be user friendly- don't generate methods. + variants=set([Variant.function]), + structured=False, + structured_delegate=None, + structured_inherits=None, + precomputed=None, + autogen=[], + ufunc_inner_loop={}, + manual_kernel_registration=False, + manual_cpp_binding=False, + python_module=None, + category_override=None, + device_guard=False, + device_check=DeviceCheckType.NoCheck, + loc=f.loc, + cpp_no_default_args=set(), + is_abstract=f.is_abstract, + has_composite_implicit_autograd_kernel=False, + has_composite_explicit_autograd_kernel=gets_composite_kernel, + # Every generated NativeFunction gets a "generated" tag, so it's easy to tell + # which NativeFunction objects did not come directly from native_functions.yaml. + tags=set(["generated"]), + ), + backend_metadata, + ) + + +# This function is responsible for adding generated NativeFunctions which don't appear +# explicitly in the codegen. +# You can inspect the full list of NativeFunctions yourself with the torchgen package, by running +# torchgen.parse_native_yaml("aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml") +# (Maybe we should make a friendly API for this) +# +# Note: this function *mutates* its two inputs, +# adding the new NativeFunctions / BackendMetadata to them +def add_generated_native_functions( + rs: List[NativeFunction], + indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], +) -> None: + # The main code for gnerating new NativeFunctions + # First we group of NaitveFunctions by schema kind, + # then we detect which ones are missing and generate them. + pre_grouped_native_functions = pre_group_native_functions(rs) + for k, d in pre_grouped_native_functions.items(): + has_functional = SchemaKind.functional in d + has_inplace = SchemaKind.inplace in d + has_mutable = SchemaKind.mutable in d + has_out = SchemaKind.out in d + + # We automatically generate a few native functions that don't exist in the yaml, for a few reasons: + # (1) If an operator has an inplace/out= variant but no functional variant, we can generate + # a simple functional variant that the functionalization pass can consume. + # (2) If an operator has an inplace and functional but no out= variant, we generate an out= + # variant, mostly so we can easily pair up functions into NativeFunctionsGroup, + # while maintaining the constraint that the out= variant is "required". + # + # For now, we don't bother generated NativeFunctions for existing operators + # that only have a functional variant. + if has_mutable or has_inplace or has_out: + + # Don't bother generating functions trio's for native functions that bypass the dispatcher. + are_manual = all(f.manual_cpp_binding for f in d.values()) + # Don't bother generating functional + out= variants for view operators + has_view_ops = ( + has_inplace and "inplace_view" in d[SchemaKind.inplace].tags + ) or any(f.is_view_op for f in d.values()) + # Don't generate the other variants for CompositeImplicitAutograd operators. + # We could probably do this, but the main benefit of generating the function triplets + # is for transforms that need them, and transforms don't need to act directly + # on CompositeImplicitAutograd operators (since we let them decompose). + are_composite_implicit = all( + f.has_composite_implicit_autograd_kernel for f in d.values() + ) + if are_manual or has_view_ops or are_composite_implicit: + continue + if has_out and len(d.values()) == 1: + # Note: [Out ops with functional variants that don't get grouped properly] + # In theory we could validly have an out= operator in native_functions.yaml + # that has no other variants. + # But today, all of the operators where that's the case actually do have + # functional variants, that we are just unable to pair up properly. + # I think banning this all together is probably safer + # (you can always add a functional variant yourself if you want to add a new out= operator). + # + # We should probably fix the existing cases; this check is to prevent us from adding more over time. + if ( + str(d[SchemaKind.out].func.name) + not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY + ): + raise AssertionError( + f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}" + ) + continue + + # Some inplace ops that have problematic schemas (that we should fix), which prevent us + # from generating out= and functional variants + if ( + has_inplace + and str(d[SchemaKind.inplace].func.name) + in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY + ): + continue + + base_fn = ( + d[SchemaKind.inplace] + if has_inplace + else d[SchemaKind.mutable] + if has_mutable + else d[SchemaKind.out] + ) + + # Note: [Mutable ops that cannot get an out variant] + # We can only generate an out= variant if either: + # - the original function has tensor-like returns (since we can convert them to out kwargs) + # - or it's inplace (since we can convert `self` to an out kwarg) + # There are only two functions that don't fit this criteria today though, + # and they both look like they should be fixed to be out= variants, + # so if feels safer to ban this schema all-together + gets_out_variant = not has_out and ( + base_fn.func.kind() == SchemaKind.inplace + or any(r.type.is_tensor_like() for r in base_fn.func.returns) + ) + if not has_out and not gets_out_variant: + if ( + str(base_fn.func.name) + not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT + ): + raise AssertionError( + f"""Found a mutable operator that we could not generate an out= variant for: {str(base_fn.func)}. +These operators are problematic, because we can't easily auto-generate functionalization code for them. If you really need +the operator have the schema mentioned, that add the name of the operator to the allow-list. Otherwise if possible, +please convert it to an inplace operator""" + ) + + # Generate an out= variant + if gets_out_variant: + fn, metadata = generate_function(base_fn, SchemaKind.out) + d[SchemaKind.out] = fn + BackendIndex.grow_index(indices, metadata) + rs.append(fn) + + # Generate a functional variant, but only do it if the operator got an out= variant + # (Functional variants are only useful if we can group up the variants, + # which we can only do if they have an out= variant) + if not has_functional and (has_out or gets_out_variant): + fn, metadata = generate_function(base_fn, SchemaKind.functional) + d[SchemaKind.functional] = fn + BackendIndex.grow_index(indices, metadata) + rs.append(fn)