DgemmThread.java
package org.mklab.sdpj.gpack.blaswrap;
import org.mklab.nfc.matrix.ComplexNumericalMatrix;
import org.mklab.nfc.matrix.RealNumericalMatrix;
import org.mklab.nfc.scalar.ComplexNumericalScalar;
import org.mklab.nfc.scalar.RealNumericalScalar;
import org.mklab.sdpj.tool.Tools;
/**
* @author koga
* @version $Revision$, 2009/04/24
* @param <RS> 実スカラーの型
* @param <RM> 実行列の型
* @param <CS> 複素スカラーの型
* @param <CM> 複素行列の型
*/
public class DgemmThread<RS extends RealNumericalScalar<RS, RM, CS, CM>, RM extends RealNumericalMatrix<RS, RM, CS, CM>, CS extends ComplexNumericalScalar<RS, RM, CS, CM>, CM extends ComplexNumericalMatrix<RS, RM, CS, CM>> {
/**
* @author koga
* @version $Revision$, 2010/11/15
*/
enum Trans {
/** */
NN,
/** */
NT,
/** */
TN,
/** */
TT
}
/**
* @param transa transa
* @param transb tranb
* @param m m
* @param n n
* @param k k
* @param alpha alpha
* @param a a
* @param lda lda
* @param b b
* @param ldb ldb
* @param beta beta
* @param c c
* @param ldc ldc
* @return result
*/
public int dgemm(String transa, String transb, int m, int n, int k, RS alpha, RS[] a, int lda, RS[] b, int ldb, RS beta, RS[] c, int ldc) {
Trans trans = null;
if (BLAS.lsame(transa, "N")) { //$NON-NLS-1$
if (BLAS.lsame(transb, "N")) { //$NON-NLS-1$
trans = Trans.NN;
} else if (BLAS.lsame(transb, "T")) { //$NON-NLS-1$
trans = Trans.NT;
} else {
Tools.error("BLAS::DGEMMTHREAD:error TRANS B"); //$NON-NLS-1$
}
} else if (BLAS.lsame(transa, "T")) { //$NON-NLS-1$
if (BLAS.lsame(transb, "N")) { //$NON-NLS-1$
trans = Trans.TN;//TN
} else if (BLAS.lsame(transb, "T")) { //$NON-NLS-1$
trans = Trans.TT;
} else {
Tools.error("BLAS::DGEMMTHREAD:error TRANS B"); //$NON-NLS-1$
}
} else {
Tools.error("BLAS::DGEMMTHREAD:error TRANS A"); //$NON-NLS-1$
}
dgemmChThread(trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
return 0;
}
/**
* @param ch ch
* @param m m
* @param n n
* @param k k
* @param alpha alpha
* @param A A
* @param lda lda
* @param B B
* @param ldb ldb
* @param beta beta
* @param C C
* @param ldc ldc
*/
private void dgemmChThread(Trans ch, int m, int n, int k, RS alpha, RS[] A, int lda, RS[] B, int ldb, RS beta, RS[] C, int ldc) {
int m2 = m / 2;
Thread subThread = new Thread(new DgemmThreadCh<>(ch, m2, n, k, alpha, A, lda, B, ldb, beta, C, ldc));
subThread.start();
subDgemmCh(ch, m2, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
while (subThread.isAlive()) {
// wait until threads die
}
}
/**
* @param ch ch
* @param startLoopCount startLoopCount
* @param m2 m2
* @param n n
* @param k k
* @param alpha alpha
* @param A A
* @param lda lda
* @param B B
* @param ldb ldb
* @param beta beta
* @param C C
* @param ldc ldc
*/
private void subDgemmCh(Trans ch, int startLoopCount, int m2, int n, int k, RS alpha, RS[] A, int lda, RS[] B, int ldb, RS beta, RS[] C, int ldc) {
for (int i = startLoopCount; i < m2; i++) {
for (int j = 0; j < n; j++) {
RS temp = A[0].createZero();
for (int indexk = 0; indexk < k; indexk++) {
if (ch == Trans.NN) {
temp = temp.add(A[i + indexk * lda].multiply(B[indexk + j * ldb]));
} else if (ch == Trans.NT) {
temp = temp.add(A[i + indexk * lda].multiply(B[j + indexk * ldb]));
} else if (ch == Trans.TN) {
temp = temp.add(A[indexk + i * lda].multiply(B[indexk + j * ldb]));
} else if (ch == Trans.TT) {
temp = temp.add(A[indexk + i * lda].multiply(B[j + indexk * ldb]));
}
}
C[i + j * ldc] = (beta.multiply(C[i + j * ldc])).add(alpha.multiply(temp));
}
}
}
}