diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index aae4371c8a..e889d23ccf 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -8,6 +8,10 @@ which are the minimum and maximum among the node times and mutation times, respectively. (:user:`szhan`, :pr:`2612`, :issue:`2271`) +- The ``msprime.RateMap`` class has been ported into tskit: functionality should + be identical to the version in msprime, apart from minor changes in the formatting + of tabular text output (:user:`hyanwong`, :pr:`2636`) + **Breaking Changes** - the ``filter_populations``, ``filter_individuals``, and ``filter_sites`` diff --git a/python/tests/test_intervals.py b/python/tests/test_intervals.py new file mode 100644 index 0000000000..2c558ac264 --- /dev/null +++ b/python/tests/test_intervals.py @@ -0,0 +1,857 @@ +# MIT License +# +# Copyright (c) 2019-2022 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for the intervals module. +""" +import decimal +import fractions +import gzip +import io +import os +import pickle +import textwrap +import xml + +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import tskit + + +class TestRateMapErrors: + @pytest.mark.parametrize( + ("position", "rate"), + [ + ([], []), + ([0], []), + ([0], [0]), + ([1, 2], [0]), + ([0, -1], [0]), + ([0, 1], [-1]), + ], + ) + def test_bad_input(self, position, rate): + with pytest.raises(ValueError): + tskit.RateMap(position=position, rate=rate) + + def test_zero_length_interval(self): + with pytest.raises(ValueError, match=r"at indexes \[2 4\]"): + tskit.RateMap(position=[0, 1, 1, 2, 2, 3], rate=[0, 0, 0, 0, 0]) + + def test_bad_length(self): + positions = np.array([0, 1, 2]) + rates = np.array([0, 1, 2]) + with pytest.raises(ValueError, match="one less entry"): + tskit.RateMap(position=positions, rate=rates) + + def test_bad_first_pos(self): + positions = np.array([1, 2, 3]) + rates = np.array([1, 1]) + with pytest.raises(ValueError, match="First position"): + tskit.RateMap(position=positions, rate=rates) + + def test_bad_rate(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, -1]) + with pytest.raises(ValueError, match="negative.*1"): + tskit.RateMap(position=positions, rate=rates) + + def test_bad_rate_with_missing(self): + positions = np.array([0, 1, 2]) + rates = np.array([np.nan, -1]) + with pytest.raises(ValueError, match="negative.*1"): + tskit.RateMap(position=positions, rate=rates) + + def test_read_only(self): + positions = np.array([0, 0.25, 0.5, 0.75, 1]) + rates = np.array([0.125, 0.25, 0.5, 0.75]) # 1 shorter than positions + rate_map = tskit.RateMap(position=positions, rate=rates) + assert np.all(rates == rate_map.rate) + assert np.all(positions == rate_map.position) + with pytest.raises(AttributeError): + rate_map.rate = 2 * rate_map.rate + with pytest.raises(AttributeError): + rate_map.position = 2 * rate_map.position + with pytest.raises(AttributeError): + rate_map.left = 1234 + with pytest.raises(AttributeError): + rate_map.right = 1234 + with pytest.raises(AttributeError): + rate_map.mid = 1234 + with pytest.raises(ValueError): + rate_map.rate[0] = 1 + with pytest.raises(ValueError): + rate_map.position[0] = 1 + with pytest.raises(ValueError): + rate_map.left[0] = 1 + with pytest.raises(ValueError): + rate_map.mid[0] = 1 + with pytest.raises(ValueError): + rate_map.right[0] = 1 + + +class TestGetRateAllKnown: + examples = [ + tskit.RateMap(position=[0, 1], rate=[0]), + tskit.RateMap(position=[0, 1], rate=[0.1]), + tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]), + tskit.RateMap(position=[0, 1, 2], rate=[0, 0.2]), + tskit.RateMap(position=[0, 1, 2], rate=[0.1, 1e-6]), + tskit.RateMap(position=range(100), rate=range(99)), + ] + + @pytest.mark.parametrize("rate_map", examples) + def test_get_rate_mid(self, rate_map): + rate = rate_map.get_rate(rate_map.mid) + assert len(rate) == len(rate_map) + for j in range(len(rate_map)): + assert rate[j] == rate_map[rate_map.mid[j]] + + @pytest.mark.parametrize("rate_map", examples) + def test_get_rate_left(self, rate_map): + rate = rate_map.get_rate(rate_map.left) + assert len(rate) == len(rate_map) + for j in range(len(rate_map)): + assert rate[j] == rate_map[rate_map.left[j]] + + @pytest.mark.parametrize("rate_map", examples) + def test_get_rate_right(self, rate_map): + rate = rate_map.get_rate(rate_map.right[:-1]) + assert len(rate) == len(rate_map) - 1 + for j in range(len(rate_map) - 1): + assert rate[j] == rate_map[rate_map.right[j]] + + +class TestOperations: + examples = [ + tskit.RateMap.uniform(sequence_length=1, rate=0), + tskit.RateMap.uniform(sequence_length=1, rate=0.1), + tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]), + tskit.RateMap(position=[0, 1, 2], rate=[0, 0.2]), + tskit.RateMap(position=[0, 1, 2], rate=[0.1, 1e-6]), + tskit.RateMap(position=range(100), rate=range(99)), + # Missing data + tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0]), + tskit.RateMap(position=[0, 1, 2], rate=[0, np.nan]), + tskit.RateMap(position=[0, 1, 2, 3], rate=[0, np.nan, 1]), + ] + + @pytest.mark.parametrize("rate_map", examples) + def test_num_intervals(self, rate_map): + assert rate_map.num_intervals == len(rate_map.rate) + assert rate_map.num_missing_intervals == np.sum(np.isnan(rate_map.rate)) + assert rate_map.num_non_missing_intervals == np.sum(~np.isnan(rate_map.rate)) + + @pytest.mark.parametrize("rate_map", examples) + def test_mask_arrays(self, rate_map): + assert_array_equal(rate_map.missing, np.isnan(rate_map.rate)) + assert_array_equal(rate_map.non_missing, ~np.isnan(rate_map.rate)) + + @pytest.mark.parametrize("rate_map", examples) + def test_missing_intervals(self, rate_map): + missing = [] + for left, right, rate in zip(rate_map.left, rate_map.right, rate_map.rate): + if np.isnan(rate): + missing.append([left, right]) + if len(missing) == 0: + assert len(rate_map.missing_intervals()) == 0 + else: + assert_array_equal(missing, rate_map.missing_intervals()) + + @pytest.mark.parametrize("rate_map", examples) + def test_mean_rate(self, rate_map): + total_span = 0 + total_mass = 0 + for span, mass in zip(rate_map.span, rate_map.mass): + if not np.isnan(mass): + total_span += span + total_mass += mass + assert total_mass / total_span == rate_map.mean_rate + + @pytest.mark.parametrize("rate_map", examples) + def test_total_mass(self, rate_map): + assert rate_map.total_mass == np.nansum(rate_map.mass) + + @pytest.mark.parametrize("rate_map", examples) + def test_get_cumulative_mass(self, rate_map): + assert list(rate_map.get_cumulative_mass([0])) == [0] + assert list(rate_map.get_cumulative_mass([rate_map.sequence_length])) == [ + rate_map.total_mass + ] + assert_array_equal( + rate_map.get_cumulative_mass(rate_map.right), np.nancumsum(rate_map.mass) + ) + + @pytest.mark.parametrize("rate_map", examples) + def test_get_rate(self, rate_map): + assert_array_equal(rate_map.get_rate([0]), rate_map.rate[0]) + assert_array_equal( + rate_map.get_rate([rate_map.sequence_length - 1e-9]), rate_map.rate[-1] + ) + assert_array_equal(rate_map.get_rate(rate_map.left), rate_map.rate) + + @pytest.mark.parametrize("rate_map", examples) + def test_map_semantics(self, rate_map): + assert len(rate_map) == rate_map.num_non_missing_intervals + assert_array_equal(list(rate_map.keys()), rate_map.mid[rate_map.non_missing]) + for x in rate_map.left[rate_map.missing]: + assert x not in rate_map + for x in rate_map.mid[rate_map.missing]: + assert x not in rate_map + + def test_asdict(self): + rate_map = tskit.RateMap.uniform(sequence_length=2, rate=4) + d = rate_map.asdict() + assert_array_equal(d["position"], np.array([0.0, 2.0])) + assert_array_equal(d["rate"], np.array([4.0])) + + +class TestFindIndex: + def test_one_interval(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + for j in range(10): + assert rate_map.find_index(j) == 0 + assert rate_map.find_index(0.0001) == 0 + assert rate_map.find_index(9.999) == 0 + + def test_two_intervals(self): + rate_map = tskit.RateMap(position=[0, 5, 10], rate=[0.1, 0.1]) + assert rate_map.find_index(0) == 0 + assert rate_map.find_index(0.0001) == 0 + assert rate_map.find_index(4.9999) == 0 + assert rate_map.find_index(5) == 1 + assert rate_map.find_index(5.1) == 1 + assert rate_map.find_index(7) == 1 + assert rate_map.find_index(9.999) == 1 + + def test_three_intervals(self): + rate_map = tskit.RateMap(position=[0, 5, 10, 15], rate=[0.1, 0.1, 0.1]) + assert rate_map.find_index(0) == 0 + assert rate_map.find_index(0.0001) == 0 + assert rate_map.find_index(4.9999) == 0 + assert rate_map.find_index(5) == 1 + assert rate_map.find_index(5.1) == 1 + assert rate_map.find_index(7) == 1 + assert rate_map.find_index(9.999) == 1 + assert rate_map.find_index(10) == 2 + assert rate_map.find_index(10.1) == 2 + assert rate_map.find_index(12) == 2 + assert rate_map.find_index(14.9999) == 2 + + def test_out_of_bounds(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + for bad_value in [-1, -0.0001, 10, 10.0001, 1e9]: + with pytest.raises(KeyError, match="out of bounds"): + rate_map.find_index(bad_value) + + def test_input_types(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + assert rate_map.find_index(0) == 0 + assert rate_map.find_index(0.0) == 0 + assert rate_map.find_index(np.zeros(1)[0]) == 0 + + +class TestSimpleExamples: + def test_all_missing_one_interval(self): + with pytest.raises(ValueError, match="missing data"): + tskit.RateMap(position=[0, 10], rate=[np.nan]) + + def test_all_missing_two_intervals(self): + with pytest.raises(ValueError, match="missing data"): + tskit.RateMap(position=[0, 5, 10], rate=[np.nan, np.nan]) + + def test_count(self): + rate_map = tskit.RateMap(position=[0, 5, 10], rate=[np.nan, 1]) + assert rate_map.num_intervals == 2 + assert rate_map.num_missing_intervals == 1 + assert rate_map.num_non_missing_intervals == 1 + + def test_missing_arrays(self): + rate_map = tskit.RateMap(position=[0, 5, 10], rate=[np.nan, 1]) + assert list(rate_map.missing) == [True, False] + assert list(rate_map.non_missing) == [False, True] + + def test_missing_at_start_mean_rate(self): + positions = np.array([0, 0.5, 1, 2]) + rates = np.array([np.nan, 0, 1]) + rate_map = tskit.RateMap(position=positions, rate=rates) + assert np.isclose(rate_map.mean_rate, 1 / (1 + 0.5)) + + def test_missing_at_end_mean_rate(self): + positions = np.array([0, 1, 1.5, 2]) + rates = np.array([1, 0, np.nan]) + rate_map = tskit.RateMap(position=positions, rate=rates) + assert np.isclose(rate_map.mean_rate, 1 / (1 + 0.5)) + + def test_interval_properties_all_known(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + assert list(rate_map.left) == [0, 1, 2] + assert list(rate_map.right) == [1, 2, 3] + assert list(rate_map.mid) == [0.5, 1.5, 2.5] + assert list(rate_map.span) == [1, 1, 1] + assert list(rate_map.mass) == [0.1, 0.2, 0.3] + + def test_pickle_non_missing(self): + r1 = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + r2 = pickle.loads(pickle.dumps(r1)) + assert r1 == r2 + + def test_pickle_missing(self): + r1 = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, np.nan, 0.3]) + r2 = pickle.loads(pickle.dumps(r1)) + assert r1 == r2 + + def test_get_cumulative_mass_all_known(self): + rate_map = tskit.RateMap(position=[0, 10, 20, 30], rate=[0.1, 0.2, 0.3]) + assert list(rate_map.mass) == [1, 2, 3] + assert list(rate_map.get_cumulative_mass([10, 20, 30])) == [1, 3, 6] + + def test_cumulative_mass_missing(self): + rate_map = tskit.RateMap(position=[0, 10, 20, 30], rate=[0.1, np.nan, 0.3]) + assert list(rate_map.get_cumulative_mass([10, 20, 30])) == [1, 1, 4] + + +class TestDisplay: + def test_str(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + s = """\ + ╔════╤═════╤═══╤════╤════╗ + ║left│right│mid│span│rate║ + ╠════╪═════╪═══╪════╪════╣ + ║0 │10 │ 5│ 10│ 0.1║ + ╚════╧═════╧═══╧════╧════╝ + """ + assert textwrap.dedent(s) == str(rate_map) + + def test_str_scinot(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.000001]) + s = """\ + ╔════╤═════╤═══╤════╤═════╗ + ║left│right│mid│span│rate ║ + ╠════╪═════╪═══╪════╪═════╣ + ║0 │10 │ 5│ 10│1e-06║ + ╚════╧═════╧═══╧════╧═════╝ + """ + print(str(rate_map)) + assert textwrap.dedent(s) == str(rate_map) + + def test_repr(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + s = "RateMap(position=array([ 0., 10.]), rate=array([0.1]))" + assert repr(rate_map) == s + + def test_repr_html(self): + rate_map = tskit.RateMap(position=[0, 10], rate=[0.1]) + html = rate_map._repr_html_() + root = xml.etree.ElementTree.fromstring(html) + assert root.tag == "div" + table = root.find("table") + rows = list(table.find("tbody")) + assert len(rows) == 1 + + def test_long_table(self): + n = 100 + rate_map = tskit.RateMap(position=range(n + 1), rate=[0.1] * n) + headers, data = rate_map._text_header_and_rows(limit=20) + assert len(headers) == 5 + assert len(data) == 21 + # check some left values + assert int(data[0][0]) == 0 + assert int(data[-1][0]) == n - 1 + + def test_short_table(self): + n = 10 + rate_map = tskit.RateMap(position=range(n + 1), rate=[0.1] * n) + headers, data = rate_map._text_header_and_rows(limit=20) + assert len(headers) == 5 + assert len(data) == n + # check some left values. + assert int(data[0][0]) == 0 + assert int(data[-1][0]) == n - 1 + + +class TestRateMapIsMapping: + def test_items(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + items = list(rate_map.items()) + assert items[0] == (0.5, 0.1) + assert items[1] == (1.5, 0.2) + assert items[2] == (2.5, 0.3) + + def test_keys(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + assert list(rate_map.keys()) == [0.5, 1.5, 2.5] + + def test_values(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + assert list(rate_map.values()) == [0.1, 0.2, 0.3] + + def test_in_points(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + # Any point within the map are True + for x in [0, 0.5, 1, 2.9999]: + assert x in rate_map + # Points outside the map are False + for x in [-1, -0.0001, 3, 3.1]: + assert x not in rate_map + + def test_in_slices(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + # slices that are within the map are "in" + for x in [slice(0, 0.5), slice(0, 1), slice(0, 2), slice(2, 3), slice(0, 3)]: + assert x in rate_map + # Any slice that doesn't fully intersect with the map "not in" + assert slice(-0.001, 1) not in rate_map + assert slice(0, 3.0001) not in rate_map + assert slice(2.9999, 3.0001) not in rate_map + assert slice(3, 4) not in rate_map + assert slice(-2, -1) not in rate_map + + def test_other_types_not_in(self): + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + for other_type in [None, "sdf", "123", {}, [], Exception]: + assert other_type not in rate_map + + def test_len(self): + rate_map = tskit.RateMap(position=[0, 1], rate=[0.1]) + assert len(rate_map) == 1 + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + assert len(rate_map) == 2 + rate_map = tskit.RateMap(position=[0, 1, 2, 3], rate=[0.1, 0.2, 0.3]) + assert len(rate_map) == 3 + + def test_immutable(self): + rate_map = tskit.RateMap(position=[0, 1], rate=[0.1]) + with pytest.raises(TypeError, match="item assignment"): + rate_map[0] = 1 + with pytest.raises(TypeError, match="item deletion"): + del rate_map[0] + + def test_eq(self): + r1 = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + r2 = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + assert r1 == r1 + assert r1 == r2 + r2 = tskit.RateMap(position=[0, 1, 3], rate=[0.1, 0.2]) + assert r1 != r2 + assert tskit.RateMap(position=[0, 1], rate=[0.1]) != tskit.RateMap( + position=[0, 1], rate=[0.2] + ) + assert tskit.RateMap(position=[0, 1], rate=[0.1]) != tskit.RateMap( + position=[0, 10], rate=[0.1] + ) + + def test_getitem_value(self): + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + assert rate_map[0] == 0.1 + assert rate_map[0.5] == 0.1 + assert rate_map[1] == 0.2 + assert rate_map[1.5] == 0.2 + assert rate_map[1.999] == 0.2 + # Try other types + assert rate_map[np.array([1], dtype=np.float32)[0]] == 0.2 + assert rate_map[np.array([1], dtype=np.int32)[0]] == 0.2 + assert rate_map[np.array([1], dtype=np.float64)[0]] == 0.2 + assert rate_map[1 / 2] == 0.1 + assert rate_map[fractions.Fraction(1, 3)] == 0.1 + assert rate_map[decimal.Decimal(1)] == 0.2 + + def test_getitem_slice(self): + r1 = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + # The semantics of the slice() function are tested elsewhere. + assert r1[:] == r1.copy() + assert r1[:] is not r1 + assert r1[1:] == r1.slice(left=1) + assert r1[:1.5] == r1.slice(right=1.5) + assert r1[0.5:1.5] == r1.slice(left=0.5, right=1.5) + + def test_getitem_slice_step(self): + r1 = tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]) + # Trying to set a "step" is a error + with pytest.raises(TypeError, match="interval slicing"): + r1[0:3:1] + + +class TestMappingMissingData: + def test_get_missing(self): + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) + with pytest.raises(KeyError, match="within a missing interval"): + rate_map[0] + with pytest.raises(KeyError, match="within a missing interval"): + rate_map[0.999] + + def test_in_missing(self): + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) + assert 0 not in rate_map + assert 0.999 not in rate_map + assert 1 in rate_map + + def test_keys_missing(self): + rate_map = tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0.2]) + assert list(rate_map.keys()) == [1.5] + + +class TestGetIntermediates: + def test_get_rate(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, 4]) + rate_map = tskit.RateMap(position=positions, rate=rates) + assert np.all(rate_map.get_rate([0.5, 1.5]) == rates) + + def test_get_rate_out_of_bounds(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, 4]) + rate_map = tskit.RateMap(position=positions, rate=rates) + with pytest.raises(ValueError, match="out of bounds"): + rate_map.get_rate([1, -0.1]) + with pytest.raises(ValueError, match="out of bounds"): + rate_map.get_rate([2]) + + def test_get_cumulative_mass(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, 4]) + rate_map = tskit.RateMap(position=positions, rate=rates) + assert np.allclose(rate_map.get_cumulative_mass([0.5, 1.5]), np.array([0.5, 3])) + assert rate_map.get_cumulative_mass([2]) == rate_map.total_mass + + def test_get_bad_cumulative_mass(self): + positions = np.array([0, 1, 2]) + rates = np.array([1, 4]) + rate_map = tskit.RateMap(position=positions, rate=rates) + with pytest.raises(ValueError, match="positions"): + rate_map.get_cumulative_mass([1, -0.1]) + with pytest.raises(ValueError, match="positions"): + rate_map.get_cumulative_mass([1, 2.1]) + + +class TestSlice: + def test_slice_no_params(self): + # test RateMap.slice(..., trim=False) + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + b = a.slice() + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + assert a == b + + def test_slice_left_examples(self): + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + b = a.slice(left=50) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 50, 100, 200, 300, 400], b.position) + assert_array_equal([np.nan, 0, 1, 2, 3], b.rate) + + b = a.slice(left=100) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 100, 200, 300, 400], b.position) + assert_array_equal([np.nan, 1, 2, 3], b.rate) + + b = a.slice(left=150) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 150, 200, 300, 400], b.position) + assert_array_equal([np.nan, 1, 2, 3], b.rate) + + def test_slice_right_examples(self): + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + b = a.slice(right=300) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 100, 200, 300, 400], b.position) + assert_array_equal([0, 1, 2, np.nan], b.rate) + + b = a.slice(right=250) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 100, 200, 250, 400], b.position) + assert_array_equal([0, 1, 2, np.nan], b.rate) + + def test_slice_left_right_examples(self): + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, 3]) + b = a.slice(left=50, right=300) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 50, 100, 200, 300, 400], b.position) + assert_array_equal([np.nan, 0, 1, 2, np.nan], b.rate) + + b = a.slice(left=150, right=250) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 150, 200, 250, 400], b.position) + assert_array_equal([np.nan, 1, 2, np.nan], b.rate) + + b = a.slice(left=150, right=300) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 150, 200, 300, 400], b.position) + assert_array_equal([np.nan, 1, 2, np.nan], b.rate) + + b = a.slice(left=150, right=160) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 150, 160, 400], b.position) + assert_array_equal([np.nan, 1, np.nan], b.rate) + + def test_slice_right_missing(self): + # If we take a right-slice into a trailing missing region, + # we should recover the same map. + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[0, 1, 2, np.nan]) + b = a.slice(right=350) + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + + b = a.slice(right=300) + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + + def test_slice_left_missing(self): + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[np.nan, 1, 2, 3]) + b = a.slice(left=50) + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + + b = a.slice(left=100) + assert a.sequence_length == b.sequence_length + assert_array_equal(a.position, b.position) + assert_array_equal(a.rate, b.rate) + + def test_slice_with_floats(self): + # test RateMap.slice(..., trim=False) with floats + a = tskit.RateMap( + position=[np.pi * x for x in [0, 100, 200, 300, 400]], rate=[0, 1, 2, 3] + ) + b = a.slice(left=50 * np.pi) + assert a.sequence_length == b.sequence_length + assert_array_equal([0, 50 * np.pi] + list(a.position[1:]), b.position) + assert_array_equal([np.nan] + list(a.rate), b.rate) + + def test_slice_trim_left(self): + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[1, 2, 3, 4]) + b = a.slice(left=100, trim=True) + assert b == tskit.RateMap(position=[0, 100, 200, 300], rate=[2, 3, 4]) + b = a.slice(left=50, trim=True) + assert b == tskit.RateMap(position=[0, 50, 150, 250, 350], rate=[1, 2, 3, 4]) + + def test_slice_trim_right(self): + a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[1, 2, 3, 4]) + b = a.slice(right=300, trim=True) + assert b == tskit.RateMap(position=[0, 100, 200, 300], rate=[1, 2, 3]) + b = a.slice(right=350, trim=True) + assert b == tskit.RateMap(position=[0, 100, 200, 300, 350], rate=[1, 2, 3, 4]) + + def test_slice_error(self): + recomb_map = tskit.RateMap(position=[0, 100], rate=[1]) + with pytest.raises(KeyError): + recomb_map.slice(left=-1) + with pytest.raises(KeyError): + recomb_map.slice(right=-1) + with pytest.raises(KeyError): + recomb_map.slice(left=200) + with pytest.raises(KeyError): + recomb_map.slice(right=200) + with pytest.raises(KeyError): + recomb_map.slice(left=20, right=10) + + +class TestReadHapmap: + def test_read_hapmap_simple(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 1 x 0 + chr1 2 x 0.000001 x + chr1 3 x 0.000006 x x""" + ) + rm = tskit.RateMap.read_hapmap(hapfile) + assert_array_equal(rm.position, [0, 1, 2, 3]) + assert np.allclose(rm.rate, [np.nan, 1e-8, 5e-8], equal_nan=True) + + def test_read_hapmap_from_filename(self, tmp_path): + with open(tmp_path / "hapfile.txt", "w") as hapfile: + hapfile.write( + """\ + HEADER + chr1 1 x 0 + chr1 2 x 0.000001 x + chr1 3 x 0.000006 x x""" + ) + rm = tskit.RateMap.read_hapmap(tmp_path / "hapfile.txt") + assert_array_equal(rm.position, [0, 1, 2, 3]) + assert np.allclose(rm.rate, [np.nan, 1e-8, 5e-8], equal_nan=True) + + @pytest.mark.filterwarnings("ignore:loadtxt") + def test_read_hapmap_empty(self): + hapfile = io.StringIO( + """\ + HEADER""" + ) + with pytest.raises(ValueError, match="Empty"): + tskit.RateMap.read_hapmap(hapfile) + + def test_read_hapmap_col_pos(self): + hapfile = io.StringIO( + """\ + HEADER + 0 0 + 0.000001 1 x + 0.000006 2 x x""" + ) + rm = tskit.RateMap.read_hapmap(hapfile, position_col=1, map_col=0) + assert_array_equal(rm.position, [0, 1, 2]) + assert np.allclose(rm.rate, [1e-8, 5e-8]) + + def test_read_hapmap_map_and_rate(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 0 0 + chr1 1 1 0.000001 x + chr1 2 2 0.000006 x x""" + ) + with pytest.raises(ValueError, match="both rate_col and map_col"): + tskit.RateMap.read_hapmap(hapfile, rate_col=2, map_col=3) + + def test_read_hapmap_duplicate_pos(self): + hapfile = io.StringIO( + """\ + HEADER + 0 0 + 0.000001 1 x + 0.000006 2 x x""" + ) + with pytest.raises(ValueError, match="same columns"): + tskit.RateMap.read_hapmap(hapfile, map_col=1) + + def test_read_hapmap_nonzero_rate_start(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 1 5 x + chr1 2 0 x x x""" + ) + rm = tskit.RateMap.read_hapmap(hapfile, rate_col=2) + assert_array_equal(rm.position, [0, 1, 2]) + assert_array_equal(rm.rate, [np.nan, 5e-8]) + + def test_read_hapmap_nonzero_rate_end(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 5 x + chr1 2 1 x x x""" + ) + with pytest.raises(ValueError, match="last entry.*must be zero"): + tskit.RateMap.read_hapmap(hapfile, rate_col=2) + + def test_read_hapmap_gzipped(self, tmp_path): + hapfile = os.path.join(tmp_path, "hapmap.txt.gz") + with gzip.GzipFile(hapfile, "wb") as gzfile: + gzfile.write(b"HEADER\n") + gzfile.write(b"chr1 0 1\n") + gzfile.write(b"chr1 1 5.5\n") + gzfile.write(b"chr1 2 0\n") + rm = tskit.RateMap.read_hapmap(hapfile, rate_col=2) + assert_array_equal(rm.position, [0, 1, 2]) + assert_array_equal(rm.rate, [1e-8, 5.5e-8]) + + def test_read_hapmap_nonzero_map_start(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 1 x 0.000001 + chr1 2 x 0.000001 x + chr1 3 x 0.000006 x x x""" + ) + rm = tskit.RateMap.read_hapmap(hapfile) + assert_array_equal(rm.position, [0, 1, 2, 3]) + assert np.allclose(rm.rate, [1e-8, 0, 5e-8]) + + def test_read_hapmap_bad_nonzero_map_start(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 x 0.0000005 + chr1 1 x 0.000001 x + chr1 2 x 0.000006 x x x""" + ) + with pytest.raises(ValueError, match="start.*must be zero"): + tskit.RateMap.read_hapmap(hapfile) + + def test_sequence_length(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 x 0 + chr1 1 x 0.000001 x + chr1 2 x 0.000006 x x x""" + ) + # test identical seq len + rm = tskit.RateMap.read_hapmap(hapfile, sequence_length=2) + assert_array_equal(rm.position, [0, 1, 2]) + assert np.allclose(rm.rate, [1e-8, 5e-8]) + + hapfile.seek(0) + rm = tskit.RateMap.read_hapmap(hapfile, sequence_length=10) + assert_array_equal(rm.position, [0, 1, 2, 10]) + assert np.allclose(rm.rate, [1e-8, 5e-8, np.nan], equal_nan=True) + + def test_bad_sequence_length(self): + hapfile = io.StringIO( + """\ + HEADER + chr1 0 x 0 + chr1 1 x 0.000001 x + chr1 2 x 0.000006 x x x""" + ) + with pytest.raises(ValueError, match="sequence_length"): + tskit.RateMap.read_hapmap(hapfile, sequence_length=1.999) + + def test_no_header(self): + data = """\ + chr1 0 x 0 + chr1 1 x 0.000001 x + chr1 2 x 0.000006 x x x""" + hapfile_noheader = io.StringIO(data) + hapfile_header = io.StringIO("chr pos rate cM\n" + data) + with pytest.raises(ValueError): + tskit.RateMap.read_hapmap(hapfile_header, has_header=False) + rm1 = tskit.RateMap.read_hapmap(hapfile_header) + rm2 = tskit.RateMap.read_hapmap(hapfile_noheader, has_header=False) + assert_array_equal(rm1.rate, rm2.rate) + assert_array_equal(rm1.position, rm2.position) + + def test_hapmap_fragment(self): + hapfile = io.StringIO( + """\ + chr pos rate cM + 1 4283592 3.79115663174456 0 + 1 4361401 0.0664276817058413 0.294986106359414 + 1 7979763 10.9082897515584 0.535345505591925 + 1 8007051 0.0976780648822495 0.833010916332456 + 1 8762788 0.0899929572085616 0.906829844052373 + 1 9477943 0.0864382908650907 0.971188757364862 + 1 9696341 4.76495005895746 0.990066707213216 + 1 9752154 0.0864316558730679 1.25601286485381 + 1 9881751 0.0 1.26721414815999""" + ) + rm1 = tskit.RateMap.read_hapmap(hapfile) + hapfile.seek(0) + rm2 = tskit.RateMap.read_hapmap(hapfile, rate_col=2) + assert np.allclose(rm1.position, rm2.position) + assert np.allclose(rm1.rate, rm2.rate, equal_nan=True) diff --git a/python/tests/test_util.py b/python/tests/test_util.py index cc4f9d45da..1d78dd0a08 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -489,6 +489,27 @@ def test_unicode_table(): ) +def test_unicode_table_column_alignments(): + assert ( + util.unicode_table( + [["5", "6", "7", "8"], ["90", "10", "11", "12"]], + header=["1", "2", "3", "4"], + column_alignments="<>><", + ) + == textwrap.dedent( + """ + ╔══╤══╤══╤══╗ + ║1 │2 │3 │4 ║ + ╠══╪══╪══╪══╣ + ║5 │ 6│ 7│8 ║ + ╟──┼──┼──┼──╢ + ║90│10│11│12║ + ╚══╧══╧══╧══╝ + """ + )[1:] + ) + + def test_set_printoptions(): assert tskit._print_options == {"max_lines": 40} util.set_print_options(max_lines=None) diff --git a/python/tskit/__init__.py b/python/tskit/__init__.py index 09e16091e5..07dc8dc241 100644 --- a/python/tskit/__init__.py +++ b/python/tskit/__init__.py @@ -90,3 +90,4 @@ from tskit.util import * # NOQA from tskit.metadata import * # NOQA from tskit.text_formats import * # NOQA +from tskit.intervals import RateMap # NOQA diff --git a/python/tskit/intervals.py b/python/tskit/intervals.py new file mode 100644 index 0000000000..1b671a2143 --- /dev/null +++ b/python/tskit/intervals.py @@ -0,0 +1,600 @@ +# +# MIT License +# +# Copyright (c) 2022 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Utilities for working with intervals and interval maps. +""" +from __future__ import annotations + +import collections.abc +import numbers + +import numpy as np + +import tskit +import tskit.util as util + + +class RateMap(collections.abc.Mapping): + """ + A class mapping a non-negative rate value to a set of non-overlapping intervals + along the genome. Intervals for which the rate is unknown (i.e., missing data) + are encoded by NaN values in the ``rate`` array. + + :param list position: A list of :math:`n+1` positions, starting at 0, and ending + in the sequence length over which the RateMap will apply. + :param list rate: A list of :math:`n` positive rates that apply between each + position. Intervals with missing data are encoded by NaN values. + """ + + # The args are marked keyword only to give us some flexibility in how we + # create class this in the future. + def __init__( + self, + *, + position, + rate, + ): + # Making the arrays read-only guarantees rate and cumulative mass stay in sync + # We prevent the arrays themselves being overwritten by making self.position, + # etc properties. + + # TODO we always coerce the position type to float here, but we may not + # want to do this. int32 is a perfectly good choice a lot of the time. + self._position = np.array(position, dtype=float) + self._position.flags.writeable = False + self._rate = np.array(rate, dtype=float) + self._rate.flags.writeable = False + size = len(self._position) + if size < 2: + raise ValueError("Must have at least two positions") + if len(self._rate) != size - 1: + raise ValueError( + "Rate array must have one less entry than the position array" + ) + if self._position[0] != 0: + raise ValueError("First position must be zero") + + span = self.span + if np.any(span <= 0): + bad_pos = np.where(span <= 0)[0] + 1 + raise ValueError( + f"Position values not strictly increasing at indexes {bad_pos}" + ) + if np.any(self._rate < 0): + bad_rates = np.where(self._rate < 0)[0] + raise ValueError(f"Rate values negative at indexes {bad_rates}") + self._missing = np.isnan(self.rate) + self._num_missing_intervals = np.sum(self._missing) + if self._num_missing_intervals == len(self.rate): + raise ValueError("All intervals are missing data") + # We don't expose the cumulative mass array as a part of the array + # API is it's not quite as obvious how it lines up for each interval. + # It's really the sum of the mass up to but not including the current + # interval, which is a bit confusing. Probably best to just leave + # it as a function, so that people can sample at regular positions + # along the genome anyway, emphasising that it's a continuous function, + # not a step function like the other interval attributes. + self._cumulative_mass = np.insert(np.nancumsum(self.mass), 0, 0) + assert self._cumulative_mass[0] == 0 + self._cumulative_mass.flags.writeable = False + + @property + def left(self): + """ + The left position of each interval (inclusive). + """ + return self._position[:-1] + + @property + def right(self): + """ + The right position of each interval (exclusive). + """ + return self._position[1:] + + @property + def mid(self): + """ + Returns the midpoint of each interval. + """ + mid = self.left + self.span / 2 + mid.flags.writeable = False + return mid + + @property + def span(self): + """ + Returns the span (i.e., ``right - left``) of each of the intervals. + """ + span = self.right - self.left + span.flags.writeable = False + return span + + @property + def position(self): + """ + The breakpoint positions between intervals. This is equal to the + :attr:`~.RateMap.left` array with the :attr:`sequence_length` + appended. + """ + return self._position + + @property + def rate(self): + """ + The rate associated with each interval. Missing data is encoded + by NaN values. + """ + return self._rate + + @property + def mass(self): + r""" + The "mass" of each interval, defined as the :attr:`~.RateMap.rate` + :math:`\times` :attr:`~.RateMap.span`. This is NaN for intervals + containing missing data. + """ + return self._rate * self.span + + @property + def missing(self): + """ + A boolean array encoding whether each interval contains missing data. + Equivalent to ``np.isnan(rate_map.rate)`` + """ + return self._missing + + @property + def non_missing(self): + """ + A boolean array encoding whether each interval contains non-missing data. + Equivalent to ``np.logical_not(np.isnan(rate_map.rate))`` + """ + return ~self._missing + + # + # Interval counts + # + + @property + def num_intervals(self) -> int: + """ + The total number of intervals in this map. Equal to + :attr:`~.RateMap.num_missing_intervals` + + :attr:`~.RateMap.num_non_missing_intervals`. + """ + return len(self._rate) + + @property + def num_missing_intervals(self) -> int: + """ + Returns the number of missing intervals, i.e., those in which the + :attr:`~.RateMap.rate` value is a NaN. + """ + return self._num_missing_intervals + + @property + def num_non_missing_intervals(self) -> int: + """ + The number of non missing intervals, i.e., those in which the + :attr:`~.RateMap.rate` value is not a NaN. + """ + return self.num_intervals - self.num_missing_intervals + + @property + def sequence_length(self): + """ + The sequence length covered by this map + """ + return self.position[-1] + + @property + def total_mass(self): + """ + The cumulative total mass over the entire map. + """ + return self._cumulative_mass[-1] + + @property + def mean_rate(self): + """ + The mean rate over this map weighted by the span covered by each rate. + Unknown intervals are excluded. + """ + total_span = np.sum(self.span[self.non_missing]) + return self.total_mass / total_span + + def get_rate(self, x): + """ + Return the rate at the specified list of positions. + + .. note:: This function will return a NaN value for any positions + that contain missing data. + + :param numpy.ndarray x: The positions for which to return values. + :return: An array of rates, the same length as ``x``. + :rtype: numpy.ndarray + """ + loc = np.searchsorted(self.position, x, side="right") - 1 + if np.any(loc < 0) or np.any(loc >= len(self.rate)): + raise ValueError("position out of bounds") + return self.rate[loc] + + def get_cumulative_mass(self, x): + """ + Return the cumulative mass of the map up to (but not including) a + given point for a list of positions along the map. This is equal to + the integral of the rate from 0 to the point. + + :param numpy.ndarray x: The positions for which to return values. + + :return: An array of cumulative mass values, the same length as ``x`` + :rtype: numpy.ndarray + """ + x = np.array(x) + if np.any(x < 0) or np.any(x > self.sequence_length): + raise ValueError(f"Cannot have positions < 0 or > {self.sequence_length}") + return np.interp(x, self.position, self._cumulative_mass) + + def find_index(self, x: float) -> int: + """ + Returns the index of the interval that the specified position falls within, + such that ``rate_map.left[index] <= x < self.rate_map.right[index]``. + + :param float x: The position to search. + :return: The index of the interval containing this point. + :rtype: int + :raises: KeyError if the position is not contained in any of the intervals. + """ + if x < 0 or x >= self.sequence_length: + raise KeyError(f"Position {x} out of bounds") + index = np.searchsorted(self.position, x, side="left") + if x < self.position[index]: + index -= 1 + assert self.left[index] <= x < self.right[index] + return index + + def missing_intervals(self): + """ + Returns the left and right coordinates of the intervals containing + missing data in this map as a 2D numpy array + with shape (:attr:`~.RateMap.num_missing_intervals`, 2). Each row + of this returned array is therefore a ``left``, ``right`` tuple + corresponding to the coordinates of the missing intervals. + + :return: A numpy array of the coordinates of intervals containing + missing data. + :rtype: numpy.ndarray + """ + out = np.empty((self.num_missing_intervals, 2)) + out[:, 0] = self.left[self.missing] + out[:, 1] = self.right[self.missing] + return out + + def asdict(self): + return {"position": self.position, "rate": self.rate} + + # + # Dunder methods. We implement the Mapping protocol via __iter__, __len__ + # and __getitem__. We have some extra semantics for __getitem__, providing + # slice notation. + # + + def __iter__(self): + # The clinching argument for using mid here is that if we used + # left instead we would have + # RateMap([0, 1], [0.1]) == RateMap([0, 100], [0.1]) + # by the inherited definition of equality since the dictionary items + # would be equal. + # Similarly, we only return the midpoints of known intervals + # because NaN values are not equal, and we would need to do + # something to work around this. It seems reasonable that + # this high-level operation returns the *known* values only + # anyway. + yield from self.mid[self.non_missing] + + def __len__(self): + return np.sum(self.non_missing) + + def __getitem__(self, key): + if isinstance(key, slice): + if key.step is not None: + raise TypeError("Only interval slicing is supported") + return self.slice(key.start, key.stop) + if isinstance(key, numbers.Number): + index = self.find_index(key) + if np.isnan(self.rate[index]): + # To be consistent with the __iter__ definition above we + # don't consider these missing positions to be "in" the map. + raise KeyError(f"Position {key} is within a missing interval") + return self.rate[index] + # TODO we could implement numpy array indexing here and call + # to get_rate. Note we'd need to take care that we return a keyerror + # if the returned array contains any nans though. + raise KeyError("Key {key} not in map") + + def _text_header_and_rows(self, limit=None): + headers = ("left", "right", "mid", "span", "rate") + num_rows = len(self.left) + rows = [] + row_indexes = util.truncate_rows(num_rows, limit) + for j in row_indexes: + if j == -1: + rows.append(f"__skipped__{num_rows-limit}") + else: + rows.append( + [ + f"{self.left[j]:.10g}", + f"{self.right[j]:.10g}", + f"{self.mid[j]:.10g}", + f"{self.span[j]:.10g}", + f"{self.rate[j]:.2g}", + ] + ) + return headers, rows + + def __str__(self): + header, rows = self._text_header_and_rows( + limit=tskit._print_options["max_lines"] + ) + table = util.unicode_table( + rows=rows, + header=header, + column_alignments="<<>>>", + ) + return table + + def _repr_html_(self): + header, rows = self._text_header_and_rows( + limit=tskit._print_options["max_lines"] + ) + return util.html_table(rows, header=header) + + def __repr__(self): + return f"RateMap(position={repr(self.position)}, rate={repr(self.rate)})" + + # + # Methods for building rate maps. + # + + def copy(self) -> RateMap: + """ + Returns a deep copy of this RateMap. + """ + # We take read-only copies of the arrays in the constructor anyway, so + # no need for copying. + return RateMap(position=self.position, rate=self.rate) + + def slice(self, left=None, right=None, *, trim=False) -> RateMap: # noqa: A003 + """ + Returns a subset of this rate map in the specified interval. + + :param float left: The left coordinate (inclusive) of the region to keep. + If ``None``, defaults to 0. + :param float right: The right coordinate (exclusive) of the region to keep. + If ``None``, defaults to the sequence length. + :param bool trim: If True, remove the flanking regions such that the + sequence length of the new rate map is ``right`` - ``left``. If ``False`` + (default), do not change the coordinate system and mark the flanking + regions as "unknown". + :return: A new RateMap instance + :rtype: RateMap + """ + left = 0 if left is None else left + right = self.sequence_length if right is None else right + if not (0 <= left < right <= self.sequence_length): + raise KeyError(f"Invalid slice: left={left}, right={right}") + + i = self.find_index(left) + j = i + np.searchsorted(self.position[i:], right, side="right") + if right > self.position[j - 1]: + j += 1 + + position = self.position[i:j].copy() + rate = self.rate[i : j - 1].copy() + position[0] = left + position[-1] = right + + if trim: + # Return trimmed map with changed coords + return RateMap(position=position - left, rate=rate) + + # Need to check regions before & after sliced region are filled out: + if left != 0: + if np.isnan(rate[0]): + position[0] = 0 # Extend + else: + rate = np.insert(rate, 0, np.nan) # Prepend + position = np.insert(position, 0, 0) + if right != self.position[-1]: + if np.isnan(rate[-1]): + position[-1] = self.sequence_length # Extend + else: + rate = np.append(rate, np.nan) # Append + position = np.append(position, self.position[-1]) + return RateMap(position=position, rate=rate) + + @staticmethod + def uniform(sequence_length, rate) -> RateMap: + """ + Create a uniform rate map + """ + return RateMap(position=[0, sequence_length], rate=[rate]) + + @staticmethod + def read_hapmap( + fileobj, + sequence_length=None, + *, + has_header=True, + position_col=None, + rate_col=None, + map_col=None, + ): + # Black barfs with an INTERNAL_ERROR trying to reformat this docstring, + # so we explicitly disable reformatting here. + # fmt: off + """ + Parses the specified file in HapMap format and returns a :class:`.RateMap`. + HapMap files must white-space-delimited, and by default are assumed to + contain a single header line (which is ignored). Each subsequent line + then contains a physical position (in base pairs) and either a genetic + map position (in centiMorgans) or a recombination rate (in centiMorgans + per megabase). The value in the rate column in a given line gives the + constant rate between the physical position in that line (inclusive) and the + physical position on the next line (exclusive). + By default, the second column of the file is taken + as the physical position and the fourth column is taken as the genetic + position, as seen in the following sample of the format:: + + Chromosome Position(bp) Rate(cM/Mb) Map(cM) + chr10 48232 0.1614 0.002664 + chr10 48486 0.1589 0.002705 + chr10 50009 0.159 0.002947 + chr10 52147 0.1574 0.003287 + ... + chr10 133762002 3.358 181.129345 + chr10 133766368 0.000 181.144008 + + In the example above, the first row has a nonzero genetic map position + (last column, cM), implying a nonzero recombination rate before that + position, that is assumed to extend to the start of the chromosome + (at position 0 bp). However, if the first line has a nonzero bp position + (second column) and a zero genetic map position (last column, cM), + then the recombination rate before that position is *unknown*, producing + :ref:`missing data `. + + .. note:: + The rows are all assumed to come from the same contig, and the + first column is currently ignored. Therefore if you have a single + file containing several contigs or chromosomes, you must must split + it up into multiple files, and pass each one separately to this + function. + + :param str fileobj: Filename or file to read. This is passed directly + to :func:`numpy.loadtxt`, so if the filename extension is .gz or .bz2, + the file is decompressed first + :param float sequence_length: The total length of the map. If ``None``, + then assume it is the last physical position listed in the file. + Otherwise it must be greater then or equal to the last physical + position in the file, and the region between the last physical position + and the sequence_length is padded with a rate of zero. + :param bool has_header: If True (default), assume the file has a header row + and ignore the first line of the file. + :param int position_col: The zero-based index of the column in the file + specifying the physical position in base pairs. If ``None`` (default) + assume an index of 1 (i.e. the second column). + :param int rate_col: The zero-based index of the column in the file + specifying the rate in cM/Mb. If ``None`` (default) do not use the rate + column, but calculate rates using the genetic map positions, as + specified in ``map_col``. If the rate column is used, the + interval from 0 to first physical position in the file is marked as + unknown, and the last value in the rate column must be zero. + :param int map_col: The zero-based index of the column in the file + specifying the genetic map position in centiMorgans. If ``None`` + (default), assume an index of 3 (i.e. the fourth column). If the first + genetic position is 0 the interval from position 0 to the first + physical position in the file is marked as unknown. Otherwise, act + as if an additional row, specifying physical position 0 and genetic + position 0, exists at the start of the file. + :return: A RateMap object. + :rtype: RateMap + """ + # fmt: on + column_defs = {} # column definitions passed to np.loadtxt + if rate_col is None and map_col is None: + # Default to map_col + map_col = 3 + elif rate_col is not None and map_col is not None: + raise ValueError("Cannot specify both rate_col and map_col") + if map_col is not None: + column_defs[map_col] = ("map", float) + else: + column_defs[rate_col] = ("rate", float) + position_col = 1 if position_col is None else position_col + if position_col in column_defs: + raise ValueError( + "Cannot specify the same columns for position_col and " + "rate_col or map_col" + ) + column_defs[position_col] = ("pos", int) + + column_names = [c[0] for c in column_defs.values()] + column_data = np.loadtxt( + fileobj, + skiprows=1 if has_header else 0, + dtype=list(column_defs.values()), + usecols=list(column_defs.keys()), + unpack=True, + ) + data = dict(zip(column_names, column_data)) + + if "map" not in data: + assert "rate" in data + if data["rate"][-1] != 0: + raise ValueError("The last entry in the 'rate' column must be zero") + pos_Mb = data["pos"] / 1e6 + map_pos = np.cumsum(data["rate"][:-1] * np.diff(pos_Mb)) + data["map"] = np.insert(map_pos, 0, 0) / 100 + else: + data["map"] /= 100 # Convert centiMorgans to Morgans + if len(data["map"]) == 0: + raise ValueError("Empty hapmap file") + + # TO DO: read in chrom name from col 0 and poss set as .name + # attribute on the RateMap + + physical_positions = data["pos"] + genetic_positions = data["map"] + start = physical_positions[0] + end = physical_positions[-1] + + if genetic_positions[0] > 0 and start == 0: + raise ValueError( + "The map distance at the start of the chromosome must be zero" + ) + if start > 0: + physical_positions = np.insert(physical_positions, 0, 0) + if genetic_positions[0] > 0: + # Exception for a map that starts > 0cM: include the start rate + # in the mean + start = 0 + genetic_positions = np.insert(genetic_positions, 0, 0) + + if sequence_length is not None: + if sequence_length < end: + raise ValueError( + "The sequence_length cannot be less than the last physical position " + f" ({physical_positions[-1]})" + ) + if sequence_length > end: + physical_positions = np.append(physical_positions, sequence_length) + genetic_positions = np.append(genetic_positions, genetic_positions[-1]) + + assert genetic_positions[0] == 0 + rate = np.diff(genetic_positions) / np.diff(physical_positions) + if start != 0: + rate[0] = np.nan + if end != physical_positions[-1]: + rate[-1] = np.nan + return RateMap(position=physical_positions, rate=rate) diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 443b3a5ba8..6cd29c56f4 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -27,7 +27,6 @@ import collections.abc import dataclasses import datetime -import itertools import json import numbers import warnings @@ -657,41 +656,10 @@ def __str__(self): return util.unicode_table(rows, header=headers, row_separator=False) def _repr_html_(self): - """ - Called by jupyter notebooks to render tables - """ headers, rows = self._text_header_and_rows( limit=tskit._print_options["max_lines"] ) - headers = "".join(f"{header}" for header in headers) - rows = ( - f'{row[11:]}' - f" rows skipped (tskit.set_print_options)" - if "__skipped__" in row - else "".join(f"{cell}" for cell in row) - for row in rows - ) - rows = "".join(f"{row}\n" for row in rows) - return f""" -
- - - - - {headers} - - - - {rows} - -
-
- """ + return util.html_table(rows, header=headers) def _columns_all_integer(self, *colnames): # For displaying floating point values without loads of decimal places @@ -852,15 +820,8 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "flags", "location", "parents", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) - for j in indexes: + row_indexes = util.truncate_rows(self.num_rows, limit) + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -1105,16 +1066,9 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "flags", "population", "individual", "time", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) decimal_places_times = 0 if self._columns_all_integer("time") else 8 - for j in indexes: + for j in row_indexes: row = self[j] if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") @@ -1306,16 +1260,9 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "left", "right", "parent", "child", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) decimal_places = 0 if self._columns_all_integer("left", "right") else 8 - for j in indexes: + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -1528,17 +1475,10 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "left", "right", "node", "source", "dest", "time", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) decimal_places_coords = 0 if self._columns_all_integer("left", "right") else 8 decimal_places_times = 0 if self._columns_all_integer("time") else 8 - for j in indexes: + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -1748,16 +1688,9 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "position", "ancestral_state", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) decimal_places = 0 if self._columns_all_integer("position") else 8 - for j in indexes: + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -1971,17 +1904,10 @@ def __init__(self, max_rows_increment=0, ll_table=None): def _text_header_and_rows(self, limit=None): headers = ("id", "site", "node", "time", "derived_state", "parent", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) + row_indexes = util.truncate_rows(self.num_rows, limit) # Currently mutations do not have discretised times: this for consistency decimal_places_times = 0 if self._columns_all_integer("time") else 8 - for j in indexes: + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -2232,15 +2158,8 @@ def add_row(self, metadata=None): def _text_header_and_rows(self, limit=None): headers = ("id", "metadata") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) - for j in indexes: + row_indexes = util.truncate_rows(self.num_rows, limit) + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: @@ -2490,15 +2409,8 @@ def append_columns( def _text_header_and_rows(self, limit=None): headers = ("id", "timestamp", "record") rows = [] - if limit is None or self.num_rows <= limit: - indexes = range(self.num_rows) - else: - indexes = itertools.chain( - range(limit // 2), - [-1], - range(self.num_rows - (limit - (limit // 2)), self.num_rows), - ) - for j in indexes: + row_indexes = util.truncate_rows(self.num_rows, limit) + for j in row_indexes: if j == -1: rows.append(f"__skipped__{self.num_rows-limit}") else: diff --git a/python/tskit/util.py b/python/tskit/util.py index 9baa298ceb..bf5c95a2a1 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -23,6 +23,7 @@ Module responsible for various utility functions used in other modules. """ import dataclasses +import itertools import json import numbers import os @@ -320,7 +321,7 @@ def obj_to_collapsed_html(d, name=None, open_depth=0): :param str name: Name for this object :param int open_depth: By default sub-sections are collapsed. If this number is - non-zero the first layers up to open_depth will be opened. + non-zero the first layers up to open_depth will be opened. :return: The HTML as a string :rtype: str """ @@ -369,19 +370,25 @@ def render_metadata(md, length=40): return truncate_string_end(str(md), length) -def unicode_table(rows, title=None, header=None, row_separator=True): +def unicode_table( + rows, *, title=None, header=None, row_separator=True, column_alignments=None +): """ Convert a table (list of lists) of strings to a unicode table. If a row contains the string "__skipped__NNN" then "skipped N rows" is displayed. :param list[list[str]] rows: List of rows, each of which is a list of strings for - each cell. The first column will be left justified, the others right. Each row must - have the same number of cells. + each cell. Each row must have the same number of cells. :param str title: If specified the first output row will be a single cell - containing this string, left-justified. [optional] + containing this string, left-justified. [optional] :param list[str] header: Specifies a row above the main rows which will be in double - lined borders and left justified. Must be same length as each row. [optional] + lined borders and left justified. Must be same length as each row. [optional] :param boolean row_separator: If True add lines between each row. [Default: True] + :param column_alignments str: A string of the same length as the number of cells in + a row (i.e. columns) where each character specifies an alignment such as ``<``, + ``>`` or ``^`` as used in Python's string formatting mini-language. If ``None``, + set the first column to be left justified and the remaining columns to be right + justified [Default: ``None``] :return: The table as a string :rtype: str """ @@ -392,6 +399,8 @@ def unicode_table(rows, title=None, header=None, row_separator=True): widths = [ max(len(row[i_col]) for row in all_rows) for i_col in range(len(all_rows[0])) ] + if column_alignments is None: + column_alignments = "<" + ">" * (len(widths) - 1) out = [] inner_width = sum(widths) + len(header or rows[0]) - 1 if title is not None: @@ -423,9 +432,13 @@ def unicode_table(rows, title=None, header=None, row_separator=True): else: if i != 0 and not last_skipped and row_separator: out.append(f"╟{'┼'.join('─' * w for w in widths)}╢\n") + out.append( - f"║{row[0].ljust(widths[0])}│" - f"{'│'.join(cell.rjust(w) for cell, w in zip(row[1:], widths[1:]))}║\n" + "║" + + "│".join( + f"{r:{a}{w}}" for r, w, a in zip(row, widths, column_alignments) + ) + + "║\n" ) last_skipped = False @@ -433,6 +446,41 @@ def unicode_table(rows, title=None, header=None, row_separator=True): return "".join(out) +def html_table(rows, *, header): + """ + Called e.g. by jupyter notebooks to render tables + """ + headers = "".join(f"{h}" for h in header) + rows = ( + f'{row[11:]}' + f" rows skipped (tskit.set_print_options)" + if "__skipped__" in row + else "".join(f"{cell}" for cell in row) + for row in rows + ) + rows = "".join(f"{row}\n" for row in rows) + return f""" +
+ + + + + {headers} + + + + {rows} + +
+
+ """ + + def tree_sequence_html(ts): table_rows = "".join( f""" @@ -674,6 +722,20 @@ def set_print_options(*, max_lines=40): tskit._print_options = {"max_lines": max_lines} +def truncate_rows(num_rows, limit=None): + """ + Return a list of indexes into a set of rows, but if a ``limit`` is set, truncate the + number of rows and place a single ``-1`` entry, instead of the intermediate indexes + """ + if limit is None or num_rows <= limit: + return range(num_rows) + return itertools.chain( + range(limit // 2), + [-1], + range(num_rows - (limit - (limit // 2)), num_rows), + ) + + def random_nucleotides(length: numbers.Number, *, seed: Union[int, None] = None) -> str: """ Returns a random string of nucleotides of the specified length. Characters