added mv, mm and ger + better checking of addmv, addmm and addr

This commit is contained in:
Ronan Collobert
2012-02-07 14:49:13 +01:00
parent bbf3970981
commit ed4857dfcf

View File

@ -352,32 +352,113 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
{name=Tensor},
{name=Tensor}})
for _,name in ipairs({"addmv", "addmm", "addr"}) do
interface:wrap(name,
cname(name),
wrap("mv",
cname("addmv"),
{{name=Tensor, default=true, returned=true, method={default='nil'},
init=function(arg)
return table.concat(
{
arg.__metatable.init(arg),
string.format("TH%s_resize1d(%s, %s->size[0]);", Tensor, arg:carg(), arg.args[5]:carg())
}, '\n')
end,
precall=function(arg)
return table.concat(
{
string.format("TH%s_zero(%s);", Tensor, arg:carg()),
arg.__metatable.precall(arg)
}, '\n')
end
},
{name=real, default=1, invisible=true},
{name=Tensor, default=1, invisible=true},
{name=real, default=1, invisible=true},
{name=Tensor, dim=2},
{name=Tensor, dim=1}}
)
wrap("mm",
cname("addmm"),
{{name=Tensor, default=true, returned=true, method={default='nil'},
init=function(arg)
return table.concat(
{
arg.__metatable.init(arg),
string.format("TH%s_resize2d(%s, %s->size[0], %s->size[1]);", Tensor, arg:carg(), arg.args[5]:carg(), arg.args[6]:carg())
}, '\n')
end,
precall=function(arg)
return table.concat(
{
string.format("TH%s_zero(%s);", Tensor, arg:carg()),
arg.__metatable.precall(arg)
}, '\n')
end
},
{name=real, default=1, invisible=true},
{name=Tensor, default=1, invisible=true},
{name=real, default=1, invisible=true},
{name=Tensor, dim=2},
{name=Tensor, dim=2}}
)
wrap("ger",
cname("addr"),
{{name=Tensor, default=true, returned=true, method={default='nil'},
init=function(arg)
return table.concat(
{
arg.__metatable.init(arg),
string.format("TH%s_resize2d(%s, %s->size[0], %s->size[0]);", Tensor, arg:carg(), arg.args[5]:carg(), arg.args[6]:carg())
}, '\n')
end,
precall=function(arg)
return table.concat(
{
string.format("TH%s_zero(%s);", Tensor, arg:carg()),
arg.__metatable.precall(arg)
}, '\n')
end
},
{name=real, default=1, invisible=true},
{name=Tensor, default=1, invisible=true},
{name=real, default=1, invisible=true},
{name=Tensor, dim=1},
{name=Tensor, dim=1}}
)
for _,f in ipairs({
{name="addmv", dim1=1, dim2=2, dim3=1},
{name="addmm", dim1=2, dim2=2, dim3=2},
{name="addr", dim1=2, dim2=1, dim3=1},
}
) do
interface:wrap(f.name,
cname(f.name),
{{name=Tensor, default=true, returned=true},
{name=real, default=1},
{name=Tensor, method={default=1}},
{name=Tensor, dim=f.dim1},
{name=real, default=1},
{name=Tensor},
{name=Tensor}})
{name=Tensor, dim=f.dim2},
{name=Tensor, dim=f.dim3}})
-- there is an ambiguity here, hence the more complicated setup
method:wrap(name,
cname(name),
{{name=Tensor, returned=true},
method:wrap(f.name,
cname(f.name),
{{name=Tensor, returned=true, dim=f.dim1},
{name=real, default=1, invisible=true},
{name=Tensor, default=1},
{name=Tensor, default=1, dim=f.dim1},
{name=real, default=1},
{name=Tensor},
{name=Tensor}},
cname(name),
{{name=Tensor, returned=true},
{name=Tensor, dim=f.dim2},
{name=Tensor, dim=f.dim3}},
cname(f.name),
{{name=Tensor, returned=true, dim=f.dim1},
{name=real},
{name=Tensor, default=1},
{name=Tensor, default=1, dim=f.dim1},
{name=real},
{name=Tensor},
{name=Tensor}})
{name=Tensor, dim=f.dim2},
{name=Tensor, dim=f.dim3}})
end
wrap("numel",
@ -448,10 +529,14 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
wrap("randperm",
cname("randperm"),
{{name=Tensor, default=true, returned=true, method={default='nil'},
userpostcall=function(arg)
return string.format("TH%s_add(%s, %s, 1);", Tensor, arg:carg(), arg:carg())
end},
{name="long"}})
postcall=function(arg)
return table.concat(
{
arg.__metatable.postcall(arg),
string.format("TH%s_add(%s, %s, 1);", Tensor, arg:carg(), arg:carg())
}, '\n')
end},
{name="long"}})
wrap("sort",
cname("sort"),