diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java index cfd6630f274..7b912bd39ed 100644 --- a/src/main/java/org/apache/sysds/hops/DataOp.java +++ b/src/main/java/org/apache/sysds/hops/DataOp.java @@ -297,6 +297,7 @@ public Lop constructLops() case TEE: l = new Tee(getInput(0).constructLops(), getDataType(), getValueType()); + setOutputDimensions(l); break; default: @@ -488,7 +489,7 @@ else if ( getInput().get(0).areDimsBelowThreshold() ) @Override public void refreshSizeInformation() { - if( _op == OpOpData.PERSISTENTWRITE || _op == OpOpData.TRANSIENTWRITE ) { + if( _op == OpOpData.PERSISTENTWRITE || _op == OpOpData.TRANSIENTWRITE || _op == OpOpData.TEE ) { Hop input1 = getInput().get(0); setDim1(input1.getDim1()); setDim2(input1.getDim2()); diff --git a/src/main/java/org/apache/sysds/lops/DataGen.java b/src/main/java/org/apache/sysds/lops/DataGen.java index d8ac1b8a4a7..93237f9b7fc 100644 --- a/src/main/java/org/apache/sysds/lops/DataGen.java +++ b/src/main/java/org/apache/sysds/lops/DataGen.java @@ -199,7 +199,7 @@ private String getRandInstructionCPSpark(String output) sb.append(iLop == null ? "" : iLop.prepScalarLabel()); sb.append(OPERAND_DELIMITOR); - if( getExecType() == ExecType.CP ) { + if( getExecType() == ExecType.CP || getExecType() == ExecType.OOC ) { //append degree of parallelism sb.append( _numThreads ); sb.append( OPERAND_DELIMITOR ); diff --git a/src/main/java/org/apache/sysds/lops/Transform.java b/src/main/java/org/apache/sysds/lops/Transform.java index 0ac36a37e4e..0d2e79f83a8 100644 --- a/src/main/java/org/apache/sysds/lops/Transform.java +++ b/src/main/java/org/apache/sysds/lops/Transform.java @@ -179,7 +179,7 @@ private String getInstructions(String input1, int numInputs, String output) { sb.append( OPERAND_DELIMITOR ); sb.append( this.prepOutputOperand(output)); - if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED) + if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED || getExecType()==ExecType.OOC) && (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.REV || _operation == ReOrgOp.SORT) ) { sb.append( OPERAND_DELIMITOR ); sb.append( _numThreads ); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java index 496bca87642..8191040eb18 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.lang.ref.SoftReference; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.Future; @@ -528,7 +529,12 @@ protected MatrixBlock readBlobFromRDD(RDDObject rdd, MutableBoolean writeStatus) @Override protected MatrixBlock readBlobFromStream(LocalTaskQueue stream) throws IOException { - MatrixBlock ret = new MatrixBlock((int)getNumRows(), (int)getNumColumns(), false); + boolean dimsUnknown = getNumRows() < 0 || getNumColumns() < 0; + int nrows = (int)getNumRows(); + int ncols = (int)getNumColumns(); + MatrixBlock ret = dimsUnknown ? null : new MatrixBlock((int)getNumRows(), (int)getNumColumns(), false); + // TODO if stream is CachingStream, block parts might be evicted resulting in null pointer exceptions + List blockCache = dimsUnknown ? new ArrayList<>() : null; IndexedMatrixValue tmp = null; try { int blen = getBlocksize(), lnnz = 0; @@ -537,12 +543,31 @@ protected MatrixBlock readBlobFromStream(LocalTaskQueue stre final int row_offset = (int) (tmp.getIndexes().getRowIndex() - 1) * blen; final int col_offset = (int) (tmp.getIndexes().getColumnIndex() - 1) * blen; - // Add the values of this block into the output block. - ((MatrixBlock)tmp.getValue()).putInto(ret, row_offset, col_offset, true); + if (dimsUnknown) { + nrows = Math.max(nrows, row_offset + tmp.getValue().getNumRows()); + ncols = Math.max(ncols, col_offset + tmp.getValue().getNumColumns()); + blockCache.add(tmp); + } else { + // Add the values of this block into the output block. + ((MatrixBlock) tmp.getValue()).putInto(ret, row_offset, col_offset, true); + } // incremental maintenance nnz lnnz += tmp.getValue().getNonZeros(); } + + if (dimsUnknown) { + ret = new MatrixBlock(nrows, ncols, false); + + for (IndexedMatrixValue _tmp : blockCache) { + // compute row/column block offsets + final int row_offset = (int) (_tmp.getIndexes().getRowIndex() - 1) * blen; + final int col_offset = (int) (_tmp.getIndexes().getColumnIndex() - 1) * blen; + + ((MatrixBlock) _tmp.getValue()).putInto(ret, row_offset, col_offset, true); + } + } + ret.setNonZeros(lnnz); } catch(Exception ex) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index f23ad6d67a6..feefe5f63d6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -36,7 +36,7 @@ import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction; -import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; public class OOCInstructionParser extends InstructionParser { @@ -74,7 +74,7 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str case MMTSJ: return TSMMOOCInstruction.parseInstruction(str); case Reorg: - return TransposeOOCInstruction.parseInstruction(str); + return ReorgOOCInstruction.parseInstruction(str); case Tee: return TeeOOCInstruction.parseInstruction(str); case CentralMoment: diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java index 148592b6a99..01c7a525bcd 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java @@ -19,6 +19,8 @@ package org.apache.sysds.runtime.instructions.ooc; +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -67,12 +69,55 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) { OOCStream qOut = new SubscribableTaskQueue<>(); ec.getMatrixObject(output).setStreamHandle(qOut); - joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> { - IndexedMatrixValue tmpOut = new IndexedMatrixValue(); - tmpOut.set(tmp1.getIndexes(), - tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(), tmpOut.getValue())); - return tmpOut; - }, IndexedMatrixValue::getIndexes); + if (m1.getNumRows() < 0 || m1.getNumColumns() < 0 || m2.getNumRows() < 0 || m2.getNumColumns() < 0) + throw new DMLRuntimeException("Cannot process (matrix, matrix) BinaryOOCInstruction with unknown dimensions."); + + boolean isColBroadcast = m1.getNumColumns() > 1 && m2.getNumColumns() == 1; + boolean isRowBroadcast = m1.getNumRows() > 1 && m2.getNumRows() == 1; + + if (isColBroadcast && !isRowBroadcast) { + final long maxProcessesPerBroadcast = m1.getNumColumns() / m1.getBlocksize(); + + broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp1.getIndexes(), + tmp1.getValue().binaryOperations((BinaryOperator)_optr, b.getValue().getValue(), tmpOut.getValue())); + + if (b.incrProcessCtrAndGet() >= maxProcessesPerBroadcast) + b.release(); + + return tmpOut; + }, tmp -> tmp.getIndexes().getRowIndex()); + } + else if (isRowBroadcast && !isColBroadcast) { + final long maxProcessesPerBroadcast = m1.getNumRows() / m1.getBlocksize(); + + broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp1.getIndexes(), + tmp1.getValue().binaryOperations((BinaryOperator)_optr, b.getValue().getValue(), tmpOut.getValue())); + + if (b.incrProcessCtrAndGet() >= maxProcessesPerBroadcast) + b.release(); + + return tmpOut; + }, tmp -> tmp.getIndexes().getColumnIndex()); + } + else { + if (m1.getNumColumns() != m2.getNumColumns() || m1.getNumRows() != m2.getNumRows()) + throw new NotImplementedException("Invalid dimensions for matrix-matrix binary op: " + + m1.getNumRows() + "x" + m1.getNumColumns() + " <=> " + + m2.getNumRows() + "x" + m2.getNumColumns()); + + joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp1.getIndexes(), + tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(), tmpOut.getValue())); + return tmpOut; + }, IndexedMatrixValue::getIndexes); + } + + } protected void processScalarMatrixInstruction(ExecutionContext ec) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index b74f7ed5e13..d7c80e4de3c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -52,6 +52,8 @@ public class CachingStream implements OOCStreamable { private boolean _cacheInProgress = true; // caching in progress, in the first pass. private Map _index; + private DMLRuntimeException _failure; + public CachingStream(OOCStream source) { this(source, _streamSeq.getNextID()); } @@ -76,6 +78,22 @@ public CachingStream(OOCStream source, long streamId) { } } catch (InterruptedException e) { throw new DMLRuntimeException(e); + } catch (DMLRuntimeException e) { + // Propagate failure to subscribers + _failure = e; + synchronized (this) { + notifyAll(); + } + + Runnable[] mSubscribers = _subscribers; + if(mSubscribers != null) { + for(Runnable mSubscriber : mSubscribers) { + try { + mSubscriber.run(); + } catch (Exception ignored) { + } + } + } } }); } @@ -103,7 +121,9 @@ private synchronized boolean fetchFromStream() throws InterruptedException { public synchronized IndexedMatrixValue get(int idx) throws InterruptedException { while (true) { - if (idx < _numBlocks) { + if (_failure != null) + throw _failure; + else if (idx < _numBlocks) { IndexedMatrixValue out = OOCEvictionManager.get(_streamId, idx); if (_index != null) // Ensure index is up to date diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java index 355c8ddea1e..f9162a47f51 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java @@ -20,8 +20,12 @@ package org.apache.sysds.runtime.instructions.ooc; import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Opcodes; import org.apache.sysds.common.Types; +import org.apache.sysds.hops.DataGenOp; +import org.apache.sysds.lops.Lop; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -30,25 +34,88 @@ import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.data.RandomMatrixGenerator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; import org.apache.sysds.runtime.util.UtilFunctions; public class DataGenOOCInstruction extends UnaryOOCInstruction { + private static final Log LOG = LogFactory.getLog(DataGenOOCInstruction.class.getName()); + private Types.OpOpDG method; + private final CPOperand rows, cols, dims; private final int blen; - private Types.OpOpDG method; + private boolean minMaxAreDoubles; + private final String minValueStr, maxValueStr; + private final double minValue, maxValue, sparsity; + private final String pdf, pdfParams, frame_data, schema; + private final long seed; + private Long runtimeSeed; // sequence specific attributes private final CPOperand seq_from, seq_to, seq_incr; - public DataGenOOCInstruction(UnaryOperator op, Types.OpOpDG mthd, CPOperand in, CPOperand out, int blen, CPOperand seqFrom, - CPOperand seqTo, CPOperand seqIncr, String opcode, String istr) { - super(OOCType.Rand, op, in, out, opcode, istr); - this.blen = blen; + // sample specific attributes + private final boolean replace; + private final int numThreads; + + // seed positions + private static final int SEED_POSITION_RAND = 8; + private static final int SEED_POSITION_SAMPLE = 4; + + private DataGenOOCInstruction(UnaryOperator op, Types.OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols, + CPOperand dims, int blen, String minValue, String maxValue, double sparsity, long seed, + String probabilityDensityFunction, String pdfParams, int k, CPOperand seqFrom, CPOperand seqTo, + CPOperand seqIncr, boolean replace, String data, String schema, String opcode, String istr) { + super(OOCInstruction.OOCType.Rand, op, in, out, opcode, istr); this.method = mthd; + this.rows = rows; + this.cols = cols; + this.dims = dims; + this.blen = blen; + this.minValueStr = minValue; + this.maxValueStr = maxValue; + double minDouble, maxDouble; + try { + minDouble = !minValue.contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? Double.valueOf(minValue) : -1; + maxDouble = !maxValue.contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? Double.valueOf(maxValue) : -1; + minMaxAreDoubles = true; + } + catch(NumberFormatException e) { + // Non double values + if(!minValueStr.equals(maxValueStr)) { + throw new DMLRuntimeException( + "Rand instruction does not support " + "non numeric Datatypes for range initializations."); + } + minDouble = -1; + maxDouble = -1; + minMaxAreDoubles = false; + } + this.minValue = minDouble; + this.maxValue = maxDouble; + this.sparsity = sparsity; + this.seed = seed; + this.pdf = probabilityDensityFunction; + this.pdfParams = pdfParams; + this.numThreads = k; this.seq_from = seqFrom; this.seq_to = seqTo; this.seq_incr = seqIncr; + this.replace = replace; + this.frame_data = data; + this.schema = schema; + } + + private DataGenOOCInstruction(UnaryOperator op, Types.OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols, + CPOperand dims, int blen, CPOperand seqFrom, CPOperand seqTo, CPOperand seqIncr, String opcode, String istr) { + this(op, mthd, in, out, rows, cols, dims, blen, "0", "1", 1.0, -1, null, null, 1, seqFrom, seqTo, seqIncr, + false, null, null, opcode, istr); + } + + private DataGenOOCInstruction(UnaryOperator op, Types.OpOpDG mthd, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols, + CPOperand dims, int blen, String minValue, String maxValue, double sparsity, long seed, + String probabilityDensityFunction, String pdfParams, int k, String opcode, String istr) { + this(op, mthd, in, out, rows, cols, dims, blen, minValue, maxValue, sparsity, seed, probabilityDensityFunction, + pdfParams, k, null, null, null, false, null, null, opcode, istr); } public static DataGenOOCInstruction parseInstruction(String str) { @@ -56,7 +123,11 @@ public static DataGenOOCInstruction parseInstruction(String str) { String[] s = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = s[0]; - if(opcode.equalsIgnoreCase(Opcodes.SEQUENCE.toString())) { + if(opcode.equalsIgnoreCase(Opcodes.RANDOM.toString())) { + method = Types.OpOpDG.RAND; + InstructionUtils.checkNumFields(s, 10, 11); + } + else if(opcode.equalsIgnoreCase(Opcodes.SEQUENCE.toString())) { method = Types.OpOpDG.SEQ; // 8 operands: rows, cols, blen, from, to, incr, outvar InstructionUtils.checkNumFields(s, 7); @@ -67,13 +138,37 @@ public static DataGenOOCInstruction parseInstruction(String str) { CPOperand out = new CPOperand(s[s.length - 1]); UnaryOperator op = null; - if(method == Types.OpOpDG.SEQ) { + if(method == Types.OpOpDG.RAND) { + int missing; // number of missing params (row & cols or dims) + CPOperand rows = null, cols = null, dims = null; + if(s.length == 12) { + missing = 1; + rows = new CPOperand(s[1]); + cols = new CPOperand(s[2]); + } + else { + missing = 2; + dims = new CPOperand(s[1]); + } + int blen = Integer.parseInt(s[4 - missing]); + double sparsity = !s[7 - missing].contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? Double + .parseDouble(s[7 - missing]) : -1; + long seed = !s[SEED_POSITION_RAND - missing].contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? Long + .parseLong(s[SEED_POSITION_RAND - missing]) : -1; + String pdf = s[9 - missing]; + String pdfParams = !s[10 - missing].contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? s[10 - missing] : null; + int k = Integer.parseInt(s[11 - missing]); + + return new DataGenOOCInstruction(op, method, null, out, rows, cols, dims, blen, s[5 - missing], + s[6 - missing], sparsity, seed, pdf, pdfParams, k, opcode, str); + } + else if(method == Types.OpOpDG.SEQ) { int blen = Integer.parseInt(s[3]); CPOperand from = new CPOperand(s[4]); CPOperand to = new CPOperand(s[5]); CPOperand incr = new CPOperand(s[6]); - return new DataGenOOCInstruction(op, method, null, out, blen, from, to, incr, opcode, str); + return new DataGenOOCInstruction(op, method, null, out, null, null, null, blen, from, to, incr, opcode, str); } else throw new NotImplementedException(); @@ -84,7 +179,45 @@ public void processInstruction(ExecutionContext ec) { final OOCStream qOut = createWritableStream(); // process specific datagen operator - if(method == Types.OpOpDG.SEQ) { + if (method == Types.OpOpDG.RAND) { + if (!output.isMatrix()) + throw new NotImplementedException(); + + long lSeed = generateSeed(); + long lrows = ec.getScalarInput(rows).getLongValue(); + long lcols = ec.getScalarInput(cols).getLongValue(); + checkValidDimensions(lrows, lcols); + + if (!pdf.equalsIgnoreCase("uniform") || minValue != maxValue) + throw new NotImplementedException(); // TODO modified version of rng as in LibMatrixDatagen to handle blocks independently + + OOCStream qIn = createWritableStream(); + int nrb = (int)((lrows-1) / blen)+1; + int ncb = (int)((lcols-1) / blen)+1; + + for (int row = 0; row < nrb; row++) + for (int col = 0; col < ncb; col++) + qIn.enqueue(new MatrixIndexes(row+1, col+1)); + + qIn.closeInput(); + + if(sparsity == 0.0 && lrows < Integer.MAX_VALUE && lcols < Integer.MAX_VALUE) { + mapOOC(qIn, qOut, idx -> { + long rlen = Math.min(blen, lrows - (idx.getRowIndex()-1) * blen); + long clen = Math.min(blen, lcols - (idx.getColumnIndex()-1) * blen); + return new IndexedMatrixValue(idx, new MatrixBlock((int)rlen, (int)clen, 0.0)); + }); + return; + } + + mapOOC(qIn, qOut, idx -> { + long rlen = Math.min(blen, lrows - (idx.getRowIndex()-1) * blen); + long clen = Math.min(blen, lcols - (idx.getColumnIndex()-1) * blen); + MatrixBlock mout = MatrixBlock.randOperations(getGenerator(rlen, clen), lSeed); + return new IndexedMatrixValue(idx, mout); + }); + } + else if(method == Types.OpOpDG.SEQ) { double lfrom = ec.getScalarInput(seq_from).getDoubleValue(); double lto = ec.getScalarInput(seq_to).getDoubleValue(); double lincr = ec.getScalarInput(seq_incr).getDoubleValue(); @@ -133,4 +266,39 @@ public void processInstruction(ExecutionContext ec) { ec.getMatrixObject(output).setStreamHandle(qOut); } + + + + private long generateSeed() { + // generate pseudo-random seed (because not specified) + long lSeed = seed; // seed per invocation + if(lSeed == DataGenOp.UNSPECIFIED_SEED) { + if(runtimeSeed == null) + runtimeSeed = DataGenOp.generateRandomSeed(); + lSeed = runtimeSeed; + } + + if(LOG.isTraceEnabled()) + LOG.trace("Process DataGenOOCInstruction rand with seed = " + lSeed + "."); + + return lSeed; + } + + private static void checkValidDimensions(long rows, long cols) { + // check valid for integer dimensions (we cannot even represent empty blocks with larger dimensions) + if(rows > Integer.MAX_VALUE || cols > Integer.MAX_VALUE) + throw new DMLRuntimeException("DataGenOOCInstruction does not " + + "support dimensions larger than integer: rows=" + rows + ", cols=" + cols + "."); + } + + private RandomMatrixGenerator getGenerator(long lrows, long lcols) { + return LibMatrixDatagen.createRandomMatrixGenerator(pdf, + (int) lrows, + (int) lcols, + blen, + sparsity, + minValue, + maxValue, + pdfParams); + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index eb7cf55f51d..ca13cfdb2c3 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -38,8 +38,10 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -150,6 +152,100 @@ protected CompletableFuture mapOOC(OOCStream qIn, OOCStream q }, qOut::closeInput); } + protected CompletableFuture broadcastJoinOOC(OOCStream qIn, OOCStream broadcast, OOCStream qOut, BiFunction mapper, Function on) { + addInStream(qIn, broadcast); + addOutStream(qOut); + + boolean explicitLeftCaching = !qIn.hasStreamCache(); + boolean explicitRightCaching = !broadcast.hasStreamCache(); + CachingStream leftCache = explicitLeftCaching ? new CachingStream(new SubscribableTaskQueue<>()) : qIn.getStreamCache(); + CachingStream rightCache = explicitRightCaching ? new CachingStream(new SubscribableTaskQueue<>()) : broadcast.getStreamCache(); + leftCache.activateIndexing(); + rightCache.activateIndexing(); + + Map> availableLeftInput = new ConcurrentHashMap<>(); + Map availableBroadcastInput = new ConcurrentHashMap<>(); + + return submitOOCTasks(List.of(qIn, broadcast), (i, tmp) -> { + P key = on.apply(tmp); + + if (i == 0) { // qIn stream + BroadcastedElement b = availableBroadcastInput.get(key); + + if (b == null) { + // Matching broadcast element is not available -> cache element + if (explicitLeftCaching) + leftCache.getWriteStream().enqueue(tmp); + + availableLeftInput.compute(key, (k, v) -> { + if (v == null) + v = new ArrayList<>(); + v.add(tmp.getIndexes()); + return v; + }); + } else { + // Directly emit + qOut.enqueue(mapper.apply(tmp, b)); + + if (b.canRelease()) + availableBroadcastInput.remove(key); + } + } else { // broadcast stream + if (explicitRightCaching) + rightCache.getWriteStream().enqueue(tmp); + + BroadcastedElement b = new BroadcastedElement(tmp.getIndexes()); + availableBroadcastInput.put(key, b); + + List queued = availableLeftInput.remove(key); + + if (queued != null) { + for(MatrixIndexes idx : queued) { + b.value = rightCache.findCached(b.idx); + qOut.enqueue(mapper.apply(leftCache.findCached(idx), b)); + b.value = null; + } + } + + if (b.canRelease()) + availableBroadcastInput.remove(key); + } + }, qOut::closeInput); + } + + protected static class BroadcastedElement { + private final MatrixIndexes idx; + private IndexedMatrixValue value; + private boolean release; + private int processCtr; + + public BroadcastedElement(MatrixIndexes idx) { + this.idx = idx; + this.release = false; + } + + public synchronized void release() { + release = true; + } + + public synchronized boolean canRelease() { + return release; + } + + public synchronized int incrProcessCtrAndGet() { + processCtr++; + return processCtr; + } + + public MatrixIndexes getIndex() { + return idx; + } + + public IndexedMatrixValue getValue() { + return value; + } + }; + protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream qIn2, OOCStream qOut, BiFunction mapper, Function on) { return joinOOC(qIn1, qIn2, qOut, mapper, on, on); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java new file mode 100644 index 00000000000..a87a3498329 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.SortIndex; +import org.apache.sysds.runtime.functionobjects.SwapIndex; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.apache.sysds.runtime.util.DataConverter; + +public class ReorgOOCInstruction extends ComputationOOCInstruction { + // sort-specific attributes (to enable variable attributes) + private final CPOperand _col; + private final CPOperand _desc; + private final CPOperand _ixret; + + protected ReorgOOCInstruction(ReorgOperator op, CPOperand in1, CPOperand out, String opcode, String istr) { + this(op, in1, out, null, null, null, opcode, istr); + } + + private ReorgOOCInstruction(Operator op, CPOperand in, CPOperand out, CPOperand col, CPOperand desc, CPOperand ixret, + String opcode, String istr) { + super(OOCType.Reorg, op, in, out, opcode, istr); + _col = col; + _desc = desc; + _ixret = ixret; + } + + public static ReorgOOCInstruction parseInstruction(String str) { + CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN); + CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN); + + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + + if (opcode.equalsIgnoreCase(Opcodes.TRANSPOSE.toString())) { + InstructionUtils.checkNumFields(str, 2, 3); + in.split(parts[1]); + out.split(parts[2]); + + ReorgOperator reorg = new ReorgOperator(SwapIndex.getSwapIndexFnObject()); + return new ReorgOOCInstruction(reorg, in, out, opcode, str); + } + else if (opcode.equalsIgnoreCase(Opcodes.SORT.toString())) { + InstructionUtils.checkNumFields(str, 5,6); + in.split(parts[1]); + out.split(parts[5]); + CPOperand col = new CPOperand(parts[2]); + CPOperand desc = new CPOperand(parts[3]); + CPOperand ixret = new CPOperand(parts[4]); + int k = Integer.parseInt(parts[6]); + return new ReorgOOCInstruction(new ReorgOperator(new SortIndex(1,false,false), k), + in, out, col, desc, ixret, opcode, str); + } + else + throw new NotImplementedException(); + } + + public void processInstruction( ExecutionContext ec ) { + // Create thread and process the transpose operation + MatrixObject min = ec.getMatrixObject(input1); + + + ReorgOperator r_op = (ReorgOperator) _optr; + + if(r_op.fn instanceof SortIndex) { + //additional attributes for sort + int[] cols = _col.getDataType().isMatrix() ? DataConverter.convertToIntVector(ec.getMatrixInput(_col.getName())) : + new int[]{(int)ec.getScalarInput(_col).getLongValue()}; + boolean desc = ec.getScalarInput(_desc).getBooleanValue(); + boolean ixret = ec.getScalarInput(_ixret).getBooleanValue(); + r_op = r_op.setFn(new SortIndex(cols, desc, ixret)); + + // For now, we reuse the CP instruction + // In future, we could optimize by building the permutation and streaming blocks column by column + MatrixBlock matBlock = min.acquireRead(); + MatrixBlock soresBlock = matBlock.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0); + if (_col.getDataType().isMatrix()) + ec.releaseMatrixInput(_col.getName()); + ec.releaseMatrixInput(input1.getName()); + ec.setMatrixOutput(output.getName(), soresBlock); + } else if (r_op.fn instanceof SwapIndex) { + OOCStream qIn = min.getStreamHandle(); + OOCStream qOut = createWritableStream(); + ec.getMatrixObject(output).setStreamHandle(qOut); + // Transpose operation + mapOOC(qIn, qOut, tmp -> { + MatrixBlock inBlock = (MatrixBlock) tmp.getValue(); + long oldRowIdx = tmp.getIndexes().getRowIndex(); + long oldColIdx = tmp.getIndexes().getColumnIndex(); + + MatrixBlock outBlock = inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1); + return new IndexedMatrixValue(new MatrixIndexes(oldColIdx, oldRowIdx), outBlock); + }); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java index 9040c369a24..c0d430fa703 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java @@ -22,7 +22,6 @@ import org.apache.sysds.common.Opcodes; import org.apache.sysds.lops.MMTSJ; import org.apache.sysds.lops.MMTSJ.MMTSJType; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; @@ -77,18 +76,20 @@ public void processInstruction( ExecutionContext ec ) { } int dim = _type.isLeft() ? nCols : nRows; - MatrixBlock resultBlock = new MatrixBlock(dim, dim, false); - try { - IndexedMatrixValue tmp = null; - // aggregate partial tsmm outputs into result as inputs stream in - while((tmp = qIn.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { - MatrixBlock partialResult = ((MatrixBlock) tmp.getValue()) - .transposeSelfMatrixMultOperations(new MatrixBlock(), _type); - resultBlock.binaryOperationsInPlace(plus, partialResult); - } - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); + MatrixBlock resultBlock = null; + + OOCStream tmpStream = createWritableStream(); + + mapOOC(qIn, tmpStream, + tmp -> ((MatrixBlock) tmp.getValue()) + .transposeSelfMatrixMultOperations(new MatrixBlock(), _type)); + + MatrixBlock tmp; + while ((tmp = tmpStream.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { + if (resultBlock == null) + resultBlock = tmp; + else + resultBlock.binaryOperationsInPlace(plus, tmp); } ec.setMatrixOutput(output.getName(), resultBlock); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java deleted file mode 100644 index 6558145ec21..00000000000 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.runtime.instructions.ooc; - -import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.functionobjects.SwapIndex; -import org.apache.sysds.runtime.instructions.InstructionUtils; -import org.apache.sysds.runtime.instructions.cp.CPOperand; -import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.data.MatrixIndexes; -import org.apache.sysds.runtime.matrix.operators.ReorgOperator; - -public class TransposeOOCInstruction extends ComputationOOCInstruction { - - protected TransposeOOCInstruction(OOCType type, ReorgOperator op, CPOperand in1, CPOperand out, String opcode, String istr) { - super(type, op, in1, out, opcode, istr); - - } - - public static TransposeOOCInstruction parseInstruction(String str) { - String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); - InstructionUtils.checkNumFields(parts, 2); - String opcode = parts[0]; - CPOperand in1 = new CPOperand(parts[1]); - CPOperand out = new CPOperand(parts[2]); - - ReorgOperator reorg = new ReorgOperator(SwapIndex.getSwapIndexFnObject()); - return new TransposeOOCInstruction(OOCType.Reorg, reorg, in1, out, opcode, str); - } - - public void processInstruction( ExecutionContext ec ) { - - // Create thread and process the transpose operation - MatrixObject min = ec.getMatrixObject(input1); - OOCStream qIn = min.getStreamHandle(); - OOCStream qOut = createWritableStream(); - ec.getMatrixObject(output).setStreamHandle(qOut); - - mapOOC(qIn, qOut, tmp -> { - MatrixBlock inBlock = (MatrixBlock) tmp.getValue(); - long oldRowIdx = tmp.getIndexes().getRowIndex(); - long oldColIdx = tmp.getIndexes().getColumnIndex(); - - MatrixBlock outBlock = inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1); - return new IndexedMatrixValue(new MatrixIndexes(oldColIdx, oldRowIdx), outBlock); - }); - } -} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/RandTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/RandTest.java new file mode 100644 index 00000000000..40430aa49f9 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/RandTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class RandTest extends AutomatedTestBase { + private final static String TEST_NAME_1 = "Rand1"; + private final static String TEST_NAME_2 = "Rand2"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + RandTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + private static final String INPUT_NAME_1 = "X"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 1500; + private final static int cols = 1200; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_1); + addTestConfiguration(TEST_NAME_1, config); + TestConfiguration config2 = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2); + addTestConfiguration(TEST_NAME_2, config2); + } + + // Actual rand operation not yet supported + /*@Test + public void testRand() { + runRandTest(TEST_NAME_1); + }*/ + + @Test + public void testConstInit() { + runRandTest(TEST_NAME_2); + } + + private void runRandTest(String TEST_NAME) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME)}; + + runTest(true, false, null, -1); + + //check replace OOC op + Assert.assertTrue("OOC wasn't used for rand", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.RANDOM)); + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/SortTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/SortTest.java new file mode 100644 index 00000000000..61fc35a0b52 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/SortTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class SortTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "Sort"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + SortTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + private static final String INPUT_NAME_1 = "X"; + private static final String OUTPUT_NAME = "res"; + + private final static int maxVal = 7; + private final static double sparsity1 = 1; + private final static double sparsity2 = 0.05; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testSortDenseMatrix() { + runSortTest(1500, 800, false); + } + + @Test + public void testSortSparseMatrix() { + runSortTest(1500, 800, true); + } + + @Test + public void testSortDenseVector() { + runSortTest(1500, 1, false); + } + + @Test + public void testSortSparseVector() { + runSortTest(1500, 1, true); + } + + private void runSortTest(int rows, int cols, boolean sparse) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] X_data = getRandomMatrix(rows, cols, 1, maxVal, sparse ? sparsity2 : sparsity1, 7); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1), rows, cols, 1000, X_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY); + + runTest(true, false, null, -1); + + //check sort OOC + Assert.assertTrue("OOC wasn't used for sort", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.SORT)); + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/StreamCollectTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/StreamCollectTest.java index 877a0dd0b17..b1f35e4a501 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/StreamCollectTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/StreamCollectTest.java @@ -26,7 +26,7 @@ import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; -import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction; import org.apache.sysds.runtime.io.MatrixWriter; import org.apache.sysds.runtime.io.MatrixWriterFactory; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -86,7 +86,7 @@ public void runRawInstructionSequenceTest() { VariableCPInstruction createOut = VariableCPInstruction.parseInstruction( "CP°createvar°_mVar1°" + input("tmp1") + "°true°MATRIX°binary°" + rows + "°" + cols + "°1000°700147°copy"); - TransposeOOCInstruction oocTranspose = TransposeOOCInstruction.parseInstruction( + ReorgOOCInstruction oocTranspose = ReorgOOCInstruction.parseInstruction( "OOC°r'°_mVar0·MATRIX·FP64°_mVar1·MATRIX·FP64"); VariableCPInstruction createOut2 = VariableCPInstruction.parseInstruction( "CP°createvar°_mVar2°" + input("tmp2") + "°true°MATRIX°binary°" + rows + "°" + cols + diff --git a/src/test/scripts/functions/ooc/Rand1.dml b/src/test/scripts/functions/ooc/Rand1.dml new file mode 100644 index 00000000000..2861f294620 --- /dev/null +++ b/src/test/scripts/functions/ooc/Rand1.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +res = rand(rows=1500, cols=1200, min=-1, max=1); + +write(res, $2, format="binary"); diff --git a/src/test/scripts/functions/ooc/Rand2.dml b/src/test/scripts/functions/ooc/Rand2.dml new file mode 100644 index 00000000000..033632edb00 --- /dev/null +++ b/src/test/scripts/functions/ooc/Rand2.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +res = matrix(1, rows=1500, cols=1200); + +write(res, $2, format="binary"); diff --git a/src/test/scripts/functions/ooc/Sort.dml b/src/test/scripts/functions/ooc/Sort.dml new file mode 100644 index 00000000000..30ccfd03e99 --- /dev/null +++ b/src/test/scripts/functions/ooc/Sort.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +X = read($1); + +res = order(target=X, by=1); + +write(res, $2, format="binary");