mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Fix fsdp for generic-task models #40191
This commit is contained in:
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
@ -95,7 +94,7 @@ class GradientCheckpointingLayer(nn.Module):
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GenericForSequenceClassification(ABC):
|
||||
class GenericForSequenceClassification(object):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config):
|
||||
@ -170,7 +169,7 @@ class GenericForSequenceClassification(ABC):
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GenericForQuestionAnswering(ABC):
|
||||
class GenericForQuestionAnswering(object):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config):
|
||||
@ -231,7 +230,7 @@ class GenericForQuestionAnswering(ABC):
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GenericForTokenClassification(ABC):
|
||||
class GenericForTokenClassification(object):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -3473,3 +3473,23 @@ class Expectations(UserDict[PackedDeviceProperties, Any]):
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.data}"
|
||||
|
||||
|
||||
def torchrun(script: str, nproc_per_node: int, is_torchrun: bool = True, env: Optional[dict] = None):
|
||||
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
|
||||
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
|
||||
tmp.write(script)
|
||||
tmp.flush()
|
||||
tmp.seek(0)
|
||||
if is_torchrun:
|
||||
cmd = (
|
||||
f"torchrun --nproc_per_node {nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
|
||||
).split()
|
||||
else:
|
||||
cmd = ["python3", tmp.name]
|
||||
|
||||
# Note that the subprocess will be waited for here, and raise an error if not successful
|
||||
try:
|
||||
_ = subprocess.run(cmd, capture_output=True, env=env, text=True, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise Exception(f"The following error was captured: {e.stderr}")
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import textwrap
|
||||
from typing import Any, Callable
|
||||
|
||||
from transformers import is_torch_available, is_torch_xpu_available
|
||||
@ -24,6 +25,7 @@ from transformers.testing_utils import (
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_multi_accelerator,
|
||||
torch_device,
|
||||
torchrun,
|
||||
)
|
||||
from transformers.utils import is_ccl_available, is_ipex_available
|
||||
|
||||
@ -141,6 +143,33 @@ class TestFSDPGeneration(TestCasePlus):
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
|
||||
class TestFSDPGenericTaskModel(TestCasePlus):
|
||||
nproc_per_node = 2
|
||||
|
||||
def test_generic_task_model_can_be_sharded(self):
|
||||
script_to_run = textwrap.dedent(
|
||||
"""
|
||||
import torch
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo", init_method="env://"
|
||||
)
|
||||
rank = torch.distributed.get_rank()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
# Make sure it works
|
||||
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B")
|
||||
module = fully_shard(model)
|
||||
|
||||
torch.distributed.destroy_process_group()
|
||||
"""
|
||||
)
|
||||
torchrun(script_to_run, self.nproc_per_node, env=self.get_env())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
|
||||
#
|
||||
|
@ -15,7 +15,6 @@
|
||||
# Run the test: CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
|
||||
@ -24,10 +23,10 @@ from transformers.integrations.tensor_parallel import get_packed_weights, repack
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
backend_device_count,
|
||||
get_torch_dist_unique_port,
|
||||
require_huggingface_hub_greater_or_equal,
|
||||
require_torch_multi_accelerator,
|
||||
torch_device,
|
||||
torchrun,
|
||||
)
|
||||
|
||||
|
||||
@ -67,25 +66,6 @@ class TestTensorParallelUtils(TestCasePlus):
|
||||
class TestTensorParallel(TestCasePlus):
|
||||
nproc_per_node = 2
|
||||
|
||||
def torchrun(self, script: str, is_torchrun: bool = True):
|
||||
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
|
||||
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
|
||||
tmp.write(script)
|
||||
tmp.flush()
|
||||
tmp.seek(0)
|
||||
if is_torchrun:
|
||||
cmd = (
|
||||
f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
|
||||
).split()
|
||||
else:
|
||||
cmd = ["python3", tmp.name]
|
||||
|
||||
# Note that the subprocess will be waited for here, and raise an error if not successful
|
||||
try:
|
||||
_ = subprocess.run(cmd, capture_output=True, env=self.get_env(), text=True, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise Exception(f"The following error was captured: {e.stderr}")
|
||||
|
||||
def test_model_forward(self):
|
||||
script_to_run = textwrap.dedent(
|
||||
"""
|
||||
@ -124,7 +104,33 @@ class TestTensorParallel(TestCasePlus):
|
||||
torch.distributed.destroy_process_group()
|
||||
"""
|
||||
)
|
||||
self.torchrun(script_to_run)
|
||||
torchrun(script_to_run, self.nproc_per_node, env=self.get_env())
|
||||
|
||||
def test_model_backward_pass(self):
|
||||
script_to_run = textwrap.dedent(
|
||||
"""
|
||||
import torch
|
||||
import os
|
||||
from transformers import AutoModelForCausalLM
|
||||
from torch import nn
|
||||
|
||||
model_id = "JackFram/llama-68m"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, tp_plan="auto")
|
||||
torch.distributed.barrier()
|
||||
|
||||
# Dummy forward and backward pass
|
||||
# Note that loss.backward() will fail if there is a bug in the TP implementation
|
||||
inputs = torch.randint(0, model.config.vocab_size, (2, 10), device=model.device)
|
||||
labels = torch.randint(0, model.config.vocab_size, (2, 10), device=model.device)
|
||||
loss = model(inputs, labels=labels).loss
|
||||
loss.backward()
|
||||
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.destroy_process_group()
|
||||
"""
|
||||
)
|
||||
torchrun(script_to_run, self.nproc_per_node, env=self.get_env())
|
||||
|
||||
def test_model_generate(self):
|
||||
script_to_run = textwrap.dedent(
|
||||
@ -164,7 +170,7 @@ class TestTensorParallel(TestCasePlus):
|
||||
torch.distributed.destroy_process_group()
|
||||
"""
|
||||
)
|
||||
self.torchrun(script_to_run)
|
||||
torchrun(script_to_run, self.nproc_per_node, env=self.get_env())
|
||||
|
||||
@require_huggingface_hub_greater_or_equal("0.31.4")
|
||||
def test_model_save(self):
|
||||
@ -191,7 +197,7 @@ class TestTensorParallel(TestCasePlus):
|
||||
model.save_pretrained(result_dir)
|
||||
"""
|
||||
)
|
||||
self.torchrun(script_to_run, is_torchrun=is_torchrun)
|
||||
torchrun(script_to_run, self.nproc_per_node, is_torchrun=is_torchrun, env=self.get_env())
|
||||
|
||||
non_tp_model_path = os.path.join(tmp_dir, "nontp")
|
||||
tp_model_path = os.path.join(tmp_dir, "tp")
|
||||
|
Reference in New Issue
Block a user