fix incorrect c10::SymFloat::sqrt (#141728)

Fixes the silent correctness for SDPA in https://github.com/pytorch/pytorch/issues/141710

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141728
Approved by: https://github.com/Skylion007, https://github.com/ezyang, https://github.com/drisspg
ghstack dependencies: #141725
This commit is contained in:
Brian Hirsh
2024-12-03 10:49:40 -08:00
committed by PyTorch MergeBot
parent af3e7389ef
commit 20912ba582
2 changed files with 2 additions and 2 deletions

View File

@ -146,7 +146,7 @@ SymFloat SymFloat::sqrt() const {
if (!is_symbolic()) { if (!is_symbolic()) {
return SymFloat(std::sqrt(data_)); return SymFloat(std::sqrt(data_));
} }
auto other = SymFloat(-0.5); auto other = SymFloat(0.5);
auto res = normalize_symfloats(*this, other); auto res = normalize_symfloats(*this, other);
return SymFloat(res[0]->pow(res[1])); return SymFloat(res[0]->pow(res[1]));
} }

View File

@ -6421,7 +6421,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
with torch._dynamo.config.patch(assume_static_by_default=False): with torch._dynamo.config.patch(assume_static_by_default=False):
out_ref = f(x_ref, s0, s1, s2) out_ref = f(x_ref, s0, s1, s2)
out = f_compiled(x, s0, s1, s2) out = f_compiled(x, s0, s1, s2)
self.assertFalse(torch.any(torch.isnan(out))) self.assertEqual(out_ref, out)
def test_bitwise_op_guard(self): def test_bitwise_op_guard(self):
# attempt evaluating a guard with BitwiseFn_bitwise_[and/or] # attempt evaluating a guard with BitwiseFn_bitwise_[and/or]