Add C shim fallback for fill_ (#156245)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156245
Approved by: https://github.com/desertfire
This commit is contained in:
Jane Xu
2025-06-20 08:10:32 -07:00
committed by PyTorch MergeBot
parent 208ec60e72
commit 4b6cbf528b
5 changed files with 16 additions and 0 deletions

View File

@ -3662,6 +3662,18 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), inputs)
def test_fill__fallback(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, inp: torch.Tensor, scalar: float):
torch.ops.aten.fill_(inp, scalar)
return inp
inputs = (torch.rand((3, 3, 4, 2), device=self.device), 0.5)
self.check_model(Model(), inputs)
@common_utils.parametrize("embed_kernel_binary", [False, True])
def test_repeated_user_defined_triton_kernel(self, embed_kernel_binary):
if self.device != GPU_TYPE:

View File

@ -73,6 +73,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cummin(AtenTensorHandle self, in
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fill__Scalar(AtenTensorHandle self, double value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1);

View File

@ -80,6 +80,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cummin(AtenTensorHandle self, i
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fill__Scalar(AtenTensorHandle self, double value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1);

View File

@ -49,6 +49,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cummin(AtenTensorHandle self, in
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_fill__Scalar(AtenTensorHandle self, double value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_hann_window(int64_t window_length, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histogram_bin_ct(AtenTensorHandle self, int64_t bins, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0, AtenTensorHandle* ret1);

View File

@ -91,6 +91,7 @@ inductor_fallback_ops: dict[str, dict[str, list[str]]] = {
"aten.cumprod.default": {},
"aten.cumsum.default": {},
"aten.exponential.default": {},
"aten.fill_.Scalar": {},
"aten.fractional_max_pool2d_backward.default": {},
"aten.fractional_max_pool2d.default": {},
"aten.fractional_max_pool3d_backward.default": {},