mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Set simdlen based on ATEN_CPU_CAPABILITY (#123514)
It is part of https://github.com/pytorch/pytorch/issues/123224. Set simdlen based on the environment ATEN_CPU_CAPABILITY to control CPU vec ISA like eager. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123514 Approved by: https://github.com/jgong5, https://github.com/peterbell10
This commit is contained in:
@ -4,6 +4,7 @@ import copy
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import unittest
|
||||
@ -66,12 +67,13 @@ aten = torch.ops.aten
|
||||
check_model = test_torchinductor.check_model
|
||||
|
||||
requires_vectorization = unittest.skipUnless(
|
||||
codecache.valid_vec_isa_list(), "Does not support vectorization"
|
||||
codecache.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default",
|
||||
"Does not support vectorization",
|
||||
)
|
||||
|
||||
|
||||
def check_metrics_vec_kernel_count(num_expected_vec_kernels):
|
||||
if codecache.valid_vec_isa_list():
|
||||
if codecache.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default":
|
||||
assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels
|
||||
|
||||
|
||||
@ -1580,6 +1582,71 @@ class CPUReproTests(TestCase):
|
||||
metrics.reset()
|
||||
self.common(fn, (value,))
|
||||
|
||||
@unittest.skipIf(
|
||||
not codecache.valid_vec_isa_list()
|
||||
or "avx2" in [str(vec_isa) for vec_isa in codecache.valid_vec_isa_list()],
|
||||
"Does not support vectorization or not s390x/neon machine",
|
||||
)
|
||||
@patch("torch.cuda.is_available", lambda: False)
|
||||
def test_auto_zvec_neon_simd(self):
|
||||
vec_zvec_neon = codecache.valid_vec_isa_list()[0]
|
||||
self.assertTrue(vec_zvec_neon.bit_width() == 256)
|
||||
|
||||
with config.patch({"cpp.simdlen": 0}):
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertFalse(isa)
|
||||
|
||||
with config.patch({"cpp.simdlen": 1}):
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertFalse(isa)
|
||||
|
||||
with config.patch({"cpp.simdlen": 257}):
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertFalse(isa)
|
||||
|
||||
with config.patch({"cpp.simdlen": 256}):
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertTrue(isa == vec_zvec_neon)
|
||||
|
||||
pre_var = os.getenv("ATEN_CPU_CAPABILITY")
|
||||
if pre_var:
|
||||
os.environ.pop("ATEN_CPU_CAPABILITY")
|
||||
|
||||
try:
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertTrue(isa == vec_zvec_neon)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
os.environ["ATEN_CPU_CAPABILITY"] = "avx2"
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertTrue(isa == vec_zvec_neon)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
os.environ["ATEN_CPU_CAPABILITY"] = "avx512"
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertTrue(isa == vec_zvec_neon)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
os.environ["ATEN_CPU_CAPABILITY"] = "default"
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertFalse(isa)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
os.environ["ATEN_CPU_CAPABILITY"] = "neon"
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertTrue(isa == vec_zvec_neon)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
os.environ["ATEN_CPU_CAPABILITY"] = "zvector"
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertTrue(isa == vec_zvec_neon)
|
||||
finally:
|
||||
if pre_var:
|
||||
os.environ["ATEN_CPU_CAPABILITY"] = pre_var
|
||||
elif os.getenv("ATEN_CPU_CAPABILITY"):
|
||||
os.environ.pop("ATEN_CPU_CAPABILITY")
|
||||
|
||||
@unittest.skipIf(
|
||||
platform.machine() != "x86_64" or not codecache.valid_vec_isa_list(),
|
||||
"Does not support vectorization or not x86_64 machine",
|
||||
@ -1595,13 +1662,6 @@ class CPUReproTests(TestCase):
|
||||
self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32)
|
||||
self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
isa = codecache.pick_vec_isa()
|
||||
if vec_avx512 in codecache.valid_vec_isa_list():
|
||||
self.assertTrue(isa == vec_avx512)
|
||||
else:
|
||||
self.assertTrue(isa == vec_avx2)
|
||||
|
||||
with config.patch({"cpp.simdlen": 0}):
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertFalse(isa)
|
||||
@ -1631,6 +1691,60 @@ class CPUReproTests(TestCase):
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertTrue(isa == vec_avx2)
|
||||
|
||||
pre_var = os.getenv("ATEN_CPU_CAPABILITY")
|
||||
if pre_var:
|
||||
os.environ.pop("ATEN_CPU_CAPABILITY")
|
||||
|
||||
try:
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
isa = codecache.pick_vec_isa()
|
||||
if vec_avx512 in codecache.valid_vec_isa_list():
|
||||
self.assertTrue(isa == vec_avx512)
|
||||
else:
|
||||
self.assertTrue(isa == vec_avx2)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
os.environ["ATEN_CPU_CAPABILITY"] = "avx2"
|
||||
isa = codecache.pick_vec_isa()
|
||||
if vec_avx512 in codecache.valid_vec_isa_list():
|
||||
self.assertTrue(isa == vec_avx2)
|
||||
elif vec_avx2 in codecache.valid_vec_isa_list():
|
||||
self.assertTrue(isa == vec_avx2)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
os.environ["ATEN_CPU_CAPABILITY"] = "avx512"
|
||||
isa = codecache.pick_vec_isa()
|
||||
if vec_avx512 in codecache.valid_vec_isa_list():
|
||||
self.assertTrue(isa == vec_avx512)
|
||||
else:
|
||||
self.assertTrue(isa == vec_avx2)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
os.environ["ATEN_CPU_CAPABILITY"] = "default"
|
||||
isa = codecache.pick_vec_isa()
|
||||
self.assertFalse(isa)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
os.environ["ATEN_CPU_CAPABILITY"] = "neon"
|
||||
isa = codecache.pick_vec_isa()
|
||||
if vec_avx512 in codecache.valid_vec_isa_list():
|
||||
self.assertTrue(isa == vec_avx512)
|
||||
else:
|
||||
self.assertTrue(isa == vec_avx2)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
os.environ["ATEN_CPU_CAPABILITY"] = "zvector"
|
||||
isa = codecache.pick_vec_isa()
|
||||
if vec_avx512 in codecache.valid_vec_isa_list():
|
||||
self.assertTrue(isa == vec_avx512)
|
||||
else:
|
||||
self.assertTrue(isa == vec_avx2)
|
||||
finally:
|
||||
if pre_var:
|
||||
os.environ["ATEN_CPU_CAPABILITY"] = pre_var
|
||||
elif os.getenv("ATEN_CPU_CAPABILITY"):
|
||||
os.environ.pop("ATEN_CPU_CAPABILITY")
|
||||
|
||||
@requires_vectorization
|
||||
@patch("torch.cuda.is_available", lambda: False)
|
||||
def test_masked_fill_softmax(self):
|
||||
@ -3371,6 +3485,7 @@ class CPUReproTests(TestCase):
|
||||
self.common(m, (idx, x))
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
@requires_vectorization
|
||||
def test_embedding_vec_bf16(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -3655,7 +3770,7 @@ class CPUReproTests(TestCase):
|
||||
x = torch.randint(0, 100, (819,), dtype=torch.int64)
|
||||
metrics.reset()
|
||||
self.common(fn, (x,))
|
||||
assert metrics.generated_cpp_vec_kernel_count == 1
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
def test_reduction_float_to_int64(self):
|
||||
# https://github.com/pytorch/pytorch/issues/124821
|
||||
@ -3665,7 +3780,7 @@ class CPUReproTests(TestCase):
|
||||
x = torch.randint(0, 100, (22, 51), dtype=torch.int64)
|
||||
metrics.reset()
|
||||
self.common(fn, (x,))
|
||||
assert metrics.generated_cpp_vec_kernel_count == 1
|
||||
check_metrics_vec_kernel_count(1)
|
||||
|
||||
@config.patch({"cpp.dynamic_threads": True})
|
||||
def test_reduction_with_dynamic_threads(self):
|
||||
|
@ -8,6 +8,7 @@ import torch
|
||||
import torch._dynamo
|
||||
import torch.utils.cpp_extension
|
||||
from torch._C import FileCheck
|
||||
from torch._dynamo.testing import expectedFailureScalar
|
||||
|
||||
try:
|
||||
from extension_backends.cpp.extension_codegen_backend import (
|
||||
@ -103,6 +104,9 @@ class ExtensionBackendTests(TestCase):
|
||||
# return the working directory (see setUp)
|
||||
os.chdir(self.old_working_dir)
|
||||
|
||||
# Fails when testing the scalar version
|
||||
# See https://github.com/pytorch/pytorch/issues/126372.
|
||||
@expectedFailureScalar
|
||||
def test_open_device_registration(self):
|
||||
torch.utils.rename_privateuse1_backend("extension_device")
|
||||
torch._register_device_module("extension_device", self.module)
|
||||
|
@ -34,6 +34,7 @@ from torch._dynamo.debug_utils import aot_graph_input_parser
|
||||
from torch._dynamo.testing import (
|
||||
CompileCounterWithBackend,
|
||||
expectedFailureCodegenDynamic,
|
||||
expectedFailureScalar,
|
||||
rand_strided,
|
||||
same,
|
||||
skipIfPy312,
|
||||
@ -1315,6 +1316,9 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (torch.randn(1024),))
|
||||
|
||||
# Fails when testing the scalar version
|
||||
# See https://github.com/pytorch/pytorch/issues/128029.
|
||||
@expectedFailureScalar
|
||||
@skipIfRocm
|
||||
@config.patch(debug_index_asserts=False)
|
||||
def test_neg_index(self):
|
||||
@ -1577,16 +1581,40 @@ class CommonTemplate:
|
||||
def fn(a):
|
||||
return torch.var(a)
|
||||
|
||||
self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float32),)))
|
||||
self.common(fn, ((torch.rand((14923), dtype=torch.float32),)))
|
||||
atol = None
|
||||
rtol = None
|
||||
if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default":
|
||||
atol = 1e-4
|
||||
rtol = 1e-4
|
||||
self.common(
|
||||
fn,
|
||||
((torch.rand((10, 3, 352, 352), dtype=torch.float32),)),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
self.common(
|
||||
fn, ((torch.rand((14923), dtype=torch.float32),)), rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
@skipCPUIf(IS_MACOS, "fails on macos")
|
||||
def test_multilayer_var_lowp(self):
|
||||
def fn(a):
|
||||
return torch.var(a)
|
||||
|
||||
self.common(fn, (torch.rand((16, 16, 352, 352), dtype=torch.float16),))
|
||||
self.common(fn, (torch.rand((14923), dtype=torch.float16),))
|
||||
atol = None
|
||||
rtol = None
|
||||
if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default":
|
||||
atol = 1e-3
|
||||
rtol = 1e-3
|
||||
self.common(
|
||||
fn,
|
||||
(torch.rand((16, 16, 352, 352), dtype=torch.float16),),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
self.common(
|
||||
fn, (torch.rand((14923), dtype=torch.float16),), rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
def test_split_cumsum(self):
|
||||
def fn(a):
|
||||
@ -8199,7 +8227,7 @@ class CommonTemplate:
|
||||
rand_strided(shape, stride, dtype).requires_grad_(True).add(1)
|
||||
for shape, stride, dtype in args
|
||||
]
|
||||
self.common(forward, args)
|
||||
self.common(forward, args, atol=1e-05, rtol=1e-05)
|
||||
|
||||
@requires_gpu()
|
||||
def test_tmp_not_defined_issue3(self):
|
||||
@ -9281,6 +9309,7 @@ class CommonTemplate:
|
||||
# To support this behavior, we need to allow const-propping tensors that store symint data.
|
||||
# For now, dynamo will explicitly graph break when it encounters user code with this behavior.
|
||||
@expectedFailureCodegenDynamic
|
||||
@expectedFailureScalar
|
||||
def test_AllenaiLongformerBase_repro(self):
|
||||
def fn(query, scores, window_overlap):
|
||||
batch_size, seq_len, num_heads, _ = query.size()
|
||||
@ -9316,6 +9345,9 @@ class CommonTemplate:
|
||||
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
||||
_, code = run_and_get_cpp_code(opt_fn, *args)
|
||||
print(code)
|
||||
# When testing the scalar version, i.e., ATEN_CPU_CAPABILITY=default,
|
||||
# static_cast<int>(256) is not found, but static_cast<int64_t>(256).
|
||||
# See https://github.com/pytorch/pytorch/issues/126262.
|
||||
FileCheck().check_count(
|
||||
"static_cast<int32_t>(256)",
|
||||
1,
|
||||
|
@ -381,6 +381,12 @@ def expectedFailureDynamicWrapper(fn):
|
||||
return fn
|
||||
|
||||
|
||||
def expectedFailureScalar(fn):
|
||||
if os.getenv("ATEN_CPU_CAPABILITY") == "default":
|
||||
return unittest.expectedFailure(fn)
|
||||
return fn
|
||||
|
||||
|
||||
def reset_rng_state(use_xla=False):
|
||||
torch.manual_seed(1337)
|
||||
random.seed(1337)
|
||||
|
@ -1465,6 +1465,31 @@ invalid_vec_isa = InvalidVecISA()
|
||||
supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()]
|
||||
|
||||
|
||||
def get_isa_from_cpu_capability(
|
||||
capability: str | None, vec_isa_list: List[VecISA], invalid_vec_isa: InvalidVecISA
|
||||
):
|
||||
# VSX is not supported in inductor
|
||||
capability_to_isa_str = {
|
||||
"default": "INVALID_VEC_ISA",
|
||||
"neon": "asimd",
|
||||
"zvector": "zvector",
|
||||
"avx2": "avx2",
|
||||
"avx512": "avx512",
|
||||
}
|
||||
if capability in capability_to_isa_str.keys():
|
||||
isa_str = capability_to_isa_str[capability]
|
||||
if isa_str == "INVALID_VEC_ISA":
|
||||
return invalid_vec_isa
|
||||
for vec_isa in vec_isa_list:
|
||||
if isa_str == str(vec_isa):
|
||||
return vec_isa
|
||||
|
||||
if capability:
|
||||
warnings.warn(f"ignoring invalid value for ATEN_CPU_CAPABILITY {capability}")
|
||||
|
||||
return vec_isa_list[0]
|
||||
|
||||
|
||||
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
|
||||
# might have too much redundant content that is useless for ISA check. Hence,
|
||||
# we only cache some key isa information.
|
||||
@ -1507,10 +1532,13 @@ def pick_vec_isa() -> VecISA:
|
||||
if not _valid_vec_isa_list:
|
||||
return invalid_vec_isa
|
||||
|
||||
# If the simdlen is None, it indicates determine the vectorization length automatically
|
||||
# If the simdlen is None, set simdlen based on the environment ATEN_CPU_CAPABILITY
|
||||
# to control CPU vec ISA
|
||||
|
||||
if config.cpp.simdlen is None:
|
||||
assert _valid_vec_isa_list
|
||||
return _valid_vec_isa_list[0]
|
||||
return get_isa_from_cpu_capability(
|
||||
os.getenv("ATEN_CPU_CAPABILITY"), _valid_vec_isa_list, invalid_vec_isa
|
||||
)
|
||||
|
||||
for isa in _valid_vec_isa_list:
|
||||
if config.cpp.simdlen == isa.bit_width():
|
||||
|
@ -24,6 +24,7 @@
|
||||
#include <c10/util/generic_math.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/TypeCast.h>
|
||||
#include <ATen/native/Math.h>
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON)
|
||||
#define INDUCTOR_USE_VECTOR_TYPES() 1
|
||||
|
Reference in New Issue
Block a user