Dgemm.java

package org.mklab.sdpj.gpack.blaswrap;

import org.mklab.nfc.matrix.ComplexNumericalMatrix;
import org.mklab.nfc.matrix.RealNumericalMatrix;
import org.mklab.nfc.scalar.ComplexNumericalScalar;
import org.mklab.nfc.scalar.RealNumericalScalar;


/**
 * @author koga
 * @version $Revision$, 2009/04/24
 * @param <RS> 実スカラーの型
 * @param <RM> 実行列の型
 * @param <CS> 複素スカラーの型
 * @param <CM> 複素行列の型
 */
public class Dgemm<RS extends RealNumericalScalar<RS, RM, CS, CM>, RM extends RealNumericalMatrix<RS, RM, CS, CM>, CS extends ComplexNumericalScalar<RS, RM, CS, CM>, CM extends ComplexNumericalMatrix<RS, RM, CS, CM>> {

  /*  Purpose   
  =======   
  DGEMM  performs one of the matrix-matrix operations   
     C := alpha*op( A )*op( B ) + beta*C,   
  where  op( X ) is one of   
     op( X ) = X   or   op( X ) = X',   
  alpha and beta are scalars, and A, B and C are matrices, with op( A )   
  an m by k matrix,  op( B )  a  k by n matrix and  C an m by n matrix.   
  Parameters   
  ==========   
  TRANSA - CHARACTER*1.   
           On entry, TRANSA specifies the form of op( A ) to be used in   
           the matrix multiplication as follows:   
              TRANSA = 'N' or 'n',  op( A ) = A.   
              TRANSA = 'T' or 't',  op( A ) = A'.   
              TRANSA = 'C' or 'c',  op( A ) = A'.   
           Unchanged on exit.   
  TRANSB - CHARACTER*1.   
           On entry, TRANSB specifies the form of op( B ) to be used in   
           the matrix multiplication as follows:   
              TRANSB = 'N' or 'n',  op( B ) = B.   
              TRANSB = 'T' or 't',  op( B ) = B'.   
              TRANSB = 'C' or 'c',  op( B ) = B'.   
           Unchanged on exit.   
  M      - INTEGER.   
           On entry,  M  specifies  the number  of rows  of the  matrix   
           op( A )  and of the  matrix  C.  M  must  be at least  zero.   
           Unchanged on exit.   
  N      - INTEGER.   
           On entry,  N  specifies the number  of columns of the matrix   
           op( B ) and the number of columns of the matrix C. N must be   
           at least zero.   
           Unchanged on exit.   
  K      - INTEGER.   
           On entry,  K  specifies  the number of columns of the matrix   
           op( A ) and the number of rows of the matrix op( B ). K must   
           be at least  zero.   
           Unchanged on exit.   
  ALPHA  - DOUBLE PRECISION.   
           On entry, ALPHA specifies the scalar alpha.   
           Unchanged on exit.   
  A      - DOUBLE PRECISION array of DIMENSION ( LDA, ka ), where ka is   
           k  when  TRANSA = 'N' or 'n',  and is  m  otherwise.   
           Before entry with  TRANSA = 'N' or 'n',  the leading  m by k   
           part of the array  A  must contain the matrix  A,  otherwise   
           the leading  k by m  part of the array  A  must contain  the   
           matrix A.   
           Unchanged on exit.   
  LDA    - INTEGER.   
           On entry, LDA specifies the first dimension of A as declared   
           in the calling (sub) program. When  TRANSA = 'N' or 'n' then   
           LDA must be at least  max( 1, m ), otherwise  LDA must be at   
           least  max( 1, k ).   
           Unchanged on exit.   
  B      - DOUBLE PRECISION array of DIMENSION ( LDB, kb ), where kb is   
           n  when  TRANSB = 'N' or 'n',  and is  k  otherwise.   
           Before entry with  TRANSB = 'N' or 'n',  the leading  k by n   
           part of the array  B  must contain the matrix  B,  otherwise   
           the leading  n by k  part of the array  B  must contain  the   
           matrix B.   
           Unchanged on exit.   
  LDB    - INTEGER.   
           On entry, LDB specifies the first dimension of B as declared   
           in the calling (sub) program. When  TRANSB = 'N' or 'n' then   
           LDB must be at least  max( 1, k ), otherwise  LDB must be at   
           least  max( 1, n ).   
           Unchanged on exit.   
  BETA   - DOUBLE PRECISION.   
           On entry,  BETA  specifies the scalar  beta.  When  BETA  is   
           supplied as zero then C need not be set on input.   
           Unchanged on exit.   
  C      - DOUBLE PRECISION array of DIMENSION ( LDC, n ).   
           Before entry, the leading  m by n  part of the array  C must   
           contain the matrix  C,  except when  beta  is zero, in which   
           case C need not be set on entry.   
           On exit, the array  C  is overwritten by the  m by n  matrix   
           ( alpha*op( A )*op( B ) + beta*C ).   
  LDC    - INTEGER.   
           On entry, LDC specifies the first dimension of C as declared   
           in  the  calling  (sub)  program.   LDC  must  be  at  least   
           max( 1, m ).   
           Unchanged on exit.   
  Level 3 Blas routine.   
  -- Written on 8-February-1989.   
     Jack Dongarra, Argonne National Laboratory.   
     Iain Duff, AERE Harwell.   
     Jeremy Du Croz, Numerical Algorithms Group Ltd.   
     Sven Hammarling, Numerical Algorithms Group Ltd.   
     Set  NOTA  and  NOTB  as  true if  A  and  B  respectively are not   
     transposed and set  NROWA, NCOLA and  NROWB  as the number of rows   
     and  columns of  A  and the  number of  rows  of  B  respectively.   
     Parameter adjustments */
  /**
   * C := alpha*op(A)*op(B)+beta*C
   * 
   * @param transa transa
   * @param transb tranb
   * @param m op(A)とCの列数
   * @param n op(B)とCの行数
   * @param k op(A)の行数と op(B)の列数
   * @param alpha alpha
   * @param a a
   * @param lda transa = 'n' then max(1,m),otherwise max(1,k)
   * @param b b
   * @param ldb transb = 'n' then max(1,k), otherwise max(1,n)
   * @param beta beta
   * @param c c
   * @param ldc max(1,m)
   * @return result
   */
  int dgemm(String transa, String transb, int m, int n, int k, RS alpha, RS[] a, int lda, RS[] b, int ldb, RS beta, RS[] c, int ldc) {
    final RS unit = a[0].createUnit();

    int a_dim1 = lda;
    int a_offset = 1 + a_dim1 * 1;
    int pointer_a = -a_offset;

    int b_dim1 = ldb;
    int b_offset = 1 + b_dim1 * 1;
    int pointer_b = -b_offset;

    int c_dim1 = ldc;
    int c_offset = 1 + c_dim1 * 1;
    int pointer_c = -c_offset;

    boolean nota = BLAS.lsame(transa, "N"); //$NON-NLS-1$
    boolean notb = BLAS.lsame(transb, "N"); //$NON-NLS-1$

    int ncola;
    int nrowa;
    int nrowb;
    if (nota) {
      nrowa = m;
      ncola = k;
    } else {
      nrowa = k;
      ncola = m;
    }
    if (notb) {
      nrowb = k;
    } else {
      nrowb = n;
    }

    // Test the input parameters.
    int info = 0;
    if (!nota && !BLAS.lsame(transa, "C") && !BLAS.lsame(transa, "T")) { //$NON-NLS-1$ //$NON-NLS-2$
      info = 1;
    } else if (!notb && !BLAS.lsame(transb, "C") && !BLAS.lsame(transb, "T")) { //$NON-NLS-1$ //$NON-NLS-2$
      info = 2;
    } else if (m < 0) {
      info = 3;
    } else if (n < 0) {
      info = 4;
    } else if (k < 0) {
      info = 5;
    } else if (lda < Math.max(1, nrowa)) {
      info = 8;
    } else if (ldb < Math.max(1, nrowb)) {
      info = 10;
    } else if (ldc < Math.max(1, m)) {
      info = 13;
    }
    if (info != 0) {
      BLAS.xerbla("DGEMM ", info); //$NON-NLS-1$
      return 0;
    }

    // Quick return if possible.
    if (m == 0 || n == 0 || (alpha.isZero() || k == 0) && beta.isZero()) {
      return 0;
    }

    // And if  alpha.eq.zero.
    if (alpha.isZero()) {
      if (beta.isZero()) {
        for (int j = 1; j <= n; ++j) {
          for (int i = 1; i <= m; ++i) {
            c[(j) * c_dim1 + i + pointer_c] = unit.createZero();
            // c(i, j) = 0.;
          }
        }
      } else {
        for (int j = 1; j <= n; ++j) {
          for (int i = 1; i <= m; ++i) {
            int p = j * c_dim1 + i + pointer_c;
            c[p] = beta.multiply(c[p]);
            // c(i, j) = beta * c(i, j);
          }
        }
      }
      return 0;
    }

    // Start the operations.
    if (notb) {
      if (nota) {
        for (int j = 1; j <= n; ++j) {
          if (beta.isZero()) {
            for (int i = 1; i <= m; ++i) {
              c[j * c_dim1 + i + pointer_c] = unit.create(0);
              // c(i, j) = 0.;
            }
          } else if (!beta.equals(unit.create(1))) {
            for (int i = 1; i <= m; ++i) {
              int p = j * c_dim1 + i + pointer_c;
              c[p] = beta.multiply(c[p]);
              // c(i, j) = *beta * c(i, j);
            }
          }
          for (int l = 1; l <= k; ++l) {
            if (!b[j * b_dim1 + l + pointer_b].isZero()) {
              RS temp = alpha.multiply(b[(j) * b_dim1 + l + pointer_b]);
              for (int i = 1; i <= m; ++i) {
                int p = j * c_dim1 + i + pointer_c;
                c[p] = c[p].add(temp.multiply(a[l * a_dim1 + i + pointer_a]));
                // c(i, j) = c(i, j) + temp * a_ref(i, l);
              }
            }
          }
        }
      } else {
        for (int j = 1; j <= n; ++j) {
          for (int i = 1; i <= m; ++i) {
            RS temp = unit.create(0);
            for (int l = 1; l <= k; ++l) {
              temp = temp.add(a[i * a_dim1 + l + pointer_a].multiply(b[j * b_dim1 + l + pointer_b]));
            }
            int cp = j * c_dim1 + i + pointer_c;
            if (beta.isZero()) {
              c[cp] = alpha.multiply(temp);
            } else {
              c[cp] = (alpha.multiply(temp)).add(beta.multiply(c[cp]));
            }
          }
        }
      }
    } else {
      if (nota) {
        for (int j = 1; j <= n; ++j) {
          if (beta.isZero()) {
            for (int i = 1; i <= m; ++i) {
              // c(i, j) = 0.;
              c[j * c_dim1 + i + pointer_c] = unit.createZero();
            }
          } else if (!beta.equals(unit.createUnit())) {
            for (int i = 1; i <= m; ++i) {
              int cp = j * c_dim1 + i + pointer_c;
              c[cp] = beta.multiply(c[cp]);
            }
          }
          for (int l = 1; l <= k; ++l) {
            int bp = l * b_dim1 + j + pointer_b;
            if (!b[bp].isZero()) {
              RS temp = alpha.multiply(b[bp]);
              for (int i = 1; i <= m; ++i) {
                int cp = j * c_dim1 + i + pointer_c;
                c[cp] = c[cp].add(temp.multiply(a[l * a_dim1 + i + pointer_a]));
              }
            }
          }
        }
      } else {
        for (int j = 1; j <= n; ++j) {
          for (int i = 1; i <= m; ++i) {
            RS temp = unit.createZero();
            for (int l = 1; l <= k; ++l) {
              temp = temp.add(a[i * a_dim1 + l + pointer_a].multiply(b[l * b_dim1 + j + pointer_b]));
            }
            int p = j * c_dim1 + i + pointer_c;
            if (beta.isZero()) {
              c[p] = alpha.multiply(temp);
            } else {
              c[p] = (alpha.multiply(temp)).add(beta.multiply(c[p]));
            }
          }
        }
      }
    }
    return 0;
  }
}