Lanczos.java

package org.mklab.sdpj.algorithm;

import org.mklab.nfc.matrix.ComplexNumericalMatrix;
import org.mklab.nfc.matrix.RealNumericalMatrix;
import org.mklab.nfc.scalar.ComplexNumericalScalar;
import org.mklab.nfc.scalar.RealNumericalScalar;
import org.mklab.sdpj.algebra.Algebra;
import org.mklab.sdpj.datastructure.BlockDenseMatrix;
import org.mklab.sdpj.datastructure.DenseMatrix;
import org.mklab.sdpj.datastructure.Vector;
import org.mklab.sdpj.gpack.blaswrap.BLAS;
import org.mklab.sdpj.gpack.lapackwrap.Clapack;
import org.mklab.sdpj.tool.Tools;


/**
 * @author koga
 * @version $Revision$, 2009/04/24
 * @param <RS> type of real scalar
 * @param <RM> type of real matrix
 * @param <CS> type of complex scalar
 * @param <CM> type of complex Matrix
 */
public class Lanczos<RS extends RealNumericalScalar<RS, RM, CS, CM>, RM extends RealNumericalMatrix<RS, RM, CS, CM>, CS extends ComplexNumericalScalar<RS, RM, CS, CM>, CM extends ComplexNumericalMatrix<RS, RM, CS, CM>> {

  /**
   * @param x ,
   * @param y ,
   * @return E
   */
  public RS getMinEigen(BlockDenseMatrix<RS,RM,CS,CM> x, BlockDenseMatrix<RS,RM,CS,CM> y) {
    final int blockSize = y.getBlockSize();

    RS min;
    if (y.getBlockStruct(0) < 0) {
      min = getMinEigen(x.getBlock(0), y.getBlock(0));
    } else {
      min = getMinEigen2(x.getBlock(0), y.getBlock(0));
    }

    for (int i = 1; i < blockSize; ++i) {
      RS value;
      if (y.getBlockStruct(i) < 0) {
        value = getMinEigen(x.getBlock(i), y.getBlock(i));
      } else {
        value = getMinEigen2(x.getBlock(i), y.getBlock(i));
      }
      if (value.isLessThan(min)) {
        min = value;
      }
    }

    return min;
  }

  /**
   * Refer to rsdpa_parts.cpp : SDPA 6.2.0 Line No. : 2545
   * 
   * @param x x
   * @param y y
   * @return minimum eigen
   */
  private RS getMinEigen2(DenseMatrix<RS,RM,CS,CM> x, DenseMatrix<RS,RM,CS,CM> y) {
    final RS unit = x.getElementUnit();

    // TODO パラメータとして抽出すべきです。
    final RS D1E50 = unit.create(10).power(50);
    final RS D1E51 = unit.create(10).power(51);
    final RS D1E52 = unit.create(10).power(52);
    final RS D1Em16 = unit.create(10).power(-16);

    RS min = D1E51;
    RS min_old = D1E52;
    RS error = unit.create(10).power(10);

    int nDim = y.getRowSize();
    int k = 0;
    int kk = 0;

    Vector<RS,RM,CS,CM> diagVec = new Vector<>(nDim, D1E50);
    Vector<RS,RM,CS,CM> diagVec2 = new Vector<>(nDim, unit.create(0));
    Vector<RS,RM,CS,CM> q = new Vector<>(nDim, unit.create(0));
    Vector<RS,RM,CS,CM> r = new Vector<>(nDim, unit.create(1));
    DenseMatrix<RS,RM,CS,CM> Q = new DenseMatrix<>(nDim, nDim, x.getDenseOrDiagonal(), unit);
    Vector<RS,RM,CS,CM> workVec = new Vector<>(Math.max(1, 2 * nDim - 2), unit.create(0));
    RS beta = unit.create(nDim).sqrt();

    final int kmax = (int)Math.min(nDim, Math.sqrt(nDim) + 10);
    while (k < kmax) {
      if (beta.isLessThanOrEquals(D1Em16)) {
        break;
      }

      final RS minTolerance = (min.abs().multiply(unit.create(10).power(-5))).add(unit.create(10).power(-8));
      final RS errorTolerance = (min.abs().multiply(unit.create(10).power(-2))).add(unit.create(10).power(-4));

      if (min.subtract(min_old).abs().isLessThanOrEquals(minTolerance) && error.multiply(beta).abs().isLessThanOrEquals(errorTolerance)) {
        break;
      }

      Vector<RS,RM,CS,CM> qold = q.createClone();
      q = Algebra.let(r, '*', unit.create(1).divide(beta));
      // w = (lMat^T)*q
      Vector<RS,RM,CS,CM> w = q.createClone();

      BLAS.dtrmv("Lower", "Transpose", "NotUnit", nDim, x.denseElements, nDim, w.getElements(), 1); //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$

      w = Algebra.let(y, '*', w);

      BLAS.dtrmv("Lower", "NoTranspose", "NotUnit", nDim, x.denseElements, nDim, w.getElements(), 1); //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$
      // w = lMat*xMat*(lMat^T)*q
      RS alpha = Algebra.run(q, '.', w);
      diagVec.setElement(k, alpha);

      r = Algebra.let(w, '-', q, alpha);
      r = Algebra.let(r, '-', qold, beta);

      if (kk >= Math.sqrt(k) || k == nDim - 1 || k > Math.sqrt(nDim + 9)) {
        kk = 0;
        Vector<RS,RM,CS,CM> out = diagVec.createClone();
        Vector<RS,RM,CS,CM> b = diagVec2.createClone();
        out.setElement(nDim - 1, diagVec.getElement(k));
        b.setElement(nDim - 1, unit.create(0));
        int kp1 = k + 1;
        int info = Clapack.dsteqr("I_withEigenvalues", kp1, out.getElements(), b.getElements(), Q.denseElements, Q.getRowSize(), workVec.getElements()); //$NON-NLS-1$
        if (info < 0) {
          Tools.error(" rLanczos :: bad argument " + (-info) + " Q.nRow = " + Q.getRowSize() + ": nDim = " + nDim + ": kp1 = " + kp1); //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$ //$NON-NLS-4$
        } else if (info > 0) {
          Tools.message(" rLanczos :: cannot converge " + info); //$NON-NLS-1$
          break;
        }
        min_old = min;
        // out have eigen values with ascending order.
        min = out.getElement(0);
        error = Q.denseElements[k];
      }

      RS value2 = Algebra.run(r, '.', r);
      beta = value2.sqrt();
      diagVec2.getElements()[k] = beta;
      ++k;
      ++kk;
    }
    return min.subtract((error.multiply(beta)).abs());
  }

  /**
   * Refer to rsdpa_parts.cpp : SDPA 6.2.0 Line No. : 2655
   * 
   * @param lMat lMat
   * @param xMat xMat
   * @return minimum eigen
   */
  private RS getMinEigen(DenseMatrix<RS,RM,CS,CM> lMat, DenseMatrix<RS,RM,CS,CM> xMat) {
    // lMat, xMat is Diagonal
    final int nDim = xMat.getRowSize();
    final RS[] l_ele = lMat.diagonalElements;
    final RS[] x_ele = xMat.diagonalElements;

    RS min = l_ele[0].multiply(x_ele[0]).multiply(l_ele[0]);
    //E value;
    int shou = nDim / 4;
    int amari = nDim % 4;

    for (int i = 1; i < amari; ++i) {
      RS value = l_ele[i].multiply(x_ele[i]).multiply(l_ele[i]);
      if (value.isLessThan(min)) {
        min = value;
      }
    }

    int count, i;
    for (i = amari, count = 0; count < shou; ++count, i += 4) {
      RS value1 = l_ele[i].multiply(x_ele[i]).multiply(l_ele[i]);
      RS value2 = l_ele[i + 1].multiply(x_ele[i + 1]).multiply(l_ele[i + 1]);
      RS value3 = l_ele[i + 2].multiply(x_ele[i + 2]).multiply(l_ele[i + 2]);
      RS value4 = l_ele[i + 3].multiply(x_ele[i + 3]).multiply(l_ele[i + 3]);

      RS tmp1;
      if (value1.isLessThan(value2)) {
        tmp1 = value1;
      } else {
        tmp1 = value2;
      }

      RS tmp2;
      if (value3.isLessThan(value4)) {
        tmp2 = value3;
      } else {
        tmp2 = value4;
      }

      if (tmp1.isLessThan(tmp2)) {
        if (tmp1.isLessThan(min)) {
          min = tmp1;
        }
      } else if (tmp2.isLessThan(min)) {
        min = tmp2;
      }
    }
    return min;
  }
}