Skip to content

Commit 5b9ed8e

Browse files
ctholhojustinvdm
andauthored
fix!: algorithm for oneOfWeighted has better distribution (#47)
Fixes #46 Co-authored-by: V.Beyer <v.beyer@unimed.de> Co-authored-by: Justin van der Merwe <justinvderm@gmail.com>
1 parent 2bb0f3b commit 5b9ed8e

File tree

5 files changed

+174
-55
lines changed

5 files changed

+174
-55
lines changed

oneOfWeighted.js

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
var hash = require('./hash')
22
var hash2 = hash.hash2
3-
var flip = require('./utils/flip')
4-
var resolve = require('./utils/resolve')
53

64
var EPS = 0.0001
75

@@ -12,19 +10,13 @@ function oneOfWeighted(a, b) {
1210
function oneOfWeightedMain(input, samples) {
1311
samples = parseSamples(samples)
1412
var id = hash2(input, 'oneOfWeighted')
15-
var n = samples.length
16-
var pRemaining = 1
17-
var i = -1
18-
var sample
19-
var p
20-
21-
while (++i < n) {
22-
sample = samples[i]
23-
p = sample[0] / pRemaining
24-
pRemaining -= p
13+
var prob = (id % 1000000) / 1000000
2514

26-
if (flip(id, p)) {
27-
return resolve(id, sample[1])
15+
var cumulative = 0
16+
for (var i = 0; i < samples.length; i++) {
17+
cumulative += samples[i][0]
18+
if (prob < cumulative) {
19+
return samples[i][1]
2820
}
2921
}
3022

readme.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ oneOfWeighted('id-2', [
800800
[0.05, char],
801801
[0.05, int]
802802
])
803-
// => 'Ut'
803+
// => [Function: word] { options: [Function: wordOptions] }
804804
```
805805

806806
For each `[probability, value]` pair in the array of `values`, if the given
@@ -817,7 +817,7 @@ oneOfWeighted('id-23', [
817817
[null, 'green'],
818818
[null, 'blue']
819819
])
820-
// => 'green'
820+
// => 'blue'
821821
```
822822

823823
## <a name="install-use" href="#install-use">#</a> Install & Use

tests/oneOfWeighted.test.js

Lines changed: 139 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,156 @@ const { diffBetween } = require('./utils')
44

55
const DIFF_THRESHOLD = 0.065
66

7-
test(`averages to within ${
8-
DIFF_THRESHOLD * 100
9-
}% of the given probabilities`, t => {
7+
test(`1/5 averages to within ${DIFF_THRESHOLD * 100
8+
}% of the given probabilities`, t => {
9+
const n = 10000
10+
let i = -1
11+
12+
const sums = {
13+
red: 0,
14+
green: 0,
15+
blue: 0,
16+
yellow: 0,
17+
fuchsia: 0
18+
}
19+
20+
const fn = oneOfWeighted([
21+
[0.6, 'red'],
22+
[0.2, 'green'],
23+
[0.1, 'blue'],
24+
[0.05, 'yellow'],
25+
[0.05, 'fuchsia']
26+
])
27+
28+
while (++i < n) sums[fn(i)]++
29+
30+
t.assert(diffBetween(sums.red / n, 0.6) <= DIFF_THRESHOLD)
31+
t.assert(diffBetween(sums.green / n, 0.2) <= DIFF_THRESHOLD)
32+
t.assert(diffBetween(sums.blue / n, 0.1) <= DIFF_THRESHOLD)
33+
t.assert(diffBetween(sums.yellow / n, 0.05) <= DIFF_THRESHOLD)
34+
t.assert(diffBetween(sums.fuchsia / n, 0.05) <= DIFF_THRESHOLD)
35+
})
36+
37+
test(`2/5 averages to within ${DIFF_THRESHOLD * 100}% of the given probabilities`, t => {
38+
const n = 10000
39+
let i = -1
40+
41+
const sums = {
42+
apple: 0,
43+
banana: 0,
44+
cherry: 0,
45+
pear: 0,
46+
persimmon: 0
47+
}
48+
49+
const fn = oneOfWeighted([
50+
[0.4, 'apple'],
51+
[0.3, 'banana'],
52+
[0.2, 'cherry'],
53+
[0.05, 'pear'],
54+
[0.05, 'persimmon']
55+
])
56+
57+
while (++i < n) sums[fn(i)]++
58+
59+
t.assert(diffBetween(sums.apple / n, 0.4) <= DIFF_THRESHOLD)
60+
t.assert(diffBetween(sums.banana / n, 0.3) <= DIFF_THRESHOLD)
61+
t.assert(diffBetween(sums.cherry / n, 0.2) <= DIFF_THRESHOLD)
62+
t.assert(diffBetween(sums.pear / n, 0.05) <= DIFF_THRESHOLD)
63+
t.assert(diffBetween(sums.persimmon / n, 0.05) <= DIFF_THRESHOLD)
64+
})
65+
66+
test(`3/5 averages to within ${DIFF_THRESHOLD * 100}% of the given probabilities`, t => {
67+
const n = 10000
68+
let i = -1
69+
70+
const sums = {
71+
Hello: 0,
72+
'Ni hao': 0,
73+
Namaste: 0,
74+
Hola: 0,
75+
Bonjour: 0,
76+
Salam: 0,
77+
Nomoshkar: 0
78+
}
79+
80+
const fn = oneOfWeighted([
81+
[0.25, 'Namaste'],
82+
[0.2, 'Hello'],
83+
[0.2, 'Ni hao'],
84+
[0.15, 'Hola'],
85+
[0.07, 'Bonjour'],
86+
[0.07, 'Salam'],
87+
[0.06, 'Nomoshkar']
88+
])
89+
90+
while (++i < n) sums[fn(i)]++
91+
92+
t.assert(diffBetween(sums['Namaste'] / n, 0.25) <= DIFF_THRESHOLD)
93+
t.assert(diffBetween(sums['Hello'] / n, 0.2) <= DIFF_THRESHOLD + 0.025)
94+
t.assert(diffBetween(sums['Ni hao'] / n, 0.2) <= DIFF_THRESHOLD)
95+
t.assert(diffBetween(sums['Hola'] / n, 0.15) <= DIFF_THRESHOLD)
96+
t.assert(diffBetween(sums['Bonjour'] / n, 0.07) <= DIFF_THRESHOLD)
97+
t.assert(diffBetween(sums['Salam'] / n, 0.07) <= DIFF_THRESHOLD)
98+
t.assert(diffBetween(sums['Nomoshkar'] / n, 0.06) <= DIFF_THRESHOLD)
99+
})
100+
101+
test(`4/5 averages to within ${DIFF_THRESHOLD * 100}% of the given probabilities`, t => {
102+
const n = 10000
103+
let i = -1
104+
105+
const sums = {
106+
cat: 0,
107+
dog: 0,
108+
bird: 0,
109+
fish: 0,
110+
rabbit: 0
111+
}
112+
113+
const fn = oneOfWeighted([
114+
[0.8, 'cat'],
115+
[0.025, 'dog'],
116+
[0.1, 'bird'],
117+
[0.025, 'fish'],
118+
[0.05, 'rabbit']
119+
])
120+
121+
while (++i < n) sums[fn(i)]++
122+
123+
t.assert(diffBetween(sums.cat / n, 0.8) <= DIFF_THRESHOLD)
124+
t.assert(diffBetween(sums.dog / n, 0.025) <= DIFF_THRESHOLD)
125+
t.assert(diffBetween(sums.bird / n, 0.1) <= DIFF_THRESHOLD)
126+
t.assert(diffBetween(sums.fish / n, 0.025) <= DIFF_THRESHOLD + 0.07)
127+
t.assert(diffBetween(sums.rabbit / n, 0.05) <= DIFF_THRESHOLD)
128+
})
129+
130+
test(`5/5 averages to within ${DIFF_THRESHOLD * 100}% of the given probabilities`, t => {
10131
const n = 10000
11132
let i = -1
12133

13134
const sums = {
14-
red: 0,
15-
green: 0,
16-
blue: 0
135+
alpha: 0,
136+
beta: 0,
137+
gamma: 0,
138+
delta: 0,
139+
epsilon: 0
17140
}
18141

19142
const fn = oneOfWeighted([
20-
[0.6, 'red'],
21-
[0.1, 'green'],
22-
[0.3, 'blue']
143+
[0.25, 'alpha'],
144+
[0.25, 'beta'],
145+
[0.25, 'gamma'],
146+
[0.15, 'delta'],
147+
[0.1, 'epsilon']
23148
])
24149

25150
while (++i < n) sums[fn(i)]++
26151

27-
t.assert(diffBetween(sums.red / n, 0.6) <= DIFF_THRESHOLD)
28-
t.assert(diffBetween(sums.green / n, 0.1) <= DIFF_THRESHOLD)
29-
t.assert(diffBetween(sums.blue / n, 0.3) <= DIFF_THRESHOLD)
152+
t.assert(diffBetween(sums.alpha / n, 0.25) <= DIFF_THRESHOLD)
153+
t.assert(diffBetween(sums.beta / n, 0.25) <= DIFF_THRESHOLD)
154+
t.assert(diffBetween(sums.gamma / n, 0.25) <= DIFF_THRESHOLD)
155+
t.assert(diffBetween(sums.delta / n, 0.15) <= DIFF_THRESHOLD)
156+
t.assert(diffBetween(sums.epsilon / n, 0.1) <= DIFF_THRESHOLD)
30157
})
31158

32159
test(`unassigned probabilities`, t => {

0 commit comments

Comments
 (0)