mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Pyrefly][Refactor] Replace dict() calls with literal dict syntax for improved readability (#157735)
There are 31 places that I spotted which construct literal dictionaries. This PR refactors dictionary construction by replacing` dict(...) `calls with `literal {...}` syntax where applicable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157735 Approved by: https://github.com/ezyang, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
0f31445139
commit
4f5be56612
@ -211,12 +211,12 @@ inline void {{kernel_name}}(
|
||||
) -> str:
|
||||
buffer_size = " * ".join(map(str, size_args))
|
||||
return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render(
|
||||
dict(
|
||||
buffer_name=buffer_name,
|
||||
buffer_dtype=buffer_dtype,
|
||||
buffer_size=buffer_size,
|
||||
is_msvc_compiler=cpp_builder.is_msvc_cl(),
|
||||
)
|
||||
{
|
||||
"buffer_name": buffer_name,
|
||||
"buffer_dtype": buffer_dtype,
|
||||
"buffer_size": buffer_size,
|
||||
"is_msvc_compiler": cpp_builder.is_msvc_cl(),
|
||||
}
|
||||
)
|
||||
|
||||
def is_woq_int4(self):
|
||||
|
@ -1179,27 +1179,27 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
|
||||
instance_definition, instance_type = self._define_gemm_instance(op, evt_name)
|
||||
|
||||
options = dict(
|
||||
alpha=self.alpha,
|
||||
beta=self.beta,
|
||||
X=X,
|
||||
W=W,
|
||||
Y=Y,
|
||||
kernel_call_signature=kernel_call_signature,
|
||||
Bias=Bias,
|
||||
epilogue_template=epilogue_template,
|
||||
argument_template=argument_template,
|
||||
should_swap_xw=should_swap_xw,
|
||||
template=self,
|
||||
kernel=kernel,
|
||||
instance_definition=instance_definition,
|
||||
instance_type=instance_type,
|
||||
input_reorder=self.input_reorder,
|
||||
epilogue_args=evt_args,
|
||||
test_call_statement=test_call_statement,
|
||||
op_conf_name=op.configuration_name(),
|
||||
epilogue_visitor_tree=evt_code,
|
||||
)
|
||||
options = {
|
||||
"alpha": self.alpha,
|
||||
"beta": self.beta,
|
||||
"X": X,
|
||||
"W": W,
|
||||
"Y": Y,
|
||||
"kernel_call_signature": kernel_call_signature,
|
||||
"Bias": Bias,
|
||||
"epilogue_template": epilogue_template,
|
||||
"argument_template": argument_template,
|
||||
"should_swap_xw": should_swap_xw,
|
||||
"template": self,
|
||||
"kernel": kernel,
|
||||
"instance_definition": instance_definition,
|
||||
"instance_type": instance_type,
|
||||
"input_reorder": self.input_reorder,
|
||||
"epilogue_args": evt_args,
|
||||
"test_call_statement": test_call_statement,
|
||||
"op_conf_name": op.configuration_name(),
|
||||
"epilogue_visitor_tree": evt_code,
|
||||
}
|
||||
options.update(dict(zip(extra_names, extra_inputs)))
|
||||
res = self._template_from_string(self._get_template()).render(**options)
|
||||
if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias):
|
||||
@ -1587,19 +1587,19 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
|
||||
before the function call.
|
||||
"""
|
||||
options = dict(
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
X=X,
|
||||
W=W,
|
||||
Y=Y,
|
||||
Bias=Bias,
|
||||
template=self,
|
||||
kernel=kernel,
|
||||
M="M",
|
||||
N="N",
|
||||
epilogue_args=epilogue_args,
|
||||
)
|
||||
options = {
|
||||
"alpha": alpha,
|
||||
"beta": beta,
|
||||
"X": X,
|
||||
"W": W,
|
||||
"Y": Y,
|
||||
"Bias": Bias,
|
||||
"template": self,
|
||||
"kernel": kernel,
|
||||
"M": "M",
|
||||
"N": "N",
|
||||
"epilogue_args": epilogue_args,
|
||||
}
|
||||
assert epilogue_template is not None
|
||||
|
||||
if should_swap_xw:
|
||||
@ -1878,21 +1878,21 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
|
||||
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
|
||||
before the function call.
|
||||
"""
|
||||
options = dict(
|
||||
instance_type=instance_type,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
X=X,
|
||||
W=W,
|
||||
Y=Y,
|
||||
Bias=Bias,
|
||||
Meta=Meta,
|
||||
template=self,
|
||||
kernel=kernel,
|
||||
M="M",
|
||||
N="N",
|
||||
epilogue_args=epilogue_args,
|
||||
)
|
||||
options = {
|
||||
"instance_type": instance_type,
|
||||
"alpha": alpha,
|
||||
"beta": beta,
|
||||
"X": X,
|
||||
"W": W,
|
||||
"Y": Y,
|
||||
"Bias": Bias,
|
||||
"Meta": Meta,
|
||||
"template": self,
|
||||
"kernel": kernel,
|
||||
"M": "M",
|
||||
"N": "N",
|
||||
"epilogue_args": epilogue_args,
|
||||
}
|
||||
|
||||
if epilogue_template is None:
|
||||
arguments = self._template_from_string(argument_template).render(
|
||||
|
@ -86,12 +86,12 @@ def tma_options() -> dict[str, Any]:
|
||||
|
||||
|
||||
def persistent_mm_options(mat1, mat2):
|
||||
res = dict(
|
||||
A_ROW_MAJOR=not mat1.layout.is_transposed(),
|
||||
B_ROW_MAJOR=not mat2.layout.is_transposed(),
|
||||
NUM_SMS=get_num_sms(),
|
||||
TMA_SIZE=TMA_DESCRIPTOR_SIZE,
|
||||
)
|
||||
res = {
|
||||
"A_ROW_MAJOR": not mat1.layout.is_transposed(),
|
||||
"B_ROW_MAJOR": not mat2.layout.is_transposed(),
|
||||
"NUM_SMS": get_num_sms(),
|
||||
"TMA_SIZE": TMA_DESCRIPTOR_SIZE,
|
||||
}
|
||||
res.update(tma_options())
|
||||
return res
|
||||
|
||||
|
@ -147,12 +147,12 @@ def grouped_gemm_lowering(
|
||||
choices: list[ChoiceCaller] = []
|
||||
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
|
||||
|
||||
kwargs = dict(
|
||||
has_bias=[bias is not None for bias in b],
|
||||
trans_w=True,
|
||||
epilogue_creator=None,
|
||||
act_mapping=dict.fromkeys(range(num_gemm), x),
|
||||
)
|
||||
kwargs = {
|
||||
"has_bias": [bias is not None for bias in b],
|
||||
"trans_w": True,
|
||||
"epilogue_creator": None,
|
||||
"act_mapping": dict.fromkeys(range(num_gemm), x),
|
||||
}
|
||||
|
||||
input_nodes = [x, *w]
|
||||
input_nodes.extend([bias for bias in b if bias is not None])
|
||||
@ -353,11 +353,13 @@ def register_onednn_fusion_ops():
|
||||
buf, attr, scalars=scalars, algorithm=algorithm
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
has_bias=b is not None,
|
||||
trans_w=True,
|
||||
epilogue_creator=None if attr == "none" else epilogue_creator,
|
||||
)
|
||||
kwargs = {
|
||||
"has_bias": b is not None,
|
||||
"trans_w": True,
|
||||
"epilogue_creator": (
|
||||
None if attr == "none" else epilogue_creator
|
||||
),
|
||||
}
|
||||
if b is not None:
|
||||
kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment]
|
||||
CppGemmTemplate.add_choices(
|
||||
@ -416,11 +418,12 @@ def register_onednn_fusion_ops():
|
||||
def epilogue_creator(buf):
|
||||
return create_epilogue_with_attr(buf, attr, other=y)
|
||||
|
||||
kwargs = dict(
|
||||
has_bias=b is not None,
|
||||
trans_w=True,
|
||||
epilogue_creator=epilogue_creator,
|
||||
)
|
||||
kwargs = {
|
||||
"has_bias": b is not None,
|
||||
"trans_w": True,
|
||||
"epilogue_creator": epilogue_creator,
|
||||
}
|
||||
|
||||
kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
|
||||
CppGemmTemplate.add_choices(
|
||||
choices,
|
||||
|
@ -166,123 +166,123 @@ Example::
|
||||
""",
|
||||
)
|
||||
|
||||
args_and_kwargs = dict(
|
||||
args_and_kwargs = {
|
||||
# argument name sufficies separated by double underscore will
|
||||
# be removed in the final documentation string.
|
||||
sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
cumsum=(("dim__as_int",), ("dtype=None", "mask=None")),
|
||||
cumprod=(("dim__as_int",), ("dtype=None", "mask=None")),
|
||||
amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
norm=(
|
||||
"sum": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
"prod": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
"cumsum": (("dim__as_int",), ("dtype=None", "mask=None")),
|
||||
"cumprod": (("dim__as_int",), ("dtype=None", "mask=None")),
|
||||
"amin": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
"amax": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
"argmin": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
"argmax": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
"mean": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
"median": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
"norm": (
|
||||
(
|
||||
"ord",
|
||||
"dim",
|
||||
),
|
||||
("keepdim=False", "dtype=None", "mask=None"),
|
||||
),
|
||||
var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
|
||||
log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
|
||||
softmin=(("dim__as_int",), ("dtype=None", "mask=None")),
|
||||
normalize=(
|
||||
"var": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
"std": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
"logsumexp": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||
"softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
|
||||
"log_softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
|
||||
"softmin": (("dim__as_int",), ("dtype=None", "mask=None")),
|
||||
"normalize": (
|
||||
(
|
||||
"ord__required",
|
||||
"dim__as_int",
|
||||
),
|
||||
("eps=1e-12", "dtype=None", "mask=None"),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
argument_declarations = dict(
|
||||
dim="""\
|
||||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||||
Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
|
||||
dim__as_int="""\
|
||||
dim (int): the dimension along which {operation name} is computed.""",
|
||||
ord="""\
|
||||
ord (int, float, optional): the order of vector norm. Default: 2.
|
||||
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
|
||||
ord__required="""\
|
||||
ord (int, float): the order of vector norm. Default: 2.
|
||||
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
|
||||
unbiased="""\
|
||||
unbiased (bool): when True, use Bessel's correction, otherwise, compute
|
||||
the uncorrected sample variance.""",
|
||||
eps="""\
|
||||
eps (float, optional): small value to avoid division by zero. Default: {default}.""",
|
||||
keepdim="""\
|
||||
keepdim (bool, optional): whether the output tensor has
|
||||
:attr:`dim` retained or not. Default: {default}.""",
|
||||
dtype="""\
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: {default}.""",
|
||||
mask="""\
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
|
||||
)
|
||||
argument_declarations = {
|
||||
"dim": """\
|
||||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||||
Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
|
||||
"dim__as_int": """\
|
||||
dim (int): the dimension along which {operation name} is computed.""",
|
||||
"ord": """\
|
||||
ord (int, float, optional): the order of vector norm. Default: 2.
|
||||
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
|
||||
"ord__required": """\
|
||||
ord (int, float): the order of vector norm. Default: 2.
|
||||
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
|
||||
"unbiased": """\
|
||||
unbiased (bool): when True, use Bessel's correction, otherwise, compute
|
||||
the uncorrected sample variance.""",
|
||||
"eps": """\
|
||||
eps (float, optional): small value to avoid division by zero. Default: {default}.""",
|
||||
"keepdim": """\
|
||||
keepdim (bool, optional): whether the output tensor has
|
||||
:attr:`dim` retained or not. Default: {default}.""",
|
||||
"dtype": """\
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. Default: {default}.""",
|
||||
"mask": """\
|
||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of input tensor
|
||||
elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
|
||||
}
|
||||
|
||||
definitions = dict(
|
||||
softmax="""\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
|
||||
defined as ``exp(x[i])/sum(exp(x))``.""",
|
||||
log_softmax="""\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
|
||||
defined as ``log(exp(x[i])/sum(exp(x)))``.""",
|
||||
softmin="""\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
|
||||
defined as ``exp(-x[i])/sum(exp(-x))``.""",
|
||||
normalize="""\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
|
||||
defined as ``x[i]/max(norm(x, p), eps)``.""",
|
||||
cumsum="""\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
|
||||
defined as ``sum(x[:i])``.""",
|
||||
cumprod="""\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
|
||||
defined as ``prod(x[:i])``.""",
|
||||
)
|
||||
definitions = {
|
||||
"softmax": """\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
|
||||
defined as ``exp(x[i])/sum(exp(x))``.""",
|
||||
"log_softmax": """\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
|
||||
defined as ``log(exp(x[i])/sum(exp(x)))``.""",
|
||||
"softmin": """\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
|
||||
defined as ``exp(-x[i])/sum(exp(-x))``.""",
|
||||
"normalize": """\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
|
||||
defined as ``x[i]/max(norm(x, p), eps)``.""",
|
||||
"cumsum": """\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
|
||||
defined as ``sum(x[:i])``.""",
|
||||
"cumprod": """\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
|
||||
defined as ``prod(x[:i])``.""",
|
||||
}
|
||||
|
||||
reduction_names = dict(
|
||||
sum="sum",
|
||||
prod="product",
|
||||
amax="maximum",
|
||||
amin="minimum",
|
||||
argmax="argmax",
|
||||
argmin="argmin",
|
||||
mean="mean",
|
||||
median="median",
|
||||
norm="norm",
|
||||
var="variance",
|
||||
std="standard_deviation",
|
||||
logsumexp="logsumexp",
|
||||
)
|
||||
reduction_names = {
|
||||
"sum": "sum",
|
||||
"prod": "product",
|
||||
"amax": "maximum",
|
||||
"amin": "minimum",
|
||||
"argmax": "argmax",
|
||||
"argmin": "argmin",
|
||||
"mean": "mean",
|
||||
"median": "median",
|
||||
"norm": "norm",
|
||||
"var": "variance",
|
||||
"std": "standard_deviation",
|
||||
"logsumexp": "logsumexp",
|
||||
}
|
||||
|
||||
normalization_names = dict(
|
||||
softmax="softmax",
|
||||
log_softmax="log_softmax",
|
||||
softmin="softmin",
|
||||
normalize="normalize",
|
||||
cumsum="cumulative_sum",
|
||||
cumprod="cumulative_prod",
|
||||
)
|
||||
normalization_names = {
|
||||
"softmax": "softmax",
|
||||
"log_softmax": "log_softmax",
|
||||
"softmin": "softmin",
|
||||
"normalize": "normalize",
|
||||
"cumsum": "cumulative_sum",
|
||||
"cumprod": "cumulative_prod",
|
||||
}
|
||||
|
||||
operation_names = {}
|
||||
operation_names.update(reduction_names)
|
||||
|
@ -47,15 +47,15 @@ class Adafactor(Optimizer):
|
||||
raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}")
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
beta2_decay=beta2_decay,
|
||||
eps=eps,
|
||||
d=d,
|
||||
weight_decay=weight_decay,
|
||||
foreach=foreach,
|
||||
maximize=maximize,
|
||||
)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"beta2_decay": beta2_decay,
|
||||
"eps": eps,
|
||||
"d": d,
|
||||
"weight_decay": weight_decay,
|
||||
"foreach": foreach,
|
||||
"maximize": maximize,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
@ -50,16 +50,16 @@ class Adadelta(Optimizer):
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
rho=rho,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
foreach=foreach,
|
||||
differentiable=differentiable,
|
||||
)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"rho": rho,
|
||||
"eps": eps,
|
||||
"weight_decay": weight_decay,
|
||||
"maximize": maximize,
|
||||
"capturable": capturable,
|
||||
"foreach": foreach,
|
||||
"differentiable": differentiable,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
@ -54,17 +54,17 @@ class Adagrad(Optimizer):
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
lr_decay=lr_decay,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
initial_accumulator_value=initial_accumulator_value,
|
||||
foreach=foreach,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
fused=fused,
|
||||
)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"lr_decay": lr_decay,
|
||||
"eps": eps,
|
||||
"weight_decay": weight_decay,
|
||||
"initial_accumulator_value": initial_accumulator_value,
|
||||
"foreach": foreach,
|
||||
"maximize": maximize,
|
||||
"differentiable": differentiable,
|
||||
"fused": fused,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
if fused:
|
||||
|
@ -85,19 +85,19 @@ class Adam(Optimizer):
|
||||
if betas[1].numel() != 1:
|
||||
raise ValueError("Tensor betas[1] must be 1-element")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=amsgrad,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
fused=fused,
|
||||
decoupled_weight_decay=decoupled_weight_decay,
|
||||
)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"betas": betas,
|
||||
"eps": eps,
|
||||
"weight_decay": weight_decay,
|
||||
"amsgrad": amsgrad,
|
||||
"maximize": maximize,
|
||||
"foreach": foreach,
|
||||
"capturable": capturable,
|
||||
"differentiable": differentiable,
|
||||
"fused": fused,
|
||||
"decoupled_weight_decay": decoupled_weight_decay,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
if fused:
|
||||
|
@ -53,16 +53,16 @@ class Adamax(Optimizer):
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
foreach=foreach,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
capturable=capturable,
|
||||
)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"betas": betas,
|
||||
"eps": eps,
|
||||
"weight_decay": weight_decay,
|
||||
"foreach": foreach,
|
||||
"maximize": maximize,
|
||||
"differentiable": differentiable,
|
||||
"capturable": capturable,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
@ -47,17 +47,17 @@ class ASGD(Optimizer):
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
lambd=lambd,
|
||||
alpha=alpha,
|
||||
t0=t0,
|
||||
weight_decay=weight_decay,
|
||||
foreach=foreach,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
capturable=capturable,
|
||||
)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"lambd": lambd,
|
||||
"alpha": alpha,
|
||||
"t0": t0,
|
||||
"weight_decay": weight_decay,
|
||||
"foreach": foreach,
|
||||
"maximize": maximize,
|
||||
"differentiable": differentiable,
|
||||
"capturable": capturable,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
@ -231,15 +231,15 @@ class LBFGS(Optimizer):
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if max_eval is None:
|
||||
max_eval = max_iter * 5 // 4
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
max_iter=max_iter,
|
||||
max_eval=max_eval,
|
||||
tolerance_grad=tolerance_grad,
|
||||
tolerance_change=tolerance_change,
|
||||
history_size=history_size,
|
||||
line_search_fn=line_search_fn,
|
||||
)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"max_iter": max_iter,
|
||||
"max_eval": max_eval,
|
||||
"tolerance_grad": tolerance_grad,
|
||||
"tolerance_change": tolerance_change,
|
||||
"history_size": history_size,
|
||||
"line_search_fn": line_search_fn,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
if len(self.param_groups) != 1:
|
||||
|
@ -59,18 +59,18 @@ class NAdam(Optimizer): # noqa: D101
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
if not 0.0 <= momentum_decay:
|
||||
raise ValueError(f"Invalid momentum_decay value: {momentum_decay}")
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
momentum_decay=momentum_decay,
|
||||
decoupled_weight_decay=decoupled_weight_decay,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"betas": betas,
|
||||
"eps": eps,
|
||||
"weight_decay": weight_decay,
|
||||
"momentum_decay": momentum_decay,
|
||||
"decoupled_weight_decay": decoupled_weight_decay,
|
||||
"maximize": maximize,
|
||||
"foreach": foreach,
|
||||
"capturable": capturable,
|
||||
"differentiable": differentiable,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state): # noqa: D105
|
||||
|
@ -56,17 +56,17 @@ class RAdam(Optimizer): # noqa: D101
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
capturable=capturable,
|
||||
decoupled_weight_decay=decoupled_weight_decay,
|
||||
differentiable=differentiable,
|
||||
)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"betas": betas,
|
||||
"eps": eps,
|
||||
"weight_decay": weight_decay,
|
||||
"maximize": maximize,
|
||||
"foreach": foreach,
|
||||
"capturable": capturable,
|
||||
"decoupled_weight_decay": decoupled_weight_decay,
|
||||
"differentiable": differentiable,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state): # noqa: D105
|
||||
|
@ -55,18 +55,18 @@ class RMSprop(Optimizer): # noqa: D101
|
||||
if not 0.0 <= alpha:
|
||||
raise ValueError(f"Invalid alpha value: {alpha}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
alpha=alpha,
|
||||
eps=eps,
|
||||
centered=centered,
|
||||
weight_decay=weight_decay,
|
||||
capturable=capturable,
|
||||
foreach=foreach,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"momentum": momentum,
|
||||
"alpha": alpha,
|
||||
"eps": eps,
|
||||
"centered": centered,
|
||||
"weight_decay": weight_decay,
|
||||
"capturable": capturable,
|
||||
"foreach": foreach,
|
||||
"maximize": maximize,
|
||||
"differentiable": differentiable,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state): # noqa: D105
|
||||
|
@ -47,15 +47,15 @@ class Rprop(Optimizer): # noqa: D101
|
||||
if not 0.0 < etas[0] < 1.0 < etas[1]:
|
||||
raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
etas=etas,
|
||||
step_sizes=step_sizes,
|
||||
foreach=foreach,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
capturable=capturable,
|
||||
)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"etas": etas,
|
||||
"step_sizes": step_sizes,
|
||||
"foreach": foreach,
|
||||
"maximize": maximize,
|
||||
"differentiable": differentiable,
|
||||
"capturable": capturable,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state): # noqa: D105
|
||||
|
@ -49,17 +49,17 @@ class SGD(Optimizer): # noqa: D101
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
dampening=dampening,
|
||||
weight_decay=weight_decay,
|
||||
nesterov=nesterov,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
differentiable=differentiable,
|
||||
fused=fused,
|
||||
)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"momentum": momentum,
|
||||
"dampening": dampening,
|
||||
"weight_decay": weight_decay,
|
||||
"nesterov": nesterov,
|
||||
"maximize": maximize,
|
||||
"foreach": foreach,
|
||||
"differentiable": differentiable,
|
||||
"fused": fused,
|
||||
}
|
||||
if nesterov and (momentum <= 0 or dampening != 0):
|
||||
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
||||
super().__init__(params, defaults)
|
||||
|
@ -31,7 +31,12 @@ class SparseAdam(Optimizer):
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize)
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"betas": betas,
|
||||
"eps": eps,
|
||||
"maximize": maximize,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
sparse_params = []
|
||||
|
@ -1109,15 +1109,15 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
return res
|
||||
return None
|
||||
|
||||
sys_info = dict(
|
||||
protocol_version=PROTOCOL_VERSION,
|
||||
little_endian=sys.byteorder == "little",
|
||||
type_sizes=dict(
|
||||
short=SHORT_SIZE,
|
||||
int=INT_SIZE,
|
||||
long=LONG_SIZE,
|
||||
),
|
||||
)
|
||||
sys_info = {
|
||||
"protocol_version": PROTOCOL_VERSION,
|
||||
"little_endian": sys.byteorder == "little",
|
||||
"type_sizes": {
|
||||
"short": SHORT_SIZE,
|
||||
"int": INT_SIZE,
|
||||
"long": LONG_SIZE,
|
||||
},
|
||||
}
|
||||
|
||||
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
|
||||
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
|
||||
|
@ -598,7 +598,10 @@ def as_sparse_gradcheck(gradcheck):
|
||||
and obj.requires_grad
|
||||
and obj.layout in sparse_layouts
|
||||
):
|
||||
d = dict(layout=obj.layout, shape=obj.shape)
|
||||
d = {
|
||||
"layout": obj.layout,
|
||||
"shape": obj.shape,
|
||||
}
|
||||
if not masked:
|
||||
# Materialize unspecified elements with zero values
|
||||
batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim()
|
||||
|
@ -213,13 +213,14 @@ def get_model_info(
|
||||
path_prefix = prefix
|
||||
elif prefix != path_prefix:
|
||||
raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002
|
||||
zip_files.append(dict(
|
||||
filename=zi.filename,
|
||||
compression=zi.compress_type,
|
||||
compressed_size=zi.compress_size,
|
||||
file_size=zi.file_size,
|
||||
))
|
||||
|
||||
zip_files.append(
|
||||
{
|
||||
"filename": zi.filename,
|
||||
"compression": zi.compress_type,
|
||||
"compressed_size": zi.compress_size,
|
||||
"file_size": zi.file_size,
|
||||
}
|
||||
)
|
||||
assert path_prefix is not None
|
||||
version = zf.read(path_prefix + "/version").decode("utf-8").strip()
|
||||
|
||||
@ -332,18 +333,20 @@ def get_model_info(
|
||||
continue
|
||||
extra_pickles[zi.filename] = contents
|
||||
|
||||
return {"model": dict(
|
||||
title=title,
|
||||
file_size=file_size,
|
||||
version=version,
|
||||
zip_files=zip_files,
|
||||
interned_strings=list(interned_strings),
|
||||
code_files=code_files,
|
||||
model_data=model_data,
|
||||
constants=constants,
|
||||
extra_files_jsons=extra_files_jsons,
|
||||
extra_pickles=extra_pickles,
|
||||
)}
|
||||
return {
|
||||
"model": {
|
||||
"title": title,
|
||||
"file_size": file_size,
|
||||
"version": version,
|
||||
"zip_files": zip_files,
|
||||
"interned_strings": list(interned_strings),
|
||||
"code_files": code_files,
|
||||
"model_data": model_data,
|
||||
"constants": constants,
|
||||
"extra_files_jsons": extra_files_jsons,
|
||||
"extra_pickles": extra_pickles,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_inline_skeleton():
|
||||
|
Reference in New Issue
Block a user