Codegen: python_torch_functions only include relevant operators (#68693)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68693

Generation of python bindings for native functions is split over 8
different files. One for each namespace, with the torch namespace
split into 3 shards, and methods in their own file as well. This
change ensures that editing any single (non-method) operator only
causes one of these files to be rebuilt.

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D32596270

Pulled By: albanD

fbshipit-source-id: 0570ec69e7476b8f1bc21138ba18fe8f95ebbe3f
(cherry picked from commit ba0fc71a3a6835e49b332a8be52bf798fa2726b3)
This commit is contained in:
Peter Bell
2022-01-21 07:32:59 -08:00
committed by PyTorch MergeBot
parent 7680a0ae9d
commit 40d1f77384
39 changed files with 307 additions and 192 deletions

View File

@ -65,6 +65,7 @@ file(GLOB cuda_nvrtc_stub_cpp "cuda/nvrtc_stub/*.cpp")
file(GLOB cuda_cu "cuda/*.cu" "cuda/detail/*.cu")
file(GLOB cudnn_h "cudnn/*.h" "cudnn/*.cuh")
file(GLOB cudnn_cpp "cudnn/*.cpp")
file(GLOB ops_h "ops/*.h")
file(GLOB hip_h "hip/*.h" "hip/detail/*.h" "hip/*.cuh" "hip/detail/*.cuh" "hip/impl/*.h")
file(GLOB hip_cpp "hip/*.cpp" "hip/detail/*.cpp" "hip/impl/*.cpp")
@ -488,7 +489,8 @@ foreach(HEADER ${core_generated_headers})
install(FILES ${HEADER} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen/core)
endforeach()
install(FILES ${ops_generated_headers} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen/ops)
install(FILES ${ops_h} ${ops_generated_headers}
DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen/ops)
install(FILES ${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml
DESTINATION ${AT_INSTALL_SHARE_DIR}/ATen)

View File

@ -0,0 +1,138 @@
#pragma once
#include <ATen/core/Tensor.h>
namespace at {
namespace detail {
TORCH_API inline void noopDelete(void*) {}
} // namespace detail
/// Provides a fluent API to construct tensors from external data.
///
/// The fluent API can be used instead of `from_blob` functions in case the
/// required set of parameters does not align with the existing overloads.
///
/// at::Tensor tensor = at::for_blob(data, sizes)
/// .strides(strides)
/// .context(context, [](void *ctx) { delete static_cast<Ctx*>(ctx); })
/// .options(...)
/// .make_tensor();
///
class TORCH_API TensorMaker {
friend TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept;
public:
using ContextDeleter = DeleterFnPtr;
TensorMaker& strides(optional<IntArrayRef> value) noexcept {
strides_ = value;
return *this;
}
TensorMaker& storage_offset(optional<int64_t> value) noexcept {
storage_offset_ = value;
return *this;
}
TensorMaker& deleter(std::function<void(void*)> value) noexcept {
deleter_ = std::move(value);
return *this;
}
TensorMaker& context(void* value, ContextDeleter deleter = nullptr) noexcept {
ctx_ = std::unique_ptr<void, ContextDeleter>{
value, deleter != nullptr ? deleter : detail::noopDelete};
return *this;
}
TensorMaker& target_device(optional<Device> value) noexcept {
device_ = value;
return *this;
}
TensorMaker& options(TensorOptions value) noexcept {
opts_ = value;
return *this;
}
Tensor make_tensor();
private:
explicit TensorMaker(void* data, IntArrayRef sizes) noexcept
: data_{data}, sizes_{sizes} {}
std::size_t computeStorageSize() const noexcept;
DataPtr makeDataPtrFromDeleter() const;
DataPtr makeDataPtrFromContext() noexcept;
IntArrayRef makeTempSizes() const noexcept;
void* data_;
IntArrayRef sizes_;
optional<IntArrayRef> strides_{};
optional<int64_t> storage_offset_{};
std::function<void(void*)> deleter_{};
std::unique_ptr<void, ContextDeleter> ctx_{nullptr, detail::noopDelete};
optional<Device> device_{};
TensorOptions opts_{};
};
inline TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept {
return TensorMaker{data, sizes};
}
inline Tensor from_blob(
void* data,
IntArrayRef sizes,
IntArrayRef strides,
const std::function<void(void*)>& deleter,
const TensorOptions& options = {},
const c10::optional<Device> target_device = c10::nullopt) {
return for_blob(data, sizes)
.strides(strides)
.deleter(deleter)
.options(options)
.target_device(target_device)
.make_tensor();
}
inline Tensor from_blob(
void* data,
IntArrayRef sizes,
const std::function<void(void*)>& deleter,
const TensorOptions& options = {}) {
return for_blob(data, sizes)
.deleter(deleter)
.options(options)
.make_tensor();
}
inline Tensor from_blob(
void* data,
IntArrayRef sizes,
IntArrayRef strides,
const TensorOptions& options = {}) {
return for_blob(data, sizes)
.strides(strides)
.options(options)
.make_tensor();
}
inline Tensor from_blob(
void* data,
IntArrayRef sizes,
const TensorOptions& options = {}) {
return for_blob(data, sizes).options(options).make_tensor();
}
} // namespace at

View File

@ -0,0 +1,30 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
namespace at {
// These functions are defined in ATen/Utils.cpp.
#define TENSOR(T, S) \
TORCH_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \
inline Tensor tensor( \
std::initializer_list<T> values, const TensorOptions& options) { \
return at::tensor(ArrayRef<T>(values), options); \
} \
inline Tensor tensor(T value, const TensorOptions& options) { \
return at::tensor(ArrayRef<T>(value), options); \
} \
inline Tensor tensor(ArrayRef<T> values) { \
return at::tensor(std::move(values), at::dtype(k##S)); \
} \
inline Tensor tensor(std::initializer_list<T> values) { \
return at::tensor(ArrayRef<T>(values)); \
} \
inline Tensor tensor(T value) { \
return at::tensor(ArrayRef<T>(value)); \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
AT_FORALL_COMPLEX_TYPES(TENSOR)
#undef TENSOR
} // namespace at

View File

@ -71,33 +71,13 @@
${static_dispatch_extra_headers}
#include <ATen/ops/from_blob.h>
#include <ATen/ops/tensor.h>
${Functions_includes}
namespace at {
// These functions are defined in ATen/Utils.cpp.
#define TENSOR(T, S) \
TORCH_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \
inline Tensor tensor( \
std::initializer_list<T> values, const TensorOptions& options) { \
return at::tensor(ArrayRef<T>(values), options); \
} \
inline Tensor tensor(T value, const TensorOptions& options) { \
return at::tensor(ArrayRef<T>(value), options); \
} \
inline Tensor tensor(ArrayRef<T> values) { \
return at::tensor(std::move(values), at::dtype(k##S)); \
} \
inline Tensor tensor(std::initializer_list<T> values) { \
return at::tensor(ArrayRef<T>(values)); \
} \
inline Tensor tensor(T value) { \
return at::tensor(ArrayRef<T>(value)); \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
AT_FORALL_COMPLEX_TYPES(TENSOR)
#undef TENSOR
${Functions_declarations}
// Special C++ only overloads for std()-like functions (See gh-40287)
@ -116,138 +96,6 @@ TORCH_API inline std::tuple<Tensor, Tensor> std_mean(const Tensor& self, int dim
return at::std_mean(self, IntArrayRef{dim});
}
namespace detail {
TORCH_API inline void noopDelete(void*) {}
} // namespace detail
/// Provides a fluent API to construct tensors from external data.
///
/// The fluent API can be used instead of `from_blob` functions in case the
/// required set of parameters does not align with the existing overloads.
///
/// at::Tensor tensor = at::for_blob(data, sizes)
/// .strides(strides)
/// .context(context, [](void *ctx) { delete static_cast<Ctx*>(ctx); })
/// .options(...)
/// .make_tensor();
///
class TORCH_API TensorMaker {
friend TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept;
public:
using ContextDeleter = DeleterFnPtr;
TensorMaker& strides(optional<IntArrayRef> value) noexcept {
strides_ = value;
return *this;
}
TensorMaker& storage_offset(optional<int64_t> value) noexcept {
storage_offset_ = value;
return *this;
}
TensorMaker& deleter(std::function<void(void*)> value) noexcept {
deleter_ = std::move(value);
return *this;
}
TensorMaker& context(void* value, ContextDeleter deleter = nullptr) noexcept {
ctx_ = std::unique_ptr<void, ContextDeleter>{
value, deleter != nullptr ? deleter : detail::noopDelete};
return *this;
}
TensorMaker& target_device(optional<Device> value) noexcept {
device_ = value;
return *this;
}
TensorMaker& options(TensorOptions value) noexcept {
opts_ = value;
return *this;
}
Tensor make_tensor();
private:
explicit TensorMaker(void* data, IntArrayRef sizes) noexcept
: data_{data}, sizes_{sizes} {}
std::size_t computeStorageSize() const noexcept;
DataPtr makeDataPtrFromDeleter() const;
DataPtr makeDataPtrFromContext() noexcept;
IntArrayRef makeTempSizes() const noexcept;
void* data_;
IntArrayRef sizes_;
optional<IntArrayRef> strides_{};
optional<int64_t> storage_offset_{};
std::function<void(void*)> deleter_{};
std::unique_ptr<void, ContextDeleter> ctx_{nullptr, detail::noopDelete};
optional<Device> device_{};
TensorOptions opts_{};
};
inline TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept {
return TensorMaker{data, sizes};
}
inline Tensor from_blob(
void* data,
IntArrayRef sizes,
IntArrayRef strides,
const std::function<void(void*)>& deleter,
const TensorOptions& options = {},
const c10::optional<Device> target_device = c10::nullopt) {
return for_blob(data, sizes)
.strides(strides)
.deleter(deleter)
.options(options)
.target_device(target_device)
.make_tensor();
}
inline Tensor from_blob(
void* data,
IntArrayRef sizes,
const std::function<void(void*)>& deleter,
const TensorOptions& options = {}) {
return for_blob(data, sizes)
.deleter(deleter)
.options(options)
.make_tensor();
}
inline Tensor from_blob(
void* data,
IntArrayRef sizes,
IntArrayRef strides,
const TensorOptions& options = {}) {
return for_blob(data, sizes)
.strides(strides)
.options(options)
.make_tensor();
}
inline Tensor from_blob(
void* data,
IntArrayRef sizes,
const TensorOptions& options = {}) {
return for_blob(data, sizes).options(options).make_tensor();
}
inline int64_t numel(const Tensor& tensor) {
return tensor.numel();
}

View File

@ -9,6 +9,8 @@
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <torch/csrc/utils/memory.h>
#include <ATen/TensorOperators.h>
namespace torch {
namespace jit {

View File

@ -219,6 +219,7 @@ def create_python_bindings(
) -> None:
"""Generates Python bindings to ATen functions"""
py_methods: List[str] = []
ops_headers: List[str] = []
py_method_defs: List[str] = []
py_forwards: List[str] = []
@ -229,9 +230,11 @@ def create_python_bindings(
py_methods.append(method_impl(name, module, overloads, method=method))
py_method_defs.append(method_def(name, module, overloads, method=method))
py_forwards.extend(forward_decls(name, overloads, method=method))
ops_headers.append(f'#include <ATen/ops/{name.base}.h>')
fm.write_with_template(filename, filename, lambda: {
'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}',
'ops_headers': ops_headers,
'py_forwards': py_forwards,
'py_methods': py_methods,
'py_method_defs': py_method_defs,
@ -278,15 +281,17 @@ def create_python_bindings_sharded(
grouped = group_filter_overloads(pairs, pred)
def key_func(kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]) -> str:
return str(kv[0])
return kv[0].base
def env_func(
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
) -> Dict[str, List[str]]:
name, fn_pairs = kv
return {
'py_forwards': list(forward_decls(kv[0], kv[1], method=method)),
'py_methods': [method_impl(kv[0], module, kv[1], method=method)],
'py_method_defs': [method_def(kv[0], module, kv[1], method=method)],
'ops_headers': [f'#include <ATen/ops/{name.base}.h>'],
'py_forwards': list(forward_decls(name, fn_pairs, method=method)),
'py_methods': [method_impl(name, module, fn_pairs, method=method)],
'py_method_defs': [method_def(name, module, fn_pairs, method=method)],
}
fm.write_sharded(
@ -299,7 +304,7 @@ def create_python_bindings_sharded(
key_fn=key_func,
env_callable=env_func,
num_shards=num_shards,
sharded_keys={'py_forwards', 'py_methods', 'py_method_defs'}
sharded_keys={'ops_headers', 'py_forwards', 'py_methods', 'py_method_defs'}
)
def load_signatures(

View File

@ -35,12 +35,23 @@ def fully_qualified_type(argument_type: str) -> str:
def gen_variable_factories(out: str, native_yaml_path: str, template_path: str) -> None:
native_functions = parse_native_yaml(native_yaml_path).native_functions
factory_functions = [fn for fn in native_functions if is_factory_function(fn)]
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
fm.write_with_template('variable_factories.h', 'variable_factories.h', lambda: {
'generated_comment': '@' + f'generated from {fm.template_dir}/variable_factories.h',
'function_definitions': list(mapMaybe(process_function, native_functions)),
'ops_headers': [f'#include <ATen/ops/{fn.root_name}.h>' for fn in factory_functions],
'function_definitions': list(mapMaybe(process_function, factory_functions)),
})
@with_native_function
def is_factory_function(f: NativeFunction) -> bool:
if Variant.function not in f.variants:
return False
name = cpp.name(f.func)
has_tensor_options = python.has_tensor_options(f)
return has_tensor_options or name.endswith("_like")
@with_native_function
def process_function(f: NativeFunction) -> Optional[str]:
name = cpp.name(f.func)

View File

@ -2,7 +2,8 @@
// ${generated_comment}
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/Context.h>
#include <c10/util/intrusive_ptr.h>

View File

@ -1,3 +1,4 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
#include "torch/csrc/Device.h"
@ -15,7 +16,13 @@
#include "torch/csrc/utils/structseq.h"
#include "torch/csrc/utils/cuda_lazy_init.h"
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
$ops_headers
#endif
using at::Tensor;
using at::Device;

View File

@ -1,3 +1,4 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
#include "torch/csrc/Device.h"
@ -12,6 +13,12 @@
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/structseq.h"
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
$ops_headers
#endif
using at::Tensor;
using at::Scalar;
using at::ScalarType;

View File

@ -1,3 +1,4 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
#include "torch/csrc/Device.h"
@ -12,6 +13,12 @@
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/structseq.h"
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
$ops_headers
#endif
using at::Tensor;
using at::Scalar;
using at::MemoryFormat;

View File

@ -1,3 +1,4 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
#include "torch/csrc/Device.h"
@ -11,6 +12,12 @@
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/structseq.h"
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
$ops_headers
#endif
using at::Tensor;
using at::Scalar;
using at::ScalarType;

View File

@ -1,3 +1,4 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
#include "torch/csrc/Device.h"
@ -15,7 +16,11 @@
#include "torch/csrc/utils/structseq.h"
#include "torch/csrc/utils/cuda_lazy_init.h"
#include <ATen/ATen.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
$ops_headers
#endif
using at::Tensor;
using at::Device;

View File

@ -1,3 +1,4 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
// Python bindings for torch.* functions implemented through ATen.
@ -33,7 +34,13 @@
#include "torch/csrc/utils/cuda_lazy_init.h"
#include "torch/csrc/autograd/python_return_types.h"
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
$ops_headers
#endif
#include <functional>
#include <initializer_list>

View File

@ -1,3 +1,4 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
#include <Python.h>
@ -35,12 +36,18 @@
#include "torch/csrc/utils/structseq.h"
#include "torch/csrc/autograd/python_return_types.h"
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include "c10/util/Optional.h"
#include "c10/core/Stream.h"
#include <stdexcept>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
$ops_headers
#endif
using at::DeviceGuard;
using at::device_of;
using at::OptionalDeviceGuard;
@ -256,11 +263,12 @@ static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObjec
// manually.
if (jit::tracer::isTracing()) {
auto tracer_state = jit::tracer::getTracingState();
auto node = tracer_state->graph->create(jit::aten::contiguous, /*num_outputs=*/0);
auto op_name = c10::Symbol::fromQualString("aten::contiguous");
auto node = tracer_state->createNode(op_name, /*num_outputs=*/0);
jit::tracer::recordSourceLocation(node);
jit::tracer::addInputs(node, "self", self_);
jit::tracer::addInputs(node, "memory_format", memory_format);
tracer_state->graph->insertNode(node);
tracer_state->insertNode(node);
jit::tracer::addOutput(node, self_);
}
Py_INCREF(self);

View File

@ -2,15 +2,20 @@
// ${generated_comment}
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/TracerMode.h>
#include <ATen/core/grad_mode.h>
#include <c10/util/ArrayRef.h>
#include <c10/core/MemoryFormat.h>
#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/ir.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/from_blob.h>
$ops_headers
#endif
#include <functional>
#include <initializer_list>

View File

@ -1,6 +1,6 @@
#pragma once
#include <ATen/ATen.h>
#include <c10/core/ScalarType.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/Export.h>

View File

@ -2,7 +2,7 @@
#include <torch/csrc/Export.h>
#include <torch/csrc/python_headers.h>
#include <ATen/ATen.h>
#include <ATen/core/Generator.h>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)

View File

@ -2,6 +2,8 @@
#include <torch/csrc/python_headers.h>
#include <torch/csrc/Types.h>
#include <c10/core/StorageImpl.h>
#include <c10/core/TensorImpl.h>
namespace torch {

View File

@ -1,8 +1,18 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/ScalarOps.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/tensor.h>
#endif
#include <initializer_list>
namespace torch {

View File

@ -7,6 +7,11 @@
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/variable.h>
// TODO: These don't really belong here but torchvision builds in CI need them
// Remove once the torchvision version being compiled in CI is updated
#include <torch/library.h>
#include <ATen/core/dispatch/Dispatcher.h>
namespace torch {
// NOTE [ Exposing declarations in `at::` to `torch::` ]

View File

@ -1,6 +1,5 @@
#pragma once
#include <ATen/ATen.h>
#include <torch/library.h>
namespace torch {

View File

@ -2,7 +2,7 @@
#include <torch/csrc/python_headers.h>
#include <memory>
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/Export.h>

View File

@ -1,7 +1,7 @@
#pragma once
#include <torch/csrc/python_headers.h>
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <torch/csrc/utils/python_arg_parser.h>

View File

@ -2,7 +2,8 @@
// Wrap tensor operation outputs as PyObject*
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/ScalarOps.h>
#include <c10/util/irange.h>
#include <torch/csrc/python_headers.h>
#include <tuple>

View File

@ -1,4 +1,5 @@
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/custom_class.h>
namespace torch {
namespace distributed {

View File

@ -6,6 +6,10 @@
#include <algorithm>
#include <cctype>
#include <chrono>
#include <condition_variable>
#include <mutex>
#include <thread>
namespace torch {
namespace distributed {

View File

@ -3,6 +3,7 @@
#include <ATen/core/ivalue.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <algorithm>
#include <bitset>

View File

@ -1,6 +1,6 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
namespace torch {
namespace utils {

View File

@ -2,7 +2,7 @@
#include <torch/csrc/python_headers.h>
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <c10/util/irange.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

View File

@ -1146,7 +1146,7 @@ at::Scalar PythonArgs::scalar_slow(int i) {
if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
auto& var = THPVariable_Unpack(args[i]);
jit::tracer::ArgumentStash::stashValue(
signature.params[i].name, idx, var, jit::NumberType::get());
signature.params[i].name, idx, var, c10::NumberType::get());
}
return scalar_slow(args[i]);

View File

@ -53,7 +53,6 @@
#include <torch/csrc/Layout.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/python_dimname.h>
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/utils/object_ptr.h>
@ -64,7 +63,7 @@
#include <torch/csrc/utils/six.h>
#include <torch/csrc/autograd/variable.h>
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
@ -623,7 +622,7 @@ inline int64_t PythonArgs::toInt64(int i) {
if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
auto & var = THPVariable_Unpack(args[i]);
jit::tracer::ArgumentStash::stashValue(
signature.params[i].name, idx, var, jit::IntType::get());
signature.params[i].name, idx, var, c10::IntType::get());
}
return THPUtils_unpackLong(args[i]);
}

View File

@ -1,7 +1,7 @@
#pragma once
#include <torch/csrc/python_headers.h>
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
namespace torch { namespace utils {

View File

@ -1,7 +1,10 @@
#pragma once
#include <torch/csrc/python_headers.h>
#include <ATen/ATen.h>
namespace at {
class Tensor;
}
namespace torch { namespace utils {

View File

@ -2,7 +2,7 @@
#include <torch/csrc/python_headers.h>
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
namespace torch { namespace utils {

View File

@ -1,7 +1,7 @@
#pragma once
#include <torch/csrc/python_headers.h>
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
namespace torch { namespace utils {

View File

@ -6,6 +6,7 @@
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/tensor/python_tensor.h>
#include <ATen/Context.h>
#include <ATen/Formatting.h>
#include <sstream>
#include <unordered_map>

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/DeprecatedTypeProperties.h>
#include <c10/core/TensorOptions.h>
#include <utility>
#include <vector>

View File

@ -61,9 +61,7 @@
#include <ATen/core/op_registration/infer_schema.h>
#include <ATen/core/op_registration/op_allowlist.h>
#include <c10/core/DispatchKey.h>
#if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD)
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#endif
// Just for inferFunctionSchemaFromFunctor
#include <ATen/core/op_registration/op_registration.h>