Skip to content

Commit 78255b6

Browse files
Update the tests and API
1 parent 188d1ec commit 78255b6

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

tests/test_pedigree.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def simulate_pedigree(
4848
num_generations=3,
4949
sequence_length=1,
5050
random_seed=42,
51-
internal_sample_gen=(None,None,None),
51+
sample_gen=None,
5252
) -> tskit.TableCollection:
5353
"""
5454
Simulates pedigree.
@@ -62,17 +62,19 @@ def simulate_pedigree(
6262
num_generations: Number of generations to attempt to simulate
6363
sequence_length: The sequence_length of the output tables.
6464
random_seed: Random seed.
65+
sample_gen: Generations at which all individuals are samples. Defaults
66+
to the first generation (backwards in time).
6567
"""
66-
# Fill-in internal_sample_gen with None if shorter than number of generations
67-
if len(internal_sample_gen) < num_generations:
68-
tmp = internal_sample_gen
69-
internal_sample_gen = np.repeat(None,num_generations)
70-
internal_sample_gen[0:len(tmp)] = tmp
7168
rng = np.random.RandomState(random_seed)
7269
builder = msprime.PedigreeBuilder()
7370

7471
time = num_generations - 1
75-
curr_gen = [builder.add_individual(time=time,is_sample=internal_sample_gen[0]) for _ in range(num_founders)]
72+
if sample_gen is None:
73+
sample_gen = [0]
74+
curr_gen = [
75+
builder.add_individual(time=time, is_sample=time in sample_gen)
76+
for _ in range(num_founders)
77+
]
7678
for generation in range(1, num_generations):
7779
num_pairs = len(curr_gen) // 2
7880
if num_pairs == 0 and num_children_prob[0] != 1:
@@ -86,7 +88,9 @@ def simulate_pedigree(
8688
num_children = rng.choice(len(num_children_prob), p=num_children_prob)
8789
for _ in range(num_children):
8890
parents = np.sort(parents).astype(np.int32)
89-
ind_id = builder.add_individual(time=time, parents=parents, is_sample=internal_sample_gen[generation])
91+
ind_id = builder.add_individual(
92+
time=time, parents=parents, is_sample=time in sample_gen
93+
)
9094
curr_gen.append(ind_id)
9195
return builder.finalise(sequence_length)
9296

@@ -550,16 +554,16 @@ def test_shallow(self, num_founders, recombination_rate):
550554
sequence_length=100,
551555
)
552556
self.verify(tables, recombination_rate)
553-
554-
@pytest.mark.parametrize("num_founders", [2, 3, 5, 100])
557+
558+
@pytest.mark.parametrize("num_founders", [2, 3, 5])
555559
@pytest.mark.parametrize("recombination_rate", [0, 0.01])
556560
def test_shallow_internal(self, num_founders, recombination_rate):
557561
tables = simulate_pedigree(
558562
num_founders=num_founders,
559563
num_children_prob=[0, 0, 1],
560564
num_generations=2,
561-
sequence_length=100,
562-
internal_sample_gen=[True, False],
565+
sequence_length=100,
566+
sample_gen=[0, 1],
563567
)
564568
self.verify(tables, recombination_rate)
565569

0 commit comments

Comments
 (0)