mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied. - #94587 - #94588 - #94592 Also, methods with only a `super()` call are removed: ```diff class MyModule(nn.Module): - def __init__(self): - super().__init__() - def forward(self, ...): ... ``` Some cases that change the semantics should be kept unchanged. E.g.:f152a79be9/caffe2/python/net_printer.py (L184-L190)
f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94587 Approved by: https://github.com/ezyang
17 lines
368 B
Python
17 lines
368 B
Python
import torch
|
|
from utils import NUM_LOOP_ITERS
|
|
|
|
def add_tensors_loop(x, y):
|
|
z = torch.add(x, y)
|
|
for i in range(NUM_LOOP_ITERS):
|
|
z = torch.add(z, x)
|
|
return z
|
|
|
|
class SimpleAddModule(torch.nn.Module):
|
|
def __init__(self, add_op):
|
|
super().__init__()
|
|
self.add_op = add_op
|
|
|
|
def forward(self, x, y):
|
|
return self.add_op(x, y)
|