Specialize symfloats during equality checks (#140830)

Fixes `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=0 python
    test/inductor/test_torchinductor_opinfo.py
    TestInductorOpInfoCPU.test_comprehensive_nn_functional_local_response_norm_cpu_float32`
    when `specialize_float=False`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140830
Approved by: https://github.com/ezyang
This commit is contained in:
Bob Ren
2024-11-16 18:56:23 -08:00
committed by PyTorch MergeBot
parent 6094f17ada
commit 602ae9cbcf

View File

@ -255,8 +255,7 @@ class C10_API Scalar {
auto val = v.z;
return (val.real() == num) && (val.imag() == T());
} else if (isFloatingPoint()) {
TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality");
return v.d == num;
return toDouble() == num;
} else if (tag == Tag::HAS_i) {
if (overflows<T>(v.i, /* strict_unsigned */ true)) {
return false;
@ -288,8 +287,7 @@ class C10_API Scalar {
TORCH_INTERNAL_ASSERT(!isSymbolic());
return v.z == num;
} else if (isFloatingPoint()) {
TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality");
return (v.d == num.real()) && (num.imag() == T());
return (toDouble() == num.real()) && (num.imag() == T());
} else if (tag == Tag::HAS_i) {
if (overflows<T>(v.i, /* strict_unsigned */ true)) {
return false;