mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This will add pyrefly to lint runner as a warning only - and allow us to collect feedback about the tool before switching to pyrefly as the main type checker. References the steps outlined here: : https://github.com/pytorch/pytorch/issues/163283: test plan: `lintrunner init` `lintrunner` confirm when pyrefly errors are present results look like: https://gist.github.com/maggiemoss/e6cb2d015dd1ded560ae1329098cf33f Pull Request resolved: https://github.com/pytorch/pytorch/pull/165179 Approved by: https://github.com/ezyang
212 lines
7.0 KiB
Python
212 lines
7.0 KiB
Python
from typing import Any
|
|
|
|
import torch
|
|
from torch.fx import Proxy
|
|
|
|
from .. import graph_break_hints
|
|
from ..exc import TYPE_CHECKING, unimplemented_v2
|
|
from .base import VariableTracker
|
|
from .constant import ConstantVariable
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._dynamo.symbolic_convert import InstructionTranslator
|
|
|
|
from ..codegen import PyCodegen
|
|
|
|
from torch._library.custom_ops import custom_op
|
|
|
|
|
|
Tensor = torch.Tensor
|
|
|
|
|
|
@custom_op("streams::fork", mutates_args=())
|
|
def fork_stream(
|
|
from_index: int,
|
|
from_device: torch.device,
|
|
to_index: int,
|
|
to_device: torch.device,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
@fork_stream.register_fake
|
|
def _(
|
|
from_index: int,
|
|
from_device: torch.device,
|
|
to_index: int,
|
|
to_device: torch.device,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
@custom_op("streams::join", mutates_args=())
|
|
def join_stream(
|
|
from_index: int,
|
|
from_device: torch.device,
|
|
to_index: int,
|
|
to_device: torch.device,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
@join_stream.register_fake
|
|
def _(
|
|
from_index: int,
|
|
from_device: torch.device,
|
|
to_index: int,
|
|
to_device: torch.device,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
class StreamVariable(VariableTracker):
|
|
def __init__(
|
|
self,
|
|
proxy: Proxy,
|
|
value: torch.Stream,
|
|
device: torch.device,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
if proxy is not None and "example_value" in proxy.node.meta:
|
|
assert proxy.node.meta["example_value"] == value
|
|
assert value.device.type == device.type, (
|
|
"stream value is not equal to the passed device"
|
|
)
|
|
super().__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.value = value
|
|
# pyrefly: ignore # read-only
|
|
self.device = device
|
|
|
|
def python_type(self) -> type:
|
|
return torch.Stream
|
|
|
|
def call_method(
|
|
self,
|
|
tx: "InstructionTranslator",
|
|
name: str,
|
|
args: list[VariableTracker],
|
|
kwargs: dict[str, VariableTracker],
|
|
) -> "VariableTracker":
|
|
assert hasattr(self.value, name), f"no stream method found named {name}"
|
|
|
|
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
if name in ("wait_stream", "synchronize", "wait_event"):
|
|
tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
)
|
|
return ConstantVariable(None)
|
|
elif name == "query":
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=ConstantVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
),
|
|
)
|
|
elif name == "record_event":
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=EventVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
),
|
|
)
|
|
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
|
|
from ..guards import GuardBuilder, install_guard
|
|
|
|
if self.source:
|
|
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
|
|
|
|
# NB : Checking for mutation is necessary because we compare
|
|
# constant values
|
|
other = args[0]
|
|
if not isinstance(other, StreamVariable):
|
|
return ConstantVariable.create(NotImplemented)
|
|
|
|
if other.source:
|
|
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
|
|
return ConstantVariable.create(
|
|
cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type]
|
|
)
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def as_proxy(self) -> Proxy:
|
|
return self.proxy
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
# If we got here, this stream is fully subsumed by the graph - this means it is
|
|
# not an input or global
|
|
assert not self.source
|
|
# Since we just proved that - for other such structures, like lists and dicts, reconstruction
|
|
# is fine and sound according to dynamo principles of treating collectives. However,
|
|
# streams are special in that we want to preserve the identity of the stream as the same as in the graph
|
|
# Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
|
|
# yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
|
|
# design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
|
|
prefix = f"_stream_{self.device}"
|
|
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
|
codegen.append_output(codegen.create_load_global(name, add=True))
|
|
|
|
|
|
class EventVariable(VariableTracker):
|
|
def __init__(self, proxy: Proxy, value: torch.Event, **kwargs: Any) -> None:
|
|
if proxy is not None and "example_value" in proxy.node.meta:
|
|
assert proxy.node.meta["example_value"] == value
|
|
super().__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.value = value
|
|
|
|
def call_method(
|
|
self,
|
|
tx: "InstructionTranslator",
|
|
name: str,
|
|
args: list[VariableTracker],
|
|
kwargs: dict[str, VariableTracker],
|
|
) -> VariableTracker:
|
|
from ..utils import proxy_args_kwargs
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
if name in ("wait", "record", "synchronize"):
|
|
tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
)
|
|
return ConstantVariable(None)
|
|
elif name == "query":
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=ConstantVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
),
|
|
)
|
|
else:
|
|
method_name = (
|
|
f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}"
|
|
)
|
|
unimplemented_v2(
|
|
gb_type="Unsupported event method",
|
|
context=str(name),
|
|
explanation=f"Dynamo doesn't support tracing the {method_name} method. "
|
|
f"We currently support wait, record, synchronize, and query.",
|
|
hints=[
|
|
*graph_break_hints.SUPPORTABLE,
|
|
],
|
|
)
|
|
|
|
def as_proxy(self) -> Proxy:
|
|
return self.proxy
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
# If we got here, this event is fully subsumed by the graph - this means it is
|
|
# not an input or global
|
|
assert not self.source
|
|
# Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
|
|
prefix = "_event"
|
|
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
|
codegen.append_output(codegen.create_load_global(name, add=True))
|