Apply TorchFix TOR203 fixes (#143691)

Codemodded via `torchfix . --select=TOR203 --fix`.
This is a step to unblock https://github.com/pytorch/pytorch/pull/141076
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143691
Approved by: https://github.com/malfet
This commit is contained in:
Sergii Dymchenko
2024-12-23 18:21:03 +00:00
committed by PyTorch MergeBot
parent c042c8a475
commit 727ee853b4
6 changed files with 6 additions and 8 deletions

View File

@ -1,4 +1,4 @@
import torchvision.models as models from torchvision import models
import torch import torch
import torch.autograd.profiler as profiler import torch.autograd.profiler as profiler

View File

@ -1,8 +1,8 @@
import time import time
import torchvision.models as models
from opacus import PrivacyEngine from opacus import PrivacyEngine
from opacus.utils.module_modification import convert_batchnorm_modules from opacus.utils.module_modification import convert_batchnorm_modules
from torchvision import models
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -12,9 +12,8 @@ import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
import numpy as np import numpy as np
import torchvision.transforms as transforms
from opacus import PrivacyEngine from opacus import PrivacyEngine
from torchvision import models from torchvision import models, transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from tqdm import tqdm from tqdm import tqdm

View File

@ -12,8 +12,7 @@ import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
import numpy as np import numpy as np
import torchvision.transforms as transforms from torchvision import models, transforms
from torchvision import models
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from tqdm import tqdm from tqdm import tqdm

View File

@ -22,8 +22,8 @@ import os
import os.path import os.path
import numpy as np import numpy as np
import torchvision.transforms as transforms
from PIL import Image from PIL import Image
from torchvision import transforms
import torch import torch
import torch.utils.data as data import torch.utils.data as data

View File

@ -4541,7 +4541,7 @@ class TestExamplesCorrectness(TestCase):
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
@parametrize("mechanism", ["make_functional", "functional_call"]) @parametrize("mechanism", ["make_functional", "functional_call"])
def test_resnet18_per_sample_grads(self, device, mechanism): def test_resnet18_per_sample_grads(self, device, mechanism):
import torchvision.models as models from torchvision import models
model = models.__dict__["resnet18"]( model = models.__dict__["resnet18"](
pretrained=False, norm_layer=(lambda c: nn.GroupNorm(min(32, c), c)) pretrained=False, norm_layer=(lambda c: nn.GroupNorm(min(32, c), c))