Apply UFMT to all non test/torch files (#106205)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106205
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2023-07-28 16:04:39 -04:00
committed by PyTorch MergeBot
parent 1163800d0f
commit e6ec0efaf8
51 changed files with 1671 additions and 941 deletions

View File

@ -908,62 +908,6 @@ exclude_patterns = [
'third_party/**/*.pyi',
# These files are all grandfathered in, feel free to remove from this list
# as necessary
'aten/src/ATen/function_wrapper.py',
'aten/src/ATen/native/quantized/cpu/qnnpack/configure.py',
'aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/configure.py',
'aten/src/ATen/native/quantized/cpu/qnnpack/generate-wrapper.py',
'aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py',
'aten/src/ATen/nnapi/codegen.py',
'functorch/__init__.py',
'functorch/_src/__init__.py',
'functorch/_src/aot_autograd/__init__.py',
'functorch/_src/eager_transforms/__init__.py',
'functorch/_src/make_functional/__init__.py',
'functorch/_src/vmap/__init__.py',
'functorch/benchmarks/chrome_trace_parser.py',
'functorch/benchmarks/cse.py',
'functorch/benchmarks/operator_authoring.py',
'functorch/benchmarks/per_sample_grads.py',
'functorch/benchmarks/pointwise_scorecard.py',
'functorch/benchmarks/process_scorecard.py',
'functorch/compile/__init__.py',
'functorch/dim/__init__.py',
'functorch/dim/batch_tensor.py',
'functorch/dim/delayed_mul_tensor.py',
'functorch/dim/dim.py',
'functorch/dim/magic_trace.py',
'functorch/dim/op_properties.py',
'functorch/dim/reference.py',
'functorch/dim/tree_map.py',
'functorch/dim/wrap_type.py',
'functorch/docs/source/conf.py',
'functorch/einops/__init__.py',
'functorch/einops/_parsing.py',
'functorch/einops/rearrange.py',
'functorch/examples/compilation/eager_fusion.py',
'functorch/examples/compilation/fuse_module.py',
'functorch/examples/compilation/linear_train.py',
'functorch/examples/compilation/simple_function.py',
'functorch/examples/dp_cifar10/cifar10_opacus.py',
'functorch/examples/dp_cifar10/cifar10_transforms.py',
'functorch/examples/ensembling/parallel_train.py',
'functorch/examples/lennard_jones/lennard_jones.py',
'functorch/examples/maml_omniglot/maml-omniglot-higher.py',
'functorch/examples/maml_omniglot/maml-omniglot-ptonly.py',
'functorch/examples/maml_omniglot/maml-omniglot-transforms.py',
'functorch/examples/maml_omniglot/support/omniglot_loaders.py',
'functorch/examples/maml_regression/evjang.py',
'functorch/examples/maml_regression/evjang_transforms.py',
'functorch/examples/maml_regression/evjang_transforms_module.py',
'functorch/experimental/__init__.py',
'functorch/experimental/_cond.py',
'functorch/experimental/_map.py',
'functorch/experimental/control_flow.py',
'functorch/experimental/ops.py',
'functorch/notebooks/_src/plot_ensembling.py',
'functorch/notebooks/_src/plot_jacobians_and_hessians.py',
'functorch/notebooks/_src/plot_per_sample_gradients.py',
'functorch/op_analysis/gen_data.py',
'test/_nvfuser/__init__.py',
'test/_nvfuser/test_dynamo.py',
'test/_nvfuser/test_python_frontend.py',

View File

@ -31,7 +31,6 @@ def main(args):
],
extra_include_dirs="src",
):
requantization_objects = [
build.cc("requantization/precise-scalar.c"),
build.cc("requantization/fp32-scalar.c"),
@ -192,7 +191,6 @@ def main(args):
},
extra_include_dirs=["src", "test"],
):
build.unittest("hgemm-test", build.cxx("hgemm.cc"))
build.unittest("q8avgpool-test", build.cxx("q8avgpool.cc"))
build.unittest("q8conv-test", build.cxx("q8conv.cc"))
@ -252,7 +250,6 @@ def main(args):
isa=benchmark_isa,
extra_include_dirs="src",
):
build.benchmark("add-bench", build.cxx("add.cc"))
build.benchmark("average-pooling-bench", build.cxx("average-pooling.cc"))
build.benchmark("channel-shuffle-bench", build.cxx("channel-shuffle.cc"))

View File

@ -7,6 +7,7 @@
# LICENSE file in the root directory of this source tree.
import confu
parser = confu.standard_parser("clog configuration script")
@ -19,13 +20,16 @@ def main(args):
with build.options(source_dir="src", extra_include_dirs="src"):
build.static_library("clog", build.cc("clog.c"))
with build.options(source_dir="test", deps={
(build, build.deps.googletest): all,
"log": build.target.is_android}):
with build.options(
source_dir="test",
deps={(build, build.deps.googletest): all, "log": build.target.is_android},
):
build.unittest("clog-test", build.cxx("clog.cc"))
return build
if __name__ == "__main__":
import sys
main(sys.argv[1:]).generate()

View File

@ -8,12 +8,12 @@
# Kernels are ordered (see `sort_index`), and when dispatching,
# we select the first kernel in the list that supports the inputs
import argparse
import collections
import itertools
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple, TypeVar
import argparse
DTYPES = {
"f32": "float",
@ -303,7 +303,11 @@ T = TypeVar("T", FwdKernel, BwdKernel)
def write_decl_impl(
kernels: List[T], family_name: str, impl_file: str, autogen_dir: Path, disable_def: str = None
kernels: List[T],
family_name: str,
impl_file: str,
autogen_dir: Path,
disable_def: str = None,
) -> None:
cpp_file_header = """/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
@ -382,22 +386,28 @@ def main(output_dir: Optional[str]) -> None:
FwdKernel.get_all(),
"cutlassF",
impl_file="<ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h>",
autogen_dir=output_dir
autogen_dir=output_dir,
)
write_decl_impl(
BwdKernel.get_all(),
"cutlassB",
impl_file="<ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>",
autogen_dir=output_dir
autogen_dir=output_dir,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='generate_kernels',
description='Generate the mem-eff kernels template instantiations')
prog="generate_kernels",
description="Generate the mem-eff kernels template instantiations",
)
# Set an optional output directory
parser.add_argument('-o', '--output_dir', required=False, help="Where to generate the kernels "
" will default to <ATen/native/transformers/cuda/mem_eff_attention/kernels/> ")
parser.add_argument(
"-o",
"--output_dir",
required=False,
help="Where to generate the kernels "
" will default to <ATen/native/transformers/cuda/mem_eff_attention/kernels/> ",
)
args = parser.parse_args()
main(args.output_dir)

View File

@ -7,9 +7,9 @@ that opens libneuralnetworks.so with dlopen and finds the functions
we need with dlsym. We also generate a "check" wrapper that checks
return values and throws C++ exceptions on errors.
"""
import sys
import re
import pathlib
import re
import sys
import textwrap
@ -36,39 +36,155 @@ PREFIX = """\
NNAPI_FUNCTIONS = [
("int", "ANeuralNetworks_getDeviceCount", "uint32_t* numDevices"), # noqa: B950
("int", "ANeuralNetworks_getDevice", "uint32_t devIndex, ANeuralNetworksDevice** device"), # noqa: B950
("int", "ANeuralNetworksDevice_getName", "const ANeuralNetworksDevice* device, const char** name"), # noqa: B950
("int", "ANeuralNetworksDevice_getVersion", "const ANeuralNetworksDevice* device, const char** version"), # noqa: B950
("int", "ANeuralNetworksDevice_getFeatureLevel", "const ANeuralNetworksDevice* device, int64_t* featureLevel"), # noqa: B950
("int", "ANeuralNetworksModel_getSupportedOperationsForDevices", " const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps"), # noqa: B950
("int", "ANeuralNetworksCompilation_createForDevices", "ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation"), # noqa: B950
("int", "ANeuralNetworksExecution_compute", "ANeuralNetworksExecution* execution"), # noqa: B950
("int", "ANeuralNetworksMemory_createFromFd", "size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory"), # noqa: B950
("void", "ANeuralNetworksMemory_free", "ANeuralNetworksMemory* memory"), # noqa: B950
("int", "ANeuralNetworksModel_create", "ANeuralNetworksModel** model"), # noqa: B950
(
"int",
"ANeuralNetworks_getDevice",
"uint32_t devIndex, ANeuralNetworksDevice** device",
), # noqa: B950
(
"int",
"ANeuralNetworksDevice_getName",
"const ANeuralNetworksDevice* device, const char** name",
), # noqa: B950
(
"int",
"ANeuralNetworksDevice_getVersion",
"const ANeuralNetworksDevice* device, const char** version",
), # noqa: B950
(
"int",
"ANeuralNetworksDevice_getFeatureLevel",
"const ANeuralNetworksDevice* device, int64_t* featureLevel",
), # noqa: B950
(
"int",
"ANeuralNetworksModel_getSupportedOperationsForDevices",
" const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps",
), # noqa: B950
(
"int",
"ANeuralNetworksCompilation_createForDevices",
"ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation", # noqa: B950
),
(
"int",
"ANeuralNetworksExecution_compute",
"ANeuralNetworksExecution* execution",
), # noqa: B950
(
"int",
"ANeuralNetworksMemory_createFromFd",
"size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory",
), # noqa: B950
(
"void",
"ANeuralNetworksMemory_free",
"ANeuralNetworksMemory* memory",
), # noqa: B950
(
"int",
"ANeuralNetworksModel_create",
"ANeuralNetworksModel** model",
), # noqa: B950
("void", "ANeuralNetworksModel_free", "ANeuralNetworksModel* model"), # noqa: B950
("int", "ANeuralNetworksModel_finish", "ANeuralNetworksModel* model"), # noqa: B950
("int", "ANeuralNetworksModel_addOperand", "ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type"), # noqa: B950
("int", "ANeuralNetworksModel_setOperandValue", "ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length"), # noqa: B950
("int", "ANeuralNetworksModel_setOperandValueFromMemory", "ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950
("int", "ANeuralNetworksModel_addOperation", "ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs"), # noqa: B950
("int", "ANeuralNetworksModel_identifyInputsAndOutputs", "ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs"), # noqa: B950
("int", "ANeuralNetworksModel_relaxComputationFloat32toFloat16", "ANeuralNetworksModel* model, bool allow"), # noqa: B950
("int", "ANeuralNetworksCompilation_create", "ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation"), # noqa: B950
("void", "ANeuralNetworksCompilation_free", "ANeuralNetworksCompilation* compilation"), # noqa: B950
("int", "ANeuralNetworksCompilation_setPreference", "ANeuralNetworksCompilation* compilation, int32_t preference"), # noqa: B950
("int", "ANeuralNetworksCompilation_finish", "ANeuralNetworksCompilation* compilation"), # noqa: B950
("int", "ANeuralNetworksExecution_create", "ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution"), # noqa: B950
("void", "ANeuralNetworksExecution_free", "ANeuralNetworksExecution* execution"), # noqa: B950
("int", "ANeuralNetworksExecution_setInput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length"), # noqa: B950
("int", "ANeuralNetworksExecution_setInputFromMemory", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950
("int", "ANeuralNetworksExecution_setOutput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length"), # noqa: B950
("int", "ANeuralNetworksExecution_setOutputFromMemory", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950
("int", "ANeuralNetworksExecution_startCompute", "ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event"), # noqa: B950
(
"int",
"ANeuralNetworksModel_addOperand",
"ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type",
), # noqa: B950
(
"int",
"ANeuralNetworksModel_setOperandValue",
"ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length",
), # noqa: B950
(
"int",
"ANeuralNetworksModel_setOperandValueFromMemory",
"ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length",
), # noqa: B950
(
"int",
"ANeuralNetworksModel_addOperation",
"ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs", # noqa: B950
),
(
"int",
"ANeuralNetworksModel_identifyInputsAndOutputs",
"ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs",
), # noqa: B950
(
"int",
"ANeuralNetworksModel_relaxComputationFloat32toFloat16",
"ANeuralNetworksModel* model, bool allow",
), # noqa: B950
(
"int",
"ANeuralNetworksCompilation_create",
"ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation",
), # noqa: B950
(
"void",
"ANeuralNetworksCompilation_free",
"ANeuralNetworksCompilation* compilation",
), # noqa: B950
(
"int",
"ANeuralNetworksCompilation_setPreference",
"ANeuralNetworksCompilation* compilation, int32_t preference",
), # noqa: B950
(
"int",
"ANeuralNetworksCompilation_finish",
"ANeuralNetworksCompilation* compilation",
), # noqa: B950
(
"int",
"ANeuralNetworksExecution_create",
"ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution",
), # noqa: B950
(
"void",
"ANeuralNetworksExecution_free",
"ANeuralNetworksExecution* execution",
), # noqa: B950
(
"int",
"ANeuralNetworksExecution_setInput",
"ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length", # noqa: B950
),
(
"int",
"ANeuralNetworksExecution_setInputFromMemory",
"ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length", # noqa: B950
),
(
"int",
"ANeuralNetworksExecution_setOutput",
"ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length",
), # noqa: B950
(
"int",
"ANeuralNetworksExecution_setOutputFromMemory",
"ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length", # noqa: B950
),
(
"int",
"ANeuralNetworksExecution_startCompute",
"ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event",
), # noqa: B950
("int", "ANeuralNetworksEvent_wait", "ANeuralNetworksEvent* event"), # noqa: B950
("void", "ANeuralNetworksEvent_free", "ANeuralNetworksEvent* event"), # noqa: B950
("int", "ANeuralNetworksExecution_getOutputOperandRank", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank"), # noqa: B950
("int", "ANeuralNetworksExecution_getOutputOperandDimensions", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions"), # noqa: B950
(
"int",
"ANeuralNetworksExecution_getOutputOperandRank",
"ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank",
), # noqa: B950
(
"int",
"ANeuralNetworksExecution_getOutputOperandDimensions",
"ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions",
), # noqa: B950
]
@ -82,18 +198,26 @@ def main(argv):
struct_members.append(f" {ret}(*{short_name})({args});")
load_functions.append(f' *(void**)&nnapi_.{short_name} = dlsym(handle, "{name}");')
load_functions.append(f' check_nnapi_.{short_name} = check_{short_name};')
load_functions.append(
f' *(void**)&nnapi_.{short_name} = dlsym(handle, "{name}");'
)
load_functions.append(f" check_nnapi_.{short_name} = check_{short_name};")
call_args = "".join(re.findall(r"\w+(?:,|$)", args))
if ret == "void":
define_checks.append(textwrap.dedent(f"""\
define_checks.append(
textwrap.dedent(
f"""\
{ret} check_{short_name}({args}) {{
CAFFE_ENFORCE(nnapi_.{short_name});
nnapi_.{short_name}({call_args});
}}"""))
}}"""
)
)
if ret == "int":
define_checks.append(textwrap.dedent(f"""\
define_checks.append(
textwrap.dedent(
f"""\
{ret} check_{short_name}({args}) {{
CAFFE_ENFORCE(nnapi_.{short_name});
int ret = nnapi_.{short_name}({call_args});
@ -103,13 +227,16 @@ def main(argv):
"{short_name}", "failed with error ", ret
);
return ret;
}}"""))
}}"""
)
)
out_dir = pathlib.Path(__file__).parent
(out_dir / "nnapi_wrapper.h").write_text(
PREFIX +
textwrap.dedent("""\
PREFIX
+ textwrap.dedent(
"""\
#ifndef NNAPI_WRAPPER_H_
#define NNAPI_WRAPPER_H_
#include <stddef.h>
@ -122,13 +249,14 @@ def main(argv):
void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi);
#endif
#endif
""")
.replace("__STRUCT_MEMBERS__", "\n".join(struct_members))
"""
).replace("__STRUCT_MEMBERS__", "\n".join(struct_members))
)
(out_dir / "nnapi_wrapper.cpp").write_text(
PREFIX +
textwrap.dedent("""\
PREFIX
+ textwrap.dedent(
"""\
#ifndef _WIN32
#include <dlfcn.h>
#endif
@ -157,7 +285,8 @@ def main(argv):
*check_nnapi = &check_nnapi_;
#endif
}
""")
"""
)
.replace("__DEFINE_CHECK_FUNCTIONS__", "\n".join(define_checks))
.replace("__LOAD_FUNCTIONS__", "\n".join(load_functions))
)

View File

@ -5,6 +5,27 @@
# LICENSE file in the root directory of this source tree.
import torch
from torch._functorch.deprecated import (
combine_state_for_ensemble,
functionalize,
grad,
grad_and_value,
hessian,
jacfwd,
jacrev,
jvp,
make_functional,
make_functional_with_buffers,
vjp,
vmap,
)
# utilities. Maybe these should go in their own namespace in the future?
from torch._functorch.make_functional import (
FunctionalModule,
FunctionalModuleWithBuffers,
)
# Top-level APIs. Please think carefully before adding something to the
# top-level namespace:
# - private helper functions should go into torch._functorch
@ -14,15 +35,4 @@ import torch
# Was never documented
from torch._functorch.python_key import make_fx
from torch._functorch.deprecated import (
vmap, grad, grad_and_value, vjp, jvp, jacrev, jacfwd, hessian, functionalize,
make_functional, make_functional_with_buffers, combine_state_for_ensemble,
)
# utilities. Maybe these should go in their own namespace in the future?
from torch._functorch.make_functional import (
FunctionalModule,
FunctionalModuleWithBuffers,
)
__version__ = torch.__version__

View File

@ -2,6 +2,6 @@
# If you are not a PyTorch developer and you are relying on the following
# imports, please file an issue.
from torch._functorch.eager_transforms import (
_unwrap_functional_tensor,
_assert_wrapped_functional,
_unwrap_functional_tensor,
)

View File

@ -4,13 +4,13 @@
from torch._functorch.vmap import (
_add_batch_dim,
_broadcast_to_and_flatten,
_create_batched_inputs,
_get_name,
_process_batched_inputs,
_remove_batch_dim,
_unwrap_batched,
_validate_and_get_batch_size,
Tensor,
tree_flatten,
tree_unflatten,
_process_batched_inputs,
_create_batched_inputs,
_unwrap_batched,
)

View File

@ -1,8 +1,9 @@
#!/usr/bin/env python3
import argparse
import logging
import os
import logging
import pandas as pd
from torch._functorch.benchmark_utils import compute_utilization
@ -20,6 +21,7 @@ def get_model_name(filename):
modelname = tail[: tail.find("_chrome_trace")]
return modelname
def get_total_length(run_times_df, modelname):
return float(run_times_df[run_times_df["name"] == modelname]["runtime"])
@ -31,14 +33,14 @@ def main():
"--runtime", "-runf", help="file name of the runtime file", required=True
)
group.add_argument(
"--filename", "-f", action="append", help="a filename of the json file to process"
)
group.add_argument(
"--folder", "-fd", help="a folder of the json files to process"
"--filename",
"-f",
action="append",
help="a filename of the json file to process",
)
group.add_argument("--folder", "-fd", help="a folder of the json files to process")
args = parser.parse_args()
if args.filename:
filenames = args.filename
elif args.folder:
@ -58,11 +60,14 @@ def main():
try:
modelname = get_model_name(filename)
total_length = get_total_length(run_times_df, modelname) * 1e6
utilization, mm_conv_utilization = compute_utilization(filenames, total_length)
utilization, mm_conv_utilization = compute_utilization(
filenames, total_length
)
print(f"{modelname}, {utilization}, {mm_conv_utilization}")
except BaseException:
logging.exception("%s, ERROR", filename)
print(f"{filename}, ERROR")
if __name__ == "__main__":
main()

View File

@ -1,9 +1,10 @@
import torch
import torch.fx as fx
from functorch import make_fx
from torch.profiler import profile, ProfilerActivity
from torch._functorch.compile_utils import fx_graph_cse
from torch.profiler import profile, ProfilerActivity
def profile_it(f, inp):
for _ in range(5):
@ -20,6 +21,7 @@ def profile_it(f, inp):
cuda_time_total = cuda_time_total + e.cuda_time_total
return cuda_time_total / itr
def profile_function(name, f, inp):
fx_g = make_fx(f)(inp)
@ -34,17 +36,23 @@ def profile_function(name, f, inp):
avg_cuda_time_g = profile_it(new_g, inp)
num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)
print(f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}")
print(
f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}"
)
g_gpu = torch.Generator(device='cuda')
g_gpu = torch.Generator(device="cuda")
g_gpu.manual_seed(2147483647)
inp = torch.randn(2**20, device='cuda', generator=g_gpu)
inp = torch.randn(2**20, device="cuda", generator=g_gpu)
def f1(x):
return x.cos().cos()
profile_function("f1", f1, inp)
def fsum(x):
a = x.sum()
b = x.sum()
@ -52,22 +60,29 @@ def fsum(x):
d = x.sum()
return a + b + c + d
profile_function("fsum", fsum, inp)
def fconcat(x):
a = torch.cat((x, x))
b = torch.cat((x, x))
return a + b
profile_function("fconcat", fconcat, inp)
def fsum2(x):
a = x.sum()
for _ in range(30):
a = a + x.sum()
return a
profile_function("fsum2", fsum2, inp)
def fsummulti(x):
a = 0
for _ in range(3):
@ -75,8 +90,10 @@ def fsummulti(x):
a = a * x.sum()
return a
profile_function("fsummulti", fsummulti, inp)
def fsummulti2(x):
a = 0
for _ in range(30):
@ -84,20 +101,25 @@ def fsummulti2(x):
a = a * x.sum()
return a
profile_function("fsummulti2", fsummulti2, inp)
def fcos(x):
a = 0
for _ in range(3):
a = a + x.cos()
return a
profile_function("fcos", fcos, inp)
def fcos2(x):
a = 0
for _ in range(30):
a = a + x.cos()
return a
profile_function("fcos2", fcos2, inp)

View File

@ -1,7 +1,8 @@
import timeit
from functools import partial
import numpy as np
import pandas as pd
import timeit
import torch
from functorch.compile import pointwise_operator

View File

@ -1,14 +1,14 @@
import time
import torch
import torch.nn as nn
import torchvision.models as models
from opacus.utils.module_modification import convert_batchnorm_modules
import time
from functorch import vmap, grad
from functorch import make_functional
from functorch import grad, make_functional, vmap
from opacus import PrivacyEngine
from opacus.utils.module_modification import convert_batchnorm_modules
device = 'cuda'
device = "cuda"
batch_size = 128
torch.manual_seed(0)
@ -20,6 +20,7 @@ images = torch.randn(batch_size, 3, 32, 32, device=device)
targets = torch.randint(0, 10, (batch_size,), device=device)
func_model, weights = make_functional(model_functorch)
def compute_loss(weights, image, target):
images = image.unsqueeze(0)
targets = target.unsqueeze(0)
@ -27,11 +28,11 @@ def compute_loss(weights, image, target):
loss = criterion(output, targets)
return loss
def functorch_per_sample_grad():
compute_grad = grad(compute_loss)
compute_per_sample_grad = vmap(compute_grad, (None, 0, 0))
start = time.time()
result = compute_per_sample_grad(weights, images, targets)
torch.cuda.synchronize()
@ -39,6 +40,7 @@ def functorch_per_sample_grad():
return result, end - start # end - start in seconds
torch.manual_seed(0)
model_opacus = convert_batchnorm_modules(models.resnet18(num_classes=10))
model_opacus = model_opacus.to(device)
@ -54,6 +56,7 @@ privacy_engine = PrivacyEngine(
max_grad_norm=10000.0,
)
def opacus_per_sample_grad():
start = time.time()
output = model_opacus(images)
@ -63,7 +66,7 @@ def opacus_per_sample_grad():
end = time.time()
expected = [p.grad_sample for p in model_opacus.parameters()]
for p in model_opacus.parameters():
delattr(p, 'grad_sample')
delattr(p, "grad_sample")
p.grad = None
return expected, end - start

View File

@ -1,14 +1,16 @@
import sys
import time
import torch
import inspect
import itertools
import sys
import time
import torch
from functorch import pointwise_operator
torch.set_num_threads(1)
torch._C._debug_set_fusion_group_inlining(False)
def rand(*shape):
return torch.rand(*shape).mul(16).add(1)
@ -19,105 +21,139 @@ def rand(*shape):
def scalar():
return (rand(1), rand(1))
def small():
return (rand(32), rand(32))
def small_2d():
return (rand(1, 32), rand(1, 32))
def small_broadcast():
return (rand(4, 32), rand(32))
def medium():
return (rand(32, 12, 64, 64), rand(32, 12, 64, 64))
def medium_sliced():
return (rand(32, 12, 64, 64)[..., ::2],
rand(32, 12, 64, 64)[..., ::2])
return (rand(32, 12, 64, 64)[..., ::2], rand(32, 12, 64, 64)[..., ::2])
def medium_transpose():
return (rand(32, 12, 64, 64).transpose(-1, -2),
rand(32, 12, 64, 64).transpose(-1, -2))
return (
rand(32, 12, 64, 64).transpose(-1, -2),
rand(32, 12, 64, 64).transpose(-1, -2),
)
def medium2():
return (rand(32, 3, 224, 224), rand(32, 3, 224, 224))
def medium3d():
return (rand(16, 32, 64), rand(16, 32, 64))
def medium_channels_last():
return (rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
rand(32, 3, 224, 224).to(memory_format=torch.channels_last))
return (
rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
)
def medium_broadcast():
return (rand(32, 12, 64, 64), rand(64))
def medium_broadcast_channels_last():
return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last),
rand(3, 1, 1))
return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last), rand(3, 1, 1))
def large():
return (rand(8192, 8192), rand(8192, 8192))
def large_transpose():
return (rand(8192, 8192).transpose(0, 1),
rand(8192, 8192).transpose(0, 1))
return (rand(8192, 8192).transpose(0, 1), rand(8192, 8192).transpose(0, 1))
def large_channels_last():
return (rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
rand(32, 32, 256, 256).to(memory_format=torch.channels_last))
return (
rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
)
def pathological_broadcast():
return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2))
# ------------------------------------------------------------------------------
# Operator test cases
# ------------------------------------------------------------------------------
def add(a, b):
return a + b
def sub(a, b):
return a - b
def mul(a, b):
return a * b
def div(a, b):
return a / b
def relu(a):
return a.relu()
def sigmoid(a):
return a.sigmoid()
def tanh(a):
return a.tanh()
def log(a):
return a.log()
def exp(a):
return a.exp()
def square(a):
return a**2
def fma(a, b):
return a * b + b
def hardswish(a):
return a * (a + 3.0).clamp(0.0, 6.0) / 6.0
def native_hardswish(a):
return torch._C._nn.hardswish(a)
def softplus(a):
return (a * 1.0).exp().log1p() / 1.0
def mish(a):
return a * ((a * 1.0).exp().log1p() / 1.0).tanh()
# ------------------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------------------
@ -128,6 +164,7 @@ def time_cpu(fn, args, iters):
e = time.perf_counter()
return e - s
def time_cuda(fn, args, iters):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
@ -138,19 +175,23 @@ def time_cuda(fn, args, iters):
torch.cuda.synchronize()
return start.elapsed_time(end) / 1e3
def benchmark_with_timer(fn, args, timer):
timer(fn, args, 3)
calibration = timer(fn, args, 1)
iters = int(1.0 / calibration)
return timer(fn, args, iters) / iters
def benchmark(fn, args):
timer = time_cpu if args[0].device.type == "cpu" else time_cuda
return benchmark_with_timer(fn, args, timer)
def micros(s):
return f"{s * 1e6:.1f}"
shapes = [
scalar,
small,
@ -211,7 +252,17 @@ for shape, operator in itertools.product(shapes, operators):
args = shape()[:nargs]
result = benchmark(operator, args)
print(",".join(["eager", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
print(
",".join(
[
"eager",
args[0].device.type,
operator.__name__,
shape.__name__,
micros(result),
]
)
)
try:
if shape == medium_transpose:
raise RuntimeError("pointwise_operator hangs on medium_transpose")
@ -219,11 +270,41 @@ for shape, operator in itertools.product(shapes, operators):
raise RuntimeError("pointwise_operator fails on medium_transpose")
pw_op = pointwise_operator(operator)
result = benchmark(pw_op, args)
print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
print(
",".join(
[
"pointwise",
args[0].device.type,
operator.__name__,
shape.__name__,
micros(result),
]
)
)
except Exception:
print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(float("nan"))]))
print(
",".join(
[
"pointwise",
args[0].device.type,
operator.__name__,
shape.__name__,
micros(float("nan")),
]
)
)
ts_op = torch.jit.script(operator)
result = benchmark(ts_op, args)
print(",".join(["fuser", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
print(
",".join(
[
"fuser",
args[0].device.type,
operator.__name__,
shape.__name__,
micros(result),
]
)
)
sys.stdout.flush()

View File

@ -1,11 +1,13 @@
import pandas
import matplotlib.pyplot as plt
import pandas
df = pandas.read_csv("perf.csv")
ops = pandas.unique(df["operator"])
nops = len(ops)
pivot_op_shape = df.pivot_table(values="time", index=["operator", "shape"], columns=["fuser"])
pivot_op_shape = df.pivot_table(
values="time", index=["operator", "shape"], columns=["fuser"]
)
pivot_speedups = (pivot_op_shape.T / pivot_op_shape["eager"]).T
plt.rcParams["figure.figsize"] = (20, 100)

View File

@ -1,31 +1,31 @@
from torch._functorch.python_key import pythonkey_decompose
from torch._functorch.fx_minifier import minifier
from torch._functorch import config
from torch._functorch.aot_autograd import (
aot_function,
aot_module,
aot_module_simplified,
compiled_function,
compiled_module,
aot_module_simplified,
get_graph_being_compiled,
get_aot_graph_name,
get_aot_compilation_context,
get_aot_graph_name,
get_graph_being_compiled,
make_boxed_compiler,
make_boxed_func,
make_boxed_compiler
)
from torch._functorch.compilers import (
ts_compile,
draw_graph_compile,
nop,
nnc_jit,
memory_efficient_fusion,
debug_compile,
default_decompositions,
draw_graph_compile,
memory_efficient_fusion,
nnc_jit,
nop,
print_compile,
default_decompositions
ts_compile,
)
from torch._functorch.fx_minifier import minifier
from torch._functorch.partitioners import (
min_cut_rematerialization_partition,
default_partition,
draw_graph,
draw_joint_graph,
min_cut_rematerialization_partition,
)
from torch._functorch import config
from torch._functorch.python_key import pythonkey_decompose

View File

@ -1,20 +1,26 @@
import torch
from typing import Union, Sequence
import inspect
import dis
from .tree_map import tree_flatten, tree_map
from .wrap_type import wrap_type
import inspect
from typing import Sequence, Union
import torch
import functorch._C
from functorch._C import dim as _C
from .tree_map import tree_flatten, tree_map
from .wrap_type import wrap_type
_C._patch_tensor_class()
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
class DimensionMismatchError(Exception):
pass
class DimensionBindError(Exception):
pass
from . import op_properties
# use dict to avoid writing C++ bindings for set
@ -24,11 +30,11 @@ use_c = True
if not use_c:
from . import reference
class _Tensor:
# fast path around slow wrapping/unwrapping logic for simply queries used
# by the implementation...
@property
def dims(self):
return tuple(d for d in self._levels if isinstance(d, Dim))
@ -47,11 +53,12 @@ class _Tensor:
def __repr__(self):
tensor, levels, ndim = self._tensor, self._levels, self.ndim
return f'{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}'
return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"
TensorLike = (_Tensor, torch.Tensor)
class Dim(_C.Dim, _Tensor):
# note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
# Tensor defines format, but we want to print Dims with special formatting
@ -69,6 +76,7 @@ def cat(tensors, dim, new_dim):
n = dims()
return stack(tensors, n, dim).index([n, dim], new_dim)
if use_c:
_wrap = _C._wrap
@ -107,41 +115,41 @@ if use_c:
else:
_Tensor.order = reference.positional
_def('mean')
_def('sum')
_def('all')
_def('amax')
_def('amin')
_def('aminmax')
_def('any')
_def('count_nonzero')
_def('logsumexp')
_def('nanmean')
_def('nansum')
_def('prod')
_def('std', keepdim_offset=2)
_def('var', keepdim_offset=2)
_def('max', single_dim=True)
_def('min', single_dim=True)
_def('argmax', single_dim=True)
_def('argmin', single_dim=True)
_def('kthvalue', single_dim=True)
_def('median', single_dim=True)
_def('nanmedian', single_dim=True)
_def('mode', single_dim=True)
_def('sort', reduce=False)
_def('argsort', reduce=False)
_def('unbind', single_dim=True)
_def('chunk', dim_offset=1, reduce=False)
_def('cummax', single_dim=True, reduce=False)
_def('cummin', single_dim=True, reduce=False)
_def('cumprod', single_dim=True, reduce=False)
_def('cumprod_', single_dim=True, reduce=False)
_def('cumsum', single_dim=True, reduce=False)
_def('cumsum_', single_dim=True, reduce=False)
_def('logcumsumexp', single_dim=True, reduce=False)
_def('renorm', dim_offset=1, single_dim=True, reduce=False)
_def('softmax', single_dim=True, reduce=False)
_def("mean")
_def("sum")
_def("all")
_def("amax")
_def("amin")
_def("aminmax")
_def("any")
_def("count_nonzero")
_def("logsumexp")
_def("nanmean")
_def("nansum")
_def("prod")
_def("std", keepdim_offset=2)
_def("var", keepdim_offset=2)
_def("max", single_dim=True)
_def("min", single_dim=True)
_def("argmax", single_dim=True)
_def("argmin", single_dim=True)
_def("kthvalue", single_dim=True)
_def("median", single_dim=True)
_def("nanmedian", single_dim=True)
_def("mode", single_dim=True)
_def("sort", reduce=False)
_def("argsort", reduce=False)
_def("unbind", single_dim=True)
_def("chunk", dim_offset=1, reduce=False)
_def("cummax", single_dim=True, reduce=False)
_def("cummin", single_dim=True, reduce=False)
_def("cumprod", single_dim=True, reduce=False)
_def("cumprod_", single_dim=True, reduce=False)
_def("cumsum", single_dim=True, reduce=False)
_def("cumsum_", single_dim=True, reduce=False)
_def("logcumsumexp", single_dim=True, reduce=False)
_def("renorm", dim_offset=1, single_dim=True, reduce=False)
_def("softmax", single_dim=True, reduce=False)
softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
# stuff to handle in the future, because they require special

View File

@ -3,14 +3,13 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torch._C._functorch import (
_vmap_add_layers,
_vmap_remove_layers,
)
from contextlib import contextmanager
from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
_enabled = False
@contextmanager
def _enable_layers(dims):
global _enabled

View File

@ -4,9 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import _Tensor, Tensor
from .reference import _dims, _enable_layers, llist, ltuple
class DelayedMulTensor(_Tensor):
def __init__(self, lhs, rhs):
self._lhs, self._rhs = lhs, rhs
@ -37,7 +39,9 @@ class DelayedMulTensor(_Tensor):
@property
def _tensor(self):
if self._tensor_data is None:
self._tensor_data = Tensor.from_batched(self._batchtensor, self._has_device)._tensor
self._tensor_data = Tensor.from_batched(
self._batchtensor, self._has_device
)._tensor
return self._tensor_data
@property
@ -48,20 +52,26 @@ class DelayedMulTensor(_Tensor):
def dims(self):
return ltuple(super().dims)
def sum(self, dim):
dims = _dims(dim, 0, False, False)
n = ord('a')
n = ord("a")
all_levels = self._levels
def to_char(d):
return chr(n + all_levels.index(d))
plhs, levelslhs = self._lhs._tensor, self._lhs._levels
prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
new_dims = tuple(d for d in self.dims if d not in dims)
new_levels = [l for l in self._levels if l not in dims]
fmt = ''.join([*(to_char(d) for d in levelslhs), ',',
*(to_char(d) for d in levelsrhs), '->',
*(to_char(d) for d in new_levels)])
fmt = "".join(
[
*(to_char(d) for d in levelslhs),
",",
*(to_char(d) for d in levelsrhs),
"->",
*(to_char(d) for d in new_levels),
]
)
result_data = torch.einsum(fmt, (plhs, prhs))
return Tensor.from_positional(result_data, new_levels, True)

View File

@ -4,11 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
_vmap_levels = []
@dataclass
class LevelInfo:
level: int
alive: bool = True
class Dim:
def __init__(self, name: str, size: Union[None, int] = None):
self.name = name
@ -20,7 +23,9 @@ class Dim:
def __del__(self):
if self._vmap_level is not None:
_vmap_active_levels[self._vmap_stack].alive = False
while not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level:
while (
not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level
):
_vmap_decrement_nesting()
_vmap_levels.pop()
@ -33,13 +38,14 @@ class Dim:
def size(self, size: int):
if self._size is None:
self._size = size
self._vmap_level = _vmap_increment_nesting(size, 'same')
self._vmap_level = _vmap_increment_nesting(size, "same")
self._vmap_stack = len(_vmap_levels)
_vmap_levels.append(LevelInfo(self._vmap_level))
elif self._size != size:
raise DimensionBindError(
f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}")
f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
)
@property
def is_bound(self):
@ -50,10 +56,13 @@ class Dim:
def extract_name(inst):
assert inst.opname == 'STORE_FAST' or inst.opname == 'STORE_NAME'
assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
return inst.argval
_cache = {}
def dims(lists=0):
frame = inspect.currentframe()
assert frame is not None
@ -66,17 +75,22 @@ def dims(lists=0):
instructions = list(dis.get_instructions(calling_frame.f_code))
unpack = instructions[first]
if unpack.opname == 'STORE_FAST' or unpack.opname == 'STORE_NAME':
if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
# just a single dim, not a list
name = unpack.argval
ctor = Dim if lists == 0 else DimList
_cache[key] = lambda: ctor(name=name)
else:
assert unpack.opname == 'UNPACK_SEQUENCE'
assert unpack.opname == "UNPACK_SEQUENCE"
ndims = unpack.argval
names = tuple(extract_name(instructions[first + 1 + i]) for i in range(ndims))
names = tuple(
extract_name(instructions[first + 1 + i]) for i in range(ndims)
)
first_list = len(names) - lists
_cache[key] = lambda: tuple(Dim(n) if i < first_list else DimList(name=n) for i, n in enumerate(names))
_cache[key] = lambda: tuple(
Dim(n) if i < first_list else DimList(name=n)
for i, n in enumerate(names)
)
return _cache[key]()
@ -87,6 +101,7 @@ def _dim_set(positional, arg):
else:
assert isinstance(a, int)
return positional[a]
if arg is None:
return positional
elif not isinstance(arg, (Dim, int)):

View File

@ -3,25 +3,33 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
import os
import subprocess
import signal
import subprocess
from contextlib import contextmanager
@contextmanager
def magic_trace(output='trace.fxt', magic_trace_cache='/tmp/magic-trace'):
def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
pid = os.getpid()
if not os.path.exists(magic_trace_cache):
print(f"Downloading magic_trace to: {magic_trace_cache}")
subprocess.run(['wget', '-O', magic_trace_cache, '-q',
'https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace'])
subprocess.run(['chmod', '+x', magic_trace_cache])
args = [magic_trace_cache, 'attach', '-pid', str(pid), '-o', output]
p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding='utf-8')
subprocess.run(
[
"wget",
"-O",
magic_trace_cache,
"-q",
"https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace",
]
)
subprocess.run(["chmod", "+x", magic_trace_cache])
args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output]
p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8")
while True:
x = p.stderr.readline()
print(x)
if 'Attached' in x:
if "Attached" in x:
break
try:
yield
@ -31,4 +39,4 @@ def magic_trace(output='trace.fxt', magic_trace_cache='/tmp/magic-trace'):
print(p.stderr.read())
p.stderr.close()
if r != 0:
raise ValueError(f'magic_trace exited abnormally: {r}')
raise ValueError(f"magic_trace exited abnormally: {r}")

View File

@ -4,29 +4,58 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
# pointwise operators can go through a faster pathway
tensor_magic_methods = [
'add',
''
]
tensor_magic_methods = ["add", ""]
pointwise_magic_methods_with_reverse = (
'add', 'sub', 'mul', 'floordiv', 'div', 'truediv', 'mod',
'pow', 'lshift', 'rshift', 'and', 'or', 'xor'
"add",
"sub",
"mul",
"floordiv",
"div",
"truediv",
"mod",
"pow",
"lshift",
"rshift",
"and",
"or",
"xor",
)
pointwise_magic_methods = (
*(x for m in pointwise_magic_methods_with_reverse for x in (m, 'r' + m)),
'eq', 'gt', 'le', 'lt', 'ge', 'gt', 'ne', 'neg', 'pos',
'abs', 'invert',
'iadd', 'isub', 'imul', 'ifloordiv', 'idiv',
'itruediv', 'imod', 'ipow', 'ilshift', 'irshift', 'iand',
'ior', 'ixor',
'int', 'long', 'float', 'complex',
*(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)),
"eq",
"gt",
"le",
"lt",
"ge",
"gt",
"ne",
"neg",
"pos",
"abs",
"invert",
"iadd",
"isub",
"imul",
"ifloordiv",
"idiv",
"itruediv",
"imod",
"ipow",
"ilshift",
"irshift",
"iand",
"ior",
"ixor",
"int",
"long",
"float",
"complex",
)
pointwise_methods = (
*(f'__{m}__' for m in pointwise_magic_methods),
)
pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),)
pointwise = (
*(getattr(torch.Tensor, m) for m in pointwise_methods),

View File

@ -6,23 +6,28 @@
# reference python implementations for C ops
import torch
from .tree_map import tree_flatten, tree_map
from .batch_tensor import _enable_layers
from . import op_properties
from functorch._C import dim as _C
from . import op_properties
from .batch_tensor import _enable_layers
from .tree_map import tree_flatten, tree_map
DimList = _C.DimList
from functools import reduce
import operator
from functools import reduce
# use dict to avoid writing C++ bindings for set
pointwise = set(op_properties.pointwise)
def prod(x):
return reduce(operator.mul, x, 1)
def _wrap_dim(d, N, keepdim):
from . import Dim
if isinstance(d, Dim):
assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
return d
@ -31,40 +36,52 @@ def _wrap_dim(d, N, keepdim):
else:
return d
def _dims(d, N, keepdim, single_dim):
from . import Dim
if isinstance(d, (Dim, int)):
return ltuple((_wrap_dim(d, N, keepdim),))
assert not single_dim, f"expected a single dimension or int but found: {d}"
return ltuple(_wrap_dim(x, N, keepdim) for x in d)
def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
from . import DimensionMismatchError
not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
if len(not_bound) == 1:
idx, d = not_bound[0]
rhs_so_far = prod(r.size for r in rhs if r.is_bound)
if lhs_size % rhs_so_far != 0:
rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs)
raise DimensionMismatchError(f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}")
rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
raise DimensionMismatchError(
f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}"
)
new_size = lhs_size // rhs_so_far
d.size = new_size
elif len(not_bound) > 1:
rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs)
raise DimensionMismatchError(f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}")
rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
raise DimensionMismatchError(
f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}"
)
else:
rhs_size = prod(r.size for r in rhs)
if lhs_size != rhs_size:
raise DimensionMismatchError(
f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}")
f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}"
)
def _tensor_levels(inp):
from . import _Tensor
if isinstance(inp, _Tensor):
return inp._tensor, llist(inp._levels), inp._has_device
else:
return inp, llist(range(-inp.ndim, 0)), True
def _match_levels(v, from_levels, to_levels):
view = []
permute = []
@ -90,6 +107,7 @@ def _match_levels(v, from_levels, to_levels):
# should not physically move if possible
def _positional_no_permute(self, dim, expand_dim=False):
from . import Tensor
ptensor, levels = self._tensor, llist(self._levels)
try:
idx = levels.index(dim)
@ -107,8 +125,10 @@ def _positional_no_permute(self, dim, expand_dim=False):
levels[idx] = -idx_batched - 1
return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
def seq(a, b):
from . import Dim
if isinstance(a, Dim) != isinstance(b, Dim):
return False
if isinstance(a, Dim):
@ -116,6 +136,7 @@ def seq(a, b):
else:
return a == b
class isin:
def __contains__(self, item):
for x in self:
@ -133,18 +154,27 @@ class isin:
class llist(isin, list):
pass
class ltuple(isin, tuple):
pass
empty_dict = {}
@classmethod
def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
from . import _Tensor, TensorLike, Tensor
from . import _Tensor, Tensor, TensorLike
from .delayed_mul_tensor import DelayedMulTensor
if orig is torch.Tensor.__mul__:
lhs, rhs = args
if isinstance(lhs, _Tensor) and isinstance(rhs, _Tensor) and lhs.ndim == 0 and rhs.ndim == 0:
if (
isinstance(lhs, _Tensor)
and isinstance(rhs, _Tensor)
and lhs.ndim == 0
and rhs.ndim == 0
):
return DelayedMulTensor(lhs, rhs)
all_dims = llist()
flat_args, unflatten = tree_flatten((args, kwargs))
@ -172,7 +202,11 @@ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
for i, f in enumerate(flat_args):
if isinstance(f, TensorLike):
ptensor, levels, _ = _tensor_levels(f)
if isinstance(f, _Tensor) and not f._has_device and device_holding_tensor is not None:
if (
isinstance(f, _Tensor)
and not f._has_device
and device_holding_tensor is not None
):
ptensor = ptensor.to(device=device_holding_tensor.device)
flat_args[i] = ptensor
for l in levels:
@ -187,14 +221,19 @@ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
def wrap(t):
if isinstance(t, TensorLike):
return Tensor.from_positional(t, result_levels, device_holding_tensor is not None)
return Tensor.from_positional(
t, result_levels, device_holding_tensor is not None
)
return t
return tree_map(wrap, result)
else:
def wrap(t):
if isinstance(t, TensorLike):
return Tensor.from_batched(t, device_holding_tensor is not None)
return t
with _enable_layers(all_dims):
print(f"batch_tensor for {orig}")
args, kwargs = unflatten(unwrap(f) for f in flat_args)
@ -202,8 +241,10 @@ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
# print("END", orig)
return tree_map(wrap, result)
def positional(self, *dims):
from . import Dim, Tensor
ptensor, levels = self._tensor, llist(self._levels)
flat_dims = llist()
view = []
@ -231,7 +272,9 @@ def positional(self, *dims):
try:
idx = levels.index(d)
except ValueError as e:
raise DimensionBindError(f'tensor of dimensions {self.dims} does not contain dim {d}') from e
raise DimensionBindError(
f"tensor of dimensions {self.dims} does not contain dim {d}"
) from e
p = permute[idx]
del levels[idx]
del permute[idx]
@ -248,12 +291,15 @@ def positional(self, *dims):
result = result.reshape(*view, *result.size()[len(flat_dims) :])
return result
def _contains_dim(input):
from . import Dim
for i in input:
if isinstance(i, Dim):
return True
def expand(self, *sizes):
if not _contains_dim(sizes):
return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
@ -265,27 +311,36 @@ def expand(self, *sizes):
_not_present = object()
def _getarg(name, offset, args, kwargs, default):
if len(args) > offset:
return args[offset]
return kwargs.get(name, default)
def _patcharg(name, offset, args, kwargs, value):
if len(args) > offset:
args[offset] = value
else:
kwargs[name] = value
def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False, reduce=True):
from . import TensorLike, Dim, Tensor
def _wrap(
orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True
):
from . import Dim, Tensor, TensorLike
def fn(self, *args, **kwargs):
dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
with _enable_layers(self.dims):
print(f"dim fallback batch_tensor for {orig}")
return Tensor.from_batched(orig(self._batchtensor, *args, **kwargs), self._has_device)
keepdim = _getarg('keepdim', keepdim_offset, args, kwargs, False) if reduce else False
return Tensor.from_batched(
orig(self._batchtensor, *args, **kwargs), self._has_device
)
keepdim = (
_getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False
)
t, levels = self._tensor, llist(self._levels)
dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
dim_indices = tuple(levels.index(d) for d in dims)
@ -295,7 +350,9 @@ def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False
new_levels = levels
if len(dim_indices) == 1:
dim_indices = dim_indices[0] # so that dims that really only take a single argument work...
dim_indices = dim_indices[
0
] # so that dims that really only take a single argument work...
args = list(args)
_patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
@ -303,21 +360,27 @@ def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False
if isinstance(t, TensorLike):
return Tensor.from_positional(t, new_levels, self._has_device)
return t
with _enable_layers(new_levels):
print(f"dim used batch_tensor for {orig}")
r = orig(t, *args, **kwargs)
return tree_map(wrap, r)
return fn
def _def(name, *args, **kwargs):
from . import _Tensor
orig = getattr(torch.Tensor, name)
setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
no_slice = slice(None)
_orig_getitem = torch.Tensor.__getitem__
class dim_tracker:
def __init__(self):
self.dims = llist()
@ -331,8 +394,10 @@ class dim_tracker:
def __getitem__(self, d):
return self.count[self.dims.index(d)]
def t__getitem__(self, input):
from . import Dim, DimensionBindError, _Tensor, TensorLike, DimList, Tensor
from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike
# * bail to original example if we have a single non-Dim tensor, or a non-tensor
# * locate ... or an unbound tensor list, and determine its size, bind dim list
# (remember that None does not count to the total dim count)
@ -345,10 +410,13 @@ def t__getitem__(self, input):
# this handles bool indexing handling, as well as some other simple cases.
is_simple = (not isinstance(input, Dim) and
not isinstance(input, (tuple, list)) and
is_simple = (
not isinstance(input, Dim)
and not isinstance(input, (tuple, list))
and
# WAR for functorch bug where zero time tensors in getitem are not handled correctly.
not (isinstance(input, TensorLike) and input.ndim == 0))
not (isinstance(input, TensorLike) and input.ndim == 0)
)
if is_simple:
if isinstance(self, _Tensor):
@ -368,8 +436,10 @@ def t__getitem__(self, input):
for i, s in enumerate(input):
if s is ... or isinstance(s, DimList) and not s.is_bound:
if expanding_object is not None:
msg = 'at most one ... or unbound dimension list can exist in indexing list but' \
f' found 2 at offsets {i} and {expanding_object}'
msg = (
"at most one ... or unbound dimension list can exist in indexing list but"
f" found 2 at offsets {i} and {expanding_object}"
)
raise DimensionBindError(msg)
expanding_object = i
@ -381,12 +451,16 @@ def t__getitem__(self, input):
ndim = self.ndim
if dims_indexed > ndim:
raise IndexError(f'at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions.')
raise IndexError(
f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions."
)
if expanding_object is not None:
expanding_ndims = ndim - dims_indexed
obj = input[expanding_object]
if obj is ...:
input[expanding_object:expanding_object + 1] = [no_slice] * expanding_ndims
input[expanding_object : expanding_object + 1] = [
no_slice
] * expanding_ndims
else:
obj.bind_len(expanding_ndims)
# flatten the dimslists into the indexing
@ -420,7 +494,7 @@ def t__getitem__(self, input):
elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
for d in idx:
dims_seen.record(idx)
_bind_dims_to_size(sz, idx, f'offset {i}')
_bind_dims_to_size(sz, idx, f"offset {i}")
view_sizes.extend(d.size for d in idx)
requires_view = True
dim_packs.append(i)
@ -499,6 +573,7 @@ def t__getitem__(self, input):
return Tensor.from_positional(result, result_levels, has_device)
# XXX - dim is optional and can be the outer-most dimension...
def stack(tensors, new_dim, dim=0, out=None):
if isinstance(dim, int):
@ -517,12 +592,20 @@ def stack(tensors, new_dim, dim=0, out=None):
pr = torch.stack(ptensors, index, out=out)
return pr.index((index, index + 1), (new_dim, dim))
_orig_split = torch.Tensor.split
def split(self, split_size_or_sections, dim=0):
from . import Dim, _Tensor
if isinstance(split_size_or_sections, int) or any(isinstance(t, int) for t in split_size_or_sections):
from . import _Tensor, Dim
if isinstance(split_size_or_sections, int) or any(
isinstance(t, int) for t in split_size_or_sections
):
if isinstance(dim, Dim):
raise ValueError('when dim is specified as a Dim object, split sizes must also be dimensions.')
raise ValueError(
"when dim is specified as a Dim object, split sizes must also be dimensions."
)
return _orig_split(self, split_size_or_sections, dim=dim)
if isinstance(dim, Dim):
@ -542,8 +625,9 @@ def split(self, split_size_or_sections, dim=0):
unbound.append(i)
if unbound:
assert total_bound_size <= size, \
f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
assert (
total_bound_size <= size
), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
remaining_size = size - total_bound_size
chunk_size = -(-remaining_size // len(unbound))
for u in unbound:
@ -552,6 +636,10 @@ def split(self, split_size_or_sections, dim=0):
sizes[u] = sz
remaining_size -= sz
else:
assert total_bound_size == size, \
f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
return tuple(t.index(dim, d) for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)))
assert (
total_bound_size == size
), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
return tuple(
t.index(dim, d)
for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
)

View File

@ -5,8 +5,10 @@
# LICENSE file in the root directory of this source tree.
from functorch._C import dim
tree_flatten = dim.tree_flatten
def tree_map(fn, tree):
vs, unflatten = tree_flatten(tree)
return unflatten(fn(v) for v in vs)

View File

@ -4,22 +4,35 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from types import FunctionType, BuiltinMethodType, MethodDescriptorType, WrapperDescriptorType, GetSetDescriptorType
from types import (
BuiltinMethodType,
FunctionType,
GetSetDescriptorType,
MethodDescriptorType,
WrapperDescriptorType,
)
from functorch._C import dim as _C
_wrap_method = _C._wrap_method
FUNC_TYPES = (FunctionType, MethodDescriptorType, BuiltinMethodType, WrapperDescriptorType)
FUNC_TYPES = (
FunctionType,
MethodDescriptorType,
BuiltinMethodType,
WrapperDescriptorType,
)
PROPERTY_TYPES = (GetSetDescriptorType, property)
def _py_wrap_method(orig, __torch_function__):
def impl(*args, **kwargs):
return __torch_function__(orig, None, args, kwargs)
return impl
def wrap_type(use_c, to_patch, pattern, __torch_function__):
if use_c:
wrap_method = _wrap_method
else:
@ -29,18 +42,27 @@ def wrap_type(use_c, to_patch, pattern, __torch_function__):
for t in reversed(pattern.mro()[:-1]): # skip object
all.update(t.__dict__)
def wrap_attr(orig):
return property(wrap_method(orig.__get__, __torch_function__))
for name, obj in all.items():
if name in ('__dict__', '__new__', '__init__', '__repr__', '__weakref__', '__doc__', '__module__', '__dir__'):
if name in (
"__dict__",
"__new__",
"__init__",
"__repr__",
"__weakref__",
"__doc__",
"__module__",
"__dir__",
):
continue
# skip things that have been overloaded
# things that come from object like `__eq__` still need to be patched, however.
if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(object, name, None):
if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(
object, name, None
):
continue
if isinstance(obj, FUNC_TYPES):

View File

@ -14,18 +14,21 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import functorch
# import sys
# source code directory, relative to this file, for sphinx-autobuild
# sys.path.insert(0, os.path.abspath('../..'))
import torch
import functorch
RELEASE = os.environ.get('RELEASE', False)
RELEASE = os.environ.get("RELEASE", False)
import sys
import pytorch_sphinx_theme
import sys
# -- General configuration ------------------------------------------------
@ -35,18 +38,18 @@ import sys
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"sphinx.ext.todo",
"sphinx.ext.coverage",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
# 'sphinxcontrib.katex',
'sphinx.ext.autosectionlabel',
'sphinx_copybutton',
'myst_nb',
"sphinx.ext.autosectionlabel",
"sphinx_copybutton",
"myst_nb",
]
# sys.path.insert(0, os.path.abspath('./notebooks'))
@ -75,21 +78,21 @@ napoleon_use_ivar = True
autosummary_generate = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
source_suffix = ".rst"
# The master toctree document.
master_doc = 'index'
master_doc = "index"
# General information about the project.
project = 'functorch'
copyright = 'PyTorch Contributors'
author = 'PyTorch Contributors'
project = "functorch"
copyright = "PyTorch Contributors"
author = "PyTorch Contributors"
functorch_version = str(functorch.__version__)
# The version info for the project you're documenting, acts as replacement for
@ -98,16 +101,16 @@ functorch_version = str(functorch.__version__)
#
# The short X.Y version.
# TODO: change to [:2] at v1.0
version = 'nightly (' + functorch_version + ')'
version = "nightly (" + functorch_version + ")"
# The full version, including alpha/beta/rc tags.
# TODO: verify this works as expected
release = 'nightly'
release = "nightly"
# Customized html_title here.
# Default is " ".join(project, release, "documentation") if not set
# TODO: I don't know if this flag works, please check before using it
if RELEASE:
raise RuntimeError('NYI')
raise RuntimeError("NYI")
# remove hash (start with 'a') from version number if any
# version_end = functorch_version.find('a')
# if version_end == -1:
@ -128,10 +131,10 @@ language = "en"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['notebooks/colab**', 'notebooks/_src/**']
exclude_patterns = ["notebooks/colab**", "notebooks/_src/**"]
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
@ -140,7 +143,7 @@ todo_include_todos = True
autodoc_inherit_docstrings = False
# Disable displaying type annotations, these can be very verbose
autodoc_typehints = 'none'
autodoc_typehints = "none"
# Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True
@ -159,7 +162,7 @@ autodoc_docstring_signature = True
#
#
html_theme = 'pytorch_sphinx_theme'
html_theme = "pytorch_sphinx_theme"
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme
@ -178,10 +181,10 @@ html_theme_options = {
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]
html_css_files = [
'css/custom.css',
"css/custom.css",
]
@ -191,19 +194,20 @@ def setup(app):
# and can be moved outside of this function (and the setup(app) function
# can be deleted).
html_css_files = [
'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css'
"https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css"
]
# In Sphinx 1.8 it was renamed to `add_css_file`, 1.7 and prior it is
# `add_stylesheet` (deprecated in 1.8).
add_css = getattr(app, 'add_css_file', app.add_stylesheet)
add_css = getattr(app, "add_css_file", app.add_stylesheet)
for css_file in html_css_files:
add_css(css_file)
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'PyTorchdoc'
htmlhelp_basename = "PyTorchdoc"
# -- Options for LaTeX output ---------------------------------------------
@ -212,15 +216,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
@ -230,8 +231,13 @@ latex_elements = {
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'pytorch.tex', 'PyTorch Documentation',
'Torch Contributors', 'manual'),
(
master_doc,
"pytorch.tex",
"PyTorch Documentation",
"Torch Contributors",
"manual",
),
]
@ -239,10 +245,7 @@ latex_documents = [
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'functorch', 'functorch Documentation',
[author], 1)
]
man_pages = [(master_doc, "functorch", "functorch Documentation", [author], 1)]
# -- Options for Texinfo output -------------------------------------------
@ -251,37 +254,44 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'functorch', 'functorch Documentation',
author, 'functorch', 'One line description of project.',
'Miscellaneous'),
(
master_doc,
"functorch",
"functorch Documentation",
author,
"functorch",
"One line description of project.",
"Miscellaneous",
),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
'numpy': ('https://numpy.org/doc/stable', None),
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable", None),
"torch": ("https://pytorch.org/docs/stable/", None),
}
import sphinx.ext.doctest
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043
from docutils import nodes
from sphinx.util.docfields import TypedField
from sphinx import addnodes
import sphinx.ext.doctest
from sphinx.util.docfields import TypedField
# Without this, doctest adds any example with a `>>>` as a test
doctest_test_doctest_blocks = ''
doctest_test_doctest_blocks = ""
doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS
doctest_global_setup = '''
doctest_global_setup = """
import torch
try:
import torchvision
except ImportError:
torchvision = None
'''
"""
def patched_make_field(self, types, domain, items, **kw):
@ -291,43 +301,51 @@ def patched_make_field(self, types, domain, items, **kw):
# (List, unicode, Tuple) -> nodes.field
def handle_item(fieldarg, content):
par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added
par += addnodes.literal_strong("", fieldarg) # Patch: this line added
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
# addnodes.literal_strong))
if fieldarg in types:
par += nodes.Text(' (')
par += nodes.Text(" (")
# NOTE: using .pop() here to prevent a single type node to be
# inserted twice into the doctree, which leads to
# inconsistencies later when references are resolved
fieldtype = types.pop(fieldarg)
if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):
typename = u''.join(n.astext() for n in fieldtype)
typename = typename.replace('int', 'python:int')
typename = typename.replace('long', 'python:long')
typename = typename.replace('float', 'python:float')
typename = typename.replace('bool', 'python:bool')
typename = typename.replace('type', 'python:type')
par.extend(self.make_xrefs(self.typerolename, domain, typename,
addnodes.literal_emphasis, **kw))
typename = "".join(n.astext() for n in fieldtype)
typename = typename.replace("int", "python:int")
typename = typename.replace("long", "python:long")
typename = typename.replace("float", "python:float")
typename = typename.replace("bool", "python:bool")
typename = typename.replace("type", "python:type")
par.extend(
self.make_xrefs(
self.typerolename,
domain,
typename,
addnodes.literal_emphasis,
**kw,
)
)
else:
par += fieldtype
par += nodes.Text(')')
par += nodes.Text(' -- ')
par += nodes.Text(")")
par += nodes.Text(" -- ")
par += content
return par
fieldname = nodes.field_name('', self.label)
fieldname = nodes.field_name("", self.label)
if len(items) == 1 and self.can_collapse:
fieldarg, content = items[0]
bodynode = handle_item(fieldarg, content)
else:
bodynode = self.list_type()
for fieldarg, content in items:
bodynode += nodes.list_item('', handle_item(fieldarg, content))
fieldbody = nodes.field_body('', bodynode)
return nodes.field('', fieldname, fieldbody)
bodynode += nodes.list_item("", handle_item(fieldarg, content))
fieldbody = nodes.field_body("", bodynode)
return nodes.field("", fieldname, fieldbody)
TypedField.make_field = patched_make_field
copybutton_prompt_text = r'>>> |\.\.\. '
copybutton_prompt_text = r">>> |\.\.\. "
copybutton_prompt_is_regexp = True

View File

@ -1,3 +1,3 @@
from .rearrange import rearrange
__all__ = ['rearrange']
__all__ = ["rearrange"]

View File

@ -40,7 +40,9 @@ class AnonymousAxis:
def __init__(self, value: str) -> None:
self.value = int(value)
if self.value < 1:
raise ValueError(f'Anonymous axis should have positive length, not {self.value}')
raise ValueError(
f"Anonymous axis should have positive length, not {self.value}"
)
def __repr__(self) -> str:
return f"{self.value}-axis"
@ -49,7 +51,13 @@ class AnonymousAxis:
class ParsedExpression:
"""Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)')."""
def __init__(self, expression: str, *, allow_underscore: bool = False, allow_duplicates: bool = False) -> None:
def __init__(
self,
expression: str,
*,
allow_underscore: bool = False,
allow_duplicates: bool = False,
) -> None:
"""Parse the expression and store relevant metadata.
Args:
@ -66,10 +74,13 @@ class ParsedExpression:
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
if "." in expression:
if "..." not in expression:
raise ValueError("Expression may contain dots only inside ellipsis (...)")
raise ValueError(
"Expression may contain dots only inside ellipsis (...)"
)
if str.count(expression, "...") != 1 or str.count(expression, ".") != 3:
raise ValueError(
"Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ")
"Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor "
)
expression = expression.replace("...", _ellipsis)
self.has_ellipsis = True
@ -78,7 +89,9 @@ class ParsedExpression:
def add_axis_name(x: str) -> None:
if x in self.identifiers:
if not (allow_underscore and x == "_") and not allow_duplicates:
raise ValueError(f"Indexing expression contains duplicate dimension '{x}'")
raise ValueError(
f"Indexing expression contains duplicate dimension '{x}'"
)
if x == _ellipsis:
self.identifiers.add(_ellipsis)
if bracket_group is None:
@ -96,10 +109,14 @@ class ParsedExpression:
else:
pass # no need to think about 1s inside parenthesis
return
is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore)
is_axis_name, reason = self.check_axis_name_return_reason(
x, allow_underscore=allow_underscore
)
if not (is_number or is_axis_name):
raise ValueError(f"Invalid axis identifier: {x}\n{reason}")
axis_name: Union[str, AnonymousAxis] = AnonymousAxis(x) if is_number else x
axis_name: Union[str, AnonymousAxis] = (
AnonymousAxis(x) if is_number else x
)
self.identifiers.add(axis_name)
if is_number:
self.has_non_unitary_anonymous_axes = True
@ -116,7 +133,9 @@ class ParsedExpression:
current_identifier = None
if char == "(":
if bracket_group is not None:
raise ValueError("Axis composition is one-level (brackets inside brackets not allowed)")
raise ValueError(
"Axis composition is one-level (brackets inside brackets not allowed)"
)
bracket_group = []
elif char == ")":
if bracket_group is None:
@ -137,7 +156,9 @@ class ParsedExpression:
add_axis_name(current_identifier)
@staticmethod
def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]:
def check_axis_name_return_reason(
name: str, allow_underscore: bool = False
) -> Tuple[bool, str]:
"""Check if the given axis name is valid, and a message explaining why if not.
Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
@ -157,10 +178,14 @@ class ParsedExpression:
return False, "axis name should should not start or end with underscore"
else:
if keyword.iskeyword(name):
warnings.warn(f"It is discouraged to use axes names that are keywords: {name}", RuntimeWarning)
warnings.warn(
f"It is discouraged to use axes names that are keywords: {name}",
RuntimeWarning,
)
if name in ["axis"]:
warnings.warn(
"It is discouraged to use 'axis' as an axis name and will raise an error in future", FutureWarning
"It is discouraged to use 'axis' as an axis name and will raise an error in future",
FutureWarning,
)
return True, ""
@ -178,8 +203,9 @@ class ParsedExpression:
return is_valid
def parse_pattern(pattern: str, axes_lengths: Mapping[str, int]) -> Tuple[ParsedExpression, ParsedExpression]:
def parse_pattern(
pattern: str, axes_lengths: Mapping[str, int]
) -> Tuple[ParsedExpression, ParsedExpression]:
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
Args:
@ -203,9 +229,13 @@ def parse_pattern(pattern: str, axes_lengths: Mapping[str, int]) -> Tuple[Parsed
right = ParsedExpression(right_str)
if not left.has_ellipsis and right.has_ellipsis:
raise ValueError(f"Ellipsis found in right side, but not left side of a pattern {pattern}")
raise ValueError(
f"Ellipsis found in right side, but not left side of a pattern {pattern}"
)
if left.has_ellipsis and left.has_ellipsis_parenthesized:
raise ValueError(f"Ellipsis is parenthesis in the left side is not allowed: {pattern}")
raise ValueError(
f"Ellipsis is parenthesis in the left side is not allowed: {pattern}"
)
return left, right
@ -222,18 +252,24 @@ def validate_rearrange_expressions(
"""
for length in axes_lengths.values():
if (length_type := type(length)) is not int:
raise TypeError(f"rearrange axis lengths must be integers, got: {length_type}")
raise TypeError(
f"rearrange axis lengths must be integers, got: {length_type}"
)
if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes:
raise ValueError("rearrange only supports unnamed axes of size 1")
difference = set.symmetric_difference(left.identifiers, right.identifiers)
if len(difference) > 0:
raise ValueError(f"Identifiers only on one side of rearrange expression (should be on both): {difference}")
raise ValueError(
f"Identifiers only on one side of rearrange expression (should be on both): {difference}"
)
unmatched_axes = axes_lengths.keys() - left.identifiers
if len(unmatched_axes) > 0:
raise ValueError(f"Identifiers not found in rearrange expression: {unmatched_axes}")
raise ValueError(
f"Identifiers not found in rearrange expression: {unmatched_axes}"
)
def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
@ -259,6 +295,8 @@ def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
'(d0,), (), (d1,), (d2,), (d3, d4)'
"""
return ", ".join(
item if isinstance(item, str) else f"({comma_separate(item)}{',' if len(item) == 1 else ''})"
item
if isinstance(item, str)
else f"({comma_separate(item)}{',' if len(item) == 1 else ''})"
for item in collection
)

View File

@ -4,8 +4,15 @@ import functools
from typing import Callable, Dict, List, Sequence, Tuple, Union
import torch
from functorch._C import dim as _C
from ._parsing import AnonymousAxis, _ellipsis, comma_separate, parse_pattern, validate_rearrange_expressions
from ._parsing import (
_ellipsis,
AnonymousAxis,
comma_separate,
parse_pattern,
validate_rearrange_expressions,
)
__all__ = ["rearrange"]
@ -79,10 +86,12 @@ def _create_rearrange_callable(
dims_i += 1
elif dimension == _ellipsis:
identifier = _ellipsis
identifier_dim_map[identifier] = tuple(first_class_dims[dims_i + j] for j in range(n_ellipsis_dims))
identifier_dim_map[identifier] = tuple(
first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)
)
dims_i += n_ellipsis_dims
else:
raise ValueError(f'Unexpected dimension: {dimension}')
raise ValueError(f"Unexpected dimension: {dimension}")
def composition_to_dims(
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
@ -92,11 +101,17 @@ def _create_rearrange_callable(
dim_composition: List[Union[str, Tuple[str, ...]]] = []
for dimension in composition:
if isinstance(dimension, list):
dim_composition.append(tuple(dim for identifier in dimension for dim in identifier_dim_map[identifier]))
dim_composition.append(
tuple(
dim
for identifier in dimension
for dim in identifier_dim_map[identifier]
)
)
elif dimension == _ellipsis:
dim_composition.extend(identifier_dim_map[_ellipsis])
else:
raise ValueError(f'Unexpected dimension: {dimension}')
raise ValueError(f"Unexpected dimension: {dimension}")
return dim_composition
left_dims = composition_to_dims(left.composition)
@ -108,16 +123,22 @@ def _create_rearrange_callable(
custom_rearrange_callable_name = "do_rearrange"
custom_rearrange_callable_code = (
(
f"def {custom_rearrange_callable_name}(tensor):\n"
f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
)
+ (
"".join(f" {dim}.size = {length}\n" for (dim, length) in specified_lengths)
if specified_lengths else ""
"".join(
f" {dim}.size = {length}\n" for (dim, length) in specified_lengths
)
if specified_lengths
else ""
)
+ f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
+ (
f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
if anon_dims else " return tensor\n"
if anon_dims
else " return tensor\n"
)
)
@ -126,7 +147,9 @@ def _create_rearrange_callable(
def rearrange(
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], pattern: str, **axes_lengths: int
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
pattern: str,
**axes_lengths: int,
) -> torch.Tensor:
r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional
tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
@ -177,6 +200,8 @@ def rearrange(
if not isinstance(tensor, torch.Tensor):
tensor = torch.stack(tensor)
rearrange_callable = _create_rearrange_callable(tensor.ndim, pattern, **axes_lengths)
rearrange_callable = _create_rearrange_callable(
tensor.ndim, pattern, **axes_lengths
)
return rearrange_callable(tensor)

View File

@ -1,7 +1,8 @@
from functorch.compile import aot_function, tvm_compile
import torch
import time
import torch
import torch.utils
from functorch.compile import aot_function, tvm_compile
a = torch.randn(2000, 1, 4, requires_grad=True)
b = torch.randn(1, 2000, 4)
@ -11,8 +12,8 @@ def f(a):
return (a * b).sum(dim=0)
fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops')
bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops')
fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops")
bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops")
compiled_f = aot_function(f, fw_compiler, bw_compiler)
# fw_compiler = lambda x, _: x
@ -32,13 +33,15 @@ def bench(func):
def bench_jax():
import jax.numpy as jnp
import jax
import jax.numpy as jnp
jax_a = jnp.array(a.detach().numpy())
jax_b = jnp.array(b.detach().numpy())
def f(a):
return jnp.sin((a * jax_b).sum(axis=[0])).sum()
jit_f = jax.jit(jax.grad(f))
jit_f(jax_a)
begin = time.time()

View File

@ -1,15 +1,16 @@
import timeit
from functorch.compile import compiled_module, tvm_compile
import torch.nn as nn
import torch
import torch.nn as nn
from functorch.compile import compiled_module, tvm_compile
def nop(f, _):
return f
fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops')
bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops')
fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops")
bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops")
fw_compiler = nop
bw_compiler = nop

View File

@ -4,11 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functorch import make_functional
from functorch.compile import nnc_jit
import time
import torch
import torch.nn as nn
import time
from functorch import make_functional
from functorch.compile import nnc_jit
torch._C._jit_override_can_fuse_on_cpu(True)
@ -54,7 +56,9 @@ def functional_step(x, weights):
return out, new_weights
optim = torch.optim.SGD(jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0)
optim = torch.optim.SGD(
jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0
)
def jit_step(x, weights):

View File

@ -4,10 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import time
import torch
from functorch import grad, make_fx
from functorch.compile import nnc_jit
import torch
import time
def f(x):

View File

@ -17,8 +17,8 @@ import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
from torchvision import models
from opacus import PrivacyEngine
from torchvision import models
from torchvision.datasets import CIFAR10
from tqdm import tqdm
@ -52,7 +52,6 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch, device):
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
@ -279,6 +278,7 @@ def main():
)
logger.info(metrics)
def parse_args():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
@ -309,7 +309,7 @@ def parse_args():
default=256,
type=int,
metavar="N",
help="mini-batch size for test dataset (default: 256)"
help="mini-batch size for test dataset (default: 256)",
)
parser.add_argument(
"--sample-rate",

View File

@ -17,12 +17,12 @@ import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
from torch.func import functional_call, grad_and_value, vmap
from torchvision import models
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from torch.func import vmap, grad_and_value, functional_call
logging.basicConfig(
format="%(asctime)s:%(levelname)s:%(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
@ -44,12 +44,16 @@ def accuracy(preds, labels):
def compute_norms(sample_grads):
batch_size = sample_grads[0].shape[0]
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
norms = [
sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads
]
norms = torch.stack(norms, dim=0).norm(2, dim=0)
return norms, batch_size
def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
def clip_and_accumulate_and_add_noise(
model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0
):
sample_grads = tuple(param.grad_sample for param in model.parameters())
# step 0: compute the norms
@ -60,13 +64,16 @@ def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise
clip_factor = clip_factor.clamp(max=1.0)
# step 2: clip
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad)
for sample_grad in sample_grads)
grads = tuple(
torch.einsum("i,i...", clip_factor, sample_grad) for sample_grad in sample_grads
)
# step 3: add gaussian noise
stddev = max_per_sample_grad_norm * noise_multiplier
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
for grad_param in grads)
noises = tuple(
torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
for grad_param in grads
)
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
# step 4: assign the new grads, delete the sample grads
@ -84,7 +91,6 @@ def train(args, model, train_loader, optimizer, epoch, device):
top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device)
target = target.to(device)
@ -120,8 +126,9 @@ def train(args, model, train_loader, optimizer, epoch, device):
# detaching weights since we don't need to track gradients outside of transforms
# and this is more performant
detached_weights = {k: v.detach() for k, v in weights.items()}
sample_grads, (sample_loss, output) = \
vmap(grads_loss_output, (None, 0, 0))(detached_weights, images, target)
sample_grads, (sample_loss, output) = vmap(grads_loss_output, (None, 0, 0))(
detached_weights, images, target
)
loss = sample_loss.mean()
for name, grad_sample in sample_grads.items():
@ -129,7 +136,8 @@ def train(args, model, train_loader, optimizer, epoch, device):
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
clip_and_accumulate_and_add_noise(
model, args.max_per_sample_grad_norm, args.sigma)
model, args.max_per_sample_grad_norm, args.sigma
)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
@ -270,9 +278,7 @@ def main():
for param_group in optimizer.param_groups:
param_group["lr"] = lr
train_duration = train(
args, model, train_loader, optimizer, epoch, device
)
train_duration = train(args, model, train_loader, optimizer, epoch, device)
top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint
@ -308,6 +314,7 @@ def main():
)
logger.info(metrics)
def parse_args():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument(
@ -338,7 +345,7 @@ def parse_args():
default=256,
type=int,
metavar="N",
help="mini-batch size for test dataset (default: 256)"
help="mini-batch size for test dataset (default: 256)",
)
parser.add_argument(
"--sample-rate",

View File

@ -1,9 +1,10 @@
import argparse
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import functional_call, grad_and_value, vmap, stack_module_state
from torch.func import functional_call, grad_and_value, stack_module_state, vmap
# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a
# tutorial on Model Ensembling with JAX by Will Whitney.
@ -33,15 +34,21 @@ DEVICE = args.device
# Step 1: Make some spirals
def make_spirals(n_samples, noise_std=0., rotations=1.):
def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
ts = torch.linspace(0, 1, n_samples, device=DEVICE)
rs = ts**0.5
thetas = rs * rotations * 2 * math.pi
signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1
labels = (signs > 0).to(torch.long).to(DEVICE)
xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std
ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std
xs = (
rs * signs * torch.cos(thetas)
+ torch.randn(n_samples, device=DEVICE) * noise_std
)
ys = (
rs * signs * torch.sin(thetas)
+ torch.randn(n_samples, device=DEVICE) * noise_std
)
points = torch.stack([xs, ys], dim=1)
return points, labels
@ -70,6 +77,7 @@ class MLPClassifier(nn.Module):
loss_fn = nn.NLLLoss()
model = MLPClassifier().to(DEVICE)
def train_step_fn(weights, batch, targets, lr=0.2):
def compute_loss(weights, batch, targets):
output = functional_call(model, weights, batch)
@ -109,6 +117,7 @@ def init_fn(num_models):
params, _ = stack_module_state(models)
return params
# Step 6: Now, can we try multiple models at the same time?
# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
# on decreasing

View File

@ -4,11 +4,11 @@
import torch
from torch import nn
from torch.nn.functional import mse_loss
from torch.func import jacrev, vmap
from torch.nn.functional import mse_loss
sigma = 0.5
epsilon = 4.
epsilon = 4.0
def lennard_jones(r):
@ -29,7 +29,9 @@ norms = torch.norm(drs, dim=1).reshape(-1, 1)
# Create training energies
training_energies = torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1)
# Create forces with random direction vectors
training_forces = torch.stack([force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)])
training_forces = torch.stack(
[force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)]
)
model = nn.Sequential(
nn.Linear(1, 16),
@ -40,7 +42,7 @@ model = nn.Sequential(
nn.Tanh(),
nn.Linear(16, 16),
nn.Tanh(),
nn.Linear(16, 1)
nn.Linear(16, 1),
)
@ -54,7 +56,10 @@ def make_prediction(model, drs):
def loss_fn(energies, forces, predicted_energies, predicted_forces):
return mse_loss(energies, predicted_energies) + 0.01 * mse_loss(forces, predicted_forces) / 3
return (
mse_loss(energies, predicted_energies)
+ 0.01 * mse_loss(forces, predicted_forces) / 3
)
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)

View File

@ -27,38 +27,43 @@ Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""
from support.omniglot_loaders import OmniglotNShot
import higher
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
import torch
import matplotlib.pyplot as plt
import argparse
import time
import pandas as pd
import numpy as np
import higher
import matplotlib as mpl
mpl.use('Agg')
plt.style.use('bmh')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
from support.omniglot_loaders import OmniglotNShot
from torch import nn
mpl.use("Agg")
plt.style.use("bmh")
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5)
argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5)
argparser.add_argument(
'--k-spt', '--k_spt', type=int, help='k shot for support set', default=5)
"--k-spt", "--k_spt", type=int, help="k shot for support set", default=5
)
argparser.add_argument(
'--k-qry', '--k_qry', type=int, help='k shot for query set', default=15)
"--k-qry", "--k_qry", type=int, help="k shot for query set", default=15
)
argparser.add_argument("--device", type=str, help="device", default="cuda")
argparser.add_argument(
'--device', type=str, help='device', default='cuda')
argparser.add_argument(
'--task-num', '--task_num',
"--task-num",
"--task_num",
type=int,
help='meta batch size, namely task num',
default=32)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
help="meta batch size, namely task num",
default=32,
)
argparser.add_argument("--seed", type=int, help="random seed", default=1)
args = argparser.parse_args()
torch.manual_seed(args.seed)
@ -69,7 +74,7 @@ def main():
# Set up the Omniglot loader.
device = args.device
db = OmniglotNShot(
'/tmp/omniglot-data',
"/tmp/omniglot-data",
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
@ -97,7 +102,8 @@ def main():
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(64, args.n_way)).to(device)
nn.Linear(64, args.n_way),
).to(device)
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
@ -134,9 +140,10 @@ def train(db, net, device, meta_opt, epoch, log):
qry_accs = []
meta_opt.zero_grad()
for i in range(task_num):
with higher.innerloop_ctx(
net, inner_opt, copy_initial_weights=False
) as (fnet, diffopt):
with higher.innerloop_ctx(net, inner_opt, copy_initial_weights=False) as (
fnet,
diffopt,
):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
@ -153,8 +160,7 @@ def train(db, net, device, meta_opt, epoch, log):
qry_logits = fnet(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax(
dim=1) == y_qry[i]).sum().item() / querysz
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
# print([b.shape for b in fnet[1].buffers()])
@ -166,21 +172,23 @@ def train(db, net, device, meta_opt, epoch, log):
meta_opt.step()
qry_losses = sum(qry_losses) / task_num
qry_accs = 100. * sum(qry_accs) / task_num
qry_accs = 100.0 * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}"
)
log.append({
'epoch': i,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
})
log.append(
{
"epoch": i,
"loss": qry_losses,
"acc": qry_accs,
"mode": "train",
"time": time.time(),
}
)
def test(db, net, device, epoch, log):
@ -196,7 +204,7 @@ def test(db, net, device, epoch, log):
qry_accs = []
for _ in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
x_spt, y_spt, x_qry, y_qry = db.next("test")
task_num, setsz, c_, h, w = x_spt.size()
@ -206,7 +214,10 @@ def test(db, net, device, epoch, log):
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
for i in range(task_num):
with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (
fnet,
diffopt,
):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
@ -217,24 +228,22 @@ def test(db, net, device, epoch, log):
# The query loss and acc induced by these parameters.
qry_logits = fnet(x_qry[i]).detach()
qry_loss = F.cross_entropy(
qry_logits, y_qry[i], reduction='none')
qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none")
qry_losses.append(qry_loss.detach())
qry_accs.append(
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
print(
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}")
log.append(
{
"epoch": epoch + 1,
"loss": qry_losses,
"acc": qry_accs,
"mode": "test",
"time": time.time(),
}
)
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log):
@ -243,17 +252,17 @@ def plot(log):
df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train']
test_df = df[df['mode'] == 'test']
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
train_df = df[df["mode"] == "train"]
test_df = df[df["mode"] == "test"]
ax.plot(train_df["epoch"], train_df["acc"], label="Train")
ax.plot(test_df["epoch"], test_df["acc"], label="Test")
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy")
ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right')
fig.legend(ncol=2, loc="lower right")
fig.tight_layout()
fname = 'maml-accs.png'
print(f'--- Plotting accuracy to {fname}')
fname = "maml-accs.png"
print(f"--- Plotting accuracy to {fname}")
fig.savefig(fname)
plt.close(fig)
@ -265,5 +274,5 @@ class Flatten(nn.Module):
return input.view(input.size(0), -1)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -27,38 +27,43 @@ Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""
from support.omniglot_loaders import OmniglotNShot
from functorch import make_functional_with_buffers
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
import torch
import matplotlib.pyplot as plt
import argparse
import time
import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
plt.style.use('bmh')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
from functorch import make_functional_with_buffers
from support.omniglot_loaders import OmniglotNShot
from torch import nn
mpl.use("Agg")
plt.style.use("bmh")
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5)
argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5)
argparser.add_argument(
'--k-spt', '--k_spt', type=int, help='k shot for support set', default=5)
"--k-spt", "--k_spt", type=int, help="k shot for support set", default=5
)
argparser.add_argument(
'--k-qry', '--k_qry', type=int, help='k shot for query set', default=15)
"--k-qry", "--k_qry", type=int, help="k shot for query set", default=15
)
argparser.add_argument("--device", type=str, help="device", default="cuda")
argparser.add_argument(
'--device', type=str, help='device', default='cuda')
argparser.add_argument(
'--task-num', '--task_num',
"--task-num",
"--task_num",
type=int,
help='meta batch size, namely task num',
default=32)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
help="meta batch size, namely task num",
default=32,
)
argparser.add_argument("--seed", type=int, help="random seed", default=1)
args = argparser.parse_args()
torch.manual_seed(args.seed)
@ -69,7 +74,7 @@ def main():
# Set up the Omniglot loader.
device = args.device
db = OmniglotNShot(
'/tmp/omniglot-data',
"/tmp/omniglot-data",
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
@ -97,7 +102,8 @@ def main():
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(64, args.n_way)).to(device)
nn.Linear(64, args.n_way),
).to(device)
net.train()
fnet, params, buffers = make_functional_with_buffers(net)
@ -153,8 +159,7 @@ def train(db, net, device, meta_opt, epoch, log):
qry_logits = fnet(new_params, buffers, x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax(
dim=1) == y_qry[i]).sum().item() / querysz
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
# Update the model's meta-parameters to optimize the query
@ -164,21 +169,23 @@ def train(db, net, device, meta_opt, epoch, log):
meta_opt.step()
qry_losses = sum(qry_losses) / task_num
qry_accs = 100. * sum(qry_accs) / task_num
qry_accs = 100.0 * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}"
)
log.append({
'epoch': i,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
})
log.append(
{
"epoch": i,
"loss": qry_losses,
"acc": qry_accs,
"mode": "train",
"time": time.time(),
}
)
def test(db, net, device, epoch, log):
@ -194,7 +201,7 @@ def test(db, net, device, epoch, log):
qry_accs = []
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
x_spt, y_spt, x_qry, y_qry = db.next("test")
task_num, setsz, c_, h, w = x_spt.size()
# TODO: Maybe pull this out into a separate module so it
@ -211,24 +218,22 @@ def test(db, net, device, epoch, log):
# The query loss and acc induced by these parameters.
qry_logits = fnet(new_params, buffers, x_qry[i]).detach()
qry_loss = F.cross_entropy(
qry_logits, y_qry[i], reduction='none')
qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none")
qry_losses.append(qry_loss.detach())
qry_accs.append(
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
print(
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}")
log.append(
{
"epoch": epoch + 1,
"loss": qry_losses,
"acc": qry_accs,
"mode": "test",
"time": time.time(),
}
)
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log):
@ -237,17 +242,17 @@ def plot(log):
df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train']
test_df = df[df['mode'] == 'test']
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
train_df = df[df["mode"] == "train"]
test_df = df[df["mode"] == "test"]
ax.plot(train_df["epoch"], train_df["acc"], label="Train")
ax.plot(test_df["epoch"], test_df["acc"], label="Test")
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy")
ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right')
fig.legend(ncol=2, loc="lower right")
fig.tight_layout()
fname = 'maml-accs.png'
print(f'--- Plotting accuracy to {fname}')
fname = "maml-accs.png"
print(f"--- Plotting accuracy to {fname}")
fig.savefig(fname)
plt.close(fig)
@ -259,5 +264,5 @@ class Flatten(nn.Module):
return input.view(input.size(0), -1)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -27,39 +27,44 @@ Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""
from support.omniglot_loaders import OmniglotNShot
from torch.func import vmap, grad, functional_call
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
import torch
import matplotlib.pyplot as plt
import argparse
import time
import functools
import time
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
plt.style.use('bmh')
import torch
import torch.nn.functional as F
import torch.optim as optim
from support.omniglot_loaders import OmniglotNShot
from torch import nn
from torch.func import functional_call, grad, vmap
mpl.use("Agg")
plt.style.use("bmh")
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5)
argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5)
argparser.add_argument(
'--k-spt', '--k_spt', type=int, help='k shot for support set', default=5)
"--k-spt", "--k_spt", type=int, help="k shot for support set", default=5
)
argparser.add_argument(
'--k-qry', '--k_qry', type=int, help='k shot for query set', default=15)
"--k-qry", "--k_qry", type=int, help="k shot for query set", default=15
)
argparser.add_argument("--device", type=str, help="device", default="cuda")
argparser.add_argument(
'--device', type=str, help='device', default='cuda')
argparser.add_argument(
'--task-num', '--task_num',
"--task-num",
"--task_num",
type=int,
help='meta batch size, namely task num',
default=32)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
help="meta batch size, namely task num",
default=32,
)
argparser.add_argument("--seed", type=int, help="random seed", default=1)
args = argparser.parse_args()
torch.manual_seed(args.seed)
@ -70,7 +75,7 @@ def main():
# Set up the Omniglot loader.
device = args.device
db = OmniglotNShot(
'/tmp/omniglot-data',
"/tmp/omniglot-data",
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
@ -95,7 +100,8 @@ def main():
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(64, args.n_way)).to(device)
nn.Linear(64, args.n_way),
).to(device)
net.train()
@ -132,8 +138,7 @@ def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
# These will be used to update the model's meta-parameters.
qry_logits = functional_call(net, (new_params, buffers), x_qry)
qry_loss = F.cross_entropy(qry_logits, y_qry)
qry_acc = (qry_logits.argmax(
dim=1) == y_qry).sum() / querysz
qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz
return qry_loss, qry_acc
@ -163,21 +168,23 @@ def train(db, net, device, meta_opt, epoch, log):
meta_opt.step()
qry_losses = qry_losses.detach().sum() / task_num
qry_accs = 100. * qry_accs.sum() / task_num
qry_accs = 100.0 * qry_accs.sum() / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}"
)
log.append({
'epoch': i,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
})
log.append(
{
"epoch": i,
"loss": qry_losses,
"acc": qry_accs,
"mode": "train",
"time": time.time(),
}
)
def test(db, net, device, epoch, log):
@ -194,7 +201,7 @@ def test(db, net, device, epoch, log):
qry_accs = []
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
x_spt, y_spt, x_qry, y_qry = db.next("test")
task_num, setsz, c_, h, w = x_spt.size()
# TODO: Maybe pull this out into a separate module so it
@ -207,28 +214,28 @@ def test(db, net, device, epoch, log):
spt_logits = functional_call(net, (new_params, buffers), x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
grads = torch.autograd.grad(spt_loss, new_params.values())
new_params = {k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads)}
new_params = {
k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads)
}
# The query loss and acc induced by these parameters.
qry_logits = functional_call(net, (new_params, buffers), x_qry[i]).detach()
qry_loss = F.cross_entropy(
qry_logits, y_qry[i], reduction='none')
qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none")
qry_losses.append(qry_loss.detach())
qry_accs.append(
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
print(
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}")
log.append(
{
"epoch": epoch + 1,
"loss": qry_losses,
"acc": qry_accs,
"mode": "test",
"time": time.time(),
}
)
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log):
@ -237,20 +244,20 @@ def plot(log):
df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train']
test_df = df[df['mode'] == 'test']
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
train_df = df[df["mode"] == "train"]
test_df = df[df["mode"] == "test"]
ax.plot(train_df["epoch"], train_df["acc"], label="Train")
ax.plot(test_df["epoch"], test_df["acc"], label="Test")
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy")
ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right')
fig.legend(ncol=2, loc="lower right")
fig.tight_layout()
fname = 'maml-accs.png'
print(f'--- Plotting accuracy to {fname}')
fname = "maml-accs.png"
print(f"--- Plotting accuracy to {fname}")
fig.savefig(fname)
plt.close(fig)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -17,38 +17,38 @@
# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py
# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py
import torchvision.transforms as transforms
from PIL import Image
import errno
import os
import os.path
import numpy as np
import torch
import torch.utils.data as data
import os
import os.path
import errno
import torchvision.transforms as transforms
from PIL import Image
class Omniglot(data.Dataset):
urls = [
'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
"https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip",
"https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip",
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
raw_folder = "raw"
processed_folder = "processed"
training_file = "training.pt"
test_file = "test.pt"
'''
"""
The items are (filename,category). The index of all the categories can be found in self.idx_classes
Args:
- root: the directory where the dataset will be stored
- transform: how to transform the input
- target_transform: how to transform the target
- download: need to download the dataset
'''
"""
def __init__(self, root, transform=None, target_transform=None,
download=False):
def __init__(self, root, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
@ -57,14 +57,16 @@ class Omniglot(data.Dataset):
if download:
self.download()
else:
raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
raise RuntimeError(
"Dataset not found." + " You can use download=True to download it"
)
self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
self.idx_classes = index_classes(self.all_items)
def __getitem__(self, index):
filename = self.all_items[index][0]
img = str.join('/', [self.all_items[index][2], filename])
img = str.join("/", [self.all_items[index][2], filename])
target = self.idx_classes[self.all_items[index][1]]
if self.transform is not None:
@ -78,8 +80,11 @@ class Omniglot(data.Dataset):
return len(self.all_items)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))
return os.path.exists(
os.path.join(self.root, self.processed_folder, "images_evaluation")
) and os.path.exists(
os.path.join(self.root, self.processed_folder, "images_background")
)
def download(self):
import urllib
@ -99,15 +104,15 @@ class Omniglot(data.Dataset):
raise
for url in self.urls:
print('== Downloading ' + url)
print("== Downloading " + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2]
filename = url.rpartition("/")[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
with open(file_path, "wb") as f:
f.write(data.read())
file_processed = os.path.join(self.root, self.processed_folder)
print("== Unzip from " + file_path + " to " + file_processed)
zip_ref = zipfile.ZipFile(file_path, 'r')
zip_ref = zipfile.ZipFile(file_path, "r")
zip_ref.extractall(file_processed)
zip_ref.close()
print("Download finished.")
@ -115,10 +120,10 @@ class Omniglot(data.Dataset):
def find_classes(root_dir):
retour = []
for (root, dirs, files) in os.walk(root_dir):
for root, dirs, files in os.walk(root_dir):
for f in files:
if (f.endswith("png")):
r = root.split('/')
if f.endswith("png"):
r = root.split("/")
lr = len(r)
retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
print(f"== Found {len(retour)} items ")
@ -135,7 +140,6 @@ def index_classes(items):
class OmniglotNShot:
def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None):
"""
Different from mnistNShot, the
@ -149,41 +153,52 @@ class OmniglotNShot:
self.resize = imgsz
self.device = device
if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
if not os.path.isfile(os.path.join(root, "omniglot.npy")):
# if root/data.npy does not exist, just download it
self.x = Omniglot(
root, download=True,
root,
download=True,
transform=transforms.Compose(
[lambda x: Image.open(x).convert('L'),
[
lambda x: Image.open(x).convert("L"),
lambda x: x.resize((imgsz, imgsz)),
lambda x: np.reshape(x, (imgsz, imgsz, 1)),
lambda x: np.transpose(x, [2, 0, 1]),
lambda x: x / 255.]),
lambda x: x / 255.0,
]
),
)
temp = {} # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
for (img, label) in self.x:
temp = (
{}
) # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
for img, label in self.x:
if label in temp.keys():
temp[label].append(img)
else:
temp[label] = [img]
self.x = []
for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs
for (
label,
imgs,
) in temp.items(): # labels info deserted , each label contains 20imgs
self.x.append(np.array(imgs))
# as different class may have different number of imgs
self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total]
self.x = np.array(self.x).astype(
np.float
) # [[20 imgs],..., 1623 classes in total]
# each character contains 20 imgs
print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1]
print("data shape:", self.x.shape) # [1623, 20, 84, 84, 1]
temp = [] # Free memory
# save all dataset into npy file.
np.save(os.path.join(root, 'omniglot.npy'), self.x)
print('write into omniglot.npy.')
np.save(os.path.join(root, "omniglot.npy"), self.x)
print("write into omniglot.npy.")
else:
# if data.npy exists, just load it.
self.x = np.load(os.path.join(root, 'omniglot.npy'))
print('load from omniglot.npy.')
self.x = np.load(os.path.join(root, "omniglot.npy"))
print("load from omniglot.npy.")
# [1623, 20, 84, 84, 1]
# TODO: can not shuffle here, we must keep training and test set distinct!
@ -200,11 +215,18 @@ class OmniglotNShot:
# save pointer of current read batch in total cache
self.indexes = {"train": 0, "test": 0}
self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached
self.datasets = {
"train": self.x_train,
"test": self.x_test,
} # original data cached
print("DB: train", self.x_train.shape, "test", self.x_test.shape)
self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), # current epoch data cached
"test": self.load_data_cache(self.datasets["test"])}
self.datasets_cache = {
"train": self.load_data_cache(
self.datasets["train"]
), # current epoch data cached
"test": self.load_data_cache(self.datasets["test"]),
}
def normalization(self):
"""
@ -238,16 +260,15 @@ class OmniglotNShot:
# print('preload next 50 caches of batchsz of batch.')
for sample in range(10): # num of episodes
x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
for i in range(self.batchsz): # one batch means one set
x_spt, y_spt, x_qry, y_qry = [], [], [], []
selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False)
for j, cur_class in enumerate(selected_cls):
selected_img = np.random.choice(20, self.k_shot + self.k_query, False)
selected_img = np.random.choice(
20, self.k_shot + self.k_query, False
)
# meta-training and meta-test
x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]])
@ -257,10 +278,14 @@ class OmniglotNShot:
# shuffle inside a batch
perm = np.random.permutation(self.n_way * self.k_shot)
x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm]
x_spt = np.array(x_spt).reshape(
self.n_way * self.k_shot, 1, self.resize, self.resize
)[perm]
y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
perm = np.random.permutation(self.n_way * self.k_query)
x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm]
x_qry = np.array(x_qry).reshape(
self.n_way * self.k_query, 1, self.resize, self.resize
)[perm]
y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
# append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84]
@ -270,22 +295,30 @@ class OmniglotNShot:
y_qrys.append(y_qry)
# [b, setsz, 1, 84, 84]
x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize)
x_spts = (
np.array(x_spts)
.astype(np.float32)
.reshape(self.batchsz, setsz, 1, self.resize, self.resize)
)
y_spts = np.array(y_spts).astype(int).reshape(self.batchsz, setsz)
# [b, qrysz, 1, 84, 84]
x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize)
x_qrys = (
np.array(x_qrys)
.astype(np.float32)
.reshape(self.batchsz, querysz, 1, self.resize, self.resize)
)
y_qrys = np.array(y_qrys).astype(int).reshape(self.batchsz, querysz)
x_spts, y_spts, x_qrys, y_qrys = (
torch.from_numpy(z).to(self.device) for z in
[x_spts, y_spts, x_qrys, y_qrys]
torch.from_numpy(z).to(self.device)
for z in [x_spts, y_spts, x_qrys, y_qrys]
)
data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
return data_cache
def next(self, mode='train'):
def next(self, mode="train"):
"""
Gets next batch from the dataset with name.
:param mode: The name of the splitting (one of "train", "val", "test")

View File

@ -2,13 +2,15 @@
# (https://github.com/ericjang/maml-jax).
# We translated his implementation from JAX to PyTorch.
import matplotlib.pyplot as plt
import math
import torch
import numpy as np
from torch.nn import functional as F
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import functional as F
mpl.use("Agg")
def net(x, params):
@ -23,13 +25,15 @@ def net(x, params):
params = [
torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(),
torch.Tensor(40, 1).uniform_(-1.0, 1.0).requires_grad_(),
torch.Tensor(40).zero_().requires_grad_(),
torch.Tensor(40, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(),
torch.Tensor(40, 40)
.uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
.requires_grad_(),
torch.Tensor(40).zero_().requires_grad_(),
torch.Tensor(1, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(),
torch.Tensor(1, 40)
.uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
.requires_grad_(),
torch.Tensor(1).zero_().requires_grad_(),
]
@ -46,17 +50,18 @@ def sample_tasks(outer_batch_size, inner_batch_size):
As = []
phases = []
for _ in range(outer_batch_size):
As.append(np.random.uniform(low=0.1, high=.5))
phases.append(np.random.uniform(low=0., high=np.pi))
As.append(np.random.uniform(low=0.1, high=0.5))
phases.append(np.random.uniform(low=0.0, high=np.pi))
def get_batch():
xs, ys = [], []
for A, phase in zip(As, phases):
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1))
y = A * np.sin(x + phase)
xs.append(x)
ys.append(y)
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
x1, y1 = get_batch()
x2, y2 = get_batch()
return x1, y1, x2, y2
@ -80,14 +85,17 @@ for it in range(20000):
return F.mse_loss(v_f, y2)
task = sample_tasks(num_tasks, K)
inner_losses = [get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i]) for i in range(num_tasks)]
inner_losses = [
get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i])
for i in range(num_tasks)
]
loss2 = sum(inner_losses) / len(inner_losses)
loss2.backward()
opt.step()
if it % 100 == 0:
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
losses.append(loss2.detach())
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
@ -112,11 +120,11 @@ test_y = t_A * torch.sin(test_x + t_b)
test_f = net(test_x, t_params)
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)")
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)")
plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples")
plt.legend()
plt.savefig('maml-sine.png')
plt.savefig("maml-sine.png")
plt.figure()
plt.plot(np.convolve(losses, [.05] * 20))
plt.savefig('losses.png')
plt.plot(np.convolve(losses, [0.05] * 20))
plt.savefig("losses.png")

View File

@ -2,14 +2,16 @@
# (https://github.com/ericjang/maml-jax).
# We translated his implementation from JAX to PyTorch.
from torch.func import grad, vmap
import matplotlib.pyplot as plt
import math
import torch
import numpy as np
from torch.nn import functional as F
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.func import grad, vmap
from torch.nn import functional as F
mpl.use("Agg")
def net(params, x):
@ -24,13 +26,15 @@ def net(params, x):
params = [
torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(),
torch.Tensor(40, 1).uniform_(-1.0, 1.0).requires_grad_(),
torch.Tensor(40).zero_().requires_grad_(),
torch.Tensor(40, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(),
torch.Tensor(40, 40)
.uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
.requires_grad_(),
torch.Tensor(40).zero_().requires_grad_(),
torch.Tensor(1, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(),
torch.Tensor(1, 40)
.uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
.requires_grad_(),
torch.Tensor(1).zero_().requires_grad_(),
]
@ -54,17 +58,18 @@ def sample_tasks(outer_batch_size, inner_batch_size):
As = []
phases = []
for _ in range(outer_batch_size):
As.append(np.random.uniform(low=0.1, high=.5))
phases.append(np.random.uniform(low=0., high=np.pi))
As.append(np.random.uniform(low=0.1, high=0.5))
phases.append(np.random.uniform(low=0.0, high=np.pi))
def get_batch():
xs, ys = [], []
for A, phase in zip(As, phases):
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1))
y = A * np.sin(x + phase)
xs.append(x)
ys.append(y)
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
x1, y1 = get_batch()
x2, y2 = get_batch()
return x1, y1, x2, y2
@ -94,7 +99,7 @@ for it in range(20000):
opt.step()
if it % 100 == 0:
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
losses.append(loss2.detach())
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
@ -119,11 +124,11 @@ test_y = t_A * torch.sin(test_x + t_b)
test_f = net(t_params, test_x)
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)")
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)")
plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples")
plt.legend()
plt.savefig('maml-sine.png')
plt.savefig("maml-sine.png")
plt.figure()
plt.plot(np.convolve(losses, [.05] * 20))
plt.savefig('losses.png')
plt.plot(np.convolve(losses, [0.05] * 20))
plt.savefig("losses.png")

View File

@ -2,15 +2,17 @@
# (https://github.com/ericjang/maml-jax).
# We translated his implementation from JAX to PyTorch.
from functorch import grad, vmap, make_functional
import matplotlib.pyplot as plt
import math
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
from functorch import grad, make_functional, vmap
from torch import nn
from torch.nn import functional as F
import matplotlib as mpl
mpl.use('Agg')
mpl.use("Agg")
class ThreeLayerNet(nn.Module):
@ -30,6 +32,7 @@ class ThreeLayerNet(nn.Module):
x = self.fc3(x)
return x
# TODO: Use F.mse_loss
@ -51,17 +54,18 @@ def sample_tasks(outer_batch_size, inner_batch_size):
As = []
phases = []
for _ in range(outer_batch_size):
As.append(np.random.uniform(low=0.1, high=.5))
phases.append(np.random.uniform(low=0., high=np.pi))
As.append(np.random.uniform(low=0.1, high=0.5))
phases.append(np.random.uniform(low=0.0, high=np.pi))
def get_batch():
xs, ys = [], []
for A, phase in zip(As, phases):
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1))
y = A * np.sin(x + phase)
xs.append(x)
ys.append(y)
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
x1, y1 = get_batch()
x2, y2 = get_batch()
return x1, y1, x2, y2
@ -91,7 +95,7 @@ for it in range(20000):
opt.step()
if it % 100 == 0:
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
losses.append(loss2.detach())
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
@ -116,11 +120,11 @@ test_y = t_A * torch.sin(test_x + t_b)
test_f = net(t_params, test_x)
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)")
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)")
plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples")
plt.legend()
plt.savefig('maml-sine.png')
plt.savefig("maml-sine.png")
plt.figure()
plt.plot(np.convolve(losses, [.05] * 20))
plt.savefig('losses.png')
plt.plot(np.convolve(losses, [0.05] * 20))
plt.savefig("losses.png")

View File

@ -1,5 +1,6 @@
# PyTorch forward-mode is not mature yet
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.eager_transforms import hessian, jacfwd, jvp
from torch._functorch.vmap import chunk_vmap
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from functorch import functionalize

View File

@ -1,26 +1,31 @@
from dataclasses import dataclass
import torch
from torch.multiprocessing.reductions import StorageWeakRef
import torch.utils._pytree as pytree
from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
from torch._dynamo.exc import CondOpArgsMismatchError
from torch._functorch.eager_transforms import (
_unwrap_all_tensors_from_functional,
_wrap_all_tensors_to_functional,
functionalize,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)
from torch.utils._pytree import tree_flatten
from torch._dynamo.exc import CondOpArgsMismatchError
@dataclass
@ -34,9 +39,14 @@ In order to do this, we need implementations for each of the dispatch keys.
"""
cond = HigherOrderOperator("cond")
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors"
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"
assert isinstance(
operands, (list, tuple)
), "Cond operands must be a list or tuple of tensors"
assert all(
isinstance(o, torch.Tensor) for o in operands
), "Cond operands must be a list of tensors"
with disable_proxy_modes_tracing():
true_graph = make_fx(true_fn)(*operands)
@ -45,11 +55,11 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
true_outs = []
false_outs = []
for node in true_graph.graph.nodes:
if node.op == 'output':
if node.op == "output":
true_outs.extend(node.args)
for node in false_graph.graph.nodes:
if node.op == 'output':
if node.op == "output":
false_outs.extend(node.args)
flat_true_outs, _ = pytree.tree_flatten(true_outs)
@ -64,7 +74,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
for i in range(0, len(flat_true_outs)):
true_out = flat_true_outs[i]
false_out = flat_false_outs[i]
if true_out.meta['tensor_meta'] != false_out.meta['tensor_meta']:
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
raise CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
@ -85,7 +95,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
true_name = next_name
false_name = f"false_graph_{i}"
assert(not hasattr(proxy_mode.tracer.root, false_name))
assert not hasattr(proxy_mode.tracer.root, false_name)
proxy_mode.tracer.root.register_module(true_name, true_graph)
proxy_mode.tracer.root.register_module(false_name, false_graph)
@ -94,8 +104,9 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
name="conditional")
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="conditional"
)
# At this point, we're *guaranteed* that whether an output came from the
# true or false branch is indistinguishable. So, as this is just for tracing
@ -112,7 +123,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
@cond.py_impl(DispatchKey.CompositeExplicitAutograd)
def cond_dense(pred, true_fn, false_fn, operands):
mode = _get_current_dispatch_mode()
assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
if pred:
return true_fn(*operands)
else:
@ -125,8 +136,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
requires_grad = any(
isinstance(arg, torch.Tensor) and arg.requires_grad
for arg in flat_operands
isinstance(arg, torch.Tensor) and arg.requires_grad for arg in flat_operands
)
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)):
@ -148,6 +158,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
var = var.detach()
var.requires_grad = True
return var
return err_fn(fake_requires_grad(result))
return result
@ -156,7 +167,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
@cond.py_impl(ProxyTorchDispatchMode)
def inner(pred, true_fn, false_fn, operands):
mode = _get_current_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
assert mode is not None, "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode:
if mode.enable_tracing:
return trace_cond(mode, cond, pred, true_fn, false_fn, operands)
@ -177,7 +188,8 @@ def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
false_meta = _extract_tensor_metadata(false_out)
if true_meta != false_meta:
raise RuntimeError(
f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}")
f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}"
)
return true_outs
@ -203,7 +215,10 @@ def _has_potential_branch_input_mutation(branch, inputs):
input_nodes.add(node)
if node.op == "call_function":
target = node.target
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
if (
isinstance(target, torch._ops.OpOverload)
and target._schema.is_mutable
):
for arg in node.args:
if arg in input_nodes:
return True
@ -241,13 +256,15 @@ def _has_potential_branch_input_alias(branch, inputs):
# for map operator, where num_mapped_args is a scalar
# and doesn't have a "val" meta.
if node.op == "placeholder" and "val" in node.meta:
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
if node.op == "output":
def check_alias(out):
if out is not None and "val" in out.meta:
out_storage = StorageWeakRef(out.meta['val']._typed_storage())
out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
return out_storage in input_storages
return False
if any(pytree.tree_flatten(pytree.tree_map(check_alias, node.args))[0]):
return True
@ -263,22 +280,30 @@ def _has_potential_branch_input_alias(branch, inputs):
@cond.py_impl(DispatchKey.Functionalize)
def cond_func(pred, true_fn, false_fn, inputs):
reapply_views = torch._C._functionalization_reapply_views_tls()
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
mode = 'mutations_and_views' if reapply_views else 'mutations'
unwrapped_inputs = _unwrap_all_tensors_from_functional(
inputs, reapply_views=reapply_views
)
unwrapped_pred = _unwrap_all_tensors_from_functional(
pred, reapply_views=reapply_views
)
mode = "mutations_and_views" if reapply_views else "mutations"
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
functional_true = functionalize(true_fn, remove=mode)
functional_false = functionalize(false_fn, remove=mode)
for branch in [true_fn, false_fn]:
if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be modifying the input!")
raise UnsupportedAliasMutationException(
"One of torch.cond branch " "might be modifying the input!"
)
if _has_potential_branch_input_alias(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be aliasing the input!")
raise UnsupportedAliasMutationException(
"One of torch.cond branch " "might be aliasing the input!"
)
cond_return = cond(unwrapped_pred, functional_true, functional_false, unwrapped_inputs)
cond_return = cond(
unwrapped_pred, functional_true, functional_false, unwrapped_inputs
)
return _wrap_all_tensors_to_functional(cond_return, level=0)
@ -290,10 +315,14 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
2. Our check for above condition is not exhaustive
"""
reapply_views = interpreter.functionalize_add_back_views()
mode = 'mutations_and_views' if reapply_views else 'mutations'
mode = "mutations_and_views" if reapply_views else "mutations"
# At this point, we will see functionalized tensors, so need to unwrap them first
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
unwrapped_inputs = _unwrap_all_tensors_from_functional(
inputs, reapply_views=reapply_views
)
unwrapped_pred = _unwrap_all_tensors_from_functional(
pred, reapply_views=reapply_views
)
functional_true_fn = functionalize(true_fn, remove=mode)
functional_false_fn = functionalize(false_fn, remove=mode)
@ -301,16 +330,21 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
with interpreter.lower():
for branch in [functional_true_fn, functional_false_fn]:
if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be modifying the input!")
raise UnsupportedAliasMutationException(
"One of torch.cond branch " "might be modifying the input!"
)
for branch in [true_fn, false_fn]:
if _has_potential_branch_input_alias(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be aliasing the input!")
raise UnsupportedAliasMutationException(
"One of torch.cond branch " "might be aliasing the input!"
)
cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs)
cond_return = cond(
unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs
)
return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())
# TODO(voz): Make this automatic for keys, this is very ugly atm
cond.fallthrough(DispatchKey.PythonDispatcher)
cond.fallthrough(DispatchKey.PythonTLSSnapshot)

View File

@ -1,23 +1,32 @@
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
from torch._functorch.aot_autograd import create_joint, AOTConfig
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._functorch.eager_transforms import (
_unwrap_all_tensors_from_functional,
_wrap_all_tensors_to_functional,
functionalize,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.multiprocessing.reductions import StorageWeakRef
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)
from torch._dispatch.python import suspend_functionalization
from ._cond import _has_potential_branch_input_alias, _has_potential_branch_input_mutation, UnsupportedAliasMutationException
from ._cond import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
UnsupportedAliasMutationException,
)
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
@ -26,16 +35,19 @@ class MapWrapper(HigherOrderOperator):
def __call__(self, xs, *args):
return map_wrapper(xs, *args)
map = MapWrapper("map", _deprecated_global_ns=True)
map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True)
dummy_aot_config = AOTConfig(fw_compiler=None,
dummy_aot_config = AOTConfig(
fw_compiler=None,
bw_compiler=None,
partition_fn=None,
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False)
keep_inference_input_mutations=False,
)
def create_fw_bw_graph(f, num_mapped_args, *args):
@ -59,20 +71,33 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
with suspend_functionalization():
with disable_proxy_modes_tracing():
def from_fun(t):
if isinstance(t, torch.Tensor):
return torch.empty_strided(t.size(), t.stride(), requires_grad=t.requires_grad)
return torch.empty_strided(
t.size(), t.stride(), requires_grad=t.requires_grad
)
return t
example_xs = [from_fun(xs) for xs in _unstack_pytree(mapped_xs)[0]]
example_pos_args = [from_fun(arg) if isinstance(arg, torch.Tensor) else arg for arg in pos_args]
example_flat_out = pytree.tree_map(from_fun, f(*example_xs, *example_pos_args))
if any(not isinstance(out, torch.Tensor) for out in example_flat_out if out is not None):
raise RuntimeError("Expect outputs of map only contains tensors or None. "
f"Got types {[type(out) for out in example_flat_out]}.")
example_pos_args = [
from_fun(arg) if isinstance(arg, torch.Tensor) else arg
for arg in pos_args
]
example_flat_out = pytree.tree_map(
from_fun, f(*example_xs, *example_pos_args)
)
if any(
not isinstance(out, torch.Tensor)
for out in example_flat_out
if out is not None
):
raise RuntimeError(
"Expect outputs of map only contains tensors or None. "
f"Got types {[type(out) for out in example_flat_out]}."
)
example_grad = [from_fun(out) for out in example_flat_out]
fw_graph = make_fx(f)(*example_xs, *example_pos_args)
def joint_f(*example_args):
@ -84,20 +109,39 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
def fw_with_masks(*args):
fw_out = f(*args)
return fw_out, [True if isinstance(ret, torch.Tensor) and ret.requires_grad else False for ret in fw_out]
return fw_out, [
True
if isinstance(ret, torch.Tensor) and ret.requires_grad
else False
for ret in fw_out
]
joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
_, grads = joint(list(mapped_input) + list(args),
[grad for grad in mapped_grads if grad is not None and grad.requires_grad])
_, grads = joint(
list(mapped_input) + list(args),
[
grad
for grad in mapped_grads
if grad is not None and grad.requires_grad
],
)
# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
input_storage = {StorageWeakRef(arg._typed_storage()) for arg in example_args if isinstance(arg, torch.Tensor)}
input_storage = {
StorageWeakRef(arg._typed_storage())
for arg in example_args
if isinstance(arg, torch.Tensor)
}
def maybe_clone(t):
if isinstance(t, torch.Tensor) and StorageWeakRef(t._typed_storage()) in input_storage:
if (
isinstance(t, torch.Tensor)
and StorageWeakRef(t._typed_storage()) in input_storage
):
return t.clone()
return t
return pytree.tree_map(maybe_clone, grads)
joint_num_mapped = len(example_grad) + len(example_xs)
@ -114,12 +158,12 @@ def map_wrapper(f, xs, *args):
shapes = [xs.shape for xs in flat_xs]
leading_dim_size = shapes[0][0]
if leading_dim_size == 0:
raise RuntimeError(
"Leading dimensions of mapped xs cannot be 0.")
raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")
if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
raise RuntimeError(
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}.")
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
)
out_spec = None
@ -131,7 +175,11 @@ def map_wrapper(f, xs, *args):
nonlocal out_spec
out_spec = tmp_out_spec
return flat_out
return pytree.tree_unflatten(map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec)
return pytree.tree_unflatten(
map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec
)
class MapAutogradOp(torch.autograd.Function):
@staticmethod
@ -148,9 +196,16 @@ class MapAutogradOp(torch.autograd.Function):
fw_mapped_args = fw_args[: ctx._num_mapped_args]
pos_args = fw_args[ctx._num_mapped_args :]
grads = map_impl(ctx._joint_graph, ctx._num_mapped_args + len(flat_grads), *fw_mapped_args, *flat_grads, *pos_args)
grads = map_impl(
ctx._joint_graph,
ctx._num_mapped_args + len(flat_grads),
*fw_mapped_args,
*flat_grads,
*pos_args,
)
return None, None, None, *grads
def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
xs = list(args[:num_mapped])
pos_args = list(args[num_mapped:])
@ -168,6 +223,7 @@ def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
if isinstance(t, torch.Tensor):
return t.expand(leading_dim_size, *t.shape)
return t
expanded_outs = pytree.tree_map(expand_tensor, example_outs)
next_name = None
@ -182,9 +238,13 @@ def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
proxy_mode.tracer.root.register_module(next_name, body_graph)
node_args = (body_graph, num_mapped, *args)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
name="map_impl")
return track_tensor_tree(expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="map_impl"
)
return track_tensor_tree(
expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
)
def _unstack_pytree(xs):
flat_xs, inspec = pytree.tree_flatten(xs)
@ -192,7 +252,9 @@ def _unstack_pytree(xs):
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
raise RuntimeError(f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}")
raise RuntimeError(
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
)
a = zip(*flat_xs)
pytrees = []
@ -200,6 +262,7 @@ def _unstack_pytree(xs):
pytrees.append(pytree.tree_unflatten(tuple, inspec))
return pytrees
def _stack_pytree(pytrees):
flat_out = []
out_spec = None
@ -220,6 +283,7 @@ def _stack_pytree(pytrees):
raise RuntimeError(f"Cannot stack {leaves}.")
return pytree.tree_unflatten(stacked_out, out_spec)
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_dense(f, num_mapped_args, *args):
xs = args[:num_mapped_args]
@ -240,7 +304,7 @@ def map_autograd(f, num_mapped_args, *args):
@map_impl.py_impl(ProxyTorchDispatchMode)
def map_proxy_torch_dispatch_mode(f, num_mapped, *args):
mode = _get_current_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
assert mode is not None, "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode:
if mode.enable_tracing:
return trace_map(mode, map_impl, f, num_mapped, *args)
@ -259,8 +323,10 @@ def map_func(f, num_mapped, *args):
xs = args[:num_mapped]
pos_args = args[num_mapped:]
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views)
mode = 'mutations_and_views' if reapply_views else 'mutations'
unwrapped_args = _unwrap_all_tensors_from_functional(
pos_args, reapply_views=reapply_views
)
mode = "mutations_and_views" if reapply_views else "mutations"
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
functional_map_fn = functionalize(f, remove=mode)
@ -268,18 +334,17 @@ def map_func(f, num_mapped, *args):
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
if _has_potential_branch_input_mutation(f, example_inputs):
raise UnsupportedAliasMutationException(
"torch.map is mutating the input!"
)
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
if _has_potential_branch_input_alias(f, example_inputs):
raise UnsupportedAliasMutationException(
"torch.map is aliasing the input!"
)
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
map_return = map_impl(
functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args
)
return _wrap_all_tensors_to_functional(map_return, level=0)
@map_impl.py_impl(torch._C._functorch.TransformType.Functionalize)
def map_functionalize(interpreter, f, num_mapped, *args):
"""
@ -290,10 +355,12 @@ def map_functionalize(interpreter, f, num_mapped, *args):
xs = args[:num_mapped]
pos_args = args[num_mapped:]
reapply_views = interpreter.functionalize_add_back_views()
mode = 'mutations_and_views' if reapply_views else 'mutations'
mode = "mutations_and_views" if reapply_views else "mutations"
# At this point, we will see functionalized tensors, so need to unwrap them first
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(
pos_args, reapply_views=reapply_views
)
functional_map_fn = functionalize(f, remove=mode)
@ -301,18 +368,17 @@ def map_functionalize(interpreter, f, num_mapped, *args):
with disable_proxy_modes_tracing():
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
if _has_potential_branch_input_mutation(f, example_inputs):
raise UnsupportedAliasMutationException(
"torch.map is mutating the input!"
)
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
if _has_potential_branch_input_alias(f, example_inputs):
raise UnsupportedAliasMutationException(
"torch.map is aliasing the input!"
)
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
map_return = map_impl(
functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args
)
return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())
# TODO(voz) Make this automatic for keys, this is very ugly atm
map_impl.fallthrough(DispatchKey.PythonDispatcher)
map_impl.fallthrough(DispatchKey.PythonTLSSnapshot)

View File

@ -1,2 +1,2 @@
from ._map import map # noqa: F401
from ._cond import cond, UnsupportedAliasMutationException # noqa: F401
from ._map import map # noqa: F401

View File

@ -19,8 +19,10 @@ Let's demonstrate how to do this using an ensemble of simple CNNs.
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
# Here's a simple CNN
class SimpleCNN(nn.Module):
def __init__(self):
@ -44,11 +46,12 @@ class SimpleCNN(nn.Module):
output = x
return output
# Let's generate some dummy data. Pretend that we're working with an MNIST dataset
# where the images are 28 by 28.
# Furthermore, let's say we wish to combine the predictions from 10 different
# models.
device = 'cuda'
device = "cuda"
num_models = 10
data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device)
@ -81,6 +84,7 @@ predictions2 = [model(minibatch) for model in models]
# functorch offers the following convenience function to do that. It returns a
# stateless version of the model (fmodel) and stacked parameters and buffers.
from functorch import combine_state_for_ensemble
fmodel, params, buffers = combine_state_for_ensemble(models)
[p.requires_grad_() for p in params]
@ -92,15 +96,20 @@ fmodel, params, buffers = combine_state_for_ensemble(models)
print([p.size(0) for p in params])
assert minibatches.shape == (num_models, 64, 1, 28, 28)
from functorch import vmap
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
assert torch.allclose(predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6)
assert torch.allclose(
predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6
)
# Option 2: get predictions using the same minibatch of data
# vmap has an in_dims arg that specify which dimensions to map over.
# Using ``None``, we tell vmap we want the same minibatch to apply for all of
# the 10 models.
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-6, rtol=1e-6)
assert torch.allclose(
predictions2_vmap, torch.stack(predictions2), atol=1e-6, rtol=1e-6
)
# A quick note: there are limitations around what types of functions can be
# transformed by vmap. The best functions to transform are ones that are

View File

@ -8,11 +8,14 @@ deep learning models. It is difficult (or annoying) to compute these quantities
efficiently using a standard autodiff system like PyTorch Autograd; functorch
provides ways of computing various higher-order autodiff quantities efficiently.
"""
from functools import partial
import torch
import torch.nn.functional as F
from functools import partial
torch.manual_seed(0)
######################################################################
# Setup: Comparing functorch vs the naive approach
# --------------------------------------------------------------------
@ -21,6 +24,7 @@ torch.manual_seed(0)
def predict(weight, bias, x):
return F.linear(x, weight, bias).tanh()
# Here's some dummy data: a weight, a bias, and a feature vector.
D = 16
weight = torch.randn(D, D)
@ -34,19 +38,24 @@ x = torch.randn(D)
xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)
def compute_jac(xp):
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
for vec in unit_vectors]
jacobian_rows = [
torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
for vec in unit_vectors
]
return torch.stack(jacobian_rows)
jacobian = compute_jac(xp)
# Instead of computing the jacobian row-by-row, we can use ``vmap`` to get rid
# of the for-loop and vectorize the computation. We can't directly apply vmap
# to PyTorch Autograd; instead, functorch provides a ``vjp`` transform:
from functorch import vmap, vjp
from functorch import vjp, vmap
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
(ft_jacobian,) = vmap(vjp_fn)(unit_vectors)
assert torch.allclose(ft_jacobian, jacobian)
# In another tutorial a composition of reverse-mode AD and vmap gave us
@ -59,6 +68,7 @@ assert torch.allclose(ft_jacobian, jacobian)
# argument that says which argument we would like to compute Jacobians with
# respect to.
from functorch import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
assert torch.allclose(ft_jacobian, jacobian)
@ -67,6 +77,7 @@ assert torch.allclose(ft_jacobian, jacobian)
# there are). In general, we expect that vectorization via ``vmap`` can help
# eliminate overhead and give better utilization of your hardware.
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
print(without_vmap.timeit(500))
@ -95,7 +106,7 @@ ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
# In reverse-mode AD, we are computing the jacobian row-by-row, while in
# forward-mode AD (which computes Jacobian-vector products), we are computing
# it column-by-column. The Jacobian matrix has M rows and N columns.
from functorch import jacrev, jacfwd
from functorch import jacfwd, jacrev
# Benchmark with more inputs than outputs
Din = 32
@ -106,8 +117,8 @@ x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
print(f'jacfwd time: {using_fwd.timeit(500)}')
print(f'jacrev time: {using_bwd.timeit(500)}')
print(f"jacfwd time: {using_fwd.timeit(500)}")
print(f"jacrev time: {using_bwd.timeit(500)}")
# Benchmark with more outputs than inputs
Din = 2048
@ -118,8 +129,8 @@ x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
print(f'jacfwd time: {using_fwd.timeit(500)}')
print(f'jacrev time: {using_bwd.timeit(500)}')
print(f"jacfwd time: {using_fwd.timeit(500)}")
print(f"jacrev time: {using_bwd.timeit(500)}")
######################################################################
# Hessian computation with functorch.hessian
@ -132,6 +143,7 @@ print(f'jacrev time: {using_bwd.timeit(500)}')
# Depending on your model, you may want to use ``jacfwd(jacfwd(f))`` or
# ``jacrev(jacrev(f))`` instead to compute hessians.
from functorch import hessian
# # TODO: make sure PyTorch has tanh_backward implemented for jvp!!
# hess0 = hessian(predict, argnums=2)(weight, bias, x)
# hess1 = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
@ -148,9 +160,11 @@ hess2 = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
# The easiest way to do this is to sum over the batch dimension and then
# compute the Jacobian of that function:
def predict_with_output_summed(weight, bias, x):
return predict(weight, bias, x).sum(0)
batch_size = 64
Din = 31
Dout = 33

View File

@ -12,8 +12,10 @@ and optimization research.
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
# Here's a simple CNN
class SimpleCNN(nn.Module):
def __init__(self):
@ -37,12 +39,14 @@ class SimpleCNN(nn.Module):
output = x
return output
def loss_fn(predictions, targets):
return F.nll_loss(predictions, targets)
# Let's generate a batch of dummy data. Pretend that we're working with an
# MNIST dataset where the images are 28 by 28 and we have a minibatch of size 64.
device = 'cuda'
device = "cuda"
num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)
@ -56,6 +60,7 @@ predictions = model(data)
loss = loss_fn(predictions, targets)
loss.backward()
# Conceptually, per-sample-gradient computation is equivalent to: for each sample
# of the data, perform a forward and a backward pass to get a gradient.
def compute_grad(sample, target):
@ -65,12 +70,14 @@ def compute_grad(sample, target):
loss = loss_fn(prediction, target)
return torch.autograd.grad(loss, list(model.parameters()))
def compute_sample_grads(data, targets):
sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
sample_grads = zip(*sample_grads)
sample_grads = [torch.stack(shards) for shards in sample_grads]
return sample_grads
per_sample_grads = compute_sample_grads(data, targets)
# sample_grads[0] is the per-sample-grad for model.conv1.weight
@ -85,9 +92,11 @@ print(per_sample_grads[0].shape)
# We can compute per-sample-gradients efficiently by using function transforms.
# First, let's create a stateless functional version of ``model`` by using
# ``functorch.make_functional_with_buffers``.
from functorch import make_functional_with_buffers, vmap, grad
from functorch import grad, make_functional_with_buffers, vmap
fmodel, params, buffers = make_functional_with_buffers(model)
# Next, let's define a function to compute the loss of the model given a single
# input rather than a batch of inputs. It is important that this function accepts the
# parameters, the input, and the target, because we will be transforming over them.
@ -100,6 +109,7 @@ def compute_loss(params, buffers, sample, target):
loss = loss_fn(predictions, targets)
return loss
# Now, let's use ``grad`` to create a new function that computes the gradient
# with respect to the first argument of compute_loss (i.e. the params).
ft_compute_grad = grad(compute_loss)

View File

@ -1,8 +1,9 @@
import yaml
import csv
import torch
from collections import defaultdict
import torch
import yaml
def get_ops_for_key(key):
# Needs modified PyTorch C++ code to work
@ -12,7 +13,7 @@ def get_ops_for_key(key):
ops = torch._C._dispatch_get_registrations_for_dispatch_key(key)
cleaned_ops = []
for i in ops:
if 'aten::' not in i:
if "aten::" not in i:
continue
cleaned_ops.append(i[6:].strip())
return set(cleaned_ops)
@ -20,12 +21,17 @@ def get_ops_for_key(key):
def gen_data(special_op_lists, analysis_name):
all_ops = get_ops_for_key(None)
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
composite_ops = get_ops_for_key("CompositeImplicitAutograd")
noncomposite_ops = all_ops - composite_ops
ops = yaml.load(open('../../aten/src/ATen/native/native_functions.yaml').read(), Loader=yaml.CLoader)
ops = yaml.load(
open("../../aten/src/ATen/native/native_functions.yaml").read(),
Loader=yaml.CLoader,
)
annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))}
annotated_ops = {
a.strip(): b.strip() for a, b in list(csv.reader(open("annotated_ops")))
}
from collections import defaultdict
uniq_ops = []
@ -33,18 +39,18 @@ def gen_data(special_op_lists, analysis_name):
overload_types = defaultdict(list)
cnt = 0
for op in ops:
func_str = op['func']
name = func_str[:func_str.index('(')]
if '.' in name:
uniq_name = name[:name.index('.')]
overload_types[name[name.index('.') + 1:]].append(name)
func_str = op["func"]
name = func_str[: func_str.index("(")]
if "." in name:
uniq_name = name[: name.index(".")]
overload_types[name[name.index(".") + 1 :]].append(name)
else:
uniq_name = name
op['name'] = uniq_name
full_name = func_str[:func_str.index('(')]
op['full_name'] = full_name
ret_type = func_str[func_str.index('->') + 3:]
op['ret_type'] = ret_type
op["name"] = uniq_name
full_name = func_str[: func_str.index("(")]
op["full_name"] = full_name
ret_type = func_str[func_str.index("->") + 3 :]
op["ret_type"] = ret_type
cnt += 1
if uniq_name in uniq_names:
continue
@ -54,70 +60,78 @@ def gen_data(special_op_lists, analysis_name):
def annotate_ops(ops, is_unique):
categorization = defaultdict(int)
for op in ops:
if op['name'][-1] == '_':
categorization['inplace'] += 1
op['meta'] = 'inplace'
if op["name"][-1] == "_":
categorization["inplace"] += 1
op["meta"] = "inplace"
continue
if not is_unique and 'a!' in op['func'].lower():
categorization['out'] += 1
op['meta'] = 'out'
if not is_unique and "a!" in op["func"].lower():
categorization["out"] += 1
op["meta"] = "out"
continue
if 'conv' in op['name']:
categorization['conv'] += 1
op['meta'] = 'conv'
if "conv" in op["name"]:
categorization["conv"] += 1
op["meta"] = "conv"
continue
if 'pool' in op['name']:
categorization['pool'] += 1
op['meta'] = 'pool'
if "pool" in op["name"]:
categorization["pool"] += 1
op["meta"] = "pool"
continue
if 'backward' in op['name']:
categorization['backward'] += 1
op['meta'] = 'backward'
if "backward" in op["name"]:
categorization["backward"] += 1
op["meta"] = "backward"
continue
if op['name'][0] == '_' and op['name'][1] != '_':
categorization['private'] += 1
op['meta'] = 'private'
if op["name"][0] == "_" and op["name"][1] != "_":
categorization["private"] += 1
op["meta"] = "private"
continue
if 'batch_norm' in op['name']:
categorization['batch_norm'] += 1
op['meta'] = 'batch_norm'
if "batch_norm" in op["name"]:
categorization["batch_norm"] += 1
op["meta"] = "batch_norm"
continue
if 'Tensor' not in op['func'] or 'Tensor' not in op['ret_type']:
categorization['non_tensor'] += 1
op['meta'] = 'non_tensor'
if "Tensor" not in op["func"] or "Tensor" not in op["ret_type"]:
categorization["non_tensor"] += 1
op["meta"] = "non_tensor"
continue
if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or \
'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']:
categorization['backend'] += 1
op['meta'] = 'backend'
if (
"cudnn" in op["name"]
or "mkldnn" in op["name"]
or "miopen" in op["name"]
or "native" in op["name"]
or "thnn" in op["name"]
or "slow" in op["name"]
):
categorization["backend"] += 1
op["meta"] = "backend"
continue
if op['name'] in annotated_ops:
categorization['core'] += 1
op['meta'] = 'core ' + annotated_ops[op['name']]
if op["name"] in annotated_ops:
categorization["core"] += 1
op["meta"] = "core " + annotated_ops[op["name"]]
continue
categorization['core'] += 1
op['meta'] = 'core unknown'
categorization["core"] += 1
op["meta"] = "core unknown"
return categorization
annotate_ops(ops, is_unique=False)
with open(f"{analysis_name}", 'w') as f:
with open(f"{analysis_name}", "w") as f:
for op in ops:
info = [
op['full_name'], op['meta'], op['full_name'] not in noncomposite_ops
op["full_name"],
op["meta"],
op["full_name"] not in noncomposite_ops,
] + [check(op) for check in special_op_lists]
f.write(','.join([str(i) for i in info]) + '\n')
f.write(",".join([str(i) for i in info]) + "\n")
def name_check(lst):
return lambda x: x['name'] in lst
return lambda x: x["name"] in lst
def full_name_check(lst):
return lambda x: x['full_name'] in lst
return lambda x: x["full_name"] in lst
# Generates batching rule data
gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap.txt')
gen_data([full_name_check(get_ops_for_key("FuncTorchBatched"))], "vmap.txt")
def remove_suffix(input_string, suffix):
@ -125,6 +139,7 @@ def remove_suffix(input_string, suffix):
return input_string[: -len(suffix)]
return input_string
def remove_prefix(input_string, prefix):
if prefix and input_string.startswith(prefix):
return input_string[len(prefix) :]
@ -132,26 +147,36 @@ def remove_prefix(input_string, prefix):
if True:
with open('run_ops.txt') as f:
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
with open('count_ops.txt') as f:
with open("run_ops.txt") as f:
opinfo_ops = [remove_suffix(i.strip(), ".default") for i in f.readlines()]
with open("count_ops.txt") as f:
opinfo_counts = [i.strip() for i in f.readlines()]
opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts)))
def count_fn(x):
return opinfo_counts[x['full_name']]
return opinfo_counts[x["full_name"]]
with open('run_decompositions.txt') as f:
decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
with open("run_decompositions.txt") as f:
decomposed_ops = [remove_suffix(i.strip(), ".default") for i in f.readlines()]
with open('public_api') as f:
with open("public_api") as f:
ref_api = [i.strip() for i in f.readlines()]
def has_ref_impl(x):
name = x['name']
name = x["name"]
for prefix in ["linalg_", "special_"]:
name = remove_prefix(name, prefix)
prefixes = ['nn.functional', 'fft', 'special', 'linalg']
return any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api
prefixes = ["nn.functional", "fft", "special", "linalg"]
return (
any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api
)
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops), count_fn, has_ref_impl], 'decompositions.txt')
gen_data(
[
full_name_check(opinfo_ops),
full_name_check(decomposed_ops),
count_fn,
has_ref_impl,
],
"decompositions.txt",
)