mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
added mv, mm and ger + better checking of addmv, addmm and addr
This commit is contained in:
127
TensorMath.lua
127
TensorMath.lua
@ -352,32 +352,113 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
|
||||
{name=Tensor},
|
||||
{name=Tensor}})
|
||||
|
||||
for _,name in ipairs({"addmv", "addmm", "addr"}) do
|
||||
interface:wrap(name,
|
||||
cname(name),
|
||||
wrap("mv",
|
||||
cname("addmv"),
|
||||
{{name=Tensor, default=true, returned=true, method={default='nil'},
|
||||
init=function(arg)
|
||||
return table.concat(
|
||||
{
|
||||
arg.__metatable.init(arg),
|
||||
string.format("TH%s_resize1d(%s, %s->size[0]);", Tensor, arg:carg(), arg.args[5]:carg())
|
||||
}, '\n')
|
||||
end,
|
||||
precall=function(arg)
|
||||
return table.concat(
|
||||
{
|
||||
string.format("TH%s_zero(%s);", Tensor, arg:carg()),
|
||||
arg.__metatable.precall(arg)
|
||||
}, '\n')
|
||||
end
|
||||
},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=2},
|
||||
{name=Tensor, dim=1}}
|
||||
)
|
||||
|
||||
wrap("mm",
|
||||
cname("addmm"),
|
||||
{{name=Tensor, default=true, returned=true, method={default='nil'},
|
||||
init=function(arg)
|
||||
return table.concat(
|
||||
{
|
||||
arg.__metatable.init(arg),
|
||||
string.format("TH%s_resize2d(%s, %s->size[0], %s->size[1]);", Tensor, arg:carg(), arg.args[5]:carg(), arg.args[6]:carg())
|
||||
}, '\n')
|
||||
end,
|
||||
precall=function(arg)
|
||||
return table.concat(
|
||||
{
|
||||
string.format("TH%s_zero(%s);", Tensor, arg:carg()),
|
||||
arg.__metatable.precall(arg)
|
||||
}, '\n')
|
||||
end
|
||||
},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=2},
|
||||
{name=Tensor, dim=2}}
|
||||
)
|
||||
|
||||
wrap("ger",
|
||||
cname("addr"),
|
||||
{{name=Tensor, default=true, returned=true, method={default='nil'},
|
||||
init=function(arg)
|
||||
return table.concat(
|
||||
{
|
||||
arg.__metatable.init(arg),
|
||||
string.format("TH%s_resize2d(%s, %s->size[0], %s->size[0]);", Tensor, arg:carg(), arg.args[5]:carg(), arg.args[6]:carg())
|
||||
}, '\n')
|
||||
end,
|
||||
precall=function(arg)
|
||||
return table.concat(
|
||||
{
|
||||
string.format("TH%s_zero(%s);", Tensor, arg:carg()),
|
||||
arg.__metatable.precall(arg)
|
||||
}, '\n')
|
||||
end
|
||||
},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=1},
|
||||
{name=Tensor, dim=1}}
|
||||
)
|
||||
|
||||
for _,f in ipairs({
|
||||
{name="addmv", dim1=1, dim2=2, dim3=1},
|
||||
{name="addmm", dim1=2, dim2=2, dim3=2},
|
||||
{name="addr", dim1=2, dim2=1, dim3=1},
|
||||
}
|
||||
) do
|
||||
|
||||
interface:wrap(f.name,
|
||||
cname(f.name),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=1},
|
||||
{name=Tensor, method={default=1}},
|
||||
{name=Tensor, dim=f.dim1},
|
||||
{name=real, default=1},
|
||||
{name=Tensor},
|
||||
{name=Tensor}})
|
||||
{name=Tensor, dim=f.dim2},
|
||||
{name=Tensor, dim=f.dim3}})
|
||||
|
||||
-- there is an ambiguity here, hence the more complicated setup
|
||||
method:wrap(name,
|
||||
cname(name),
|
||||
{{name=Tensor, returned=true},
|
||||
method:wrap(f.name,
|
||||
cname(f.name),
|
||||
{{name=Tensor, returned=true, dim=f.dim1},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, default=1},
|
||||
{name=Tensor, default=1, dim=f.dim1},
|
||||
{name=real, default=1},
|
||||
{name=Tensor},
|
||||
{name=Tensor}},
|
||||
cname(name),
|
||||
{{name=Tensor, returned=true},
|
||||
{name=Tensor, dim=f.dim2},
|
||||
{name=Tensor, dim=f.dim3}},
|
||||
cname(f.name),
|
||||
{{name=Tensor, returned=true, dim=f.dim1},
|
||||
{name=real},
|
||||
{name=Tensor, default=1},
|
||||
{name=Tensor, default=1, dim=f.dim1},
|
||||
{name=real},
|
||||
{name=Tensor},
|
||||
{name=Tensor}})
|
||||
{name=Tensor, dim=f.dim2},
|
||||
{name=Tensor, dim=f.dim3}})
|
||||
end
|
||||
|
||||
wrap("numel",
|
||||
@ -448,10 +529,14 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
|
||||
wrap("randperm",
|
||||
cname("randperm"),
|
||||
{{name=Tensor, default=true, returned=true, method={default='nil'},
|
||||
userpostcall=function(arg)
|
||||
return string.format("TH%s_add(%s, %s, 1);", Tensor, arg:carg(), arg:carg())
|
||||
end},
|
||||
{name="long"}})
|
||||
postcall=function(arg)
|
||||
return table.concat(
|
||||
{
|
||||
arg.__metatable.postcall(arg),
|
||||
string.format("TH%s_add(%s, %s, 1);", Tensor, arg:carg(), arg:carg())
|
||||
}, '\n')
|
||||
end},
|
||||
{name="long"}})
|
||||
|
||||
wrap("sort",
|
||||
cname("sort"),
|
||||
|
Reference in New Issue
Block a user