mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][Easy][4/19] enforce style for empty lines in import segments in functorch/
(#129755)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129755 Approved by: https://github.com/zou3519 ghstack dependencies: #129752
This commit is contained in:
committed by
PyTorch MergeBot
parent
a085acd7d6
commit
740fb22966
@ -4,7 +4,6 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
|
||||
from torch._functorch.deprecated import (
|
||||
combine_state_for_ensemble,
|
||||
functionalize,
|
||||
@ -26,13 +25,15 @@ from torch._functorch.make_functional import (
|
||||
FunctionalModuleWithBuffers,
|
||||
)
|
||||
|
||||
# Was never documented
|
||||
from torch._functorch.python_key import make_fx
|
||||
|
||||
|
||||
# Top-level APIs. Please think carefully before adding something to the
|
||||
# top-level namespace:
|
||||
# - private helper functions should go into torch._functorch
|
||||
# - very experimental things should go into functorch.experimental
|
||||
# - compilation related things should go into functorch.compile
|
||||
|
||||
# Was never documented
|
||||
from torch._functorch.python_key import make_fx
|
||||
|
||||
__version__ = torch.__version__
|
||||
|
@ -1,13 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from torch._functorch.benchmark_utils import compute_utilization
|
||||
|
||||
|
||||
# process the chrome traces output by the pytorch profiler
|
||||
# require the json input file's name to be in format {model_name}_chrome_trace_*.json
|
||||
# the runtimes file should have format (model_name, runtime)
|
||||
|
@ -1,8 +1,6 @@
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
||||
from functorch import make_fx
|
||||
|
||||
from torch._functorch.compile_utils import fx_graph_cse
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
|
@ -5,9 +5,9 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
|
||||
from functorch.compile import pointwise_operator
|
||||
|
||||
|
||||
WRITE_CSV = False
|
||||
CUDA = False
|
||||
SIZES = [1, 512, 8192]
|
||||
|
@ -6,9 +6,9 @@ from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from functorch import grad, make_functional, vmap
|
||||
|
||||
|
||||
device = "cuda"
|
||||
batch_size = 128
|
||||
torch.manual_seed(0)
|
||||
|
@ -4,9 +4,9 @@ import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from functorch import pointwise_operator
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch._C._debug_set_fusion_group_inlining(False)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas
|
||||
|
||||
|
||||
df = pandas.read_csv("perf.csv")
|
||||
|
||||
ops = pandas.unique(df["operator"])
|
||||
|
@ -3,12 +3,13 @@ import inspect
|
||||
from typing import Sequence, Union
|
||||
|
||||
import functorch._C
|
||||
|
||||
import torch
|
||||
from functorch._C import dim as _C
|
||||
|
||||
from .tree_map import tree_flatten, tree_map
|
||||
from .wrap_type import wrap_type
|
||||
|
||||
|
||||
_C._patch_tensor_class()
|
||||
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
|
||||
|
||||
@ -23,6 +24,7 @@ class DimensionBindError(Exception):
|
||||
|
||||
from . import op_properties
|
||||
|
||||
|
||||
# use dict to avoid writing C++ bindings for set
|
||||
pointwise = dict.fromkeys(op_properties.pointwise, True)
|
||||
|
||||
|
@ -7,6 +7,7 @@ from contextlib import contextmanager
|
||||
|
||||
from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
|
||||
|
||||
|
||||
_enabled = False
|
||||
|
||||
|
||||
|
@ -5,12 +5,12 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import dis
|
||||
import inspect
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
from . import DimList
|
||||
|
||||
|
||||
_vmap_levels = []
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
|
||||
|
||||
# pointwise operators can go through a faster pathway
|
||||
|
||||
tensor_magic_methods = ["add", ""]
|
||||
|
@ -6,12 +6,13 @@
|
||||
|
||||
# reference python implementations for C ops
|
||||
import torch
|
||||
|
||||
from functorch._C import dim as _C
|
||||
|
||||
from . import op_properties
|
||||
from .batch_tensor import _enable_layers
|
||||
from .tree_map import tree_flatten, tree_map
|
||||
|
||||
|
||||
DimList = _C.DimList
|
||||
import operator
|
||||
from functools import reduce
|
||||
@ -407,7 +408,6 @@ def t__getitem__(self, input):
|
||||
# (keep track of whether we have to call super)
|
||||
# * call super if needed
|
||||
# * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
|
||||
|
||||
# this handles bool indexing handling, as well as some other simple cases.
|
||||
|
||||
is_simple = (
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
from functorch._C import dim
|
||||
|
||||
|
||||
tree_flatten = dim.tree_flatten
|
||||
|
||||
|
||||
|
@ -14,6 +14,7 @@ from types import (
|
||||
|
||||
from functorch._C import dim as _C
|
||||
|
||||
|
||||
_wrap_method = _C._wrap_method
|
||||
|
||||
FUNC_TYPES = (
|
||||
|
@ -16,6 +16,7 @@ import os
|
||||
|
||||
import functorch
|
||||
|
||||
|
||||
# import sys
|
||||
|
||||
# source code directory, relative to this file, for sphinx-autobuild
|
||||
@ -27,6 +28,7 @@ RELEASE = os.environ.get("RELEASE", False)
|
||||
|
||||
import pytorch_sphinx_theme
|
||||
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
# Required version of sphinx is set from docs/requirements.txt
|
||||
@ -274,11 +276,11 @@ import sphinx.ext.doctest
|
||||
|
||||
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
|
||||
# See http://stackoverflow.com/a/41184353/3343043
|
||||
|
||||
from docutils import nodes
|
||||
from sphinx import addnodes
|
||||
from sphinx.util.docfields import TypedField
|
||||
|
||||
|
||||
# Without this, doctest adds any example with a `>>>` as a test
|
||||
doctest_test_doctest_blocks = ""
|
||||
doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS
|
||||
|
@ -1,3 +1,4 @@
|
||||
from .rearrange import rearrange
|
||||
|
||||
|
||||
__all__ = ["rearrange"]
|
||||
|
@ -28,6 +28,7 @@ import keyword
|
||||
import warnings
|
||||
from typing import Collection, List, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
|
||||
_ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
|
||||
|
||||
|
||||
|
@ -4,8 +4,8 @@ import functools
|
||||
from typing import Callable, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from functorch._C import dim as _C
|
||||
|
||||
from ._parsing import (
|
||||
_ellipsis,
|
||||
AnonymousAxis,
|
||||
@ -14,6 +14,7 @@ from ._parsing import (
|
||||
validate_rearrange_expressions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["rearrange"]
|
||||
|
||||
dims = _C.dims
|
||||
|
@ -2,9 +2,9 @@ import time
|
||||
|
||||
import torch
|
||||
import torch.utils
|
||||
|
||||
from functorch.compile import aot_function, tvm_compile
|
||||
|
||||
|
||||
a = torch.randn(2000, 1, 4, requires_grad=True)
|
||||
b = torch.randn(1, 2000, 4)
|
||||
|
||||
|
@ -2,7 +2,6 @@ import timeit
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from functorch.compile import compiled_module, tvm_compile
|
||||
|
||||
|
||||
|
@ -8,10 +8,10 @@ import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from functorch import make_functional
|
||||
from functorch.compile import nnc_jit
|
||||
|
||||
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
|
||||
|
||||
|
@ -7,7 +7,6 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from functorch import grad, make_fx
|
||||
from functorch.compile import nnc_jit
|
||||
|
||||
|
@ -21,9 +21,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
from torch.func import functional_call, grad_and_value, vmap
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s:%(levelname)s:%(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
|
@ -6,6 +6,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.func import functional_call, grad_and_value, stack_module_state, vmap
|
||||
|
||||
|
||||
# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a
|
||||
# tutorial on Model Ensembling with JAX by Will Whitney.
|
||||
#
|
||||
|
@ -7,6 +7,7 @@ from torch import nn
|
||||
from torch.func import jacrev, vmap
|
||||
from torch.nn.functional import mse_loss
|
||||
|
||||
|
||||
sigma = 0.5
|
||||
epsilon = 4.0
|
||||
|
||||
|
@ -34,7 +34,6 @@ import higher
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import pandas as pd
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
|
||||
@ -43,6 +42,7 @@ import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch import nn
|
||||
|
||||
|
||||
mpl.use("Agg")
|
||||
plt.style.use("bmh")
|
||||
|
||||
|
@ -33,17 +33,16 @@ import time
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import pandas as pd
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
from functorch import make_functional_with_buffers
|
||||
from torch import nn
|
||||
|
||||
|
||||
mpl.use("Agg")
|
||||
plt.style.use("bmh")
|
||||
|
||||
|
@ -34,7 +34,6 @@ import time
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import pandas as pd
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
|
||||
@ -44,6 +43,7 @@ import torch.optim as optim
|
||||
from torch import nn
|
||||
from torch.func import functional_call, grad, vmap
|
||||
|
||||
|
||||
mpl.use("Agg")
|
||||
plt.style.use("bmh")
|
||||
|
||||
|
@ -11,6 +11,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
mpl.use("Agg")
|
||||
|
||||
|
||||
|
@ -12,6 +12,7 @@ import torch
|
||||
from torch.func import grad, vmap
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
mpl.use("Agg")
|
||||
|
||||
|
||||
|
@ -9,11 +9,11 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
from functorch import grad, make_functional, vmap
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
mpl.use("Agg")
|
||||
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
from torch import cond # noqa: F401
|
||||
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
|
||||
|
||||
from torch._higher_order_ops.map import ( # noqa: F401
|
||||
_stack_pytree,
|
||||
_unstack_pytree,
|
||||
|
@ -20,6 +20,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@ -85,6 +86,7 @@ predictions2 = [model(minibatch) for model in models]
|
||||
# stateless version of the model (fmodel) and stacked parameters and buffers.
|
||||
from functorch import combine_state_for_ensemble
|
||||
|
||||
|
||||
fmodel, params, buffers = combine_state_for_ensemble(models)
|
||||
[p.requires_grad_() for p in params]
|
||||
|
||||
@ -97,6 +99,7 @@ print([p.size(0) for p in params])
|
||||
assert minibatches.shape == (num_models, 64, 1, 28, 28)
|
||||
from functorch import vmap
|
||||
|
||||
|
||||
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
|
||||
assert torch.allclose(
|
||||
predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6
|
||||
|
@ -13,6 +13,7 @@ from functools import partial
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@ -54,6 +55,7 @@ jacobian = compute_jac(xp)
|
||||
# to PyTorch Autograd; instead, functorch provides a ``vjp`` transform:
|
||||
from functorch import vjp, vmap
|
||||
|
||||
|
||||
_, vjp_fn = vjp(partial(predict, weight, bias), x)
|
||||
(ft_jacobian,) = vmap(vjp_fn)(unit_vectors)
|
||||
assert torch.allclose(ft_jacobian, jacobian)
|
||||
@ -69,6 +71,7 @@ assert torch.allclose(ft_jacobian, jacobian)
|
||||
# respect to.
|
||||
from functorch import jacrev
|
||||
|
||||
|
||||
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
|
||||
assert torch.allclose(ft_jacobian, jacobian)
|
||||
|
||||
@ -78,6 +81,7 @@ assert torch.allclose(ft_jacobian, jacobian)
|
||||
# eliminate overhead and give better utilization of your hardware.
|
||||
from torch.utils.benchmark import Timer
|
||||
|
||||
|
||||
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
|
||||
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
|
||||
print(without_vmap.timeit(500))
|
||||
@ -108,6 +112,7 @@ ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
|
||||
# it column-by-column. The Jacobian matrix has M rows and N columns.
|
||||
from functorch import jacfwd, jacrev
|
||||
|
||||
|
||||
# Benchmark with more inputs than outputs
|
||||
Din = 32
|
||||
Dout = 2048
|
||||
@ -144,6 +149,7 @@ print(f"jacrev time: {using_bwd.timeit(500)}")
|
||||
# ``jacrev(jacrev(f))`` instead to compute hessians.
|
||||
from functorch import hessian
|
||||
|
||||
|
||||
# # TODO: make sure PyTorch has tanh_backward implemented for jvp!!
|
||||
# hess0 = hessian(predict, argnums=2)(weight, bias, x)
|
||||
# hess1 = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
|
||||
|
@ -13,6 +13,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@ -94,6 +95,7 @@ print(per_sample_grads[0].shape)
|
||||
# ``functorch.make_functional_with_buffers``.
|
||||
from functorch import grad, make_functional_with_buffers, vmap
|
||||
|
||||
|
||||
fmodel, params, buffers = make_functional_with_buffers(model)
|
||||
|
||||
|
||||
|
@ -33,7 +33,6 @@ ISORT_SKIPLIST = re.compile(
|
||||
# .github/**
|
||||
# benchmarks/**
|
||||
# functorch/**
|
||||
"functorch/**",
|
||||
# tools/**
|
||||
# torchgen/**
|
||||
# test/**
|
||||
|
Reference in New Issue
Block a user