Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 """
0019 Package for distributed linear algebra.
0020 """
0021 
0022 import sys
0023 
0024 if sys.version >= '3':
0025     long = int
0026 
0027 from py4j.java_gateway import JavaObject
0028 
0029 from pyspark import RDD, since
0030 from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
0031 from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, QRDecomposition
0032 from pyspark.mllib.stat import MultivariateStatisticalSummary
0033 from pyspark.sql import DataFrame
0034 from pyspark.storagelevel import StorageLevel
0035 
0036 
0037 __all__ = ['BlockMatrix', 'CoordinateMatrix', 'DistributedMatrix', 'IndexedRow',
0038            'IndexedRowMatrix', 'MatrixEntry', 'RowMatrix', 'SingularValueDecomposition']
0039 
0040 
0041 class DistributedMatrix(object):
0042     """
0043     Represents a distributively stored matrix backed by one or
0044     more RDDs.
0045 
0046     """
0047     def numRows(self):
0048         """Get or compute the number of rows."""
0049         raise NotImplementedError
0050 
0051     def numCols(self):
0052         """Get or compute the number of cols."""
0053         raise NotImplementedError
0054 
0055 
0056 class RowMatrix(DistributedMatrix):
0057     """
0058     Represents a row-oriented distributed Matrix with no meaningful
0059     row indices.
0060 
0061     :param rows: An RDD or DataFrame of vectors. If a DataFrame is provided, it must have a single
0062                  vector typed column.
0063     :param numRows: Number of rows in the matrix. A non-positive
0064                     value means unknown, at which point the number
0065                     of rows will be determined by the number of
0066                     records in the `rows` RDD.
0067     :param numCols: Number of columns in the matrix. A non-positive
0068                     value means unknown, at which point the number
0069                     of columns will be determined by the size of
0070                     the first row.
0071     """
0072     def __init__(self, rows, numRows=0, numCols=0):
0073         """
0074         Note: This docstring is not shown publicly.
0075 
0076         Create a wrapper over a Java RowMatrix.
0077 
0078         Publicly, we require that `rows` be an RDD or DataFrame.  However, for
0079         internal usage, `rows` can also be a Java RowMatrix
0080         object, in which case we can wrap it directly.  This
0081         assists in clean matrix conversions.
0082 
0083         >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6]])
0084         >>> mat = RowMatrix(rows)
0085 
0086         >>> mat_diff = RowMatrix(rows)
0087         >>> (mat_diff._java_matrix_wrapper._java_model ==
0088         ...  mat._java_matrix_wrapper._java_model)
0089         False
0090 
0091         >>> mat_same = RowMatrix(mat._java_matrix_wrapper._java_model)
0092         >>> (mat_same._java_matrix_wrapper._java_model ==
0093         ...  mat._java_matrix_wrapper._java_model)
0094         True
0095         """
0096         if isinstance(rows, RDD):
0097             rows = rows.map(_convert_to_vector)
0098             java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols))
0099         elif isinstance(rows, DataFrame):
0100             java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols))
0101         elif (isinstance(rows, JavaObject)
0102               and rows.getClass().getSimpleName() == "RowMatrix"):
0103             java_matrix = rows
0104         else:
0105             raise TypeError("rows should be an RDD of vectors, got %s" % type(rows))
0106 
0107         self._java_matrix_wrapper = JavaModelWrapper(java_matrix)
0108 
0109     @property
0110     def rows(self):
0111         """
0112         Rows of the RowMatrix stored as an RDD of vectors.
0113 
0114         >>> mat = RowMatrix(sc.parallelize([[1, 2, 3], [4, 5, 6]]))
0115         >>> rows = mat.rows
0116         >>> rows.first()
0117         DenseVector([1.0, 2.0, 3.0])
0118         """
0119         return self._java_matrix_wrapper.call("rows")
0120 
0121     def numRows(self):
0122         """
0123         Get or compute the number of rows.
0124 
0125         >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6],
0126         ...                        [7, 8, 9], [10, 11, 12]])
0127 
0128         >>> mat = RowMatrix(rows)
0129         >>> print(mat.numRows())
0130         4
0131 
0132         >>> mat = RowMatrix(rows, 7, 6)
0133         >>> print(mat.numRows())
0134         7
0135         """
0136         return self._java_matrix_wrapper.call("numRows")
0137 
0138     def numCols(self):
0139         """
0140         Get or compute the number of cols.
0141 
0142         >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6],
0143         ...                        [7, 8, 9], [10, 11, 12]])
0144 
0145         >>> mat = RowMatrix(rows)
0146         >>> print(mat.numCols())
0147         3
0148 
0149         >>> mat = RowMatrix(rows, 7, 6)
0150         >>> print(mat.numCols())
0151         6
0152         """
0153         return self._java_matrix_wrapper.call("numCols")
0154 
0155     @since('2.0.0')
0156     def computeColumnSummaryStatistics(self):
0157         """
0158         Computes column-wise summary statistics.
0159 
0160         :return: :class:`MultivariateStatisticalSummary` object
0161                  containing column-wise summary statistics.
0162 
0163         >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6]])
0164         >>> mat = RowMatrix(rows)
0165 
0166         >>> colStats = mat.computeColumnSummaryStatistics()
0167         >>> colStats.mean()
0168         array([ 2.5,  3.5,  4.5])
0169         """
0170         java_col_stats = self._java_matrix_wrapper.call("computeColumnSummaryStatistics")
0171         return MultivariateStatisticalSummary(java_col_stats)
0172 
0173     @since('2.0.0')
0174     def computeCovariance(self):
0175         """
0176         Computes the covariance matrix, treating each row as an
0177         observation.
0178 
0179         .. note:: This cannot be computed on matrices with more than 65535 columns.
0180 
0181         >>> rows = sc.parallelize([[1, 2], [2, 1]])
0182         >>> mat = RowMatrix(rows)
0183 
0184         >>> mat.computeCovariance()
0185         DenseMatrix(2, 2, [0.5, -0.5, -0.5, 0.5], 0)
0186         """
0187         return self._java_matrix_wrapper.call("computeCovariance")
0188 
0189     @since('2.0.0')
0190     def computeGramianMatrix(self):
0191         """
0192         Computes the Gramian matrix `A^T A`.
0193 
0194         .. note:: This cannot be computed on matrices with more than 65535 columns.
0195 
0196         >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6]])
0197         >>> mat = RowMatrix(rows)
0198 
0199         >>> mat.computeGramianMatrix()
0200         DenseMatrix(3, 3, [17.0, 22.0, 27.0, 22.0, 29.0, 36.0, 27.0, 36.0, 45.0], 0)
0201         """
0202         return self._java_matrix_wrapper.call("computeGramianMatrix")
0203 
0204     @since('2.0.0')
0205     def columnSimilarities(self, threshold=0.0):
0206         """
0207         Compute similarities between columns of this matrix.
0208 
0209         The threshold parameter is a trade-off knob between estimate
0210         quality and computational cost.
0211 
0212         The default threshold setting of 0 guarantees deterministically
0213         correct results, but uses the brute-force approach of computing
0214         normalized dot products.
0215 
0216         Setting the threshold to positive values uses a sampling
0217         approach and incurs strictly less computational cost than the
0218         brute-force approach. However the similarities computed will
0219         be estimates.
0220 
0221         The sampling guarantees relative-error correctness for those
0222         pairs of columns that have similarity greater than the given
0223         similarity threshold.
0224 
0225         To describe the guarantee, we set some notation:
0226             * Let A be the smallest in magnitude non-zero element of
0227               this matrix.
0228             * Let B be the largest in magnitude non-zero element of
0229               this matrix.
0230             * Let L be the maximum number of non-zeros per row.
0231 
0232         For example, for {0,1} matrices: A=B=1.
0233         Another example, for the Netflix matrix: A=1, B=5
0234 
0235         For those column pairs that are above the threshold, the
0236         computed similarity is correct to within 20% relative error
0237         with probability at least 1 - (0.981)^10/B^
0238 
0239         The shuffle size is bounded by the *smaller* of the following
0240         two expressions:
0241 
0242             * O(n log(n) L / (threshold * A))
0243             * O(m L^2^)
0244 
0245         The latter is the cost of the brute-force approach, so for
0246         non-zero thresholds, the cost is always cheaper than the
0247         brute-force approach.
0248 
0249         :param: threshold: Set to 0 for deterministic guaranteed
0250                            correctness. Similarities above this
0251                            threshold are estimated with the cost vs
0252                            estimate quality trade-off described above.
0253         :return: An n x n sparse upper-triangular CoordinateMatrix of
0254                  cosine similarities between columns of this matrix.
0255 
0256         >>> rows = sc.parallelize([[1, 2], [1, 5]])
0257         >>> mat = RowMatrix(rows)
0258 
0259         >>> sims = mat.columnSimilarities()
0260         >>> sims.entries.first().value
0261         0.91914503...
0262         """
0263         java_sims_mat = self._java_matrix_wrapper.call("columnSimilarities", float(threshold))
0264         return CoordinateMatrix(java_sims_mat)
0265 
0266     @since('2.0.0')
0267     def tallSkinnyQR(self, computeQ=False):
0268         """
0269         Compute the QR decomposition of this RowMatrix.
0270 
0271         The implementation is designed to optimize the QR decomposition
0272         (factorization) for the RowMatrix of a tall and skinny shape.
0273 
0274         Reference:
0275          Paul G. Constantine, David F. Gleich. "Tall and skinny QR
0276          factorizations in MapReduce architectures"
0277          ([[https://doi.org/10.1145/1996092.1996103]])
0278 
0279         :param: computeQ: whether to computeQ
0280         :return: QRDecomposition(Q: RowMatrix, R: Matrix), where
0281                  Q = None if computeQ = false.
0282 
0283         >>> rows = sc.parallelize([[3, -6], [4, -8], [0, 1]])
0284         >>> mat = RowMatrix(rows)
0285         >>> decomp = mat.tallSkinnyQR(True)
0286         >>> Q = decomp.Q
0287         >>> R = decomp.R
0288 
0289         >>> # Test with absolute values
0290         >>> absQRows = Q.rows.map(lambda row: abs(row.toArray()).tolist())
0291         >>> absQRows.collect()
0292         [[0.6..., 0.0], [0.8..., 0.0], [0.0, 1.0]]
0293 
0294         >>> # Test with absolute values
0295         >>> abs(R.toArray()).tolist()
0296         [[5.0, 10.0], [0.0, 1.0]]
0297         """
0298         decomp = JavaModelWrapper(self._java_matrix_wrapper.call("tallSkinnyQR", computeQ))
0299         if computeQ:
0300             java_Q = decomp.call("Q")
0301             Q = RowMatrix(java_Q)
0302         else:
0303             Q = None
0304         R = decomp.call("R")
0305         return QRDecomposition(Q, R)
0306 
0307     @since('2.2.0')
0308     def computeSVD(self, k, computeU=False, rCond=1e-9):
0309         """
0310         Computes the singular value decomposition of the RowMatrix.
0311 
0312         The given row matrix A of dimension (m X n) is decomposed into
0313         U * s * V'T where
0314 
0315         * U: (m X k) (left singular vectors) is a RowMatrix whose
0316              columns are the eigenvectors of (A X A')
0317         * s: DenseVector consisting of square root of the eigenvalues
0318              (singular values) in descending order.
0319         * v: (n X k) (right singular vectors) is a Matrix whose columns
0320              are the eigenvectors of (A' X A)
0321 
0322         For more specific details on implementation, please refer
0323         the Scala documentation.
0324 
0325         :param k: Number of leading singular values to keep (`0 < k <= n`).
0326                   It might return less than k if there are numerically zero singular values
0327                   or there are not enough Ritz values converged before the maximum number of
0328                   Arnoldi update iterations is reached (in case that matrix A is ill-conditioned).
0329         :param computeU: Whether or not to compute U. If set to be
0330                          True, then U is computed by A * V * s^-1
0331         :param rCond: Reciprocal condition number. All singular values
0332                       smaller than rCond * s[0] are treated as zero
0333                       where s[0] is the largest singular value.
0334         :returns: :py:class:`SingularValueDecomposition`
0335 
0336         >>> rows = sc.parallelize([[3, 1, 1], [-1, 3, 1]])
0337         >>> rm = RowMatrix(rows)
0338 
0339         >>> svd_model = rm.computeSVD(2, True)
0340         >>> svd_model.U.rows.collect()
0341         [DenseVector([-0.7071, 0.7071]), DenseVector([-0.7071, -0.7071])]
0342         >>> svd_model.s
0343         DenseVector([3.4641, 3.1623])
0344         >>> svd_model.V
0345         DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0)
0346         """
0347         j_model = self._java_matrix_wrapper.call(
0348             "computeSVD", int(k), bool(computeU), float(rCond))
0349         return SingularValueDecomposition(j_model)
0350 
0351     @since('2.2.0')
0352     def computePrincipalComponents(self, k):
0353         """
0354         Computes the k principal components of the given row matrix
0355 
0356         .. note:: This cannot be computed on matrices with more than 65535 columns.
0357 
0358         :param k: Number of principal components to keep.
0359         :returns: :py:class:`pyspark.mllib.linalg.DenseMatrix`
0360 
0361         >>> rows = sc.parallelize([[1, 2, 3], [2, 4, 5], [3, 6, 1]])
0362         >>> rm = RowMatrix(rows)
0363 
0364         >>> # Returns the two principal components of rm
0365         >>> pca = rm.computePrincipalComponents(2)
0366         >>> pca
0367         DenseMatrix(3, 2, [-0.349, -0.6981, 0.6252, -0.2796, -0.5592, -0.7805], 0)
0368 
0369         >>> # Transform into new dimensions with the greatest variance.
0370         >>> rm.multiply(pca).rows.collect() # doctest: +NORMALIZE_WHITESPACE
0371         [DenseVector([0.1305, -3.7394]), DenseVector([-0.3642, -6.6983]), \
0372         DenseVector([-4.6102, -4.9745])]
0373         """
0374         return self._java_matrix_wrapper.call("computePrincipalComponents", k)
0375 
0376     @since('2.2.0')
0377     def multiply(self, matrix):
0378         """
0379         Multiply this matrix by a local dense matrix on the right.
0380 
0381         :param matrix: a local dense matrix whose number of rows must match the number of columns
0382                        of this matrix
0383         :returns: :py:class:`RowMatrix`
0384 
0385         >>> rm = RowMatrix(sc.parallelize([[0, 1], [2, 3]]))
0386         >>> rm.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect()
0387         [DenseVector([2.0, 3.0]), DenseVector([6.0, 11.0])]
0388         """
0389         if not isinstance(matrix, DenseMatrix):
0390             raise ValueError("Only multiplication with DenseMatrix "
0391                              "is supported.")
0392         j_model = self._java_matrix_wrapper.call("multiply", matrix)
0393         return RowMatrix(j_model)
0394 
0395 
0396 class SingularValueDecomposition(JavaModelWrapper):
0397     """
0398     Represents singular value decomposition (SVD) factors.
0399 
0400     .. versionadded:: 2.2.0
0401     """
0402 
0403     @property
0404     @since('2.2.0')
0405     def U(self):
0406         """
0407         Returns a distributed matrix whose columns are the left
0408         singular vectors of the SingularValueDecomposition if computeU was set to be True.
0409         """
0410         u = self.call("U")
0411         if u is not None:
0412             mat_name = u.getClass().getSimpleName()
0413             if mat_name == "RowMatrix":
0414                 return RowMatrix(u)
0415             elif mat_name == "IndexedRowMatrix":
0416                 return IndexedRowMatrix(u)
0417             else:
0418                 raise TypeError("Expected RowMatrix/IndexedRowMatrix got %s" % mat_name)
0419 
0420     @property
0421     @since('2.2.0')
0422     def s(self):
0423         """
0424         Returns a DenseVector with singular values in descending order.
0425         """
0426         return self.call("s")
0427 
0428     @property
0429     @since('2.2.0')
0430     def V(self):
0431         """
0432         Returns a DenseMatrix whose columns are the right singular
0433         vectors of the SingularValueDecomposition.
0434         """
0435         return self.call("V")
0436 
0437 
0438 class IndexedRow(object):
0439     """
0440     Represents a row of an IndexedRowMatrix.
0441 
0442     Just a wrapper over a (long, vector) tuple.
0443 
0444     :param index: The index for the given row.
0445     :param vector: The row in the matrix at the given index.
0446     """
0447     def __init__(self, index, vector):
0448         self.index = long(index)
0449         self.vector = _convert_to_vector(vector)
0450 
0451     def __repr__(self):
0452         return "IndexedRow(%s, %s)" % (self.index, self.vector)
0453 
0454 
0455 def _convert_to_indexed_row(row):
0456     if isinstance(row, IndexedRow):
0457         return row
0458     elif isinstance(row, tuple) and len(row) == 2:
0459         return IndexedRow(*row)
0460     else:
0461         raise TypeError("Cannot convert type %s into IndexedRow" % type(row))
0462 
0463 
0464 class IndexedRowMatrix(DistributedMatrix):
0465     """
0466     Represents a row-oriented distributed Matrix with indexed rows.
0467 
0468     :param rows: An RDD of IndexedRows or (long, vector) tuples or a DataFrame consisting of a
0469                  long typed column of indices and a vector typed column.
0470     :param numRows: Number of rows in the matrix. A non-positive
0471                     value means unknown, at which point the number
0472                     of rows will be determined by the max row
0473                     index plus one.
0474     :param numCols: Number of columns in the matrix. A non-positive
0475                     value means unknown, at which point the number
0476                     of columns will be determined by the size of
0477                     the first row.
0478     """
0479     def __init__(self, rows, numRows=0, numCols=0):
0480         """
0481         Note: This docstring is not shown publicly.
0482 
0483         Create a wrapper over a Java IndexedRowMatrix.
0484 
0485         Publicly, we require that `rows` be an RDD or DataFrame.  However, for
0486         internal usage, `rows` can also be a Java IndexedRowMatrix
0487         object, in which case we can wrap it directly.  This
0488         assists in clean matrix conversions.
0489 
0490         >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
0491         ...                        IndexedRow(1, [4, 5, 6])])
0492         >>> mat = IndexedRowMatrix(rows)
0493 
0494         >>> mat_diff = IndexedRowMatrix(rows)
0495         >>> (mat_diff._java_matrix_wrapper._java_model ==
0496         ...  mat._java_matrix_wrapper._java_model)
0497         False
0498 
0499         >>> mat_same = IndexedRowMatrix(mat._java_matrix_wrapper._java_model)
0500         >>> (mat_same._java_matrix_wrapper._java_model ==
0501         ...  mat._java_matrix_wrapper._java_model)
0502         True
0503         """
0504         if isinstance(rows, RDD):
0505             rows = rows.map(_convert_to_indexed_row)
0506             # We use DataFrames for serialization of IndexedRows from
0507             # Python, so first convert the RDD to a DataFrame on this
0508             # side. This will convert each IndexedRow to a Row
0509             # containing the 'index' and 'vector' values, which can
0510             # both be easily serialized.  We will convert back to
0511             # IndexedRows on the Scala side.
0512             java_matrix = callMLlibFunc("createIndexedRowMatrix", rows.toDF(),
0513                                         long(numRows), int(numCols))
0514         elif isinstance(rows, DataFrame):
0515             java_matrix = callMLlibFunc("createIndexedRowMatrix", rows, long(numRows), int(numCols))
0516         elif (isinstance(rows, JavaObject)
0517               and rows.getClass().getSimpleName() == "IndexedRowMatrix"):
0518             java_matrix = rows
0519         else:
0520             raise TypeError("rows should be an RDD of IndexedRows or (long, vector) tuples, "
0521                             "got %s" % type(rows))
0522 
0523         self._java_matrix_wrapper = JavaModelWrapper(java_matrix)
0524 
0525     @property
0526     def rows(self):
0527         """
0528         Rows of the IndexedRowMatrix stored as an RDD of IndexedRows.
0529 
0530         >>> mat = IndexedRowMatrix(sc.parallelize([IndexedRow(0, [1, 2, 3]),
0531         ...                                        IndexedRow(1, [4, 5, 6])]))
0532         >>> rows = mat.rows
0533         >>> rows.first()
0534         IndexedRow(0, [1.0,2.0,3.0])
0535         """
0536         # We use DataFrames for serialization of IndexedRows from
0537         # Java, so we first convert the RDD of rows to a DataFrame
0538         # on the Scala/Java side. Then we map each Row in the
0539         # DataFrame back to an IndexedRow on this side.
0540         rows_df = callMLlibFunc("getIndexedRows", self._java_matrix_wrapper._java_model)
0541         rows = rows_df.rdd.map(lambda row: IndexedRow(row[0], row[1]))
0542         return rows
0543 
0544     def numRows(self):
0545         """
0546         Get or compute the number of rows.
0547 
0548         >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
0549         ...                        IndexedRow(1, [4, 5, 6]),
0550         ...                        IndexedRow(2, [7, 8, 9]),
0551         ...                        IndexedRow(3, [10, 11, 12])])
0552 
0553         >>> mat = IndexedRowMatrix(rows)
0554         >>> print(mat.numRows())
0555         4
0556 
0557         >>> mat = IndexedRowMatrix(rows, 7, 6)
0558         >>> print(mat.numRows())
0559         7
0560         """
0561         return self._java_matrix_wrapper.call("numRows")
0562 
0563     def numCols(self):
0564         """
0565         Get or compute the number of cols.
0566 
0567         >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
0568         ...                        IndexedRow(1, [4, 5, 6]),
0569         ...                        IndexedRow(2, [7, 8, 9]),
0570         ...                        IndexedRow(3, [10, 11, 12])])
0571 
0572         >>> mat = IndexedRowMatrix(rows)
0573         >>> print(mat.numCols())
0574         3
0575 
0576         >>> mat = IndexedRowMatrix(rows, 7, 6)
0577         >>> print(mat.numCols())
0578         6
0579         """
0580         return self._java_matrix_wrapper.call("numCols")
0581 
0582     def columnSimilarities(self):
0583         """
0584         Compute all cosine similarities between columns.
0585 
0586         >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
0587         ...                        IndexedRow(6, [4, 5, 6])])
0588         >>> mat = IndexedRowMatrix(rows)
0589         >>> cs = mat.columnSimilarities()
0590         >>> print(cs.numCols())
0591         3
0592         """
0593         java_coordinate_matrix = self._java_matrix_wrapper.call("columnSimilarities")
0594         return CoordinateMatrix(java_coordinate_matrix)
0595 
0596     @since('2.0.0')
0597     def computeGramianMatrix(self):
0598         """
0599         Computes the Gramian matrix `A^T A`.
0600 
0601         .. note:: This cannot be computed on matrices with more than 65535 columns.
0602 
0603         >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
0604         ...                        IndexedRow(1, [4, 5, 6])])
0605         >>> mat = IndexedRowMatrix(rows)
0606 
0607         >>> mat.computeGramianMatrix()
0608         DenseMatrix(3, 3, [17.0, 22.0, 27.0, 22.0, 29.0, 36.0, 27.0, 36.0, 45.0], 0)
0609         """
0610         return self._java_matrix_wrapper.call("computeGramianMatrix")
0611 
0612     def toRowMatrix(self):
0613         """
0614         Convert this matrix to a RowMatrix.
0615 
0616         >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
0617         ...                        IndexedRow(6, [4, 5, 6])])
0618         >>> mat = IndexedRowMatrix(rows).toRowMatrix()
0619         >>> mat.rows.collect()
0620         [DenseVector([1.0, 2.0, 3.0]), DenseVector([4.0, 5.0, 6.0])]
0621         """
0622         java_row_matrix = self._java_matrix_wrapper.call("toRowMatrix")
0623         return RowMatrix(java_row_matrix)
0624 
0625     def toCoordinateMatrix(self):
0626         """
0627         Convert this matrix to a CoordinateMatrix.
0628 
0629         >>> rows = sc.parallelize([IndexedRow(0, [1, 0]),
0630         ...                        IndexedRow(6, [0, 5])])
0631         >>> mat = IndexedRowMatrix(rows).toCoordinateMatrix()
0632         >>> mat.entries.take(3)
0633         [MatrixEntry(0, 0, 1.0), MatrixEntry(0, 1, 0.0), MatrixEntry(6, 0, 0.0)]
0634         """
0635         java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix")
0636         return CoordinateMatrix(java_coordinate_matrix)
0637 
0638     def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024):
0639         """
0640         Convert this matrix to a BlockMatrix.
0641 
0642         :param rowsPerBlock: Number of rows that make up each block.
0643                              The blocks forming the final rows are not
0644                              required to have the given number of rows.
0645         :param colsPerBlock: Number of columns that make up each block.
0646                              The blocks forming the final columns are not
0647                              required to have the given number of columns.
0648 
0649         >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
0650         ...                        IndexedRow(6, [4, 5, 6])])
0651         >>> mat = IndexedRowMatrix(rows).toBlockMatrix()
0652 
0653         >>> # This IndexedRowMatrix will have 7 effective rows, due to
0654         >>> # the highest row index being 6, and the ensuing
0655         >>> # BlockMatrix will have 7 rows as well.
0656         >>> print(mat.numRows())
0657         7
0658 
0659         >>> print(mat.numCols())
0660         3
0661         """
0662         java_block_matrix = self._java_matrix_wrapper.call("toBlockMatrix",
0663                                                            rowsPerBlock,
0664                                                            colsPerBlock)
0665         return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock)
0666 
0667     @since('2.2.0')
0668     def computeSVD(self, k, computeU=False, rCond=1e-9):
0669         """
0670         Computes the singular value decomposition of the IndexedRowMatrix.
0671 
0672         The given row matrix A of dimension (m X n) is decomposed into
0673         U * s * V'T where
0674 
0675         * U: (m X k) (left singular vectors) is a IndexedRowMatrix
0676              whose columns are the eigenvectors of (A X A')
0677         * s: DenseVector consisting of square root of the eigenvalues
0678              (singular values) in descending order.
0679         * v: (n X k) (right singular vectors) is a Matrix whose columns
0680              are the eigenvectors of (A' X A)
0681 
0682         For more specific details on implementation, please refer
0683         the scala documentation.
0684 
0685         :param k: Number of leading singular values to keep (`0 < k <= n`).
0686                   It might return less than k if there are numerically zero singular values
0687                   or there are not enough Ritz values converged before the maximum number of
0688                   Arnoldi update iterations is reached (in case that matrix A is ill-conditioned).
0689         :param computeU: Whether or not to compute U. If set to be
0690                          True, then U is computed by A * V * s^-1
0691         :param rCond: Reciprocal condition number. All singular values
0692                       smaller than rCond * s[0] are treated as zero
0693                       where s[0] is the largest singular value.
0694         :returns: SingularValueDecomposition object
0695 
0696         >>> rows = [(0, (3, 1, 1)), (1, (-1, 3, 1))]
0697         >>> irm = IndexedRowMatrix(sc.parallelize(rows))
0698         >>> svd_model = irm.computeSVD(2, True)
0699         >>> svd_model.U.rows.collect() # doctest: +NORMALIZE_WHITESPACE
0700         [IndexedRow(0, [-0.707106781187,0.707106781187]),\
0701         IndexedRow(1, [-0.707106781187,-0.707106781187])]
0702         >>> svd_model.s
0703         DenseVector([3.4641, 3.1623])
0704         >>> svd_model.V
0705         DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0)
0706         """
0707         j_model = self._java_matrix_wrapper.call(
0708             "computeSVD", int(k), bool(computeU), float(rCond))
0709         return SingularValueDecomposition(j_model)
0710 
0711     @since('2.2.0')
0712     def multiply(self, matrix):
0713         """
0714         Multiply this matrix by a local dense matrix on the right.
0715 
0716         :param matrix: a local dense matrix whose number of rows must match the number of columns
0717                        of this matrix
0718         :returns: :py:class:`IndexedRowMatrix`
0719 
0720         >>> mat = IndexedRowMatrix(sc.parallelize([(0, (0, 1)), (1, (2, 3))]))
0721         >>> mat.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect()
0722         [IndexedRow(0, [2.0,3.0]), IndexedRow(1, [6.0,11.0])]
0723         """
0724         if not isinstance(matrix, DenseMatrix):
0725             raise ValueError("Only multiplication with DenseMatrix "
0726                              "is supported.")
0727         return IndexedRowMatrix(self._java_matrix_wrapper.call("multiply", matrix))
0728 
0729 
0730 class MatrixEntry(object):
0731     """
0732     Represents an entry of a CoordinateMatrix.
0733 
0734     Just a wrapper over a (long, long, float) tuple.
0735 
0736     :param i: The row index of the matrix.
0737     :param j: The column index of the matrix.
0738     :param value: The (i, j)th entry of the matrix, as a float.
0739     """
0740     def __init__(self, i, j, value):
0741         self.i = long(i)
0742         self.j = long(j)
0743         self.value = float(value)
0744 
0745     def __repr__(self):
0746         return "MatrixEntry(%s, %s, %s)" % (self.i, self.j, self.value)
0747 
0748 
0749 def _convert_to_matrix_entry(entry):
0750     if isinstance(entry, MatrixEntry):
0751         return entry
0752     elif isinstance(entry, tuple) and len(entry) == 3:
0753         return MatrixEntry(*entry)
0754     else:
0755         raise TypeError("Cannot convert type %s into MatrixEntry" % type(entry))
0756 
0757 
0758 class CoordinateMatrix(DistributedMatrix):
0759     """
0760     Represents a matrix in coordinate format.
0761 
0762     :param entries: An RDD of MatrixEntry inputs or
0763                     (long, long, float) tuples.
0764     :param numRows: Number of rows in the matrix. A non-positive
0765                     value means unknown, at which point the number
0766                     of rows will be determined by the max row
0767                     index plus one.
0768     :param numCols: Number of columns in the matrix. A non-positive
0769                     value means unknown, at which point the number
0770                     of columns will be determined by the max row
0771                     index plus one.
0772     """
0773     def __init__(self, entries, numRows=0, numCols=0):
0774         """
0775         Note: This docstring is not shown publicly.
0776 
0777         Create a wrapper over a Java CoordinateMatrix.
0778 
0779         Publicly, we require that `rows` be an RDD.  However, for
0780         internal usage, `rows` can also be a Java CoordinateMatrix
0781         object, in which case we can wrap it directly.  This
0782         assists in clean matrix conversions.
0783 
0784         >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
0785         ...                           MatrixEntry(6, 4, 2.1)])
0786         >>> mat = CoordinateMatrix(entries)
0787 
0788         >>> mat_diff = CoordinateMatrix(entries)
0789         >>> (mat_diff._java_matrix_wrapper._java_model ==
0790         ...  mat._java_matrix_wrapper._java_model)
0791         False
0792 
0793         >>> mat_same = CoordinateMatrix(mat._java_matrix_wrapper._java_model)
0794         >>> (mat_same._java_matrix_wrapper._java_model ==
0795         ...  mat._java_matrix_wrapper._java_model)
0796         True
0797         """
0798         if isinstance(entries, RDD):
0799             entries = entries.map(_convert_to_matrix_entry)
0800             # We use DataFrames for serialization of MatrixEntry entries
0801             # from Python, so first convert the RDD to a DataFrame on
0802             # this side. This will convert each MatrixEntry to a Row
0803             # containing the 'i', 'j', and 'value' values, which can
0804             # each be easily serialized. We will convert back to
0805             # MatrixEntry inputs on the Scala side.
0806             java_matrix = callMLlibFunc("createCoordinateMatrix", entries.toDF(),
0807                                         long(numRows), long(numCols))
0808         elif (isinstance(entries, JavaObject)
0809               and entries.getClass().getSimpleName() == "CoordinateMatrix"):
0810             java_matrix = entries
0811         else:
0812             raise TypeError("entries should be an RDD of MatrixEntry entries or "
0813                             "(long, long, float) tuples, got %s" % type(entries))
0814 
0815         self._java_matrix_wrapper = JavaModelWrapper(java_matrix)
0816 
0817     @property
0818     def entries(self):
0819         """
0820         Entries of the CoordinateMatrix stored as an RDD of
0821         MatrixEntries.
0822 
0823         >>> mat = CoordinateMatrix(sc.parallelize([MatrixEntry(0, 0, 1.2),
0824         ...                                        MatrixEntry(6, 4, 2.1)]))
0825         >>> entries = mat.entries
0826         >>> entries.first()
0827         MatrixEntry(0, 0, 1.2)
0828         """
0829         # We use DataFrames for serialization of MatrixEntry entries
0830         # from Java, so we first convert the RDD of entries to a
0831         # DataFrame on the Scala/Java side. Then we map each Row in
0832         # the DataFrame back to a MatrixEntry on this side.
0833         entries_df = callMLlibFunc("getMatrixEntries", self._java_matrix_wrapper._java_model)
0834         entries = entries_df.rdd.map(lambda row: MatrixEntry(row[0], row[1], row[2]))
0835         return entries
0836 
0837     def numRows(self):
0838         """
0839         Get or compute the number of rows.
0840 
0841         >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
0842         ...                           MatrixEntry(1, 0, 2),
0843         ...                           MatrixEntry(2, 1, 3.7)])
0844 
0845         >>> mat = CoordinateMatrix(entries)
0846         >>> print(mat.numRows())
0847         3
0848 
0849         >>> mat = CoordinateMatrix(entries, 7, 6)
0850         >>> print(mat.numRows())
0851         7
0852         """
0853         return self._java_matrix_wrapper.call("numRows")
0854 
0855     def numCols(self):
0856         """
0857         Get or compute the number of cols.
0858 
0859         >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
0860         ...                           MatrixEntry(1, 0, 2),
0861         ...                           MatrixEntry(2, 1, 3.7)])
0862 
0863         >>> mat = CoordinateMatrix(entries)
0864         >>> print(mat.numCols())
0865         2
0866 
0867         >>> mat = CoordinateMatrix(entries, 7, 6)
0868         >>> print(mat.numCols())
0869         6
0870         """
0871         return self._java_matrix_wrapper.call("numCols")
0872 
0873     @since('2.0.0')
0874     def transpose(self):
0875         """
0876         Transpose this CoordinateMatrix.
0877 
0878         >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
0879         ...                           MatrixEntry(1, 0, 2),
0880         ...                           MatrixEntry(2, 1, 3.7)])
0881         >>> mat = CoordinateMatrix(entries)
0882         >>> mat_transposed = mat.transpose()
0883 
0884         >>> print(mat_transposed.numRows())
0885         2
0886 
0887         >>> print(mat_transposed.numCols())
0888         3
0889         """
0890         java_transposed_matrix = self._java_matrix_wrapper.call("transpose")
0891         return CoordinateMatrix(java_transposed_matrix)
0892 
0893     def toRowMatrix(self):
0894         """
0895         Convert this matrix to a RowMatrix.
0896 
0897         >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
0898         ...                           MatrixEntry(6, 4, 2.1)])
0899         >>> mat = CoordinateMatrix(entries).toRowMatrix()
0900 
0901         >>> # This CoordinateMatrix will have 7 effective rows, due to
0902         >>> # the highest row index being 6, but the ensuing RowMatrix
0903         >>> # will only have 2 rows since there are only entries on 2
0904         >>> # unique rows.
0905         >>> print(mat.numRows())
0906         2
0907 
0908         >>> # This CoordinateMatrix will have 5 columns, due to the
0909         >>> # highest column index being 4, and the ensuing RowMatrix
0910         >>> # will have 5 columns as well.
0911         >>> print(mat.numCols())
0912         5
0913         """
0914         java_row_matrix = self._java_matrix_wrapper.call("toRowMatrix")
0915         return RowMatrix(java_row_matrix)
0916 
0917     def toIndexedRowMatrix(self):
0918         """
0919         Convert this matrix to an IndexedRowMatrix.
0920 
0921         >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
0922         ...                           MatrixEntry(6, 4, 2.1)])
0923         >>> mat = CoordinateMatrix(entries).toIndexedRowMatrix()
0924 
0925         >>> # This CoordinateMatrix will have 7 effective rows, due to
0926         >>> # the highest row index being 6, and the ensuing
0927         >>> # IndexedRowMatrix will have 7 rows as well.
0928         >>> print(mat.numRows())
0929         7
0930 
0931         >>> # This CoordinateMatrix will have 5 columns, due to the
0932         >>> # highest column index being 4, and the ensuing
0933         >>> # IndexedRowMatrix will have 5 columns as well.
0934         >>> print(mat.numCols())
0935         5
0936         """
0937         java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix")
0938         return IndexedRowMatrix(java_indexed_row_matrix)
0939 
0940     def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024):
0941         """
0942         Convert this matrix to a BlockMatrix.
0943 
0944         :param rowsPerBlock: Number of rows that make up each block.
0945                              The blocks forming the final rows are not
0946                              required to have the given number of rows.
0947         :param colsPerBlock: Number of columns that make up each block.
0948                              The blocks forming the final columns are not
0949                              required to have the given number of columns.
0950 
0951         >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2),
0952         ...                           MatrixEntry(6, 4, 2.1)])
0953         >>> mat = CoordinateMatrix(entries).toBlockMatrix()
0954 
0955         >>> # This CoordinateMatrix will have 7 effective rows, due to
0956         >>> # the highest row index being 6, and the ensuing
0957         >>> # BlockMatrix will have 7 rows as well.
0958         >>> print(mat.numRows())
0959         7
0960 
0961         >>> # This CoordinateMatrix will have 5 columns, due to the
0962         >>> # highest column index being 4, and the ensuing
0963         >>> # BlockMatrix will have 5 columns as well.
0964         >>> print(mat.numCols())
0965         5
0966         """
0967         java_block_matrix = self._java_matrix_wrapper.call("toBlockMatrix",
0968                                                            rowsPerBlock,
0969                                                            colsPerBlock)
0970         return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock)
0971 
0972 
0973 def _convert_to_matrix_block_tuple(block):
0974     if (isinstance(block, tuple) and len(block) == 2
0975             and isinstance(block[0], tuple) and len(block[0]) == 2
0976             and isinstance(block[1], Matrix)):
0977         blockRowIndex = int(block[0][0])
0978         blockColIndex = int(block[0][1])
0979         subMatrix = block[1]
0980         return ((blockRowIndex, blockColIndex), subMatrix)
0981     else:
0982         raise TypeError("Cannot convert type %s into a sub-matrix block tuple" % type(block))
0983 
0984 
0985 class BlockMatrix(DistributedMatrix):
0986     """
0987     Represents a distributed matrix in blocks of local matrices.
0988 
0989     :param blocks: An RDD of sub-matrix blocks
0990                    ((blockRowIndex, blockColIndex), sub-matrix) that
0991                    form this distributed matrix. If multiple blocks
0992                    with the same index exist, the results for
0993                    operations like add and multiply will be
0994                    unpredictable.
0995     :param rowsPerBlock: Number of rows that make up each block.
0996                          The blocks forming the final rows are not
0997                          required to have the given number of rows.
0998     :param colsPerBlock: Number of columns that make up each block.
0999                          The blocks forming the final columns are not
1000                          required to have the given number of columns.
1001     :param numRows: Number of rows of this matrix. If the supplied
1002                     value is less than or equal to zero, the number
1003                     of rows will be calculated when `numRows` is
1004                     invoked.
1005     :param numCols: Number of columns of this matrix. If the supplied
1006                     value is less than or equal to zero, the number
1007                     of columns will be calculated when `numCols` is
1008                     invoked.
1009     """
1010     def __init__(self, blocks, rowsPerBlock, colsPerBlock, numRows=0, numCols=0):
1011         """
1012         Note: This docstring is not shown publicly.
1013 
1014         Create a wrapper over a Java BlockMatrix.
1015 
1016         Publicly, we require that `blocks` be an RDD.  However, for
1017         internal usage, `blocks` can also be a Java BlockMatrix
1018         object, in which case we can wrap it directly.  This
1019         assists in clean matrix conversions.
1020 
1021         >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
1022         ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
1023         >>> mat = BlockMatrix(blocks, 3, 2)
1024 
1025         >>> mat_diff = BlockMatrix(blocks, 3, 2)
1026         >>> (mat_diff._java_matrix_wrapper._java_model ==
1027         ...  mat._java_matrix_wrapper._java_model)
1028         False
1029 
1030         >>> mat_same = BlockMatrix(mat._java_matrix_wrapper._java_model, 3, 2)
1031         >>> (mat_same._java_matrix_wrapper._java_model ==
1032         ...  mat._java_matrix_wrapper._java_model)
1033         True
1034         """
1035         if isinstance(blocks, RDD):
1036             blocks = blocks.map(_convert_to_matrix_block_tuple)
1037             # We use DataFrames for serialization of sub-matrix blocks
1038             # from Python, so first convert the RDD to a DataFrame on
1039             # this side. This will convert each sub-matrix block
1040             # tuple to a Row containing the 'blockRowIndex',
1041             # 'blockColIndex', and 'subMatrix' values, which can
1042             # each be easily serialized.  We will convert back to
1043             # ((blockRowIndex, blockColIndex), sub-matrix) tuples on
1044             # the Scala side.
1045             java_matrix = callMLlibFunc("createBlockMatrix", blocks.toDF(),
1046                                         int(rowsPerBlock), int(colsPerBlock),
1047                                         long(numRows), long(numCols))
1048         elif (isinstance(blocks, JavaObject)
1049               and blocks.getClass().getSimpleName() == "BlockMatrix"):
1050             java_matrix = blocks
1051         else:
1052             raise TypeError("blocks should be an RDD of sub-matrix blocks as "
1053                             "((int, int), matrix) tuples, got %s" % type(blocks))
1054 
1055         self._java_matrix_wrapper = JavaModelWrapper(java_matrix)
1056 
1057     @property
1058     def blocks(self):
1059         """
1060         The RDD of sub-matrix blocks
1061         ((blockRowIndex, blockColIndex), sub-matrix) that form this
1062         distributed matrix.
1063 
1064         >>> mat = BlockMatrix(
1065         ...     sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
1066         ...                     ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]), 3, 2)
1067         >>> blocks = mat.blocks
1068         >>> blocks.first()
1069         ((0, 0), DenseMatrix(3, 2, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 0))
1070 
1071         """
1072         # We use DataFrames for serialization of sub-matrix blocks
1073         # from Java, so we first convert the RDD of blocks to a
1074         # DataFrame on the Scala/Java side. Then we map each Row in
1075         # the DataFrame back to a sub-matrix block on this side.
1076         blocks_df = callMLlibFunc("getMatrixBlocks", self._java_matrix_wrapper._java_model)
1077         blocks = blocks_df.rdd.map(lambda row: ((row[0][0], row[0][1]), row[1]))
1078         return blocks
1079 
1080     @property
1081     def rowsPerBlock(self):
1082         """
1083         Number of rows that make up each block.
1084 
1085         >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
1086         ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
1087         >>> mat = BlockMatrix(blocks, 3, 2)
1088         >>> mat.rowsPerBlock
1089         3
1090         """
1091         return self._java_matrix_wrapper.call("rowsPerBlock")
1092 
1093     @property
1094     def colsPerBlock(self):
1095         """
1096         Number of columns that make up each block.
1097 
1098         >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
1099         ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
1100         >>> mat = BlockMatrix(blocks, 3, 2)
1101         >>> mat.colsPerBlock
1102         2
1103         """
1104         return self._java_matrix_wrapper.call("colsPerBlock")
1105 
1106     @property
1107     def numRowBlocks(self):
1108         """
1109         Number of rows of blocks in the BlockMatrix.
1110 
1111         >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
1112         ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
1113         >>> mat = BlockMatrix(blocks, 3, 2)
1114         >>> mat.numRowBlocks
1115         2
1116         """
1117         return self._java_matrix_wrapper.call("numRowBlocks")
1118 
1119     @property
1120     def numColBlocks(self):
1121         """
1122         Number of columns of blocks in the BlockMatrix.
1123 
1124         >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
1125         ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
1126         >>> mat = BlockMatrix(blocks, 3, 2)
1127         >>> mat.numColBlocks
1128         1
1129         """
1130         return self._java_matrix_wrapper.call("numColBlocks")
1131 
1132     def numRows(self):
1133         """
1134         Get or compute the number of rows.
1135 
1136         >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
1137         ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
1138 
1139         >>> mat = BlockMatrix(blocks, 3, 2)
1140         >>> print(mat.numRows())
1141         6
1142 
1143         >>> mat = BlockMatrix(blocks, 3, 2, 7, 6)
1144         >>> print(mat.numRows())
1145         7
1146         """
1147         return self._java_matrix_wrapper.call("numRows")
1148 
1149     def numCols(self):
1150         """
1151         Get or compute the number of cols.
1152 
1153         >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
1154         ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
1155 
1156         >>> mat = BlockMatrix(blocks, 3, 2)
1157         >>> print(mat.numCols())
1158         2
1159 
1160         >>> mat = BlockMatrix(blocks, 3, 2, 7, 6)
1161         >>> print(mat.numCols())
1162         6
1163         """
1164         return self._java_matrix_wrapper.call("numCols")
1165 
1166     @since('2.0.0')
1167     def cache(self):
1168         """
1169         Caches the underlying RDD.
1170         """
1171         self._java_matrix_wrapper.call("cache")
1172         return self
1173 
1174     @since('2.0.0')
1175     def persist(self, storageLevel):
1176         """
1177         Persists the underlying RDD with the specified storage level.
1178         """
1179         if not isinstance(storageLevel, StorageLevel):
1180             raise TypeError("`storageLevel` should be a StorageLevel, got %s" % type(storageLevel))
1181         javaStorageLevel = self._java_matrix_wrapper._sc._getJavaStorageLevel(storageLevel)
1182         self._java_matrix_wrapper.call("persist", javaStorageLevel)
1183         return self
1184 
1185     @since('2.0.0')
1186     def validate(self):
1187         """
1188         Validates the block matrix info against the matrix data (`blocks`)
1189         and throws an exception if any error is found.
1190         """
1191         self._java_matrix_wrapper.call("validate")
1192 
1193     def add(self, other):
1194         """
1195         Adds two block matrices together. The matrices must have the
1196         same size and matching `rowsPerBlock` and `colsPerBlock` values.
1197         If one of the sub matrix blocks that are being added is a
1198         SparseMatrix, the resulting sub matrix block will also be a
1199         SparseMatrix, even if it is being added to a DenseMatrix. If
1200         two dense sub matrix blocks are added, the output block will
1201         also be a DenseMatrix.
1202 
1203         >>> dm1 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])
1204         >>> dm2 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12])
1205         >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12])
1206         >>> blocks1 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)])
1207         >>> blocks2 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)])
1208         >>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm2)])
1209         >>> mat1 = BlockMatrix(blocks1, 3, 2)
1210         >>> mat2 = BlockMatrix(blocks2, 3, 2)
1211         >>> mat3 = BlockMatrix(blocks3, 3, 2)
1212 
1213         >>> mat1.add(mat2).toLocalMatrix()
1214         DenseMatrix(6, 2, [2.0, 4.0, 6.0, 14.0, 16.0, 18.0, 8.0, 10.0, 12.0, 20.0, 22.0, 24.0], 0)
1215 
1216         >>> mat1.add(mat3).toLocalMatrix()
1217         DenseMatrix(6, 2, [8.0, 2.0, 3.0, 14.0, 16.0, 18.0, 4.0, 16.0, 18.0, 20.0, 22.0, 24.0], 0)
1218         """
1219         if not isinstance(other, BlockMatrix):
1220             raise TypeError("Other should be a BlockMatrix, got %s" % type(other))
1221 
1222         other_java_block_matrix = other._java_matrix_wrapper._java_model
1223         java_block_matrix = self._java_matrix_wrapper.call("add", other_java_block_matrix)
1224         return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock)
1225 
1226     @since('2.0.0')
1227     def subtract(self, other):
1228         """
1229         Subtracts the given block matrix `other` from this block matrix:
1230         `this - other`. The matrices must have the same size and
1231         matching `rowsPerBlock` and `colsPerBlock` values.  If one of
1232         the sub matrix blocks that are being subtracted is a
1233         SparseMatrix, the resulting sub matrix block will also be a
1234         SparseMatrix, even if it is being subtracted from a DenseMatrix.
1235         If two dense sub matrix blocks are subtracted, the output block
1236         will also be a DenseMatrix.
1237 
1238         >>> dm1 = Matrices.dense(3, 2, [3, 1, 5, 4, 6, 2])
1239         >>> dm2 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12])
1240         >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [1, 2, 3])
1241         >>> blocks1 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)])
1242         >>> blocks2 = sc.parallelize([((0, 0), dm2), ((1, 0), dm1)])
1243         >>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm2)])
1244         >>> mat1 = BlockMatrix(blocks1, 3, 2)
1245         >>> mat2 = BlockMatrix(blocks2, 3, 2)
1246         >>> mat3 = BlockMatrix(blocks3, 3, 2)
1247 
1248         >>> mat1.subtract(mat2).toLocalMatrix()
1249         DenseMatrix(6, 2, [-4.0, -7.0, -4.0, 4.0, 7.0, 4.0, -6.0, -5.0, -10.0, 6.0, 5.0, 10.0], 0)
1250 
1251         >>> mat2.subtract(mat3).toLocalMatrix()
1252         DenseMatrix(6, 2, [6.0, 8.0, 9.0, -4.0, -7.0, -4.0, 10.0, 9.0, 9.0, -6.0, -5.0, -10.0], 0)
1253         """
1254         if not isinstance(other, BlockMatrix):
1255             raise TypeError("Other should be a BlockMatrix, got %s" % type(other))
1256 
1257         other_java_block_matrix = other._java_matrix_wrapper._java_model
1258         java_block_matrix = self._java_matrix_wrapper.call("subtract", other_java_block_matrix)
1259         return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock)
1260 
1261     def multiply(self, other):
1262         """
1263         Left multiplies this BlockMatrix by `other`, another
1264         BlockMatrix. The `colsPerBlock` of this matrix must equal the
1265         `rowsPerBlock` of `other`. If `other` contains any SparseMatrix
1266         blocks, they will have to be converted to DenseMatrix blocks.
1267         The output BlockMatrix will only consist of DenseMatrix blocks.
1268         This may cause some performance issues until support for
1269         multiplying two sparse matrices is added.
1270 
1271         >>> dm1 = Matrices.dense(2, 3, [1, 2, 3, 4, 5, 6])
1272         >>> dm2 = Matrices.dense(2, 3, [7, 8, 9, 10, 11, 12])
1273         >>> dm3 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])
1274         >>> dm4 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12])
1275         >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12])
1276         >>> blocks1 = sc.parallelize([((0, 0), dm1), ((0, 1), dm2)])
1277         >>> blocks2 = sc.parallelize([((0, 0), dm3), ((1, 0), dm4)])
1278         >>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm4)])
1279         >>> mat1 = BlockMatrix(blocks1, 2, 3)
1280         >>> mat2 = BlockMatrix(blocks2, 3, 2)
1281         >>> mat3 = BlockMatrix(blocks3, 3, 2)
1282 
1283         >>> mat1.multiply(mat2).toLocalMatrix()
1284         DenseMatrix(2, 2, [242.0, 272.0, 350.0, 398.0], 0)
1285 
1286         >>> mat1.multiply(mat3).toLocalMatrix()
1287         DenseMatrix(2, 2, [227.0, 258.0, 394.0, 450.0], 0)
1288         """
1289         if not isinstance(other, BlockMatrix):
1290             raise TypeError("Other should be a BlockMatrix, got %s" % type(other))
1291 
1292         other_java_block_matrix = other._java_matrix_wrapper._java_model
1293         java_block_matrix = self._java_matrix_wrapper.call("multiply", other_java_block_matrix)
1294         return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock)
1295 
1296     @since('2.0.0')
1297     def transpose(self):
1298         """
1299         Transpose this BlockMatrix. Returns a new BlockMatrix
1300         instance sharing the same underlying data. Is a lazy operation.
1301 
1302         >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
1303         ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
1304         >>> mat = BlockMatrix(blocks, 3, 2)
1305 
1306         >>> mat_transposed = mat.transpose()
1307         >>> mat_transposed.toLocalMatrix()
1308         DenseMatrix(2, 6, [1.0, 4.0, 2.0, 5.0, 3.0, 6.0, 7.0, 10.0, 8.0, 11.0, 9.0, 12.0], 0)
1309         """
1310         java_transposed_matrix = self._java_matrix_wrapper.call("transpose")
1311         return BlockMatrix(java_transposed_matrix, self.colsPerBlock, self.rowsPerBlock)
1312 
1313     def toLocalMatrix(self):
1314         """
1315         Collect the distributed matrix on the driver as a DenseMatrix.
1316 
1317         >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
1318         ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
1319         >>> mat = BlockMatrix(blocks, 3, 2).toLocalMatrix()
1320 
1321         >>> # This BlockMatrix will have 6 effective rows, due to
1322         >>> # having two sub-matrix blocks stacked, each with 3 rows.
1323         >>> # The ensuing DenseMatrix will also have 6 rows.
1324         >>> print(mat.numRows)
1325         6
1326 
1327         >>> # This BlockMatrix will have 2 effective columns, due to
1328         >>> # having two sub-matrix blocks stacked, each with 2
1329         >>> # columns. The ensuing DenseMatrix will also have 2 columns.
1330         >>> print(mat.numCols)
1331         2
1332         """
1333         return self._java_matrix_wrapper.call("toLocalMatrix")
1334 
1335     def toIndexedRowMatrix(self):
1336         """
1337         Convert this matrix to an IndexedRowMatrix.
1338 
1339         >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
1340         ...                          ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
1341         >>> mat = BlockMatrix(blocks, 3, 2).toIndexedRowMatrix()
1342 
1343         >>> # This BlockMatrix will have 6 effective rows, due to
1344         >>> # having two sub-matrix blocks stacked, each with 3 rows.
1345         >>> # The ensuing IndexedRowMatrix will also have 6 rows.
1346         >>> print(mat.numRows())
1347         6
1348 
1349         >>> # This BlockMatrix will have 2 effective columns, due to
1350         >>> # having two sub-matrix blocks stacked, each with 2 columns.
1351         >>> # The ensuing IndexedRowMatrix will also have 2 columns.
1352         >>> print(mat.numCols())
1353         2
1354         """
1355         java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix")
1356         return IndexedRowMatrix(java_indexed_row_matrix)
1357 
1358     def toCoordinateMatrix(self):
1359         """
1360         Convert this matrix to a CoordinateMatrix.
1361 
1362         >>> blocks = sc.parallelize([((0, 0), Matrices.dense(1, 2, [1, 2])),
1363         ...                          ((1, 0), Matrices.dense(1, 2, [7, 8]))])
1364         >>> mat = BlockMatrix(blocks, 1, 2).toCoordinateMatrix()
1365         >>> mat.entries.take(3)
1366         [MatrixEntry(0, 0, 1.0), MatrixEntry(0, 1, 2.0), MatrixEntry(1, 0, 7.0)]
1367         """
1368         java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix")
1369         return CoordinateMatrix(java_coordinate_matrix)
1370 
1371 
1372 def _test():
1373     import doctest
1374     import numpy
1375     from pyspark.sql import SparkSession
1376     from pyspark.mllib.linalg import Matrices
1377     import pyspark.mllib.linalg.distributed
1378     try:
1379         # Numpy 1.14+ changed it's string format.
1380         numpy.set_printoptions(legacy='1.13')
1381     except TypeError:
1382         pass
1383     globs = pyspark.mllib.linalg.distributed.__dict__.copy()
1384     spark = SparkSession.builder\
1385         .master("local[2]")\
1386         .appName("mllib.linalg.distributed tests")\
1387         .getOrCreate()
1388     globs['sc'] = spark.sparkContext
1389     globs['Matrices'] = Matrices
1390     (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
1391     spark.stop()
1392     if failure_count:
1393         sys.exit(-1)
1394 
1395 if __name__ == "__main__":
1396     _test()