(Part 1) fix: make TP training compatible with new transformers (#3457)

* feat: support new tp refactor for training

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: @S1ro1 review cmt

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: @S1ro1 review cmt - tp_plan flag docstr

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: @SunMarc review cmt on un used flag

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: pick approach 3 as discussed in the PR

see https://github.com/huggingface/accelerate/pull/3457#discussion_r2037909077 for more details

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: styling errors

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: bump up transformers for tp_size feature

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

---------

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
Mehant Kammakomati
2025-04-11 22:01:28 +05:30
committed by GitHub
parent ee4cab96ed
commit 67adb473a4
10 changed files with 57 additions and 83 deletions

View File

@ -374,9 +374,7 @@ class Accelerator:
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(
torch_tp_plugin, TorchTensorParallelPlugin
):
if isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")
@ -396,14 +394,8 @@ class Accelerator:
if not is_torch_version(">=", FSDP2_PYTORCH_VERSION):
raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}")
if torch_tp_plugin is None:
torch_tp_plugin = (
TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None
)
else:
if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
os.environ["ACCELERATE_USE_TP"] = "true"
if torch_tp_plugin is not None and not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
if megatron_lm_plugin is None: # init from env variables
megatron_lm_plugin = (
@ -1598,15 +1590,17 @@ class Accelerator:
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
if hasattr(model, "supports_tp_plan") and not model.supports_tp_plan:
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
if not hasattr(model, "tp_size"):
raise NotImplementedError(
"Provided model does not support tensor parallelism. \
Tensor parallelism plan can be added as base_model_tp_plan to model config class \
and _tp_plan attribute to model class."
"Model should undergo tensor parallel before passing it to accelerate."
"You can use .from_pretrained(..., tp_plan='auto') if the model supports"
)
if model.tp_size != self.state.torch_tp_plugin.tp_size:
raise ValueError(
f"tp_size in the plugin {self.state.torch_tp_plugin.tp_size} should be same as model's tp size {model.tp_size}"
)
model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
elif self.is_fsdp2:
model = fsdp2_prepare_model(self, model)
@ -2223,8 +2217,7 @@ class Accelerator:
return self.state.torch_tp_plugin.torch_device_mesh
elif self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
return self.state.ds_device_mesh
else:
return None
return None
def _prepare_msamp(self, *args, device_placement):
if not is_msamp_available():

View File

@ -369,7 +369,7 @@ def get_cluster_input():
)
fsdp_config = {}
tp_config = {}
if distributed_type in [
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
@ -498,21 +498,7 @@ def get_cluster_input():
default=False,
error_message="Please enter yes or no.",
)
if not use_fsdp:
use_tp = _ask_field(
"Do you want to use TensorParallel? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
error_message="Please enter yes or no.",
)
if use_tp:
distributed_type = DistributedType.TP
if distributed_type == DistributedType.TP:
tp_config["tp_size"] = _ask_field(
"What should be your Tensor Parallel degree? [1]: ",
int,
default=1,
)
megatron_lm_config = {}
if distributed_type in [DistributedType.MULTI_GPU]:
use_megatron_lm = _ask_field(
@ -857,7 +843,6 @@ def get_cluster_input():
fp8_config=fp8_config,
deepspeed_config=deepspeed_config,
fsdp_config=fsdp_config,
tp_config=tp_config,
megatron_lm_config=megatron_lm_config,
ipex_config=ipex_config,
mpirun_config=mpirun_config,

View File

@ -194,8 +194,6 @@ class ClusterConfig(BaseConfig):
deepspeed_config: dict = None
# args for fsdp
fsdp_config: dict = None
# args for tp
tp_config: dict = None
# args for megatron_lm
megatron_lm_config: dict = None
# args for ipex
@ -223,8 +221,6 @@ class ClusterConfig(BaseConfig):
self.deepspeed_config = {}
if self.fsdp_config is None:
self.fsdp_config = {}
if self.tp_config is None:
self.tp_config = {}
if self.megatron_lm_config is None:
self.megatron_lm_config = {}
if self.ipex_config is None:

View File

@ -75,7 +75,6 @@ options_to_group = {
"tpu": "TPU",
"use_deepspeed": "DeepSpeed Arguments",
"use_fsdp": "FSDP Arguments",
"use_tp": "PyTorch TP Arguments",
"use_megatron_lm": "Megatron-LM Arguments",
"fp8_backend": "FP8 Arguments",
}
@ -264,12 +263,6 @@ def launch_command_parser(subparsers=None):
action="store_true",
help="Whether to use fsdp.",
)
paradigm_args.add_argument(
"--use_tp",
default=False,
action="store_true",
help="Whether to use PyTorch TP.",
)
paradigm_args.add_argument(
"--use_megatron_lm",
default=False,
@ -612,15 +605,6 @@ def launch_command_parser(subparsers=None):
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
)
# tp args
tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.")
tp_args.add_argument(
"--tp_size",
default=1,
type=int,
help="PyTorch Tensor Parallelism (TP) degree. Set a value greater than 1 to activate. (useful only when `use_tp` flag is passed)",
)
# megatron_lm args
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
megatron_lm_args.add_argument(
@ -1002,9 +986,9 @@ def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):
def _validate_launch_command(args):
# Sanity checks
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp, args.use_tp]) > 1:
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:
raise ValueError(
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`, `--use_tp` at a time."
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time."
)
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
@ -1021,7 +1005,6 @@ def _validate_launch_command(args):
and not args.tpu_use_cluster
and not args.use_deepspeed
and not args.use_fsdp
and not args.use_tp
and not args.use_megatron_lm
):
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
@ -1041,7 +1024,6 @@ def _validate_launch_command(args):
)
args.tpu = defaults.distributed_type == DistributedType.XLA
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
args.use_tp = defaults.distributed_type == DistributedType.TP
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
if args.gpu_ids is None:
@ -1195,8 +1177,6 @@ def launch_command(args):
deepspeed_launcher(args)
elif args.use_fsdp and not args.cpu:
multi_gpu_launcher(args)
elif args.use_tp and not args.cpu:
multi_gpu_launcher(args)
elif args.use_megatron_lm and not args.cpu:
multi_gpu_launcher(args)
elif args.multi_gpu and not args.cpu:

View File

@ -971,7 +971,7 @@ class AcceleratorState:
self.distributed_type = DistributedType.MEGATRON_LM
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
self.megatron_lm_plugin = megatron_lm_plugin
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
if self.torch_tp_plugin is not None:
self.distributed_type = DistributedType.TP
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
if is_ipex_available():

View File

@ -14,6 +14,7 @@
import argparse
import json
import os
from contextlib import nullcontext
from pathlib import Path
import evaluate
@ -24,7 +25,7 @@ from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
from accelerate import Accelerator, DistributedType
from accelerate.utils import SAFE_WEIGHTS_NAME, set_seed
from accelerate.utils import SAFE_WEIGHTS_NAME, TorchTensorParallelPlugin, set_seed
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
@ -80,7 +81,7 @@ def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name:
def training_function(config, args):
# Initialize accelerator
accelerator = Accelerator()
accelerator = Accelerator(torch_tp_plugin=TorchTensorParallelPlugin(tp_size=args.tp_size))
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
@ -91,9 +92,10 @@ def training_function(config, args):
set_seed(seed)
train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name)
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, return_dict=True, tp_plan=args.tp_plan, tp_size=args.tp_size
)
if args.add_pad_token:
if model.config.pad_token_id is None:
@ -150,7 +152,13 @@ def training_function(config, args):
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
context = nullcontext
if args.tp_plan is not None:
from torch.distributed._tensor.experimental import implicit_replication
context = implicit_replication
with context():
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
@ -213,12 +221,15 @@ def training_function(config, args):
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump(performance_metric, f)
# Finally try saving the model
accelerator.save_model(model, args.output_dir)
# TODO: skip saving of the model test for TP until the feature lands
if args.tp_plan is None:
# Finally try saving the model
accelerator.save_model(model, args.output_dir)
accelerator.wait_for_everyone()
assert Path(args.output_dir, SAFE_WEIGHTS_NAME).exists(), (
"Model was not saved when calling `Accelerator.save_model`"
)
if args.tp_plan is None:
assert Path(args.output_dir, SAFE_WEIGHTS_NAME).exists(), (
"Model was not saved when calling `Accelerator.save_model`"
)
accelerator.end_training()
@ -255,6 +266,18 @@ def main():
default=False,
help="To add pad token if not exists.",
)
parser.add_argument(
"--tp_plan",
type=str,
default=None,
help="pass 'auto' to use TP",
)
parser.add_argument(
"--tp_size",
type=int,
default=None,
help="TP size to be used to shard the model",
)
args = parser.parse_args()
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
training_function(config, args)

View File

@ -49,7 +49,7 @@ ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.47.0"
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0"
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}

View File

@ -2034,7 +2034,6 @@ class TorchTensorParallelPlugin:
torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None)
def __post_init__(self):
self.tp_size = self.tp_size if os.environ.get("TP_SIZE", "1") == "1" else int(os.environ.get("TP_SIZE", "1"))
if self.tp_size == 1:
raise ValueError("Provide TP degree > 1.")
@ -2052,6 +2051,8 @@ class TorchTensorParallelPlugin:
mesh_dim_name = "tp"
# device mesh is not used for model sharding
# it is only used for preparing data loader
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))

View File

@ -305,10 +305,6 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
if args.use_tp:
current_env["ACCELERATE_USE_TP"] = "true"
current_env["TP_SIZE"] = str(args.tp_size)
if args.use_megatron_lm:
prefix = "MEGATRON_LM_"
current_env["ACCELERATE_USE_MEGATRON_LM"] = "true"

View File

@ -48,15 +48,15 @@ class TPIntegrationTest(TempDirTestCase):
def test_working_of_tp(self):
self.test_file_path = self.test_scripts_folder / "test_performance.py"
cmd = get_launch_command(
num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size
)
cmd = get_launch_command(num_processes=self.test_tp_size, num_machines=1, machine_rank=0)
cmd.extend(
[
self.test_file_path,
f"--output_dir={self.tmpdir}",
f"--model_name_or_path={self.model_name_or_path}",
"--add_pad_token=true",
"--tp_plan=auto",
f"--tp_size={self.test_tp_size}",
]
)
with patch_environment(omp_num_threads=1):