From 216b6a952c4c4457e820235e9f987b1e063d6e88 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 11 Nov 2024 21:01:40 -0300 Subject: [PATCH] `triangular_solve`: fix meta function output argument dtype check. (#140286) Tracking issue: #138399 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140286 Approved by: https://github.com/ezyang ghstack dependencies: #140186 --- test/test_ops.py | 1 - torch/_meta_registrations.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 7808222ae9d7..4035e3b13a30 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -222,7 +222,6 @@ meta_consistency_out_dtype_mismatch_xfails = { xfail("take"), xfail("transpose_copy"), xfail("tril"), - xfail("triangular_solve"), xfail("triu"), xfail("trunc"), xfail("unfold_copy"), diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index b14b57ba4d1d..9a15a6bd71aa 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1495,7 +1495,7 @@ def linalg_solve_triangular_meta( @register_meta(aten.triangular_solve) -@out_wrapper("X", "M") +@out_wrapper("X", "M", exact_dtype=True) def triangular_solve_meta( self: Tensor, A: Tensor,