Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/main/java/org/apache/sysds/hops/DataOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ public Lop constructLops()

case TEE:
l = new Tee(getInput(0).constructLops(), getDataType(), getValueType());
setOutputDimensions(l);
break;

default:
Expand Down Expand Up @@ -488,7 +489,7 @@ else if ( getInput().get(0).areDimsBelowThreshold() )

@Override
public void refreshSizeInformation() {
if( _op == OpOpData.PERSISTENTWRITE || _op == OpOpData.TRANSIENTWRITE ) {
if( _op == OpOpData.PERSISTENTWRITE || _op == OpOpData.TRANSIENTWRITE || _op == OpOpData.TEE ) {
Hop input1 = getInput().get(0);
setDim1(input1.getDim1());
setDim2(input1.getDim2());
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/lops/DataGen.java
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ private String getRandInstructionCPSpark(String output)
sb.append(iLop == null ? "" : iLop.prepScalarLabel());
sb.append(OPERAND_DELIMITOR);

if( getExecType() == ExecType.CP ) {
if( getExecType() == ExecType.CP || getExecType() == ExecType.OOC ) {
//append degree of parallelism
sb.append( _numThreads );
sb.append( OPERAND_DELIMITOR );
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/lops/Transform.java
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ private String getInstructions(String input1, int numInputs, String output) {
sb.append( OPERAND_DELIMITOR );
sb.append( this.prepOutputOperand(output));

if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED)
if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED || getExecType()==ExecType.OOC)
&& (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.REV || _operation == ReOrgOp.SORT) ) {
sb.append( OPERAND_DELIMITOR );
sb.append( _numThreads );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.IOException;
import java.lang.ref.SoftReference;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Future;

Expand Down Expand Up @@ -528,7 +529,12 @@ protected MatrixBlock readBlobFromRDD(RDDObject rdd, MutableBoolean writeStatus)

@Override
protected MatrixBlock readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stream) throws IOException {
MatrixBlock ret = new MatrixBlock((int)getNumRows(), (int)getNumColumns(), false);
boolean dimsUnknown = getNumRows() < 0 || getNumColumns() < 0;
int nrows = (int)getNumRows();
int ncols = (int)getNumColumns();
MatrixBlock ret = dimsUnknown ? null : new MatrixBlock((int)getNumRows(), (int)getNumColumns(), false);
// TODO if stream is CachingStream, block parts might be evicted resulting in null pointer exceptions
List<IndexedMatrixValue> blockCache = dimsUnknown ? new ArrayList<>() : null;
IndexedMatrixValue tmp = null;
try {
int blen = getBlocksize(), lnnz = 0;
Expand All @@ -537,12 +543,31 @@ protected MatrixBlock readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stre
final int row_offset = (int) (tmp.getIndexes().getRowIndex() - 1) * blen;
final int col_offset = (int) (tmp.getIndexes().getColumnIndex() - 1) * blen;

// Add the values of this block into the output block.
((MatrixBlock)tmp.getValue()).putInto(ret, row_offset, col_offset, true);
if (dimsUnknown) {
nrows = Math.max(nrows, row_offset + tmp.getValue().getNumRows());
ncols = Math.max(ncols, col_offset + tmp.getValue().getNumColumns());
blockCache.add(tmp);
} else {
// Add the values of this block into the output block.
((MatrixBlock) tmp.getValue()).putInto(ret, row_offset, col_offset, true);
}

// incremental maintenance nnz
lnnz += tmp.getValue().getNonZeros();
}

if (dimsUnknown) {
ret = new MatrixBlock(nrows, ncols, false);

for (IndexedMatrixValue _tmp : blockCache) {
// compute row/column block offsets
final int row_offset = (int) (_tmp.getIndexes().getRowIndex() - 1) * blen;
final int col_offset = (int) (_tmp.getIndexes().getColumnIndex() - 1) * blen;

((MatrixBlock) _tmp.getValue()).putInto(ret, row_offset, col_offset, true);
}
}

ret.setNonZeros(lnnz);
}
catch(Exception ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;

public class OOCInstructionParser extends InstructionParser {
Expand Down Expand Up @@ -74,7 +74,7 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
case MMTSJ:
return TSMMOOCInstruction.parseInstruction(str);
case Reorg:
return TransposeOOCInstruction.parseInstruction(str);
return ReorgOOCInstruction.parseInstruction(str);
case Tee:
return TeeOOCInstruction.parseInstruction(str);
case CentralMoment:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.sysds.runtime.instructions.ooc;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
Expand Down Expand Up @@ -67,12 +69,55 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) {
OOCStream<IndexedMatrixValue> qOut = new SubscribableTaskQueue<>();
ec.getMatrixObject(output).setStreamHandle(qOut);

joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> {
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
tmpOut.set(tmp1.getIndexes(),
tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(), tmpOut.getValue()));
return tmpOut;
}, IndexedMatrixValue::getIndexes);
if (m1.getNumRows() < 0 || m1.getNumColumns() < 0 || m2.getNumRows() < 0 || m2.getNumColumns() < 0)
throw new DMLRuntimeException("Cannot process (matrix, matrix) BinaryOOCInstruction with unknown dimensions.");

boolean isColBroadcast = m1.getNumColumns() > 1 && m2.getNumColumns() == 1;
boolean isRowBroadcast = m1.getNumRows() > 1 && m2.getNumRows() == 1;

if (isColBroadcast && !isRowBroadcast) {
final long maxProcessesPerBroadcast = m1.getNumColumns() / m1.getBlocksize();

broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> {
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
tmpOut.set(tmp1.getIndexes(),
tmp1.getValue().binaryOperations((BinaryOperator)_optr, b.getValue().getValue(), tmpOut.getValue()));

if (b.incrProcessCtrAndGet() >= maxProcessesPerBroadcast)
b.release();

return tmpOut;
}, tmp -> tmp.getIndexes().getRowIndex());
}
else if (isRowBroadcast && !isColBroadcast) {
final long maxProcessesPerBroadcast = m1.getNumRows() / m1.getBlocksize();

broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> {
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
tmpOut.set(tmp1.getIndexes(),
tmp1.getValue().binaryOperations((BinaryOperator)_optr, b.getValue().getValue(), tmpOut.getValue()));

if (b.incrProcessCtrAndGet() >= maxProcessesPerBroadcast)
b.release();

return tmpOut;
}, tmp -> tmp.getIndexes().getColumnIndex());
}
else {
if (m1.getNumColumns() != m2.getNumColumns() || m1.getNumRows() != m2.getNumRows())
throw new NotImplementedException("Invalid dimensions for matrix-matrix binary op: "
+ m1.getNumRows() + "x" + m1.getNumColumns() + " <=> "
+ m2.getNumRows() + "x" + m2.getNumColumns());

joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> {
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
tmpOut.set(tmp1.getIndexes(),
tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(), tmpOut.getValue()));
return tmpOut;
}, IndexedMatrixValue::getIndexes);
}


}

protected void processScalarMatrixInstruction(ExecutionContext ec) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ public class CachingStream implements OOCStreamable<IndexedMatrixValue> {
private boolean _cacheInProgress = true; // caching in progress, in the first pass.
private Map<MatrixIndexes, Integer> _index;

private DMLRuntimeException _failure;

public CachingStream(OOCStream<IndexedMatrixValue> source) {
this(source, _streamSeq.getNextID());
}
Expand All @@ -76,6 +78,22 @@ public CachingStream(OOCStream<IndexedMatrixValue> source, long streamId) {
}
} catch (InterruptedException e) {
throw new DMLRuntimeException(e);
} catch (DMLRuntimeException e) {
// Propagate failure to subscribers
_failure = e;
synchronized (this) {
notifyAll();
}

Runnable[] mSubscribers = _subscribers;
if(mSubscribers != null) {
for(Runnable mSubscriber : mSubscribers) {
try {
mSubscriber.run();
} catch (Exception ignored) {
}
}
}
}
});
}
Expand Down Expand Up @@ -103,7 +121,9 @@ private synchronized boolean fetchFromStream() throws InterruptedException {

public synchronized IndexedMatrixValue get(int idx) throws InterruptedException {
while (true) {
if (idx < _numBlocks) {
if (_failure != null)
throw _failure;
else if (idx < _numBlocks) {
IndexedMatrixValue out = OOCEvictionManager.get(_streamId, idx);

if (_index != null) // Ensure index is up to date
Expand Down
Loading
Loading