Skip to content

Commit c0c9565

Browse files
committed
Adding the subtract method.
1 parent f1410bf commit c0c9565

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

python/pyspark/mllib/linalg/distributed.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,41 @@ def add(self, other):
10361036
java_block_matrix = self._java_matrix_wrapper.call("add", other_java_block_matrix)
10371037
return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock)
10381038

1039+
@since('2.0.0')
1040+
def subtract(self, other):
1041+
"""
1042+
Subtracts the given block matrix `other` from this block matrix:
1043+
`this - other`. The matrices must have the same size and
1044+
matching `rowsPerBlock` and `colsPerBlock` values. If one of
1045+
the sub matrix blocks that are being subtracted is a
1046+
SparseMatrix, the resulting sub matrix block will also be a
1047+
SparseMatrix, even if it is being subtracted from a DenseMatrix.
1048+
If two dense sub matrix blocks are subtracted, the output block
1049+
will also be a DenseMatrix.
1050+
1051+
>>> dm1 = Matrices.dense(3, 2, [3, 1, 5, 4, 6, 2])
1052+
>>> dm2 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12])
1053+
>>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [1, 2, 3])
1054+
>>> blocks1 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)])
1055+
>>> blocks2 = sc.parallelize([((0, 0), dm2), ((1, 0), dm1)])
1056+
>>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm2)])
1057+
>>> mat1 = BlockMatrix(blocks1, 3, 2)
1058+
>>> mat2 = BlockMatrix(blocks2, 3, 2)
1059+
>>> mat3 = BlockMatrix(blocks3, 3, 2)
1060+
1061+
>>> mat1.subtract(mat2).toLocalMatrix()
1062+
DenseMatrix(6, 2, [-4.0, -7.0, -4.0, 4.0, 7.0, 4.0, -6.0, -5.0, -10.0, 6.0, 5.0, 10.0], 0)
1063+
1064+
>>> mat2.subtract(mat3).toLocalMatrix()
1065+
DenseMatrix(6, 2, [6.0, 8.0, 9.0, -4.0, -7.0, -4.0, 10.0, 9.0, 9.0, -6.0, -5.0, -10.0], 0)
1066+
"""
1067+
if not isinstance(other, BlockMatrix):
1068+
raise TypeError("Other should be a BlockMatrix, got %s" % type(other))
1069+
1070+
other_java_block_matrix = other._java_matrix_wrapper._java_model
1071+
java_block_matrix = self._java_matrix_wrapper.call("subtract", other_java_block_matrix)
1072+
return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock)
1073+
10391074
def multiply(self, other):
10401075
"""
10411076
Left multiplies this BlockMatrix by `other`, another

0 commit comments

Comments
 (0)