[examples] upgrade code for seed setting (#3387)

* replace set_seed

* update import
This commit is contained in:
Fanli Lin
2025-02-11 23:31:41 +08:00
committed by GitHub
parent 5cc99e6e02
commit 24f8d0276c

View File

@ -24,6 +24,7 @@ from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
from accelerate import Accelerator
from accelerate.utils import set_seed
########################################################################
@ -93,10 +94,7 @@ def training_function(config, args):
label_to_id = {lbl: i for i, lbl in enumerate(id_to_label)}
# Set the seed before splitting the data.
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(seed)
# Split our filenames between train and validation
random_perm = np.random.permutation(len(file_names))
cut = int(0.8 * len(file_names))