mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add skip_dtype_check_in_meta_registrations config to torch/fx/experimental/_config (#153513)
Helion relies on torch/fx/experimental 's fake_tensor tracing but does its own dtype checking, which conflicts with some meta kernel's existing dtype checking. This PR adds a config so that we skip those dtype checking in meta kernels and rely on the calling system to do the dtype checking. Currently it only applies to `baddbmm`, but I expect that similar changes will need to be done to other meta kernels in the future. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153513 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
4015166e5d
commit
0139ce9303
@ -2174,10 +2174,11 @@ def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1):
|
||||
self = self.expand((dim1, dim2, dim3))
|
||||
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
|
||||
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
|
||||
torch._check(
|
||||
self.dtype == batch1.dtype == batch2.dtype,
|
||||
lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
|
||||
)
|
||||
if not exp_config.skip_dtype_check_in_meta_registrations:
|
||||
torch._check(
|
||||
self.dtype == batch1.dtype == batch2.dtype,
|
||||
lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
|
||||
)
|
||||
batch1_sizes = batch1.shape
|
||||
batch2_sizes = batch2.shape
|
||||
bs = batch1_sizes[0]
|
||||
|
@ -97,6 +97,9 @@ meta_nonzero_assume_all_nonzero = False
|
||||
# Currently an experimental option for export.
|
||||
backed_size_oblivious = False
|
||||
|
||||
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
|
||||
skip_dtype_check_in_meta_registrations = False
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user