Re-enable GPT-J unit tests and refactor inference tests (#3618)

This commit is contained in:
Michael Wyatt
2023-06-28 12:33:55 -07:00
committed by GitHub
parent 7726fc8d54
commit 78b7693591
5 changed files with 102 additions and 119 deletions

5
.flake8 Normal file
View File

@ -0,0 +1,5 @@
[flake8]
ignore = E,F403,F405,F541,F841,W
select = E9,F,W6
per-file-ignores =
__init__.py:F401

View File

@ -67,7 +67,7 @@ repos:
rev: 4.0.1 rev: 4.0.1
hooks: hooks:
- id: flake8 - id: flake8
args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401'] args: ['--config=.flake8']
- repo: local - repo: local
hooks: hooks:

View File

@ -6,3 +6,4 @@ markers =
inference_ops:Individual inference operator tests inference_ops:Individual inference operator tests
seq_inference:Inference model tests to run sequentially seq_inference:Inference model tests to run sequentially
nightly:Tests that should be run nightly nightly:Tests that should be run nightly
world_size:Change world size of individual tests in a class

View File

@ -49,75 +49,57 @@ _gpt_models = [
"gpt2", "gpt2",
"distilgpt2", "distilgpt2",
"Norod78/hebrew-bad_wiki-gpt_neo-tiny", "Norod78/hebrew-bad_wiki-gpt_neo-tiny",
"EleutherAI/gpt-j-6B", # bring back this model as we did not catch an error before by merging some changes! TODO: we need to fix the OOM issue later! "EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m-deduped",
"bigscience/bloom-560m", "bigscience/bloom-560m",
] ]
_opt_models = [ _opt_models = [
"facebook/opt-125m", # 125m, 1.7B, ..., 175B variants have the same model architecture. "facebook/opt-125m", # 125m, 1.7B, ..., 175B variants have the same model architecture.
"facebook/opt-350m", # 350m applies layer norm after attention layer which is different than other variants. "facebook/opt-350m", # 350m applies layer norm after attention layer which is different than other variants.
] ]
_all_models = HfApi().list_models() _test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models)
_test_tasks = [
test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models)
test_tasks = [
"fill-mask", "question-answering", "text-classification", "token-classification", "text-generation", "fill-mask", "question-answering", "text-classification", "token-classification", "text-generation",
"text2text-generation", "summarization", "translation" "text2text-generation", "summarization", "translation"
] ]
pytest.all_models = {task: [m.modelId for m in _all_models if m.pipeline_tag == task] for task in test_tasks}
_model_w_tasks = itertools.product(*[test_models, test_tasks]) # Get a list of all models and mapping from task to supported models
_hf_models = HfApi().list_models()
_hf_model_names = [m.modelId for m in _hf_models]
_hf_task_to_models = {task: [m.modelId for m in _hf_models if m.pipeline_tag == task] for task in _test_tasks}
# Get all combinations of task:model to test
_model_w_tasks = [(m, t) for m, t in itertools.product(*[_test_models, _test_tasks]) if m in _hf_task_to_models[t]]
# Assign to pytest variables for testing
pytest.model_w_tasks = _model_w_tasks
pytest.mt_names = [f"{m}-{t}" for m, t in pytest.model_w_tasks]
def _valid_model_task(model_task): @pytest.fixture(scope="module", autouse=True)
m, t = model_task def verify_models():
return m in pytest.all_models[t] # Verify all test models are registered in HF
_test_models_not_found = [m for m in _test_models if m not in _hf_model_names]
if _test_models_not_found:
pytest.models_w_tasks = list(filter(_valid_model_task, _model_w_tasks)) pytest.fail(f"Model(s) not found in HuggingFace: {_test_models_not_found}")
pytest.mt_names = [f"{m}-{t}" for m, t in pytest.models_w_tasks]
""" # Verify all models are assigned to at least one task
These fixtures iterate all combinations of tasks and models, dtype, & cuda_graph _models_to_be_tested = set(m for m, t in _model_w_tasks)
""" _missing_task_models = _models_to_be_tested.difference(_test_models)
if _missing_task_models:
pytest.fail(f"Model(s) do not have an assigned task: {_missing_task_models}")
@pytest.fixture(params=pytest.models_w_tasks, ids=pytest.mt_names)
def model_w_task(request):
return request.param
@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"])
def dtype(request):
return request.param
@pytest.fixture(params=[True, False], ids=["CG", "noCG"])
def enable_cuda_graph(request):
return request.param
@pytest.fixture(params=[True, False], ids=["Triton", "noTriton"])
def enable_triton(request):
return request.param
"""
This fixture will validate the configuration
"""
# Fixture to add skips for certain configurations
@pytest.fixture() @pytest.fixture()
def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph, enable_triton): def invalid_test(model_w_task, dtype, enable_cuda_graph, enable_triton):
model, task = model_w_task model, task = model_w_task
msg = "" msg = ""
if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"): if enable_cuda_graph and (torch_info["cuda_version"] == "0.0"):
msg = "DS inference injection doesn't work well on older torch versions"
elif model not in pytest.all_models[task]:
msg = f"Not a valid model / task combination: {model} / {task}"
elif enable_cuda_graph and (torch_info["cuda_version"] == "0.0"):
msg = "CUDA not detected, cannot use CUDA Graph" msg = "CUDA not detected, cannot use CUDA Graph"
elif enable_cuda_graph and pkg_version.parse(torch.__version__) < pkg_version.parse("1.10"): elif enable_cuda_graph and pkg_version.parse(torch.__version__) < pkg_version.parse("1.10"):
msg = "CUDA Graph is only available in torch versions >= 1.10" msg = "CUDA Graph is only available in torch versions >= 1.10"
elif "gpt-j-6B" in model: elif "gpt-j-6b" in model:
if dtype != torch.half: if dtype != torch.half:
msg = f"Not enough GPU memory to run {model} with dtype {dtype}" msg = f"Not enough GPU memory to run {model} with dtype {dtype}"
elif enable_cuda_graph: elif enable_cuda_graph:
@ -139,10 +121,30 @@ def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph, enable_tri
return msg return msg
""" """ Fixtures for inference config """
These fixtures can be used to customize the query, inference args, and assert
statement for each combination of model /task
""" @pytest.fixture(params=pytest.model_w_tasks, ids=pytest.mt_names)
def model_w_task(request):
return request.param
@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"])
def dtype(request):
return request.param
@pytest.fixture(params=[True, False], ids=["CG", "noCG"])
def enable_cuda_graph(request):
return request.param
@pytest.fixture(params=[True, False], ids=["Triton", "noTriton"])
def enable_triton(request):
return request.param
""" Fixtures for running query """
@pytest.fixture @pytest.fixture
@ -178,7 +180,7 @@ def query(model_w_task):
def inf_kwargs(model_w_task): def inf_kwargs(model_w_task):
model, task = model_w_task model, task = model_w_task
if task == "text-generation": if task == "text-generation":
if model == "EleutherAI/gpt-j-6B": if model == "EleutherAI/gpt-j-6b":
# This model on V100 is hitting memory problems that limit the number of output tokens # This model on V100 is hitting memory problems that limit the number of output tokens
return {"do_sample": False, "max_length": 12} return {"do_sample": False, "max_length": 12}
return {"do_sample": False, "max_length": 20} return {"do_sample": False, "max_length": 20}
@ -186,6 +188,9 @@ def inf_kwargs(model_w_task):
return {} return {}
""" Assertion fixture for verifying model outputs """
def fill_mask_assert(x, y): def fill_mask_assert(x, y):
return set(res["token_str"] for res in x) == set(res["token_str"] for res in y) return set(res["token_str"] for res in x) == set(res["token_str"] for res in y)
@ -237,6 +242,7 @@ def assert_fn(model_w_task):
return assert_fn return assert_fn
# Used to verify DeepSpeed kernel injection worked with a model
def check_injection(model): def check_injection(model):
def verify_injection(module): def verify_injection(module):
@ -251,27 +257,24 @@ def check_injection(model):
verify_injection(model) verify_injection(model)
"""
Tests
"""
@pytest.mark.inference @pytest.mark.inference
class TestModelTask(DistributedTest): class TestModelTask(DistributedTest):
world_size = 1 world_size = 1
def test(self, def test(
model_w_task, self,
dtype, model_w_task,
enable_cuda_graph, dtype,
enable_triton, enable_cuda_graph,
query, enable_triton,
inf_kwargs, query,
assert_fn, inf_kwargs,
invalid_model_task_config, assert_fn,
perf_meas=True): invalid_test,
if invalid_model_task_config: perf_meas=True,
pytest.skip(invalid_model_task_config) ):
if invalid_test:
pytest.skip(invalid_test)
model, task = model_w_task model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0"))
@ -338,10 +341,10 @@ class TestModelTask(DistributedTest):
@pytest.mark.parametrize("model_w_task", [("EleutherAI/gpt-neo-1.3B", "text-generation"), @pytest.mark.parametrize("model_w_task", [("EleutherAI/gpt-neo-1.3B", "text-generation"),
("EleutherAI/gpt-neox-20b", "text-generation"), ("EleutherAI/gpt-neox-20b", "text-generation"),
("bigscience/bloom-3b", "text-generation"), ("bigscience/bloom-3b", "text-generation"),
("EleutherAI/gpt-j-6B", "text-generation")], ("EleutherAI/gpt-j-6b", "text-generation")],
ids=["gpt-neo", "gpt-neox", "bloom", "gpt-j"]) ids=["gpt-neo", "gpt-neox", "bloom", "gpt-j"])
class TestMPSize(DistributedTest): class TestMPSize(DistributedTest):
world_size = 4 world_size = 2
def test( def test(
self, self,
@ -350,10 +353,10 @@ class TestMPSize(DistributedTest):
query, query,
inf_kwargs, inf_kwargs,
assert_fn, assert_fn,
invalid_model_task_config, invalid_test,
): ):
if invalid_model_task_config: if invalid_test:
pytest.skip(invalid_model_task_config) pytest.skip(invalid_test)
model, task = model_w_task model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0"))
@ -402,12 +405,12 @@ class TestInjectionPolicy(DistributedTest):
query, query,
inf_kwargs, inf_kwargs,
assert_fn, assert_fn,
invalid_model_task_config, invalid_test,
dtype, dtype,
enable_cuda_graph, enable_cuda_graph,
): ):
if invalid_model_task_config: if invalid_test:
pytest.skip(invalid_model_task_config) pytest.skip(invalid_test)
model, task = model_w_task model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0"))
@ -452,12 +455,12 @@ class TestAutoTensorParallelism(DistributedTest):
query, query,
inf_kwargs, inf_kwargs,
assert_fn, assert_fn,
invalid_model_task_config, invalid_test,
dtype, dtype,
enable_cuda_graph, enable_cuda_graph,
): ):
if invalid_model_task_config: if invalid_test:
pytest.skip(invalid_model_task_config) pytest.skip(invalid_test)
model, task = model_w_task model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0"))
@ -483,7 +486,7 @@ class TestAutoTensorParallelism(DistributedTest):
"model_family, model_name", "model_family, model_name",
( (
["gpt2", "EleutherAI/gpt-neo-2.7B"], ["gpt2", "EleutherAI/gpt-neo-2.7B"],
["gpt2", "EleutherAI/gpt-j-6B"], ["gpt2", "EleutherAI/gpt-j-6b"],
["gpt2", "gpt2-xl"], ["gpt2", "gpt2-xl"],
), ),
) )
@ -503,7 +506,7 @@ class TestLMCorrectness(DistributedTest):
dtype = torch.float dtype = torch.float
task_dict = lm_eval.tasks.get_task_dict([task]) task_dict = lm_eval.tasks.get_task_dict([task])
if 'gpt-j-6B' in model_name: if 'gpt-j-6b' in model_name:
dtype = torch.half dtype = torch.half
lm = lm_eval.models.get_model(model_family).create_from_arg_string(f"pretrained={model_name}", lm = lm_eval.models.get_model(model_family).create_from_arg_string(f"pretrained={model_name}",
{"device": "cpu"}) {"device": "cpu"})

View File

@ -13,43 +13,17 @@ from unit.common import DistributedTest
from deepspeed.accelerator import get_accelerator from deepspeed.accelerator import get_accelerator
@pytest.fixture
def query(model, task):
if task == "text-generation":
return "DeepSpeed is"
elif task == "fill-mask":
if "roberta" in model:
return "I am a <mask> model"
else:
return "I am a [MASK] model"
else:
raise NotImplementedError
@pytest.fixture
def inf_kwargs(task):
if task == "text-generation":
return {"do_sample": False, "min_length": 50, "max_length": 50}
else:
return {}
@pytest.mark.inference @pytest.mark.inference
@pytest.mark.parametrize("model,task", [
("bert-base-cased", "fill-mask"),
("roberta-base", "fill-mask"),
("gpt2", "text-generation"),
("facebook/opt-125m", "text-generation"),
("bigscience/bloom-560m", "text-generation"),
])
@pytest.mark.parametrize("cuda_graphs", [True, False])
@pytest.mark.parametrize("use_cuda_events", [True, False]) @pytest.mark.parametrize("use_cuda_events", [True, False])
@pytest.mark.parametrize("enable_cuda_graph", [True, False])
class TestModelProfiling(DistributedTest): class TestModelProfiling(DistributedTest):
world_size = 1 world_size = 1
def test(self, model, task, query, inf_kwargs, cuda_graphs, use_cuda_events, dtype=torch.float16): def test(self, enable_cuda_graph, use_cuda_events):
if cuda_graphs and "bert" not in model: task = "fill-mask"
pytest.skip(f"CUDA Graph not supported for {model}") model = "bert-base-cased"
dtype = torch.float16
query = "I am a [MASK] model"
local_rank = int(os.getenv("LOCAL_RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
@ -59,7 +33,7 @@ class TestModelProfiling(DistributedTest):
dtype=dtype, dtype=dtype,
mp_size=world_size, mp_size=world_size,
replace_with_kernel_inject=True, replace_with_kernel_inject=True,
enable_cuda_graph=cuda_graphs) enable_cuda_graph=enable_cuda_graph)
pipe.model.profile_model_time(use_cuda_events=use_cuda_events) pipe.model.profile_model_time(use_cuda_events=use_cuda_events)
e2e_times = [] e2e_times = []
@ -68,7 +42,7 @@ class TestModelProfiling(DistributedTest):
get_accelerator().synchronize() get_accelerator().synchronize()
start = time.perf_counter_ns() start = time.perf_counter_ns()
r = pipe(query, **inf_kwargs) r = pipe(query)
get_accelerator().synchronize() get_accelerator().synchronize()
end = time.perf_counter_ns() end = time.perf_counter_ns()