Skip to content

Commit f8fd4e8

Browse files
committed
refactor: improve quantization classes and documentation
1 parent f326d7d commit f8fd4e8

File tree

4 files changed

+29
-36
lines changed

4 files changed

+29
-36
lines changed

src/komm/_quantization/LloydMaxQuantizer.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,30 +147,26 @@ def lloyd_max_quantizer(
147147
) -> tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]:
148148
# See [Say06, eqs. (9.27) and (9.28)].
149149
x_min, x_max = input_range
150-
delta = (x_max - x_min) / num_levels
151150

152151
# Initial guess
153-
levels = np.linspace(x_min + delta / 2, x_max - delta / 2, num=num_levels)
154-
thresholds = np.empty(num_levels + 1, dtype=float)
155-
new_levels = np.empty_like(levels)
152+
delta = (x_max - x_min) / num_levels
153+
y = np.linspace(x_min + delta / 2, x_max - delta / 2, num=num_levels)
156154

157-
for _ in range(max_iter):
158-
thresholds[0] = x_min
159-
thresholds[1:-1] = 0.5 * (levels[:-1] + levels[1:])
160-
thresholds[-1] = x_max
155+
λ = np.concatenate([[x_min], np.empty(num_levels - 1), [x_max]])
156+
y_new = np.empty_like(y)
161157

158+
for _ in range(max_iter):
159+
λ[1:-1] = 0.5 * (y[:-1] + y[1:])
162160
for i in range(num_levels):
163-
left, right = thresholds[i], thresholds[i + 1]
164-
x = np.linspace(left, right, num=points_per_interval, dtype=float)
161+
x = np.linspace(λ[i], λ[i + 1], num=points_per_interval, dtype=float)
165162
pdf = input_pdf(x)
166-
numerator = np.trapezoid(x * pdf, x)
167163
denominator = np.trapezoid(pdf, x)
168164
if denominator != 0:
169-
new_levels[i] = numerator / denominator
165+
y_new[i] = np.trapezoid(x * pdf, x) / denominator
170166
else: # Keep old level
171-
new_levels[i] = levels[i]
172-
if np.allclose(levels, new_levels):
167+
y_new[i] = y[i]
168+
if np.allclose(y, y_new):
173169
break
174-
levels = new_levels.copy()
170+
y = y_new.copy()
175171

176-
return new_levels, thresholds[1:-1]
172+
return y_new, λ[1:-1]

src/komm/_quantization/ScalarQuantizer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,25 @@
99

1010
class ScalarQuantizer(abc.ScalarQuantizer):
1111
r"""
12-
General scalar quantizer. It is defined by a list of *levels*, $v_0, v_1, \ldots, v_{L-1}$, and a list of *thresholds*, $t_0, t_1, \ldots, t_L$, satisfying
12+
General scalar quantizer. It is defined by a list of *levels*, $y_0, y_1, \ldots, y_{L-1}$, and a list of *thresholds*, $\lambda_0, \lambda_1, \ldots, \lambda_L$, satisfying
1313
$$
14-
-\infty = t_0 < v_0 < t_1 < v_1 < \cdots < t_{L - 1} < v_{L - 1} < t_L = +\infty.
14+
-\infty = \lambda_0 < y_0 < \lambda_1 < y_1 < \cdots < \lambda_{L - 1} < y_{L - 1} < \lambda_L = +\infty.
1515
$$
16-
Given an input $x \in \mathbb{R}$, the output of the quantizer is given by $y = v_i$ if and only if $t_i \leq x < t_{i+1}$, where $i \in [0:L)$. For more details, see <cite>Say06, Ch. 9</cite>.
16+
Given an input $x \in \mathbb{R}$, the output of the quantizer is given by $y = y_i$ if and only if $\lambda_i \leq x < \lambda_{i+1}$, where $i \in [0:L)$. For more details, see <cite>Say06, Ch. 9</cite>.
1717
1818
Parameters:
19-
levels: The quantizer levels $v_0, v_1, \ldots, v_{L-1}$. It should be a list floats of length $L$.
19+
levels: The quantizer levels $y_0, y_1, \ldots, y_{L-1}$. It should be a list floats of length $L$.
2020
21-
thresholds: The quantizer finite thresholds $t_1, t_2, \ldots, t_{L-1}$. It should be a list of floats of length $L - 1$.
21+
thresholds: The quantizer finite thresholds $\lambda_1, \lambda_2, \ldots, \lambda_{L-1}$. It should be a list of floats of length $L - 1$.
2222
2323
Examples:
2424
The $5$-level scalar quantizer whose characteristic (input × output) curve is depicted in the figure below has levels
2525
$$
26-
v_0 = -2, ~ v_1 = -1, ~ v_2 = 0, ~ v_3 = 1, ~ v_4 = 2,
26+
y_0 = -2, ~ y_1 = -1, ~ y_2 = 0, ~ y_3 = 1, ~ y_4 = 2,
2727
$$
2828
and thresholds
2929
$$
30-
t_0 = -\infty, ~ t_1 = -1.5, ~ t_2 = -0.3, ~ t_3 = 0.8, ~ t_4 = 1.4, ~ t_5 = \infty.
30+
\lambda_0 = -\infty, ~ \lambda_1 = -1.5, ~ \lambda_2 = -0.3, ~ \lambda_3 = 0.8, ~ \lambda_4 = 1.4, ~ \lambda_5 = \infty.
3131
$$
3232
3333
<figure markdown>

src/komm/_quantization/UniformQuantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ class UniformQuantizer(abc.ScalarQuantizer):
1111
r"""
1212
Uniform scalar quantizer. It is a [scalar quantizer](/ref/ScalarQuantizer) in which the separation between levels is a constant $\Delta$, called the *quantization step*, and the thresholds are the mid-point between adjacent levels. More precisely, the levels are given by
1313
$$
14-
v_i = (i - (L - 1)/2 + \theta) \Delta, \qquad i \in [0 : L),
14+
y_i = (i - (L - 1)/2 + \theta) \Delta, \qquad i \in [0 : L),
1515
$$
1616
where $\theta \in \mathbb{R}$ is an arbitrary *offset* (normalized by $\Delta$), and the finite thresholds are given by
1717
$$
18-
t_i = \frac{v_{i-1} + v_i}{2}, \qquad i \in [1 : L).
18+
\lambda_i = \frac{y_{i-1} + y_i}{2}, \qquad i \in [1 : L).
1919
$$
2020
For more details, see <cite>Say06, Sec. 9.4</cite>.
2121

src/komm/_quantization/base.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ class ScalarQuantizer(ABC):
1111
@abstractmethod
1212
def levels(self) -> npt.NDArray[np.floating]:
1313
r"""
14-
The quantizer levels $v_0, v_1, \ldots, v_{L-1}$.
14+
The quantizer levels $y_0, y_1, \ldots, y_{L-1}$.
1515
"""
1616
raise NotImplementedError
1717

1818
@cached_property
1919
@abstractmethod
2020
def thresholds(self) -> npt.NDArray[np.floating]:
2121
r"""
22-
The quantizer finite thresholds $t_1, t_2, \ldots, t_{L-1}$.
22+
The quantizer finite thresholds $\lambda_1, \lambda_2, \ldots, \lambda_{L-1}$.
2323
"""
2424
raise NotImplementedError
2525

@@ -68,20 +68,17 @@ def mean_squared_error(
6868
Parameters:
6969
input_pdf: The pdf $f_X(x)$ of the input signal.
7070
input_range: The range $(x_\mathrm{min}, x_\mathrm{max})$ of the input signal.
71-
points_per_interval: The number of points per interval for numerical integration (default: 4096).
71+
points_per_interval: The number of points per interval for numerical integration (default: `4096`).
7272
7373
Returns:
7474
mse: The mean square quantization error.
7575
"""
7676
# See [Say06, eq. (9.3)].
7777
x_min, x_max = input_range
78-
thresholds = np.concatenate(([x_min], self.thresholds, [x_max]))
78+
λ = np.concatenate(([x_min], self.thresholds, [x_max]))
7979
mse = 0.0
8080
for i, level in enumerate(self.levels):
81-
left, right = thresholds[i], thresholds[i + 1]
82-
x = np.linspace(left, right, num=points_per_interval, dtype=float)
83-
pdf = input_pdf(x)
84-
integrand: npt.NDArray[np.floating] = (level - x) ** 2 * pdf
85-
integral = np.trapezoid(integrand, x)
86-
mse += float(integral)
87-
return mse
81+
x = np.linspace(λ[i], λ[i + 1], num=points_per_interval, dtype=float)
82+
integrand = (level - x) ** 2 * input_pdf(x)
83+
mse += np.trapezoid(integrand, x)
84+
return float(mse)

0 commit comments

Comments
 (0)