mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pyrefly suppressions 4/n (#164615)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1 after: 0 errors (2,753 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
4bd1505f84
commit
4ab847bbc7
@ -29,15 +29,11 @@ project-excludes = [
|
||||
"torch/fx/**",
|
||||
"torch/distributions/**",
|
||||
"torch/onnx/**",
|
||||
"torch/_refs/**",
|
||||
"torch/_export/**",
|
||||
"torch/jit/**",
|
||||
"torch/optim/**",
|
||||
"torch/_higher_order_ops/**",
|
||||
# formatting issues
|
||||
"torch/linalg/__init__.py",
|
||||
"torch/package/importer.py",
|
||||
"torch/package/_package_pickler.py",
|
||||
"torch/jit/annotations.py",
|
||||
# ====
|
||||
"benchmarks/instruction_counts/main.py",
|
||||
"benchmarks/instruction_counts/definitions/setup.py",
|
||||
|
@ -1097,6 +1097,7 @@ class TS2FXGraphConverter:
|
||||
|
||||
# Update the value of loop local variables.
|
||||
if node.outputsSize() >= 1:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for i, outp in enumerate(node.outputs()):
|
||||
output_name = outp.debugName()
|
||||
self.name_to_node[output_name] = self.fx_graph.call_function(
|
||||
@ -1109,6 +1110,7 @@ class TS2FXGraphConverter:
|
||||
fx_block_args[i] = self.name_to_node[output_name]
|
||||
|
||||
# Update the value of global variables, whose values are modified inplace.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for i, name in enumerate(
|
||||
subgraph_converter.name_update_from_subblock_to_parent
|
||||
):
|
||||
|
@ -132,6 +132,7 @@ def _make_export_case(m, name, configs):
|
||||
m.__doc__ is not None
|
||||
), f"Could not find description or docstring for export case: {m}"
|
||||
configs = {**configs, "description": m.__doc__}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return ExportCase(**{**configs, "model": m, "name": name})
|
||||
|
||||
|
||||
|
@ -3,10 +3,12 @@ import torch
|
||||
|
||||
class MyAutogradFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, x):
|
||||
return x.clone()
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output + 1
|
||||
|
||||
|
@ -39,6 +39,7 @@ def get_class_if_classified_error(e: Exception) -> Optional[str]:
|
||||
TorchRuntimeError: None,
|
||||
}
|
||||
if type(e) in _ALLOW_LIST:
|
||||
# pyrefly: ignore # index-error
|
||||
attr_name = _ALLOW_LIST[type(e)]
|
||||
if attr_name is None:
|
||||
return ALWAYS_CLASSIFIED
|
||||
|
@ -101,6 +101,7 @@ class _KeyPathTrie:
|
||||
assert len(kp) > 0
|
||||
k, *kp = kp # type: ignore[assignment]
|
||||
node = node[k]
|
||||
# pyrefly: ignore # bad-return
|
||||
return node, kp
|
||||
|
||||
|
||||
@ -139,6 +140,7 @@ def key_path_to_source(
|
||||
source: Source = LocalSource("args")
|
||||
else:
|
||||
source, kp = sourced_prefixes.get(kp)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for k in kp:
|
||||
if isinstance(k, SequenceKey):
|
||||
source = GetItemSource(source, k.idx)
|
||||
@ -354,10 +356,12 @@ def _override_builtin_ops():
|
||||
original_min = builtins.min
|
||||
original_pow = math.pow
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
builtins.max = functools.partial(
|
||||
_tensor_min_max, real_callable=original_max, tensor_callable=torch.maximum
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
builtins.min = functools.partial(
|
||||
_tensor_min_max, real_callable=original_min, tensor_callable=torch.minimum
|
||||
)
|
||||
@ -1083,6 +1087,7 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
|
||||
|
||||
def run():
|
||||
# Run sequence.
|
||||
# pyrefly: ignore # index-error
|
||||
t = args[0]
|
||||
for _method, _args in sequence:
|
||||
t = _method(t, *_args)
|
||||
|
@ -188,6 +188,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
self.callback = callback
|
||||
self.node: torch.fx.Node = next(iter(gm.graph.nodes))
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def placeholder(
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
@ -439,6 +440,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
)
|
||||
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
|
||||
interpreter = self.ExportInterpreter(self, graph_module)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
prev_interpreter, self.interpreter = (
|
||||
self.interpreter,
|
||||
torch.fx.Interpreter( # type: ignore[assignment]
|
||||
|
@ -32,6 +32,7 @@ def _node_metadata_hook(
|
||||
that nodes being added are only call_function nodes, and copies over the
|
||||
first argument node's nn_module_stack.
|
||||
"""
|
||||
# pyrefly: ignore # bad-assignment
|
||||
fake_mode = fake_mode or contextlib.nullcontext()
|
||||
|
||||
assert node.op == "call_function" and callable(node.target), (
|
||||
@ -47,6 +48,7 @@ def _node_metadata_hook(
|
||||
fake_args, fake_kwargs = pytree.tree_map_only(
|
||||
torch.fx.Node, lambda arg: arg.meta["val"], (node.args, node.kwargs)
|
||||
)
|
||||
# pyrefly: ignore # bad-context-manager
|
||||
with fake_mode, enable_python_dispatcher():
|
||||
fake_res = node.target(*fake_args, **fake_kwargs)
|
||||
node.meta["val"] = fake_res
|
||||
@ -81,7 +83,9 @@ def _node_metadata_hook(
|
||||
node.meta["torch_fn"] = node.meta.get(
|
||||
"torch_fn",
|
||||
(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
f"{node.target.__name__}_0",
|
||||
# pyrefly: ignore # missing-attribute
|
||||
f"{node.target.__class__.__name__}.{node.target.__name__}",
|
||||
),
|
||||
)
|
||||
|
@ -567,6 +567,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
|
||||
quantized = False
|
||||
|
||||
last_quantized_node = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, OpOverload):
|
||||
with gm.graph.inserting_before(node):
|
||||
@ -629,6 +630,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
|
||||
attr_names_to_clean.add(k)
|
||||
if k == "_buffers":
|
||||
buffer_name_to_clean = set()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for b_name, b_value in v.items():
|
||||
if isinstance(b_value, torch.Tensor) and b_value.dtype in [
|
||||
torch.qint8,
|
||||
@ -636,6 +638,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
|
||||
]:
|
||||
buffer_name_to_clean.add(b_name)
|
||||
for b_name in buffer_name_to_clean:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
v.pop(b_name, None)
|
||||
for attr_name in attr_names_to_clean:
|
||||
delattr(submod, attr_name)
|
||||
|
@ -35,6 +35,7 @@ def _replace_with_hop_helper(
|
||||
)
|
||||
call_func_node.meta["torch_fn"] = (
|
||||
f"{wrap_hoo.__name__}",
|
||||
# pyrefly: ignore # missing-attribute
|
||||
f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}",
|
||||
)
|
||||
if isinstance(output_args, (tuple, list)):
|
||||
|
@ -54,6 +54,7 @@ def _postprocess_serialized_shapes(
|
||||
)
|
||||
for k, v in sorted(dims.items())
|
||||
}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims)
|
||||
if to_dict:
|
||||
return _dataclass_to_dict(spec)
|
||||
@ -183,6 +184,7 @@ def _dump_dynamic_shapes(
|
||||
kwargs = kwargs or {}
|
||||
if isinstance(dynamic_shapes, dict):
|
||||
dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment]
|
||||
# pyrefly: ignore # bad-assignment, bad-argument-type
|
||||
dynamic_shapes = tuple(dynamic_shapes)
|
||||
combined_args = tuple(args) + tuple(kwargs.values())
|
||||
|
||||
|
@ -623,7 +623,9 @@ class _Commit:
|
||||
def update_schema():
|
||||
import importlib.resources
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if importlib.resources.is_resource(__package__, "schema.yaml"):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
content = importlib.resources.read_text(__package__, "schema.yaml")
|
||||
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
|
||||
_check(match is not None, "checksum not found in schema.yaml")
|
||||
@ -631,7 +633,9 @@ def update_schema():
|
||||
checksum_head = match.group(1)
|
||||
|
||||
thrift_content = importlib.resources.read_text(
|
||||
__package__, "export_schema.thrift"
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
__package__,
|
||||
"export_schema.thrift",
|
||||
)
|
||||
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", thrift_content)
|
||||
_check(match is not None, "checksum not found in export_schema.thrift")
|
||||
@ -654,7 +658,9 @@ def update_schema():
|
||||
|
||||
src, cpp_header, thrift_schema = _staged_schema()
|
||||
additions, subtractions = _diff_schema(dst, src)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
yaml_path = __package__.replace(".", "/") + "/schema.yaml"
|
||||
# pyrefly: ignore # missing-attribute
|
||||
thrift_schema_path = __package__.replace(".", "/") + "/export_schema.thrift"
|
||||
torch_prefix = "torch/"
|
||||
assert yaml_path.startswith(torch_prefix) # sanity check
|
||||
|
@ -383,6 +383,7 @@ def _reconstruct_fake_tensor(
|
||||
fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta)
|
||||
if is_parameter:
|
||||
fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment]
|
||||
# pyrefly: ignore # bad-return
|
||||
return fake_tensor
|
||||
|
||||
|
||||
@ -2740,6 +2741,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
serialized_node.metadata
|
||||
)
|
||||
assert arg is not None
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
|
||||
fx_node.meta["val"] = tuple(meta_val)
|
||||
self.serialized_name_to_node[fx_node.name] = fx_node
|
||||
@ -3165,6 +3167,7 @@ def _dict_to_dataclass(cls, data):
|
||||
_value = next(iter(data.values()))
|
||||
assert isinstance(_type, str)
|
||||
field_type = cls.__annotations__[_type]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return cls.create(**{_type: _dict_to_dataclass(field_type, _value)})
|
||||
elif dataclasses.is_dataclass(cls):
|
||||
fields = {}
|
||||
@ -3471,18 +3474,23 @@ def _canonicalize_graph(
|
||||
n.metadata.clear()
|
||||
|
||||
# Stage 4: Aggregate values.
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
sorted_tensor_values = dict(
|
||||
sorted(graph.tensor_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
sorted_sym_int_values = dict(
|
||||
sorted(graph.sym_int_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
sorted_sym_float_values = dict(
|
||||
sorted(graph.sym_float_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
sorted_sym_bool_values = dict(
|
||||
sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
sorted_custom_obj_values = dict(
|
||||
sorted(graph.custom_obj_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
@ -3539,6 +3547,7 @@ def canonicalize(
|
||||
ExportedProgram: The canonicalized exported program.
|
||||
"""
|
||||
ep = copy.deepcopy(ep)
|
||||
# pyrefly: ignore # annotation-mismatch
|
||||
constants: set[str] = constants or set()
|
||||
|
||||
opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0)))
|
||||
|
@ -34,7 +34,7 @@ from torch.fx._pytree import (
|
||||
_deregister_pytree_flatten_spec,
|
||||
register_pytree_flatten_spec,
|
||||
)
|
||||
from torch.utils._pytree import (
|
||||
from torch.utils._pytree import ( # pyrefly: ignore # deprecated
|
||||
_deregister_pytree_node,
|
||||
_register_pytree_node,
|
||||
Context,
|
||||
@ -470,7 +470,14 @@ def _check_input_constraints_for_graph(
|
||||
)
|
||||
elif isinstance(node_val, torch.SymInt):
|
||||
_check_symint(
|
||||
node_val, arg, range_constraints, unification_map, key_path, None
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
node_val,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
arg,
|
||||
range_constraints,
|
||||
unification_map,
|
||||
key_path,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@ -1115,12 +1122,14 @@ def placeholder_naming_pass(
|
||||
if ( # handle targets for custom objects
|
||||
spec.kind == InputKind.CUSTOM_OBJ and spec.target in name_map
|
||||
):
|
||||
# pyrefly: ignore # index-error
|
||||
spec.target = name_map[spec.target][4:] # strip obj_ prefix
|
||||
|
||||
for spec in export_graph_signature.output_specs:
|
||||
if spec.arg.name in name_map:
|
||||
spec.arg.name = name_map[spec.arg.name]
|
||||
if spec.kind == OutputKind.USER_INPUT_MUTATION and spec.target in name_map:
|
||||
# pyrefly: ignore # index-error
|
||||
spec.target = name_map[spec.target]
|
||||
|
||||
# rename keys in constants dict for custom objects
|
||||
|
@ -96,6 +96,7 @@ class AssociativeScanOp(HigherOrderOperator):
|
||||
validate_subgraph_args_types(additional_inputs)
|
||||
return super().__call__(combine_fn, xs, additional_inputs)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def gen_schema(self, combine_fn, xs, additional_inputs):
|
||||
from torch._higher_order_ops.schema import HopSchemaGenerator
|
||||
from torch._higher_order_ops.utils import materialize_as_graph
|
||||
@ -648,6 +649,7 @@ class AssociativeScanAutogradOp(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
ctx,
|
||||
combine_fn,
|
||||
|
@ -609,6 +609,7 @@ def do_auto_functionalize_v2(
|
||||
normalized_kwargs = {}
|
||||
|
||||
schema = op._schema
|
||||
# pyrefly: ignore # bad-assignment
|
||||
op = op._op if isinstance(op, HopInstance) else op
|
||||
assert isinstance(op, get_args(_MutableOpType))
|
||||
|
||||
|
@ -170,6 +170,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
|
||||
out = self(functionalized_subgraph, *unwrapped_operands, **kwargs)
|
||||
return ctx.wrap_tensors(out)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def gen_schema(self, subgraph, *operands, **kwargs):
|
||||
from .schema import HopSchemaGenerator
|
||||
|
||||
@ -214,6 +215,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
|
||||
|
||||
class BaseHOPFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, hop, subgraph, kwargs, *operands):
|
||||
ctx.hop = hop
|
||||
ctx.operands = operands
|
||||
|
@ -52,6 +52,7 @@ class CondOp(HigherOrderOperator):
|
||||
validate_subgraph_args_types(operands)
|
||||
return super().__call__(pred, true_fn, false_fn, operands)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def gen_schema(self, pred, true_fn, false_fn, operands):
|
||||
from torch._higher_order_ops.schema import HopSchemaGenerator
|
||||
from torch._higher_order_ops.utils import materialize_as_graph
|
||||
@ -284,6 +285,7 @@ def cond_op_dense(pred, true_fn, false_fn, operands):
|
||||
|
||||
class CondAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
ctx,
|
||||
pred,
|
||||
|
@ -298,4 +298,5 @@ def handle_effects(
|
||||
assert isinstance(wrapped_token, torch.Tensor)
|
||||
tokens[key] = wrapped_token
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return ctx.wrap_tensors(unwrapped_outs)
|
||||
|
@ -354,12 +354,18 @@ def trace_flex_attention(
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", flex_attention, proxy_args, {}
|
||||
)
|
||||
return track_tensor_tree(
|
||||
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
example_out,
|
||||
out_proxy,
|
||||
constant=None,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
tracer=proxy_mode.tracer,
|
||||
)
|
||||
|
||||
|
||||
@ -621,6 +627,7 @@ def create_fw_bw_graph(
|
||||
|
||||
class FlexAttentionAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
ctx: Any,
|
||||
query: Tensor,
|
||||
@ -1063,6 +1070,7 @@ def trace_flex_attention_backward(
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function",
|
||||
@ -1072,7 +1080,12 @@ def trace_flex_attention_backward(
|
||||
name="flex_attention_backward",
|
||||
)
|
||||
return track_tensor_tree(
|
||||
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
example_out,
|
||||
out_proxy,
|
||||
constant=None,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
tracer=proxy_mode.tracer,
|
||||
)
|
||||
|
||||
|
||||
|
@ -86,6 +86,7 @@ class InvokeSubgraphHOP(HigherOrderOperator):
|
||||
|
||||
return super().__call__(subgraph, identifier, *operands)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def gen_schema(self, subgraph, identifier, *operands):
|
||||
from torch._higher_order_ops.schema import HopSchemaGenerator
|
||||
from torch._higher_order_ops.utils import (
|
||||
@ -401,6 +402,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
ctx,
|
||||
subgraph,
|
||||
@ -477,6 +479,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
||||
for tangent in filtered_grad_outs:
|
||||
metadata = extract_tensor_metadata(tangent)
|
||||
metadata._flatten_into(tangent_metadata, fake_mode, state)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
tangent_metadata = tuple(tangent_metadata)
|
||||
|
||||
# bw_graph is a joint graph with signature (*primals_and_tangents) and
|
||||
|
@ -197,6 +197,7 @@ def create_hop_fw_bw(
|
||||
|
||||
class LocalMapAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
ctx: Any,
|
||||
fw_gm: GraphModule,
|
||||
@ -243,6 +244,7 @@ class LocalMapAutogradOp(torch.autograd.Function):
|
||||
)
|
||||
|
||||
for i, meta in ctx.expected_tangent_metadata.items():
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
grads[i] = coerce_to_expected_memory_format(grads[i], meta)
|
||||
|
||||
grad_ins = local_map_hop(ctx.bw_gm, *saved_activations, *grads)
|
||||
|
@ -125,6 +125,7 @@ def map(
|
||||
|
||||
class MapAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, f, num_mapped_args, *flat_args):
|
||||
ctx._f = f
|
||||
ctx._num_mapped_args = num_mapped_args
|
||||
|
@ -241,6 +241,7 @@ class ScanOp(HigherOrderOperator):
|
||||
validate_subgraph_args_types(additional_inputs)
|
||||
return super().__call__(combine_fn, init, xs, additional_inputs)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def gen_schema(self, combine_fn, init, xs, additional_inputs):
|
||||
from torch._higher_order_ops.schema import HopSchemaGenerator
|
||||
from torch._higher_order_ops.utils import materialize_as_graph
|
||||
@ -448,6 +449,7 @@ class ScanAutogradOp(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
ctx,
|
||||
hop_partitioned_graph,
|
||||
|
@ -292,6 +292,7 @@ def generate_ttir(
|
||||
ordered_args[name] = 2
|
||||
elif (
|
||||
stable_meta := maybe_unpack_tma_stable_metadata(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
tma_descriptor_metadata.get(name, None)
|
||||
)
|
||||
) is not None:
|
||||
@ -425,6 +426,7 @@ def generate_ttir(
|
||||
specialize_value=not kp.do_not_specialize,
|
||||
align=not kp.do_not_specialize_on_alignment,
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
attrvals.append(spec[1])
|
||||
|
||||
attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
|
||||
@ -443,6 +445,7 @@ def generate_ttir(
|
||||
def get_signature_value(idx: int, arg: Any) -> str:
|
||||
if kernel.params[idx].is_constexpr:
|
||||
return "constexpr"
|
||||
# pyrefly: ignore # not-callable
|
||||
return mangle_type(arg)
|
||||
|
||||
else:
|
||||
@ -815,6 +818,7 @@ def get_tma_stores(
|
||||
for op in op_list:
|
||||
if op.name == "tt.call":
|
||||
assert op.fn_call_name in functions
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
tma_stores = get_tma_stores(functions, op.fn_call_name)
|
||||
for i, inp in enumerate(op.args):
|
||||
if Param(idx=i) in tma_stores:
|
||||
@ -895,7 +899,11 @@ def analyze_kernel_mutations(
|
||||
if op.name == "tt.call":
|
||||
assert op.fn_call_name in functions
|
||||
mutations = analyze_kernel_mutations(
|
||||
functions, op.fn_call_name, len(op.args)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
functions,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
op.fn_call_name,
|
||||
len(op.args),
|
||||
)
|
||||
stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated)
|
||||
else:
|
||||
@ -948,6 +956,7 @@ def identify_mutated_tensors(
|
||||
assert functions is not None
|
||||
kernel_name = next(iter(functions.keys()))
|
||||
# Triton codegen modifies the name
|
||||
# pyrefly: ignore # missing-attribute
|
||||
assert kernel.fn.__name__ in kernel_name
|
||||
# Reset the cache between top level invocations
|
||||
# The cache for analyze kernel mutations is mainly used for cycle
|
||||
@ -1051,7 +1060,11 @@ def triton_kernel_wrapper_mutation_dense(
|
||||
grid_fn = grid[0]
|
||||
else:
|
||||
fn_name, code = user_defined_kernel_grid_fn_code(
|
||||
kernel.fn.__name__, kernel.configs, grid
|
||||
# pyrefly: ignore # missing-attribute
|
||||
kernel.fn.__name__,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
kernel.configs,
|
||||
grid,
|
||||
)
|
||||
namespace: dict[str, Any] = {}
|
||||
exec(code, namespace)
|
||||
@ -1100,6 +1113,7 @@ def triton_kernel_wrapper_mutation_dense(
|
||||
# avoid mutating the original inputs
|
||||
kwargs = kwargs.copy()
|
||||
constant_args = constant_args.copy()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for name in kernel.arg_names:
|
||||
if name in kwargs:
|
||||
args.append(kwargs.pop(name))
|
||||
@ -1108,6 +1122,7 @@ def triton_kernel_wrapper_mutation_dense(
|
||||
else:
|
||||
break
|
||||
|
||||
# pyrefly: ignore # index-error
|
||||
kernel[grid_fn](*args, **kwargs, **constant_args)
|
||||
|
||||
|
||||
@ -1513,6 +1528,7 @@ class TritonHOPifier:
|
||||
|
||||
assert kernel_idx is None or variable.kernel_idx == kernel_idx
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
variable.grid = grid
|
||||
|
||||
if isinstance(kernel, Autotuner):
|
||||
@ -2057,6 +2073,7 @@ class TraceableTritonKernelWrapper:
|
||||
return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
|
||||
else:
|
||||
assert self.kernel is not None
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return self.kernel.run(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any:
|
||||
@ -2068,6 +2085,7 @@ class TraceableTritonKernelWrapper:
|
||||
)
|
||||
else:
|
||||
assert self.kernel is not None
|
||||
# pyrefly: ignore # index-error
|
||||
return self.kernel[self.grid](*args, **kwargs)
|
||||
|
||||
def specialize_symbolic(self, arg: Sequence[Any]) -> Any:
|
||||
|
@ -270,6 +270,7 @@ def _set_compilation_env():
|
||||
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
|
||||
# once we are confident fx tracing works with dynamo.
|
||||
torch.fx._symbolic_trace._is_fx_tracing_flag = False
|
||||
# pyrefly: ignore # bad-assignment
|
||||
torch._dynamo.config.allow_empty_graphs = True
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
yield
|
||||
@ -440,6 +441,7 @@ def unique_graph_name_with_root(
|
||||
) -> tuple[int, str]:
|
||||
next_name = None
|
||||
i = 0
|
||||
# pyrefly: ignore # bad-assignment
|
||||
while not next_name:
|
||||
candidate = f"{prefix}_{i}"
|
||||
if hasattr(root, candidate):
|
||||
@ -795,6 +797,8 @@ def create_bw_fn(
|
||||
"""
|
||||
|
||||
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
||||
|
||||
# pyrefly: ignore # missing-module-attribute
|
||||
from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad
|
||||
|
||||
dummy_aot_config = AOTConfig(
|
||||
@ -939,6 +943,7 @@ def check_input_alias_and_mutation(
|
||||
out_out_alias_map,
|
||||
mutated_inputs,
|
||||
) = check_input_alias_and_mutation_return_outputs(gm)[:-1]
|
||||
# pyrefly: ignore # bad-return
|
||||
return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs
|
||||
|
||||
|
||||
|
@ -54,6 +54,7 @@ class WhileLoopOp(HigherOrderOperator):
|
||||
validate_subgraph_args_types(additional_inputs)
|
||||
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def gen_schema(self, cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
from torch._higher_order_ops.schema import HopSchemaGenerator
|
||||
from torch._higher_order_ops.utils import materialize_as_graph
|
||||
@ -430,6 +431,7 @@ def while_loop_tracing(
|
||||
elif isinstance(x, torch.Tensor):
|
||||
x = x.clone()
|
||||
if hasattr(x, "constant") and x.constant is not None:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
x.constant = None
|
||||
return x
|
||||
|
||||
@ -452,6 +454,7 @@ def while_loop_tracing(
|
||||
|
||||
next_name = None
|
||||
i = 0
|
||||
# pyrefly: ignore # bad-assignment
|
||||
while not next_name:
|
||||
candidate = f"while_loop_cond_graph_{i}"
|
||||
if hasattr(proxy_mode.tracer.root, candidate):
|
||||
@ -696,6 +699,7 @@ class WhileLoopStackOutputOp(HigherOrderOperator):
|
||||
|
||||
class WhileLoopAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
ctx,
|
||||
cond_fn,
|
||||
@ -725,6 +729,7 @@ class WhileLoopAutogradOp(torch.autograd.Function):
|
||||
ctx.additional_inputs = additional_inputs
|
||||
ctx.fw_outputs = fw_outputs
|
||||
loop_count = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for out in fw_outputs:
|
||||
if isinstance(out, torch.Tensor):
|
||||
if loop_count is not None:
|
||||
@ -878,6 +883,7 @@ class WhileLoopAutogradOp(torch.autograd.Function):
|
||||
while_loop_op(
|
||||
cond_gm,
|
||||
body_gm,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
(
|
||||
init_idx,
|
||||
*init_grad_carries,
|
||||
|
@ -880,10 +880,14 @@ def logsumexp(
|
||||
if not isinstance(dim, Iterable):
|
||||
dim = (dim,)
|
||||
if self.numel() == 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return torch.sum(torch.exp(self), dim, keepdim).log()
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
maxes = torch.amax(torch.real(self), dim, keepdim=True)
|
||||
maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
result = torch.sum(torch.exp(self - maxes), dim, keepdim)
|
||||
return result.log().add(maxes_squeezed)
|
||||
|
||||
@ -1241,10 +1245,12 @@ def copysign(
|
||||
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
|
||||
):
|
||||
if isinstance(b, Number) and isinstance(a, Tensor):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
|
||||
elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
|
||||
msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
|
||||
raise RuntimeError(msg)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return where(signbit(b), neg(abs(a)), abs(a))
|
||||
|
||||
|
||||
@ -1330,10 +1336,13 @@ def float_power(
|
||||
|
||||
# Float power has the following contiguous cast behavior to be
|
||||
# consistent with its C++ impl
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
a = _maybe_convert_to_dtype(a, dtype)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
b = _maybe_convert_to_dtype(b, dtype)
|
||||
|
||||
a, b = _maybe_broadcast(a, b)
|
||||
# pyrefly: ignore # bad-return
|
||||
return pow(a, b)
|
||||
|
||||
|
||||
@ -1375,11 +1384,15 @@ def floor_divide(
|
||||
):
|
||||
# Wrap scalars because some references only accept tensor arguments.
|
||||
if isinstance(a, Number) and isinstance(b, Number):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
a = scalar_tensor(a)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
b = scalar_tensor(b)
|
||||
elif isinstance(b, Number) and isinstance(a, Tensor):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
|
||||
elif isinstance(a, Number) and isinstance(b, Tensor):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
a = scalar_tensor(a, dtype=b.dtype, device=b.device)
|
||||
elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
|
||||
if a.device == torch.device("cpu"):
|
||||
@ -1856,8 +1869,10 @@ def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberT
|
||||
|
||||
# Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
|
||||
if isinstance(b, TensorLike) and isinstance(a, Number):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
a = scalar_tensor(a, dtype=b.dtype, device=b.device)
|
||||
elif isinstance(a, TensorLike) and isinstance(b, Number):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
|
||||
|
||||
# mypy: expected "Tensor"
|
||||
@ -2333,6 +2348,7 @@ def all(
|
||||
dim: Optional[DimsType] = None,
|
||||
keepdim: bool = False,
|
||||
) -> TensorLikeType:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim))
|
||||
|
||||
if a.dtype == torch.uint8:
|
||||
@ -2850,6 +2866,7 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
|
||||
# SymInts
|
||||
|
||||
example = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for i, t in enumerate(tensors):
|
||||
if example is None:
|
||||
if t.ndim != 1:
|
||||
@ -3228,6 +3245,7 @@ def _normalize(
|
||||
mean (Tensor): mean of the tensor along norm_dims.
|
||||
rstd (Tensor): 1/std of the tensor along norm_dims.
|
||||
"""
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
norm_dims = utils.canonicalize_dims(a.ndim, norm_dims)
|
||||
computation_dtype = utils.get_computation_dtype(a.dtype)
|
||||
a_acc = _maybe_convert_to_dtype(a, computation_dtype)
|
||||
@ -3341,6 +3359,7 @@ def native_layer_norm(
|
||||
# while torch.Size([1, 2, 3]) == (1, 2, 3) is True
|
||||
# therefore we use tuple(normalized_shape)
|
||||
torch._check(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
weight is None or sym_eq(weight.shape, tuple(normalized_shape)),
|
||||
lambda: "Expected weight to be of same shape as normalized_shape, but got "
|
||||
+ "weight of shape "
|
||||
@ -3349,6 +3368,7 @@ def native_layer_norm(
|
||||
+ str(normalized_shape),
|
||||
)
|
||||
torch._check(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
bias is None or sym_eq(bias.shape, tuple(normalized_shape)),
|
||||
lambda: "Expected bias to be of same shape as normalized_shape, but got "
|
||||
+ "bias of shape "
|
||||
@ -3359,7 +3379,10 @@ def native_layer_norm(
|
||||
torch._check(
|
||||
input.ndim >= normalized_ndim
|
||||
and sym_eq(
|
||||
input.shape[(input.ndim - normalized_ndim) :], tuple(normalized_shape)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
input.shape[(input.ndim - normalized_ndim) :],
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
tuple(normalized_shape),
|
||||
),
|
||||
lambda: "Given normalized_shape="
|
||||
+ str(normalized_shape)
|
||||
@ -3953,6 +3976,7 @@ def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
|
||||
@out_wrapper()
|
||||
def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType:
|
||||
"""Reference implementation of :func:`torch.roll`."""
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
dims = utils.canonicalize_dims(a.ndim, dims)
|
||||
# ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
|
||||
if not isinstance(shifts, Iterable):
|
||||
@ -3965,12 +3989,16 @@ def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLike
|
||||
# Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
|
||||
return a.clone()
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if a.dim() == 0 and len(dims) > 0:
|
||||
raise IndexError(
|
||||
# pyrefly: ignore # index-error
|
||||
f"Dimension specified as {dims[0]} but tensor has no dimensions"
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
len_shifts = len(shifts)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
len_dims = len(dims)
|
||||
if len_shifts != 1 or len_dims != 1:
|
||||
if len_shifts == 0:
|
||||
@ -3978,21 +4006,27 @@ def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLike
|
||||
# Takes care of the case when dims is not specified (default)
|
||||
# By default, the tensor is flattened before shifting, after which the original shape is restored
|
||||
if len_dims == 0 and len_shifts == 1:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch.roll(torch.flatten(a), shifts, 0).view(a.shape)
|
||||
if len_shifts != len_dims:
|
||||
raise RuntimeError(
|
||||
f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
|
||||
)
|
||||
assert len_dims > 1
|
||||
# pyrefly: ignore # index-error
|
||||
tail_shifts = shifts[1:]
|
||||
# pyrefly: ignore # index-error
|
||||
tail_dims = dims[1:]
|
||||
# pyrefly: ignore # index-error
|
||||
first_dim_rolled = torch.roll(a, (shifts[0],), dims[0])
|
||||
return torch.roll(first_dim_rolled, tail_shifts, tail_dims)
|
||||
|
||||
# This path is taken when only one dimension is rolled
|
||||
# For example to get `first_dim_rolled` above
|
||||
# pyrefly: ignore # index-error
|
||||
dim = dims[0]
|
||||
size = a.shape[dim]
|
||||
# pyrefly: ignore # index-error
|
||||
start = (size - shifts[0]) % size
|
||||
idx = torch.arange(size, device=a.device)
|
||||
return a.index_select(dim, torch.fmod(start + idx, size))
|
||||
@ -4074,7 +4108,9 @@ def softmax(
|
||||
a_max = amax(a_, dim, keepdim=True)
|
||||
a_exp = exp(a_ - a_max)
|
||||
return _maybe_convert_to_dtype(
|
||||
true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
true_divide(a_exp, sum(a_exp, dim, keepdim=True)),
|
||||
result_dtype,
|
||||
) # type: ignore[return-value]
|
||||
|
||||
|
||||
@ -4251,6 +4287,7 @@ def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType
|
||||
return prims.squeeze(a, dims) if dims else prims.view_of(a)
|
||||
|
||||
ndim = a.ndim
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
dim = utils.canonicalize_dims(ndim, dim)
|
||||
dims = (dim,) if isinstance(dim, Dim) else dim
|
||||
# Short-circuits if the tensor has no dimensions
|
||||
@ -4391,6 +4428,7 @@ def hsplit(
|
||||
if isinstance(indices_or_sections, IntLike):
|
||||
split_size = indices_or_sections
|
||||
torch._check(
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
(split_size != 0 and a.shape[dim] % split_size == 0),
|
||||
lambda: (
|
||||
"torch.hsplit attempted to split along dimension "
|
||||
@ -4402,6 +4440,7 @@ def hsplit(
|
||||
+ "!"
|
||||
),
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return tensor_split(a, split_size, dim)
|
||||
|
||||
torch._check_type(
|
||||
@ -4432,6 +4471,7 @@ def vsplit(
|
||||
if isinstance(indices_or_sections, IntLike):
|
||||
split_size = indices_or_sections
|
||||
torch._check(
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
(split_size != 0 and a.shape[0] % split_size == 0),
|
||||
lambda: (
|
||||
f"torch.vsplit attempted to split along dimension 0"
|
||||
@ -4442,6 +4482,7 @@ def vsplit(
|
||||
f"!"
|
||||
),
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return tensor_split(a, split_size, 0)
|
||||
|
||||
torch._check_type(
|
||||
@ -4646,6 +4687,7 @@ def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
|
||||
raise RuntimeError(
|
||||
f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
|
||||
raise RuntimeError(
|
||||
"torch.dsplit attempted to split along dimension 2, "
|
||||
@ -5419,6 +5461,7 @@ def logspace(
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def meshgrid(tensors: Sequence[TensorLikeType], indexing: str):
|
||||
pass
|
||||
|
||||
@ -5845,6 +5888,7 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi
|
||||
|
||||
# Since `where` allows type-promotion,
|
||||
# cast value to correct type before passing to `where`
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
value = _maybe_convert_to_dtype(value, a.dtype)
|
||||
r = torch.where(mask, value, a) # type: ignore[arg-type]
|
||||
|
||||
@ -6639,6 +6683,7 @@ def _infer_scalar_type(obj):
|
||||
# double.
|
||||
if length == 0:
|
||||
return torch.get_default_dtype()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for i in range(length):
|
||||
cur_item = obj[i]
|
||||
# TODO: test this
|
||||
@ -6676,6 +6721,7 @@ def _recursive_build(
|
||||
# torch.Size([1, 2])
|
||||
return obj.detach().to(dtype=scalarType, device="cpu", copy=True)
|
||||
elif isinstance(obj, Number):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch.scalar_tensor(obj, dtype=scalarType)
|
||||
|
||||
# seq can be a list of tensors
|
||||
|
@ -106,11 +106,13 @@ def _resize_fft_input(
|
||||
if x_sizes[dims[i]] < sizes[i]:
|
||||
must_copy = True
|
||||
pad_idx = len(pad_amount) - 2 * dims[i] - 1
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
|
||||
|
||||
if x_sizes[dims[i]] > sizes[i]:
|
||||
x = x.narrow(dims[i], 0, sizes[i])
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch.constant_pad_nd(x, pad_amount) if must_copy else x
|
||||
|
||||
|
||||
|
@ -216,6 +216,7 @@ def matrix_norm(
|
||||
# shape
|
||||
check_is_matrix(A, "linalg.matrix_norm")
|
||||
# dim
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
dim = utils.canonicalize_dims(A.ndim, dim)
|
||||
if isinstance(dim, Dim):
|
||||
dim = (dim,) # type: ignore[assignment]
|
||||
@ -223,7 +224,9 @@ def matrix_norm(
|
||||
len(dim) == 2, lambda: f"linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
|
||||
)
|
||||
torch._check(
|
||||
# pyrefly: ignore # index-error
|
||||
dim[0] != dim[1],
|
||||
# pyrefly: ignore # index-error
|
||||
lambda: f"linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
|
||||
)
|
||||
# dtype arg
|
||||
@ -245,6 +248,7 @@ def matrix_norm(
|
||||
else: # ord == "nuc"
|
||||
if dtype is not None:
|
||||
A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
|
||||
# pyrefly: ignore # index-error
|
||||
perm = _backshift_permutation(dim[0], dim[1], A.ndim)
|
||||
result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim)
|
||||
if keepdim:
|
||||
@ -268,6 +272,7 @@ def matrix_norm(
|
||||
if abs_ord == 2.0:
|
||||
if dtype is not None:
|
||||
A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
|
||||
# pyrefly: ignore # index-error
|
||||
perm = _backshift_permutation(dim[0], dim[1], A.ndim)
|
||||
result = max_min(svdvals(prims.transpose(A, perm)), dim=-1)
|
||||
if keepdim:
|
||||
@ -275,6 +280,7 @@ def matrix_norm(
|
||||
result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
|
||||
return result
|
||||
else: # 1, -1, inf, -inf
|
||||
# pyrefly: ignore # bad-unpacking
|
||||
dim0, dim1 = dim
|
||||
if abs_ord == float("inf"):
|
||||
dim0, dim1 = dim1, dim0
|
||||
|
@ -142,9 +142,11 @@ def _inplace_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
# nb. We use the name of the first argument used in the unary references
|
||||
@wraps(fn)
|
||||
def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
a = args[0]
|
||||
if "inplace" not in kwargs:
|
||||
kwargs["inplace"] = False
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
if kwargs["inplace"]:
|
||||
torch._check(
|
||||
"out" not in kwargs,
|
||||
@ -625,6 +627,7 @@ def smooth_l1_loss(
|
||||
)
|
||||
else:
|
||||
loss = torch.abs(input - target)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
|
||||
return _apply_loss_reduction(loss, reduction)
|
||||
|
||||
|
@ -155,8 +155,10 @@ def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, Numbe
|
||||
|
||||
# Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
|
||||
if isinstance(a, TensorLike) and isinstance(b, Number):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device)
|
||||
elif isinstance(b, TensorLike) and isinstance(a, Number):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device)
|
||||
|
||||
# mypy: expected "Tensor"
|
||||
|
@ -130,6 +130,7 @@ def var_decomposition(
|
||||
else:
|
||||
raise RuntimeError("correction must be int or float")
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return sum / max(0, denom)
|
||||
|
||||
|
||||
|
@ -14,6 +14,8 @@ import torch
|
||||
_IS_MONKEYTYPE_INSTALLED = True
|
||||
try:
|
||||
import monkeytype # type: ignore[import]
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
from monkeytype import trace as monkeytype_trace
|
||||
from monkeytype.config import _startswith, LIB_PATHS # type: ignore[import]
|
||||
from monkeytype.db.base import ( # type: ignore[import]
|
||||
@ -87,6 +89,7 @@ if _IS_MONKEYTYPE_INSTALLED:
|
||||
super().__init__(store)
|
||||
|
||||
def log(self, trace: CallTrace) -> None:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.traces.append(trace)
|
||||
|
||||
class JitTypeTraceStore(CallTraceStore):
|
||||
@ -148,6 +151,7 @@ if _IS_MONKEYTYPE_INSTALLED:
|
||||
|
||||
def trace_logger(self) -> JitTypeTraceStoreLogger:
|
||||
"""Return a JitCallTraceStoreLogger that logs to the configured trace store."""
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
return JitTypeTraceStoreLogger(self.trace_store())
|
||||
|
||||
def trace_store(self) -> CallTraceStore:
|
||||
|
@ -747,6 +747,7 @@ def get_overload_annotations(mod, jit_ignored_properties):
|
||||
if method_overloads is None:
|
||||
continue
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if item.__func__ in method_overloads:
|
||||
raise RuntimeError(
|
||||
_jit_internal.get_overload_no_implementation_error_message(
|
||||
|
@ -545,6 +545,7 @@ if _enabled:
|
||||
#
|
||||
# This ensures that if we use the attr again in `__init__`, it
|
||||
# will look like the actual value, not an instance of Attribute.
|
||||
# pyrefly: ignore # invalid-argument
|
||||
if isinstance(value, Attribute):
|
||||
# NB: Ensure that we set __annotations__ on the specific
|
||||
# class in question, and not on a superclass (which would
|
||||
@ -656,6 +657,7 @@ if _enabled:
|
||||
|
||||
# Finalize the ScriptModule: replace the nn.Module state with our
|
||||
# custom implementations and flip the _initializing bit.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
RecursiveScriptModule._finalize_scriptmodule(script_module)
|
||||
return script_module
|
||||
|
||||
@ -929,6 +931,7 @@ if _enabled:
|
||||
# Don't do anything here, we'll initialize the ScriptModule below
|
||||
return
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return RecursiveScriptModule._construct(
|
||||
self._c._replicate_for_data_parallel(), init_fn
|
||||
)
|
||||
@ -938,6 +941,7 @@ if _enabled:
|
||||
# This is because `super().foo()` does not use
|
||||
# `__getattr__` to look up `foo`. So we need to make each method available on
|
||||
# the ScriptModule manually.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for name, item in RecursiveScriptModule.__dict__.items():
|
||||
if not callable(item) and not isinstance(item, property):
|
||||
continue
|
||||
@ -1006,6 +1010,7 @@ if _enabled:
|
||||
if name.startswith("__") or name.endswith("_call_impl"):
|
||||
continue
|
||||
if (
|
||||
# pyrefly: ignore # missing-attribute
|
||||
name not in RecursiveScriptModule.__dict__
|
||||
and name not in _compiled_methods_allowlist
|
||||
):
|
||||
@ -1038,6 +1043,7 @@ def call_prepare_scriptable_func_impl(obj, memo):
|
||||
return memo[id(obj)]
|
||||
|
||||
obj = (
|
||||
# pyrefly: ignore # not-callable
|
||||
obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj
|
||||
) # type: ignore[operator]
|
||||
# Record obj in memo to avoid infinite recursion in the case of cycles in the module
|
||||
@ -1135,6 +1141,7 @@ def _script_impl(
|
||||
# the provide example inputs. This logs all the traces in type_trace_db
|
||||
type_trace_db = JitTypeTraceStore()
|
||||
if monkeytype_trace:
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
monkeytype_config = JitTypeTraceConfig(type_trace_db)
|
||||
with monkeytype_trace(monkeytype_config):
|
||||
if isinstance(example_inputs, dict):
|
||||
|
@ -166,11 +166,25 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
|
||||
cu = torch._C.CompilationUnit()
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
cpp_module = torch._C.import_ir_module(
|
||||
cu, os.fspath(f), map_location, _extra_files, _restore_shapes
|
||||
# pyrefly: ignore # no-matching-overload, bad-argument-count
|
||||
cu,
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
os.fspath(f),
|
||||
map_location,
|
||||
_extra_files,
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
_restore_shapes,
|
||||
) # type: ignore[call-arg]
|
||||
else:
|
||||
cpp_module = torch._C.import_ir_module_from_buffer(
|
||||
cu, f.read(), map_location, _extra_files, _restore_shapes
|
||||
# pyrefly: ignore # missing-attribute, bad-argument-count
|
||||
cu,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
f.read(),
|
||||
map_location,
|
||||
_extra_files,
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
_restore_shapes,
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
# TODO: Pretty sure this approach loses ConstSequential status and such
|
||||
@ -196,6 +210,7 @@ def validate_map_location(map_location=None):
|
||||
|
||||
def jit_module_from_flatbuffer(f):
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
f = os.fspath(f)
|
||||
return wrap_cpp_module(torch._C._load_jit_module_from_file(f))
|
||||
else:
|
||||
@ -245,6 +260,7 @@ def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
|
||||
extra_files = {}
|
||||
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
f = os.fspath(f)
|
||||
torch._C._save_jit_module(m._c, f, extra_files)
|
||||
else:
|
||||
|
@ -561,6 +561,7 @@ def cat(tensors: list[list[int]], dim: int):
|
||||
for i in range(len(tensors)):
|
||||
tensor = tensors[i]
|
||||
if not should_skip(tensor):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
|
||||
cat_dim_size = cat_dim_size + tensor[dim]
|
||||
|
||||
|
@ -169,6 +169,7 @@ def _clone_inputs(args):
|
||||
else:
|
||||
return a.clone(memory_format=torch.preserve_format)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return function._nested_map(
|
||||
lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors"
|
||||
)(args)
|
||||
@ -335,6 +336,7 @@ def _check_trace(
|
||||
|
||||
if is_trace_module:
|
||||
copied_dict = {}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for name, data in inputs.items():
|
||||
copied_dict[name] = _clone_inputs(data)
|
||||
check_mod = torch.jit.trace_module(
|
||||
@ -739,6 +741,7 @@ def _trace_impl(
|
||||
example_inputs = (example_inputs,)
|
||||
# done primarily so that weird iterables fail here and not pybind11 code
|
||||
elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
example_inputs = tuple(example_inputs)
|
||||
|
||||
var_lookup_fn = _create_interpreter_name_lookup_fn(0)
|
||||
@ -765,6 +768,7 @@ def _trace_impl(
|
||||
traced = torch._C._create_function_from_trace(
|
||||
name,
|
||||
func,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
example_inputs,
|
||||
var_lookup_fn,
|
||||
strict,
|
||||
|
@ -98,6 +98,7 @@ class EvalEnv:
|
||||
def __init__(self, rcb):
|
||||
self.rcb = rcb
|
||||
if torch.distributed.rpc.is_available():
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.env["RRef"] = RRef
|
||||
|
||||
def __getitem__(self, name):
|
||||
|
@ -115,6 +115,7 @@ node_start_tokens = {
|
||||
ast.Continue: "continue",
|
||||
}
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
pretty_node_names.update(
|
||||
{
|
||||
ast.AsyncFunctionDef: "async function definitions",
|
||||
@ -125,6 +126,7 @@ pretty_node_names.update(
|
||||
}
|
||||
)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
node_start_tokens.update(
|
||||
{
|
||||
ast.AsyncFunctionDef: "async def",
|
||||
@ -135,6 +137,7 @@ node_start_tokens.update(
|
||||
}
|
||||
)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
pretty_node_names.update(
|
||||
{
|
||||
ast.AnnAssign: "annotated assignments",
|
||||
@ -859,6 +862,7 @@ class ExprBuilder(Builder):
|
||||
ast.RShift: ">>",
|
||||
}
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
binop_map[ast.MatMult] = "@"
|
||||
|
||||
unop_map = {
|
||||
@ -1220,6 +1224,7 @@ class ExprBuilder(Builder):
|
||||
s += "{}"
|
||||
args.append(build_expr(ctx, value.value))
|
||||
elif isinstance(value, ast.Constant):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
s += value.value
|
||||
else:
|
||||
raise NotSupportedError(r, "Unsupported value in JoinedStr")
|
||||
|
@ -44,10 +44,13 @@ def _load_for_lite_interpreter(f, map_location=None):
|
||||
map_location = validate_map_location(map_location)
|
||||
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
|
||||
else:
|
||||
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
|
||||
f.read(), map_location
|
||||
# pyrefly: ignore # missing-attribute
|
||||
f.read(),
|
||||
map_location,
|
||||
)
|
||||
|
||||
return LiteScriptModule(cpp_module)
|
||||
@ -103,8 +106,10 @@ def _get_model_bytecode_version(f_input) -> int:
|
||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||
|
||||
if isinstance(f_input, (str, os.PathLike)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return torch._C._get_model_bytecode_version(os.fspath(f_input))
|
||||
else:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
|
||||
|
||||
|
||||
@ -135,8 +140,10 @@ def _get_mobile_model_contained_types(f_input) -> int:
|
||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||
|
||||
if isinstance(f_input, (str, os.PathLike)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
|
||||
else:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read())
|
||||
|
||||
|
||||
@ -161,11 +168,18 @@ def _backport_for_mobile(f_input, f_output, to_version):
|
||||
isinstance(f_output, (str, os.PathLike))
|
||||
):
|
||||
return torch._C._backport_for_mobile(
|
||||
os.fspath(f_input), os.fspath(f_output), to_version
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
os.fspath(f_input),
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
os.fspath(f_output),
|
||||
to_version,
|
||||
)
|
||||
else:
|
||||
return torch._C._backport_for_mobile_from_buffer(
|
||||
f_input.read(), str(f_output), to_version
|
||||
# pyrefly: ignore # missing-attribute
|
||||
f_input.read(),
|
||||
str(f_output),
|
||||
to_version,
|
||||
)
|
||||
|
||||
|
||||
@ -184,10 +198,13 @@ def _backport_for_mobile_to_buffer(f_input, to_version):
|
||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||
|
||||
if isinstance(f_input, (str, os.PathLike)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
|
||||
else:
|
||||
return torch._C._backport_for_mobile_from_buffer_to_buffer(
|
||||
f_input.read(), to_version
|
||||
# pyrefly: ignore # missing-attribute
|
||||
f_input.read(),
|
||||
to_version,
|
||||
)
|
||||
|
||||
|
||||
@ -227,6 +244,8 @@ def _get_model_ops_and_info(f_input):
|
||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||
|
||||
if isinstance(f_input, (str, os.PathLike)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return torch._C._get_model_ops_and_info(os.fspath(f_input))
|
||||
else:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return torch._C._get_model_ops_and_info(f_input.read())
|
||||
|
@ -261,6 +261,7 @@ def _get_global_builtins():
|
||||
|
||||
magic_methods_rows = []
|
||||
for fn, magic_method in magic_methods:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
magic_methods_rows.append(f'"{fn}", "``{magic_method}``"')
|
||||
|
||||
schematized_ops = []
|
||||
@ -279,6 +280,7 @@ def _get_global_builtins():
|
||||
table_row = (
|
||||
f'":external+python:py:obj:`{fn}`", "{schemaless_op_explanations[fn]}"'
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
schemaless_ops.append(table_row)
|
||||
|
||||
schematized_ops_str = "\n".join(schematized_ops)
|
||||
|
@ -78,6 +78,7 @@ def _adjust_lr(
|
||||
A, B = param_shape[:2]
|
||||
|
||||
if adjust_lr_fn is None or adjust_lr_fn == "original":
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
adjusted_ratio = math.sqrt(max(1, A / B))
|
||||
elif adjust_lr_fn == "match_rms_adamw":
|
||||
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
|
||||
|
@ -415,6 +415,7 @@ def _single_tensor_adam(
|
||||
if weight_decay.requires_grad:
|
||||
grad = grad.addcmul_(param.clone(), weight_decay)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
@ -444,6 +445,7 @@ def _single_tensor_adam(
|
||||
device_beta1 = beta1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
exp_avg.lerp_(grad, 1 - device_beta1)
|
||||
|
||||
# Nested if is necessary to bypass jitscript rules
|
||||
@ -692,6 +694,7 @@ def _multi_tensor_adam(
|
||||
device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
|
||||
)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
|
||||
# Due to the strictness of the _foreach_addcmul API, we can't have a single
|
||||
|
@ -263,6 +263,7 @@ def _single_tensor_asgd(
|
||||
ax.copy_(param)
|
||||
|
||||
if capturable:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
|
||||
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
|
||||
else:
|
||||
|
@ -113,9 +113,11 @@ def _strong_wolfe(
|
||||
|
||||
# compute new trial value
|
||||
t = _cubic_interpolate(
|
||||
# pyrefly: ignore # index-error
|
||||
bracket[0],
|
||||
bracket_f[0],
|
||||
bracket_gtd[0], # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # index-error
|
||||
bracket[1],
|
||||
bracket_f[1],
|
||||
bracket_gtd[1],
|
||||
@ -151,6 +153,7 @@ def _strong_wolfe(
|
||||
|
||||
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
|
||||
# Armijo condition not satisfied or not lower than lowest point
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
bracket[high_pos] = t
|
||||
bracket_f[high_pos] = f_new
|
||||
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
|
||||
@ -160,14 +163,17 @@ def _strong_wolfe(
|
||||
if abs(gtd_new) <= -c2 * gtd:
|
||||
# Wolfe conditions satisfied
|
||||
done = True
|
||||
# pyrefly: ignore # index-error
|
||||
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
|
||||
# old high becomes new low
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
bracket[high_pos] = bracket[low_pos]
|
||||
bracket_f[high_pos] = bracket_f[low_pos]
|
||||
bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
|
||||
bracket_gtd[high_pos] = bracket_gtd[low_pos]
|
||||
|
||||
# new point becomes new low
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
bracket[low_pos] = t
|
||||
bracket_f[low_pos] = f_new
|
||||
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
|
||||
@ -252,6 +258,7 @@ class LBFGS(Optimizer):
|
||||
|
||||
def _numel(self):
|
||||
if self._numel_cache is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._numel_cache = sum(
|
||||
2 * p.numel() if torch.is_complex(p) else p.numel()
|
||||
for p in self._params
|
||||
|
@ -1665,6 +1665,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||
self.default_min_lr = None
|
||||
self.min_lrs = list(min_lr)
|
||||
else:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.default_min_lr = min_lr
|
||||
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
||||
|
||||
@ -1724,6 +1725,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||
"of the `optimizer` param groups."
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)
|
||||
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
@ -1903,10 +1905,13 @@ class CyclicLR(LRScheduler):
|
||||
|
||||
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
step_size_up = float(step_size_up)
|
||||
step_size_down = (
|
||||
# pyrefly: ignore # bad-assignment
|
||||
float(step_size_down) if step_size_down is not None else step_size_up
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.total_size = step_size_up + step_size_down
|
||||
self.step_ratio = step_size_up / self.total_size
|
||||
|
||||
|
@ -62,6 +62,7 @@ def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
import torch._dynamo
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self = cast(Optimizer, args[0]) # assume first positional arg is `self`
|
||||
prev_grad = torch.is_grad_enabled()
|
||||
try:
|
||||
@ -135,11 +136,13 @@ def _disable_dynamo_if_unsupported(
|
||||
if torch.compiler.is_compiling() and (
|
||||
not kwargs.get("capturable", False)
|
||||
and has_state_steps
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
and (arg := args[state_steps_ind])
|
||||
and isinstance(arg, Sequence)
|
||||
and arg[0].is_cuda
|
||||
or (
|
||||
"state_steps" in kwargs
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
and (kwarg := kwargs["state_steps"])
|
||||
and isinstance(kwarg, Sequence)
|
||||
and kwarg[0].is_cuda
|
||||
@ -359,14 +362,18 @@ class Optimizer:
|
||||
|
||||
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
|
||||
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
|
||||
# pyrefly: ignore # not-a-type
|
||||
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
_optimizer_state_dict_post_hooks: (
|
||||
# pyrefly: ignore # not-a-type
|
||||
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
)
|
||||
_optimizer_load_state_dict_pre_hooks: (
|
||||
# pyrefly: ignore # not-a-type
|
||||
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
)
|
||||
_optimizer_load_state_dict_post_hooks: (
|
||||
# pyrefly: ignore # not-a-type
|
||||
'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
)
|
||||
|
||||
@ -391,6 +398,7 @@ class Optimizer:
|
||||
self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict)
|
||||
self.param_groups: list[dict[str, Any]] = []
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
param_groups = list(params)
|
||||
if len(param_groups) == 0:
|
||||
raise ValueError("optimizer got an empty parameter list")
|
||||
@ -514,6 +522,7 @@ class Optimizer:
|
||||
f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
|
||||
)
|
||||
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
out = func(*args, **kwargs)
|
||||
self._optimizer_step_code()
|
||||
|
||||
@ -949,7 +958,14 @@ class Optimizer:
|
||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
if isinstance(value, torch.Tensor):
|
||||
return Optimizer._process_value_according_to_param_policy(
|
||||
param, value, param_id, param_groups, key
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
param,
|
||||
value,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
param_id,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
param_groups,
|
||||
key,
|
||||
)
|
||||
elif isinstance(value, dict):
|
||||
return {
|
||||
@ -960,6 +976,7 @@ class Optimizer:
|
||||
}
|
||||
elif isinstance(value, Iterable):
|
||||
return type(value)(
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
_cast(param, v, param_id=param_id, param_groups=param_groups)
|
||||
for v in value
|
||||
) # type: ignore[call-arg]
|
||||
|
@ -322,6 +322,7 @@ def _single_tensor_radam(
|
||||
rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2
|
||||
|
||||
def _compute_rect():
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return (
|
||||
(rho_t - 4)
|
||||
* (rho_t - 2)
|
||||
@ -336,6 +337,7 @@ def _single_tensor_radam(
|
||||
else:
|
||||
exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return (bias_correction2**0.5) / exp_avg_sq_sqrt
|
||||
|
||||
# Compute the variance rectification term and update parameters accordingly
|
||||
|
@ -337,6 +337,7 @@ def _single_tensor_sgd(
|
||||
if not torch.jit.is_scripting():
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i] if not maximize else -grads[i]
|
||||
|
||||
@ -347,6 +348,7 @@ def _single_tensor_sgd(
|
||||
# usually this is the differentiable path, which is why the param.clone() is needed
|
||||
grad = grad.addcmul_(param.clone(), weight_decay)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
@ -370,6 +372,7 @@ def _single_tensor_sgd(
|
||||
if lr.requires_grad:
|
||||
param.addcmul_(grad, lr, value=-1)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
param.add_(grad, alpha=-lr)
|
||||
else:
|
||||
param.add_(grad, alpha=-lr)
|
||||
@ -430,10 +433,12 @@ def _multi_tensor_sgd(
|
||||
|
||||
all_states_with_momentum_buffer = True
|
||||
for i in range(len(device_momentum_buffer_list)):
|
||||
# pyrefly: ignore # index-error
|
||||
if device_momentum_buffer_list[i] is None:
|
||||
all_states_with_momentum_buffer = False
|
||||
break
|
||||
else:
|
||||
# pyrefly: ignore # index-error
|
||||
bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
|
||||
|
||||
if all_states_with_momentum_buffer:
|
||||
@ -441,12 +446,15 @@ def _multi_tensor_sgd(
|
||||
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
|
||||
else:
|
||||
bufs = []
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for i in range(len(device_momentum_buffer_list)):
|
||||
# pyrefly: ignore # index-error
|
||||
if device_momentum_buffer_list[i] is None:
|
||||
buf = device_momentum_buffer_list[i] = momentum_buffer_list[
|
||||
indices[i]
|
||||
] = device_grads[i].detach().clone()
|
||||
else:
|
||||
# pyrefly: ignore # index-error
|
||||
buf = cast(Tensor, device_momentum_buffer_list[i])
|
||||
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
|
||||
|
||||
|
@ -249,11 +249,13 @@ class AveragedModel(Module):
|
||||
def update_parameters(self, model: Module):
|
||||
"""Update model parameters."""
|
||||
self_param = (
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
itertools.chain(self.module.parameters(), self.module.buffers())
|
||||
if self.use_buffers
|
||||
else self.parameters()
|
||||
)
|
||||
model_param = (
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
itertools.chain(model.parameters(), model.buffers())
|
||||
if self.use_buffers
|
||||
else model.parameters()
|
||||
@ -300,8 +302,11 @@ class AveragedModel(Module):
|
||||
for p_averaged, p_model in zip( # type: ignore[assignment]
|
||||
self_param_detached, model_param_detached
|
||||
):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
n_averaged = self.n_averaged.to(p_averaged.device)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
p_averaged.detach().copy_(
|
||||
# pyrefly: ignore # missing-attribute, bad-argument-type
|
||||
self.avg_fn(p_averaged.detach(), p_model, n_averaged)
|
||||
)
|
||||
|
||||
@ -489,12 +494,14 @@ class SWALR(LRScheduler):
|
||||
step = self._step_count - 1
|
||||
if self.anneal_epochs == 0:
|
||||
step = max(1, step)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
|
||||
prev_alpha = self.anneal_func(prev_t)
|
||||
prev_lrs = [
|
||||
self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
|
||||
for group in self.optimizer.param_groups
|
||||
]
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
t = max(0, min(1, step / max(1, self.anneal_epochs)))
|
||||
alpha = self.anneal_func(t)
|
||||
return [
|
||||
|
Reference in New Issue
Block a user