-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtree_booster.jl
77 lines (55 loc) · 2.19 KB
/
tree_booster.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# # Using IterationControl.jl to train a tree-booster on the iris data set
# In this demonstration we show how to the controls in
# [IterationControl.jl](https://github.com/ablaom/IterationControl.jl)
# with an iterative
# [MLJ](https://github.com/alan-turing-institute/MLJ.jl) model, using
# our bare hands. (MLJ will ultimately provide its own canned
# `IteratedModel` wrapper to make this more convenient and
# compositional.)
using Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()
import MLJ
using IterationControl
using Statistics
using Random
Random.seed!(123)
MLJ.color_off()
# Loading some data and splitting observation indices into test/train:
X, y = MLJ.@load_iris;
train, test = MLJ.partition(eachindex(y), 0.7, shuffle=true)
# Import an model type:
Booster = MLJ.@load EvoTreeClassifier verbosity=0
# Note that in MLJ a "model" is just a container for
# hyper-parameters. The objects we will iterate here are MLJ
# [*machines*](https://alan-turing-institute.github.io/MLJ.jl/dev/machines/);
# these bind the model to train/test data and, in the case of
# iterative models, can be trained using a warm-restart.
# Creating a machine:
mach = MLJ.machine(Booster(nrounds=1), X, y);
# Lifting MLJ's `fit!(::Machine)` method to `IterativeControl.train!`:
function IterationControl.train!(mach::MLJ.Machine{<:Booster}, n::Int)
mlj_model = mach.model
mlj_model.nrounds = mlj_model.nrounds + n
MLJ.fit!(mach, rows=train, verbosity=0)
end
# Lifting the out-of-sample loss:
function IterationControl.loss(mach::MLJ.Machine{<:Booster})
mlj_model = mach.model
yhat = MLJ.predict(mach, rows=test)
return MLJ.log_loss(yhat, y[test]) |> mean
end
# Iterating with controls:
logging(mach) = "loss: $(IterationControl.loss(mach))"
IterationControl.train!(mach,
Step(5),
GL(),
Info(logging))
# Continuing iteration with a different stopping criterion:
IterationControl.train!(mach,
Step(5),
NumberLimit(10),
Info(logging))
using Literate #src
Literate.markdown(@__FILE__, @__DIR__, execute=false) #src
Literate.notebook(@__FILE__, @__DIR__, execute=true) #src