[functorch] Added python key stuff that builds with master

This commit is contained in:
Horace He
2021-05-01 20:22:52 -07:00
committed by Jon Janzen
parent 9d36895a83
commit 98e995c467
7 changed files with 665 additions and 4 deletions

View File

@ -4,7 +4,8 @@ from . import _C
from ._src.vmap import vmap
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
from ._src.make_functional import make_functional, make_functional_with_buffers
from ._src.python_key import wrap_key, WrapModule
from ._src.python_key import wrap_key, WrapModule, PythonTensor, pythonkey_trace
from ._src.nnc_compile import nnc_compile
# Monkeypatching lol

View File

@ -0,0 +1,287 @@
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch._C._te as te
import torch.fx as fx
from torch.fx import map_arg
from torch.fx.passes.shape_prop import ShapeProp
import operator
def truncate(model, k):
model = fx.symbolic_trace(model)
new_graph= fx.Graph()
env = {}
cnt = 0
for node in list(model.graph.nodes):
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
cnt += 1
if cnt == k:
new_graph.output(env[node.name])
break
return fx.GraphModule(model, new_graph)
# NNC Lowering Pass
def remove_args(model: torch.nn.Module, args):
fx_model = fx.symbolic_trace(model)
for node in fx_model.graph.nodes:
if node.op == 'placeholder' and node.target in args:
assert(len(node.users) == 0)
fx_model.graph.erase_node(node)
fx_model.recompile()
return fx_model
class kernel_arena_scope(object):
def __enter__(self):
self.scope = te.KernelScope()
def __exit__(self, typ, val, traceback):
self.scope = None
def get_dim_args(dims):
dim_args = []
for dim in dims:
dim_args.append(te.DimArg(te.ExprHandle.int(dim), 'i' + str(len(dim_args))))
return dim_args
def get_te_shapes(shape):
return [te.ExprHandle.int(i) for i in shape]
def to_expr(x):
if isinstance(x, int):
return te.ExprHandle.int(x)
elif isinstance(x, float):
return te.ExprHandle.float(x)
else:
raise RuntimeError(f"type {type(x)} not supported")
def get_nnc_type(dtype):
if dtype == torch.float:
return te.Dtype.Float
elif dtype == torch.long:
return te.Dtype.Long
elif dtype == torch.float64:
return te.Dtype.Double
else:
raise RuntimeError("nyi")
lowering_functions = { }
def index_or_broadcast(shape, *args):
out = []
for idx, arg in enumerate(args):
if idx >= len(shape): continue
if shape[idx] == 1:
out.append(to_expr(0))
else:
out.append(arg)
return out
def ones_like_lower(name, out_shape, inp_shapes, args):
def f(*idxs):
return to_expr(1.0)
res = te.Compute(name, get_dim_args(out_shape), f)
return res
# def select_lower(name, out_shape, inp_shapes, args):
# A = args[0]
# dim = args[1]
# idx = args[2]
# import pdb; pdb.set_trace()
# def f(*idxs):
# # idxs = list(idxs)
# idxs.insert(dim, to_expr(idx))
# # idxs = [to_expr(0)]
# return A.load(idxs)
# res = te.Compute(name, get_dim_args(out_shape), f)
# return res
def dot_lower(name, out_shape, inp_shapes, args):
mul_te = te.lower('aten::mul', list(args), get_te_shapes(inp_shapes[0][0]), get_nnc_type(inp_shapes[0][1]))
res = te.lower('aten::sum', [mul_te.buf()], get_te_shapes(out_shape), get_nnc_type(inp_shapes[0][1]))
return (res.buf(), [mul_te.stmt(), res.stmt()])
def mv_lower(name, out_shape, inp_shapes, args):
A = args[0]
B = args[1]
N, M = inp_shapes[0][0]
def f(n, m):
return A.load([n, m]) * B.load([m])
# mm = te.Compute('mm', get_dim_args([N,M]), f)
# out = te.Reduce(name, get_dim_args([N]), te.Sum(), mm, get_dim_args([M]))
# return out.buf(), [mm.stmt(), out.stmt()]
C = torch._C._te.BufHandle('C', get_te_shapes([N]), get_nnc_type(inp_shapes[0][1]))
s = torch._C._te.ExternalCall(C, "nnc_aten_mv", [A, B], [])
return C, [s]
lowering_functions[torch.ops.aten.ones_like] = ones_like_lower
lowering_functions[torch.ops.aten.dot] = dot_lower
# lowering_functions[torch.ops.aten.select] = select_lower
lowering_functions[torch.ops.aten.mv] = mv_lower
func_to_aten = {
operator.getitem: torch.ops.aten.slice,
operator.add: torch.ops.aten.add,
operator.mul: torch.ops.aten.mul,
torch.mul: torch.ops.aten.mul,
torch.sin: torch.ops.aten.sin,
torch.cos: torch.ops.aten.cos,
}
def process_shape(x):
if len(x) == 0:
return [1]
return x
def lower_function(node, op, nnc_args, args):
inp_shapes = fx.node.map_aggregate(args, lambda arg: (process_shape(arg.meta['tensor_meta'].shape), arg.meta['tensor_meta'].dtype) if isinstance(arg, fx.Node) and 'tensor_meta' in arg.meta else None)
if op in lowering_functions:
out = lowering_functions[op](node.name, process_shape(node.meta['tensor_meta'].shape), inp_shapes, nnc_args)
else:
if op in func_to_aten:
op = func_to_aten[op]
aten_str = f'aten::{op.__name__}'
out_shape = get_te_shapes(process_shape(node.meta['tensor_meta'].shape))
out = te.lower(aten_str, list(nnc_args), out_shape, get_nnc_type(node.meta['tensor_meta'].dtype))
if isinstance(out, te.Tensor):
return out.buf(), [out.stmt()]
else:
return out[0], out[1]
def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module:
"""
nnc_compile(model, example_inputs) returns a function with the same args
as `model.forward`, with an extra argument corresponding to where the
output is stored. This function takes the inputs (which must be PyTorch
tensors with the same shapes as example_inputs), and passes them to an
NNC executor.
"""
fx_model = fx.symbolic_trace(model)
ShapeProp(fx_model).propagate(*example_inputs)
# This env maps from nodes to `te.ExprHandle`, which represent the output
# of an NNC computation.
env = {}
def get_te_type(node):
return get_nnc_type(node.meta['tensor_meta'].dtype)
def gen_compute(args):
te_args = [env[arg.name] for arg in args]
def lookup_env(l):
res = fx.node.map_aggregate(l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)
return res
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = fx_model
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
outs = None
inputs = []
module_attrs = []
compute_stmts = []
for node in fx_model.graph.nodes:
if node.op == 'placeholder':
# We simply map the input placeholder to a `te.Placeholder`, which
# also represents an input to the NNC computation.
shapes = get_te_shapes(node.meta['tensor_meta'].shape)
placeholder = te.Placeholder(node.name, get_te_type(node), shapes)
env[node.name] = placeholder.data()
inputs.append(placeholder)
elif node.op == 'call_function':
# This does the bulk of the work - we call `lower_function`, which
# returns a `te.ExprHandle` (the output of a NNC computation), and
# put it in our environment.
if 'tensor_meta' in node.meta:
# todo: fix kwargs handling
if node.kwargs:
raise RuntimeError("kwargs nyi")
buf, stmt = lower_function(node, node.target, lookup_env(node.args), node.args)
# if isinstance(stmt, list)
compute_stmts.extend(stmt)
env[node.name] = buf
elif node.target == getattr or node.target == operator.getitem:
# todo: handle non-tensor computations correctly
continue
elif node.op == 'output':
args = node.args
if not isinstance(args, tuple):
args = (args,)
if isinstance(args[0], tuple):
args = args[0]
te_args = lookup_env(args)
outs = (list(te_args), [(i.meta['tensor_meta'].shape, i.meta['tensor_meta'].dtype) for i in args])
elif node.op == 'get_attr':
# As NNC doesn't have any concept of state, we pull out the module
# attributes and pass them in as inputs to NNC.
module_attrs.append(node)
shapes = get_te_shapes(process_shape(node.meta['tensor_meta'].shape))
placeholder = te.Placeholder(node.name, get_te_type(node), shapes)
env[node.name] = placeholder.data()
else:
print(node.op, node.target)
raise RuntimeError("not yet implemented")
loopnest = te.LoopNest(te.Stmt(compute_stmts), outs[0])
# loopnest.inline_intermediate_bufs(True)
loopnest.simplify()
loopnest.prepare_for_codegen()
stmt = te.simplify(loopnest.root_stmt())
cg = te.construct_codegen('llvm', stmt, [te.BufferArg(x) for x in [env[i.name] for i in module_attrs] + inputs + outs[0]])
alloc_results = [torch.empty(shape, dtype=dtype) for shape,dtype in outs[1]]
if module_attrs:
module_stuff = [fetch_attr(i.target).contiguous().data for i in module_attrs]
else:
module_stuff = []
def f(*inps, out_tensors=None):
if out_tensors is None:
results = alloc_results
else:
results = out_tensors
full_inps = module_stuff + list(inps) + results
# begin = time.time()
cg.call(full_inps)
# print(time.time()-begin)
if out_tensors is None:
if len(results) == 1:
return results[0]
return results
return f
################################
# Example usage and Benchmarking
################################
def bench(f, warmup=3, iters=1000):
for _ in range(warmup):
f()
begin = time.time()
for _ in range(iters):
f()
print(time.time()-begin)
if __name__ == '__main__':
scope = te.KernelScope()
def f(a, b):
return (torch.cos(a)* torch.sin(b))[:2000]
mod = fx.symbolic_trace(f)
inps = (torch.randn(5000), torch.randn(5000))
ShapeProp(mod).propagate(*inps)
cg = nnc_compile(mod, inps)
bench(lambda: cg(*inps))
bench(lambda: f(*inps))

View File

@ -0,0 +1,108 @@
import functools
from typing import Any, Dict, NamedTuple, Optional, Set, Tuple, List, Callable, Union
import torch
from torch.fx.node import map_aggregate
import torch.utils._pytree as pytree
from functorch._C import hasPythonKey, addPythonKey, removePythonKey
from torch.fx import Tracer, GraphModule
class PythonTensor(object):
def __init__(self, out, proxy):
if isinstance(out, torch.Tensor):
self.value = torch.empty_like(out)
else:
self.value = torch.empty(out)
self.proxy = proxy
def __repr__(self):
return f"PythonTensor({tuple(self.value.shape)})"
def tensor(self):
return self.value
def __torch_function__(self, func, types, args=(), kwargs={}):
namespace, func_name = func.split("::")
func = getattr(getattr(torch.ops, namespace), func_name)
outs = kwargs['val']
rets = []
proxy_args = map_aggregate(args, lambda i: i.proxy if isinstance(i, PythonTensor) else i)
out_proxy = func(*proxy_args)
if len(outs) == 1 and isinstance(outs[0], torch.Tensor):
return [PythonTensor(outs[0], out_proxy)]
for idx, out in enumerate(outs):
if isinstance(out, torch.Tensor):
rets.append(PythonTensor(out, out_proxy[idx]))
else:
rets.append(out)
return rets
class PythonKeyTracer(Tracer):
def __init__(self):
super().__init__()
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
return False
def module_getattr(self, attr, attr_val):
if isinstance(attr_val, torch.nn.Parameter):
for n, p in self.root.named_parameters():
if attr_val is p:
if n not in self.parameter_proxy_cache:
proxy = self.create_proxy('get_attr', n, (), {})
self.parameter_proxy_cache[n] = addPythonKey(PythonTensor(attr_val.shape, proxy))
return self.parameter_proxy_cache[n]
return attr_val
def pythonkey_trace(root : Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule:
tracer = PythonKeyTracer()
graph = tracer.trace(root, concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return GraphModule(tracer.root, graph, name)
class WrapModule(torch.nn.Module):
def __init__(self, mod, inps):
super().__init__()
self.mod = mod
self.inps = inps
@functools.wraps(mod.forward)
def forward_wrapped(self, *args):
new_args = []
for inp, arg in zip(inps, args):
if isinstance(inp, torch.Tensor):
new_arg = addPythonKey(PythonTensor(inp.shape, arg))
else:
new_arg = inp
new_args.append(new_arg)
out = self.mod(*new_args)
flat_outs, out_spec = pytree.tree_flatten(out)
for idx in range(len(flat_outs)):
if hasPythonKey(flat_outs[idx]):
flat_outs[idx] = removePythonKey(flat_outs[idx]).proxy
return pytree.tree_unflatten(flat_outs, out_spec)
type(self).forward = forward_wrapped
def wrap_key(f, inps):
flat_inps, inp_spec = pytree.tree_flatten(inps)
@functools.wraps(f)
def wrapped(*args):
flat_args, args_spec = pytree.tree_flatten(args)
import pdb; pdb.set_trace()
assert(len(flat_args) == len(flat_inps))
for idx, arg in enumerate(flat_args):
if isinstance(flat_inps[idx], torch.Tensor):
flat_args[idx] = addPythonKey(PythonTensor(flat_inps[idx].shape, arg))
else:
flat_args[idx] = flat_inps[idx]
tree_args = pytree.tree_unflatten(flat_args, args_spec)
out = f(*tree_args)
flat_outs, out_spec = pytree.tree_flatten(out)
for idx in range(len(flat_outs)):
if hasPythonKey(flat_outs[idx]):
flat_outs[idx] = removePythonKey(flat_outs[idx]).proxy
return pytree.tree_unflatten(flat_outs, out_spec)
return wrapped

View File

@ -0,0 +1,199 @@
#include <functorch/csrc/PythonKey.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/jit/python/pybind_utils.h>
namespace at {
namespace functorch {
// The following are publically exposed as methods of Tensor
bool PythonTensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
TORCH_CHECK(
memory_format == at::MemoryFormat::Contiguous,
"NYI: querying is_contiguous inside of python tensor for memory_format ",
"other than torch.contiguous_format");
return is_contiguous_;
}
// The following are some internal inherited methods that we do not support.
// They should never get called.
void PythonTensorImpl::set_size(int64_t dim, int64_t new_size) {
TORCH_INTERNAL_ASSERT(false, "Can't set_size for PythonTensorImpl");
}
void PythonTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
TORCH_INTERNAL_ASSERT(false, "Can't set_stride for PythonTensorImpl");
}
void PythonTensorImpl::set_storage_offset(int64_t storage_offset) {
TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for PythonTensorImpl");
}
bool isPythonTensor(at::Tensor tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(
c10::DispatchKey::FuncTorchPython);
}
PythonTensorImpl* getPythonImpl(at::Tensor tensor) {
return static_cast<PythonTensorImpl*>(tensor.unsafeGetTensorImpl());
}
at::Tensor addPythonKey(const py::object& tensor) {
return at::detail::make_tensor<PythonTensorImpl>(tensor);
}
bool hasPythonKey(at::Tensor tensor) {
return isPythonTensor(tensor);
}
py::object removePythonKey(at::Tensor tensor) {
assert(isPythonTensor(tensor));
return getPythonImpl(tensor)->value_;
}
py::object pyIdentity(py::object x) {
return x;
}
template <class T>
py::tuple vectorToPyTuple(
const std::vector<T>& data,
std::function<py::object(T)> converter) {
PyObject* tuple = PyTuple_New(data.size());
if (!tuple)
throw std::runtime_error("Unable to allocate memory for Python tuple");
for (unsigned int i = 0; i < data.size(); i++) {
PyObject* num = converter(data[i]).ptr();
if (!num) {
Py_DECREF(tuple);
throw std::runtime_error("Unable to allocate memory for Python tuple");
}
Py_INCREF(
num); // todo: dunno?? Need it to fix segfaults, but probably not right
PyTuple_SET_ITEM(tuple, i, num);
}
return py::cast<py::tuple>(tuple);
}
void pythonFallBack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
const auto arguments = torch::jit::last(stack, num_arguments);
py::gil_scoped_acquire g;
std::vector<py::object> pyArgs;
std::vector<py::object> pyTensorArgs;
std::vector<torch::jit::IValue> unwrappedArgs;
for (int idx = 0; idx < arguments.size(); idx++) {
const auto ivalue = arguments[idx];
if (ivalue.isTensor() && isPythonTensor(ivalue.toTensor())) {
auto pyTensor = getPythonImpl(ivalue.toTensor());
pyArgs.push_back(pyTensor->value_);
pyTensorArgs.push_back(pyTensor->value_);
unwrappedArgs.push_back(getValueFromPyTensor(pyTensor->value_));
} else {
if (ivalue.isList()) {
auto l = ivalue.toList();
auto unwrappedL =
c10::impl::GenericList(op.schema()
.arguments()[idx]
.type()
->expectRef<torch::jit::ListType>()
.getElementType());
py::list pyL;
for (int jdx = 0; jdx < l.size(); jdx++) {
auto nv = l.get(jdx);
if (nv.isTensor() && isPythonTensor(nv.toTensor())) {
auto pyTensor = getPythonImpl(nv.toTensor());
pyTensorArgs.push_back(pyTensor->value_);
unwrappedL.push_back(getValueFromPyTensor(pyTensor->value_));
pyL.append(pyTensor->value_);
} else {
unwrappedL.push_back(l.get(jdx));
pyL.append(torch::jit::toPyObject(l.get(jdx)));
}
}
pyArgs.push_back(pyL);
unwrappedArgs.push_back(unwrappedL);
} else {
pyArgs.push_back(torch::jit::toPyObject(ivalue));
unwrappedArgs.push_back(ivalue);
}
}
}
py::object torch_function =
PyObject_FastGetAttrString(pyTensorArgs[0].ptr(), "__torch_function__");
for (auto v : unwrappedArgs) {
torch::jit::push(stack, v);
}
op.callBoxed(stack);
std::vector<c10::IValue> realOuts = torch::jit::pop(*stack, num_returns);
py::tuple py_types = py::cast<py::tuple>(
vectorToPyTuple<py::object>(pyArgs, [](py::object x) -> py::object {
return py::reinterpret_borrow<py::object>(PyObject_Type(x.ptr()));
}));
py::dict kwargs;
std::vector<py::object> t;
for (auto x : realOuts) {
t.push_back(torch::jit::toPyObject(x));
}
kwargs["val"] = vectorToPyTuple<py::object>(t, pyIdentity);
std::string func_name = op.operator_name().name;
std::string delimiter = "aten::";
func_name = func_name.substr(func_name.find(delimiter) + delimiter.size());
py::object torch_api_function =
PyObject_FastGetAttrString(THPVariableClass, (char*)func_name.c_str());
torch_api_function = py::str(op.operator_name().name);
auto pyTupleArgs = vectorToPyTuple<py::object>(pyArgs, pyIdentity);
auto out = PyObject_CallFunctionObjArgs(
torch_function.ptr(),
torch_api_function.ptr(),
py_types.ptr(),
pyTupleArgs.ptr(),
kwargs.ptr(),
0);
if (out == nullptr) {
throw std::runtime_error("call failed");
}
py::list outs = py::cast<py::list>(out);
torch::jit::drop(stack, num_arguments);
std::vector<c10::IValue> ret_ivalues;
assert(outs.size() == op.schema().returns().size());
for (int idx = 0; idx < outs.size(); idx++) {
auto ret_type = op.schema().returns()[idx].type();
if (ret_type->kind() == c10::TensorType::Kind) {
torch::jit::push(stack, addPythonKey(py::cast<py::object>(outs[idx])));
} else {
auto ivalue_out = torch::jit::toTypeInferredIValue(outs[idx]);
torch::jit::push(stack, ivalue_out);
}
}
return;
}
TORCH_LIBRARY_IMPL(_, FuncTorchPython, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallBack>());
}
c10::intrusive_ptr<c10::TensorImpl> PythonTensorImpl::shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
auto impl = c10::make_intrusive<PythonTensorImpl>(value_);
copy_tensor_metadata(
/*src_impl=*/this,
/*dest_impl=*/impl.get(),
/*version_counter=*/version_counter,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->set_version_counter(version_counter);
impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
return impl;
}
c10::intrusive_ptr<c10::TensorImpl> PythonTensorImpl::shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const {
auto impl = c10::make_intrusive<PythonTensorImpl>(value_);
impl->set_version_counter(version_counter);
impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
return impl;
}
}}

View File

@ -0,0 +1,58 @@
#include <torch/csrc/utils/python_strings.h>
#include <ATen/Tensor.h>
#include <torch/library.h>
namespace at {
namespace functorch {
inline at::Tensor getValueFromPyTensor(const py::object& pyTensor) {
auto out = pyTensor.attr("value").cast<at::Tensor>();
return out;
}
struct TORCH_API PythonTensorImpl : public c10::TensorImpl {
explicit PythonTensorImpl(py::object value): TensorImpl(c10::DispatchKeySet(c10::DispatchKey::FuncTorchPython), getValueFromPyTensor(value).dtype(), c10::Device(at::kCPU)), value_(value) {
set_storage_access_should_throw();
// asm("int $0x3\n");
auto tensor = getValueFromPyTensor(value_);
const auto value_sizes = tensor.sizes();
const auto value_strides = tensor.strides();
sizes_and_strides_.resize(tensor.dim());
for (int64_t dim = 0; dim < tensor.dim(); dim++) {
sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(dim);
sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(dim);
}
refresh_numel();
refresh_contiguous();
}
c10::intrusive_ptr<c10::TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const;
c10::intrusive_ptr<c10::TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const;
// Returns a reference to BatchDims that represent which dimensions of this
// tensor are private.
// Override a bunch of methods inherited from TensorImpl to return error messages.
bool is_contiguous(at::MemoryFormat memory_format) const override;
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;
// #ifdef DEBUG
// bool has_storage() const override;
// #endif
py::object value_;
};
PythonTensorImpl* getPythonImpl(at::Tensor tensor);
at::Tensor addPythonKey(const py::object& tensor);
bool hasPythonKey(at::Tensor tensor);
py::object removePythonKey(at::Tensor tensor);
}}

View File

@ -5,6 +5,7 @@
#include <functorch/csrc/DynamicLayer.h>
#include <functorch/csrc/BatchedTensorImpl.h>
#include <functorch/csrc/VmapTransforms.h>
#include <functorch/csrc/PythonKey.h>
namespace at {
namespace functorch {
@ -183,4 +184,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("_unwrap_for_grad", &at::functorch::_unwrap_for_grad, "add batch dim");
m.def("dlevel", &at::functorch::dlevel, "add batch dim");
m.def("dump_tensor", &at::functorch::dump_tensor, "add batch dim");
m.def(
"addPythonKey",
&at::functorch::addPythonKey,
py::return_value_policy::copy); // not sure if needed - cargo cult
m.def("removePythonKey", &at::functorch::removePythonKey);
m.def("hasPythonKey", &at::functorch::hasPythonKey);
}

View File

@ -563,7 +563,7 @@ class TestExamplesCorrectness(TestCase):
return mse_loss(v_f, y2)
task = sample_tasks(num_tasks, K)
# Compute with vmap+grad
inner_losses = vmap(partial(get_loss_for_task, True))\
(task[0], task[1], task[2], task[3])
@ -609,7 +609,7 @@ class TestExamplesCorrectness(TestCase):
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(64, n_way)).to(device).to(dtype)
params, buffers, fnet, _, _, = make_functional_with_buffers(net)
net = (params, buffers, fnet)
@ -739,7 +739,7 @@ class TestExamplesCorrectness(TestCase):
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_resnet18_per_sample_grads(self, device):
# Straight out of opacus
# Straight out of opacus
def _replace_child(
root: nn.Module, child_name: str, converter: Callable[[nn.Module], nn.Module]
) -> None: