Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/deep-learning-on-flink-unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ jobs:
- name: Build and test with maven
env:
TF_ON_FLINK_IP: 127.0.0.1
run: cd deep-learning-on-flink && mvn -DskipITs=true -B clean package
run: |
cd deep-learning-on-flink
mvn -DskipTests -B clean install
mvn -DskipITs=true -B test
- name: Upload jars
uses: actions/upload-artifact@v2
with:
Expand Down
2 changes: 1 addition & 1 deletion deep-learning-on-flink/flink-ml-examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table-planner-blink_${scala.major.version}</artifactId>
<artifactId>flink-table-planner_${scala.major.version}</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ private void runTable(String script) throws Exception {
StatementSet statementSet = tableEnv.createStatementSet();
PyTorchUtil.train(streamEnv, tableEnv, statementSet, null, pytorchConfig, null);
statementSet.execute().getJobClient().get()
.getJobExecutionResult(Thread.currentThread().getContextClassLoader())
.getJobExecutionResult()
.get();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ private void trainMnistTable(String trainPy) throws Exception {
TFConfig tfConfig = prepareTrain(trainPy);
TFUtils.train(flinkEnv, tableEnv, statementSet, null, tfConfig, null);
statementSet.execute().getJobClient().get()
.getJobExecutionResult(Thread.currentThread().getContextClassLoader()).get();
.getJobExecutionResult().get();
}

private RestartStrategies.RestartStrategyConfiguration restartStrategy() {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@

import com.alibaba.flink.ml.tensorflow.data.TFRecordReader;
import com.alibaba.flink.ml.tensorflow.io.TFRExtractRowHelper;

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.types.Row;

import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
Expand All @@ -40,81 +38,81 @@
import java.util.concurrent.ThreadLocalRandom;

public class DelayedTFRSourceFunction extends RichParallelSourceFunction<Row>
implements ListCheckpointed<Long>, ResultTypeQueryable<Row> {
implements ListCheckpointed<Long>, ResultTypeQueryable<Row> {

private static Logger LOG = LoggerFactory.getLogger(DelayedTFRSourceFunction.class);
private static Logger LOG = LoggerFactory.getLogger(DelayedTFRSourceFunction.class);

private final String[] paths;
private final long delayBound;
private final RowTypeInfo outRowType;
private final TFRExtractRowHelper extractRowHelper;
private long offset = 0;
private long numRead = 0;
private volatile boolean cancelled;
private final String[] paths;
private final long delayBound;
private final RowTypeInfo outRowType;
private final TFRExtractRowHelper extractRowHelper;
private long offset = 0;
private long numRead = 0;
private volatile boolean cancelled;

DelayedTFRSourceFunction(String[] paths, long delayBound, RowTypeInfo outRowType,
TFRExtractRowHelper.ScalarConverter[] converters) {
this.paths = paths;
this.delayBound = delayBound;
this.outRowType = outRowType;
extractRowHelper = new TFRExtractRowHelper(outRowType, converters);
}
DelayedTFRSourceFunction(String[] paths, long delayBound, RowTypeInfo outRowType,
TFRExtractRowHelper.ScalarConverter[] converters) {
this.paths = paths;
this.delayBound = delayBound;
this.outRowType = outRowType;
extractRowHelper = new TFRExtractRowHelper(outRowType, converters);
}

@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
if (offset != 0) {
LOG.info("Restored from offset {}", offset);
}
numRead = 0;
cancelled = false;
}
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
if (offset != 0) {
LOG.info("Restored from offset {}", offset);
}
numRead = 0;
cancelled = false;
}

@Override
public List<Long> snapshotState(long l, long l1) throws Exception {
return Collections.singletonList(offset);
}
@Override
public List<Long> snapshotState(long l, long l1) throws Exception {
return Collections.singletonList(offset);
}

@Override
public void restoreState(List<Long> list) throws Exception {
offset = list.get(0);
}
@Override
public void restoreState(List<Long> list) throws Exception {
offset = list.get(0);
}

@Override
public void run(SourceContext<Row> sourceContext) throws Exception {
final Object lock = sourceContext.getCheckpointLock();
org.apache.hadoop.conf.Configuration hadoopConf = new org.apache.hadoop.conf.Configuration();
for (String p : paths) {
Path path = new Path(p);
FileSystem fs = path.getFileSystem(hadoopConf);
try (FSDataInputStream inputStream = fs.open(path)) {
TFRecordReader tfrReader = new TFRecordReader(inputStream, true);
byte[] bytes = tfrReader.read();
while (bytes != null) {
if (cancelled) {
return;
}
if (numRead == offset) {
synchronized (lock) {
sourceContext.collect(extractRowHelper.extract(bytes));
offset++;
}
Thread.sleep(ThreadLocalRandom.current().nextLong(delayBound) + 1);
}
numRead++;
bytes = tfrReader.read();
}
}
}
}
@Override
public void run(SourceContext<Row> sourceContext) throws Exception {
final Object lock = sourceContext.getCheckpointLock();
org.apache.hadoop.conf.Configuration hadoopConf = new org.apache.hadoop.conf.Configuration();
for (String p : paths) {
Path path = new Path(p);
FileSystem fs = path.getFileSystem(hadoopConf);
try (FSDataInputStream inputStream = fs.open(path)) {
TFRecordReader tfrReader = new TFRecordReader(inputStream, true);
byte[] bytes = tfrReader.read();
while (bytes != null) {
if (cancelled) {
return;
}
if (numRead == offset) {
synchronized (lock) {
sourceContext.collect(extractRowHelper.extract(bytes));
offset++;
}
Thread.sleep(ThreadLocalRandom.current().nextLong(delayBound) + 1);
}
numRead++;
bytes = tfrReader.read();
}
}
}
}

@Override
public void cancel() {
cancelled = true;
}
@Override
public void cancel() {
cancelled = true;
}

@Override
public TypeInformation<Row> getProducedType() {
return outRowType;
}
@Override
public TypeInformation<Row> getProducedType() {
return outRowType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import com.alibaba.flink.ml.operator.util.TypeUtil;
import com.alibaba.flink.ml.tensorflow.io.TFRExtractRowHelper;

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
Expand All @@ -34,44 +33,44 @@

public class DelayedTFRTableSourceStream implements StreamTableSource<Row> {

private static final long DEFAULT_DELAY_BOUND = 20;
private static final long DEFAULT_DELAY_BOUND = 20;

private final String[] paths;
private final Long delayBound;
private final RowTypeInfo outRowType;
private final TFRExtractRowHelper.ScalarConverter[] converters;
private final String[] paths;
private final Long delayBound;
private final RowTypeInfo outRowType;
private final TFRExtractRowHelper.ScalarConverter[] converters;

private DelayedTFRTableSourceStream(String[] paths, int epochs, Long delayBound,
RowTypeInfo outRowType, TFRExtractRowHelper.ScalarConverter[] converters) {
this.paths = paths;
this.delayBound = delayBound;
this.outRowType = outRowType;
this.converters = converters;
}
private DelayedTFRTableSourceStream(String[] paths, int epochs, Long delayBound,
RowTypeInfo outRowType, TFRExtractRowHelper.ScalarConverter[] converters) {
this.paths = paths;
this.delayBound = delayBound;
this.outRowType = outRowType;
this.converters = converters;
}

public DelayedTFRTableSourceStream(String[] paths, int epochs, RowTypeInfo outRowType,
TFRExtractRowHelper.ScalarConverter[] converters) {
this(paths, epochs, DEFAULT_DELAY_BOUND, outRowType, converters);
}
public DelayedTFRTableSourceStream(String[] paths, int epochs, RowTypeInfo outRowType,
TFRExtractRowHelper.ScalarConverter[] converters) {
this(paths, epochs, DEFAULT_DELAY_BOUND, outRowType, converters);
}

@Override
public TypeInformation<Row> getReturnType() {
return outRowType;
}
@Override
public TypeInformation<Row> getReturnType() {
return outRowType;
}

@Override
public TableSchema getTableSchema() {
return TypeUtil.rowTypeInfoToSchema(outRowType);
}
@Override
public TableSchema getTableSchema() {
return TypeUtil.rowTypeInfoToTableSchema(outRowType);
}

@Override
public String explainSource() {
return "Delayed TFRecord source " + Arrays.toString(paths);
}
@Override
public String explainSource() {
return "Delayed TFRecord source " + Arrays.toString(paths);
}

@Override
public DataStream<Row> getDataStream(StreamExecutionEnvironment execEnv) {
return execEnv.addSource(new DelayedTFRSourceFunction(paths, delayBound, outRowType, converters))
.setParallelism(1).name(explainSource());
}
@Override
public DataStream<Row> getDataStream(StreamExecutionEnvironment execEnv) {
return execEnv.addSource(new DelayedTFRSourceFunction(paths, delayBound, outRowType, converters))
.setParallelism(1).name(explainSource());
}
}
Loading