Set simdlen based on ATEN_CPU_CAPABILITY (#123514)

It is part of https://github.com/pytorch/pytorch/issues/123224. Set simdlen based on the environment ATEN_CPU_CAPABILITY to control CPU vec ISA like eager.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123514
Approved by: https://github.com/jgong5, https://github.com/peterbell10
This commit is contained in:
CaoE
2024-06-06 07:52:26 -07:00
committed by PyTorch MergeBot
parent df43d5843e
commit b66e3f0957
6 changed files with 205 additions and 19 deletions

View File

@ -4,6 +4,7 @@ import copy
import functools
import itertools
import math
import os
import platform
import sys
import unittest
@ -66,12 +67,13 @@ aten = torch.ops.aten
check_model = test_torchinductor.check_model
requires_vectorization = unittest.skipUnless(
codecache.valid_vec_isa_list(), "Does not support vectorization"
codecache.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default",
"Does not support vectorization",
)
def check_metrics_vec_kernel_count(num_expected_vec_kernels):
if codecache.valid_vec_isa_list():
if codecache.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default":
assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels
@ -1580,6 +1582,71 @@ class CPUReproTests(TestCase):
metrics.reset()
self.common(fn, (value,))
@unittest.skipIf(
not codecache.valid_vec_isa_list()
or "avx2" in [str(vec_isa) for vec_isa in codecache.valid_vec_isa_list()],
"Does not support vectorization or not s390x/neon machine",
)
@patch("torch.cuda.is_available", lambda: False)
def test_auto_zvec_neon_simd(self):
vec_zvec_neon = codecache.valid_vec_isa_list()[0]
self.assertTrue(vec_zvec_neon.bit_width() == 256)
with config.patch({"cpp.simdlen": 0}):
isa = codecache.pick_vec_isa()
self.assertFalse(isa)
with config.patch({"cpp.simdlen": 1}):
isa = codecache.pick_vec_isa()
self.assertFalse(isa)
with config.patch({"cpp.simdlen": 257}):
isa = codecache.pick_vec_isa()
self.assertFalse(isa)
with config.patch({"cpp.simdlen": 256}):
isa = codecache.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon)
pre_var = os.getenv("ATEN_CPU_CAPABILITY")
if pre_var:
os.environ.pop("ATEN_CPU_CAPABILITY")
try:
with config.patch({"cpp.simdlen": None}):
isa = codecache.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon)
with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "avx2"
isa = codecache.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon)
with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "avx512"
isa = codecache.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon)
with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "default"
isa = codecache.pick_vec_isa()
self.assertFalse(isa)
with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "neon"
isa = codecache.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon)
with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "zvector"
isa = codecache.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon)
finally:
if pre_var:
os.environ["ATEN_CPU_CAPABILITY"] = pre_var
elif os.getenv("ATEN_CPU_CAPABILITY"):
os.environ.pop("ATEN_CPU_CAPABILITY")
@unittest.skipIf(
platform.machine() != "x86_64" or not codecache.valid_vec_isa_list(),
"Does not support vectorization or not x86_64 machine",
@ -1595,13 +1662,6 @@ class CPUReproTests(TestCase):
self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32)
self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16)
with config.patch({"cpp.simdlen": None}):
isa = codecache.pick_vec_isa()
if vec_avx512 in codecache.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
self.assertTrue(isa == vec_avx2)
with config.patch({"cpp.simdlen": 0}):
isa = codecache.pick_vec_isa()
self.assertFalse(isa)
@ -1631,6 +1691,60 @@ class CPUReproTests(TestCase):
isa = codecache.pick_vec_isa()
self.assertTrue(isa == vec_avx2)
pre_var = os.getenv("ATEN_CPU_CAPABILITY")
if pre_var:
os.environ.pop("ATEN_CPU_CAPABILITY")
try:
with config.patch({"cpp.simdlen": None}):
isa = codecache.pick_vec_isa()
if vec_avx512 in codecache.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
self.assertTrue(isa == vec_avx2)
with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "avx2"
isa = codecache.pick_vec_isa()
if vec_avx512 in codecache.valid_vec_isa_list():
self.assertTrue(isa == vec_avx2)
elif vec_avx2 in codecache.valid_vec_isa_list():
self.assertTrue(isa == vec_avx2)
with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "avx512"
isa = codecache.pick_vec_isa()
if vec_avx512 in codecache.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
self.assertTrue(isa == vec_avx2)
with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "default"
isa = codecache.pick_vec_isa()
self.assertFalse(isa)
with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "neon"
isa = codecache.pick_vec_isa()
if vec_avx512 in codecache.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
self.assertTrue(isa == vec_avx2)
with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "zvector"
isa = codecache.pick_vec_isa()
if vec_avx512 in codecache.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
self.assertTrue(isa == vec_avx2)
finally:
if pre_var:
os.environ["ATEN_CPU_CAPABILITY"] = pre_var
elif os.getenv("ATEN_CPU_CAPABILITY"):
os.environ.pop("ATEN_CPU_CAPABILITY")
@requires_vectorization
@patch("torch.cuda.is_available", lambda: False)
def test_masked_fill_softmax(self):
@ -3371,6 +3485,7 @@ class CPUReproTests(TestCase):
self.common(m, (idx, x))
check_metrics_vec_kernel_count(1)
@requires_vectorization
def test_embedding_vec_bf16(self):
class M(torch.nn.Module):
def __init__(self):
@ -3655,7 +3770,7 @@ class CPUReproTests(TestCase):
x = torch.randint(0, 100, (819,), dtype=torch.int64)
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
check_metrics_vec_kernel_count(1)
def test_reduction_float_to_int64(self):
# https://github.com/pytorch/pytorch/issues/124821
@ -3665,7 +3780,7 @@ class CPUReproTests(TestCase):
x = torch.randint(0, 100, (22, 51), dtype=torch.int64)
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
check_metrics_vec_kernel_count(1)
@config.patch({"cpp.dynamic_threads": True})
def test_reduction_with_dynamic_threads(self):

View File

@ -8,6 +8,7 @@ import torch
import torch._dynamo
import torch.utils.cpp_extension
from torch._C import FileCheck
from torch._dynamo.testing import expectedFailureScalar
try:
from extension_backends.cpp.extension_codegen_backend import (
@ -103,6 +104,9 @@ class ExtensionBackendTests(TestCase):
# return the working directory (see setUp)
os.chdir(self.old_working_dir)
# Fails when testing the scalar version
# See https://github.com/pytorch/pytorch/issues/126372.
@expectedFailureScalar
def test_open_device_registration(self):
torch.utils.rename_privateuse1_backend("extension_device")
torch._register_device_module("extension_device", self.module)

View File

@ -34,6 +34,7 @@ from torch._dynamo.debug_utils import aot_graph_input_parser
from torch._dynamo.testing import (
CompileCounterWithBackend,
expectedFailureCodegenDynamic,
expectedFailureScalar,
rand_strided,
same,
skipIfPy312,
@ -1315,6 +1316,9 @@ class CommonTemplate:
self.common(fn, (torch.randn(1024),))
# Fails when testing the scalar version
# See https://github.com/pytorch/pytorch/issues/128029.
@expectedFailureScalar
@skipIfRocm
@config.patch(debug_index_asserts=False)
def test_neg_index(self):
@ -1577,16 +1581,40 @@ class CommonTemplate:
def fn(a):
return torch.var(a)
self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float32),)))
self.common(fn, ((torch.rand((14923), dtype=torch.float32),)))
atol = None
rtol = None
if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default":
atol = 1e-4
rtol = 1e-4
self.common(
fn,
((torch.rand((10, 3, 352, 352), dtype=torch.float32),)),
rtol=rtol,
atol=atol,
)
self.common(
fn, ((torch.rand((14923), dtype=torch.float32),)), rtol=rtol, atol=atol
)
@skipCPUIf(IS_MACOS, "fails on macos")
def test_multilayer_var_lowp(self):
def fn(a):
return torch.var(a)
self.common(fn, (torch.rand((16, 16, 352, 352), dtype=torch.float16),))
self.common(fn, (torch.rand((14923), dtype=torch.float16),))
atol = None
rtol = None
if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default":
atol = 1e-3
rtol = 1e-3
self.common(
fn,
(torch.rand((16, 16, 352, 352), dtype=torch.float16),),
rtol=rtol,
atol=atol,
)
self.common(
fn, (torch.rand((14923), dtype=torch.float16),), rtol=rtol, atol=atol
)
def test_split_cumsum(self):
def fn(a):
@ -8199,7 +8227,7 @@ class CommonTemplate:
rand_strided(shape, stride, dtype).requires_grad_(True).add(1)
for shape, stride, dtype in args
]
self.common(forward, args)
self.common(forward, args, atol=1e-05, rtol=1e-05)
@requires_gpu()
def test_tmp_not_defined_issue3(self):
@ -9281,6 +9309,7 @@ class CommonTemplate:
# To support this behavior, we need to allow const-propping tensors that store symint data.
# For now, dynamo will explicitly graph break when it encounters user code with this behavior.
@expectedFailureCodegenDynamic
@expectedFailureScalar
def test_AllenaiLongformerBase_repro(self):
def fn(query, scores, window_overlap):
batch_size, seq_len, num_heads, _ = query.size()
@ -9316,6 +9345,9 @@ class CommonTemplate:
opt_fn = torch._dynamo.optimize("inductor")(fn)
_, code = run_and_get_cpp_code(opt_fn, *args)
print(code)
# When testing the scalar version, i.e., ATEN_CPU_CAPABILITY=default,
# static_cast<int>(256) is not found, but static_cast<int64_t>(256).
# See https://github.com/pytorch/pytorch/issues/126262.
FileCheck().check_count(
"static_cast<int32_t>(256)",
1,

View File

@ -381,6 +381,12 @@ def expectedFailureDynamicWrapper(fn):
return fn
def expectedFailureScalar(fn):
if os.getenv("ATEN_CPU_CAPABILITY") == "default":
return unittest.expectedFailure(fn)
return fn
def reset_rng_state(use_xla=False):
torch.manual_seed(1337)
random.seed(1337)

View File

@ -1465,6 +1465,31 @@ invalid_vec_isa = InvalidVecISA()
supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()]
def get_isa_from_cpu_capability(
capability: str | None, vec_isa_list: List[VecISA], invalid_vec_isa: InvalidVecISA
):
# VSX is not supported in inductor
capability_to_isa_str = {
"default": "INVALID_VEC_ISA",
"neon": "asimd",
"zvector": "zvector",
"avx2": "avx2",
"avx512": "avx512",
}
if capability in capability_to_isa_str.keys():
isa_str = capability_to_isa_str[capability]
if isa_str == "INVALID_VEC_ISA":
return invalid_vec_isa
for vec_isa in vec_isa_list:
if isa_str == str(vec_isa):
return vec_isa
if capability:
warnings.warn(f"ignoring invalid value for ATEN_CPU_CAPABILITY {capability}")
return vec_isa_list[0]
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
# might have too much redundant content that is useless for ISA check. Hence,
# we only cache some key isa information.
@ -1507,10 +1532,13 @@ def pick_vec_isa() -> VecISA:
if not _valid_vec_isa_list:
return invalid_vec_isa
# If the simdlen is None, it indicates determine the vectorization length automatically
# If the simdlen is None, set simdlen based on the environment ATEN_CPU_CAPABILITY
# to control CPU vec ISA
if config.cpp.simdlen is None:
assert _valid_vec_isa_list
return _valid_vec_isa_list[0]
return get_isa_from_cpu_capability(
os.getenv("ATEN_CPU_CAPABILITY"), _valid_vec_isa_list, invalid_vec_isa
)
for isa in _valid_vec_isa_list:
if config.cpp.simdlen == isa.bit_width():

View File

@ -24,6 +24,7 @@
#include <c10/util/generic_math.h>
#include <c10/util/Half.h>
#include <c10/util/TypeCast.h>
#include <ATen/native/Math.h>
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON)
#define INDUCTOR_USE_VECTOR_TYPES() 1