[BE][Easy][4/19] enforce style for empty lines in import segments in functorch/ (#129755)

See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129755
Approved by: https://github.com/zou3519
ghstack dependencies: #129752
This commit is contained in:
Xuehai Pan
2024-07-16 11:40:08 +08:00
committed by PyTorch MergeBot
parent a085acd7d6
commit 740fb22966
36 changed files with 48 additions and 27 deletions

View File

@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch._functorch.deprecated import (
combine_state_for_ensemble,
functionalize,
@ -26,13 +25,15 @@ from torch._functorch.make_functional import (
FunctionalModuleWithBuffers,
)
# Was never documented
from torch._functorch.python_key import make_fx
# Top-level APIs. Please think carefully before adding something to the
# top-level namespace:
# - private helper functions should go into torch._functorch
# - very experimental things should go into functorch.experimental
# - compilation related things should go into functorch.compile
# Was never documented
from torch._functorch.python_key import make_fx
__version__ = torch.__version__

View File

@ -1,13 +1,13 @@
#!/usr/bin/env python3
import argparse
import logging
import os
import pandas as pd
from torch._functorch.benchmark_utils import compute_utilization
# process the chrome traces output by the pytorch profiler
# require the json input file's name to be in format {model_name}_chrome_trace_*.json
# the runtimes file should have format (model_name, runtime)

View File

@ -1,8 +1,6 @@
import torch
import torch.fx as fx
from functorch import make_fx
from torch._functorch.compile_utils import fx_graph_cse
from torch.profiler import profile, ProfilerActivity

View File

@ -5,9 +5,9 @@ import numpy as np
import pandas as pd
import torch
from functorch.compile import pointwise_operator
WRITE_CSV = False
CUDA = False
SIZES = [1, 512, 8192]

View File

@ -6,9 +6,9 @@ from opacus.utils.module_modification import convert_batchnorm_modules
import torch
import torch.nn as nn
from functorch import grad, make_functional, vmap
device = "cuda"
batch_size = 128
torch.manual_seed(0)

View File

@ -4,9 +4,9 @@ import sys
import time
import torch
from functorch import pointwise_operator
torch.set_num_threads(1)
torch._C._debug_set_fusion_group_inlining(False)

View File

@ -1,6 +1,7 @@
import matplotlib.pyplot as plt
import pandas
df = pandas.read_csv("perf.csv")
ops = pandas.unique(df["operator"])

View File

@ -3,12 +3,13 @@ import inspect
from typing import Sequence, Union
import functorch._C
import torch
from functorch._C import dim as _C
from .tree_map import tree_flatten, tree_map
from .wrap_type import wrap_type
_C._patch_tensor_class()
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
@ -23,6 +24,7 @@ class DimensionBindError(Exception):
from . import op_properties
# use dict to avoid writing C++ bindings for set
pointwise = dict.fromkeys(op_properties.pointwise, True)

View File

@ -7,6 +7,7 @@ from contextlib import contextmanager
from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
_enabled = False

View File

@ -5,12 +5,12 @@
# LICENSE file in the root directory of this source tree.
import dis
import inspect
from dataclasses import dataclass
from typing import Union
from . import DimList
_vmap_levels = []

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import torch
# pointwise operators can go through a faster pathway
tensor_magic_methods = ["add", ""]

View File

@ -6,12 +6,13 @@
# reference python implementations for C ops
import torch
from functorch._C import dim as _C
from . import op_properties
from .batch_tensor import _enable_layers
from .tree_map import tree_flatten, tree_map
DimList = _C.DimList
import operator
from functools import reduce
@ -407,7 +408,6 @@ def t__getitem__(self, input):
# (keep track of whether we have to call super)
# * call super if needed
# * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
# this handles bool indexing handling, as well as some other simple cases.
is_simple = (

View File

@ -6,6 +6,7 @@
from functorch._C import dim
tree_flatten = dim.tree_flatten

View File

@ -14,6 +14,7 @@ from types import (
from functorch._C import dim as _C
_wrap_method = _C._wrap_method
FUNC_TYPES = (

View File

@ -16,6 +16,7 @@ import os
import functorch
# import sys
# source code directory, relative to this file, for sphinx-autobuild
@ -27,6 +28,7 @@ RELEASE = os.environ.get("RELEASE", False)
import pytorch_sphinx_theme
# -- General configuration ------------------------------------------------
# Required version of sphinx is set from docs/requirements.txt
@ -274,11 +276,11 @@ import sphinx.ext.doctest
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043
from docutils import nodes
from sphinx import addnodes
from sphinx.util.docfields import TypedField
# Without this, doctest adds any example with a `>>>` as a test
doctest_test_doctest_blocks = ""
doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS

View File

@ -1,3 +1,4 @@
from .rearrange import rearrange
__all__ = ["rearrange"]

View File

@ -28,6 +28,7 @@ import keyword
import warnings
from typing import Collection, List, Mapping, Optional, Set, Tuple, Union
_ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated

View File

@ -4,8 +4,8 @@ import functools
from typing import Callable, Dict, List, Sequence, Tuple, Union
import torch
from functorch._C import dim as _C
from ._parsing import (
_ellipsis,
AnonymousAxis,
@ -14,6 +14,7 @@ from ._parsing import (
validate_rearrange_expressions,
)
__all__ = ["rearrange"]
dims = _C.dims

View File

@ -2,9 +2,9 @@ import time
import torch
import torch.utils
from functorch.compile import aot_function, tvm_compile
a = torch.randn(2000, 1, 4, requires_grad=True)
b = torch.randn(1, 2000, 4)

View File

@ -2,7 +2,6 @@ import timeit
import torch
import torch.nn as nn
from functorch.compile import compiled_module, tvm_compile

View File

@ -8,10 +8,10 @@ import time
import torch
import torch.nn as nn
from functorch import make_functional
from functorch.compile import nnc_jit
torch._C._jit_override_can_fuse_on_cpu(True)

View File

@ -7,7 +7,6 @@
import time
import torch
from functorch import grad, make_fx
from functorch.compile import nnc_jit

View File

@ -21,9 +21,9 @@ import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from torch.func import functional_call, grad_and_value, vmap
logging.basicConfig(
format="%(asctime)s:%(levelname)s:%(message)s",
datefmt="%m/%d/%Y %H:%M:%S",

View File

@ -6,6 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.func import functional_call, grad_and_value, stack_module_state, vmap
# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a
# tutorial on Model Ensembling with JAX by Will Whitney.
#

View File

@ -7,6 +7,7 @@ from torch import nn
from torch.func import jacrev, vmap
from torch.nn.functional import mse_loss
sigma = 0.5
epsilon = 4.0

View File

@ -34,7 +34,6 @@ import higher
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from support.omniglot_loaders import OmniglotNShot
@ -43,6 +42,7 @@ import torch.nn.functional as F
import torch.optim as optim
from torch import nn
mpl.use("Agg")
plt.style.use("bmh")

View File

@ -33,17 +33,16 @@ import time
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from support.omniglot_loaders import OmniglotNShot
import torch
import torch.nn.functional as F
import torch.optim as optim
from functorch import make_functional_with_buffers
from torch import nn
mpl.use("Agg")
plt.style.use("bmh")

View File

@ -34,7 +34,6 @@ import time
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from support.omniglot_loaders import OmniglotNShot
@ -44,6 +43,7 @@ import torch.optim as optim
from torch import nn
from torch.func import functional_call, grad, vmap
mpl.use("Agg")
plt.style.use("bmh")

View File

@ -11,6 +11,7 @@ import numpy as np
import torch
from torch.nn import functional as F
mpl.use("Agg")

View File

@ -12,6 +12,7 @@ import torch
from torch.func import grad, vmap
from torch.nn import functional as F
mpl.use("Agg")

View File

@ -9,11 +9,11 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
from functorch import grad, make_functional, vmap
from torch import nn
from torch.nn import functional as F
mpl.use("Agg")

View File

@ -1,6 +1,5 @@
from torch import cond # noqa: F401
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
from torch._higher_order_ops.map import ( # noqa: F401
_stack_pytree,
_unstack_pytree,

View File

@ -20,6 +20,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
@ -85,6 +86,7 @@ predictions2 = [model(minibatch) for model in models]
# stateless version of the model (fmodel) and stacked parameters and buffers.
from functorch import combine_state_for_ensemble
fmodel, params, buffers = combine_state_for_ensemble(models)
[p.requires_grad_() for p in params]
@ -97,6 +99,7 @@ print([p.size(0) for p in params])
assert minibatches.shape == (num_models, 64, 1, 28, 28)
from functorch import vmap
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
assert torch.allclose(
predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6

View File

@ -13,6 +13,7 @@ from functools import partial
import torch
import torch.nn.functional as F
torch.manual_seed(0)
@ -54,6 +55,7 @@ jacobian = compute_jac(xp)
# to PyTorch Autograd; instead, functorch provides a ``vjp`` transform:
from functorch import vjp, vmap
_, vjp_fn = vjp(partial(predict, weight, bias), x)
(ft_jacobian,) = vmap(vjp_fn)(unit_vectors)
assert torch.allclose(ft_jacobian, jacobian)
@ -69,6 +71,7 @@ assert torch.allclose(ft_jacobian, jacobian)
# respect to.
from functorch import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
assert torch.allclose(ft_jacobian, jacobian)
@ -78,6 +81,7 @@ assert torch.allclose(ft_jacobian, jacobian)
# eliminate overhead and give better utilization of your hardware.
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
print(without_vmap.timeit(500))
@ -108,6 +112,7 @@ ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
# it column-by-column. The Jacobian matrix has M rows and N columns.
from functorch import jacfwd, jacrev
# Benchmark with more inputs than outputs
Din = 32
Dout = 2048
@ -144,6 +149,7 @@ print(f"jacrev time: {using_bwd.timeit(500)}")
# ``jacrev(jacrev(f))`` instead to compute hessians.
from functorch import hessian
# # TODO: make sure PyTorch has tanh_backward implemented for jvp!!
# hess0 = hessian(predict, argnums=2)(weight, bias, x)
# hess1 = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)

View File

@ -13,6 +13,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
@ -94,6 +95,7 @@ print(per_sample_grads[0].shape)
# ``functorch.make_functional_with_buffers``.
from functorch import grad, make_functional_with_buffers, vmap
fmodel, params, buffers = make_functional_with_buffers(model)

View File

@ -33,7 +33,6 @@ ISORT_SKIPLIST = re.compile(
# .github/**
# benchmarks/**
# functorch/**
"functorch/**",
# tools/**
# torchgen/**
# test/**