mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
[examples] upgrade code for seed setting (#3387)
* replace set_seed * update import
This commit is contained in:
@ -24,6 +24,7 @@ from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
|
||||
########################################################################
|
||||
@ -93,10 +94,7 @@ def training_function(config, args):
|
||||
label_to_id = {lbl: i for i, lbl in enumerate(id_to_label)}
|
||||
|
||||
# Set the seed before splitting the data.
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
set_seed(seed)
|
||||
# Split our filenames between train and validation
|
||||
random_perm = np.random.permutation(len(file_names))
|
||||
cut = int(0.8 * len(file_names))
|
||||
|
Reference in New Issue
Block a user