@@ -89,6 +89,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
8989 pick_categories: Restrict the dataset to the given list of categories.
9090 pick_sequences: A Sequence of sequence names to restrict the dataset to.
9191 exclude_sequences: A Sequence of the names of the sequences to exclude.
92+ limit_sequences_per_category_to: Limit the dataset to the first up to N
93+ sequences within each category (applies after all other sequence filters
94+ but before `limit_sequences_to`).
9295 limit_sequences_to: Limit the dataset to the first `limit_sequences_to`
9396 sequences (after other sequence filters have been applied but before
9497 frame-based filters).
@@ -115,6 +118,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
115118
116119 pick_sequences : Tuple [str , ...] = ()
117120 exclude_sequences : Tuple [str , ...] = ()
121+ limit_sequences_per_category_to : int = 0
118122 limit_sequences_to : int = 0
119123 limit_to : int = 0
120124 n_frames_per_sequence : int = - 1
@@ -373,27 +377,46 @@ def is_filtered(self) -> bool:
373377 self .remove_empty_masks
374378 or self .limit_to > 0
375379 or self .limit_sequences_to > 0
380+ or self .limit_sequences_per_category_to > 0
376381 or len (self .pick_sequences ) > 0
377382 or len (self .exclude_sequences ) > 0
378383 or len (self .pick_categories ) > 0
379384 or self .n_frames_per_sequence > 0
380385 )
381386
382387 def _get_filtered_sequences_if_any (self ) -> Optional [pd .Series ]:
383- # maximum possible query: WHERE category IN 'self.pick_categories'
388+ # maximum possible filter (if limit_sequences_per_category_to == 0):
389+ # WHERE category IN 'self.pick_categories'
384390 # AND sequence_name IN 'self.pick_sequences'
385391 # AND sequence_name NOT IN 'self.exclude_sequences'
386392 # LIMIT 'self.limit_sequence_to'
387393
388- stmt = sa .select (SqlSequenceAnnotation .sequence_name )
389-
390394 where_conditions = [
391395 * self ._get_category_filters (),
392396 * self ._get_pick_filters (),
393397 * self ._get_exclude_filters (),
394398 ]
395- if where_conditions :
396- stmt = stmt .where (* where_conditions )
399+
400+ def add_where (stmt ):
401+ return stmt .where (* where_conditions ) if where_conditions else stmt
402+
403+ if self .limit_sequences_per_category_to <= 0 :
404+ stmt = add_where (sa .select (SqlSequenceAnnotation .sequence_name ))
405+ else :
406+ subquery = sa .select (
407+ SqlSequenceAnnotation .sequence_name ,
408+ sa .func .row_number ()
409+ .over (
410+ order_by = sa .text ("ROWID" ), # NOTE: ROWID is SQLite-specific
411+ partition_by = SqlSequenceAnnotation .category ,
412+ )
413+ .label ("row_number" ),
414+ )
415+
416+ subquery = add_where (subquery ).subquery ()
417+ stmt = sa .select (subquery .c .sequence_name ).where (
418+ subquery .c .row_number <= self .limit_sequences_per_category_to
419+ )
397420
398421 if self .limit_sequences_to > 0 :
399422 logger .info (
@@ -402,7 +425,11 @@ def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
402425 # NOTE: ROWID is SQLite-specific
403426 stmt = stmt .order_by (sa .text ("ROWID" )).limit (self .limit_sequences_to )
404427
405- if not where_conditions and self .limit_sequences_to <= 0 :
428+ if (
429+ not where_conditions
430+ and self .limit_sequences_to <= 0
431+ and self .limit_sequences_per_category_to <= 0
432+ ):
406433 # we will not need to filter by sequences
407434 return None
408435
0 commit comments