Skip to content

Commit 5092ea2

Browse files
authored
Fix negative bandwidth test and add online code path test. (gh-118600)
1 parent 9c13d9e commit 5092ea2

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

Lib/statistics.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,9 +1791,8 @@ def kde_random(data, h, kernel='normal', *, seed=None):
17911791
if h <= 0.0:
17921792
raise StatisticsError(f'Bandwidth h must be positive, not {h=!r}')
17931793

1794-
try:
1795-
kernel_invcdf = _kernel_invcdfs[kernel]
1796-
except KeyError:
1794+
kernel_invcdf = _kernel_invcdfs.get(kernel)
1795+
if kernel_invcdf is None:
17971796
raise StatisticsError(f'Unknown kernel name: {kernel!r}')
17981797

17991798
prng = _random.Random(seed)

Lib/test/test_statistics.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2402,7 +2402,7 @@ def integrate(func, low, high, steps=10_000):
24022402
with self.assertRaises(StatisticsError):
24032403
kde(sample, h=0.0) # Zero bandwidth
24042404
with self.assertRaises(StatisticsError):
2405-
kde(sample, h=0.0) # Negative bandwidth
2405+
kde(sample, h=-1.0) # Negative bandwidth
24062406
with self.assertRaises(TypeError):
24072407
kde(sample, h='str') # Wrong bandwidth type
24082408
with self.assertRaises(StatisticsError):
@@ -2426,6 +2426,14 @@ def integrate(func, low, high, steps=10_000):
24262426
self.assertEqual(f_hat(-1.0), 1/2)
24272427
self.assertEqual(f_hat(1.0), 1/2)
24282428

2429+
# Test online updates to data
2430+
2431+
data = [1, 2]
2432+
f_hat = kde(data, 5.0, 'triangular')
2433+
self.assertEqual(f_hat(100), 0.0)
2434+
data.append(100)
2435+
self.assertGreater(f_hat(100), 0.0)
2436+
24292437
def test_kde_kernel_invcdfs(self):
24302438
kernel_invcdfs = statistics._kernel_invcdfs
24312439
kde = statistics.kde
@@ -2462,7 +2470,7 @@ def test_kde_random(self):
24622470
with self.assertRaises(TypeError):
24632471
kde_random(iter(sample), 1.5) # Data is not a sequence
24642472
with self.assertRaises(StatisticsError):
2465-
kde_random(sample, h=0.0) # Zero bandwidth
2473+
kde_random(sample, h=-1.0) # Zero bandwidth
24662474
with self.assertRaises(StatisticsError):
24672475
kde_random(sample, h=0.0) # Negative bandwidth
24682476
with self.assertRaises(TypeError):
@@ -2474,10 +2482,10 @@ def test_kde_random(self):
24742482

24752483
h = 1.5
24762484
kernel = 'cosine'
2477-
prng = kde_random(sample, h, kernel)
2478-
self.assertEqual(prng.__name__, 'rand')
2479-
self.assertIn(kernel, prng.__doc__)
2480-
self.assertIn(repr(h), prng.__doc__)
2485+
rand = kde_random(sample, h, kernel)
2486+
self.assertEqual(rand.__name__, 'rand')
2487+
self.assertIn(kernel, rand.__doc__)
2488+
self.assertIn(repr(h), rand.__doc__)
24812489

24822490
# Approximate distribution test: Compare a random sample to the expected distribution
24832491

@@ -2507,6 +2515,14 @@ def p_expected(x):
25072515
for x in xarr:
25082516
self.assertTrue(math.isclose(p_observed(x), p_expected(x), abs_tol=0.0005))
25092517

2518+
# Test online updates to data
2519+
2520+
data = [1, 2]
2521+
rand = kde_random(data, 5, 'triangular')
2522+
self.assertLess(max([rand() for i in range(5000)]), 10)
2523+
data.append(100)
2524+
self.assertGreater(max(rand() for i in range(5000)), 10)
2525+
25102526

25112527
class TestQuantiles(unittest.TestCase):
25122528

0 commit comments

Comments
 (0)