Compare commits

...

4 Commits

Author SHA1 Message Date
82e19d9f45 Minor polishing 2025-11-04 08:35:55 -08:00
ed07be351a Fix linter, tests 2025-11-03 13:10:52 -08:00
745606c796 Improve typing of ctx_manager 2025-11-03 10:54:00 -08:00
9820038308 First pass at typing 2025-11-03 08:59:03 -08:00
7 changed files with 386 additions and 228 deletions

View File

@ -5,6 +5,8 @@ from torch import Tensor
# Defined in torch/csrc/functorch/init.cpp
def set_inplace_requires_grad_allowed(allowed: bool) -> None: ...
def get_inplace_requires_grad_allowed() -> bool: ...
def _set_dynamic_layer_keys_included(included: bool) -> None: ...
def get_unwrapped(tensor: Tensor) -> Tensor: ...
def is_batchedtensor(tensor: Tensor) -> bool: ...

View File

@ -64,7 +64,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
del __func
del __name
@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) # type: ignore[arg-type]
def tree_is_leaf(
tree: PyTree,
/,
@ -79,7 +79,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
return True
return False
@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False)
@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) # type: ignore[arg-type]
def tree_iter(
tree: PyTree,
/,
@ -110,7 +110,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_iter"]
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type]
def tree_leaves(
tree: PyTree,
/,
@ -451,7 +451,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
dict,
metadata,
entries,
unflatten_func,
unflatten_func, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
)
@ -507,7 +507,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
type(node),
metadata,
entries,
unflatten_func,
unflatten_func, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
) # type: ignore[arg-type]
@ -557,7 +557,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_unflatten"]
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True) # type: ignore[arg-type]
def tree_map(
func: Callable[..., Any],
tree: PyTree,
@ -578,7 +578,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_map"]
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) # type: ignore[arg-type]
def tree_map_(
func: Callable[..., Any],
tree: PyTree,
@ -600,7 +600,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_map_"]
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr]
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr, attr-defined]
@substitute_in_graph( # type: ignore[arg-type]
_none_unflatten,

View File

@ -434,12 +434,15 @@ class BlockStackEntry:
else:
return ReenterWith(self.stack_index - 1)
def exit(self, tx: InstructionTranslatorBase, is_graph_break: bool) -> None:
def exit(
self, tx: InstructionTranslatorBase, is_graph_break: bool
) -> VariableTracker | None:
assert self.with_context is not None
if (
is_graph_break and self.with_context.exit_on_graph_break()
) or not is_graph_break:
return self.with_context.exit(tx) # type: ignore[arg-type]
return None
class SpeculationLogDivergence(AssertionError):
@ -3860,7 +3863,7 @@ class InstructionTranslatorBase(
else:
self.block_stack.append(BlockStackEntry(inst, target, len(self.stack)))
return ctx.enter(self)
return ctx.enter(self) # type: ignore[arg-type]
@staticmethod
def unsupported_ctx_graph_break(ctx: VariableTracker) -> NoReturn:

File diff suppressed because it is too large Load Diff

View File

@ -116,11 +116,7 @@ class StreamContextVariable(FxTracebackAnnotateVariable):
**kwargs,
)
def __init__(
self,
stream: Optional["StreamVariable"],
**kwargs: dict[str, Any],
) -> None:
def __init__(self, stream: Optional["StreamVariable"], **kwargs: Any) -> None:
self.stream = stream
super().__init__(
target_values={"stream": self.get_stream().user_object_index},
@ -129,14 +125,16 @@ class StreamContextVariable(FxTracebackAnnotateVariable):
)
def enter(
self, tx: "InstructionTranslator", *args: tuple[Any]
) -> "VariableTracker":
self, tx: "InstructionTranslator", *args: VariableTracker
) -> VariableTracker:
# to stream, from stream is the order of the arguments
# we are entering the target, and leaving the initial stream
tx.symbolic_stream_state.enter_stream(self.get_stream())
return super().enter(tx)
def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker":
def exit(
self, tx: "InstructionTranslator", *args: VariableTracker
) -> VariableTracker:
# to stream, from stream is the order of the arguments
# we are leaving the target, and entering the initial stream
tx.symbolic_stream_state.exit_stream()
@ -182,7 +180,7 @@ class StreamVariable(StreamContextVariable):
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
assert hasattr(self.value, name), f"no stream method found named {name}"
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs

View File

@ -408,6 +408,7 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
torch.cuda.amp.autocast,
torch.cpu.amp.autocast,
):
# pyrefly: ignore [bad-argument-type]
return AutocastModeVariable.create(self.value, args, kwargs)
elif self.value in (
# NOTE any class added here must align with the semantic

View File

@ -163,7 +163,8 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable):
if value is not None:
super().__init__(value, **kwargs)
self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code
# needed for BC with calling enter from CM code
self.cm_obj = value # type: ignore[assignment]
self.source = source # type: ignore[assignment]
def reconstruct(self, codegen: "PyCodegen") -> None: