mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add C shim fallback for fill_ (#156245)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156245 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
208ec60e72
commit
4b6cbf528b
@ -3662,6 +3662,18 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
self.check_model(Model(), inputs)
|
||||
|
||||
def test_fill__fallback(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, inp: torch.Tensor, scalar: float):
|
||||
torch.ops.aten.fill_(inp, scalar)
|
||||
return inp
|
||||
|
||||
inputs = (torch.rand((3, 3, 4, 2), device=self.device), 0.5)
|
||||
self.check_model(Model(), inputs)
|
||||
|
||||
@common_utils.parametrize("embed_kernel_binary", [False, True])
|
||||
def test_repeated_user_defined_triton_kernel(self, embed_kernel_binary):
|
||||
if self.device != GPU_TYPE:
|
||||
|
@ -73,6 +73,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cummin(AtenTensorHandle self, in
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fill__Scalar(AtenTensorHandle self, double value);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
|
@ -80,6 +80,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cummin(AtenTensorHandle self, i
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fill__Scalar(AtenTensorHandle self, double value);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
|
@ -49,6 +49,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cummin(AtenTensorHandle self, in
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_fill__Scalar(AtenTensorHandle self, double value);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_hann_window(int64_t window_length, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histogram_bin_ct(AtenTensorHandle self, int64_t bins, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
|
@ -91,6 +91,7 @@ inductor_fallback_ops: dict[str, dict[str, list[str]]] = {
|
||||
"aten.cumprod.default": {},
|
||||
"aten.cumsum.default": {},
|
||||
"aten.exponential.default": {},
|
||||
"aten.fill_.Scalar": {},
|
||||
"aten.fractional_max_pool2d_backward.default": {},
|
||||
"aten.fractional_max_pool2d.default": {},
|
||||
"aten.fractional_max_pool3d_backward.default": {},
|
||||
|
Reference in New Issue
Block a user