Dpotrf.java

package org.mklab.sdpj.gpack.lapackwrap;

import org.mklab.nfc.matrix.ComplexNumericalMatrix;
import org.mklab.nfc.matrix.RealNumericalMatrix;
import org.mklab.nfc.scalar.ComplexNumericalScalar;
import org.mklab.nfc.scalar.RealNumericalScalar;
import org.mklab.sdpj.gpack.blaswrap.BLAS;


/**
 * 対称/エルミート正定値行列をコレスキー分解するクラス
 * 
 * @author takafumi
   * @param <RS> type of real scalar
   * @param <RM> type of real matrix
   * @param <CS> type of complex scalar
   * @param <CM> type of complex Matrix
 */
public class Dpotrf<RS extends RealNumericalScalar<RS, RM, CS, CM>, RM extends RealNumericalMatrix<RS, RM, CS, CM>, CS extends ComplexNumericalScalar<RS, RM, CS, CM>, CM extends ComplexNumericalMatrix<RS, RM, CS, CM>> {

  /*  -- LAPACK routine (version 3.0) --   
  Univ. of Tennessee, Univ. of California Berkeley, NAG Ltd.,   
  Courant Institute, Argonne National Lab, and Rice University   
  March 31, 1993   


  Purpose   
  =======   

  DPOTRF computes the Cholesky factorization of a real symmetric   
  positive definite matrix A.   

  The factorization has the form   
  A = U**T * U,  if UPLO = 'U', or   
  A = L  * L**T,  if UPLO = 'L',   
  where U is an upper triangular matrix and L is lower triangular.   

  This is the block version of the algorithm, calling Level 3 BLAS.   

  Arguments   
  =========   

  UPLO    (input) CHARACTER*1   
       = 'U':  Upper triangle of A is stored;   
       = 'L':  Lower triangle of A is stored.   

  N       (input) INTEGER   
       The order of the matrix A.  N >= 0.   

  A       (input/output) DOUBLE PRECISION array, dimension (LDA,N)   
       On entry, the symmetric matrix A.  If UPLO = 'U', the leading   
       N-by-N upper triangular part of A contains the upper   
       triangular part of the matrix A, and the strictly lower   
       triangular part of A is not referenced.  If UPLO = 'L', the   
       leading N-by-N lower triangular part of A contains the lower   
       triangular part of the matrix A, and the strictly upper   
       triangular part of A is not referenced.   

       On exit, if INFO = 0, the factor U or L from the Cholesky   
       factorization A = U**T*U or A = L*L**T.   

  LDA     (input) INTEGER   
       The leading dimension of the array A.  LDA >= max(1,N).   

  INFO    (output) INTEGER   
       = 0:  successful exit   
       < 0:  if INFO = -i, the i-th argument had an illegal value   
       > 0:  if INFO = i, the leading minor of order i is not   
             positive definite, and the factorization could not be   
             completed.   

  =====================================================================   
  */
  /**
   * @param uplo uplo
   * @param n n
   * @param a a
   * @param lda lda
   * @return result
   */
  public int execute(String uplo, int n, RS[] a, int lda) {
    final RS unit = a[0].createUnit();

    RS c_b13 = unit.create(-1);
    RS c_b14 = unit.create(1);

    int a_dim1 = lda;
    int a_offset = 1 + a_dim1 * 1;
    int pointer_a = -a_offset;

    int info = 0;
    boolean upper = BLAS.lsame(uplo, "U"); //$NON-NLS-1$
    if (!upper && !BLAS.lsame(uplo, "L")) { //$NON-NLS-1$
      info = -1;
    } else if (n < 0) {
      info = -2;
    } else if (lda < Math.max(1, n)) {
      info = -4;
    }

    if (info != 0) {
      BLAS.xerbla("DPOTRF", -info); //$NON-NLS-1$
      return info;
    }

    /* Quick return if possible */
    if (n == 0) {
      return info;
    }

    /* Determine the block size for this environment. */
    int nb = Clapack.ilaenv(1, "DPOTRF", uplo, n, -1, -1, -1, 6, 1, unit); //$NON-NLS-1$
    if (nb <= 1 || nb >= n) {
      /* Use unblocked code. */
      RS[] a_temp = unit.createArray(a.length - (a_offset + pointer_a));
      System.arraycopy(a, a_offset + pointer_a, a_temp, 0, a_temp.length);
      info = Clapack.dpotf2(uplo, n, a_temp, lda);
      System.arraycopy(a_temp, 0, a, a_offset + pointer_a, a_temp.length);
    } else {
      /* Use blocked code. */
      if (upper) {
        /* Compute the Cholesky factorization A = U'U. */
        for (int j = 1; nb < 0 ? j >= n : j <= n; j += nb) {
          /*
           * Update and factorize the current diagonal block and test for non-positive-definiteness.
           * 
           * Computing MIN
           */
          int jb = Math.min(nb, n - j + 1);
          RS[] a_1jTemp = unit.createArray(a.length - (j * a_dim1 + 1 + pointer_a));
          System.arraycopy(a, j * a_dim1 + 1 + pointer_a, a_1jTemp, 0, a_1jTemp.length);
          // (j,j)
          RS[] a_jjTemp = unit.createArray(a.length - (j * a_dim1 + j + pointer_a));
          System.arraycopy(a, j * a_dim1 + j + pointer_a, a_jjTemp, 0, a_jjTemp.length);
          // ------------
          BLAS.dsyrk("Upper", "Transpose", jb, j - 1, c_b13, a_1jTemp, lda, c_b14, a_jjTemp, lda); //$NON-NLS-1$ //$NON-NLS-2$
          info = Clapack.dpotf2("Upper", jb, a_jjTemp, lda); //$NON-NLS-1$
          // 元に戻す
          System.arraycopy(a_1jTemp, 0, a, j * a_dim1 + 1 + pointer_a, a_1jTemp.length);
          System.arraycopy(a_jjTemp, 0, a, j * a_dim1 + j + pointer_a, a_jjTemp.length);
          // -------
          if (info != 0) {
            info = info + j - 1;
            return info;
          }

          if (j + jb <= n) {

            /* Compute the current block row. */
            // a_ref(1,j)
            a_1jTemp = unit.createArray(a.length - (j * a_dim1 + 1 + pointer_a));
            System.arraycopy(a, j * a_dim1 + 1 + pointer_a, a_1jTemp, 0, a_1jTemp.length);
            // (a_ref(j,j)
            a_jjTemp = unit.createArray(a.length - (j * a_dim1 + j + pointer_a));
            System.arraycopy(a, j * a_dim1 + j + pointer_a, a_jjTemp, 0, a_jjTemp.length);
            // a_ref(1,J+jb)
            RS[] a_1jjbTemp = unit.createArray(a.length - ((j + jb) * a_dim1 + 1 + pointer_a));
            System.arraycopy(a, (j + jb) * a_dim1 + 1 + pointer_a, a_1jjbTemp, 0, a_1jjbTemp.length);
            // a_ref(1,j+jb)
            RS[] a_jjjbTemp = unit.createArray(a.length - ((j + jb) * a_dim1 + j + pointer_a));
            System.arraycopy(a, (j + jb) * a_dim1 + j + pointer_a, a_jjjbTemp, 0, a_jjjbTemp.length);
            BLAS.dgemm("Transpose", "No transpose", jb, n - j - jb + 1, j - 1, c_b13, a_1jTemp, lda, a_1jjbTemp, lda, c_b14, a_jjjbTemp, lda); //$NON-NLS-1$ //$NON-NLS-2$
            BLAS.dtrsm("Left", "Upper", "Transpose", "Non-unit", jb, n - j - jb + 1, c_b14, a_jjTemp, lda, a_jjjbTemp, lda); //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$ //$NON-NLS-4$

            System.arraycopy(a_1jTemp, 0, a, j * a_dim1 + 1 + pointer_a, a_1jTemp.length);
            System.arraycopy(a_jjTemp, 0, a, j * a_dim1 + j + pointer_a, a_jjTemp.length);
            System.arraycopy(a_1jjbTemp, 0, a, (j + jb) * a_dim1 + 1 + pointer_a, a_1jjbTemp.length);
            System.arraycopy(a_jjjbTemp, 0, a, (j + jb) * a_dim1 + j + pointer_a, a_jjjbTemp.length);
          }
          /* L10: */
        }

      } else {
        for (int j = 1; nb < 0 ? j >= n : j <= n; j += nb) {
          /*
           * Update and factorize the current diagonal block and test for
           * non-positive-definiteness.
           * 
           * Computing MIN
           */
          int jb = Math.min(nb, n - j + 1);
          RS[] a_j1Temp = unit.createArray(a.length - (a_dim1 + j + pointer_a));
          System.arraycopy(a, a_dim1 + j + pointer_a, a_j1Temp, 0, a_j1Temp.length);
          RS[] a_jjTemp = unit.createArray(a.length - (j * a_dim1 + j + pointer_a));
          System.arraycopy(a, j * a_dim1 + j + pointer_a, a_jjTemp, 0, a_jjTemp.length);

          BLAS.dsyrk("Lower", "No transpose", jb, j - 1, c_b13, a_j1Temp, lda, c_b14, a_jjTemp, lda); //$NON-NLS-1$ //$NON-NLS-2$
          info = Clapack.dpotf2("Lower", jb, a_jjTemp, lda); //$NON-NLS-1$
          System.arraycopy(a_j1Temp, 0, a, a_dim1 + j + pointer_a, a_j1Temp.length);
          System.arraycopy(a_jjTemp, 0, a, j * a_dim1 + j + pointer_a, a_jjTemp.length);

          if (info != 0) {
            info = info + j - 1;
            return info;
          }

          if (j + jb <= n) {
            /* Compute the current block column. */
            RS[] a_jjb1Temp = unit.createArray(a.length - (a_dim1 + j + jb + pointer_a));
            RS[] a_jjbjTemp = unit.createArray(a.length - (j * a_dim1 + j + jb + pointer_a));
            System.arraycopy(a, a_dim1 + j + jb + pointer_a, a_jjb1Temp, 0, a_jjb1Temp.length);
            System.arraycopy(a, j * a_dim1 + j + jb + pointer_a, a_jjbjTemp, 0, a_jjbjTemp.length);
            System.arraycopy(a, a_dim1 + j + pointer_a, a_j1Temp, 0, a_j1Temp.length);
            System.arraycopy(a, j * a_dim1 + j + pointer_a, a_jjTemp, 0, a_jjTemp.length);
            BLAS.dgemm("No transpose", "Transpose", n - j - jb + 1, jb, j - 1, c_b13, a_jjb1Temp, lda, a_j1Temp, lda, c_b14, a_jjbjTemp, lda); //$NON-NLS-1$ //$NON-NLS-2$
            BLAS.dtrsm("Right", "Lower", "Transpose", "Non-unit", n - j - jb + 1, jb, c_b14, a_jjTemp, lda, a_jjbjTemp, lda); //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$ //$NON-NLS-4$
            System.arraycopy(a_j1Temp, 0, a, a_dim1 + j + pointer_a, a_j1Temp.length);
            System.arraycopy(a_jjTemp, 0, a, j * a_dim1 + j + pointer_a, a_jjTemp.length);
            System.arraycopy(a_jjb1Temp, 0, a, a_dim1 + j + jb + pointer_a, a_jjb1Temp.length);
            System.arraycopy(a_jjbjTemp, 0, a, j * a_dim1 + j + jb + pointer_a, a_jjbjTemp.length);
          }
          /* L20: */
        }
      }
    }

    return info;
  }
}