mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Implement Variable._sparse_mask (#4124)
* Implement Variable._sparse_mask * Use SparseTensor as the dyanmic_type
This commit is contained in:
committed by
Soumith Chintala
parent
6d72c82985
commit
c813ce3787
@ -71,6 +71,7 @@ def create_python_bindings(
|
||||
|
||||
unpack_methods = {
|
||||
'const Tensor &': 'tensor',
|
||||
'SparseTensor': 'tensor',
|
||||
'Tensor &': 'tensor',
|
||||
'Generator *': 'generator',
|
||||
'Storage &': 'storage',
|
||||
@ -116,6 +117,8 @@ def create_python_bindings(
|
||||
expr = 'r.{}({})'.format(unpack, arg_idx)
|
||||
if typename == 'Storage &':
|
||||
expr = '*' + expr
|
||||
if typename == 'SparseTensor':
|
||||
expr = 'SparseTensor({})'.format(expr)
|
||||
actuals.append(expr)
|
||||
dispatch_type = typename
|
||||
if dispatch_type == 'Tensor':
|
||||
@ -167,6 +170,7 @@ def create_python_bindings(
|
||||
if not is_class:
|
||||
# Use 'input' instead of 'self' for NN functions
|
||||
prototype = prototype.replace('Tensor self', 'Tensor input')
|
||||
prototype = prototype.replace('SparseTensor', 'Tensor')
|
||||
if 'deprecated' in o:
|
||||
prototype += '|deprecated'
|
||||
env['prototypes'].append('"{}",'.format(prototype))
|
||||
|
Reference in New Issue
Block a user