mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[1/N] Use "is" in python type comparison (#165037)
It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
960b0d5f0d
commit
70925bdf82
@ -2278,7 +2278,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||||||
# don't support this in serialization because it uses unsupported FUNCTION_MATCH
|
# don't support this in serialization because it uses unsupported FUNCTION_MATCH
|
||||||
val = self.get(guard.name)
|
val = self.get(guard.name)
|
||||||
# Strictly only want user-defined functions
|
# Strictly only want user-defined functions
|
||||||
if type(val) == types.FunctionType and hasattr(val, "__code__"):
|
if type(val) is types.FunctionType and hasattr(val, "__code__"):
|
||||||
self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type]
|
self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type]
|
||||||
self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) # type: ignore[arg-type]
|
self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
|
@ -197,7 +197,7 @@ class ConstDictVariable(VariableTracker):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _eq_impl(a, b):
|
def _eq_impl(a, b):
|
||||||
# TODO: Put this in utils and share it between variables/builtin.py and here
|
# TODO: Put this in utils and share it between variables/builtin.py and here
|
||||||
if type(a) != type(b):
|
if type(a) is not type(b):
|
||||||
return False
|
return False
|
||||||
elif isinstance(a, tuple):
|
elif isinstance(a, tuple):
|
||||||
Hashable = ConstDictVariable._HashableTracker
|
Hashable = ConstDictVariable._HashableTracker
|
||||||
|
@ -3532,7 +3532,7 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
|
|||||||
from torch.distributed.tensor.experimental._func_map import _local_map_wrapped
|
from torch.distributed.tensor.experimental._func_map import _local_map_wrapped
|
||||||
|
|
||||||
# check is important to avoid subclass dispatch
|
# check is important to avoid subclass dispatch
|
||||||
if type(value) != type(_local_map_wrapped):
|
if type(value) is not type(_local_map_wrapped):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return value == _local_map_wrapped and cls._enabled
|
return value == _local_map_wrapped and cls._enabled
|
||||||
|
@ -200,7 +200,7 @@ class BaseListVariable(VariableTracker):
|
|||||||
if kwargs or len(args) != 1:
|
if kwargs or len(args) != 1:
|
||||||
raise_args_mismatch(tx, name)
|
raise_args_mismatch(tx, name)
|
||||||
|
|
||||||
if type(self) != type(args[0]):
|
if type(self) is not type(args[0]):
|
||||||
tp_name = self.python_type_name()
|
tp_name = self.python_type_name()
|
||||||
other = args[0].python_type_name()
|
other = args[0].python_type_name()
|
||||||
msg = ConstantVariable.create(
|
msg = ConstantVariable.create(
|
||||||
|
@ -464,7 +464,7 @@ def _check_input_constraints_for_graph(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(node_val, (int, float, str)):
|
elif isinstance(node_val, (int, float, str)):
|
||||||
if type(arg) != type(node_val) or arg != node_val:
|
if type(arg) is not type(node_val) or arg != node_val:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}",
|
f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}",
|
||||||
)
|
)
|
||||||
|
@ -361,7 +361,7 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
|
|||||||
input_meta += get_input_meta(args[1])
|
input_meta += get_input_meta(args[1])
|
||||||
return input_meta
|
return input_meta
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if type(arg) == int or type(arg) == float:
|
if type(arg) is int or type(arg) is float:
|
||||||
input_meta.append((type(arg),))
|
input_meta.append((type(arg),))
|
||||||
else:
|
else:
|
||||||
input_meta.append(
|
input_meta.append(
|
||||||
|
@ -802,7 +802,7 @@ def _insert_fused_matmul_reduce_scatter(
|
|||||||
scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern
|
scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern
|
||||||
output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern
|
output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern
|
||||||
) -> torch.fx.Node:
|
) -> torch.fx.Node:
|
||||||
if type(matmul) == _Matmul:
|
if type(matmul) is _Matmul:
|
||||||
return graph.call_function(
|
return graph.call_function(
|
||||||
torch.ops.symm_mem.fused_matmul_reduce_scatter.default,
|
torch.ops.symm_mem.fused_matmul_reduce_scatter.default,
|
||||||
args=(
|
args=(
|
||||||
@ -813,7 +813,7 @@ def _insert_fused_matmul_reduce_scatter(
|
|||||||
group_name,
|
group_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif type(matmul) == _ScaledMatmul:
|
elif type(matmul) is _ScaledMatmul:
|
||||||
return graph.call_function(
|
return graph.call_function(
|
||||||
torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default,
|
torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default,
|
||||||
args=(
|
args=(
|
||||||
|
@ -98,7 +98,7 @@ class FakeTensorUpdater:
|
|||||||
return statically_known_true(sym_eq(new, old))
|
return statically_known_true(sym_eq(new, old))
|
||||||
|
|
||||||
def is_fake_tensor_same(new, old, *, node):
|
def is_fake_tensor_same(new, old, *, node):
|
||||||
if type(new) != type(old):
|
if type(new) is not type(old):
|
||||||
return False
|
return False
|
||||||
if isinstance(new, (list, tuple)):
|
if isinstance(new, (list, tuple)):
|
||||||
if len(new) != len(old):
|
if len(new) != len(old):
|
||||||
|
@ -122,7 +122,7 @@ class SchemaCheckMode(TorchDispatchMode):
|
|||||||
|
|
||||||
def parse_metadata(e):
|
def parse_metadata(e):
|
||||||
if isinstance(e, torch.Tensor):
|
if isinstance(e, torch.Tensor):
|
||||||
if type(e) != torch.Tensor:
|
if type(e) is not torch.Tensor:
|
||||||
try:
|
try:
|
||||||
current = e.elem
|
current = e.elem
|
||||||
return (
|
return (
|
||||||
|
@ -37,7 +37,7 @@ def _type(self, dtype=None, non_blocking=False, **kwargs):
|
|||||||
|
|
||||||
if isinstance(dtype, str):
|
if isinstance(dtype, str):
|
||||||
dtype = _import_dotted_name(dtype)
|
dtype = _import_dotted_name(dtype)
|
||||||
if dtype == type(self):
|
if dtype is type(self):
|
||||||
return self
|
return self
|
||||||
if self.is_sparse:
|
if self.is_sparse:
|
||||||
if not dtype.is_sparse:
|
if not dtype.is_sparse:
|
||||||
|
@ -80,7 +80,7 @@ class ConvReLU1d(nnq.Conv1d):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||||
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
|
if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
|
||||||
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
||||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||||
mod.weight,
|
mod.weight,
|
||||||
@ -161,7 +161,7 @@ class ConvReLU2d(nnq.Conv2d):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||||
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
|
if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
|
||||||
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
||||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||||
mod.weight,
|
mod.weight,
|
||||||
@ -244,7 +244,7 @@ class ConvReLU3d(nnq.Conv3d):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||||
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
|
if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
|
||||||
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
||||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||||
mod.weight,
|
mod.weight,
|
||||||
|
@ -117,7 +117,7 @@ class Linear(nnq.Linear):
|
|||||||
+ str([float_mod.__name__ for float_mod in float_modules])
|
+ str([float_mod.__name__ for float_mod in float_modules])
|
||||||
)
|
)
|
||||||
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
||||||
if type(mod) == nni.LinearReLU:
|
if type(mod) is nni.LinearReLU:
|
||||||
mod = mod[0]
|
mod = mod[0]
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||||
|
@ -1064,15 +1064,15 @@ class RNNCellBase(torch.nn.Module):
|
|||||||
|
|
||||||
qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]
|
qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]
|
||||||
|
|
||||||
if type(mod) == torch.nn.LSTMCell:
|
if type(mod) is torch.nn.LSTMCell:
|
||||||
qRNNCellBase = LSTMCell(
|
qRNNCellBase = LSTMCell(
|
||||||
mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
|
mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
|
||||||
)
|
)
|
||||||
elif type(mod) == torch.nn.GRUCell:
|
elif type(mod) is torch.nn.GRUCell:
|
||||||
qRNNCellBase = GRUCell(
|
qRNNCellBase = GRUCell(
|
||||||
mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
|
mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
|
||||||
)
|
)
|
||||||
elif type(mod) == torch.nn.RNNCell:
|
elif type(mod) is torch.nn.RNNCell:
|
||||||
qRNNCellBase = RNNCell(
|
qRNNCellBase = RNNCell(
|
||||||
mod.input_size,
|
mod.input_size,
|
||||||
mod.hidden_size,
|
mod.hidden_size,
|
||||||
|
@ -20,7 +20,7 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
activation_post_process = mod.activation_post_process
|
activation_post_process = mod.activation_post_process
|
||||||
if type(mod) == cls._NNI_BN_RELU_MODULE:
|
if type(mod) is cls._NNI_BN_RELU_MODULE:
|
||||||
mod = mod[0]
|
mod = mod[0]
|
||||||
scale, zero_point = activation_post_process.calculate_qparams()
|
scale, zero_point = activation_post_process.calculate_qparams()
|
||||||
new_mod = cls(mod.num_features, mod.eps)
|
new_mod = cls(mod.num_features, mod.eps)
|
||||||
|
@ -280,7 +280,7 @@ class _ConvNd(WeightedQuantizedModule):
|
|||||||
if hasattr(mod, "weight_fake_quant"):
|
if hasattr(mod, "weight_fake_quant"):
|
||||||
# assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
|
# assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
|
||||||
# ".from_float only works for " + cls.__QAT_MODULE.__name__
|
# ".from_float only works for " + cls.__QAT_MODULE.__name__
|
||||||
if type(mod) == cls._NNIQAT_CONV_BN_MODULE:
|
if type(mod) is cls._NNIQAT_CONV_BN_MODULE:
|
||||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||||
mod.weight,
|
mod.weight,
|
||||||
mod.bias,
|
mod.bias,
|
||||||
|
@ -149,7 +149,7 @@ class Linear(torch.nn.Module):
|
|||||||
# TODO: Need to add options to qconfig to avoid the calibration.
|
# TODO: Need to add options to qconfig to avoid the calibration.
|
||||||
# TODO: Add calibration for the sparsity
|
# TODO: Add calibration for the sparsity
|
||||||
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
||||||
if type(mod) == nni.LinearReLU:
|
if type(mod) is nni.LinearReLU:
|
||||||
mod = mod[0]
|
mod = mod[0]
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||||
|
@ -268,7 +268,7 @@ def extract_weight_from_node(
|
|||||||
mod = getattr_from_fqn(gm, node.target)
|
mod = getattr_from_fqn(gm, node.target)
|
||||||
module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
|
module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
|
||||||
for target_mod_type, weight_extraction_fn in module_mapping.items():
|
for target_mod_type, weight_extraction_fn in module_mapping.items():
|
||||||
if type(mod) == target_mod_type:
|
if type(mod) is target_mod_type:
|
||||||
weight = weight_extraction_fn(mod)
|
weight = weight_extraction_fn(mod)
|
||||||
return {
|
return {
|
||||||
"type": res_type,
|
"type": res_type,
|
||||||
|
@ -792,7 +792,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
|
|||||||
for i in range(1, len(path)):
|
for i in range(1, len(path)):
|
||||||
prev_key = path[i - 1]
|
prev_key = path[i - 1]
|
||||||
key = path[i]
|
key = path[i]
|
||||||
def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) == str else []
|
def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) is str else []
|
||||||
|
|
||||||
if isinstance(cur_container, Mapping):
|
if isinstance(cur_container, Mapping):
|
||||||
cur_container = cast(
|
cur_container = cast(
|
||||||
@ -806,7 +806,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
|
|||||||
cur_container = cur_container[prev_key]
|
cur_container = cur_container[prev_key]
|
||||||
|
|
||||||
key = path[-1]
|
key = path[-1]
|
||||||
if type(key) == int:
|
if type(key) is int:
|
||||||
extend_list(cast(list[Any], cur_container), key)
|
extend_list(cast(list[Any], cur_container), key)
|
||||||
|
|
||||||
cur_container[key] = value
|
cur_container[key] = value
|
||||||
|
@ -121,7 +121,7 @@ def set_element(
|
|||||||
for i in range(1, len(path)):
|
for i in range(1, len(path)):
|
||||||
prev_key = path[i - 1]
|
prev_key = path[i - 1]
|
||||||
key = path[i]
|
key = path[i]
|
||||||
def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
|
def_val = cast(STATE_DICT_ITEM, {} if type(key) is str else [])
|
||||||
|
|
||||||
if isinstance(cur_container, Mapping):
|
if isinstance(cur_container, Mapping):
|
||||||
cur_container = cast(
|
cur_container = cast(
|
||||||
@ -135,7 +135,7 @@ def set_element(
|
|||||||
cur_container = cur_container[prev_key]
|
cur_container = cur_container[prev_key]
|
||||||
|
|
||||||
key = path[-1]
|
key = path[-1]
|
||||||
if type(key) == int:
|
if type(key) is int:
|
||||||
extend_list(cast(list[STATE_DICT_ITEM], cur_container), key)
|
extend_list(cast(list[STATE_DICT_ITEM], cur_container), key)
|
||||||
|
|
||||||
cur_container[key] = value
|
cur_container[key] = value
|
||||||
|
@ -2203,7 +2203,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
|||||||
# alive until all works and hooks are done. The current implementation does the
|
# alive until all works and hooks are done. The current implementation does the
|
||||||
# latter. Therefore, we explicitly call _wait_for_pending_works() here to wait
|
# latter. Therefore, we explicitly call _wait_for_pending_works() here to wait
|
||||||
# for the pending hooks to finish.
|
# for the pending hooks to finish.
|
||||||
if type(pg) == ProcessGroup and pg._has_hooks():
|
if type(pg) is ProcessGroup and pg._has_hooks():
|
||||||
pg._wait_for_pending_works()
|
pg._wait_for_pending_works()
|
||||||
|
|
||||||
if group is None or group == GroupMember.WORLD:
|
if group is None or group == GroupMember.WORLD:
|
||||||
@ -2783,7 +2783,7 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
|
|||||||
key = "group_dst" if op.op == isend else "group_src"
|
key = "group_dst" if op.op == isend else "group_src"
|
||||||
return {key: op.group_peer}
|
return {key: op.group_peer}
|
||||||
|
|
||||||
if type(group) == ProcessGroup and group._get_backend(device).supports_coalescing:
|
if type(group) is ProcessGroup and group._get_backend(device).supports_coalescing:
|
||||||
# NCCL style coalescing
|
# NCCL style coalescing
|
||||||
with _coalescing_manager(group, device, async_ops=True) as cm:
|
with _coalescing_manager(group, device, async_ops=True) as cm:
|
||||||
for p2p_op in p2p_op_list:
|
for p2p_op in p2p_op_list:
|
||||||
|
@ -149,9 +149,9 @@ class EtcdStore(Store):
|
|||||||
# In case of `str`, utf-8 encoding is assumed.
|
# In case of `str`, utf-8 encoding is assumed.
|
||||||
#
|
#
|
||||||
def _encode(self, value) -> str:
|
def _encode(self, value) -> str:
|
||||||
if type(value) == bytes:
|
if type(value) is bytes:
|
||||||
return b64encode(value).decode()
|
return b64encode(value).decode()
|
||||||
elif type(value) == str:
|
elif type(value) is str:
|
||||||
return b64encode(value.encode()).decode()
|
return b64encode(value.encode()).decode()
|
||||||
raise ValueError("Value must be of type str or bytes")
|
raise ValueError("Value must be of type str or bytes")
|
||||||
|
|
||||||
@ -160,9 +160,9 @@ class EtcdStore(Store):
|
|||||||
# Return type is `bytes`, which is more convenient with the Store interface.
|
# Return type is `bytes`, which is more convenient with the Store interface.
|
||||||
#
|
#
|
||||||
def _decode(self, value) -> bytes:
|
def _decode(self, value) -> bytes:
|
||||||
if type(value) == bytes:
|
if type(value) is bytes:
|
||||||
return b64decode(value)
|
return b64decode(value)
|
||||||
elif type(value) == str:
|
elif type(value) is str:
|
||||||
return b64decode(value.encode())
|
return b64decode(value.encode())
|
||||||
raise ValueError("Value must be of type str or bytes")
|
raise ValueError("Value must be of type str or bytes")
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ class DownSampling(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def init_weights(m):
|
def init_weights(m):
|
||||||
if type(m) == nn.Conv2d or type(m) == nn.Linear:
|
if type(m) is nn.Conv2d or type(m) is nn.Linear:
|
||||||
nn.init.ones_(m.weight)
|
nn.init.ones_(m.weight)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.zeros_(m.bias)
|
nn.init.zeros_(m.bias)
|
||||||
|
@ -280,7 +280,7 @@ def _kl_exponential_exponential(p, q):
|
|||||||
|
|
||||||
@register_kl(ExponentialFamily, ExponentialFamily)
|
@register_kl(ExponentialFamily, ExponentialFamily)
|
||||||
def _kl_expfamily_expfamily(p, q):
|
def _kl_expfamily_expfamily(p, q):
|
||||||
if not type(p) == type(q):
|
if type(p) is not type(q):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"The cross KL-divergence between different exponential families cannot \
|
"The cross KL-divergence between different exponential families cannot \
|
||||||
be computed using Bregman divergences"
|
be computed using Bregman divergences"
|
||||||
|
@ -872,7 +872,7 @@ class AdditionalInputs:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def _mark_dynamism(v, *other_vs):
|
def _mark_dynamism(v, *other_vs):
|
||||||
if not all(type(v) == type(other) for other in other_vs):
|
if not all(type(v) is type(other) for other in other_vs):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The following inputs were found to have differing types, "
|
"The following inputs were found to have differing types, "
|
||||||
f"so they cannot be marked as dynamic: {(v,) + other_vs}."
|
f"so they cannot be marked as dynamic: {(v,) + other_vs}."
|
||||||
|
@ -927,7 +927,7 @@ class Tracer(TracerBase):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
# Union[int, bool] == bool in Python <= 3.6
|
# Union[int, bool] == bool in Python <= 3.6
|
||||||
if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor:
|
if type(x) is bool or type(x) in base_types and type(x) != torch.Tensor:
|
||||||
torch._assert(
|
torch._assert(
|
||||||
out == x,
|
out == x,
|
||||||
f"{name} has been specialized to have value {x} but got another value",
|
f"{name} has been specialized to have value {x} but got another value",
|
||||||
|
@ -112,7 +112,7 @@ def unify_object(u, v, s):
|
|||||||
>>> unify_object(f, g, {})
|
>>> unify_object(f, g, {})
|
||||||
{~x: 2}
|
{~x: 2}
|
||||||
"""
|
"""
|
||||||
if type(u) != type(v):
|
if type(u) is not type(v):
|
||||||
return False
|
return False
|
||||||
if hasattr(u, "__slots__"):
|
if hasattr(u, "__slots__"):
|
||||||
return unify(
|
return unify(
|
||||||
|
@ -311,7 +311,7 @@ class ScriptMeta(type):
|
|||||||
original_init(self, *args, **kwargs)
|
original_init(self, *args, **kwargs)
|
||||||
added_methods_in_init = len(cls._methods) > num_methods
|
added_methods_in_init = len(cls._methods) > num_methods
|
||||||
|
|
||||||
if type(self) == cls:
|
if type(self) is cls:
|
||||||
|
|
||||||
def make_stubs(module):
|
def make_stubs(module):
|
||||||
cls = type(module)
|
cls = type(module)
|
||||||
@ -804,7 +804,7 @@ if _enabled:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def original_name(self):
|
def original_name(self):
|
||||||
if type(self) == str(self._c._type().name()):
|
if type(self) is str(self._c._type().name()):
|
||||||
return ""
|
return ""
|
||||||
return str(self._c._type().name())
|
return str(self._c._type().name())
|
||||||
|
|
||||||
|
@ -651,7 +651,7 @@ def analyze_ts_result_with_export_result(export, trace):
|
|||||||
# mkldnn is not supported for torch.allclose
|
# mkldnn is not supported for torch.allclose
|
||||||
if orig.layout == torch._mkldnn: # type: ignore[attr-defined]
|
if orig.layout == torch._mkldnn: # type: ignore[attr-defined]
|
||||||
return True
|
return True
|
||||||
if type(orig) != type(loaded):
|
if type(orig) is not type(loaded):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if isinstance(orig, torch._subclasses.FakeTensor):
|
if isinstance(orig, torch._subclasses.FakeTensor):
|
||||||
|
@ -198,7 +198,7 @@ class MaskedTensor(torch.Tensor):
|
|||||||
def _validate_members(self):
|
def _validate_members(self):
|
||||||
data = self._masked_data
|
data = self._masked_data
|
||||||
mask = self.get_mask()
|
mask = self.get_mask()
|
||||||
if type(data) != type(mask):
|
if type(data) is not type(mask):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"data and mask must have the same type. Got {type(data)} and {type(mask)}"
|
f"data and mask must have the same type. Got {type(data)} and {type(mask)}"
|
||||||
)
|
)
|
||||||
|
@ -1049,7 +1049,7 @@ class Module:
|
|||||||
>>> @torch.no_grad()
|
>>> @torch.no_grad()
|
||||||
>>> def init_weights(m):
|
>>> def init_weights(m):
|
||||||
>>> print(m)
|
>>> print(m)
|
||||||
>>> if type(m) == nn.Linear:
|
>>> if type(m) is nn.Linear:
|
||||||
>>> m.weight.fill_(1.0)
|
>>> m.weight.fill_(1.0)
|
||||||
>>> print(m.weight)
|
>>> print(m.weight)
|
||||||
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
|
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
|
||||||
|
@ -10,7 +10,7 @@ def is_from_package(obj: Any) -> bool:
|
|||||||
|
|
||||||
Note: packaged objects from externed modules will return ``False``.
|
Note: packaged objects from externed modules will return ``False``.
|
||||||
"""
|
"""
|
||||||
if type(obj) == ModuleType:
|
if type(obj) is ModuleType:
|
||||||
return is_mangled(obj.__name__)
|
return is_mangled(obj.__name__)
|
||||||
else:
|
else:
|
||||||
return is_mangled(type(obj).__module__)
|
return is_mangled(type(obj).__module__)
|
||||||
|
@ -891,7 +891,7 @@ class TypedStorage:
|
|||||||
def _new_wrapped_storage(self, untyped_storage) -> Self:
|
def _new_wrapped_storage(self, untyped_storage) -> Self:
|
||||||
assert type(untyped_storage) == torch.UntypedStorage
|
assert type(untyped_storage) == torch.UntypedStorage
|
||||||
|
|
||||||
if type(self) == TypedStorage:
|
if type(self) is TypedStorage:
|
||||||
return cast(
|
return cast(
|
||||||
Self,
|
Self,
|
||||||
TypedStorage(
|
TypedStorage(
|
||||||
@ -913,7 +913,7 @@ class TypedStorage:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if type(idx) != int:
|
if type(idx) is not int:
|
||||||
raise TypeError(f"can't index a {type(self)} with {type(idx)}")
|
raise TypeError(f"can't index a {type(self)} with {type(idx)}")
|
||||||
if is_stop:
|
if is_stop:
|
||||||
if (idx > self._size()) or (idx < -self._size()):
|
if (idx > self._size()) or (idx < -self._size()):
|
||||||
@ -1513,7 +1513,7 @@ class _LegacyStorageMeta(type):
|
|||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
|
|
||||||
def __instancecheck__(cls, instance):
|
def __instancecheck__(cls, instance):
|
||||||
if type(instance) == TypedStorage:
|
if type(instance) is TypedStorage:
|
||||||
cls_device = _get_device_from_module(cls.__module__)
|
cls_device = _get_device_from_module(cls.__module__)
|
||||||
return (cls_device == instance.device.type) and (
|
return (cls_device == instance.device.type) and (
|
||||||
cls.dtype == instance.dtype
|
cls.dtype == instance.dtype
|
||||||
|
@ -1432,7 +1432,7 @@ class MultiThreadedTestCase(TestCase):
|
|||||||
logger.error("Caught exception: \n%s exiting thread %s", msg, rank)
|
logger.error("Caught exception: \n%s exiting thread %s", msg, rank)
|
||||||
error_msg += f"Thread {rank} exited with exception:\n{msg}\n"
|
error_msg += f"Thread {rank} exited with exception:\n{msg}\n"
|
||||||
elif isinstance(exc, SystemExit):
|
elif isinstance(exc, SystemExit):
|
||||||
if type(exc.code) == int and skip_code < 0:
|
if type(exc.code) is int and skip_code < 0:
|
||||||
skip_code = exc.code
|
skip_code = exc.code
|
||||||
|
|
||||||
# check exceptions
|
# check exceptions
|
||||||
|
@ -1247,7 +1247,7 @@ class QuantizationTestCase(TestCase):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# TODO: make img_data a single example instead of a list
|
# TODO: make img_data a single example instead of a list
|
||||||
if type(inputs) == list:
|
if type(inputs) is list:
|
||||||
inputs = inputs[0]
|
inputs = inputs[0]
|
||||||
|
|
||||||
if quant_type == QuantType.QAT:
|
if quant_type == QuantType.QAT:
|
||||||
|
@ -365,7 +365,7 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs)
|
|||||||
for mask in _generate_masked_op_mask(
|
for mask in _generate_masked_op_mask(
|
||||||
sample_input.input.shape, device, **kwargs
|
sample_input.input.shape, device, **kwargs
|
||||||
):
|
):
|
||||||
if type(mask) != torch.Tensor:
|
if type(mask) is not torch.Tensor:
|
||||||
continue
|
continue
|
||||||
sample_input_args, sample_input_kwargs = (
|
sample_input_args, sample_input_kwargs = (
|
||||||
sample_input.args,
|
sample_input.args,
|
||||||
|
@ -70,7 +70,7 @@ class WrapperSubclass(torch.Tensor):
|
|||||||
def __coerce_same_metadata_as_tangent__(
|
def __coerce_same_metadata_as_tangent__(
|
||||||
self, expected_metadata: Any, expected_type: Optional[type] = None
|
self, expected_metadata: Any, expected_type: Optional[type] = None
|
||||||
):
|
):
|
||||||
if expected_type == type(self.a):
|
if expected_type is type(self.a):
|
||||||
return self.a
|
return self.a
|
||||||
elif expected_type is TwoTensor:
|
elif expected_type is TwoTensor:
|
||||||
return TwoTensor(self.a, self.a.clone())
|
return TwoTensor(self.a, self.a.clone())
|
||||||
|
@ -187,7 +187,7 @@ class GraphPy:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for key, node in self.nodes_io.items():
|
for key, node in self.nodes_io.items():
|
||||||
if type(node) == NodeBase:
|
if type(node) is NodeBase:
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore # unsupported-operation
|
||||||
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
|
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
|
||||||
if hasattr(node, "input_or_output"):
|
if hasattr(node, "input_or_output"):
|
||||||
|
Reference in New Issue
Block a user