mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is a new version of #15648 based on the latest master branch. Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR. In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.) Fixes https://github.com/pytorch/pytorch/issues/71105 @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797 Approved by: https://github.com/ezyang
181 lines
6.1 KiB
Python
181 lines
6.1 KiB
Python
import functools
|
|
from contextlib import nullcontext
|
|
from typing import Any, Callable, Dict, Sequence
|
|
|
|
import torch
|
|
|
|
import torch._prims
|
|
|
|
import torch._refs
|
|
import torch._refs.nn
|
|
import torch._refs.nn.functional
|
|
import torch._refs.special
|
|
import torch.overrides
|
|
|
|
from torch._prims_common import torch_function_passthrough
|
|
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def torch_to_refs_map():
|
|
"""
|
|
Mapping of torch API functions to torch._refs functions.
|
|
E.g. torch_to_refs_map()[torch.add] == torch._refs.add
|
|
"""
|
|
modules = [
|
|
(torch, torch._refs),
|
|
(torch.nn, torch._refs.nn),
|
|
(torch.nn.functional, torch._refs.nn.functional),
|
|
(torch.special, torch._refs.special),
|
|
(torch.fft, torch._refs.fft),
|
|
(torch.linalg, torch._refs.linalg),
|
|
]
|
|
r: Dict[Any, Any] = {
|
|
torch.Tensor.__invert__: torch._refs.bitwise_not,
|
|
torch.Tensor.__xor__: torch._refs.bitwise_xor,
|
|
torch.Tensor.__and__: torch._refs.bitwise_and,
|
|
torch.Tensor.__or__: torch._refs.bitwise_or,
|
|
torch.Tensor.__eq__: torch._refs.eq,
|
|
torch.Tensor.new_empty: torch._refs.new_empty,
|
|
torch.Tensor.new_full: torch._refs.new_full,
|
|
torch.Tensor.new_zeros: torch._refs.new_zeros,
|
|
torch.Tensor.new_ones: torch._refs.new_ones,
|
|
torch.Tensor.fill_: torch._refs.fill_,
|
|
torch.Tensor.zero_: torch._refs.zero_,
|
|
# TODO: Should these methods be mapped some other way?
|
|
torch.Tensor.copy_: torch._prims.copy_to,
|
|
torch.Tensor.resize: torch._prims.resize,
|
|
}
|
|
for mod_torch, mod_refs in modules:
|
|
for s in mod_refs.__all__: # type: ignore[attr-defined]
|
|
r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)
|
|
|
|
# Support remapping torch.Tensor.foo to _refs.foo
|
|
for s in dir(torch.Tensor):
|
|
if s in torch._refs.__all__:
|
|
r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s)
|
|
return r
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def all_prims():
|
|
"""
|
|
Set of all prim functions, e.g., torch._prims.add in all_prims()
|
|
"""
|
|
return {torch._prims.__dict__.get(s) for s in torch._prims.__all__}
|
|
|
|
|
|
class NvfuserPrimsMode(torch.overrides.TorchFunctionMode):
|
|
"""
|
|
Switches the interpretation of torch.ops.prims.* functions to
|
|
use nvFuser's prims in torch.ops.nvprims.*
|
|
|
|
>>> with NvfuserPrimMode():
|
|
... torch.ops.prims.add(x, y) # calls torch.ops.nvprims.add(x, y)
|
|
|
|
By default, this context manager will fall back on the torch.ops.prims* if the
|
|
nvprim does not exist.
|
|
"""
|
|
|
|
def __torch_function__(
|
|
self,
|
|
orig_func: Callable,
|
|
types: Sequence,
|
|
args: Sequence[Any] = (),
|
|
kwargs: Dict = None,
|
|
):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
if isinstance(orig_func, torch._ops.OpOverload) or isinstance(
|
|
orig_func, torch._ops.OpOverloadPacket
|
|
):
|
|
namespace = str(orig_func).split(".")[0]
|
|
name = str(orig_func).split(".")[1]
|
|
if namespace == "prims":
|
|
nvfunc = getattr(torch.ops.nvprims, name, None)
|
|
if nvfunc is not None:
|
|
return nvfunc(*args, **kwargs)
|
|
return orig_func(*args, **kwargs)
|
|
|
|
|
|
class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
|
"""
|
|
Switches the interpretation of torch.* functions and Tensor methods to
|
|
use PrimTorch refs in torch._refs. (Direct calls to _refs are unaffected.)
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> with TorchRefsMode():
|
|
... torch.add(x, y) # calls torch._refs.add(x, y)
|
|
|
|
By default, this context manager will fall back on the torch.* if the
|
|
ref does not exist; set strict=True to error if this occurs.
|
|
If the ref exists we still would like to fall back on the torch.* sometimes,
|
|
this behavior can be customized by passing a function to should_fallback_fn.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
strict=False,
|
|
should_fallback_fn=lambda *_: False,
|
|
prims_mode_cls=nullcontext,
|
|
):
|
|
self.strict = strict
|
|
self.should_fallback_fn = should_fallback_fn
|
|
self.prims_mode_cls = prims_mode_cls
|
|
|
|
def __torch_function__(
|
|
self,
|
|
orig_func: Callable,
|
|
types: Sequence,
|
|
args: Sequence[Any] = (),
|
|
kwargs: Dict = None,
|
|
):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
# For primitive operations, run them as is without interception
|
|
# Unless we are in prims_mode, in which case we want to use nvprims
|
|
if orig_func in torch_function_passthrough or orig_func in all_prims():
|
|
with self.prims_mode_cls():
|
|
return orig_func(*args, **kwargs)
|
|
mapping = torch_to_refs_map()
|
|
func = mapping.get(orig_func, None)
|
|
if func is not None:
|
|
# If the ref exists query whether we should use it or not
|
|
if self.should_fallback_fn(self, func, args, kwargs):
|
|
return orig_func(*args, **kwargs)
|
|
# torch calls inside func should be interpreted as refs calls
|
|
with torch.overrides.enable_torch_function_mode(self, replace=self.inner):
|
|
return func(*args, **kwargs)
|
|
if self.strict:
|
|
raise RuntimeError(
|
|
f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
|
|
)
|
|
return orig_func(*args, **kwargs)
|
|
|
|
|
|
def _is_node_supported_nvfuser(node):
|
|
return (
|
|
node.op == "call_function"
|
|
and getattr(node.target, "impl_nvfuser", None) is not None
|
|
)
|
|
|
|
|
|
def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
|
|
with torch.overrides.enable_torch_function_mode(
|
|
torch_function_mode, replace=torch_function_mode.inner
|
|
):
|
|
gm = get_isolated_graphmodule(func, args, kwargs)
|
|
|
|
call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
|
|
any_unsupported = any(
|
|
not _is_node_supported_nvfuser(node) for node in call_function_nodes
|
|
)
|
|
return any_unsupported
|
|
|
|
|
|
TorchRefsNvfuserCapabilityMode = functools.partial(
|
|
TorchRefsMode,
|
|
should_fallback_fn=_is_func_unsupported_nvfuser,
|
|
prims_mode_cls=NvfuserPrimsMode,
|
|
)
|