Implement Variable._sparse_mask (#4124)

* Implement Variable._sparse_mask

* Use SparseTensor as the dyanmic_type
This commit is contained in:
Sam Gross
2017-12-15 17:25:20 -05:00
committed by Soumith Chintala
parent 6d72c82985
commit c813ce3787
8 changed files with 50 additions and 1 deletions

View File

@ -71,6 +71,7 @@ def create_python_bindings(
unpack_methods = {
'const Tensor &': 'tensor',
'SparseTensor': 'tensor',
'Tensor &': 'tensor',
'Generator *': 'generator',
'Storage &': 'storage',
@ -116,6 +117,8 @@ def create_python_bindings(
expr = 'r.{}({})'.format(unpack, arg_idx)
if typename == 'Storage &':
expr = '*' + expr
if typename == 'SparseTensor':
expr = 'SparseTensor({})'.format(expr)
actuals.append(expr)
dispatch_type = typename
if dispatch_type == 'Tensor':
@ -167,6 +170,7 @@ def create_python_bindings(
if not is_class:
# Use 'input' instead of 'self' for NN functions
prototype = prototype.replace('Tensor self', 'Tensor input')
prototype = prototype.replace('SparseTensor', 'Tensor')
if 'deprecated' in o:
prototype += '|deprecated'
env['prototypes'].append('"{}",'.format(prototype))