fixed some thread safety issues

This commit is contained in:
Ronan Collobert
2012-09-26 19:55:16 +02:00
parent de2295f507
commit c34c2e2de1

View File

@ -209,7 +209,7 @@ void THTensor_(addmv)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *m1, THTensor *m2)
{
char transpose, transpose_m1, transpose_m2;
char transpose_r, transpose_m1, transpose_m2;
THTensor *r__, *m1_, *m2_;
if( (m1->nDimension != 2) || (m2->nDimension != 2) )
@ -227,10 +227,12 @@ void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
THTensor_(copy)(r_, t);
}
// printf("%ldx%ld = %ldx%ld X %ldx%ld\n", r_->size[0], r_->size[1], m1->size[0], m1->size[1], m2->size[0], m2->size[1]);
/* r_ */
if(r_->stride[0] == 1)
{
transpose = 'n';
transpose_r = 'n';
r__ = r_;
}
else if(r_->stride[1] == 1)
@ -238,69 +240,66 @@ void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
THTensor *swap = m2;
m2 = m1;
m1 = swap;
THTensor_(transpose)(r_, NULL, 0, 1);
THTensor_(transpose)(m1, NULL, 0, 1);
THTensor_(transpose)(m2, NULL, 0, 1);
transpose = 't';
transpose_r = 't';
r__ = r_;
}
else
{
transpose = 'n';
THTensor_(transpose)(r_, NULL, 0, 1);
r__ = THTensor_(newClone)(r_);
THTensor_(transpose)(r_, NULL, 0, 1);
transpose_r = 'n';
r__ = THTensor_(newWithSize2d)(r_->size[1], r_->size[0]);
THTensor_(copy)(r__, r_);
THTensor_(transpose)(r__, NULL, 0, 1);
}
/* m1 */
if(m1->stride[0] == 1)
if(m1->stride[(transpose_r == 'n' ? 0 : 1)] == 1)
{
transpose_m1 = 'n';
m1_ = m1;
}
else if(m1->stride[1] == 1)
else if(m1->stride[(transpose_r == 'n' ? 1 : 0)] == 1)
{
transpose_m1 = 't';
m1_ = m1;
}
else
{
transpose_m1 = 't';
transpose_m1 = (transpose_r == 'n' ? 't' : 'n');
m1_ = THTensor_(newContiguous)(m1);
}
/* m2 */
if(m2->stride[0] == 1)
if(m2->stride[(transpose_r == 'n' ? 0 : 1)] == 1)
{
transpose_m2 = 'n';
m2_ = m2;
}
else if(m2->stride[1] == 1)
else if(m2->stride[(transpose_r == 'n' ? 1 : 0)] == 1)
{
transpose_m2 = 't';
m2_ = m2;
}
else
{
transpose_m2 = 't';
transpose_m2 = (transpose_r == 'n' ? 't' : 'n');
m2_ = THTensor_(newContiguous)(m2);
}
/* do the operation */
THBlas_(gemm)(transpose_m1,
transpose_m2,
r__->size[0],
r__->size[1],
m1_->size[1],
r__->size[(transpose_r == 'n' ? 0 : 1)],
r__->size[(transpose_r == 'n' ? 1 : 0)],
m1_->size[(transpose_r == 'n' ? 1 : 0)],
alpha,
THTensor_(data)(m1_),
(transpose_m1 == 'n' ? m1_->stride[1] : m1_->stride[0]),
(transpose_m1 == 'n' ? m1_->stride[(transpose_r == 'n' ? 1 : 0)] : m1_->stride[(transpose_r == 'n' ? 0 : 1)]),
THTensor_(data)(m2_),
(transpose_m2 == 'n' ? m2_->stride[1] : m2_->stride[0]),
(transpose_m2 == 'n' ? m2_->stride[(transpose_r == 'n' ? 1 : 0)] : m2_->stride[(transpose_r == 'n' ? 0 : 1)]),
beta,
THTensor_(data)(r__),
r__->stride[1]);
r__->stride[(transpose_r == 'n' ? 1 : 0)]);
/* free intermediate variables */
if(m1_ != m1)
@ -311,13 +310,6 @@ void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
if(r__ != r_)
THTensor_(freeCopyTo)(r__, r_);
if(transpose == 't')
{
THTensor_(transpose)(r_, NULL, 0, 1);
THTensor_(transpose)(m1, NULL, 0, 1);
THTensor_(transpose)(m2, NULL, 0, 1);
}
}
void THTensor_(addr)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *vec1, THTensor *vec2)