mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fixed some thread safety issues
This commit is contained in:
@ -209,7 +209,7 @@ void THTensor_(addmv)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
|
||||
|
||||
void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *m1, THTensor *m2)
|
||||
{
|
||||
char transpose, transpose_m1, transpose_m2;
|
||||
char transpose_r, transpose_m1, transpose_m2;
|
||||
THTensor *r__, *m1_, *m2_;
|
||||
|
||||
if( (m1->nDimension != 2) || (m2->nDimension != 2) )
|
||||
@ -227,10 +227,12 @@ void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
|
||||
THTensor_(copy)(r_, t);
|
||||
}
|
||||
|
||||
// printf("%ldx%ld = %ldx%ld X %ldx%ld\n", r_->size[0], r_->size[1], m1->size[0], m1->size[1], m2->size[0], m2->size[1]);
|
||||
|
||||
/* r_ */
|
||||
if(r_->stride[0] == 1)
|
||||
{
|
||||
transpose = 'n';
|
||||
transpose_r = 'n';
|
||||
r__ = r_;
|
||||
}
|
||||
else if(r_->stride[1] == 1)
|
||||
@ -238,69 +240,66 @@ void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
|
||||
THTensor *swap = m2;
|
||||
m2 = m1;
|
||||
m1 = swap;
|
||||
THTensor_(transpose)(r_, NULL, 0, 1);
|
||||
THTensor_(transpose)(m1, NULL, 0, 1);
|
||||
THTensor_(transpose)(m2, NULL, 0, 1);
|
||||
transpose = 't';
|
||||
transpose_r = 't';
|
||||
r__ = r_;
|
||||
}
|
||||
else
|
||||
{
|
||||
transpose = 'n';
|
||||
THTensor_(transpose)(r_, NULL, 0, 1);
|
||||
r__ = THTensor_(newClone)(r_);
|
||||
THTensor_(transpose)(r_, NULL, 0, 1);
|
||||
transpose_r = 'n';
|
||||
|
||||
r__ = THTensor_(newWithSize2d)(r_->size[1], r_->size[0]);
|
||||
THTensor_(copy)(r__, r_);
|
||||
THTensor_(transpose)(r__, NULL, 0, 1);
|
||||
}
|
||||
|
||||
/* m1 */
|
||||
if(m1->stride[0] == 1)
|
||||
if(m1->stride[(transpose_r == 'n' ? 0 : 1)] == 1)
|
||||
{
|
||||
transpose_m1 = 'n';
|
||||
m1_ = m1;
|
||||
}
|
||||
else if(m1->stride[1] == 1)
|
||||
else if(m1->stride[(transpose_r == 'n' ? 1 : 0)] == 1)
|
||||
{
|
||||
transpose_m1 = 't';
|
||||
m1_ = m1;
|
||||
}
|
||||
else
|
||||
{
|
||||
transpose_m1 = 't';
|
||||
transpose_m1 = (transpose_r == 'n' ? 't' : 'n');
|
||||
m1_ = THTensor_(newContiguous)(m1);
|
||||
}
|
||||
|
||||
/* m2 */
|
||||
if(m2->stride[0] == 1)
|
||||
if(m2->stride[(transpose_r == 'n' ? 0 : 1)] == 1)
|
||||
{
|
||||
transpose_m2 = 'n';
|
||||
m2_ = m2;
|
||||
}
|
||||
else if(m2->stride[1] == 1)
|
||||
else if(m2->stride[(transpose_r == 'n' ? 1 : 0)] == 1)
|
||||
{
|
||||
transpose_m2 = 't';
|
||||
m2_ = m2;
|
||||
}
|
||||
else
|
||||
{
|
||||
transpose_m2 = 't';
|
||||
transpose_m2 = (transpose_r == 'n' ? 't' : 'n');
|
||||
m2_ = THTensor_(newContiguous)(m2);
|
||||
}
|
||||
|
||||
/* do the operation */
|
||||
THBlas_(gemm)(transpose_m1,
|
||||
transpose_m2,
|
||||
r__->size[0],
|
||||
r__->size[1],
|
||||
m1_->size[1],
|
||||
r__->size[(transpose_r == 'n' ? 0 : 1)],
|
||||
r__->size[(transpose_r == 'n' ? 1 : 0)],
|
||||
m1_->size[(transpose_r == 'n' ? 1 : 0)],
|
||||
alpha,
|
||||
THTensor_(data)(m1_),
|
||||
(transpose_m1 == 'n' ? m1_->stride[1] : m1_->stride[0]),
|
||||
(transpose_m1 == 'n' ? m1_->stride[(transpose_r == 'n' ? 1 : 0)] : m1_->stride[(transpose_r == 'n' ? 0 : 1)]),
|
||||
THTensor_(data)(m2_),
|
||||
(transpose_m2 == 'n' ? m2_->stride[1] : m2_->stride[0]),
|
||||
(transpose_m2 == 'n' ? m2_->stride[(transpose_r == 'n' ? 1 : 0)] : m2_->stride[(transpose_r == 'n' ? 0 : 1)]),
|
||||
beta,
|
||||
THTensor_(data)(r__),
|
||||
r__->stride[1]);
|
||||
r__->stride[(transpose_r == 'n' ? 1 : 0)]);
|
||||
|
||||
/* free intermediate variables */
|
||||
if(m1_ != m1)
|
||||
@ -311,13 +310,6 @@ void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
|
||||
|
||||
if(r__ != r_)
|
||||
THTensor_(freeCopyTo)(r__, r_);
|
||||
|
||||
if(transpose == 't')
|
||||
{
|
||||
THTensor_(transpose)(r_, NULL, 0, 1);
|
||||
THTensor_(transpose)(m1, NULL, 0, 1);
|
||||
THTensor_(transpose)(m2, NULL, 0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
void THTensor_(addr)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *vec1, THTensor *vec2)
|
||||
|
Reference in New Issue
Block a user