mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
make torch_native_parallelism examples device agnostic (#3759)
* make torch_native_parallelism examples device agnostic Signed-off-by: YAO Matrix <matrix.yao@intel.com> * xxx Signed-off-by: YAO Matrix <matrix.yao@intel.com> * xxx Signed-off-by: YAO Matrix <matrix.yao@intel.com> * Style + deprecation warning --------- Signed-off-by: YAO Matrix <matrix.yao@intel.com> Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
This commit is contained in:
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
||||
|
||||
from accelerate.utils import ParallelismConfig
|
||||
@ -28,7 +29,7 @@ def parse_args():
|
||||
parser.add_argument("--checkpoint-frequency", type=int, default=100)
|
||||
parser.add_argument("--model-name", type=str, default=MODEL_ID)
|
||||
parser.add_argument("--save-dir", type=str, default=f"./accelerate-nd-parallel-{MODEL_ID.split('/')[-1]}")
|
||||
parser.add_argument("--device-type", type=str, default="cuda")
|
||||
parser.add_argument("--device-type", type=str, default="auto")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -38,6 +39,9 @@ def main():
|
||||
pc = ParallelismConfig()
|
||||
args = parse_args()
|
||||
|
||||
if args.device_type == "auto":
|
||||
args.device_type = torch.accelerator.current_accelerator().type
|
||||
|
||||
model_kwargs = {}
|
||||
if pc.tp_enabled:
|
||||
model_kwargs["tp_plan"] = "auto"
|
||||
@ -66,7 +70,7 @@ def main():
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=packed_dataset,
|
||||
)
|
||||
|
||||
|
@ -202,16 +202,18 @@ def setup_tokenizer(model_id: str) -> AutoTokenizer:
|
||||
|
||||
|
||||
def gpu_memory_usage_all(device=0):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
device_type = torch.accelerator.current_accelerator().type
|
||||
device = torch.device(f"{device_type}:{device}")
|
||||
torch_device_module = getattr(torch, device_type, torch.cuda)
|
||||
_BYTES_IN_GIB = 1024**3
|
||||
peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
|
||||
peak_memory_alloc = torch.cuda.max_memory_allocated(device) / _BYTES_IN_GIB
|
||||
peak_memory_reserved = torch.cuda.max_memory_reserved(device) / _BYTES_IN_GIB
|
||||
peak_memory_active = torch_device_module.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
|
||||
peak_memory_alloc = torch_device_module.max_memory_allocated(device) / _BYTES_IN_GIB
|
||||
peak_memory_reserved = torch_device_module.max_memory_reserved(device) / _BYTES_IN_GIB
|
||||
memory_stats = {
|
||||
"peak_memory_active": peak_memory_active,
|
||||
"peak_memory_alloc": peak_memory_alloc,
|
||||
"peak_memory_reserved": peak_memory_reserved,
|
||||
}
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
torch_device_module.reset_peak_memory_stats(device)
|
||||
|
||||
return memory_stats
|
||||
|
Reference in New Issue
Block a user