Tuesday, April 28, 2009

Moore-Penrose Pseudoinverse in JAMA

import Jama.Matrix;
import Jama.SingularValueDecomposition;

public class Matrices {
 /**
  * The difference between 1 and the smallest exactly representable number
  * greater than one. Gives an upper bound on the relative error due to
  * rounding of floating point numbers.
  */
 public static double MACHEPS = 2E-16;

 /**
  * Updates MACHEPS for the executing machine.
  */
 public static void updateMacheps() {
  MACHEPS = 1;
  do
   MACHEPS /= 2;
  while (1 + MACHEPS / 2 != 1);
 }

 /**
  * Computes the Moore–Penrose pseudoinverse using the SVD method.
  * 
  * Modified version of the original implementation by Kim van der Linde.
  */
 public static Matrix pinv(Matrix x) {
  int rows = x.getRowDimension();
  int cols = x.getColumnDimension();
  if (rows < cols) {
   Matrix result = pinv(x.transpose());
   if (result != null)
    result = result.transpose();
   return result;
  }
  SingularValueDecomposition svdX = new SingularValueDecomposition(x);
  if (svdX.rank() < 1)
   return null;
  double[] singularValues = svdX.getSingularValues();
  double tol = Math.max(rows, cols) * singularValues[0] * MACHEPS;
  double[] singularValueReciprocals = new double[singularValues.length];
  for (int i = 0; i < singularValues.length; i++)
   if (Math.abs(singularValues[i]) >= tol)
    singularValueReciprocals[i] =  1.0 / singularValues[i];
  double[][] u = svdX.getU().getArray();
  double[][] v = svdX.getV().getArray();
  int min = Math.min(cols, u[0].length);
  double[][] inverse = new double[cols][rows];
  for (int i = 0; i < cols; i++)
   for (int j = 0; j < u.length; j++)
    for (int k = 0; k < min; k++)
     inverse[i][j] += v[i][k] * singularValueReciprocals[k] * u[j][k];
  return new Matrix(inverse);
 }
}
Update 11/20/2014: As this code continues to live out there, I'm adding a basic test. Just so you know what you're getting. (As is. No warranty. Wish you the best.)
public static boolean checkEquality(Matrix A, Matrix B) {
 return A.minus(B).normInf() < 1e-9;
}
 
public static void testPinv() {
 int rows = (int) Math.floor(100 + Math.random() * 200);
 int cols = (int) Math.floor(100 + Math.random() * 200);
 double[][] mat = new double[rows][cols];
 for (int i = 0; i < rows; i++)
  for (int j = 0; j < cols; j++)
   mat[i][j] = Math.random();
 Matrix A = new Matrix(mat);
 long millis = System.currentTimeMillis();
 Matrix Aplus = pinv(A);
 millis = System.currentTimeMillis() - millis;
 if (Aplus == null) {
  System.out.println("Always check for null");
  return;
 }
 // Good to know.
 boolean c1 = checkEquality(A.times(Aplus).times(A), A);
 boolean c2 = checkEquality(Aplus.times(A).times(Aplus), Aplus);
 boolean c3 = checkEquality(A.times(Aplus), A.times(Aplus).transpose());
 boolean c4 = checkEquality(Aplus.times(A), Aplus.times(A).transpose());
 System.out.println(rows + " x " + cols + "\t" +
                    c1 + "/" + c2 + "/" + c3 + "/" + c4 + "\t" + millis);
}
 
public static void main(String[] args) {
 for (int z = 0; z < 100; z++)
  testPinv();
}