DgemmThread.java

package org.mklab.sdpj.gpack.blaswrap;

import org.mklab.nfc.matrix.ComplexNumericalMatrix;
import org.mklab.nfc.matrix.RealNumericalMatrix;
import org.mklab.nfc.scalar.ComplexNumericalScalar;
import org.mklab.nfc.scalar.RealNumericalScalar;
import org.mklab.sdpj.tool.Tools;


/**
 * @author koga
 * @version $Revision$, 2009/04/24
 * @param <RS> 実スカラーの型
 * @param <RM> 実行列の型
 * @param <CS> 複素スカラーの型
 * @param <CM> 複素行列の型
 */
public class DgemmThread<RS extends RealNumericalScalar<RS, RM, CS, CM>, RM extends RealNumericalMatrix<RS, RM, CS, CM>, CS extends ComplexNumericalScalar<RS, RM, CS, CM>, CM extends ComplexNumericalMatrix<RS, RM, CS, CM>> {

  /**
   * @author koga
   * @version $Revision$, 2010/11/15
   */
  enum Trans {
    /** */
    NN,
    /** */
    NT,
    /** */
    TN,
    /** */
    TT
  }

  /**
   * @param transa transa
   * @param transb tranb
   * @param m m
   * @param n n
   * @param k k
   * @param alpha alpha
   * @param a a
   * @param lda lda
   * @param b b
   * @param ldb ldb
   * @param beta beta
   * @param c c
   * @param ldc ldc
   * @return result
   */
  public int dgemm(String transa, String transb, int m, int n, int k, RS alpha, RS[] a, int lda, RS[] b, int ldb, RS beta, RS[] c, int ldc) {
    Trans trans = null;
    if (BLAS.lsame(transa, "N")) { //$NON-NLS-1$
      if (BLAS.lsame(transb, "N")) { //$NON-NLS-1$
        trans = Trans.NN;
      } else if (BLAS.lsame(transb, "T")) { //$NON-NLS-1$
        trans = Trans.NT;
      } else {
        Tools.error("BLAS::DGEMMTHREAD:error TRANS B"); //$NON-NLS-1$
      }
    } else if (BLAS.lsame(transa, "T")) { //$NON-NLS-1$
      if (BLAS.lsame(transb, "N")) { //$NON-NLS-1$
        trans = Trans.TN;//TN
      } else if (BLAS.lsame(transb, "T")) { //$NON-NLS-1$
        trans = Trans.TT;
      } else {
        Tools.error("BLAS::DGEMMTHREAD:error TRANS B"); //$NON-NLS-1$
      }
    } else {
      Tools.error("BLAS::DGEMMTHREAD:error TRANS A"); //$NON-NLS-1$
    }
    dgemmChThread(trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
    return 0;
  }

  /**
   * @param ch ch
   * @param m m
   * @param n n
   * @param k k
   * @param alpha alpha
   * @param A A
   * @param lda lda
   * @param B B
   * @param ldb ldb
   * @param beta beta
   * @param C C
   * @param ldc ldc
   */
  private void dgemmChThread(Trans ch, int m, int n, int k, RS alpha, RS[] A, int lda, RS[] B, int ldb, RS beta, RS[] C, int ldc) {
    int m2 = m / 2;
    Thread subThread = new Thread(new DgemmThreadCh<>(ch, m2, n, k, alpha, A, lda, B, ldb, beta, C, ldc));
    subThread.start();
    subDgemmCh(ch, m2, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
    while (subThread.isAlive()) {
      // wait until threads die
    }
  }

  /**
   * @param ch ch
   * @param startLoopCount startLoopCount
   * @param m2 m2
   * @param n n
   * @param k k
   * @param alpha alpha
   * @param A A
   * @param lda lda
   * @param B B
   * @param ldb ldb
   * @param beta beta
   * @param C C
   * @param ldc ldc
   */
  private void subDgemmCh(Trans ch, int startLoopCount, int m2, int n, int k, RS alpha, RS[] A, int lda, RS[] B, int ldb, RS beta, RS[] C, int ldc) {
    for (int i = startLoopCount; i < m2; i++) {
      for (int j = 0; j < n; j++) {
        RS temp = A[0].createZero();
        for (int indexk = 0; indexk < k; indexk++) {
          if (ch == Trans.NN) {
            temp = temp.add(A[i + indexk * lda].multiply(B[indexk + j * ldb]));
          } else if (ch == Trans.NT) {
            temp = temp.add(A[i + indexk * lda].multiply(B[j + indexk * ldb]));
          } else if (ch == Trans.TN) {
            temp = temp.add(A[indexk + i * lda].multiply(B[indexk + j * ldb]));
          } else if (ch == Trans.TT) {
            temp = temp.add(A[indexk + i * lda].multiply(B[j + indexk * ldb]));
          }
        }
        C[i + j * ldc] = (beta.multiply(C[i + j * ldc])).add(alpha.multiply(temp));
      }
    }
  }
}