mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
(Part 1) fix: make TP training compatible with new transformers (#3457)
* feat: support new tp refactor for training Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: @S1ro1 review cmt Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: @S1ro1 review cmt - tp_plan flag docstr Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: @SunMarc review cmt on un used flag Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: pick approach 3 as discussed in the PR see https://github.com/huggingface/accelerate/pull/3457#discussion_r2037909077 for more details Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: styling errors Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: bump up transformers for tp_size feature Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> --------- Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
committed by
GitHub
parent
ee4cab96ed
commit
67adb473a4
@ -374,9 +374,7 @@ class Accelerator:
|
||||
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
|
||||
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
|
||||
|
||||
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(
|
||||
torch_tp_plugin, TorchTensorParallelPlugin
|
||||
):
|
||||
if isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
|
||||
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
|
||||
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")
|
||||
|
||||
@ -396,14 +394,8 @@ class Accelerator:
|
||||
if not is_torch_version(">=", FSDP2_PYTORCH_VERSION):
|
||||
raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}")
|
||||
|
||||
if torch_tp_plugin is None:
|
||||
torch_tp_plugin = (
|
||||
TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None
|
||||
)
|
||||
else:
|
||||
if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
|
||||
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
|
||||
os.environ["ACCELERATE_USE_TP"] = "true"
|
||||
if torch_tp_plugin is not None and not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
|
||||
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
|
||||
|
||||
if megatron_lm_plugin is None: # init from env variables
|
||||
megatron_lm_plugin = (
|
||||
@ -1598,15 +1590,17 @@ class Accelerator:
|
||||
if self.ddp_handler is not None:
|
||||
self.ddp_handler.register_comm_hook(model)
|
||||
elif self.distributed_type == DistributedType.TP:
|
||||
if hasattr(model, "supports_tp_plan") and not model.supports_tp_plan:
|
||||
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
|
||||
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
|
||||
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
|
||||
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
|
||||
if not hasattr(model, "tp_size"):
|
||||
raise NotImplementedError(
|
||||
"Provided model does not support tensor parallelism. \
|
||||
Tensor parallelism plan can be added as base_model_tp_plan to model config class \
|
||||
and _tp_plan attribute to model class."
|
||||
"Model should undergo tensor parallel before passing it to accelerate."
|
||||
"You can use .from_pretrained(..., tp_plan='auto') if the model supports"
|
||||
)
|
||||
if model.tp_size != self.state.torch_tp_plugin.tp_size:
|
||||
raise ValueError(
|
||||
f"tp_size in the plugin {self.state.torch_tp_plugin.tp_size} should be same as model's tp size {model.tp_size}"
|
||||
)
|
||||
model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
|
||||
elif self.is_fsdp2:
|
||||
model = fsdp2_prepare_model(self, model)
|
||||
|
||||
@ -2223,8 +2217,7 @@ class Accelerator:
|
||||
return self.state.torch_tp_plugin.torch_device_mesh
|
||||
elif self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
|
||||
return self.state.ds_device_mesh
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _prepare_msamp(self, *args, device_placement):
|
||||
if not is_msamp_available():
|
||||
|
@ -369,7 +369,7 @@ def get_cluster_input():
|
||||
)
|
||||
|
||||
fsdp_config = {}
|
||||
tp_config = {}
|
||||
|
||||
if distributed_type in [
|
||||
DistributedType.MULTI_GPU,
|
||||
DistributedType.MULTI_NPU,
|
||||
@ -498,21 +498,7 @@ def get_cluster_input():
|
||||
default=False,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
if not use_fsdp:
|
||||
use_tp = _ask_field(
|
||||
"Do you want to use TensorParallel? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
if use_tp:
|
||||
distributed_type = DistributedType.TP
|
||||
if distributed_type == DistributedType.TP:
|
||||
tp_config["tp_size"] = _ask_field(
|
||||
"What should be your Tensor Parallel degree? [1]: ",
|
||||
int,
|
||||
default=1,
|
||||
)
|
||||
|
||||
megatron_lm_config = {}
|
||||
if distributed_type in [DistributedType.MULTI_GPU]:
|
||||
use_megatron_lm = _ask_field(
|
||||
@ -857,7 +843,6 @@ def get_cluster_input():
|
||||
fp8_config=fp8_config,
|
||||
deepspeed_config=deepspeed_config,
|
||||
fsdp_config=fsdp_config,
|
||||
tp_config=tp_config,
|
||||
megatron_lm_config=megatron_lm_config,
|
||||
ipex_config=ipex_config,
|
||||
mpirun_config=mpirun_config,
|
||||
|
@ -194,8 +194,6 @@ class ClusterConfig(BaseConfig):
|
||||
deepspeed_config: dict = None
|
||||
# args for fsdp
|
||||
fsdp_config: dict = None
|
||||
# args for tp
|
||||
tp_config: dict = None
|
||||
# args for megatron_lm
|
||||
megatron_lm_config: dict = None
|
||||
# args for ipex
|
||||
@ -223,8 +221,6 @@ class ClusterConfig(BaseConfig):
|
||||
self.deepspeed_config = {}
|
||||
if self.fsdp_config is None:
|
||||
self.fsdp_config = {}
|
||||
if self.tp_config is None:
|
||||
self.tp_config = {}
|
||||
if self.megatron_lm_config is None:
|
||||
self.megatron_lm_config = {}
|
||||
if self.ipex_config is None:
|
||||
|
@ -75,7 +75,6 @@ options_to_group = {
|
||||
"tpu": "TPU",
|
||||
"use_deepspeed": "DeepSpeed Arguments",
|
||||
"use_fsdp": "FSDP Arguments",
|
||||
"use_tp": "PyTorch TP Arguments",
|
||||
"use_megatron_lm": "Megatron-LM Arguments",
|
||||
"fp8_backend": "FP8 Arguments",
|
||||
}
|
||||
@ -264,12 +263,6 @@ def launch_command_parser(subparsers=None):
|
||||
action="store_true",
|
||||
help="Whether to use fsdp.",
|
||||
)
|
||||
paradigm_args.add_argument(
|
||||
"--use_tp",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use PyTorch TP.",
|
||||
)
|
||||
paradigm_args.add_argument(
|
||||
"--use_megatron_lm",
|
||||
default=False,
|
||||
@ -612,15 +605,6 @@ def launch_command_parser(subparsers=None):
|
||||
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
|
||||
)
|
||||
|
||||
# tp args
|
||||
tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.")
|
||||
tp_args.add_argument(
|
||||
"--tp_size",
|
||||
default=1,
|
||||
type=int,
|
||||
help="PyTorch Tensor Parallelism (TP) degree. Set a value greater than 1 to activate. (useful only when `use_tp` flag is passed)",
|
||||
)
|
||||
|
||||
# megatron_lm args
|
||||
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
|
||||
megatron_lm_args.add_argument(
|
||||
@ -1002,9 +986,9 @@ def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):
|
||||
|
||||
def _validate_launch_command(args):
|
||||
# Sanity checks
|
||||
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp, args.use_tp]) > 1:
|
||||
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:
|
||||
raise ValueError(
|
||||
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`, `--use_tp` at a time."
|
||||
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time."
|
||||
)
|
||||
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
|
||||
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
|
||||
@ -1021,7 +1005,6 @@ def _validate_launch_command(args):
|
||||
and not args.tpu_use_cluster
|
||||
and not args.use_deepspeed
|
||||
and not args.use_fsdp
|
||||
and not args.use_tp
|
||||
and not args.use_megatron_lm
|
||||
):
|
||||
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
|
||||
@ -1041,7 +1024,6 @@ def _validate_launch_command(args):
|
||||
)
|
||||
args.tpu = defaults.distributed_type == DistributedType.XLA
|
||||
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
|
||||
args.use_tp = defaults.distributed_type == DistributedType.TP
|
||||
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
|
||||
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
|
||||
if args.gpu_ids is None:
|
||||
@ -1195,8 +1177,6 @@ def launch_command(args):
|
||||
deepspeed_launcher(args)
|
||||
elif args.use_fsdp and not args.cpu:
|
||||
multi_gpu_launcher(args)
|
||||
elif args.use_tp and not args.cpu:
|
||||
multi_gpu_launcher(args)
|
||||
elif args.use_megatron_lm and not args.cpu:
|
||||
multi_gpu_launcher(args)
|
||||
elif args.multi_gpu and not args.cpu:
|
||||
|
@ -971,7 +971,7 @@ class AcceleratorState:
|
||||
self.distributed_type = DistributedType.MEGATRON_LM
|
||||
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
|
||||
self.megatron_lm_plugin = megatron_lm_plugin
|
||||
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
|
||||
if self.torch_tp_plugin is not None:
|
||||
self.distributed_type = DistributedType.TP
|
||||
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
|
||||
if is_ipex_available():
|
||||
|
@ -14,6 +14,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import evaluate
|
||||
@ -24,7 +25,7 @@ from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.utils import SAFE_WEIGHTS_NAME, set_seed
|
||||
from accelerate.utils import SAFE_WEIGHTS_NAME, TorchTensorParallelPlugin, set_seed
|
||||
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
|
||||
|
||||
|
||||
@ -80,7 +81,7 @@ def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name:
|
||||
|
||||
def training_function(config, args):
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator()
|
||||
accelerator = Accelerator(torch_tp_plugin=TorchTensorParallelPlugin(tp_size=args.tp_size))
|
||||
|
||||
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
||||
lr = config["lr"]
|
||||
@ -91,9 +92,10 @@ def training_function(config, args):
|
||||
|
||||
set_seed(seed)
|
||||
train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name)
|
||||
|
||||
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, return_dict=True, tp_plan=args.tp_plan, tp_size=args.tp_size
|
||||
)
|
||||
|
||||
if args.add_pad_token:
|
||||
if model.config.pad_token_id is None:
|
||||
@ -150,7 +152,13 @@ def training_function(config, args):
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
context = nullcontext
|
||||
if args.tp_plan is not None:
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
|
||||
context = implicit_replication
|
||||
with context():
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
@ -213,12 +221,15 @@ def training_function(config, args):
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump(performance_metric, f)
|
||||
|
||||
# Finally try saving the model
|
||||
accelerator.save_model(model, args.output_dir)
|
||||
# TODO: skip saving of the model test for TP until the feature lands
|
||||
if args.tp_plan is None:
|
||||
# Finally try saving the model
|
||||
accelerator.save_model(model, args.output_dir)
|
||||
accelerator.wait_for_everyone()
|
||||
assert Path(args.output_dir, SAFE_WEIGHTS_NAME).exists(), (
|
||||
"Model was not saved when calling `Accelerator.save_model`"
|
||||
)
|
||||
if args.tp_plan is None:
|
||||
assert Path(args.output_dir, SAFE_WEIGHTS_NAME).exists(), (
|
||||
"Model was not saved when calling `Accelerator.save_model`"
|
||||
)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
@ -255,6 +266,18 @@ def main():
|
||||
default=False,
|
||||
help="To add pad token if not exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_plan",
|
||||
type=str,
|
||||
default=None,
|
||||
help="pass 'auto' to use TP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="TP size to be used to shard the model",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
|
||||
training_function(config, args)
|
||||
|
@ -49,7 +49,7 @@ ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
|
||||
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
|
||||
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
|
||||
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
|
||||
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.47.0"
|
||||
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0"
|
||||
|
||||
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
|
||||
|
||||
|
@ -2034,7 +2034,6 @@ class TorchTensorParallelPlugin:
|
||||
torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
self.tp_size = self.tp_size if os.environ.get("TP_SIZE", "1") == "1" else int(os.environ.get("TP_SIZE", "1"))
|
||||
if self.tp_size == 1:
|
||||
raise ValueError("Provide TP degree > 1.")
|
||||
|
||||
@ -2052,6 +2051,8 @@ class TorchTensorParallelPlugin:
|
||||
|
||||
mesh_dim_name = "tp"
|
||||
|
||||
# device mesh is not used for model sharding
|
||||
# it is only used for preparing data loader
|
||||
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))
|
||||
|
||||
|
||||
|
@ -305,10 +305,6 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
|
||||
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
|
||||
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
|
||||
|
||||
if args.use_tp:
|
||||
current_env["ACCELERATE_USE_TP"] = "true"
|
||||
current_env["TP_SIZE"] = str(args.tp_size)
|
||||
|
||||
if args.use_megatron_lm:
|
||||
prefix = "MEGATRON_LM_"
|
||||
current_env["ACCELERATE_USE_MEGATRON_LM"] = "true"
|
||||
|
@ -48,15 +48,15 @@ class TPIntegrationTest(TempDirTestCase):
|
||||
|
||||
def test_working_of_tp(self):
|
||||
self.test_file_path = self.test_scripts_folder / "test_performance.py"
|
||||
cmd = get_launch_command(
|
||||
num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size
|
||||
)
|
||||
cmd = get_launch_command(num_processes=self.test_tp_size, num_machines=1, machine_rank=0)
|
||||
cmd.extend(
|
||||
[
|
||||
self.test_file_path,
|
||||
f"--output_dir={self.tmpdir}",
|
||||
f"--model_name_or_path={self.model_name_or_path}",
|
||||
"--add_pad_token=true",
|
||||
"--tp_plan=auto",
|
||||
f"--tp_size={self.test_tp_size}",
|
||||
]
|
||||
)
|
||||
with patch_environment(omp_num_threads=1):
|
||||
|
Reference in New Issue
Block a user