mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Almost there! Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (5,064 ignored) Only four directories left to enable Pull Request resolved: https://github.com/pytorch/pytorch/pull/164877 Approved by: https://github.com/oulgen
38 lines
884 B
Python
38 lines
884 B
Python
"""This module converts objects into numpy array."""
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
|
def make_np(x: torch.Tensor) -> np.ndarray:
|
|
"""
|
|
Convert an object into numpy array.
|
|
|
|
Args:
|
|
x: An instance of torch tensor
|
|
|
|
Returns:
|
|
numpy.array: Numpy array
|
|
"""
|
|
if isinstance(x, np.ndarray):
|
|
return x
|
|
if np.isscalar(x):
|
|
return np.array([x])
|
|
if isinstance(x, torch.Tensor):
|
|
if x.device.type == "meta":
|
|
return np.random.randn(1)
|
|
return _prepare_pytorch(x)
|
|
raise NotImplementedError(
|
|
f"Got {type(x)}, but numpy array or torch tensor are expected."
|
|
)
|
|
|
|
|
|
def _prepare_pytorch(x: torch.Tensor) -> np.ndarray:
|
|
if x.dtype == torch.bfloat16:
|
|
x = x.to(torch.float16)
|
|
# pyrefly: ignore # bad-assignment
|
|
x = x.detach().cpu().numpy()
|
|
# pyrefly: ignore # bad-return
|
|
return x
|