mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE]: Update flake8 and plugins and fix bugs (#97795)
Update flake8 and flake8-plugins in lintrunner to a modern version. Enables more checks and makes flake8 checks significantly faster. Added a few additional rule ignores that will need to be fixed in the future. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97795 Approved by: https://github.com/alexsio27444, https://github.com/janeyx99, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7282be3d91
commit
597b558c51
4
.flake8
4
.flake8
@ -6,11 +6,13 @@ max-line-length = 120
|
|||||||
# E501 is not flexible enough, we're using B950 instead
|
# E501 is not flexible enough, we're using B950 instead
|
||||||
ignore =
|
ignore =
|
||||||
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
|
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
|
||||||
|
# fix these lints in the future
|
||||||
|
E275,
|
||||||
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
||||||
# to line this up with executable bit
|
# to line this up with executable bit
|
||||||
EXE001,
|
EXE001,
|
||||||
# these ignores are from flake8-bugbear; please fix!
|
# these ignores are from flake8-bugbear; please fix!
|
||||||
B007,B008,
|
B007,B008,B017,B019,B020,B023,B024,B026,B027,B028,B903,B904,B905,B906,B907
|
||||||
# these ignores are from flake8-comprehensions; please fix!
|
# these ignores are from flake8-comprehensions; please fix!
|
||||||
C407,C417
|
C407,C417
|
||||||
# these ignores are from flake8-logging-format; please fix!
|
# these ignores are from flake8-logging-format; please fix!
|
||||||
|
@ -33,15 +33,15 @@ init_command = [
|
|||||||
'python3',
|
'python3',
|
||||||
'tools/linter/adapters/pip_init.py',
|
'tools/linter/adapters/pip_init.py',
|
||||||
'--dry-run={{DRYRUN}}',
|
'--dry-run={{DRYRUN}}',
|
||||||
'flake8==3.8.2',
|
'flake8==6.0.0',
|
||||||
'flake8-bugbear==20.1.4',
|
'flake8-bugbear==23.3.23',
|
||||||
'flake8-comprehensions==3.11.1',
|
'flake8-comprehensions==3.11.1',
|
||||||
'flake8-executable==2.0.4',
|
'flake8-executable==2.1.3',
|
||||||
'flake8-logging-format==0.9.0',
|
'flake8-logging-format==0.9.0',
|
||||||
'flake8-pyi==20.5.0',
|
'flake8-pyi==23.3.1',
|
||||||
'mccabe==0.6.1',
|
'mccabe==0.7.0',
|
||||||
'pycodestyle==2.6.0',
|
'pycodestyle==2.10.0',
|
||||||
'pyflakes==2.2.0',
|
'pyflakes==3.0.1',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ def find_graph_variable(args):
|
|||||||
for arg in var_types.keys():
|
for arg in var_types.keys():
|
||||||
if ',' in args[arg]:
|
if ',' in args[arg]:
|
||||||
if args.get('x_axis_name'):
|
if args.get('x_axis_name'):
|
||||||
raise("Only 1 x axis graph variable allowed")
|
raise ValueError("Only 1 x axis graph variable allowed")
|
||||||
args[arg] = list(map(var_types[arg], args[arg].split(','))) # convert , separated str to list
|
args[arg] = list(map(var_types[arg], args[arg].split(','))) # convert , separated str to list
|
||||||
args['x_axis_name'] = arg
|
args['x_axis_name'] = arg
|
||||||
else:
|
else:
|
||||||
|
@ -397,7 +397,7 @@ class TestMin(TestCase):
|
|||||||
i, j = dims()
|
i, j = dims()
|
||||||
i.size = 3
|
i.size = 3
|
||||||
j.size = 4
|
j.size = 4
|
||||||
(i < j)
|
(i < j) # noqa: B015
|
||||||
|
|
||||||
def test_c(self):
|
def test_c(self):
|
||||||
_test_c()
|
_test_c()
|
||||||
|
@ -5803,7 +5803,7 @@ class CommonTemplate:
|
|||||||
a = a.max(0).values
|
a = a.max(0).values
|
||||||
c = torch.cat((a, b))
|
c = torch.cat((a, b))
|
||||||
c = c.round()
|
c = c.round()
|
||||||
b >= a[0]
|
b >= a[0] # noqa: B015
|
||||||
return c
|
return c
|
||||||
|
|
||||||
some_const = torch.tensor(6324)
|
some_const = torch.tensor(6324)
|
||||||
@ -5811,7 +5811,7 @@ class CommonTemplate:
|
|||||||
def fn2():
|
def fn2():
|
||||||
a = torch.tensor([[0.6324]])
|
a = torch.tensor([[0.6324]])
|
||||||
ret = torch.cat((a, a), dim=0)
|
ret = torch.cat((a, a), dim=0)
|
||||||
some_const >= a[0]
|
some_const >= a[0] # noqa: B015
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
self.common(fn1, (torch.tensor([[4.0]]), torch.tensor([5.0])))
|
self.common(fn1, (torch.tensor([[4.0]]), torch.tensor([5.0])))
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import List, Tuple, Optional, Dict, NamedTuple
|
from typing import List, Tuple, Dict, NamedTuple
|
||||||
|
|
||||||
# Make the helper files in test/ importable
|
# Make the helper files in test/ importable
|
||||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
@ -534,23 +534,23 @@ class TestFxModelReportObserver(QuantizationTestCase):
|
|||||||
|
|
||||||
# quick check that a reset occurred
|
# quick check that a reset occurred
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
getattr(model, "obs1").average_batch_activation_range,
|
model.obs1.average_batch_activation_range,
|
||||||
torch.tensor(float(0)),
|
torch.tensor(float(0)),
|
||||||
)
|
)
|
||||||
self.assertEqual(getattr(model, "obs1").epoch_activation_min, torch.tensor(float("inf")))
|
self.assertEqual(model.obs1.epoch_activation_min, torch.tensor(float("inf")))
|
||||||
self.assertEqual(getattr(model, "obs1").epoch_activation_max, torch.tensor(float("-inf")))
|
self.assertEqual(model.obs1.epoch_activation_max, torch.tensor(float("-inf")))
|
||||||
|
|
||||||
# loop through the batches and run through
|
# loop through the batches and run through
|
||||||
for index, batch in enumerate(split_up_data):
|
for index, batch in enumerate(split_up_data):
|
||||||
|
|
||||||
num_tracked_so_far = getattr(model, "obs1").num_batches_tracked
|
num_tracked_so_far = model.obs1.num_batches_tracked
|
||||||
self.assertEqual(num_tracked_so_far, index)
|
self.assertEqual(num_tracked_so_far, index)
|
||||||
|
|
||||||
# get general info about the batch and the model to use later
|
# get general info about the batch and the model to use later
|
||||||
batch_min, batch_max = torch.aminmax(batch)
|
batch_min, batch_max = torch.aminmax(batch)
|
||||||
current_average_range = getattr(model, "obs1").average_batch_activation_range
|
current_average_range = model.obs1.average_batch_activation_range
|
||||||
current_epoch_min = getattr(model, "obs1").epoch_activation_min
|
current_epoch_min = model.obs1.epoch_activation_min
|
||||||
current_epoch_max = getattr(model, "obs1").epoch_activation_max
|
current_epoch_max = model.obs1.epoch_activation_max
|
||||||
|
|
||||||
# run input through
|
# run input through
|
||||||
model(ex_input)
|
model(ex_input)
|
||||||
@ -560,13 +560,13 @@ class TestFxModelReportObserver(QuantizationTestCase):
|
|||||||
num_tracked_so_far + 1
|
num_tracked_so_far + 1
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
getattr(model, "obs1").average_batch_activation_range,
|
model.obs1.average_batch_activation_range,
|
||||||
correct_updated_value,
|
correct_updated_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
if current_epoch_max - current_epoch_min > 0:
|
if current_epoch_max - current_epoch_min > 0:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
getattr(model, "obs1").get_batch_to_epoch_ratio(),
|
model.obs1.get_batch_to_epoch_ratio(),
|
||||||
correct_updated_value / (current_epoch_max - current_epoch_min),
|
correct_updated_value / (current_epoch_max - current_epoch_min),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -589,13 +589,13 @@ class TestFxModelReportObserver(QuantizationTestCase):
|
|||||||
self.run_model_and_common_checks(model, ex_input, 1, 1)
|
self.run_model_and_common_checks(model, ex_input, 1, 1)
|
||||||
|
|
||||||
# make sure final values are all 0
|
# make sure final values are all 0
|
||||||
self.assertEqual(getattr(model, "obs1").epoch_activation_min, 0)
|
self.assertEqual(model.obs1.epoch_activation_min, 0)
|
||||||
self.assertEqual(getattr(model, "obs1").epoch_activation_max, 0)
|
self.assertEqual(model.obs1.epoch_activation_max, 0)
|
||||||
self.assertEqual(getattr(model, "obs1").average_batch_activation_range, 0)
|
self.assertEqual(model.obs1.average_batch_activation_range, 0)
|
||||||
|
|
||||||
# we should get an error if we try to calculate the ratio
|
# we should get an error if we try to calculate the ratio
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
ratio_val = getattr(model, "obs1").get_batch_to_epoch_ratio()
|
ratio_val = model.obs1.get_batch_to_epoch_ratio()
|
||||||
|
|
||||||
"""Case includes:
|
"""Case includes:
|
||||||
non-zero tensor
|
non-zero tensor
|
||||||
@ -616,13 +616,13 @@ class TestFxModelReportObserver(QuantizationTestCase):
|
|||||||
self.run_model_and_common_checks(model, ex_input, 1, 1)
|
self.run_model_and_common_checks(model, ex_input, 1, 1)
|
||||||
|
|
||||||
# make sure final values are all 0 except for range
|
# make sure final values are all 0 except for range
|
||||||
self.assertEqual(getattr(model, "obs1").epoch_activation_min, 1)
|
self.assertEqual(model.obs1.epoch_activation_min, 1)
|
||||||
self.assertEqual(getattr(model, "obs1").epoch_activation_max, 1)
|
self.assertEqual(model.obs1.epoch_activation_max, 1)
|
||||||
self.assertEqual(getattr(model, "obs1").average_batch_activation_range, 0)
|
self.assertEqual(model.obs1.average_batch_activation_range, 0)
|
||||||
|
|
||||||
# we should get an error if we try to calculate the ratio
|
# we should get an error if we try to calculate the ratio
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
ratio_val = getattr(model, "obs1").get_batch_to_epoch_ratio()
|
ratio_val = model.obs1.get_batch_to_epoch_ratio()
|
||||||
|
|
||||||
"""Case includes:
|
"""Case includes:
|
||||||
non-zero tensor
|
non-zero tensor
|
||||||
|
@ -377,10 +377,10 @@ def min_cut_rematerialization_partition(
|
|||||||
prims = torch.ops.prims
|
prims = torch.ops.prims
|
||||||
|
|
||||||
# compiler == "nvfuser" is the default set of recomputable ops
|
# compiler == "nvfuser" is the default set of recomputable ops
|
||||||
default_recomputable_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax, aten.to, aten.type_as, operator.getitem, aten.squeeze, aten.unsqueeze, aten.rsub, aten._to_copy] # noqa: E501
|
default_recomputable_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax, aten.to, aten.type_as, operator.getitem, aten.squeeze, aten.unsqueeze, aten.rsub, aten._to_copy] # noqa: E501,B950
|
||||||
view_ops = [aten.squeeze, aten.unsqueeze, aten.alias]
|
view_ops = [aten.squeeze, aten.unsqueeze, aten.alias]
|
||||||
if compiler == "inductor":
|
if compiler == "inductor":
|
||||||
default_recomputable_ops += [prims.div, prims.convert_element_type, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.arange, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum] # noqa: E501
|
default_recomputable_ops += [prims.div, prims.convert_element_type, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.arange, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum] # noqa: E501,B950
|
||||||
view_ops += [aten.view, aten.slice, aten.permute, aten.t, prims.broadcast_in_dim, aten.expand, aten.as_strided]
|
view_ops += [aten.view, aten.slice, aten.permute, aten.t, prims.broadcast_in_dim, aten.expand, aten.as_strided]
|
||||||
# Natalia said that we should allow recomputing indexing :)
|
# Natalia said that we should allow recomputing indexing :)
|
||||||
default_recomputable_ops += [aten.index]
|
default_recomputable_ops += [aten.index]
|
||||||
@ -400,7 +400,7 @@ def min_cut_rematerialization_partition(
|
|||||||
recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops)
|
recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops)
|
||||||
|
|
||||||
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
|
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
|
||||||
compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit] # noqa: E501
|
compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit] # noqa: E501,B950
|
||||||
|
|
||||||
unrecomputable_ops = random_ops + compute_intensive_ops
|
unrecomputable_ops = random_ops + compute_intensive_ops
|
||||||
|
|
||||||
|
@ -4287,8 +4287,6 @@ def empty_like(
|
|||||||
layout = a.layout if layout is None else layout
|
layout = a.layout if layout is None else layout
|
||||||
device = a.device if device is None else device
|
device = a.device if device is None else device
|
||||||
|
|
||||||
strides: Tuple[int, ...]
|
|
||||||
|
|
||||||
if memory_format != torch.preserve_format:
|
if memory_format != torch.preserve_format:
|
||||||
return torch.empty(
|
return torch.empty(
|
||||||
a.shape,
|
a.shape,
|
||||||
|
@ -131,7 +131,7 @@ class TestTrainingAwareCallback(TestCase):
|
|||||||
|
|
||||||
# data sparsifier args are correct
|
# data sparsifier args are correct
|
||||||
for key, value in sparsifier_args.items():
|
for key, value in sparsifier_args.items():
|
||||||
callback.data_sparsifier.defaults[key] == value
|
assert callback.data_sparsifier.defaults[key] == value
|
||||||
|
|
||||||
# data scheduler args are correct
|
# data scheduler args are correct
|
||||||
for key, value in scheduler_args.items():
|
for key, value in scheduler_args.items():
|
||||||
|
@ -50,7 +50,7 @@ class ContinuousBernoulli(ExponentialFamily):
|
|||||||
# validate 'probs' here if necessary as it is later clamped for numerical stability
|
# validate 'probs' here if necessary as it is later clamped for numerical stability
|
||||||
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
|
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
|
||||||
if validate_args is not None:
|
if validate_args is not None:
|
||||||
if not self.arg_constraints['probs'].check(getattr(self, 'probs')).all():
|
if not self.arg_constraints['probs'].check(self.probs).all():
|
||||||
raise ValueError("The parameter {} has invalid values".format('probs'))
|
raise ValueError("The parameter {} has invalid values".format('probs'))
|
||||||
self.probs = clamp_probs(self.probs)
|
self.probs = clamp_probs(self.probs)
|
||||||
else:
|
else:
|
||||||
|
@ -11,7 +11,7 @@ aten = torch.ops.aten
|
|||||||
|
|
||||||
|
|
||||||
# stateful ops are banned from CSE
|
# stateful ops are banned from CSE
|
||||||
rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501
|
rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501,B950
|
||||||
|
|
||||||
inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501
|
inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501
|
||||||
|
|
||||||
|
@ -3677,12 +3677,12 @@ Examples::
|
|||||||
)
|
)
|
||||||
|
|
||||||
@_overload # noqa: F811
|
@_overload # noqa: F811
|
||||||
def upsample(input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None) -> Tensor: # noqa: F811
|
def upsample(input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None) -> Tensor: # noqa: F811,B950
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@_overload # noqa: F811
|
@_overload # noqa: F811
|
||||||
def upsample(input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None) -> Tensor: # noqa: F811
|
def upsample(input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None) -> Tensor: # noqa: F811,B950
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -3752,17 +3752,17 @@ if upsample.__doc__:
|
|||||||
|
|
||||||
|
|
||||||
@_overload # noqa: F811
|
@_overload # noqa: F811
|
||||||
def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811
|
def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@_overload # noqa: F811
|
@_overload # noqa: F811
|
||||||
def interpolate(input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811
|
def interpolate(input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@_overload # noqa: F811
|
@_overload # noqa: F811
|
||||||
def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811
|
def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -3778,7 +3778,7 @@ def interpolate( # noqa: F811
|
|||||||
) -> Tensor: # noqa: F811
|
) -> Tensor: # noqa: F811
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811
|
def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950
|
||||||
r"""Down/up samples the input to either the given :attr:`size` or the given
|
r"""Down/up samples the input to either the given :attr:`size` or the given
|
||||||
:attr:`scale_factor`
|
:attr:`scale_factor`
|
||||||
|
|
||||||
|
@ -355,7 +355,7 @@ class LazyIrSchema:
|
|||||||
positional_args: List[LazyArgument] = []
|
positional_args: List[LazyArgument] = []
|
||||||
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
|
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
|
||||||
if arg_field == "self_arg" and func.arguments.self_arg is not None:
|
if arg_field == "self_arg" and func.arguments.self_arg is not None:
|
||||||
arg = getattr(func.arguments, "self_arg").argument
|
arg = func.arguments.self_arg.argument
|
||||||
positional_args.append(
|
positional_args.append(
|
||||||
LazyArgument(arg, self.properties, symint=symint)
|
LazyArgument(arg, self.properties, symint=symint)
|
||||||
)
|
)
|
||||||
@ -382,7 +382,9 @@ class LazyIrSchema:
|
|||||||
assert (
|
assert (
|
||||||
self.generator_arg is None
|
self.generator_arg is None
|
||||||
), "We expect there is only one generator arg"
|
), "We expect there is only one generator arg"
|
||||||
self.generator_arg = NamedCType(arg.name, arg.type)
|
self.generator_arg = NamedCType(
|
||||||
|
arg.name, arg.type # type:ignore[arg-type]
|
||||||
|
)
|
||||||
keyword_args.extend(
|
keyword_args.extend(
|
||||||
LazyArgument(arg, self.properties, symint=symint)
|
LazyArgument(arg, self.properties, symint=symint)
|
||||||
for arg in curr_args
|
for arg in curr_args
|
||||||
|
Reference in New Issue
Block a user