Skip to content

Commit 2524d20

Browse files
committed
Converts Spectrum to a dataclass.
Formerly abstract, `Spectrum` has to be made concrete due to [mypy issue#5374](python/mypy#5374) `Only concrete class can be given where "Type[Abstract]" is expected`
1 parent 2f3e2e9 commit 2524d20

File tree

8 files changed

+40
-35
lines changed

8 files changed

+40
-35
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ classifiers =
1717
Programming Language :: Python :: 3.9
1818
Topic :: Software Development :: Libraries :: Python Modules
1919

20-
2120
[options]
2221
packages = spectra
2322
python_requires = >= 3.9
@@ -38,6 +37,7 @@ line_length = 120
3837
[mypy]
3938
files = spectra, tests
4039
ignore_missing_imports = true
40+
plugins = numpy.typing.mypy_plugin
4141

4242
[flake8]
4343
ignore = E203, E266, E501, W503, E731

spectra/_abc_spectrum.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,49 @@
11
from __future__ import annotations
22

3-
from abc import ABC, abstractmethod
3+
from dataclasses import dataclass
4+
from datetime import datetime
45
from typing import TYPE_CHECKING, Generator, Iterable, Optional
56

67
import numpy as np
7-
from numpy.typing import ArrayLike
88

9+
from ._typing import ArrayLike
910
from .tools import index_of_x, integrate, read_csvs
1011

1112

12-
class Spectrum(ABC):
13+
@dataclass
14+
class Spectrum:
15+
name: str
16+
energies: np.ndarray
17+
intensities: np.ndarray
18+
units: Optional[str] = None
19+
style: Optional[str] = None
20+
time: Optional[datetime] = None
21+
1322
def __init__(
1423
self,
1524
name: str,
1625
energies: ArrayLike,
1726
intensities: ArrayLike,
1827
units: Optional[str] = None,
1928
style: Optional[str] = None,
20-
time=None,
29+
time: Optional[datetime] = None,
2130
):
22-
energies = np.asarray(energies)
23-
intensities = np.asarray(intensities)
24-
assert len(energies.shape) == 1
25-
assert energies.shape == intensities.shape
26-
2731
self.name = name
2832
self.energies = np.asarray(energies, dtype=float)
2933
self.intensities = np.asarray(intensities, dtype=float)
3034
self.units = units
3135
self.style = style
3236
self.time = time
3337

38+
def __post_init__(self):
39+
self.energies = np.asarray(self.energies, dtype=float)
40+
self.intensities = np.asarray(self.intensities, dtype=float)
41+
assert len(self.energies.shape) == 1
42+
assert self.energies.shape == self.intensities.shape
43+
3444
def __repr__(self) -> str:
3545
return f"<{self.__class__.__name__}: {self.name}>"
3646

37-
def __str__(self) -> str:
38-
return repr(self)
39-
4047
def __iter__(self) -> Generator[tuple[float, float], None, None]:
4148
"""
4249
Iterate over points in the Spectrum.
@@ -73,9 +80,8 @@ def __abs__(self) -> Spectrum:
7380
def __radd__(self, other: float) -> Spectrum:
7481
return self.__add__(other)
7582

76-
@abstractmethod
7783
def __add__(self, other: float) -> Spectrum:
78-
pass
84+
raise NotImplementedError()
7985

8086
def __rtruediv__(self, other: float) -> Spectrum:
8187
new = self.copy()
@@ -282,9 +288,8 @@ def sliced(self, start: Optional[float] = None, end: Optional[float] = None) ->
282288
new.intensities = new.intensities[start_i:end_i]
283289
return new
284290

285-
@abstractmethod
286291
def smoothed(self, box_pts: int | bool = True) -> Spectrum:
287-
pass
292+
raise NotImplementedError()
288293

289294
@classmethod
290295
def from_csvs(cls, *inps: str, names: Optional[Iterable[str]] = None):

spectra/_typing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44

55
from matplotlib.axes import Axes
66
from matplotlib.figure import Figure
7+
from numpy import ndarray
78

89
ITER_STR = Optional[Union[Iterable[str], str]]
910
ITER_FLOAT = Optional[Union[Iterable[float], float]]
1011

1112
OPT_PLOT = Optional[tuple[Figure, Axes]]
13+
14+
ArrayLike = Union[Iterable, ndarray]

spectra/progress.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44

55
import matplotlib.pyplot as plt
66
import numpy as np
7-
from numpy.typing import ArrayLike
87

9-
from ._typing import OPT_PLOT
8+
from ._typing import OPT_PLOT, ArrayLike
109
from .conv_spectrum import ConvSpectrum
1110
from .tools import integrate, smooth_curve
1211

spectra/shapes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
2-
from numpy.typing import ArrayLike
2+
3+
from ._typing import ArrayLike
34

45

56
def gaussian(energy: float, width: float, xs: ArrayLike) -> np.ndarray:

spectra/sticks_spectrum.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,26 @@
11
from __future__ import annotations
22

3+
from dataclasses import dataclass
34
from typing import Optional
45

56
import numpy as np
6-
from numpy.typing import ArrayLike
77

88
from ._abc_spectrum import Spectrum
99
from .conv_spectrum import ConvSpectrum
1010
from .shapes import gaussian
1111

1212

13+
@dataclass
1314
class SticksSpectrum(Spectrum):
1415
"""
1516
A SticksSpectrum is a collection of intensities at various energies
1617
These may be convolved with a shape to produce a ConvSpectrum.
1718
"""
1819

19-
def __init__(
20-
self,
21-
name: str,
22-
energies: ArrayLike,
23-
intensities: ArrayLike,
24-
units: Optional[str] = None,
25-
style: Optional[str] = None,
26-
time=None,
27-
y_shift: float = 0,
28-
):
29-
super().__init__(name, energies, intensities, units, style, time)
30-
self.y_shift = y_shift
20+
y_shift: float = 0
21+
22+
def __eq__(self, other):
23+
return self.y_shift == other.y_shift and super().__eq__(other)
3124

3225
def __rsub__(self, other: float) -> SticksSpectrum:
3326
"""

spectra/tools.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from typing import TYPE_CHECKING, Generator, Iterable, Optional, Sequence
77

88
import numpy as np
9-
from numpy.typing import ArrayLike
109
from scipy import constants
1110

11+
from ._typing import ArrayLike
12+
1213
if TYPE_CHECKING:
1314
from ._abc_spectrum import Spectrum
1415

tests/test_sticks_spectrum.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ def test_str():
5757
energies, intensities = np.arange(10), np.arange(10)
5858
s1 = SticksSpectrum("Hello World", energies, intensities)
5959

60-
assert str(s1) == "<SticksSpectrum: Hello World>"
60+
assert (
61+
str(s1)
62+
== "SticksSpectrum(name='Hello World', energies=array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), intensities=array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), units=None, style=None, time=None, y_shift=0)"
63+
)
6164

6265

6366
def test_add_sub():

0 commit comments

Comments
 (0)