Add skip_dtype_check_in_meta_registrations config to torch/fx/experimental/_config (#153513)

Helion relies on torch/fx/experimental 's fake_tensor tracing but does its own dtype checking, which conflicts with some meta kernel's existing dtype checking. This PR adds a config so that we skip those dtype checking in meta kernels and rely on the calling system to do the dtype checking.

Currently it only applies to `baddbmm`, but I expect that similar changes will need to be done to other meta kernels in the future.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153513
Approved by: https://github.com/jansel
This commit is contained in:
Will Feng
2025-05-14 05:44:31 +00:00
committed by PyTorch MergeBot
parent 4015166e5d
commit 0139ce9303
2 changed files with 8 additions and 4 deletions

View File

@ -2174,10 +2174,11 @@ def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1):
self = self.expand((dim1, dim2, dim3)) self = self.expand((dim1, dim2, dim3))
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
torch._check( if not exp_config.skip_dtype_check_in_meta_registrations:
self.dtype == batch1.dtype == batch2.dtype, torch._check(
lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}", self.dtype == batch1.dtype == batch2.dtype,
) lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
)
batch1_sizes = batch1.shape batch1_sizes = batch1.shape
batch2_sizes = batch2.shape batch2_sizes = batch2.shape
bs = batch1_sizes[0] bs = batch1_sizes[0]

View File

@ -97,6 +97,9 @@ meta_nonzero_assume_all_nonzero = False
# Currently an experimental option for export. # Currently an experimental option for export.
backed_size_oblivious = False backed_size_oblivious = False
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
skip_dtype_check_in_meta_registrations = False
from torch.utils._config_module import install_config_module from torch.utils._config_module import install_config_module