mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129754 Approved by: https://github.com/ezyang
35 lines
805 B
Python
35 lines
805 B
Python
import torch
|
|
|
|
|
|
NUM_REPEATS = 1000
|
|
NUM_REPEAT_OF_REPEATS = 1000
|
|
|
|
|
|
class SubTensor(torch.Tensor):
|
|
pass
|
|
|
|
|
|
class WithTorchFunction:
|
|
def __init__(self, data, requires_grad=False):
|
|
if isinstance(data, torch.Tensor):
|
|
self._tensor = data
|
|
return
|
|
|
|
self._tensor = torch.tensor(data, requires_grad=requires_grad)
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
return WithTorchFunction(args[0]._tensor + args[1]._tensor)
|
|
|
|
|
|
class SubWithTorchFunction(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
return super().__torch_function__(func, types, args, kwargs)
|