@@ -48,7 +48,7 @@ def simulate_pedigree(
48
48
num_generations = 3 ,
49
49
sequence_length = 1 ,
50
50
random_seed = 42 ,
51
- internal_sample_gen = ( None , None , None ) ,
51
+ sample_gen = None ,
52
52
) -> tskit .TableCollection :
53
53
"""
54
54
Simulates pedigree.
@@ -62,17 +62,19 @@ def simulate_pedigree(
62
62
num_generations: Number of generations to attempt to simulate
63
63
sequence_length: The sequence_length of the output tables.
64
64
random_seed: Random seed.
65
+ sample_gen: Generations at which all individuals are samples. Defaults
66
+ to the first generation (backwards in time).
65
67
"""
66
- # Fill-in internal_sample_gen with None if shorter than number of generations
67
- if len (internal_sample_gen ) < num_generations :
68
- tmp = internal_sample_gen
69
- internal_sample_gen = np .repeat (None ,num_generations )
70
- internal_sample_gen [0 :len (tmp )] = tmp
71
68
rng = np .random .RandomState (random_seed )
72
69
builder = msprime .PedigreeBuilder ()
73
70
74
71
time = num_generations - 1
75
- curr_gen = [builder .add_individual (time = time ,is_sample = internal_sample_gen [0 ]) for _ in range (num_founders )]
72
+ if sample_gen is None :
73
+ sample_gen = [0 ]
74
+ curr_gen = [
75
+ builder .add_individual (time = time , is_sample = time in sample_gen )
76
+ for _ in range (num_founders )
77
+ ]
76
78
for generation in range (1 , num_generations ):
77
79
num_pairs = len (curr_gen ) // 2
78
80
if num_pairs == 0 and num_children_prob [0 ] != 1 :
@@ -86,7 +88,9 @@ def simulate_pedigree(
86
88
num_children = rng .choice (len (num_children_prob ), p = num_children_prob )
87
89
for _ in range (num_children ):
88
90
parents = np .sort (parents ).astype (np .int32 )
89
- ind_id = builder .add_individual (time = time , parents = parents , is_sample = internal_sample_gen [generation ])
91
+ ind_id = builder .add_individual (
92
+ time = time , parents = parents , is_sample = time in sample_gen
93
+ )
90
94
curr_gen .append (ind_id )
91
95
return builder .finalise (sequence_length )
92
96
@@ -550,16 +554,16 @@ def test_shallow(self, num_founders, recombination_rate):
550
554
sequence_length = 100 ,
551
555
)
552
556
self .verify (tables , recombination_rate )
553
-
554
- @pytest .mark .parametrize ("num_founders" , [2 , 3 , 5 , 100 ])
557
+
558
+ @pytest .mark .parametrize ("num_founders" , [2 , 3 , 5 ])
555
559
@pytest .mark .parametrize ("recombination_rate" , [0 , 0.01 ])
556
560
def test_shallow_internal (self , num_founders , recombination_rate ):
557
561
tables = simulate_pedigree (
558
562
num_founders = num_founders ,
559
563
num_children_prob = [0 , 0 , 1 ],
560
564
num_generations = 2 ,
561
- sequence_length = 100 ,
562
- internal_sample_gen = [ True , False ],
565
+ sequence_length = 100 ,
566
+ sample_gen = [ 0 , 1 ],
563
567
)
564
568
self .verify (tables , recombination_rate )
565
569
0 commit comments