Skip to content

Commit 421a3c0

Browse files
committed
android: added TrainingJob which will run training jobs in the background for the model
1 parent ffcabf6 commit 421a3c0

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package ai.luk;
2+
3+
import android.app.job.JobInfo;
4+
import android.app.job.JobScheduler;
5+
import android.app.job.JobService;
6+
import android.app.job.JobParameters;
7+
import android.content.Context;
8+
import android.content.ComponentName;
9+
10+
import ai.luk.ModelType;
11+
12+
/*
13+
To use this training job you'll need to extend it and implement the domain,
14+
modelType and dataDir methods. Your code needs to call the schedule() method at
15+
least once.
16+
You'll also need to define the service in your AndroidManifest.xml file.
17+
18+
<service
19+
android:name=".yourpackage.TrainingJob"
20+
android:permission="android.permission.BIND_JOB_SERVICE"
21+
android:permission="android.permission.RECEIVE_BOOT_COMPLETED"
22+
android:exported="true"/>
23+
24+
*/
25+
public abstract class TrainingJob extends JobService {
26+
private ModelType mt;
27+
28+
public void schedule() {
29+
JobScheduler jobScheduler =
30+
(JobScheduler) getSystemService(Context.JOB_SCHEDULER_SERVICE);
31+
int id = jobId();
32+
33+
// Check if the job is already scheduled and reschedule it in case the job
34+
// has changed.
35+
for (JobInfo ji : jobScheduler.getAllPendingJobs()) {
36+
if (ji.getId() == id) {
37+
jobScheduler.cancel(id);
38+
}
39+
}
40+
41+
jobScheduler.schedule(new JobInfo.Builder(id,
42+
new ComponentName(this, this.getClass()))
43+
.setRequiredNetworkType(JobInfo.NETWORK_TYPE_UNMETERED)
44+
.setPersisted(true)
45+
.setRequiresCharging(true)
46+
.setRequiresDeviceIdle(true)
47+
.setPeriodic(intervalMillis())
48+
.build());
49+
}
50+
51+
@Override
52+
public boolean onStartJob(final JobParameters params) {
53+
try {
54+
mt = new ModelType(domain(), modelType(), dataDir());
55+
mt.startTraining();
56+
} catch (Exception e) {
57+
throw new RuntimeException(e);
58+
}
59+
return true;
60+
}
61+
62+
@Override
63+
public boolean onStopJob(final JobParameters params) {
64+
if (mt != null) {
65+
try {
66+
mt.stopTraining();
67+
} catch (Exception e) {
68+
throw new RuntimeException(e);
69+
}
70+
}
71+
return true;
72+
}
73+
74+
public int jobId() {
75+
return ("training:"+domain()+"/"+modelType()+"/"+dataDir()).hashCode();
76+
}
77+
78+
// intervalMillis is the number of milliseconds that the training job will run
79+
// at most once per.
80+
public long intervalMillis() {
81+
// default to 6 hours.
82+
return 6 * 60 * 60 * 1000;
83+
}
84+
85+
public abstract String domain();
86+
public abstract String modelType();
87+
public abstract String dataDir();
88+
}

0 commit comments

Comments
 (0)