mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Replaces 78 assert statements across 10 files in torch.autograd with explicit if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag. This ensures error checking remains active in optimized builds. fix partially #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165627 Approved by: https://github.com/albanD
73 lines
2.4 KiB
Python
73 lines
2.4 KiB
Python
# mypy: allow-untyped-defs
|
|
import operator
|
|
from functools import reduce
|
|
from typing_extensions import deprecated
|
|
|
|
import torch
|
|
import torch._utils
|
|
from torch.autograd.function import Function
|
|
|
|
|
|
class Type(Function):
|
|
@staticmethod
|
|
@deprecated(
|
|
"`torch.autograd._functions.Type` is deprecated as of PyTorch 2.1, "
|
|
"please use `torch.tensor.to(dtype=dtype)` instead.",
|
|
category=FutureWarning,
|
|
)
|
|
# pyrefly: ignore # bad-override
|
|
def forward(ctx, i, dest_type):
|
|
ctx.input_type = type(i)
|
|
ctx.input_device = -1 if not i.is_cuda else i.get_device()
|
|
return i.type(dest_type)
|
|
|
|
@staticmethod
|
|
# pyrefly: ignore # bad-override
|
|
def backward(ctx, grad_output):
|
|
if ctx.input_device == -1:
|
|
return grad_output.type(ctx.input_type), None
|
|
else:
|
|
with torch.accelerator.device_index(ctx.input_device):
|
|
return grad_output.type(ctx.input_type), None
|
|
|
|
|
|
# TODO: deprecate this
|
|
class Resize(Function):
|
|
@staticmethod
|
|
# pyrefly: ignore # bad-override
|
|
def forward(ctx, tensor, sizes):
|
|
ctx.sizes = sizes
|
|
ctx.numel = reduce(operator.mul, sizes, 1)
|
|
if tensor.numel() != ctx.numel:
|
|
raise RuntimeError(
|
|
(
|
|
"requested resize to {} ({} elements in total), "
|
|
"but the given tensor has a size of {} ({} elements). "
|
|
"autograd's resize can only change the shape of a given "
|
|
"tensor, while preserving the number of elements. "
|
|
).format(
|
|
"x".join(map(str, sizes)),
|
|
ctx.numel,
|
|
"x".join(map(str, tensor.size())),
|
|
tensor.numel(),
|
|
)
|
|
)
|
|
ctx.input_sizes = tensor.size()
|
|
if tensor.is_quantized:
|
|
tensor.copy_(tensor)
|
|
return tensor.contiguous().view(*sizes)
|
|
if tensor.is_contiguous():
|
|
result = tensor.new(tensor).contiguous().view(*sizes)
|
|
return result
|
|
else:
|
|
return tensor.contiguous().view(*sizes)
|
|
|
|
@staticmethod
|
|
# pyrefly: ignore # bad-override
|
|
def backward(ctx, grad_output):
|
|
if grad_output.numel() != ctx.numel:
|
|
raise AssertionError(
|
|
f"Expected grad_output to have {ctx.numel} elements, but got {grad_output.numel()}"
|
|
)
|
|
return grad_output.contiguous().view(ctx.input_sizes), None
|