0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.linalg;
0019
0020 import java.util.Random;
0021
0022 import static org.junit.Assert.assertArrayEquals;
0023 import static org.junit.Assert.assertEquals;
0024
0025 import org.junit.Test;
0026
0027 public class JavaMatricesSuite {
0028
0029 @Test
0030 public void randMatrixConstruction() {
0031 Random rng = new Random(24);
0032 Matrix r = Matrices.rand(3, 4, rng);
0033 rng.setSeed(24);
0034 DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
0035 assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
0036
0037 rng.setSeed(24);
0038 Matrix rn = Matrices.randn(3, 4, rng);
0039 rng.setSeed(24);
0040 DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
0041 assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
0042
0043 rng.setSeed(24);
0044 Matrix s = Matrices.sprand(3, 4, 0.5, rng);
0045 rng.setSeed(24);
0046 SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
0047 assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
0048
0049 rng.setSeed(24);
0050 Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
0051 rng.setSeed(24);
0052 SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
0053 assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
0054 }
0055
0056 @Test
0057 public void identityMatrixConstruction() {
0058 Matrix r = Matrices.eye(2);
0059 DenseMatrix dr = DenseMatrix.eye(2);
0060 SparseMatrix sr = SparseMatrix.speye(2);
0061 assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
0062 assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
0063 assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
0064 }
0065
0066 @Test
0067 public void diagonalMatrixConstruction() {
0068 Vector v = Vectors.dense(1.0, 0.0, 2.0);
0069 Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
0070
0071 Matrix m = Matrices.diag(v);
0072 Matrix sm = Matrices.diag(sv);
0073 DenseMatrix d = DenseMatrix.diag(v);
0074 DenseMatrix sd = DenseMatrix.diag(sv);
0075 SparseMatrix s = SparseMatrix.spdiag(v);
0076 SparseMatrix ss = SparseMatrix.spdiag(sv);
0077
0078 assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
0079 assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
0080 assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
0081 assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
0082 assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
0083 assertArrayEquals(s.values(), ss.values(), 0.0);
0084 assertEquals(2, s.values().length);
0085 assertEquals(2, ss.values().length);
0086 assertEquals(4, s.colPtrs().length);
0087 assertEquals(4, ss.colPtrs().length);
0088 }
0089
0090 @Test
0091 public void zerosMatrixConstruction() {
0092 Matrix z = Matrices.zeros(2, 2);
0093 Matrix one = Matrices.ones(2, 2);
0094 DenseMatrix dz = DenseMatrix.zeros(2, 2);
0095 DenseMatrix done = DenseMatrix.ones(2, 2);
0096
0097 assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
0098 assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
0099 assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
0100 assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
0101 }
0102
0103 @Test
0104 public void sparseDenseConversion() {
0105 int m = 3;
0106 int n = 2;
0107 double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
0108 double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
0109 int[] colPtrs = new int[]{0, 2, 4};
0110 int[] rowIndices = new int[]{0, 1, 1, 2};
0111
0112 SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values);
0113 DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
0114
0115 SparseMatrix spMat2 = deMat1.toSparse();
0116 DenseMatrix deMat2 = spMat1.toDense();
0117
0118 assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
0119 assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
0120 }
0121
0122 @Test
0123 public void concatenateMatrices() {
0124 int m = 3;
0125 int n = 2;
0126
0127 Random rng = new Random(42);
0128 SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
0129 rng.setSeed(42);
0130 DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
0131 Matrix deMat2 = Matrices.eye(3);
0132 Matrix spMat2 = Matrices.speye(3);
0133 Matrix deMat3 = Matrices.eye(2);
0134 Matrix spMat3 = Matrices.speye(2);
0135
0136 Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
0137 Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
0138 Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
0139 Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
0140
0141 assertEquals(3, deHorz1.numRows());
0142 assertEquals(3, deHorz2.numRows());
0143 assertEquals(3, deHorz3.numRows());
0144 assertEquals(3, spHorz.numRows());
0145 assertEquals(5, deHorz1.numCols());
0146 assertEquals(5, deHorz2.numCols());
0147 assertEquals(5, deHorz3.numCols());
0148 assertEquals(5, spHorz.numCols());
0149
0150 Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
0151 Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
0152 Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
0153 Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
0154
0155 assertEquals(5, deVert1.numRows());
0156 assertEquals(5, deVert2.numRows());
0157 assertEquals(5, deVert3.numRows());
0158 assertEquals(5, spVert.numRows());
0159 assertEquals(2, deVert1.numCols());
0160 assertEquals(2, deVert2.numCols());
0161 assertEquals(2, deVert3.numCols());
0162 assertEquals(2, spVert.numCols());
0163 }
0164 }