Fix fsdp for generic-task models #40191

This commit is contained in:
Cyril Vallez
2025-08-15 12:28:16 +02:00
committed by Arthur
parent e75d67ec39
commit c7bd5350f0
4 changed files with 82 additions and 28 deletions

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from functools import partial
from typing import Optional
@ -95,7 +94,7 @@ class GradientCheckpointingLayer(nn.Module):
@auto_docstring
class GenericForSequenceClassification(ABC):
class GenericForSequenceClassification(object):
base_model_prefix = "model"
def __init__(self, config):
@ -170,7 +169,7 @@ class GenericForSequenceClassification(ABC):
@auto_docstring
class GenericForQuestionAnswering(ABC):
class GenericForQuestionAnswering(object):
base_model_prefix = "model"
def __init__(self, config):
@ -231,7 +230,7 @@ class GenericForQuestionAnswering(ABC):
@auto_docstring
class GenericForTokenClassification(ABC):
class GenericForTokenClassification(object):
base_model_prefix = "model"
def __init__(self, config):

View File

@ -3473,3 +3473,23 @@ class Expectations(UserDict[PackedDeviceProperties, Any]):
def __repr__(self):
return f"{self.data}"
def torchrun(script: str, nproc_per_node: int, is_torchrun: bool = True, env: Optional[dict] = None):
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
tmp.write(script)
tmp.flush()
tmp.seek(0)
if is_torchrun:
cmd = (
f"torchrun --nproc_per_node {nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
).split()
else:
cmd = ["python3", tmp.name]
# Note that the subprocess will be waited for here, and raise an error if not successful
try:
_ = subprocess.run(cmd, capture_output=True, env=env, text=True, check=True)
except subprocess.CalledProcessError as e:
raise Exception(f"The following error was captured: {e.stderr}")

View File

@ -13,6 +13,7 @@
# limitations under the License.
import argparse
import textwrap
from typing import Any, Callable
from transformers import is_torch_available, is_torch_xpu_available
@ -24,6 +25,7 @@ from transformers.testing_utils import (
get_torch_dist_unique_port,
require_torch_multi_accelerator,
torch_device,
torchrun,
)
from transformers.utils import is_ccl_available, is_ipex_available
@ -141,6 +143,33 @@ class TestFSDPGeneration(TestCasePlus):
# successful return here == success - any errors would have caused an error in the sub-call
class TestFSDPGenericTaskModel(TestCasePlus):
nproc_per_node = 2
def test_generic_task_model_can_be_sharded(self):
script_to_run = textwrap.dedent(
"""
import torch
from torch.distributed.fsdp import fully_shard
from transformers import AutoModelForTokenClassification
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo", init_method="env://"
)
rank = torch.distributed.get_rank()
if torch.cuda.is_available():
torch.cuda.set_device(rank)
# Make sure it works
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B")
module = fully_shard(model)
torch.distributed.destroy_process_group()
"""
)
torchrun(script_to_run, self.nproc_per_node, env=self.get_env())
if __name__ == "__main__":
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
#

View File

@ -15,7 +15,6 @@
# Run the test: CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
import os
import subprocess
import tempfile
import textwrap
@ -24,10 +23,10 @@ from transformers.integrations.tensor_parallel import get_packed_weights, repack
from transformers.testing_utils import (
TestCasePlus,
backend_device_count,
get_torch_dist_unique_port,
require_huggingface_hub_greater_or_equal,
require_torch_multi_accelerator,
torch_device,
torchrun,
)
@ -67,25 +66,6 @@ class TestTensorParallelUtils(TestCasePlus):
class TestTensorParallel(TestCasePlus):
nproc_per_node = 2
def torchrun(self, script: str, is_torchrun: bool = True):
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
tmp.write(script)
tmp.flush()
tmp.seek(0)
if is_torchrun:
cmd = (
f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
).split()
else:
cmd = ["python3", tmp.name]
# Note that the subprocess will be waited for here, and raise an error if not successful
try:
_ = subprocess.run(cmd, capture_output=True, env=self.get_env(), text=True, check=True)
except subprocess.CalledProcessError as e:
raise Exception(f"The following error was captured: {e.stderr}")
def test_model_forward(self):
script_to_run = textwrap.dedent(
"""
@ -124,7 +104,33 @@ class TestTensorParallel(TestCasePlus):
torch.distributed.destroy_process_group()
"""
)
self.torchrun(script_to_run)
torchrun(script_to_run, self.nproc_per_node, env=self.get_env())
def test_model_backward_pass(self):
script_to_run = textwrap.dedent(
"""
import torch
import os
from transformers import AutoModelForCausalLM
from torch import nn
model_id = "JackFram/llama-68m"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, tp_plan="auto")
torch.distributed.barrier()
# Dummy forward and backward pass
# Note that loss.backward() will fail if there is a bug in the TP implementation
inputs = torch.randint(0, model.config.vocab_size, (2, 10), device=model.device)
labels = torch.randint(0, model.config.vocab_size, (2, 10), device=model.device)
loss = model(inputs, labels=labels).loss
loss.backward()
torch.distributed.barrier()
torch.distributed.destroy_process_group()
"""
)
torchrun(script_to_run, self.nproc_per_node, env=self.get_env())
def test_model_generate(self):
script_to_run = textwrap.dedent(
@ -164,7 +170,7 @@ class TestTensorParallel(TestCasePlus):
torch.distributed.destroy_process_group()
"""
)
self.torchrun(script_to_run)
torchrun(script_to_run, self.nproc_per_node, env=self.get_env())
@require_huggingface_hub_greater_or_equal("0.31.4")
def test_model_save(self):
@ -191,7 +197,7 @@ class TestTensorParallel(TestCasePlus):
model.save_pretrained(result_dir)
"""
)
self.torchrun(script_to_run, is_torchrun=is_torchrun)
torchrun(script_to_run, self.nproc_per_node, is_torchrun=is_torchrun, env=self.get_env())
non_tp_model_path = os.path.join(tmp_dir, "nontp")
tp_model_path = os.path.join(tmp_dir, "tp")