mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add skip_dtype_check_in_meta_registrations config to torch/fx/experimental/_config (#153513)
Helion relies on torch/fx/experimental 's fake_tensor tracing but does its own dtype checking, which conflicts with some meta kernel's existing dtype checking. This PR adds a config so that we skip those dtype checking in meta kernels and rely on the calling system to do the dtype checking. Currently it only applies to `baddbmm`, but I expect that similar changes will need to be done to other meta kernels in the future. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153513 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
4015166e5d
commit
0139ce9303
@ -2174,10 +2174,11 @@ def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1):
|
|||||||
self = self.expand((dim1, dim2, dim3))
|
self = self.expand((dim1, dim2, dim3))
|
||||||
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
|
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
|
||||||
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
|
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
|
||||||
torch._check(
|
if not exp_config.skip_dtype_check_in_meta_registrations:
|
||||||
self.dtype == batch1.dtype == batch2.dtype,
|
torch._check(
|
||||||
lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
|
self.dtype == batch1.dtype == batch2.dtype,
|
||||||
)
|
lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
|
||||||
|
)
|
||||||
batch1_sizes = batch1.shape
|
batch1_sizes = batch1.shape
|
||||||
batch2_sizes = batch2.shape
|
batch2_sizes = batch2.shape
|
||||||
bs = batch1_sizes[0]
|
bs = batch1_sizes[0]
|
||||||
|
@ -97,6 +97,9 @@ meta_nonzero_assume_all_nonzero = False
|
|||||||
# Currently an experimental option for export.
|
# Currently an experimental option for export.
|
||||||
backed_size_oblivious = False
|
backed_size_oblivious = False
|
||||||
|
|
||||||
|
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
|
||||||
|
skip_dtype_check_in_meta_registrations = False
|
||||||
|
|
||||||
from torch.utils._config_module import install_config_module
|
from torch.utils._config_module import install_config_module
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user