mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support native namespace functions with type dispatch. (#5576)
* Support native namespace functions with type dispatch. Use 'ones' as an example. Note this is a "halfway" solution; i.e. the call chain is: at::ones(shape, dtype) -> dtype.ones(shape, dtype) -> CPUFloatType.ones(shape, dtype) -> at::native::ones(shape, dtype) The "nicer" solution would probably be something like: at::ones(shape, dtype) -> dtype.ones(shape) -> CPUFloatType.ones(shape) -> at::native::ones(shape, this) * Fix type inference. * Fix test install. * Fix extensions. * Put dtype argument at the beginning. * Fix extension.cpp. * Fix rnn. * Move zeros in the same manner. * Fix cuda. * Change randn. * Change rand. * Change randperm. * Fix aten contrib. * Resize in randperm_out. * Implement eye. * Fix sparse zeros. * linspace, logspace. * arange. * range. * Remove type dispatch from gen_python_functions. * Properly generate maybe_init_cuda for type dispatch functions not named type. * Don't duplicate dtype, this parameters for native type dispatched functions. * Call VariableType factory methods from the base type so it gets version number 0. * Address review comments.
This commit is contained in:
@ -9,7 +9,7 @@ int main()
|
||||
{
|
||||
std::cout << "hello\n";
|
||||
|
||||
Tensor tensor = CPU(kDouble).rand({256,32});
|
||||
Tensor tensor = rand(CPU(kDouble), {256,32});
|
||||
|
||||
TensorDataset dataset(tensor);
|
||||
DatasetIterator datasetiterator(dataset);
|
||||
|
@ -29,8 +29,8 @@ void AUCMeter::value(Tensor& val) {
|
||||
int64_t * sortidx_d = sortidx.data<int64_t>();
|
||||
int16_t * targets_d = sortidx.data<int16_t>();
|
||||
// construct the ROC curve:
|
||||
Tensor tpr = CPU(kDouble).zeros({numel(outputs)});
|
||||
Tensor fpr = CPU(kDouble).zeros({numel(outputs)});
|
||||
Tensor tpr = zeros(CPU(kDouble), {numel(outputs)});
|
||||
Tensor fpr = zeros(CPU(kDouble), {numel(outputs)});
|
||||
|
||||
double * tpr_d = tpr.data<double>();
|
||||
double * fpr_d = fpr.data<double>();
|
||||
|
@ -8,8 +8,8 @@ int main()
|
||||
auto && T = CPU(kFloat);
|
||||
std::cout << "hello\n";
|
||||
APMeter meter;
|
||||
Tensor output = T.randn({10, 7});
|
||||
Tensor target = T.zeros({10, 7});
|
||||
Tensor output = at::randn(T, {10, 7});
|
||||
Tensor target = at::zeros(T, {10, 7});
|
||||
for(uint64_t n = 0; n < 10; ++n) {
|
||||
Tensor row = target.select(0,n);
|
||||
auto row_d = row.data<float>();
|
||||
|
@ -31,31 +31,6 @@
|
||||
long_args: True
|
||||
- CONSTANT NULL
|
||||
]]
|
||||
[[
|
||||
name: zeros
|
||||
variants:
|
||||
- function
|
||||
aten_sparse: True
|
||||
auto_gpu: False
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- arg: THSize* size
|
||||
long_args: True
|
||||
]]
|
||||
[[
|
||||
name: ones
|
||||
variants:
|
||||
- function
|
||||
auto_gpu: False
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- arg: THSize* size
|
||||
long_args: True
|
||||
]]
|
||||
[[
|
||||
name: numel
|
||||
return: long
|
||||
@ -369,7 +344,8 @@
|
||||
- long step
|
||||
]]
|
||||
[[
|
||||
name: range
|
||||
name: _range
|
||||
cname: range
|
||||
variants:
|
||||
- function
|
||||
backends:
|
||||
@ -389,7 +365,7 @@
|
||||
default: 1
|
||||
]]
|
||||
[[
|
||||
name: arange
|
||||
name: _arange
|
||||
variants:
|
||||
- function
|
||||
backends:
|
||||
@ -2268,7 +2244,8 @@
|
||||
- real weight
|
||||
]]
|
||||
[[
|
||||
name: linspace
|
||||
name: _linspace
|
||||
cname: linspace
|
||||
types:
|
||||
- Float
|
||||
- Double
|
||||
@ -2286,7 +2263,8 @@
|
||||
default: 100
|
||||
]]
|
||||
[[
|
||||
name: logspace
|
||||
name: _logspace
|
||||
cname: logspace
|
||||
types:
|
||||
- Float
|
||||
- Double
|
||||
@ -2857,21 +2835,6 @@
|
||||
- arg: long dim
|
||||
default: -1
|
||||
]]
|
||||
[[
|
||||
name: eye
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- long n
|
||||
- arg: long m
|
||||
default: -1
|
||||
]]
|
||||
[[
|
||||
name: diag
|
||||
variants:
|
||||
@ -3681,21 +3644,6 @@
|
||||
- THTensor* LU_data
|
||||
- THIntegerTensor* LU_pivots
|
||||
]]
|
||||
[[
|
||||
name: randperm
|
||||
backends:
|
||||
- CPU
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- arg: THGenerator* generator
|
||||
default: nullptr
|
||||
kwarg_only: True
|
||||
- long n
|
||||
]]
|
||||
[[
|
||||
name: random_
|
||||
backends:
|
||||
@ -3881,44 +3829,6 @@
|
||||
- arg: double lambd
|
||||
default: 1
|
||||
]]
|
||||
[[
|
||||
name: rand
|
||||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- arg: THGenerator* generator
|
||||
default: nullptr
|
||||
kwarg_only: True
|
||||
- arg: THSize* size
|
||||
long_args: True
|
||||
]]
|
||||
[[
|
||||
name: randn
|
||||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- arg: THGenerator* generator
|
||||
default: nullptr
|
||||
kwarg_only: True
|
||||
- arg: THSize* size
|
||||
long_args: True
|
||||
]]
|
||||
[[
|
||||
name: geometric_
|
||||
backends:
|
||||
|
@ -39,11 +39,11 @@ else:
|
||||
# declaration under Type.h (right now, we call this template
|
||||
# BROADCAST but it also handles default arguments)
|
||||
TYPE_METHOD_DECLARATION_BROADCAST = CodeTemplate("""\
|
||||
${return_type} ${api_name}(${formals_with_defaults}) const;
|
||||
${return_type} ${api_name}(${type_method_formals_with_defaults}) const;
|
||||
""")
|
||||
# 2. broadcasting functions are implemented in Type.cpp
|
||||
TYPE_METHOD_DEFINITION_BROADCAST = CodeTemplate("""\
|
||||
${return_type} Type::${api_name}(${formals}) const {
|
||||
${return_type} Type::${api_name}(${type_method_formals}) const {
|
||||
Tensor ${broadcast_returns};
|
||||
std::tie(${broadcast_returns}) = ${broadcast_function}(${broadcast_actuals}, "${api_name}");
|
||||
return ${method_prefix_derived}${api_name}(${broadcast_modified_actuals});
|
||||
@ -59,28 +59,28 @@ ${return_type} Type::${api_name}(${formals}) const {
|
||||
# for 'native' declarations (so the native dispatch is hardcoded into
|
||||
# the template here.)
|
||||
TYPE_METHOD_DECLARATION_ABSTRACT = CodeTemplate("""\
|
||||
virtual ${return_type} ${method_prefix_derived}${api_name}(${formals_with_defaults}) const;
|
||||
virtual ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals_with_defaults}) const;
|
||||
""")
|
||||
TYPE_METHOD_DEFINITION_ABSTRACT = CodeTemplate("""\
|
||||
${return_type} Type::${method_prefix_derived}${api_name}(${formals}) const {
|
||||
${return_type} Type::${method_prefix_derived}${api_name}(${type_method_formals}) const {
|
||||
runtime_error("${method_prefix_derived}${api_name} is not implemented for type %s", toString());
|
||||
}
|
||||
""")
|
||||
TYPE_METHOD_DECLARATION_CONCRETE = CodeTemplate("""\
|
||||
virtual ${return_type} ${api_name}(${formals_with_defaults}) const;
|
||||
virtual ${return_type} ${api_name}(${type_method_formals_with_defaults}) const;
|
||||
""")
|
||||
TYPE_METHOD_DEFINITION_CONCRETE = CodeTemplate("""\
|
||||
${return_type} Type::${api_name}(${formals}) const {
|
||||
${return_type} Type::${api_name}(${type_method_formals}) const {
|
||||
${type_definition_body}
|
||||
}
|
||||
""")
|
||||
# 4. add virtual override to TypeDerived.h
|
||||
TYPE_DERIVED_DECLARATION = CodeTemplate("""\
|
||||
virtual ${return_type} ${method_prefix_derived}${api_name}(${formals}) const override;
|
||||
virtual ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override;
|
||||
""")
|
||||
# 5. add override definition to TypeDerived.cpp
|
||||
TYPE_DERIVED_DEFINITION = CodeTemplate("""\
|
||||
${return_type} ${Type}::${method_prefix_derived}${api_name}(${formals}) const {
|
||||
${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const {
|
||||
${type_definition_body}
|
||||
}
|
||||
""")
|
||||
@ -88,12 +88,12 @@ ${return_type} ${Type}::${method_prefix_derived}${api_name}(${formals}) const {
|
||||
# because we will inherit it from the TYPE_METHOD_DEFINITION_CONCRETE in
|
||||
# the superclass. But it doesn't seem to be harmful.
|
||||
TYPE_DERIVED_DEFINITION_NATIVE = CodeTemplate("""\
|
||||
${return_type} ${Type}::${api_name}(${formals}) const {
|
||||
${return_type} ${Type}::${api_name}(${type_method_formals}) const {
|
||||
${return_call} at::native::${native_type_method_dispatch}(${actuals});
|
||||
}
|
||||
""")
|
||||
TYPE_DEFINITION_BODY_NATIVE = CodeTemplate("""\
|
||||
${return_call} at::native::${native_type_method_dispatch}(${actuals});
|
||||
${return_call} at::native::${native_type_method_dispatch}(${native_actuals});
|
||||
""")
|
||||
|
||||
# 6. add non-virtual declaration to Tensor.h
|
||||
@ -113,7 +113,7 @@ static inline ${return_type} ${api_name}(${formals_with_defaults});
|
||||
# 9. add method definition in Functions.h
|
||||
FUNCTION_DEFINITION = CodeTemplate("""\
|
||||
static inline ${return_type} ${api_name}(${formals}) {
|
||||
return ${inferred_type}.${api_name}(${actuals});
|
||||
return ${inferred_type}.${api_name}(${type_method_actuals});
|
||||
}
|
||||
""")
|
||||
# 10. add a native declaration for a native function
|
||||
@ -341,24 +341,10 @@ THFormal = TypedDict('THFormal', {
|
||||
'resize': str,
|
||||
'cpu_zero': bool,
|
||||
'zero': bool,
|
||||
'is_type_dispatched': bool,
|
||||
}, total=False)
|
||||
|
||||
# A native_functions.yaml formal argument
|
||||
# type can contain Tensor, BoolTensor, IndexTensor types
|
||||
NativeFormal = TypedDict('NativeFormal', {
|
||||
'name': str,
|
||||
'type': str,
|
||||
'dynamic_type': str,
|
||||
'kwarg_only': bool,
|
||||
'is_nullable': bool,
|
||||
'default': str,
|
||||
'default_init': str,
|
||||
'python_default_init': str,
|
||||
'output': bool,
|
||||
'size': int,
|
||||
}, total=False)
|
||||
|
||||
# Generic ATen formal.
|
||||
# Generic ATen formal or native_functions.yaml formal argument.
|
||||
# type can contain Tensor& reference types.
|
||||
AtFormal = TypedDict('AtFormal', {
|
||||
'name': str,
|
||||
@ -371,6 +357,7 @@ AtFormal = TypedDict('AtFormal', {
|
||||
'python_default_init': str,
|
||||
'output': bool,
|
||||
'size': int,
|
||||
'is_type_dispatched': bool,
|
||||
}, total=False)
|
||||
|
||||
ReturnType = TypedDict('ReturnType', {
|
||||
@ -397,6 +384,9 @@ FunctionOption = TypedDict('FunctionOption', {
|
||||
'return': ReturnDecl,
|
||||
'variants': str,
|
||||
'type_method_definition_dispatch': str,
|
||||
'type_method_formals': List[str],
|
||||
'type_method_formals_with_defaults': List[str],
|
||||
'type_method_actuals': List[str],
|
||||
'cname': str,
|
||||
'backends': List[str],
|
||||
'api_name': str,
|
||||
@ -437,6 +427,7 @@ FunctionOption = TypedDict('FunctionOption', {
|
||||
'broadcast_function': str,
|
||||
'broadcast_modified_actuals': List[str],
|
||||
'native_type_method_dispatch': str,
|
||||
'native_actuals': List[str],
|
||||
})
|
||||
|
||||
OutputDeclaration = NamedTuple('OutputDeclaration', [
|
||||
@ -699,6 +690,11 @@ def create_generic(top_env, declarations):
|
||||
option['method_actuals'] = [
|
||||
f['name'] if f['name'] != 'self' else '*this' for f in formals]
|
||||
|
||||
# There are no cases where these differ, but they do in native_functions
|
||||
option['type_method_formals'] = option['formals']
|
||||
option['type_method_formals_with_defaults'] = option['formals_with_defaults']
|
||||
option['type_method_actuals'] = option['actuals']
|
||||
|
||||
option['const_mark'] = '' if option['inplace'] else ' const'
|
||||
|
||||
is_method = 'method' in option['variants']
|
||||
@ -790,7 +786,7 @@ def create_generic(top_env, declarations):
|
||||
kwd_args = []
|
||||
|
||||
def insert(argument):
|
||||
# type: (NativeFormal) -> None
|
||||
# type: (AtFormal) -> None
|
||||
if argument['name'] not in seen:
|
||||
seen.add(argument['name'])
|
||||
if argument.get('kwarg_only', False):
|
||||
@ -804,7 +800,7 @@ def create_generic(top_env, declarations):
|
||||
# not clear we need dynamic_type translation as we can specify the correct type
|
||||
# directly in native functions
|
||||
def add_type_as_dynamic_type(argument, option):
|
||||
# type: (NativeFormal, FunctionOption) -> NativeFormal
|
||||
# type: (AtFormal, FunctionOption) -> AtFormal
|
||||
argument['dynamic_type'] = argument['type']
|
||||
return argument
|
||||
|
||||
@ -813,7 +809,7 @@ def create_generic(top_env, declarations):
|
||||
|
||||
# ensure we get reference-type formals when appropriate
|
||||
def native_translate_formals(argument, option):
|
||||
# type: (NativeFormal, FunctionOption) -> AtFormal
|
||||
# type: (AtFormal, FunctionOption) -> AtFormal
|
||||
def translate_map(const):
|
||||
# type: (bool) -> Dict[str, str]
|
||||
return {
|
||||
@ -885,12 +881,27 @@ def create_generic(top_env, declarations):
|
||||
option['method_actuals'] = [
|
||||
f['name'] if f['name'] != 'self' else '*this' for f in formals]
|
||||
|
||||
def find_dispatch_type(formals):
|
||||
for formal in formals:
|
||||
if 'Type' == formal['dynamic_type']:
|
||||
return formal
|
||||
return None
|
||||
|
||||
dispatch_tensor = find_dispatch_tensor(formals)
|
||||
dispatch_type = None if dispatch_tensor else find_dispatch_type(formals)
|
||||
if dispatch_type:
|
||||
dispatch_type['is_type_dispatched'] = True
|
||||
|
||||
option['type_method_formals'] = [format_formal(f) for f in formals if f != dispatch_type]
|
||||
option['type_method_formals_with_defaults'] = [formal_with_default(f) for f in formals if f != dispatch_type]
|
||||
option['type_method_actuals'] = [f['name'] for f in formals if f != dispatch_type]
|
||||
option['native_actuals'] = [f['name'] if f != dispatch_type else '*this' for f in formals]
|
||||
|
||||
option['const_mark'] = '' if option['inplace'] else ' const'
|
||||
|
||||
is_method = 'method' in option['variants']
|
||||
is_function = 'function' in option['variants']
|
||||
dispatch_tensor = find_dispatch_tensor(formals)
|
||||
is_namespace_function = is_function and dispatch_tensor is not None
|
||||
is_namespace_function = is_function and (dispatch_tensor or dispatch_type)
|
||||
|
||||
option['method_prefix_derived'] = ''
|
||||
env = nested_dict(option, top_env)
|
||||
@ -947,7 +958,10 @@ def create_generic(top_env, declarations):
|
||||
method_of.append('Tensor')
|
||||
|
||||
if is_namespace_function:
|
||||
option['inferred_type'] = 'infer_type({})'.format(dispatch_tensor)
|
||||
if dispatch_type:
|
||||
option['inferred_type'] = dispatch_type['name']
|
||||
else:
|
||||
option['inferred_type'] = 'infer_type({})'.format(dispatch_tensor)
|
||||
top_env['function_declarations'].append(
|
||||
FUNCTION_DECLARATION.substitute(env))
|
||||
top_env['function_definitions'].append(
|
||||
|
@ -200,7 +200,7 @@ Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
|
||||
}
|
||||
|
||||
Tensor _s_poisson_cpu(const Tensor& lambda, Generator *gen) {
|
||||
Tensor ret = lambda.type().zeros(lambda.sizes());
|
||||
Tensor ret = at::zeros(lambda.type(), lambda.sizes());
|
||||
auto lambda_ = lambda.toType(ScalarType::Double);
|
||||
AT_DISPATCH_FLOATING_TYPES(ret.type(), "poisson", [&] {
|
||||
THGenerator* generator = get_generator(gen);
|
||||
|
@ -99,7 +99,7 @@ Tensor embedding_backward_cpu(
|
||||
}
|
||||
|
||||
auto grad = grad_.contiguous().view({numel, grad_.size(-1)});
|
||||
auto grad_weight = grad_.type().zeros({num_weights, grad_.size(-1)});
|
||||
auto grad_weight = at::zeros(grad_.type(), {num_weights, grad_.size(-1)});
|
||||
|
||||
#ifdef _OPENMP
|
||||
if (numel > 1000) {
|
||||
|
@ -81,11 +81,11 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
|
||||
Tensor indices = indices__.contiguous();
|
||||
Tensor offsets = offsets__.contiguous();
|
||||
|
||||
auto bag_size = indices.type().zeros(offsets.sizes());
|
||||
auto bag_size = at::zeros(indices.type(), offsets.sizes());
|
||||
auto offset2bag =
|
||||
indices__.type().zeros({indices.sizes()[0]}); // offset2bag = [0 0 0 0 0]
|
||||
at::zeros(indices__.type(), {indices.sizes()[0]}); // offset2bag = [0 0 0 0 0]
|
||||
make_offset2bag(offsets, indices, offset2bag);
|
||||
auto output = weight.type().zeros({offsets.sizes()[0], weight.sizes()[1]});
|
||||
auto output = at::zeros(weight.type(), {offsets.sizes()[0], weight.sizes()[1]});
|
||||
auto index_output = weight.index_select(0, indices);
|
||||
output.index_add_(0, offset2bag, index_output);
|
||||
make_bag_size(offsets, indices, mode, bag_size);
|
||||
@ -168,7 +168,7 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
|
||||
}
|
||||
|
||||
auto index_grad_weight =
|
||||
grad.type().zeros({num_weights, grad.sizes()[1]}).contiguous();
|
||||
at::zeros(grad.type(), {num_weights, grad.sizes()[1]}).contiguous();
|
||||
|
||||
#pragma omp parallel for if (numel > 1000)
|
||||
for (int64_t i = 0; i < (int64_t)counts_uniq.size(); i++) {
|
||||
|
@ -174,13 +174,13 @@ static Tensor computeLinearIndex(const Tensor & src, TensorList indices) {
|
||||
// Compute the linear indices for the parts of the tensor not being indexed
|
||||
Tensor beforeIndex;
|
||||
if (emptyBefore > 0) {
|
||||
auto index = longType.arange(0, nElemBefore) * strides[emptyBefore - 1];
|
||||
auto index = at::arange(longType, 0, nElemBefore) * strides[emptyBefore - 1];
|
||||
index = index.view(src.sizes().slice(0, emptyBefore));
|
||||
beforeIndex = unsqueezeN(index, 0, linearIndex.dim() + emptyAfter);
|
||||
}
|
||||
Tensor afterIndex;
|
||||
if (emptyAfter > 0) {
|
||||
auto index = longType.arange(0, nElemAfter);
|
||||
auto index = at::arange(longType, 0, nElemAfter);
|
||||
index = index.view(src.sizes().slice(src.dim() - emptyAfter, emptyAfter));
|
||||
afterIndex = unsqueezeN(index, linearIndex.dim() + emptyBefore, 0);
|
||||
}
|
||||
|
@ -74,14 +74,14 @@ Tensor& ger_out(Tensor& result, const Tensor& self, const Tensor& vec2) {
|
||||
|
||||
Tensor mm(const Tensor& self, const Tensor& mat2) {
|
||||
if (self.is_sparse()) {
|
||||
return mat2.type().addmm(mat2.type().zeros({}), self, mat2, 0, 1);
|
||||
return mat2.type().addmm(at::zeros(mat2.type(), {}), self, mat2, 0, 1);
|
||||
}
|
||||
return self.type()._mm(self, mat2);
|
||||
}
|
||||
|
||||
Tensor& mm_out(Tensor& result, const Tensor& self, const Tensor& mat2) {
|
||||
if (self.is_sparse()) {
|
||||
return mat2.type().addmm_out(result, mat2.type().zeros({}), self, mat2, 0, 1);
|
||||
return mat2.type().addmm_out(result, at::zeros(mat2.type(), {}), self, mat2, 0, 1);
|
||||
}
|
||||
return self.type()._mm_out(result, self, mat2);
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ Tensor stft(const Tensor& self, const int64_t frame_length,
|
||||
}
|
||||
// pad zeros
|
||||
if (pad_end != 0) {
|
||||
Tensor padded_input = self.type().zeros({batch, len + pad_end});
|
||||
Tensor padded_input = at::zeros(self.type(), {batch, len + pad_end});
|
||||
padded_input.narrow(1, 0, len).copy_(input);
|
||||
len += pad_end;
|
||||
input = padded_input;
|
||||
@ -90,8 +90,8 @@ Tensor stft(const Tensor& self, const int64_t frame_length,
|
||||
// build ft kernel
|
||||
// k[omega, t] = cos (2 pi omega t / N) - j sin (2 pi omega t / N)
|
||||
double N = static_cast<double>(fft_size);
|
||||
auto freq_arange = self.type().arange(0, return_size).mul_(M_PI * 2. / N);
|
||||
auto time_arange = self.type().arange(0, frame_length);
|
||||
auto freq_arange = at::arange(self.type(), 0, return_size).mul_(M_PI * 2. / N);
|
||||
auto time_arange = at::arange(self.type(), 0, frame_length);
|
||||
auto arange_2d = at::ger(freq_arange, time_arange);
|
||||
auto re_kernel = arange_2d.cos();
|
||||
auto im_kernel = arange_2d.sin().mul_(-1);
|
||||
|
@ -1,9 +1,31 @@
|
||||
#include "ATen/ATen.h"
|
||||
#include "ATen/NativeFunctions.h"
|
||||
#include "TH/THRandom.h"
|
||||
#include "ATen/CheckGenerator.h"
|
||||
#include "ATen/CPUGenerator.h"
|
||||
#include "ATen/Dispatch.h"
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
Tensor arange(const Type& dtype, Scalar start, Scalar end, Scalar step) {
|
||||
return dtype._arange(start, end, step);
|
||||
}
|
||||
|
||||
Tensor& arange_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
|
||||
return at::_arange_out(result, start, end, step);
|
||||
}
|
||||
|
||||
Tensor arange(const Type& dtype, Scalar end) {
|
||||
return dtype._arange(end);
|
||||
}
|
||||
|
||||
Tensor& arange_out(Tensor& result, Scalar end) {
|
||||
return at::_arange_out(result, end);
|
||||
}
|
||||
|
||||
Tensor empty_like(const Tensor& self) {
|
||||
return self.type().tensor(self.sizes());
|
||||
}
|
||||
@ -12,28 +34,178 @@ Tensor empty_like(const Tensor& self, const Type& dtype) {
|
||||
return dtype.tensor(self.sizes());
|
||||
}
|
||||
|
||||
Tensor eye(const Type& dtype, int64_t n, int64_t m) {
|
||||
auto result = dtype.tensor();
|
||||
return at::eye_out(result, n, m);
|
||||
}
|
||||
|
||||
Tensor& eye_out_cpu(Tensor& result, int64_t n, int64_t m) {
|
||||
if (n <= 0) {
|
||||
std::ostringstream oss;
|
||||
oss << "n must be greater than 0, got: " << n;
|
||||
throw std::runtime_error(oss.str());
|
||||
}
|
||||
if(m <= 0) {
|
||||
m = n;
|
||||
}
|
||||
|
||||
result.resize_({n, m});
|
||||
result.zero_();
|
||||
|
||||
int64_t sz = std::min<int64_t>(n, m);
|
||||
AT_DISPATCH_ALL_TYPES(result.type(), "eye", [&]() -> void {
|
||||
scalar_t* result_data = result.data<scalar_t>();
|
||||
for(int64_t i = 0; i < sz; i++) {
|
||||
result_data[i*(result.strides()[0] + result.strides()[1])] = 1;
|
||||
}
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor linspace(const Type& dtype, Scalar start, Scalar end, int64_t steps) {
|
||||
return dtype._linspace(start, end, steps);
|
||||
}
|
||||
|
||||
Tensor& linspace_out(Tensor& result, Scalar start, Scalar end, int64_t steps) {
|
||||
return at::_linspace_out(result, start, end, steps);
|
||||
}
|
||||
|
||||
Tensor logspace(const Type& dtype, Scalar start, Scalar end, int64_t steps) {
|
||||
return dtype._logspace(start, end, steps);
|
||||
}
|
||||
|
||||
Tensor& logspace_out(Tensor& result, Scalar start, Scalar end, int64_t steps) {
|
||||
return at::_logspace_out(result, start, end, steps);
|
||||
}
|
||||
|
||||
Tensor ones(const Type& dtype, IntList size) {
|
||||
auto result = dtype.tensor(size);
|
||||
return result.fill_(1);
|
||||
}
|
||||
|
||||
Tensor& ones_out(Tensor& result, IntList size) {
|
||||
result.resize_(size);
|
||||
return result.fill_(1);
|
||||
}
|
||||
|
||||
Tensor ones_like(const Tensor& self) {
|
||||
return self.type().ones(self.sizes());
|
||||
return at::native::ones(self.type(), self.sizes());
|
||||
}
|
||||
|
||||
Tensor ones_like(const Tensor& self, const Type& dtype) {
|
||||
return dtype.ones(self.sizes());
|
||||
return at::native::ones(dtype, self.sizes());
|
||||
}
|
||||
|
||||
Tensor rand(const Type& dtype, IntList size, Generator* generator) {
|
||||
Tensor result = dtype.tensor(size);
|
||||
return result.uniform_(0, 1, generator);
|
||||
}
|
||||
|
||||
Tensor& rand_out(Tensor& result, IntList size, Generator* generator) {
|
||||
result.resize_(size);
|
||||
return result.uniform_(0, 1, generator);
|
||||
}
|
||||
|
||||
Tensor rand_like(const Tensor& self) {
|
||||
return self.type().rand(self.sizes());
|
||||
return at::native::rand_like(self, self.type());
|
||||
}
|
||||
|
||||
Tensor rand_like(const Tensor& self, const Type& dtype) {
|
||||
return dtype.rand(self.sizes());
|
||||
return at::native::rand(dtype, self.sizes());
|
||||
}
|
||||
|
||||
Tensor randn(const Type& dtype, IntList size, Generator* generator) {
|
||||
Tensor result = dtype.tensor(size);
|
||||
return result.normal_(0, 1, generator);
|
||||
}
|
||||
|
||||
Tensor& randn_out(Tensor& result, IntList size, Generator* generator) {
|
||||
result.resize_(size);
|
||||
return result.normal_(0, 1, generator);
|
||||
}
|
||||
|
||||
Tensor randn_like(const Tensor& self) {
|
||||
return self.type().randn(self.sizes());
|
||||
return at::native::randn_like(self, self.type());
|
||||
}
|
||||
|
||||
Tensor randn_like(const Tensor& self, const Type& dtype) {
|
||||
return dtype.randn(self.sizes());
|
||||
return at::native::randn(dtype, self.sizes(), nullptr);
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void randperm_cpu(Tensor& result, int64_t n, THGenerator* generator) {
|
||||
scalar_t *r__data = result.data<scalar_t>();
|
||||
|
||||
result.resize_({n});
|
||||
int64_t r__stride_0 = result.stride(0);
|
||||
|
||||
for(int64_t i = 0; i < n; i++) {
|
||||
r__data[i*r__stride_0] = static_cast<scalar_t>(i);
|
||||
}
|
||||
|
||||
for(int64_t i = 0; i < n - 1; i++)
|
||||
{
|
||||
int64_t z = THRandom_random(generator) % (n-i);
|
||||
scalar_t sav = r__data[i*r__stride_0];
|
||||
r__data[i*r__stride_0] = r__data[(z+i)*r__stride_0];
|
||||
r__data[(z+i)*r__stride_0] = sav;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
||||
THGenerator* get_generator(at::Generator* gen) {
|
||||
auto default_gen = &at::globalContext().defaultGenerator(at::Backend::CPU);
|
||||
auto gen_ = at::check_generator<at::CPUGenerator>(gen, default_gen);
|
||||
return gen_->generator;
|
||||
}
|
||||
|
||||
Tensor randperm(const Type& dtype, int64_t n, Generator* generator) {
|
||||
Tensor result = dtype.tensor(n);
|
||||
return at::native::randperm_out(result, n, generator);
|
||||
}
|
||||
|
||||
Tensor& randperm_out(Tensor& result, int64_t n, Generator* generator) {
|
||||
if (n <= 0) {
|
||||
std::ostringstream oss;
|
||||
oss << "n must be strictly positive, got " << n;
|
||||
throw std::runtime_error(oss.str());
|
||||
}
|
||||
if (result.type().backend() != at::kCPU) {
|
||||
throw std::runtime_error("randperm is only implemented for CPU");
|
||||
}
|
||||
|
||||
result.resize_({n});
|
||||
auto gen = get_generator(generator);
|
||||
AT_DISPATCH_ALL_TYPES(result.type(), "randperm", [&]() -> void {
|
||||
randperm_cpu<scalar_t>(result, n, gen);
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor range(const Type& dtype, Scalar start, Scalar end, Scalar step) {
|
||||
return dtype._range(start, end, step);
|
||||
}
|
||||
|
||||
Tensor& range_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
|
||||
return at::_range_out(result, start, end, step);
|
||||
}
|
||||
|
||||
Tensor zeros(const Type& dtype, IntList size) {
|
||||
auto result = dtype.tensor(size);
|
||||
return at::native::zeros_out(result, size);
|
||||
}
|
||||
|
||||
Tensor& zeros_out(Tensor& result, IntList size) {
|
||||
if (result.is_sparse()) {
|
||||
result.sparse_raw_resize_(size, size.size(), 0);
|
||||
} else {
|
||||
result.resize_(size);
|
||||
}
|
||||
return result.zero_();
|
||||
}
|
||||
|
||||
Tensor zeros_like(const Tensor& self) {
|
||||
@ -47,7 +219,7 @@ Tensor zeros_like(const Tensor& self, const Type& dtype) {
|
||||
res.zero_();
|
||||
return res;
|
||||
}
|
||||
return dtype.zeros(self.sizes());
|
||||
return at::native::zeros(dtype, self.sizes());
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -202,7 +202,7 @@ Tensor embedding_backward_cuda(const Tensor & grad_, const Tensor & indices,
|
||||
|
||||
auto num_indices = indices.numel();
|
||||
auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)});
|
||||
auto grad_weight = grad_.type().zeros({num_weights, grad_.size(-1)});
|
||||
auto grad_weight = at::zeros(grad_.type(), {num_weights, grad_.size(-1)});
|
||||
|
||||
int64_t stride = grad_weight.stride(0);
|
||||
cudaStream_t stream = globalContext().getCurrentCUDAStream();
|
||||
|
@ -162,13 +162,13 @@ embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
|
||||
int64_t numBags = offsets.sizes()[0];
|
||||
int64_t stride = weight.sizes()[1];
|
||||
|
||||
auto bag_size = indices.type().zeros(offsets.sizes());
|
||||
auto bag_size = at::zeros(indices.type(), offsets.sizes());
|
||||
auto offset2bag =
|
||||
indices.type().zeros({indices.sizes()[0]}); // offset2bag = [0 0 0 0 0]
|
||||
at::zeros(indices.type(), {indices.sizes()[0]}); // offset2bag = [0 0 0 0 0]
|
||||
|
||||
cudaStream_t stream = globalContext().getCurrentCUDAStream();
|
||||
|
||||
auto output = weight.type().zeros({offsets.sizes()[0], weight.sizes()[1]});
|
||||
auto output = at::zeros(weight.type(), {offsets.sizes()[0], weight.sizes()[1]});
|
||||
|
||||
dim3 block = dim3(32, 8);
|
||||
int grid = 1024;
|
||||
@ -204,7 +204,7 @@ Tensor embedding_bag_backward_cuda(const Tensor &grad_, const Tensor &indices,
|
||||
|
||||
Tensor &bag_size = const_cast<Tensor &>(bag_size_);
|
||||
|
||||
auto grad_weight = grad_.type().zeros({num_weights, grad.sizes()[1]});
|
||||
auto grad_weight = at::zeros(grad_.type(), {num_weights, grad.sizes()[1]});
|
||||
|
||||
int nDim = indices.ndimension();
|
||||
|
||||
|
29
aten/src/ATen/native/cuda/TensorFactories.cu
Normal file
29
aten/src/ATen/native/cuda/TensorFactories.cu
Normal file
@ -0,0 +1,29 @@
|
||||
#include "ATen/NativeFunctions.h"
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) {
|
||||
if (n <= 0) {
|
||||
std::ostringstream oss;
|
||||
oss << "n must be greater than 0, got: " << n;
|
||||
std::runtime_error(oss.str());
|
||||
}
|
||||
if(m <= 0) {
|
||||
m = n;
|
||||
}
|
||||
|
||||
result.resize_({n, m});
|
||||
result.zero_();
|
||||
|
||||
int64_t sz = std::min<int64_t>(n, m);
|
||||
int64_t stride = result.stride(0) + result.stride(1);
|
||||
|
||||
Tensor diag = result.as_strided({sz}, {stride});
|
||||
diag.fill_(1);
|
||||
return result;
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
@ -23,6 +23,18 @@
|
||||
- func: addr_out(Tensor result, Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: arange(Type dtype, Scalar start, Scalar end, Scalar step=1) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: arange_out(Tensor result, Scalar start, Scalar end, Scalar step=1) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: arange(Type dtype, Scalar end) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: arange_out(Tensor reuslt, Scalar end) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, double momentum, double eps, bool cudnn_enabled) -> Tensor
|
||||
variants: function
|
||||
|
||||
@ -207,6 +219,15 @@
|
||||
- func: expand_as(Tensor self, Tensor other) -> Tensor
|
||||
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
|
||||
|
||||
- func: eye(Type dtype, int64_t n, int64_t m=-1) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: eye_out(Tensor result, int64_t n, int64_t m=-1) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: eye_out_cpu
|
||||
CUDA: eye_out_cuda
|
||||
|
||||
- func: hinge_embedding_loss(Tensor self, Tensor target, double margin, bool size_average, bool reduce) -> Tensor
|
||||
variants: function
|
||||
|
||||
@ -234,6 +255,18 @@
|
||||
|
||||
- func: is_sparse(Tensor self) -> bool
|
||||
|
||||
- func: linspace(Type dtype, Scalar start, Scalar end, int64_t steps=100) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: linspace_out(Tensor result, Scalar start, Scalar end, int64_t steps=100) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: logspace(Type dtype, Scalar start, Scalar end, int64_t steps=100) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: logspace_out(Tensor result, Scalar start, Scalar end, int64_t steps=100) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: matmul(Tensor self, Tensor other) -> Tensor
|
||||
|
||||
- func: max_pool1d(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, IntList[1] dilation=1, bool ceil_mode=false) -> (Tensor, Tensor)
|
||||
@ -265,6 +298,12 @@
|
||||
- func: nnpack_spatial_convolution_backward_weight(Tensor input, IntList weight_size, Tensor grad_output, int64_t kW, int64_t kH, int64_t padW, int64_t padH) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: ones(Type dtype, IntList size) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: ones_out(Tensor result, IntList size) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: ones_like(Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
||||
@ -276,18 +315,42 @@
|
||||
|
||||
- func: pin_memory(Tensor self) -> Tensor
|
||||
|
||||
- func: rand(Type dtype, IntList size, *, Generator* generator=nullptr) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: rand_out(Tensor result, IntList size, *, Generator* generator=nullptr) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: rand_like(Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: rand_like(Tensor self, *, Type dtype) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: randn(Type dtype, IntList size, *, Generator* generator=nullptr) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: randn_out(Tensor result, IntList size, *, Generator* generator=nullptr) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: randn_like(Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: randn_like(Tensor self, *, Type dtype) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: randperm(Type dtype, int64_t n, *, Generator* generator=nullptr) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: randperm_out(Tensor result, int64_t n, *, Generator* generator=nullptr) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: range(Type dtype, Scalar start, Scalar end, Scalar step=1) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: range_out(Tensor result, Scalar start, Scalar end, Scalar step=1) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: repeat(Tensor self, IntList repeats) -> Tensor
|
||||
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
|
||||
|
||||
@ -393,6 +456,12 @@
|
||||
CPU: _s_where_cpu
|
||||
CUDA: _s_where_cuda
|
||||
|
||||
- func: zeros(Type dtype, IntList size) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: zeros_out(Tensor result, IntList size) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: zeros_like(Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
||||
|
@ -12,7 +12,7 @@ void check(bool c) {
|
||||
}
|
||||
|
||||
void trace() {
|
||||
Tensor foo = CPU(kFloat).rand({12,12});
|
||||
Tensor foo = rand(CPU(kFloat), {12,12});
|
||||
|
||||
// ASSERT foo is 2-dimensional and holds floats.
|
||||
auto foo_a = foo.accessor<float,2>();
|
||||
@ -26,7 +26,7 @@ void trace() {
|
||||
int main() {
|
||||
manual_seed(123);
|
||||
|
||||
auto foo = CPU(kFloat).rand({12,6});
|
||||
auto foo = rand(CPU(kFloat), {12,6});
|
||||
ASSERT(foo.data<float>() == foo.toFloatData());
|
||||
|
||||
cout << foo << "\n" << foo.size(0) << " " << foo.size(1) << endl;
|
||||
|
@ -32,7 +32,7 @@ static void test(Type & type) {
|
||||
|
||||
{
|
||||
std::cout << "ones and dot:" << std::endl;
|
||||
Tensor b = type.ones({3, 4});
|
||||
Tensor b = ones(type, {3, 4});
|
||||
std::cout << b << std::endl;
|
||||
ASSERT(24 == (b+b).sum().toCDouble());
|
||||
std::cout << b.numel() << std::endl;
|
||||
@ -44,14 +44,14 @@ static void test(Type & type) {
|
||||
{
|
||||
std::cout << "rand:" << std::endl;
|
||||
for(auto i = 0; i < 10; i++) {
|
||||
Tensor a = type.toScalarType(i % 2 == 0 ? kFloat : kDouble).rand({3,4});
|
||||
Tensor a = rand(type.toScalarType(i % 2 == 0 ? kFloat : kDouble), {3,4});
|
||||
std::cout << a << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
std::cout << "sort:" << std::endl;
|
||||
Tensor b = type.rand({3, 4});
|
||||
Tensor b = rand(type, {3, 4});
|
||||
|
||||
std::cout << b << std::endl;
|
||||
auto z = b.sort(1);
|
||||
@ -61,7 +61,7 @@ static void test(Type & type) {
|
||||
if(type.backend() != kCUDA)
|
||||
{
|
||||
std::cout << "randperm:" << std::endl;
|
||||
Tensor b = type.randperm(15);
|
||||
Tensor b = randperm(type, 15);
|
||||
std::cout << b << std::endl;
|
||||
Tensor rv, ri;
|
||||
std::tie(rv, ri) = sort(b, 0);
|
||||
@ -75,8 +75,8 @@ static void test(Type & type) {
|
||||
|
||||
{
|
||||
std::cout << "add:" << std::endl;
|
||||
Tensor a = type.rand({3, 4});
|
||||
Tensor b = type.rand({3, 4});
|
||||
Tensor a = rand(type, {3, 4});
|
||||
Tensor b = rand(type, {3, 4});
|
||||
std::cout << a << std::endl;
|
||||
std::cout << b << std::endl;
|
||||
Tensor c = add(a, add(a, b));
|
||||
@ -91,8 +91,8 @@ static void test(Type & type) {
|
||||
{
|
||||
std::cout << "loads of adds:" << std::endl;
|
||||
auto begin = std::chrono::high_resolution_clock::now();
|
||||
Tensor d = type.ones({3, 4});
|
||||
Tensor r = type.zeros({3,4});
|
||||
Tensor d = ones(type, {3, 4});
|
||||
Tensor r = zeros(type, {3, 4});
|
||||
for(auto i = 0; i < 100000; i++) {
|
||||
add_out(r, r, d);
|
||||
}
|
||||
@ -105,8 +105,8 @@ static void test(Type & type) {
|
||||
{
|
||||
std::cout << "loads of adds (with copy):" << std::endl;
|
||||
auto begin = std::chrono::high_resolution_clock::now();
|
||||
Tensor d = type.ones({3, 4});
|
||||
Tensor r = type.zeros({3, 4});
|
||||
Tensor d = ones(type, {3, 4});
|
||||
Tensor r = zeros(type, {3, 4});
|
||||
for(auto i = 0; i < 100000; i++) {
|
||||
r = add(r, d);
|
||||
}
|
||||
@ -118,7 +118,7 @@ static void test(Type & type) {
|
||||
|
||||
{
|
||||
std::cout << "isContiguous:" << std::endl;
|
||||
Tensor a = type.rand({3, 4});
|
||||
Tensor a = rand(type, {3, 4});
|
||||
std::cout << a.is_contiguous() << std::endl;
|
||||
ASSERT(a.is_contiguous());
|
||||
a = a.transpose(0, 1);
|
||||
@ -126,7 +126,7 @@ static void test(Type & type) {
|
||||
}
|
||||
|
||||
{
|
||||
Tensor a = type.rand({3, 4, 5});
|
||||
Tensor a = rand(type, {3, 4, 5});
|
||||
Tensor b = a.permute({1, 2, 0});
|
||||
ASSERT(b.sizes().equals({4, 5, 3}));
|
||||
ASSERT(b.strides().equals({5, 1, 20}));
|
||||
@ -134,23 +134,23 @@ static void test(Type & type) {
|
||||
|
||||
{
|
||||
std::cout << "mm:" << std::endl;
|
||||
Tensor a = type.rand({3, 4});
|
||||
Tensor b = type.rand({4});
|
||||
Tensor a = rand(type, {3, 4});
|
||||
Tensor b = rand(type, {4});
|
||||
Tensor c = mv(a, b);
|
||||
std::cout << a << std::endl;
|
||||
std::cout << b << std::endl;
|
||||
std::cout << c << std::endl;
|
||||
ASSERT(c.equal(addmv(type.zeros({3}), a, b, 0, 1)));
|
||||
ASSERT(c.equal(addmv(zeros(type, {3}), a, b, 0, 1)));
|
||||
}
|
||||
|
||||
{
|
||||
std::cout << "squeeze:" << std::endl;
|
||||
Tensor a = type.rand({2, 1});
|
||||
Tensor a = rand(type, {2, 1});
|
||||
std::cout << a << std::endl;
|
||||
Tensor b = squeeze(a);
|
||||
ASSERT(b.dim() == 1);
|
||||
std::cout << b << std::endl;
|
||||
a = type.rand({1});
|
||||
a = rand(type, {1});
|
||||
std::cout << a << std::endl;
|
||||
b = squeeze(a);
|
||||
//TODO 0-dim squeeze
|
||||
@ -159,9 +159,9 @@ static void test(Type & type) {
|
||||
|
||||
{
|
||||
std::cout << "copy:" << std::endl;
|
||||
Tensor a = type.zeros({4, 3});
|
||||
Tensor a = zeros(type, {4, 3});
|
||||
std::cout << a << std::endl;
|
||||
Tensor e = type.rand({4, 3});
|
||||
Tensor e = rand(type, {4, 3});
|
||||
std::cout << e << std::endl;
|
||||
a.copy_(e);
|
||||
std::cout << a << std::endl;
|
||||
@ -170,8 +170,8 @@ static void test(Type & type) {
|
||||
|
||||
{
|
||||
std::cout << "copy [broadcasting]:" << std::endl;
|
||||
Tensor a = type.zeros({4, 3});
|
||||
Tensor e = type.rand({3});
|
||||
Tensor a = zeros(type, {4, 3});
|
||||
Tensor e = rand(type, {3});
|
||||
a.copy_(e);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
ASSERT(a[i].equal(e));
|
||||
@ -198,15 +198,15 @@ static void test(Type & type) {
|
||||
|
||||
{
|
||||
std::cout << "adding a value with a salar:" << std::endl;
|
||||
Tensor a = type.rand({4, 3});
|
||||
Tensor a = rand(type, {4, 3});
|
||||
std::cout << a << std::endl;
|
||||
std::cout << add(a, 1) << std::endl;
|
||||
ASSERT((type.ones({4,3}) + a).equal(add(a,1)));
|
||||
ASSERT((ones(type, {4,3}) + a).equal(add(a,1)));
|
||||
}
|
||||
|
||||
{
|
||||
std::cout << "select:" << std::endl;
|
||||
Tensor a = type.rand({3, 7});
|
||||
Tensor a = rand(type, {3, 7});
|
||||
std::cout << a << std::endl;
|
||||
std::cout << select(a, 1, 3) << std::endl;
|
||||
std::cout << select(select(a, 1, 3), 0, 2) << std::endl;
|
||||
@ -214,20 +214,20 @@ static void test(Type & type) {
|
||||
|
||||
{
|
||||
std::cout << "zero-dim: " << std::endl;
|
||||
Tensor a = type.scalarTensor(4); //type.rand({1});
|
||||
Tensor a = type.scalarTensor(4); //rand(type, {1});
|
||||
|
||||
std::cout << a << "dims: " << a.dim() << std::endl;
|
||||
std::cout << Scalar(a) << std::endl;
|
||||
Tensor b = type.rand({3,4});
|
||||
Tensor b = rand(type, {3,4});
|
||||
std::cout << b + a << std::endl;
|
||||
std::cout << a + b << std::endl;
|
||||
ASSERT((a+a).dim() == 0);
|
||||
ASSERT((1+a).dim() == 0);
|
||||
auto c = type.rand({3,4});
|
||||
auto c = rand(type, {3,4});
|
||||
std::cout << c[1][2] << std::endl;
|
||||
|
||||
auto f = type.rand({3,4});
|
||||
f[2] = type.zeros({4});
|
||||
auto f = rand(type, {3,4});
|
||||
f[2] = zeros(type, {4});
|
||||
f[1][0] = -1;
|
||||
std:: cout << f << std::endl;
|
||||
ASSERT(Scalar(f[2][0]).toDouble() == 0);
|
||||
@ -240,18 +240,18 @@ static void test(Type & type) {
|
||||
std::cout << tt << std::endl;
|
||||
}
|
||||
{
|
||||
Tensor a = CPU(kFloat).zeros({3,4});
|
||||
Tensor b = CPU(kFloat).ones({3,7});
|
||||
Tensor a = zeros(CPU(kFloat), {3,4});
|
||||
Tensor b = ones(CPU(kFloat), {3,7});
|
||||
Tensor c = cat({a,b},1);
|
||||
std::cout << c.sizes() << std::endl;
|
||||
ASSERT(c.size(1) == 11);
|
||||
std::cout << c << std::endl;
|
||||
|
||||
Tensor e = CPU(kFloat).rand({});
|
||||
Tensor e = rand(CPU(kFloat), {});
|
||||
ASSERT(*e.data<float>()== e.sum().toCFloat());
|
||||
}
|
||||
{
|
||||
Tensor b = CPU(kFloat).ones({3,7})*.0000001f;
|
||||
Tensor b = ones(CPU(kFloat), {3,7})*.0000001f;
|
||||
std::stringstream s;
|
||||
s << b << "\n";
|
||||
std::string expect = "1e-07 *";
|
||||
|
@ -12,35 +12,35 @@ int main() {
|
||||
// 0) pre-req tests:
|
||||
// can't expand empty tensor
|
||||
{
|
||||
auto empty = T.randn({0});
|
||||
auto empty = randn(T, {0});
|
||||
ASSERT_THROWS(empty.expand({3}));
|
||||
}
|
||||
|
||||
// 1) out-place function with 2 args
|
||||
{
|
||||
// basic
|
||||
auto a = T.randn({3, 1});
|
||||
auto b = T.randn({5});
|
||||
auto a = randn(T, {3, 1});
|
||||
auto b = randn(T, {5});
|
||||
std::vector<int64_t> expanded_sizes = {3, 5};
|
||||
ASSERT((a + b).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes)));
|
||||
|
||||
// with scalar
|
||||
auto aScalar = T.ones({1});
|
||||
auto aScalar = ones(T, {1});
|
||||
aScalar.get()->maybeScalar(true);
|
||||
b = T.randn({3, 5});
|
||||
b = randn(T, {3, 5});
|
||||
ASSERT((aScalar + b).equal(aScalar.expand(b.sizes()) + b.expand(b.sizes())));
|
||||
|
||||
// old fallback behavior yields error
|
||||
{
|
||||
auto a = T.randn({3, 5});
|
||||
auto b = T.randn({5, 3});
|
||||
auto a = randn(T, {3, 5});
|
||||
auto b = randn(T, {5, 3});
|
||||
ASSERT_THROWS(a + b);
|
||||
}
|
||||
|
||||
// with mismatched sizes
|
||||
{
|
||||
auto a = T.randn({3, 5});
|
||||
auto b = T.randn({7, 5});
|
||||
auto a = randn(T, {3, 5});
|
||||
auto b = randn(T, {7, 5});
|
||||
ASSERT_THROWS(a + b);
|
||||
}
|
||||
}
|
||||
@ -48,31 +48,31 @@ int main() {
|
||||
// 2) out-place function with 3 args
|
||||
{
|
||||
// basic
|
||||
auto a = T.randn({3, 1, 1});
|
||||
auto b = T.randn({1, 2, 1});
|
||||
auto c = T.randn({1, 1, 5});
|
||||
auto a = randn(T, {3, 1, 1});
|
||||
auto b = randn(T, {1, 2, 1});
|
||||
auto c = randn(T, {1, 1, 5});
|
||||
std::vector<int64_t> expanded_sizes = {3, 2, 5};
|
||||
ASSERT((a + b + c).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes) + c.expand(expanded_sizes)));
|
||||
|
||||
// with scalar
|
||||
auto aTensorScalar = T.ones({1});
|
||||
auto aTensorScalar = ones(T, {1});
|
||||
aTensorScalar.get()->maybeScalar(true);
|
||||
b = T.randn({3, 2, 1});
|
||||
c = T.randn({1, 2, 5});
|
||||
b = randn(T, {3, 2, 1});
|
||||
c = randn(T, {1, 2, 5});
|
||||
ASSERT(aTensorScalar.addcmul(b, c).equal(
|
||||
aTensorScalar.expand(expanded_sizes).addcmul(b.expand(expanded_sizes), c.expand(expanded_sizes))));
|
||||
|
||||
// old fallback behavior yields error
|
||||
{
|
||||
auto a = T.randn({3, 2, 5});
|
||||
auto b = T.randn({2, 3, 5});
|
||||
auto c = T.randn({5, 3, 2});
|
||||
auto a = randn(T, {3, 2, 5});
|
||||
auto b = randn(T, {2, 3, 5});
|
||||
auto c = randn(T, {5, 3, 2});
|
||||
ASSERT_THROWS(a.addcmul(b, c));
|
||||
}
|
||||
|
||||
// with mismatched sizes
|
||||
{
|
||||
auto c = T.randn({5, 5, 5});
|
||||
auto c = randn(T, {5, 5, 5});
|
||||
ASSERT_THROWS(a.addcmul(b, c));
|
||||
}
|
||||
}
|
||||
@ -80,19 +80,19 @@ int main() {
|
||||
// 3) in-place function with 2 args
|
||||
{
|
||||
// basic
|
||||
auto a = T.randn({3, 5});
|
||||
auto b = T.randn({3, 1});
|
||||
auto a = randn(T, {3, 5});
|
||||
auto b = randn(T, {3, 1});
|
||||
ASSERT((a + b).equal(a + b.expand({3, 5})));
|
||||
|
||||
// with scalar
|
||||
auto bScalar = T.ones({1});
|
||||
auto bScalar = ones(T, {1});
|
||||
bScalar.get()->maybeScalar(true);
|
||||
ASSERT((a + bScalar).equal(a + bScalar.expand(a.sizes())));
|
||||
|
||||
// error: would have to expand inplace arg
|
||||
{
|
||||
auto a = T.randn({1, 5});
|
||||
auto b = T.randn({3, 1});
|
||||
auto a = randn(T, {1, 5});
|
||||
auto b = randn(T, {3, 1});
|
||||
ASSERT_THROWS(a.add_(b));
|
||||
}
|
||||
}
|
||||
@ -100,23 +100,23 @@ int main() {
|
||||
// 4) in-place function with 3 args
|
||||
{
|
||||
// basic
|
||||
auto a = T.randn({3, 5, 2});
|
||||
auto a = randn(T, {3, 5, 2});
|
||||
auto aClone = a.clone();
|
||||
auto b = T.randn({3, 1, 2});
|
||||
auto c = T.randn({1, 5, 1});
|
||||
auto b = randn(T, {3, 1, 2});
|
||||
auto c = randn(T, {1, 5, 1});
|
||||
|
||||
ASSERT(a.addcmul_(b, c).equal(aClone.addcmul_(b.expand(a.sizes()), c.expand(a.sizes()))));
|
||||
|
||||
// with scalar
|
||||
auto bScalar = T.ones({1});
|
||||
auto bScalar = ones(T, {1});
|
||||
bScalar.get()->maybeScalar(true);
|
||||
ASSERT(a.addcmul_(bScalar, c).equal(aClone.addcmul_(bScalar.expand(a.sizes()), c.expand(a.sizes()))));
|
||||
|
||||
// error: would have to expand inplace arg
|
||||
{
|
||||
auto a = T.randn({1, 3, 5});
|
||||
auto b = T.randn({4, 1, 1});
|
||||
auto c = T.randn({1, 3, 1});
|
||||
auto a = randn(T, {1, 3, 5});
|
||||
auto b = randn(T, {4, 1, 1});
|
||||
auto c = randn(T, {1, 3, 1});
|
||||
ASSERT_THROWS(a.addcmul_(b, c));
|
||||
}
|
||||
}
|
||||
@ -124,19 +124,19 @@ int main() {
|
||||
// explicit dim specification
|
||||
{
|
||||
// basic
|
||||
auto a = T.randn({1});
|
||||
auto b = T.randn({5, 3});
|
||||
auto c = T.randn({3, 7});
|
||||
auto a = randn(T, {1});
|
||||
auto b = randn(T, {5, 3});
|
||||
auto c = randn(T, {3, 7});
|
||||
ASSERT(a.addmm(b, c).equal(a.expand({5,7}).addmm(b, c)));
|
||||
|
||||
// with scalar
|
||||
Tensor aScalar = T.ones({1});
|
||||
Tensor aScalar = ones(T, {1});
|
||||
aScalar.get()->maybeScalar(true);
|
||||
ASSERT(aScalar.addmm(b, c).equal(aScalar.expand({5, 7}).addmm(b, c)));
|
||||
|
||||
// with mismatched sizes
|
||||
{
|
||||
auto a = T.randn({3, 3});
|
||||
auto a = randn(T, {3, 3});
|
||||
ASSERT_THROWS(a.addmm(b, c));
|
||||
}
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ using namespace at;
|
||||
static void test() {
|
||||
{
|
||||
std::cout << "dlconvertor: convert ATen to DLTensor" << std::endl;
|
||||
Tensor a = CPU(at::kFloat).rand({3,4});
|
||||
Tensor a = rand(CPU(at::kFloat), {3,4});
|
||||
std::cout << a.numel() << std::endl;
|
||||
DLManagedTensor* dlMTensor = toDLPack(a);
|
||||
std::cout << "dlconvertor: convert DLTensor to ATen" << std::endl;
|
||||
|
@ -12,7 +12,7 @@ void assertEqualTensorList(TensorList t1, TensorList t2) {
|
||||
}
|
||||
|
||||
void test(Type & T, Type & AccT) {
|
||||
auto t = T.randn({3, 3});
|
||||
auto t = randn(T, {3, 3});
|
||||
// split
|
||||
{
|
||||
// test method, type, namespace give same result
|
||||
@ -40,9 +40,9 @@ void test(Type & T, Type & AccT) {
|
||||
|
||||
// stack
|
||||
{
|
||||
auto x = T.rand({2, 3, 4});
|
||||
auto y = T.rand({2, 3, 4});
|
||||
auto z = T.rand({2, 3, 4});
|
||||
auto x = rand(T, {2, 3, 4});
|
||||
auto y = rand(T, {2, 3, 4});
|
||||
auto z = rand(T, {2, 3, 4});
|
||||
for (int64_t dim = 0; dim < 4; ++dim) {
|
||||
auto res = at::stack({x, y, z}, dim);
|
||||
auto res_neg = at::stack({x, y, z}, dim - 4);
|
||||
@ -61,13 +61,13 @@ void test(Type & T, Type & AccT) {
|
||||
|
||||
// size / stride
|
||||
{
|
||||
auto scalar = T.randn({});
|
||||
auto scalar = randn(T, {});
|
||||
ASSERT_THROWSM(scalar.size(0), "dimension specified as 0 but tensor has no dimensions");
|
||||
ASSERT_THROWSM(scalar.size(-1), "dimension specified as -1 but tensor has no dimensions");
|
||||
ASSERT_THROWSM(scalar.stride(0), "dimension specified as 0 but tensor has no dimensions");
|
||||
ASSERT_THROWSM(scalar.stride(-1), "dimension specified as -1 but tensor has no dimensions");
|
||||
|
||||
auto empty = T.randn({0});
|
||||
auto empty = randn(T, {0});
|
||||
ASSERT(empty.size(0) == 0);
|
||||
ASSERT(empty.size(-1) == 0);
|
||||
ASSERT(empty.stride(0) == 1);
|
||||
@ -76,9 +76,9 @@ void test(Type & T, Type & AccT) {
|
||||
|
||||
// matmul
|
||||
{
|
||||
auto scalar = T.randn({});
|
||||
auto d1 = T.randn({3});
|
||||
auto d2 = T.randn({2, 3});
|
||||
auto scalar = randn(T, {});
|
||||
auto d1 = randn(T, {3});
|
||||
auto d2 = randn(T, {2, 3});
|
||||
|
||||
// 0-d
|
||||
ASSERT_THROWSM(scalar.matmul(d2), "both arguments to matmul need to be at least 1D");
|
||||
@ -87,19 +87,19 @@ void test(Type & T, Type & AccT) {
|
||||
// 1-d
|
||||
ASSERT_ALLCLOSE(d1.matmul(d1), d1.dot(d1));
|
||||
ASSERT_ALLCLOSE(d2.matmul(d1), d2.mv(d1));
|
||||
auto d1o = T.randn({2});
|
||||
auto d1o = randn(T, {2});
|
||||
ASSERT_ALLCLOSE(d1o.matmul(d2), d1o.unsqueeze(0).mm(d2).squeeze(0));
|
||||
|
||||
// 2-d
|
||||
auto d2o = T.randn({3, 5});
|
||||
auto d2o = randn(T, {3, 5});
|
||||
ASSERT_ALLCLOSE(d2.matmul(d2o), d2.mm(d2o));
|
||||
|
||||
// > 2-d, 1-d
|
||||
auto d3 = T.randn({5, 2, 3});
|
||||
auto d3 = randn(T, {5, 2, 3});
|
||||
ASSERT_ALLCLOSE(d3.matmul(d1), d3.bmm(d1.view({1, 3, 1}).expand({5, 3, 1})).view({5, 2}));
|
||||
ASSERT_ALLCLOSE(d1o.matmul(d3), d1o.expand({5, 1, 2}).bmm(d3).view({5, 3}));
|
||||
|
||||
auto d5 = T.randn({3, 2, 4, 2, 3});
|
||||
auto d5 = randn(T, {3, 2, 4, 2, 3});
|
||||
ASSERT_ALLCLOSE(d5.matmul(d1), d5.view({24, 2, 3}).bmm(d1.view({1, 3, 1}).expand({24, 3, 1})).view({3, 2, 4, 2}));
|
||||
ASSERT_ALLCLOSE(d1o.matmul(d5), d1o.expand({24, 1, 2}).bmm(d5.view({24, 2, 3})).view({3, 2, 4, 3}));
|
||||
|
||||
@ -109,8 +109,8 @@ void test(Type & T, Type & AccT) {
|
||||
// Tolerances are selected empirically.
|
||||
double atol = 1e-04;
|
||||
double rtol = 1e-06;
|
||||
d2 = T.randn({3, 4});
|
||||
d2o = T.randn({4, 2});
|
||||
d2 = randn(T, {3, 4});
|
||||
d2o = randn(T, {4, 2});
|
||||
auto result = d5.matmul(d2).toType(AccT);
|
||||
|
||||
auto d5Acc = d5.toType(AccT);
|
||||
@ -120,37 +120,37 @@ void test(Type & T, Type & AccT) {
|
||||
ASSERT_ALLCLOSE(d2o.matmul(d5), d2o.expand({24, 4, 2}).bmm(d5.view({24, 2, 3})).view({3, 2, 4, 4, 3}));
|
||||
|
||||
// > 2-d, > 2-d
|
||||
auto d5o = T.randn({2, 1, 2, 4, 3, 2});
|
||||
auto d5o = randn(T, {2, 1, 2, 4, 3, 2});
|
||||
auto d5_bmm_view = d5.expand({2, 3, 2, 4, 2, 3}).contiguous().view({48, 2, 3});
|
||||
auto d5o_bmm_view = d5o.expand({2, 3, 2, 4, 3, 2}).contiguous().view({48, 3, 2});
|
||||
ASSERT_ALLCLOSE(d5.matmul(d5o), d5_bmm_view.bmm(d5o_bmm_view).view({2, 3, 2, 4, 2, 2}));
|
||||
|
||||
// non-expandable case
|
||||
auto d5wrong = T.randn({2, 4, 2, 4, 3, 2});
|
||||
auto d5wrong = randn(T, {2, 4, 2, 4, 3, 2});
|
||||
ASSERT_THROWSM(d5.matmul(d5wrong), "must match the size");
|
||||
}
|
||||
|
||||
// _standard_gamma_grad
|
||||
if (!T.is_cuda()) {
|
||||
// check empty
|
||||
auto empty = T.ones({0});
|
||||
auto empty = ones(T, {0});
|
||||
ASSERT_EQUAL(empty, empty._standard_gamma_grad(empty));
|
||||
|
||||
// check scalar equals one element
|
||||
auto one_scalar = T.ones({}).mul(5);
|
||||
auto one_with_dim = T.ones({1}).mul(5);
|
||||
auto one_scalar = ones(T, {}).mul(5);
|
||||
auto one_with_dim = ones(T, {1}).mul(5);
|
||||
ASSERT_ALLCLOSE(one_scalar._standard_gamma_grad(one_scalar),
|
||||
one_with_dim._standard_gamma_grad(one_with_dim).sum());
|
||||
|
||||
// check mixing types
|
||||
Type & DT = CPU(kDouble);
|
||||
auto t1 = T.randn({3, 4});
|
||||
auto t2 = DT.randn({3, 4});
|
||||
auto t1 = randn(T, {3, 4});
|
||||
auto t2 = randn(DT, {3, 4});
|
||||
ASSERT_THROWSM(t1._standard_gamma_grad(t2), "expected scalar type");
|
||||
} else {
|
||||
auto ct1 = T.randn({3, 4});
|
||||
auto ct2 = T.randn({3, 4});
|
||||
auto t1 = T.toBackend(Backend::CPU).randn({3, 4});
|
||||
auto ct1 = randn(T, {3, 4});
|
||||
auto ct2 = randn(T, {3, 4});
|
||||
auto t1 = randn(T.toBackend(Backend::CPU), {3, 4});
|
||||
ASSERT_THROWSM(ct1._standard_gamma_grad(ct2), "not implemented");
|
||||
ASSERT_THROWSM(ct1._standard_gamma_grad(t1), "not implemented");
|
||||
ASSERT_THROWSM(t1._standard_gamma_grad(ct2), "CUDA Backend");
|
||||
@ -159,15 +159,15 @@ void test(Type & T, Type & AccT) {
|
||||
// where
|
||||
{
|
||||
// empty
|
||||
auto empty = T.ones({0});
|
||||
auto empty = ones(T, {0});
|
||||
auto &bT = T.toScalarType(ScalarType::Byte);
|
||||
auto empty_byte = bT.ones({0});
|
||||
auto empty_byte = ones(bT, {0});
|
||||
ASSERT_EQUAL(empty, at::where(empty_byte, empty, empty));
|
||||
|
||||
// check scalar equals one element
|
||||
auto x_scalar = T.ones({}).mul(5);
|
||||
auto y_scalar = T.ones({}).mul(7);
|
||||
auto cond_scalar = bT.zeros({});
|
||||
auto x_scalar = ones(T, {}).mul(5);
|
||||
auto y_scalar = ones(T, {}).mul(7);
|
||||
auto cond_scalar = zeros(bT, {});
|
||||
auto x_1d = x_scalar.unsqueeze(0);
|
||||
auto y_1d = y_scalar.unsqueeze(0);
|
||||
auto cond_1d = cond_scalar.unsqueeze(0);
|
||||
|
@ -32,7 +32,7 @@ void test(Type &T) {
|
||||
// single-tensor/size tests
|
||||
for (auto s = sizes.begin(); s != sizes.end(); ++s) {
|
||||
// verify that the dim, sizes, strides, etc match what was requested.
|
||||
auto t = T.ones(*s);
|
||||
auto t = ones(T, *s);
|
||||
ASSERT((std::size_t)t.dim() == s->size());
|
||||
ASSERT((std::size_t)t.ndimension() == s->size());
|
||||
ASSERT(t.sizes().equals(*s));
|
||||
@ -43,9 +43,9 @@ void test(Type &T) {
|
||||
std::cout << t << std::endl;
|
||||
|
||||
// set_
|
||||
auto t2 = T.ones(*s);
|
||||
auto t2 = ones(T, *s);
|
||||
t2.set_();
|
||||
assert_equal_size_dim(t2, T.ones({0}));
|
||||
assert_equal_size_dim(t2, ones(T, {0}));
|
||||
|
||||
// unsqueeze
|
||||
if (t.numel() != 0) {
|
||||
@ -56,7 +56,7 @@ void test(Type &T) {
|
||||
|
||||
// unsqueeze_
|
||||
{
|
||||
auto t2 = T.ones(*s);
|
||||
auto t2 = ones(T, *s);
|
||||
if (t2.numel() != 0) {
|
||||
auto r = t2.unsqueeze_(0);
|
||||
ASSERT(r.dim() == t.dim() + 1);
|
||||
@ -83,12 +83,12 @@ void test(Type &T) {
|
||||
}
|
||||
}
|
||||
auto result = t.squeeze();
|
||||
assert_equal_size_dim(result, T.ones(size_without_ones));
|
||||
assert_equal_size_dim(result, ones(T, size_without_ones));
|
||||
}
|
||||
|
||||
{
|
||||
// squeeze_ (with dimension argument)
|
||||
auto t2 = T.ones(*s);
|
||||
auto t2 = ones(T, *s);
|
||||
if (t2.dim() == 0 || t2.sizes()[0] == 1) {
|
||||
ASSERT(t2.squeeze_(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
|
||||
} else {
|
||||
@ -100,7 +100,7 @@ void test(Type &T) {
|
||||
|
||||
// squeeze_ (with no dimension argument)
|
||||
{
|
||||
auto t2 = T.ones(*s);
|
||||
auto t2 = ones(T, *s);
|
||||
std::vector<int64_t> size_without_ones;
|
||||
for (auto size : *s) {
|
||||
if (size != 1) {
|
||||
@ -108,7 +108,7 @@ void test(Type &T) {
|
||||
}
|
||||
}
|
||||
auto r = t2.squeeze_();
|
||||
assert_equal_size_dim(t2, T.ones(size_without_ones));
|
||||
assert_equal_size_dim(t2, ones(T, size_without_ones));
|
||||
}
|
||||
|
||||
// reduce (with dimension argument and with 1 return argument)
|
||||
@ -146,8 +146,8 @@ void test(Type &T) {
|
||||
for (auto rhs_it = sizes.begin(); rhs_it != sizes.end(); ++rhs_it) {
|
||||
// is_same_size should only match if they are the same shape
|
||||
{
|
||||
auto lhs = T.ones(*lhs_it);
|
||||
auto rhs = T.ones(*rhs_it);
|
||||
auto lhs = ones(T, *lhs_it);
|
||||
auto rhs = ones(T, *rhs_it);
|
||||
if(*lhs_it != *rhs_it) {
|
||||
ASSERT(!lhs.is_same_size(rhs));
|
||||
ASSERT(!rhs.is_same_size(lhs));
|
||||
@ -157,15 +157,15 @@ void test(Type &T) {
|
||||
{
|
||||
// resize_
|
||||
{
|
||||
auto lhs = T.ones(*lhs_it);
|
||||
auto rhs = T.ones(*rhs_it);
|
||||
auto lhs = ones(T, *lhs_it);
|
||||
auto rhs = ones(T, *rhs_it);
|
||||
lhs.resize_(*rhs_it);
|
||||
assert_equal_size_dim(lhs, rhs);
|
||||
}
|
||||
// resize_as_
|
||||
{
|
||||
auto lhs = T.ones(*lhs_it);
|
||||
auto rhs = T.ones(*rhs_it);
|
||||
auto lhs = ones(T, *lhs_it);
|
||||
auto rhs = ones(T, *rhs_it);
|
||||
lhs.resize_as_(rhs);
|
||||
assert_equal_size_dim(lhs, rhs);
|
||||
}
|
||||
@ -173,15 +173,15 @@ void test(Type &T) {
|
||||
{
|
||||
{
|
||||
// with tensor
|
||||
auto lhs = T.ones(*lhs_it);
|
||||
auto rhs = T.ones(*rhs_it);
|
||||
auto lhs = ones(T, *lhs_it);
|
||||
auto rhs = ones(T, *rhs_it);
|
||||
lhs.set_(rhs);
|
||||
assert_equal_size_dim(lhs, rhs);
|
||||
}
|
||||
{
|
||||
// with storage
|
||||
auto lhs = T.ones(*lhs_it);
|
||||
auto rhs = T.ones(*rhs_it);
|
||||
auto lhs = ones(T, *lhs_it);
|
||||
auto rhs = ones(T, *rhs_it);
|
||||
auto storage = T.storage(rhs.numel());
|
||||
lhs.set_(*storage);
|
||||
// should not be dim 0 because an empty storage is dim 1; all other storages aren't scalars
|
||||
@ -189,8 +189,8 @@ void test(Type &T) {
|
||||
}
|
||||
{
|
||||
// with storage, offset, sizes, strides
|
||||
auto lhs = T.ones(*lhs_it);
|
||||
auto rhs = T.ones(*rhs_it);
|
||||
auto lhs = ones(T, *lhs_it);
|
||||
auto rhs = ones(T, *rhs_it);
|
||||
auto storage = T.storage(rhs.numel());
|
||||
lhs.set_(*storage, rhs.storage_offset(), rhs.sizes(), rhs.strides());
|
||||
assert_equal_size_dim(lhs, rhs);
|
||||
@ -200,8 +200,8 @@ void test(Type &T) {
|
||||
|
||||
// view
|
||||
{
|
||||
auto lhs = T.ones(*lhs_it);
|
||||
auto rhs = T.ones(*rhs_it);
|
||||
auto lhs = ones(T, *lhs_it);
|
||||
auto rhs = ones(T, *rhs_it);
|
||||
auto rhs_size = *rhs_it;
|
||||
TRY_CATCH_ELSE(auto result = lhs.view(rhs_size),
|
||||
ASSERT(lhs.numel() != rhs.numel()),
|
||||
@ -210,8 +210,8 @@ void test(Type &T) {
|
||||
|
||||
// take
|
||||
{
|
||||
auto lhs = T.ones(*lhs_it);
|
||||
auto rhs = T.zeros(*rhs_it).toType(ScalarType::Long);
|
||||
auto lhs = ones(T, *lhs_it);
|
||||
auto rhs = zeros(T, *rhs_it).toType(ScalarType::Long);
|
||||
TRY_CATCH_ELSE(auto result = lhs.take(rhs),
|
||||
ASSERT(lhs.numel() == 0 && rhs.numel() != 0),
|
||||
assert_equal_size_dim(result, rhs));
|
||||
@ -220,8 +220,8 @@ void test(Type &T) {
|
||||
|
||||
// ger
|
||||
{
|
||||
auto lhs = T.ones(*lhs_it);
|
||||
auto rhs = T.ones(*rhs_it);
|
||||
auto lhs = ones(T, *lhs_it);
|
||||
auto rhs = ones(T, *rhs_it);
|
||||
TRY_CATCH_ELSE(auto result = lhs.ger(rhs),
|
||||
ASSERT(lhs.numel() == 0 || rhs.numel() == 0 || lhs.dim() != 1 || rhs.dim() != 1),
|
||||
[&]() {
|
||||
@ -233,9 +233,9 @@ void test(Type &T) {
|
||||
|
||||
// expand
|
||||
{
|
||||
auto lhs = T.ones(*lhs_it);
|
||||
auto lhs = ones(T, *lhs_it);
|
||||
auto lhs_size = *lhs_it;
|
||||
auto rhs = T.ones(*rhs_it);
|
||||
auto rhs = ones(T, *rhs_it);
|
||||
auto rhs_size = *rhs_it;
|
||||
bool should_pass = should_expand(lhs_size, rhs_size);
|
||||
TRY_CATCH_ELSE(auto result = lhs.expand(rhs_size),
|
||||
@ -249,7 +249,7 @@ void test(Type &T) {
|
||||
bool should_pass_inplace = should_expand(rhs_size, lhs_size);
|
||||
TRY_CATCH_ELSE(lhs.add_(rhs),
|
||||
ASSERT(!should_pass_inplace),
|
||||
ASSERT(should_pass_inplace); assert_equal_size_dim(lhs, T.ones(*lhs_it)););
|
||||
ASSERT(should_pass_inplace); assert_equal_size_dim(lhs, ones(T, *lhs_it)););
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ int main() {
|
||||
auto && C = at::globalContext();
|
||||
if(at::hasCUDA()) {
|
||||
auto & CUDAFloat = C.getType(Backend::CPU,ScalarType::Float);
|
||||
auto t2 = CUDAFloat.zeros({4,4});
|
||||
auto t2 = zeros(CUDAFloat, {4,4});
|
||||
cout << &t2 << "\n";
|
||||
cout << "AFTER GET TYPE " << &CUDAFloat << "\n";
|
||||
cout << "STORAGE: " << CUDAFloat.storage(4).get() << "\n";
|
||||
@ -102,18 +102,18 @@ int main() {
|
||||
s->fill(7);
|
||||
cout << "GET " << s->get(3).toFloat() << "\n";
|
||||
}
|
||||
auto t = CPU(Float).ones({4,4});
|
||||
auto t = ones(CPU(Float), {4,4});
|
||||
|
||||
auto wha2 = CPU(Float).zeros({4,4}).add(t).sum();
|
||||
auto wha2 = zeros(CPU(Float), {4,4}).add(t).sum();
|
||||
cout << wha2.toCDouble() << " <-ndim\n";
|
||||
|
||||
cout << t.sizes() << " " << t.strides() << "\n";
|
||||
|
||||
Type & T = CPU(Float);
|
||||
Tensor x = T.randn({1,10});
|
||||
Tensor prev_h = T.randn({1,20});
|
||||
Tensor W_h = T.randn({20,20});
|
||||
Tensor W_x = T.randn({20,10});
|
||||
Tensor x = randn(T, {1,10});
|
||||
Tensor prev_h = randn(T, {1,20});
|
||||
Tensor W_h = randn(T, {20,20});
|
||||
Tensor W_x = randn(T, {20,10});
|
||||
Tensor i2h = at::mm(W_x, x.t());
|
||||
Tensor h2h = at::mm(W_h, prev_h.t());
|
||||
Tensor next_h = i2h.add(h2h);
|
||||
@ -129,12 +129,12 @@ int main() {
|
||||
|
||||
cout << r << "\n";
|
||||
}
|
||||
cout << T.randn({10,10,2}) << "\n";
|
||||
cout << randn(T, {10,10,2}) << "\n";
|
||||
|
||||
// check Scalar.toTensor on Scalars backed by different data types
|
||||
ASSERT(bar.toTensor().type().scalarType() == kDouble);
|
||||
ASSERT(what.toTensor().type().scalarType() == kLong);
|
||||
ASSERT(Scalar(CPU(kFloat).ones({})).toTensor().type().scalarType() == kFloat);
|
||||
ASSERT(Scalar(ones(CPU(kFloat), {})).toTensor().type().scalarType() == kFloat);
|
||||
|
||||
if (x.type().scalarType() != ScalarType::Half) {
|
||||
AT_DISPATCH_ALL_TYPES(x.type(), "foo", [&] {
|
||||
@ -147,10 +147,10 @@ int main() {
|
||||
|
||||
// test direct C-scalar type conversions
|
||||
{
|
||||
auto x = T.ones({1,2});
|
||||
auto x = ones(T, {1,2});
|
||||
ASSERT_THROWS(x.toCFloat());
|
||||
}
|
||||
auto float_one = T.ones({});
|
||||
auto float_one = ones(T, {});
|
||||
ASSERT(float_one.toCFloat() == 1);
|
||||
ASSERT(float_one.toCInt() == 1);
|
||||
ASSERT(float_one.toCHalf() == 1);
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include "ATen/ATen.h"
|
||||
|
||||
int main() {
|
||||
std::cout << at::CPU(at::kFloat).ones({3,4}) << "\n";
|
||||
std::cout << at::ones(at::CPU(at::kFloat), {3,4}) << "\n";
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ int main() {
|
||||
|
||||
// mainly test ops on undefined tensors don't segfault and give a reasonable errror message.
|
||||
Tensor und;
|
||||
Tensor ft = CPU(kFloat).ones({1});
|
||||
Tensor ft = ones(CPU(kFloat), {1});
|
||||
|
||||
std::cout << und << std::endl;
|
||||
ASSERT(!und.defined());
|
||||
@ -43,7 +43,7 @@ int main() {
|
||||
ASSERT_THROWSM(und.toBackend(Backend::CPU), "toBackend");
|
||||
ASSERT_THROWSM(ft.toBackend(Backend::Undefined), "UndefinedType");
|
||||
|
||||
Tensor to_move = CPU(kFloat).ones({1});
|
||||
Tensor to_move = ones(CPU(kFloat), {1});
|
||||
Tensor m(std::move(to_move));
|
||||
ASSERT(!to_move.defined());
|
||||
ASSERT(to_move.get() == UndefinedTensor::singleton());
|
||||
|
@ -11,32 +11,32 @@ int main() {
|
||||
|
||||
// test simple case
|
||||
{
|
||||
auto a = T.randn({2, 3, 4, 5});
|
||||
auto a = randn(T, {2, 3, 4, 5});
|
||||
ASSERT(a.prod(-4).equal(a.prod(0)));
|
||||
ASSERT(a.prod(3).equal(a.prod(-1)));
|
||||
}
|
||||
|
||||
// test case with expression specification
|
||||
{
|
||||
auto a = T.randn({2, 3, 4, 5});
|
||||
auto a = randn(T, {2, 3, 4, 5});
|
||||
ASSERT(a.unsqueeze(-5).equal(a.unsqueeze(0)));
|
||||
ASSERT(a.unsqueeze(4).equal(a.unsqueeze(-1)));
|
||||
|
||||
// can unsqueeze scalar
|
||||
auto b = T.randn(1);
|
||||
auto b = randn(T, 1);
|
||||
b.get()->maybeScalar(true);
|
||||
ASSERT(b.unsqueeze(0).equal(b.unsqueeze(-1)));
|
||||
}
|
||||
|
||||
// test case with empty tensor
|
||||
{
|
||||
auto a = T.randn(0);
|
||||
auto a = randn(T, 0);
|
||||
ASSERT_THROWS(a.prod(0));
|
||||
}
|
||||
|
||||
// test case with scalar vs 1-dim, 1-size
|
||||
{
|
||||
auto a = T.randn(1);
|
||||
auto a = randn(T, 1);
|
||||
ASSERT(a.prod(0).equal(a.prod(-1)));
|
||||
a.get()->maybeScalar(true);
|
||||
ASSERT(a.get()->isScalar());
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
struct Doubler {
|
||||
Doubler(int A, int B) {
|
||||
tensor_ = at::CPU(at::kDouble).ones({A, B});
|
||||
tensor_ = at::ones(at::CPU(at::kDouble), {A, B});
|
||||
}
|
||||
at::Tensor forward() {
|
||||
return tensor_ * 2;
|
||||
|
@ -8,7 +8,7 @@ Tensor sigmoid_add(Tensor x, Tensor y) {
|
||||
|
||||
struct MatrixMultiplier {
|
||||
MatrixMultiplier(int A, int B) {
|
||||
tensor_ = CPU(kDouble).ones({A, B});
|
||||
tensor_ = ones(CPU(kDouble), {A, B});
|
||||
}
|
||||
Tensor forward(Tensor weights) {
|
||||
return tensor_.mm(weights);
|
||||
|
@ -264,7 +264,7 @@
|
||||
self: grad
|
||||
|
||||
- name: gather(Tensor self, int64_t dim, Tensor index)
|
||||
self: grad.type().zeros(self.sizes()).scatter_add_(dim, index, grad)
|
||||
self: at::zeros(grad.type(), self.sizes()).scatter_add_(dim, index, grad)
|
||||
|
||||
- name: ge_(Tensor self, Scalar other)
|
||||
self: zeros_like(self)
|
||||
@ -317,7 +317,7 @@
|
||||
value: grad.index_select(dim, index).sum()
|
||||
|
||||
- name: index_select(Tensor self, int64_t dim, Tensor index)
|
||||
self: grad.type().zeros(self.sizes()).index_add_(dim, index, grad)
|
||||
self: at::zeros(grad.type(), self.sizes()).index_add_(dim, index, grad)
|
||||
|
||||
- name: inverse(Tensor self)
|
||||
self: -at::mm(output.t(), at::mm(grad, output.t()))
|
||||
@ -445,14 +445,14 @@
|
||||
self: zeros_like(grad)
|
||||
|
||||
- name: normal(Tensor mean, double std, Generator generator)
|
||||
mean: grad.type().zeros(mean.sizes())
|
||||
mean: at::zeros(grad.type(), mean.sizes())
|
||||
|
||||
- name: normal(double mean, Tensor std, Generator generator)
|
||||
std: grad.type().zeros(std.sizes())
|
||||
std: at::zeros(grad.type(), std.sizes())
|
||||
|
||||
- name: normal(Tensor mean, Tensor std, Generator generator)
|
||||
mean: grad.type().zeros(mean.sizes())
|
||||
std: grad.type().zeros(std.sizes())
|
||||
mean: at::zeros(grad.type(), mean.sizes())
|
||||
std: at::zeros(grad.type(), std.sizes())
|
||||
|
||||
- name: orgqr(Tensor self, Tensor input2)
|
||||
self: not_implemented("orgqr")
|
||||
|
@ -63,6 +63,11 @@ def load_aten_declarations(path):
|
||||
declaration['formals'] = [arg['type'] + ' ' + arg['name']
|
||||
for arg in declaration['arguments']]
|
||||
declaration['args'] = [arg['name'] for arg in declaration['arguments']]
|
||||
declaration['type_method_formals'] = [arg['type'] + ' ' + arg['name']
|
||||
for arg in declaration['arguments']
|
||||
if not arg.get('is_type_dispatched')]
|
||||
declaration['type_method_args'] = [arg['name'] for arg in declaration['arguments']
|
||||
if not arg.get('is_type_dispatched')]
|
||||
declaration['api_name'] = declaration['name']
|
||||
declaration['return_type'] = format_return_type(declaration['returns'])
|
||||
|
||||
|
@ -17,6 +17,7 @@ SKIP_PYTHON_BINDINGS = [
|
||||
'alias', 'contiguous', 'clamp.*', 'is_cuda', 'is_sparse', 'size', 'stride',
|
||||
'.*_backward', '.*_backward_out', '.*_forward', '.*_forward_out',
|
||||
'sparse_raw_resize_', '_unsafe_view', 'tensor', 'sparse_coo_tensor',
|
||||
'_arange.*', '_range.*', '_linspace.*', '_logspace.*'
|
||||
]
|
||||
|
||||
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
|
||||
@ -66,12 +67,12 @@ if (r.isNone(${out_idx})) {
|
||||
}
|
||||
""")
|
||||
|
||||
PY_VARIABLE_OUT_CHECK_DTYPE = CodeTemplate("""\
|
||||
PY_VARIABLE_OUT_CHECK_TYPE = CodeTemplate("""\
|
||||
if (r.isNone(${out_idx})) {
|
||||
${call_dispatch}
|
||||
} else {
|
||||
if (!r.isNone(${dtype_idx})) {
|
||||
check_out_dtype_matches(r.tensor(${out_idx}), r.type(${dtype_idx}));
|
||||
if (!r.isNone(${type_idx})) {
|
||||
check_out_type_matches(r.tensor(${out_idx}), r.type(${type_idx}));
|
||||
}
|
||||
${call_dispatch_out}
|
||||
}
|
||||
@ -161,23 +162,10 @@ def gen_py_nn_functions(out, declarations):
|
||||
|
||||
|
||||
def gen_py_torch_functions(out, declarations):
|
||||
def is_namespace_or_type_api_function(declaration):
|
||||
# These are functions that should end up on the torch module. This
|
||||
# includes functions on the at:: namespace and ones that are typically
|
||||
# called via the type (e.g. Type::randn()). Since every function is
|
||||
# implemented on the Type, we exclude functions that are also declared
|
||||
# as methods on Tensor, since one shouldn't generally call these from
|
||||
# the Type object.
|
||||
if 'namespace' in declaration['method_of']:
|
||||
return True
|
||||
if 'Tensor' in declaration['method_of']:
|
||||
return False
|
||||
return 'Type' in declaration['method_of']
|
||||
|
||||
def should_bind(declaration):
|
||||
return (should_generate_python_binding(declaration) and
|
||||
declaration['mode'] != 'NN' and
|
||||
is_namespace_or_type_api_function(declaration))
|
||||
'namespace' in declaration['method_of'])
|
||||
|
||||
py_torch_functions = group_declarations_by_name(declarations, should_bind)
|
||||
|
||||
@ -199,6 +187,13 @@ def group_declarations_by_name(declarations, should_bind_fn):
|
||||
return groups
|
||||
|
||||
|
||||
def get_type_default(declaration):
|
||||
if declaration['name'].startswith('randperm'):
|
||||
return 'torch.int64'
|
||||
else:
|
||||
return 'None'
|
||||
|
||||
|
||||
def create_python_bindings(python_functions, has_self, is_module=False):
|
||||
"""Generates Python bindings to ATen functions"""
|
||||
py_methods = []
|
||||
@ -260,6 +255,8 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
||||
|
||||
inputs = [arg for arg in declaration['arguments'] if not is_output(arg)]
|
||||
outputs = [arg for arg in declaration['arguments'] if is_output(arg)]
|
||||
type_dispatched_args = [arg for arg in declaration['arguments'] if arg.get('is_type_dispatched')]
|
||||
assert len(type_dispatched_args) <= 1
|
||||
|
||||
def parse_arg(arg, arg_index, unpack_args=False):
|
||||
name = arg['name']
|
||||
@ -302,6 +299,8 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
||||
|
||||
unpack = any(arg.get('python_default_init') for arg in inputs)
|
||||
for arg in inputs:
|
||||
if arg.get('is_type_dispatched'):
|
||||
continue
|
||||
if has_self and arg['name'] == 'self':
|
||||
formal_args.append('Tensor & self')
|
||||
actuals.append('self_')
|
||||
@ -318,23 +317,26 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
||||
formal_args.append('Tensor & {}'.format(arg['name']))
|
||||
actuals.append('results[{}]'.format(i))
|
||||
|
||||
# this goes after the outputs to match the signature generation.
|
||||
arg_idx = arg_idx if out_idx is None else out_idx + 1
|
||||
for arg in type_dispatched_args:
|
||||
append_actuals_formals(*parse_arg(arg, arg_idx, unpack))
|
||||
arg_idx += 1
|
||||
|
||||
# check python_binding_arguments
|
||||
has_dtype_bind = False
|
||||
has_device_bind = False
|
||||
requires_grad = None
|
||||
python_binding_arguments = declaration.get('python_binding_arguments', [])
|
||||
bind_arg_idx = arg_idx if out_idx is None else out_idx + 1
|
||||
if 'dtype' in (a['name'] for a in python_binding_arguments):
|
||||
dtype_idx, device_idx, requires_grad_idx = (bind_arg_idx, bind_arg_idx + 1, bind_arg_idx + 2)
|
||||
dtype_idx, device_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2)
|
||||
else:
|
||||
device_idx, requires_grad_idx = (bind_arg_idx, bind_arg_idx + 1)
|
||||
device_idx, requires_grad_idx = (arg_idx, arg_idx + 1)
|
||||
|
||||
for arg in python_binding_arguments:
|
||||
if arg['name'] == 'dtype' and arg['simple_type'] == 'Type':
|
||||
# out(s) determines the dtype if it is present, so don't pass the dtype to the dispatch.
|
||||
if len(outputs) == 0:
|
||||
# we have to use out_idx if there is an out variant because the base variant
|
||||
# won't have the full arg_idx count
|
||||
has_dtype_bind = True
|
||||
append_actuals_formals(*parse_arg(arg, dtype_idx))
|
||||
elif len(outputs) > 1:
|
||||
@ -352,8 +354,10 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
||||
env['unpack_args'] = []
|
||||
env['formal_args'] = formal_args
|
||||
env['actuals'] = actuals
|
||||
has_any_dtype = (has_dtype_bind or any('dtype' in a['name'] for a in inputs))
|
||||
env['initialize_cuda'] = 'maybe_initialize_cuda(dtype);' if has_any_dtype else []
|
||||
has_any_dtype = has_dtype_bind or any(a['name'] == 'dtype' and a['simple_type'] == 'Type' for a in inputs)
|
||||
type_dispatched_name = type_dispatched_args[0]['name'] if len(type_dispatched_args) > 0 else None
|
||||
maybe_init_cuda = 'dtype' if has_any_dtype else type_dispatched_name
|
||||
env['initialize_cuda'] = 'maybe_initialize_cuda({});'.format(maybe_init_cuda) if maybe_init_cuda else []
|
||||
if 'call_args' in declaration:
|
||||
env['dispatch_args'] = declaration['call_args']
|
||||
else:
|
||||
@ -362,14 +366,9 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
||||
env['dispatch_args'] = [arg for arg in env['dispatch_args'] if arg != 'self']
|
||||
env['dispatch_call'] = 'self.{}'.format(declaration['name'])
|
||||
elif 'namespace' in declaration['method_of']:
|
||||
if has_dtype_bind:
|
||||
raise RuntimeError(("dtype with namespace dispatch currently not supported, "
|
||||
"consider writing as a native function"))
|
||||
env['dispatch_call'] = 'at::{}'.format(declaration['name'])
|
||||
elif has_dtype_bind:
|
||||
env['dispatch_call'] = 'dtype.{}'.format(declaration['name'])
|
||||
else:
|
||||
env['dispatch_call'] = 'default_type().{}'.format(declaration['name'])
|
||||
raise RuntimeError('could not dispatch, neither namespace function nor Tensor method')
|
||||
env['AutoNoGIL'] = 'AutoNoGIL no_gil;'
|
||||
env['AutoGPU'] = auto_gpu(declaration, has_device_bind)
|
||||
|
||||
@ -392,7 +391,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
||||
|
||||
has_dtype_bind = 'dtype' in [d['name'] for d in dictionary['out'].get('python_binding_arguments', [])]
|
||||
if has_dtype_bind:
|
||||
body = PY_VARIABLE_OUT_CHECK_DTYPE.substitute(env, out_idx=out_idx, dtype_idx=out_idx + 1).split('\n')
|
||||
body = PY_VARIABLE_OUT_CHECK_TYPE.substitute(env, out_idx=out_idx, type_idx=out_idx + 1).split('\n')
|
||||
else:
|
||||
body = PY_VARIABLE_OUT.substitute(env, out_idx=out_idx).split('\n')
|
||||
else:
|
||||
@ -404,12 +403,15 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
||||
def get_python_binding_arguments(declaration):
|
||||
python_binding_arguments = []
|
||||
has_tensor_input_arg = False
|
||||
has_type_dispatched = False
|
||||
for arg in declaration['arguments']:
|
||||
if arg.get('output', False):
|
||||
continue
|
||||
typename = arg['simple_type']
|
||||
if typename in ['Tensor', 'TensorList']:
|
||||
has_tensor_input_arg = True
|
||||
if arg.get('is_type_dispatched'):
|
||||
has_type_dispatched = True
|
||||
if arg['name'] == 'requires_grad':
|
||||
raise ValueError("argument named requires_grad not supported")
|
||||
|
||||
@ -420,11 +422,8 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
||||
# produce a compile-time error that is obvious
|
||||
has_tensor_return = True
|
||||
|
||||
if has_tensor_return and not has_tensor_input_arg:
|
||||
if declaration['name'].startswith('randperm'):
|
||||
default_type = 'torch.int64'
|
||||
else:
|
||||
default_type = 'None'
|
||||
if has_tensor_return and not has_tensor_input_arg and not has_type_dispatched:
|
||||
default_type = get_type_default(declaration)
|
||||
dtype_arg = {
|
||||
'default': default_type,
|
||||
'dynamic_type': 'Type',
|
||||
@ -432,6 +431,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
||||
'name': 'dtype',
|
||||
'type': 'const Type &',
|
||||
'simple_type': 'Type',
|
||||
'is_type_dispatched': True,
|
||||
}
|
||||
python_binding_arguments.append(dtype_arg)
|
||||
if (not has_tensor_input_arg or name.endswith('_like')) and has_tensor_return:
|
||||
@ -538,7 +538,8 @@ def group_declarations(declarations):
|
||||
|
||||
result = []
|
||||
for _, dictionary in sorted(grouped.items()):
|
||||
assert 'base' in dictionary
|
||||
if 'base' not in dictionary:
|
||||
raise RuntimeError('\'base\' not in dictionary', dictionary)
|
||||
result.append(dictionary)
|
||||
return result
|
||||
|
||||
@ -547,6 +548,7 @@ def get_python_signature(declaration, include_out):
|
||||
# Compute the Python function signature for argument parsing
|
||||
typed_args = []
|
||||
output_args = []
|
||||
type_dispatch_args = []
|
||||
positional = True
|
||||
|
||||
def get_typed_arg(arg):
|
||||
@ -563,6 +565,11 @@ def get_python_signature(declaration, include_out):
|
||||
default = 'None'
|
||||
if arg.get('python_default_init') is not None:
|
||||
default = 'None'
|
||||
if default is None and arg.get('is_type_dispatched', False):
|
||||
# this is necessary because ATen does not have default_types; in this case,
|
||||
# the type exists in the public API (at:: namespace), but not in the type interface;
|
||||
# to match the PyTorch default_type API, we set the default to None.
|
||||
default = get_type_default(declaration)
|
||||
if default is not None:
|
||||
param += '=' + str(default)
|
||||
return param
|
||||
@ -571,6 +578,9 @@ def get_python_signature(declaration, include_out):
|
||||
if arg.get('output', False):
|
||||
output_args.append(arg)
|
||||
continue
|
||||
if arg.get('is_type_dispatched', False):
|
||||
type_dispatch_args.append(arg)
|
||||
continue
|
||||
if arg.get('kwarg_only', False) and positional:
|
||||
typed_args.append('*')
|
||||
positional = False
|
||||
@ -594,7 +604,17 @@ def get_python_signature(declaration, include_out):
|
||||
typename = typenames[0]
|
||||
typed_args.append(typename + ' out=None')
|
||||
|
||||
# we could put this in the loop above but we want to ensure it is after the out argument
|
||||
# we could put this in the loop above but we want to ensure both type dispatched args
|
||||
# and python binding arguments are after the out argument; this matches the case
|
||||
# where there is a python binding argument dtype, which is necessary to match
|
||||
# the function signatures between the out and non-out variant.
|
||||
assert len(type_dispatch_args) <= 1
|
||||
for arg in type_dispatch_args:
|
||||
if positional: # assume type_dispatch_args should be kwarg_only.
|
||||
typed_args.append('*')
|
||||
positional = False
|
||||
typed_args.append(get_typed_arg(arg))
|
||||
|
||||
if len(declaration['python_binding_arguments']) > 0:
|
||||
for arg in declaration['python_binding_arguments']:
|
||||
if arg.get('kwarg_only', False) and positional:
|
||||
|
@ -70,11 +70,11 @@ DONT_REQUIRE_DERIVATIVE = {
|
||||
}
|
||||
|
||||
METHOD_DECLARATION = CodeTemplate("""\
|
||||
virtual ${return_type} ${method_prefix_derived}${api_name}(${formals}) const override;
|
||||
virtual ${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override;
|
||||
""")
|
||||
|
||||
METHOD_DEFINITION = CodeTemplate("""\
|
||||
${return_type} VariableType::${method_prefix_derived}${api_name}(${formals}) const {
|
||||
${return_type} VariableType::${method_prefix_derived}${api_name}(${type_method_formals}) const {
|
||||
${type_definition_body}
|
||||
}
|
||||
""")
|
||||
@ -98,7 +98,7 @@ grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} ));
|
||||
""")
|
||||
|
||||
CALL_VIA_TYPE = CodeTemplate("""\
|
||||
Type::${method_prefix_derived}${api_name}(${args})""")
|
||||
Type::${method_prefix_derived}${api_name}(${type_method_args})""")
|
||||
|
||||
CALL_VIA_DERIVED = CodeTemplate("""\
|
||||
baseType->${method_prefix_derived}${base_name}(${unpacked_args})""")
|
||||
@ -496,6 +496,9 @@ def unpack_args(env, declaration):
|
||||
body = []
|
||||
unpacked_args = []
|
||||
for i, arg in enumerate(declaration['arguments']):
|
||||
# these arguments are skipped from the Type method.
|
||||
if arg.get('is_type_dispatched'):
|
||||
continue
|
||||
if not requires_unpack(arg):
|
||||
unpacked_args.append(arg['name'])
|
||||
continue
|
||||
@ -537,14 +540,22 @@ def dispatch_strategy(declaration):
|
||||
get dispatched back to VariableType (which will ensure that they
|
||||
are differentiable.)
|
||||
"""
|
||||
if declaration['abstract'] or declaration['derivative'] is not None:
|
||||
if (declaration['abstract'] or declaration['derivative'] is not None or
|
||||
any(arg.get('is_type_dispatched') for arg in declaration['arguments'])):
|
||||
# If the function is abstract (not implemented on at::Type), we must
|
||||
# call the implementation on the derived type with unpacked tensors.
|
||||
|
||||
# If the function has a derivative specified and is concrete, we could
|
||||
# call either implementation. We prefer the calling the derived
|
||||
# type's implementation with unpacked tensors because it is more
|
||||
# performant in some cases: any internal calls to other ATen functions
|
||||
# won't have the history tracked.
|
||||
|
||||
# If the function has a type dispatched argument (i.e. is a factory),
|
||||
# we prefer calling the derived type's implementation both because it is
|
||||
# more performant and to ensure factory functions return tensors with _version
|
||||
# of 0 (probably not strictly necessary, but nice to have to keeps versions simple
|
||||
# to understand.
|
||||
return 'use_derived'
|
||||
else:
|
||||
# If the function is concrete (we don't have to override it) and we
|
||||
|
@ -137,7 +137,7 @@ Tensor sum_backward(const Tensor & grad, IntList sizes, int64_t dim, bool keepdi
|
||||
}
|
||||
|
||||
Tensor reverse_dim(const Tensor& t, int64_t dim) {
|
||||
Tensor index = t.type().toScalarType(at::ScalarType::Long).arange(t.size(dim) - 1, -1, -1);
|
||||
Tensor index = at::arange(t.type().toScalarType(at::ScalarType::Long), t.size(dim) - 1, -1, -1);
|
||||
return t.index_select(dim, index);
|
||||
}
|
||||
|
||||
@ -148,7 +148,7 @@ Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t d
|
||||
|
||||
std::vector<int64_t> ones_size(inp.sizes());
|
||||
ones_size[dim] = 1;
|
||||
Tensor ones = grad.type().ones(ones_size);
|
||||
Tensor ones = at::ones(grad.type(), ones_size);
|
||||
Tensor exclusive_normal_nocp = at::cat({ones, inp.narrow(dim, 0, inp.size(dim) - 1)}, dim);
|
||||
Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim);
|
||||
|
||||
@ -301,8 +301,8 @@ Tensor cumprod_backward(const Tensor &grad, const Tensor &input, int64_t dim) {
|
||||
|
||||
std::vector<int64_t> ones_size(input.sizes());
|
||||
ones_size[dim] = 1;
|
||||
Tensor ones = grad.type().ones({1}).expand(ones_size);
|
||||
Tensor grad_input = grad.type().zeros(input.sizes());
|
||||
Tensor ones = at::ones(grad.type(), {1}).expand(ones_size);
|
||||
Tensor grad_input = at::zeros(grad.type(), input.sizes());
|
||||
Tensor prods_from_k_plus_1;
|
||||
Tensor omitted_products;
|
||||
for (int k = 0; k < dim_size; ++k) {
|
||||
@ -371,7 +371,7 @@ std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<
|
||||
auto size = sizes[i];
|
||||
accumulate += size;
|
||||
if (size == 0) {
|
||||
grad_inputs[i] = grad.type().zeros({0});
|
||||
grad_inputs[i] = at::zeros(grad.type(), {0});
|
||||
} else {
|
||||
grad_inputs[i] = grad.narrow(dim, accumulate - size, size);
|
||||
}
|
||||
@ -461,7 +461,7 @@ Tensor select_backward(Tensor grad, int64_t dim, Tensor indices, IntList sizes,
|
||||
grad = grad.unsqueeze(dim);
|
||||
indices = indices.unsqueeze(dim);
|
||||
}
|
||||
return grad.type().zeros(sizes).scatter_(dim, indices, grad);
|
||||
return at::zeros(grad.type(), sizes).scatter_(dim, indices, grad);
|
||||
}
|
||||
|
||||
Tensor trace_backward(const Tensor & grad, IntList sizes) {
|
||||
@ -471,8 +471,8 @@ Tensor trace_backward(const Tensor & grad, IntList sizes) {
|
||||
|
||||
auto& long_type = grad.type().toScalarType(at::kLong);
|
||||
|
||||
auto grad_input = grad.type().zeros(sizes[0] * sizes[1]);
|
||||
auto indices = long_type.arange(0, grad_input.numel(), sizes[1] + 1);
|
||||
auto grad_input = at::zeros(grad.type(), sizes[0] * sizes[1]);
|
||||
auto indices = at::arange(long_type, 0, grad_input.numel(), sizes[1] + 1);
|
||||
grad_input.index_fill_(0, indices, grad);
|
||||
return grad_input.view(sizes);
|
||||
}
|
||||
@ -485,9 +485,9 @@ Tensor unfold_backward(const Tensor & grad, IntList input_sizes, int64_t dim, in
|
||||
numel *= size;
|
||||
}
|
||||
|
||||
auto idx = long_type.arange(0, numel).view(input_sizes);
|
||||
auto idx = at::arange(long_type, 0, numel).view(input_sizes);
|
||||
auto idx_unfolded = idx.unfold(dim, size, step).contiguous().view(-1);
|
||||
auto grad_input = grad.type().zeros({numel});
|
||||
auto grad_input = at::zeros(grad.type(), {numel});
|
||||
grad_input.index_add_(0, idx_unfolded, grad.contiguous().view(-1));
|
||||
return grad_input.view(input_sizes);
|
||||
}
|
||||
@ -516,7 +516,7 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntList
|
||||
if (diff_nelem > 0) {
|
||||
// because mask_selected returns a 1-d tensor with size of masked elements that are 1,
|
||||
// we need to fill out the rest with zeros then reshape back to tensor2's size.
|
||||
auto zeros_fillin = grad.type().zeros({diff_nelem});
|
||||
auto zeros_fillin = at::zeros(grad.type(), {diff_nelem});
|
||||
mask_selected = at::cat({mask_selected, zeros_fillin}, 0);
|
||||
}
|
||||
return mask_selected.view(sizes);
|
||||
@ -564,7 +564,7 @@ Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &g
|
||||
auto length = split_sizes[j];
|
||||
std::vector<int64_t> grad_size(sizes);
|
||||
grad_size[dim] = length;
|
||||
grads_all_defined[j] = type.zeros(grad_size);
|
||||
grads_all_defined[j] = at::zeros(type, grad_size);
|
||||
}
|
||||
}
|
||||
|
||||
@ -613,7 +613,7 @@ Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input
|
||||
if (dim < 0) dim += input.dim();
|
||||
std::vector<int64_t> sizes = input.sizes();
|
||||
sizes[dim] /= 2;
|
||||
auto tmp = grad * glu_backward(input.type().ones(sizes), input, dim);
|
||||
auto tmp = grad * glu_backward(at::ones(input.type(), sizes), input, dim);
|
||||
return tmp.narrow(dim, 0, sizes[dim]) + tmp.narrow(dim, sizes[dim], sizes[dim]);
|
||||
}
|
||||
|
||||
@ -705,7 +705,7 @@ Tensor diag_backward(const Tensor & grad, IntList input_sizes, int64_t diagonal)
|
||||
}
|
||||
|
||||
// Input was a matrix but was not square
|
||||
auto grad_input = grad.type().zeros(input_sizes);
|
||||
auto grad_input = at::zeros(grad.type(), input_sizes);
|
||||
auto diagonal_size = diag_size(input_sizes[0], input_sizes[1], diagonal);
|
||||
auto storage_offset = diagonal >= 0 ? diagonal : -diagonal * input_sizes[1];
|
||||
auto diag = grad_input.as_strided({diagonal_size}, {input_sizes[1] + 1}, storage_offset);
|
||||
@ -867,8 +867,8 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
|
||||
auto gu = grads[0];
|
||||
auto gsigma = grads[1];
|
||||
auto gv = grads[2];
|
||||
auto im = self.type().eye(m);
|
||||
auto in = self.type().eye(n);
|
||||
auto im = at::eye(self.type(), m);
|
||||
auto in = at::eye(self.type(), n);
|
||||
auto ut = u.t();
|
||||
auto vt = v.t();
|
||||
auto sigma_mat = sigma.diag();
|
||||
@ -876,7 +876,7 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
|
||||
auto sigma_expanded_sq = sigma.pow(2).expand_as(sigma_mat);
|
||||
auto F = (sigma_expanded_sq - sigma_expanded_sq.t()).pow(-1);
|
||||
auto& long_type = sigma.type().toScalarType(at::kLong);
|
||||
auto diag_indices = long_type.arange(0, F.numel(), k + 1);
|
||||
auto diag_indices = at::arange(long_type, 0, F.numel(), k + 1);
|
||||
F.view({-1}).index_fill_(0, diag_indices, 0);
|
||||
|
||||
Tensor u_term, sigma_term, v_term;
|
||||
@ -888,13 +888,13 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
|
||||
}
|
||||
u_term = u_term.mm(vt);
|
||||
} else {
|
||||
u_term = self.type().zeros({1}).expand_as(self);
|
||||
u_term = at::zeros(self.type(), {1}).expand_as(self);
|
||||
}
|
||||
|
||||
if (gsigma.defined()) {
|
||||
sigma_term = u.mm(gsigma.diag()).mm(vt);
|
||||
} else {
|
||||
sigma_term = self.type().zeros({1}).expand_as(self);
|
||||
sigma_term = at::zeros(self.type(), {1}).expand_as(self);
|
||||
}
|
||||
|
||||
if (gv.defined()) {
|
||||
@ -905,7 +905,7 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
|
||||
}
|
||||
v_term = u.mm(v_term);
|
||||
} else {
|
||||
v_term = self.type().zeros({1}).expand_as(self);
|
||||
v_term = at::zeros(self.type(), {1}).expand_as(self);
|
||||
}
|
||||
|
||||
return u_term + sigma_term + v_term;
|
||||
@ -961,10 +961,10 @@ std::tuple<Tensor, Tensor> trtrs_backward(
|
||||
}
|
||||
}
|
||||
if (!grad_a.defined()) {
|
||||
grad_a = a.type().zeros({1}).expand_as(a);
|
||||
grad_a = at::zeros(a.type(), {1}).expand_as(a);
|
||||
}
|
||||
if (!grad_b.defined()) {
|
||||
grad_b = b.type().zeros({1}).expand_as(b);
|
||||
grad_b = at::zeros(b.type(), {1}).expand_as(b);
|
||||
}
|
||||
if (output_mask[1] && grad_m.defined()) {
|
||||
grad_a = grad_a.add(grad_m);
|
||||
|
@ -30,7 +30,7 @@ struct TypeAndSize {
|
||||
: sizes(t.sizes())
|
||||
, type(&t.type()) {}
|
||||
|
||||
Tensor zeros() { return type->zeros(sizes); }
|
||||
Tensor zeros() { return at::zeros(*type, sizes); }
|
||||
|
||||
private:
|
||||
std::vector<int64_t> sizes;
|
||||
|
@ -30,9 +30,9 @@ static Tensor set_requires_grad(Tensor self, bool requires_grad) {
|
||||
return self;
|
||||
}
|
||||
|
||||
static void check_out_dtype_matches(Tensor result, const at::Type &type) {
|
||||
static void check_out_type_matches(Tensor result, const at::Type &type) {
|
||||
if (result.type() != type) {
|
||||
at::runtime_error("dtype corresponding to %s does not match type of out parameter (%s)",
|
||||
at::runtime_error("type corresponding to %s does not match type of out parameter (%s)",
|
||||
type.toString(), result.type().toString());
|
||||
}
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ def init_dropout_state(ty, dropout, train, dropout_seed, dropout_state):
|
||||
dropout_p = dropout if train else 0
|
||||
if (dropout_desc_name not in dropout_state) or (dropout_state[dropout_desc_name].get() is None):
|
||||
dropout_state[dropout_desc_name] = Unserializable(
|
||||
torch._C._VariableFunctions._cudnn_init_dropout_state(ty, dropout_p, train, dropout_seed)
|
||||
torch._C._VariableFunctions._cudnn_init_dropout_state(dropout_p, train, dropout_seed, ty=ty)
|
||||
if dropout_p != 0 else None
|
||||
)
|
||||
dropout_ts = dropout_state[dropout_desc_name].get()
|
||||
|
@ -45,7 +45,7 @@ VariableInfo::VariableInfo(const Variable& var)
|
||||
|
||||
Variable VariableInfo::zeros(AutoGPU& gpu_guard) const {
|
||||
gpu_guard.setDevice(device);
|
||||
return type->zeros(size);
|
||||
return at::zeros(*type, size);
|
||||
}
|
||||
|
||||
auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
|
||||
|
@ -198,7 +198,7 @@ static Value* createZerosLike(Value *v) {
|
||||
AutoGPU gpu_guard(type->device());
|
||||
|
||||
auto & at_type = type->device() == -1 ? at::CPU(type->scalarType()) : at::CUDA(type->scalarType());
|
||||
auto zeros = at_type.zeros({1}).expand(type->sizes());
|
||||
auto zeros = at::zeros(at_type, {1}).expand(type->sizes());
|
||||
Node *constant = graph->createConstant(zeros)
|
||||
->i_(kis_zero, 1);
|
||||
graph->insertNode(constant);
|
||||
|
@ -26,7 +26,7 @@ static at::Tensor zeroTensorWithType(const TensorType & type) {
|
||||
auto & at_type = at::getType(device, type.scalarType());
|
||||
// note: this has to be a contiguous tensor of zeros, because the fusion engine
|
||||
// specialized to what is normally here which might be fully dense
|
||||
return at_type.zeros(type.sizes());
|
||||
return at::zeros(at_type, type.sizes());
|
||||
}
|
||||
|
||||
autograd::variable_list InterpreterAutogradFunction::apply(
|
||||
|
@ -631,7 +631,7 @@ struct to_ir {
|
||||
return emitNode(
|
||||
Symbol("type_as"),
|
||||
input->range(),
|
||||
{emitExpr(input, 1)[0], createConstant(input->range(), at::CPU(t).ones({1}))},
|
||||
{emitExpr(input, 1)[0], createConstant(input->range(), at::ones(at::CPU(t), {1}))},
|
||||
1)
|
||||
->outputs();
|
||||
}
|
||||
|
@ -120,9 +120,9 @@ static void fusionTests() {
|
||||
Var i1 = Var::asNewInput(graph);
|
||||
auto o0 = i0 * i1;
|
||||
o0.addAsOutput();
|
||||
auto a = at::CUDA(at::kFloat).rand({3,4});
|
||||
auto b = at::CUDA(at::kFloat).rand({4,3}).transpose(0,1);
|
||||
auto o = at::CUDA(at::kFloat).zeros({3,4});
|
||||
auto a = at::rand(at::CUDA(at::kFloat), {3,4});
|
||||
auto b = at::rand(at::CUDA(at::kFloat), {4,3}).transpose(0,1);
|
||||
auto o = at::zeros(at::CUDA(at::kFloat), {3,4});
|
||||
comp.debugLaunchGraph(graph, 0, {a,b}, {o});
|
||||
auto o2 = a*b;
|
||||
float max_diff = (o2 - o).abs().max().toCDouble();
|
||||
@ -164,12 +164,12 @@ static void fusionTests() {
|
||||
for(size_t i = 0; i < graph.inputs().size(); i++) {
|
||||
std::vector<int64_t> dims = {128, 128, 32};
|
||||
std::swap(dims[ti],dims[tj]);
|
||||
inputs.push_back(at::CUDA(at::kFloat).rand(dims).transpose(ti, tj));
|
||||
inputs.push_back(at::rand(at::CUDA(at::kFloat), dims).transpose(ti, tj));
|
||||
}
|
||||
for(size_t i = 0; i < graph.outputs().size(); i++) {
|
||||
std::vector<int64_t> dims = {128, 128, 32};
|
||||
std::swap(dims[toi],dims[toj]);
|
||||
outputs.push_back(at::CUDA(at::kFloat).zeros(dims).transpose(toi,toj));
|
||||
outputs.push_back(at::zeros(at::CUDA(at::kFloat), dims).transpose(toi,toj));
|
||||
}
|
||||
|
||||
auto t22 = inputs[4].sigmoid();
|
||||
@ -209,13 +209,13 @@ static void fusionTests() {
|
||||
o0.addAsOutput();
|
||||
Var::cat({i0, o0}, dim).addAsOutput();
|
||||
|
||||
auto a = at::CUDA(at::kFloat).rand({3,4,5});
|
||||
auto b = at::CUDA(at::kFloat).rand({4,3,5}).transpose(0,1);
|
||||
auto o = at::CUDA(at::kFloat).zeros({3,4,5});
|
||||
auto a = at::rand(at::CUDA(at::kFloat), {3,4,5});
|
||||
auto b = at::rand(at::CUDA(at::kFloat), {4,3,5}).transpose(0,1);
|
||||
auto o = at::zeros(at::CUDA(at::kFloat), {3,4,5});
|
||||
|
||||
auto o_r = a*b;
|
||||
auto o2_r = at::cat({a, o_r}, dim);
|
||||
auto o2 = at::CUDA(at::kFloat).zeros(o2_r.sizes());
|
||||
auto o2 = at::zeros(at::CUDA(at::kFloat), o2_r.sizes());
|
||||
comp.debugLaunchGraph(graph, 0, {a,b}, {o, o2});
|
||||
|
||||
float max_diff = (o_r - o).abs().max().toCDouble();
|
||||
@ -407,11 +407,11 @@ void interpTest() {
|
||||
|
||||
int hidden_size = 2*input_size;
|
||||
|
||||
auto input = at::CUDA(at::kFloat).randn({seq_len, batch_size, input_size});
|
||||
auto hx = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||
auto cx = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||
auto w_ih = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, input_size}));
|
||||
auto w_hh = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, hidden_size}));
|
||||
auto input = at::randn(at::CUDA(at::kFloat), {seq_len, batch_size, input_size});
|
||||
auto hx = at::randn(at::CUDA(at::kFloat), {batch_size, hidden_size});
|
||||
auto cx = at::randn(at::CUDA(at::kFloat), {batch_size, hidden_size});
|
||||
auto w_ih = t_def(at::randn(at::CUDA(at::kFloat), {4 * hidden_size, input_size}));
|
||||
auto w_hh = t_def(at::randn(at::CUDA(at::kFloat), {4 * hidden_size, hidden_size}));
|
||||
|
||||
auto lstm_g = build_lstm();
|
||||
Code lstm_function(lstm_g, /*values_are_variables=*/false);
|
||||
@ -431,12 +431,12 @@ void interpStageTest() {
|
||||
constexpr int seq_len = 32;
|
||||
|
||||
int hidden_size = 2*input_size;
|
||||
auto input = at::CUDA(at::kFloat).randn({seq_len, batch_size, input_size});
|
||||
auto hx = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||
auto cx = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||
auto cx1 = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||
auto w_ih = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, input_size}));
|
||||
auto w_hh = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, hidden_size}));
|
||||
auto input = at::randn(at::CUDA(at::kFloat), {seq_len, batch_size, input_size});
|
||||
auto hx = at::randn(at::CUDA(at::kFloat), {batch_size, hidden_size});
|
||||
auto cx = at::randn(at::CUDA(at::kFloat), {batch_size, hidden_size});
|
||||
auto cx1 = at::randn(at::CUDA(at::kFloat), {batch_size, hidden_size});
|
||||
auto w_ih = t_def(at::randn(at::CUDA(at::kFloat), {4 * hidden_size, input_size}));
|
||||
auto w_hh = t_def(at::randn(at::CUDA(at::kFloat), {4 * hidden_size, hidden_size}));
|
||||
|
||||
|
||||
auto lstm_g = build_lstm_stages();
|
||||
@ -655,7 +655,7 @@ void testCreateAutodiffSubgraphs(std::ostream & out) {
|
||||
}
|
||||
|
||||
autograd::Variable var(at::Type & t, at::IntList sizes, bool requires_grad) {
|
||||
return autograd::make_variable(t.rand(sizes), requires_grad);
|
||||
return autograd::make_variable(at::rand(t, sizes), requires_grad);
|
||||
}
|
||||
autograd::Variable undef() {
|
||||
return autograd::Variable();
|
||||
@ -740,11 +740,11 @@ void shapeAnalysisTest() {
|
||||
|
||||
auto v = [](at::Tensor t) { return autograd::make_variable(t, false); };
|
||||
|
||||
auto input = at::CUDA(at::kFloat).randn({batch_size, input_size});
|
||||
auto hx = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||
auto cx = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||
auto w_ih = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, input_size}));
|
||||
auto w_hh = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, hidden_size}));
|
||||
auto input = at::randn(at::CUDA(at::kFloat), {batch_size, input_size});
|
||||
auto hx = at::randn(at::CUDA(at::kFloat), {batch_size, hidden_size});
|
||||
auto cx = at::randn(at::CUDA(at::kFloat), {batch_size, hidden_size});
|
||||
auto w_ih = t_def(at::randn(at::CUDA(at::kFloat), {4 * hidden_size, input_size}));
|
||||
auto w_hh = t_def(at::randn(at::CUDA(at::kFloat), {4 * hidden_size, hidden_size}));
|
||||
|
||||
auto g = build_lstm();
|
||||
ArgumentSpec spec(false, createVarList({v(input), v(hx), v(cx), v(w_ih), v(w_hh) }));
|
||||
@ -766,11 +766,11 @@ void testGraphExecutor() {
|
||||
|
||||
auto v = [](at::Tensor t) { return autograd::make_variable(t, false); };
|
||||
|
||||
auto input = at::CUDA(at::kFloat).randn({batch_size, input_size});
|
||||
auto hx = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||
auto cx = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||
auto w_ih = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, input_size}));
|
||||
auto w_hh = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, hidden_size}));
|
||||
auto input = at::randn(at::CUDA(at::kFloat), {batch_size, input_size});
|
||||
auto hx = at::randn(at::CUDA(at::kFloat), {batch_size, hidden_size});
|
||||
auto cx = at::randn(at::CUDA(at::kFloat), {batch_size, hidden_size});
|
||||
auto w_ih = t_def(at::randn(at::CUDA(at::kFloat), {4 * hidden_size, input_size}));
|
||||
auto w_hh = t_def(at::randn(at::CUDA(at::kFloat), {4 * hidden_size, hidden_size}));
|
||||
|
||||
std::vector<at::Tensor> inputs = {v(input), v(hx), v(cx), v(w_ih), v(w_hh) };
|
||||
auto g = build_lstm();
|
||||
|
Reference in New Issue
Block a user