mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][PYFMT] migrate PYFMT for torch/[a-c]*/
to ruff format
(#144554)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144554 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
d56f11a1f2
commit
3fd84a8592
@ -123,6 +123,7 @@ def allow_in_graph(fn):
|
||||
|
||||
torch.compiler.allow_in_graph(my_custom_function)
|
||||
|
||||
|
||||
@torch.compile(...)
|
||||
def fn(x):
|
||||
x = torch.add(x, 1)
|
||||
@ -130,6 +131,7 @@ def allow_in_graph(fn):
|
||||
x = torch.add(x, 1)
|
||||
return x
|
||||
|
||||
|
||||
fn(...)
|
||||
|
||||
Will capture a single graph containing ``my_custom_function()``.
|
||||
@ -260,14 +262,15 @@ def set_stance(
|
||||
.. code-block:: python
|
||||
|
||||
@torch.compile
|
||||
def foo(x):
|
||||
...
|
||||
def foo(x): ...
|
||||
|
||||
|
||||
@torch.compiler.set_stance("force_eager")
|
||||
def bar():
|
||||
# will not be compiled
|
||||
foo(...)
|
||||
|
||||
|
||||
bar()
|
||||
|
||||
with torch.compiler.set_stance("force_eager"):
|
||||
@ -375,6 +378,7 @@ def cudagraph_mark_step_begin():
|
||||
def rand_foo():
|
||||
return torch.rand([4], device="cuda")
|
||||
|
||||
|
||||
for _ in range(5):
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
rand_foo() + rand_foo()
|
||||
|
@ -72,9 +72,9 @@ class CacheArtifactFactory:
|
||||
@classmethod
|
||||
def register(cls, artifact_cls: type[CacheArtifact]) -> type[CacheArtifact]:
|
||||
artifact_type_key = artifact_cls.type()
|
||||
assert (
|
||||
artifact_cls.type() not in cls._artifact_types
|
||||
), f"Artifact of type={artifact_type_key} already registered in mega-cache artifact factory"
|
||||
assert artifact_cls.type() not in cls._artifact_types, (
|
||||
f"Artifact of type={artifact_type_key} already registered in mega-cache artifact factory"
|
||||
)
|
||||
cls._artifact_types[artifact_type_key] = artifact_cls
|
||||
setattr(
|
||||
CacheInfo,
|
||||
@ -85,9 +85,9 @@ class CacheArtifactFactory:
|
||||
|
||||
@classmethod
|
||||
def _get_artifact_type(cls, artifact_type_key: str) -> type[CacheArtifact]:
|
||||
assert (
|
||||
artifact_type_key in cls._artifact_types
|
||||
), f"Artifact of type={artifact_type_key} not registered in mega-cache artifact factory"
|
||||
assert artifact_type_key in cls._artifact_types, (
|
||||
f"Artifact of type={artifact_type_key} not registered in mega-cache artifact factory"
|
||||
)
|
||||
return cls._artifact_types[artifact_type_key]
|
||||
|
||||
@classmethod
|
||||
@ -194,9 +194,9 @@ class CacheArtifactManager:
|
||||
# When serialize() is called, artifacts are transferred from _cache_artifacts to
|
||||
# internal data structure of the _serializer
|
||||
# This allows us to only pay the cost of serialization if serialize() is called
|
||||
_serializer: AppendingByteSerializer[
|
||||
tuple[str, list[CacheArtifact]]
|
||||
] = AppendingByteSerializer(serialize_fn=_serialize_single_cache)
|
||||
_serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = (
|
||||
AppendingByteSerializer(serialize_fn=_serialize_single_cache)
|
||||
)
|
||||
_cache_info: CacheInfo = CacheInfo()
|
||||
|
||||
@classmethod
|
||||
|
Reference in New Issue
Block a user