From 0139ce93036c4be05bf215fef184607994b6b110 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 14 May 2025 05:44:31 +0000 Subject: [PATCH] Add skip_dtype_check_in_meta_registrations config to torch/fx/experimental/_config (#153513) Helion relies on torch/fx/experimental 's fake_tensor tracing but does its own dtype checking, which conflicts with some meta kernel's existing dtype checking. This PR adds a config so that we skip those dtype checking in meta kernels and rely on the calling system to do the dtype checking. Currently it only applies to `baddbmm`, but I expect that similar changes will need to be done to other meta kernels in the future. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153513 Approved by: https://github.com/jansel --- torch/_meta_registrations.py | 9 +++++---- torch/fx/experimental/_config.py | 3 +++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 261a88a9b59a..e968204a3d70 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2174,10 +2174,11 @@ def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1): self = self.expand((dim1, dim2, dim3)) torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") - torch._check( - self.dtype == batch1.dtype == batch2.dtype, - lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}", - ) + if not exp_config.skip_dtype_check_in_meta_registrations: + torch._check( + self.dtype == batch1.dtype == batch2.dtype, + lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}", + ) batch1_sizes = batch1.shape batch2_sizes = batch2.shape bs = batch1_sizes[0] diff --git a/torch/fx/experimental/_config.py b/torch/fx/experimental/_config.py index 58859607eee2..ce4296b6410c 100644 --- a/torch/fx/experimental/_config.py +++ b/torch/fx/experimental/_config.py @@ -97,6 +97,9 @@ meta_nonzero_assume_all_nonzero = False # Currently an experimental option for export. backed_size_oblivious = False +# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking. +skip_dtype_check_in_meta_registrations = False + from torch.utils._config_module import install_config_module