mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Added python key stuff that builds with master
This commit is contained in:
@ -4,7 +4,8 @@ from . import _C
|
||||
from ._src.vmap import vmap
|
||||
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
|
||||
from ._src.make_functional import make_functional, make_functional_with_buffers
|
||||
from ._src.python_key import wrap_key, WrapModule
|
||||
from ._src.python_key import wrap_key, WrapModule, PythonTensor, pythonkey_trace
|
||||
from ._src.nnc_compile import nnc_compile
|
||||
|
||||
|
||||
# Monkeypatching lol
|
||||
|
287
functorch/functorch/_src/nnc_compile.py
Normal file
287
functorch/functorch/_src/nnc_compile.py
Normal file
@ -0,0 +1,287 @@
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch._C._te as te
|
||||
import torch.fx as fx
|
||||
from torch.fx import map_arg
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
import operator
|
||||
|
||||
def truncate(model, k):
|
||||
model = fx.symbolic_trace(model)
|
||||
new_graph= fx.Graph()
|
||||
env = {}
|
||||
|
||||
cnt = 0
|
||||
for node in list(model.graph.nodes):
|
||||
new_node = new_graph.node_copy(node, lambda x: env[x.name])
|
||||
env[node.name] = new_node
|
||||
cnt += 1
|
||||
if cnt == k:
|
||||
new_graph.output(env[node.name])
|
||||
break
|
||||
|
||||
return fx.GraphModule(model, new_graph)
|
||||
|
||||
# NNC Lowering Pass
|
||||
def remove_args(model: torch.nn.Module, args):
|
||||
fx_model = fx.symbolic_trace(model)
|
||||
for node in fx_model.graph.nodes:
|
||||
if node.op == 'placeholder' and node.target in args:
|
||||
assert(len(node.users) == 0)
|
||||
fx_model.graph.erase_node(node)
|
||||
fx_model.recompile()
|
||||
return fx_model
|
||||
|
||||
class kernel_arena_scope(object):
|
||||
def __enter__(self):
|
||||
self.scope = te.KernelScope()
|
||||
|
||||
def __exit__(self, typ, val, traceback):
|
||||
self.scope = None
|
||||
|
||||
def get_dim_args(dims):
|
||||
dim_args = []
|
||||
for dim in dims:
|
||||
dim_args.append(te.DimArg(te.ExprHandle.int(dim), 'i' + str(len(dim_args))))
|
||||
return dim_args
|
||||
|
||||
def get_te_shapes(shape):
|
||||
return [te.ExprHandle.int(i) for i in shape]
|
||||
|
||||
def to_expr(x):
|
||||
if isinstance(x, int):
|
||||
return te.ExprHandle.int(x)
|
||||
elif isinstance(x, float):
|
||||
return te.ExprHandle.float(x)
|
||||
else:
|
||||
raise RuntimeError(f"type {type(x)} not supported")
|
||||
|
||||
def get_nnc_type(dtype):
|
||||
if dtype == torch.float:
|
||||
return te.Dtype.Float
|
||||
elif dtype == torch.long:
|
||||
return te.Dtype.Long
|
||||
elif dtype == torch.float64:
|
||||
return te.Dtype.Double
|
||||
else:
|
||||
raise RuntimeError("nyi")
|
||||
|
||||
|
||||
lowering_functions = { }
|
||||
def index_or_broadcast(shape, *args):
|
||||
out = []
|
||||
for idx, arg in enumerate(args):
|
||||
if idx >= len(shape): continue
|
||||
if shape[idx] == 1:
|
||||
out.append(to_expr(0))
|
||||
else:
|
||||
out.append(arg)
|
||||
return out
|
||||
|
||||
def ones_like_lower(name, out_shape, inp_shapes, args):
|
||||
def f(*idxs):
|
||||
return to_expr(1.0)
|
||||
res = te.Compute(name, get_dim_args(out_shape), f)
|
||||
return res
|
||||
|
||||
# def select_lower(name, out_shape, inp_shapes, args):
|
||||
# A = args[0]
|
||||
# dim = args[1]
|
||||
# idx = args[2]
|
||||
# import pdb; pdb.set_trace()
|
||||
# def f(*idxs):
|
||||
# # idxs = list(idxs)
|
||||
# idxs.insert(dim, to_expr(idx))
|
||||
# # idxs = [to_expr(0)]
|
||||
# return A.load(idxs)
|
||||
# res = te.Compute(name, get_dim_args(out_shape), f)
|
||||
# return res
|
||||
|
||||
def dot_lower(name, out_shape, inp_shapes, args):
|
||||
mul_te = te.lower('aten::mul', list(args), get_te_shapes(inp_shapes[0][0]), get_nnc_type(inp_shapes[0][1]))
|
||||
res = te.lower('aten::sum', [mul_te.buf()], get_te_shapes(out_shape), get_nnc_type(inp_shapes[0][1]))
|
||||
return (res.buf(), [mul_te.stmt(), res.stmt()])
|
||||
|
||||
def mv_lower(name, out_shape, inp_shapes, args):
|
||||
A = args[0]
|
||||
B = args[1]
|
||||
N, M = inp_shapes[0][0]
|
||||
|
||||
def f(n, m):
|
||||
return A.load([n, m]) * B.load([m])
|
||||
# mm = te.Compute('mm', get_dim_args([N,M]), f)
|
||||
# out = te.Reduce(name, get_dim_args([N]), te.Sum(), mm, get_dim_args([M]))
|
||||
# return out.buf(), [mm.stmt(), out.stmt()]
|
||||
C = torch._C._te.BufHandle('C', get_te_shapes([N]), get_nnc_type(inp_shapes[0][1]))
|
||||
s = torch._C._te.ExternalCall(C, "nnc_aten_mv", [A, B], [])
|
||||
return C, [s]
|
||||
|
||||
lowering_functions[torch.ops.aten.ones_like] = ones_like_lower
|
||||
lowering_functions[torch.ops.aten.dot] = dot_lower
|
||||
# lowering_functions[torch.ops.aten.select] = select_lower
|
||||
lowering_functions[torch.ops.aten.mv] = mv_lower
|
||||
|
||||
func_to_aten = {
|
||||
operator.getitem: torch.ops.aten.slice,
|
||||
operator.add: torch.ops.aten.add,
|
||||
operator.mul: torch.ops.aten.mul,
|
||||
torch.mul: torch.ops.aten.mul,
|
||||
torch.sin: torch.ops.aten.sin,
|
||||
torch.cos: torch.ops.aten.cos,
|
||||
}
|
||||
|
||||
|
||||
def process_shape(x):
|
||||
if len(x) == 0:
|
||||
return [1]
|
||||
return x
|
||||
|
||||
def lower_function(node, op, nnc_args, args):
|
||||
inp_shapes = fx.node.map_aggregate(args, lambda arg: (process_shape(arg.meta['tensor_meta'].shape), arg.meta['tensor_meta'].dtype) if isinstance(arg, fx.Node) and 'tensor_meta' in arg.meta else None)
|
||||
if op in lowering_functions:
|
||||
out = lowering_functions[op](node.name, process_shape(node.meta['tensor_meta'].shape), inp_shapes, nnc_args)
|
||||
else:
|
||||
if op in func_to_aten:
|
||||
op = func_to_aten[op]
|
||||
aten_str = f'aten::{op.__name__}'
|
||||
out_shape = get_te_shapes(process_shape(node.meta['tensor_meta'].shape))
|
||||
out = te.lower(aten_str, list(nnc_args), out_shape, get_nnc_type(node.meta['tensor_meta'].dtype))
|
||||
if isinstance(out, te.Tensor):
|
||||
return out.buf(), [out.stmt()]
|
||||
else:
|
||||
return out[0], out[1]
|
||||
|
||||
def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module:
|
||||
"""
|
||||
nnc_compile(model, example_inputs) returns a function with the same args
|
||||
as `model.forward`, with an extra argument corresponding to where the
|
||||
output is stored. This function takes the inputs (which must be PyTorch
|
||||
tensors with the same shapes as example_inputs), and passes them to an
|
||||
NNC executor.
|
||||
"""
|
||||
fx_model = fx.symbolic_trace(model)
|
||||
ShapeProp(fx_model).propagate(*example_inputs)
|
||||
|
||||
# This env maps from nodes to `te.ExprHandle`, which represent the output
|
||||
# of an NNC computation.
|
||||
env = {}
|
||||
|
||||
|
||||
def get_te_type(node):
|
||||
return get_nnc_type(node.meta['tensor_meta'].dtype)
|
||||
|
||||
def gen_compute(args):
|
||||
te_args = [env[arg.name] for arg in args]
|
||||
|
||||
def lookup_env(l):
|
||||
res = fx.node.map_aggregate(l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)
|
||||
return res
|
||||
|
||||
def fetch_attr(target : str):
|
||||
target_atoms = target.split('.')
|
||||
attr_itr = fx_model
|
||||
for i, atom in enumerate(target_atoms):
|
||||
if not hasattr(attr_itr, atom):
|
||||
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
return attr_itr
|
||||
|
||||
outs = None
|
||||
inputs = []
|
||||
module_attrs = []
|
||||
compute_stmts = []
|
||||
for node in fx_model.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
# We simply map the input placeholder to a `te.Placeholder`, which
|
||||
# also represents an input to the NNC computation.
|
||||
shapes = get_te_shapes(node.meta['tensor_meta'].shape)
|
||||
placeholder = te.Placeholder(node.name, get_te_type(node), shapes)
|
||||
env[node.name] = placeholder.data()
|
||||
inputs.append(placeholder)
|
||||
elif node.op == 'call_function':
|
||||
# This does the bulk of the work - we call `lower_function`, which
|
||||
# returns a `te.ExprHandle` (the output of a NNC computation), and
|
||||
# put it in our environment.
|
||||
if 'tensor_meta' in node.meta:
|
||||
# todo: fix kwargs handling
|
||||
if node.kwargs:
|
||||
raise RuntimeError("kwargs nyi")
|
||||
buf, stmt = lower_function(node, node.target, lookup_env(node.args), node.args)
|
||||
# if isinstance(stmt, list)
|
||||
compute_stmts.extend(stmt)
|
||||
env[node.name] = buf
|
||||
elif node.target == getattr or node.target == operator.getitem:
|
||||
# todo: handle non-tensor computations correctly
|
||||
continue
|
||||
elif node.op == 'output':
|
||||
args = node.args
|
||||
if not isinstance(args, tuple):
|
||||
args = (args,)
|
||||
if isinstance(args[0], tuple):
|
||||
args = args[0]
|
||||
te_args = lookup_env(args)
|
||||
outs = (list(te_args), [(i.meta['tensor_meta'].shape, i.meta['tensor_meta'].dtype) for i in args])
|
||||
elif node.op == 'get_attr':
|
||||
# As NNC doesn't have any concept of state, we pull out the module
|
||||
# attributes and pass them in as inputs to NNC.
|
||||
module_attrs.append(node)
|
||||
shapes = get_te_shapes(process_shape(node.meta['tensor_meta'].shape))
|
||||
placeholder = te.Placeholder(node.name, get_te_type(node), shapes)
|
||||
env[node.name] = placeholder.data()
|
||||
else:
|
||||
print(node.op, node.target)
|
||||
raise RuntimeError("not yet implemented")
|
||||
|
||||
|
||||
loopnest = te.LoopNest(te.Stmt(compute_stmts), outs[0])
|
||||
# loopnest.inline_intermediate_bufs(True)
|
||||
loopnest.simplify()
|
||||
loopnest.prepare_for_codegen()
|
||||
stmt = te.simplify(loopnest.root_stmt())
|
||||
cg = te.construct_codegen('llvm', stmt, [te.BufferArg(x) for x in [env[i.name] for i in module_attrs] + inputs + outs[0]])
|
||||
alloc_results = [torch.empty(shape, dtype=dtype) for shape,dtype in outs[1]]
|
||||
if module_attrs:
|
||||
module_stuff = [fetch_attr(i.target).contiguous().data for i in module_attrs]
|
||||
else:
|
||||
module_stuff = []
|
||||
def f(*inps, out_tensors=None):
|
||||
if out_tensors is None:
|
||||
results = alloc_results
|
||||
else:
|
||||
results = out_tensors
|
||||
full_inps = module_stuff + list(inps) + results
|
||||
# begin = time.time()
|
||||
cg.call(full_inps)
|
||||
# print(time.time()-begin)
|
||||
if out_tensors is None:
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
return results
|
||||
return f
|
||||
|
||||
|
||||
################################
|
||||
# Example usage and Benchmarking
|
||||
################################
|
||||
|
||||
def bench(f, warmup=3, iters=1000):
|
||||
for _ in range(warmup):
|
||||
f()
|
||||
begin = time.time()
|
||||
for _ in range(iters):
|
||||
f()
|
||||
print(time.time()-begin)
|
||||
|
||||
if __name__ == '__main__':
|
||||
scope = te.KernelScope()
|
||||
def f(a, b):
|
||||
return (torch.cos(a)* torch.sin(b))[:2000]
|
||||
|
||||
mod = fx.symbolic_trace(f)
|
||||
inps = (torch.randn(5000), torch.randn(5000))
|
||||
ShapeProp(mod).propagate(*inps)
|
||||
cg = nnc_compile(mod, inps)
|
||||
bench(lambda: cg(*inps))
|
||||
bench(lambda: f(*inps))
|
108
functorch/functorch/_src/python_key.py
Normal file
108
functorch/functorch/_src/python_key.py
Normal file
@ -0,0 +1,108 @@
|
||||
import functools
|
||||
from typing import Any, Dict, NamedTuple, Optional, Set, Tuple, List, Callable, Union
|
||||
import torch
|
||||
from torch.fx.node import map_aggregate
|
||||
import torch.utils._pytree as pytree
|
||||
from functorch._C import hasPythonKey, addPythonKey, removePythonKey
|
||||
from torch.fx import Tracer, GraphModule
|
||||
|
||||
class PythonTensor(object):
|
||||
def __init__(self, out, proxy):
|
||||
if isinstance(out, torch.Tensor):
|
||||
self.value = torch.empty_like(out)
|
||||
else:
|
||||
self.value = torch.empty(out)
|
||||
self.proxy = proxy
|
||||
|
||||
def __repr__(self):
|
||||
return f"PythonTensor({tuple(self.value.shape)})"
|
||||
|
||||
def tensor(self):
|
||||
return self.value
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs={}):
|
||||
namespace, func_name = func.split("::")
|
||||
func = getattr(getattr(torch.ops, namespace), func_name)
|
||||
outs = kwargs['val']
|
||||
rets = []
|
||||
proxy_args = map_aggregate(args, lambda i: i.proxy if isinstance(i, PythonTensor) else i)
|
||||
out_proxy = func(*proxy_args)
|
||||
if len(outs) == 1 and isinstance(outs[0], torch.Tensor):
|
||||
return [PythonTensor(outs[0], out_proxy)]
|
||||
for idx, out in enumerate(outs):
|
||||
if isinstance(out, torch.Tensor):
|
||||
rets.append(PythonTensor(out, out_proxy[idx]))
|
||||
else:
|
||||
rets.append(out)
|
||||
return rets
|
||||
|
||||
class PythonKeyTracer(Tracer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
||||
return False
|
||||
|
||||
def module_getattr(self, attr, attr_val):
|
||||
if isinstance(attr_val, torch.nn.Parameter):
|
||||
for n, p in self.root.named_parameters():
|
||||
if attr_val is p:
|
||||
if n not in self.parameter_proxy_cache:
|
||||
proxy = self.create_proxy('get_attr', n, (), {})
|
||||
self.parameter_proxy_cache[n] = addPythonKey(PythonTensor(attr_val.shape, proxy))
|
||||
return self.parameter_proxy_cache[n]
|
||||
return attr_val
|
||||
|
||||
def pythonkey_trace(root : Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule:
|
||||
tracer = PythonKeyTracer()
|
||||
graph = tracer.trace(root, concrete_args)
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
return GraphModule(tracer.root, graph, name)
|
||||
|
||||
|
||||
class WrapModule(torch.nn.Module):
|
||||
def __init__(self, mod, inps):
|
||||
super().__init__()
|
||||
self.mod = mod
|
||||
self.inps = inps
|
||||
@functools.wraps(mod.forward)
|
||||
def forward_wrapped(self, *args):
|
||||
new_args = []
|
||||
for inp, arg in zip(inps, args):
|
||||
if isinstance(inp, torch.Tensor):
|
||||
new_arg = addPythonKey(PythonTensor(inp.shape, arg))
|
||||
else:
|
||||
new_arg = inp
|
||||
new_args.append(new_arg)
|
||||
out = self.mod(*new_args)
|
||||
|
||||
flat_outs, out_spec = pytree.tree_flatten(out)
|
||||
for idx in range(len(flat_outs)):
|
||||
if hasPythonKey(flat_outs[idx]):
|
||||
flat_outs[idx] = removePythonKey(flat_outs[idx]).proxy
|
||||
return pytree.tree_unflatten(flat_outs, out_spec)
|
||||
|
||||
type(self).forward = forward_wrapped
|
||||
|
||||
def wrap_key(f, inps):
|
||||
flat_inps, inp_spec = pytree.tree_flatten(inps)
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args):
|
||||
flat_args, args_spec = pytree.tree_flatten(args)
|
||||
import pdb; pdb.set_trace()
|
||||
assert(len(flat_args) == len(flat_inps))
|
||||
for idx, arg in enumerate(flat_args):
|
||||
if isinstance(flat_inps[idx], torch.Tensor):
|
||||
flat_args[idx] = addPythonKey(PythonTensor(flat_inps[idx].shape, arg))
|
||||
else:
|
||||
flat_args[idx] = flat_inps[idx]
|
||||
tree_args = pytree.tree_unflatten(flat_args, args_spec)
|
||||
out = f(*tree_args)
|
||||
|
||||
flat_outs, out_spec = pytree.tree_flatten(out)
|
||||
for idx in range(len(flat_outs)):
|
||||
if hasPythonKey(flat_outs[idx]):
|
||||
flat_outs[idx] = removePythonKey(flat_outs[idx]).proxy
|
||||
return pytree.tree_unflatten(flat_outs, out_spec)
|
||||
|
||||
return wrapped
|
199
functorch/functorch/csrc/PythonKey.cpp
Normal file
199
functorch/functorch/csrc/PythonKey.cpp
Normal file
@ -0,0 +1,199 @@
|
||||
#include <functorch/csrc/PythonKey.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
// The following are publically exposed as methods of Tensor
|
||||
bool PythonTensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
|
||||
TORCH_CHECK(
|
||||
memory_format == at::MemoryFormat::Contiguous,
|
||||
"NYI: querying is_contiguous inside of python tensor for memory_format ",
|
||||
"other than torch.contiguous_format");
|
||||
return is_contiguous_;
|
||||
}
|
||||
|
||||
// The following are some internal inherited methods that we do not support.
|
||||
// They should never get called.
|
||||
void PythonTensorImpl::set_size(int64_t dim, int64_t new_size) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Can't set_size for PythonTensorImpl");
|
||||
}
|
||||
void PythonTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Can't set_stride for PythonTensorImpl");
|
||||
}
|
||||
void PythonTensorImpl::set_storage_offset(int64_t storage_offset) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for PythonTensorImpl");
|
||||
}
|
||||
|
||||
bool isPythonTensor(at::Tensor tensor) {
|
||||
return tensor.unsafeGetTensorImpl()->key_set().has(
|
||||
c10::DispatchKey::FuncTorchPython);
|
||||
}
|
||||
PythonTensorImpl* getPythonImpl(at::Tensor tensor) {
|
||||
return static_cast<PythonTensorImpl*>(tensor.unsafeGetTensorImpl());
|
||||
}
|
||||
|
||||
at::Tensor addPythonKey(const py::object& tensor) {
|
||||
return at::detail::make_tensor<PythonTensorImpl>(tensor);
|
||||
}
|
||||
bool hasPythonKey(at::Tensor tensor) {
|
||||
return isPythonTensor(tensor);
|
||||
}
|
||||
|
||||
py::object removePythonKey(at::Tensor tensor) {
|
||||
assert(isPythonTensor(tensor));
|
||||
return getPythonImpl(tensor)->value_;
|
||||
}
|
||||
|
||||
|
||||
py::object pyIdentity(py::object x) {
|
||||
return x;
|
||||
}
|
||||
template <class T>
|
||||
py::tuple vectorToPyTuple(
|
||||
const std::vector<T>& data,
|
||||
std::function<py::object(T)> converter) {
|
||||
PyObject* tuple = PyTuple_New(data.size());
|
||||
if (!tuple)
|
||||
throw std::runtime_error("Unable to allocate memory for Python tuple");
|
||||
for (unsigned int i = 0; i < data.size(); i++) {
|
||||
PyObject* num = converter(data[i]).ptr();
|
||||
if (!num) {
|
||||
Py_DECREF(tuple);
|
||||
throw std::runtime_error("Unable to allocate memory for Python tuple");
|
||||
}
|
||||
Py_INCREF(
|
||||
num); // todo: dunno?? Need it to fix segfaults, but probably not right
|
||||
PyTuple_SET_ITEM(tuple, i, num);
|
||||
}
|
||||
return py::cast<py::tuple>(tuple);
|
||||
}
|
||||
|
||||
void pythonFallBack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
const auto arguments = torch::jit::last(stack, num_arguments);
|
||||
py::gil_scoped_acquire g;
|
||||
std::vector<py::object> pyArgs;
|
||||
std::vector<py::object> pyTensorArgs;
|
||||
std::vector<torch::jit::IValue> unwrappedArgs;
|
||||
for (int idx = 0; idx < arguments.size(); idx++) {
|
||||
const auto ivalue = arguments[idx];
|
||||
if (ivalue.isTensor() && isPythonTensor(ivalue.toTensor())) {
|
||||
auto pyTensor = getPythonImpl(ivalue.toTensor());
|
||||
pyArgs.push_back(pyTensor->value_);
|
||||
pyTensorArgs.push_back(pyTensor->value_);
|
||||
unwrappedArgs.push_back(getValueFromPyTensor(pyTensor->value_));
|
||||
} else {
|
||||
if (ivalue.isList()) {
|
||||
auto l = ivalue.toList();
|
||||
auto unwrappedL =
|
||||
c10::impl::GenericList(op.schema()
|
||||
.arguments()[idx]
|
||||
.type()
|
||||
->expectRef<torch::jit::ListType>()
|
||||
.getElementType());
|
||||
py::list pyL;
|
||||
|
||||
for (int jdx = 0; jdx < l.size(); jdx++) {
|
||||
auto nv = l.get(jdx);
|
||||
if (nv.isTensor() && isPythonTensor(nv.toTensor())) {
|
||||
auto pyTensor = getPythonImpl(nv.toTensor());
|
||||
pyTensorArgs.push_back(pyTensor->value_);
|
||||
unwrappedL.push_back(getValueFromPyTensor(pyTensor->value_));
|
||||
pyL.append(pyTensor->value_);
|
||||
} else {
|
||||
unwrappedL.push_back(l.get(jdx));
|
||||
pyL.append(torch::jit::toPyObject(l.get(jdx)));
|
||||
}
|
||||
}
|
||||
pyArgs.push_back(pyL);
|
||||
unwrappedArgs.push_back(unwrappedL);
|
||||
} else {
|
||||
pyArgs.push_back(torch::jit::toPyObject(ivalue));
|
||||
unwrappedArgs.push_back(ivalue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::object torch_function =
|
||||
PyObject_FastGetAttrString(pyTensorArgs[0].ptr(), "__torch_function__");
|
||||
for (auto v : unwrappedArgs) {
|
||||
torch::jit::push(stack, v);
|
||||
}
|
||||
op.callBoxed(stack);
|
||||
std::vector<c10::IValue> realOuts = torch::jit::pop(*stack, num_returns);
|
||||
py::tuple py_types = py::cast<py::tuple>(
|
||||
vectorToPyTuple<py::object>(pyArgs, [](py::object x) -> py::object {
|
||||
return py::reinterpret_borrow<py::object>(PyObject_Type(x.ptr()));
|
||||
}));
|
||||
|
||||
py::dict kwargs;
|
||||
std::vector<py::object> t;
|
||||
for (auto x : realOuts) {
|
||||
t.push_back(torch::jit::toPyObject(x));
|
||||
}
|
||||
kwargs["val"] = vectorToPyTuple<py::object>(t, pyIdentity);
|
||||
|
||||
std::string func_name = op.operator_name().name;
|
||||
std::string delimiter = "aten::";
|
||||
func_name = func_name.substr(func_name.find(delimiter) + delimiter.size());
|
||||
py::object torch_api_function =
|
||||
PyObject_FastGetAttrString(THPVariableClass, (char*)func_name.c_str());
|
||||
|
||||
torch_api_function = py::str(op.operator_name().name);
|
||||
auto pyTupleArgs = vectorToPyTuple<py::object>(pyArgs, pyIdentity);
|
||||
|
||||
auto out = PyObject_CallFunctionObjArgs(
|
||||
torch_function.ptr(),
|
||||
torch_api_function.ptr(),
|
||||
py_types.ptr(),
|
||||
pyTupleArgs.ptr(),
|
||||
kwargs.ptr(),
|
||||
0);
|
||||
if (out == nullptr) {
|
||||
throw std::runtime_error("call failed");
|
||||
}
|
||||
py::list outs = py::cast<py::list>(out);
|
||||
torch::jit::drop(stack, num_arguments);
|
||||
std::vector<c10::IValue> ret_ivalues;
|
||||
assert(outs.size() == op.schema().returns().size());
|
||||
for (int idx = 0; idx < outs.size(); idx++) {
|
||||
auto ret_type = op.schema().returns()[idx].type();
|
||||
if (ret_type->kind() == c10::TensorType::Kind) {
|
||||
torch::jit::push(stack, addPythonKey(py::cast<py::object>(outs[idx])));
|
||||
} else {
|
||||
auto ivalue_out = torch::jit::toTypeInferredIValue(outs[idx]);
|
||||
torch::jit::push(stack, ivalue_out);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
TORCH_LIBRARY_IMPL(_, FuncTorchPython, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallBack>());
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10::TensorImpl> PythonTensorImpl::shallow_copy_and_detach(
|
||||
c10::VariableVersion&& version_counter,
|
||||
bool allow_tensor_metadata_change) const {
|
||||
auto impl = c10::make_intrusive<PythonTensorImpl>(value_);
|
||||
copy_tensor_metadata(
|
||||
/*src_impl=*/this,
|
||||
/*dest_impl=*/impl.get(),
|
||||
/*version_counter=*/version_counter,
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
impl->set_version_counter(version_counter);
|
||||
impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
return impl;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10::TensorImpl> PythonTensorImpl::shallow_copy_and_detach(
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const {
|
||||
auto impl = c10::make_intrusive<PythonTensorImpl>(value_);
|
||||
impl->set_version_counter(version_counter);
|
||||
impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
return impl;
|
||||
}
|
||||
}}
|
58
functorch/functorch/csrc/PythonKey.h
Normal file
58
functorch/functorch/csrc/PythonKey.h
Normal file
@ -0,0 +1,58 @@
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
inline at::Tensor getValueFromPyTensor(const py::object& pyTensor) {
|
||||
auto out = pyTensor.attr("value").cast<at::Tensor>();
|
||||
return out;
|
||||
}
|
||||
|
||||
struct TORCH_API PythonTensorImpl : public c10::TensorImpl {
|
||||
explicit PythonTensorImpl(py::object value): TensorImpl(c10::DispatchKeySet(c10::DispatchKey::FuncTorchPython), getValueFromPyTensor(value).dtype(), c10::Device(at::kCPU)), value_(value) {
|
||||
set_storage_access_should_throw();
|
||||
// asm("int $0x3\n");
|
||||
auto tensor = getValueFromPyTensor(value_);
|
||||
|
||||
const auto value_sizes = tensor.sizes();
|
||||
const auto value_strides = tensor.strides();
|
||||
sizes_and_strides_.resize(tensor.dim());
|
||||
for (int64_t dim = 0; dim < tensor.dim(); dim++) {
|
||||
sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(dim);
|
||||
sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(dim);
|
||||
}
|
||||
refresh_numel();
|
||||
refresh_contiguous();
|
||||
}
|
||||
c10::intrusive_ptr<c10::TensorImpl> shallow_copy_and_detach(
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const;
|
||||
|
||||
c10::intrusive_ptr<c10::TensorImpl> shallow_copy_and_detach(
|
||||
c10::VariableVersion&& version_counter,
|
||||
bool allow_tensor_metadata_change) const;
|
||||
|
||||
|
||||
// Returns a reference to BatchDims that represent which dimensions of this
|
||||
// tensor are private.
|
||||
|
||||
// Override a bunch of methods inherited from TensorImpl to return error messages.
|
||||
bool is_contiguous(at::MemoryFormat memory_format) const override;
|
||||
void set_size(int64_t dim, int64_t new_size) override;
|
||||
void set_stride(int64_t dim, int64_t new_stride) override;
|
||||
void set_storage_offset(int64_t storage_offset) override;
|
||||
// #ifdef DEBUG
|
||||
// bool has_storage() const override;
|
||||
// #endif
|
||||
|
||||
py::object value_;
|
||||
};
|
||||
|
||||
PythonTensorImpl* getPythonImpl(at::Tensor tensor);
|
||||
|
||||
at::Tensor addPythonKey(const py::object& tensor);
|
||||
bool hasPythonKey(at::Tensor tensor);
|
||||
|
||||
py::object removePythonKey(at::Tensor tensor);
|
||||
}}
|
@ -5,6 +5,7 @@
|
||||
#include <functorch/csrc/DynamicLayer.h>
|
||||
#include <functorch/csrc/BatchedTensorImpl.h>
|
||||
#include <functorch/csrc/VmapTransforms.h>
|
||||
#include <functorch/csrc/PythonKey.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
@ -183,4 +184,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("_unwrap_for_grad", &at::functorch::_unwrap_for_grad, "add batch dim");
|
||||
m.def("dlevel", &at::functorch::dlevel, "add batch dim");
|
||||
m.def("dump_tensor", &at::functorch::dump_tensor, "add batch dim");
|
||||
|
||||
m.def(
|
||||
"addPythonKey",
|
||||
&at::functorch::addPythonKey,
|
||||
py::return_value_policy::copy); // not sure if needed - cargo cult
|
||||
m.def("removePythonKey", &at::functorch::removePythonKey);
|
||||
m.def("hasPythonKey", &at::functorch::hasPythonKey);
|
||||
}
|
||||
|
@ -563,7 +563,7 @@ class TestExamplesCorrectness(TestCase):
|
||||
return mse_loss(v_f, y2)
|
||||
|
||||
task = sample_tasks(num_tasks, K)
|
||||
|
||||
|
||||
# Compute with vmap+grad
|
||||
inner_losses = vmap(partial(get_loss_for_task, True))\
|
||||
(task[0], task[1], task[2], task[3])
|
||||
@ -609,7 +609,7 @@ class TestExamplesCorrectness(TestCase):
|
||||
nn.MaxPool2d(2, 2),
|
||||
Flatten(),
|
||||
nn.Linear(64, n_way)).to(device).to(dtype)
|
||||
|
||||
|
||||
params, buffers, fnet, _, _, = make_functional_with_buffers(net)
|
||||
net = (params, buffers, fnet)
|
||||
|
||||
@ -739,7 +739,7 @@ class TestExamplesCorrectness(TestCase):
|
||||
|
||||
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
|
||||
def test_resnet18_per_sample_grads(self, device):
|
||||
# Straight out of opacus
|
||||
# Straight out of opacus
|
||||
def _replace_child(
|
||||
root: nn.Module, child_name: str, converter: Callable[[nn.Module], nn.Module]
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user