Skip to content

Commit 9993e00

Browse files
authored
Merge pull request #24 from lorenzoh/OneHot-MaskMulti
Add one hot encoding for `MaskMulti`s
2 parents b2f1c11 + 46fec33 commit 9993e00

File tree

4 files changed

+62
-1
lines changed

4 files changed

+62
-1
lines changed

docs/literate/preprocessing.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ This library also implements some general transformations useful for getting dat
55
- [`ToEltype`](#)`(T)` converts the element type of any [`AbstractArrayItem`](#) to `T`.
66
- [`ImageToTensor`](#) converts an image to an `ArrayItem` with another dimension for the color channels
77
- [`Normalize`](#) normalizes image tensors
8+
- [`OneHot`](#) to one-hot encode multi-class masks ([`MaskMulti`](#)s)
89

src/DataAugmentation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ export Item,
6161
CenterResizeCrop,
6262
Buffered,
6363
BufferedThreadsafe,
64+
OneHot,
6465
apply,
6566
Reflect,
6667
FlipX,

src/preprocessing.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ denormalize(a, means, stds) = denormalize!(copy(a), means, stds)
114114
"""
115115
NormalizeIntensity()
116116
117-
Normalizes the pixels of an array based on calculated mean and std.
117+
Normalizes the pixels of an array based on calculated mean and std.
118118
"""
119119

120120
struct NormalizeIntensity <: Transform end
@@ -183,6 +183,49 @@ end
183183
tensortoimage(a::AbstractArray{T, 3}) where T = colorview(RGB, permuteddimsview(a, (3, 1, 2)))
184184
tensortoimage(a::AbstractArray{T, 2}) where T = colorview(Gray, a)
185185

186+
187+
# OneHot encoding
188+
189+
"""
190+
OneHot([T = Float32])
191+
192+
One-hot encodes a `MaskMulti` with `n` classes and size `sz` into
193+
an array item of size `(sz..., n)` with element type `T`. Supports [`apply!`].
194+
195+
```julia
196+
item = MaskMulti(rand(1:4, 100, 100), 1:4)
197+
apply(OneHot(), item)
198+
```
199+
"""
200+
struct OneHot{T} <: Transform end
201+
OneHot() = OneHot{Float32}()
202+
203+
function apply(tfm::OneHot{T}, item::MaskMulti; randstate = nothing) where T
204+
mask = itemdata(item)
205+
a = zeros(T, size(mask)..., length(item.classes))
206+
for I in CartesianIndices(mask)
207+
a[I, mask[I]] = one(T)
208+
end
209+
210+
return ArrayItem(a)
211+
end
212+
213+
214+
function apply!(buf, tfm::OneHot{T}, item::MaskMulti; randstate = nothing) where T
215+
mask = itemdata(item)
216+
a = itemdata(buf)
217+
@show a[1:6]
218+
fill!(a, zero(T))
219+
@show a[1:6]
220+
221+
for I in CartesianIndices(mask)
222+
a[I, mask[I]] = one(T)
223+
end
224+
225+
return buf
226+
end
227+
228+
186229
function onehot(T, x::Int, n::Int)
187230
v = fill(zero(T), n)
188231
v[x] = one(T)

test/preprocessing.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,22 @@ end
5555
end
5656
end
5757

58+
@testset ExtendedTestSet "OneHot" begin
59+
tfm = OneHot()
60+
mask = rand(1:4, 10, 10)
61+
item = MaskMulti(mask, 1:4)
62+
@test_nowarn apply(tfm, item)
63+
aitem = apply(tfm, item)
64+
@test size(itemdata(aitem)) == (10, 10, 4)
65+
66+
item2 = MaskMulti(rand(1:3, 10, 10), 1:4)
67+
buf = itemdata(aitem)
68+
bufcopy = copy(buf)
69+
apply!(aitem, tfm, item2)
70+
@test itemdata(item) == itemdata(item2) || itemdata(aitem) != bufcopy
71+
72+
end
73+
5874
@testset ExtendedTestSet "Image pipeline" begin
5975
image = Image(rand(RGB, 150, 150))
6076

0 commit comments

Comments
 (0)