Skip to content

Commit 83a6f9d

Browse files
committed
fix tracking of blocks
1 parent 276e5fa commit 83a6f9d

File tree

2 files changed

+27
-25
lines changed

2 files changed

+27
-25
lines changed

promptolution/tasks/base_task.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
# If no y_column is provided, create a dummy y array
8989
self.ys = [""] * len(self.xs)
9090

91-
self.block_idx: int | list[int] = 0
91+
self.block_idx: int = 0
9292
self.n_blocks: int = len(self.xs) // self.n_subsamples if self.n_subsamples > 0 else 1
9393
self.rng = np.random.default_rng(seed)
9494

@@ -98,18 +98,17 @@ def __init__(
9898
self.prompt_evaluated_blocks: Dict[Prompt, List[int]] = {} # prompt_str: set of evaluated block indices
9999

100100
def subsample(
101-
self, eval_strategy: Optional["EvalStrategy"] = None, block_idx: int | list[int] | None = None
101+
self, eval_strategy: Optional["EvalStrategy"] = None, block_idx: List[int] | None = None
102102
) -> Tuple[List[str], List[str]]:
103103
"""Subsample the dataset based on the specified parameters.
104104
105105
Args:
106106
eval_strategy (EvalStrategy, optional): Subsampling strategy to use instead of self.eval_strategy. Defaults to None.
107+
block_idx (List[int] | None, optional): Specific block index or indices to evaluate, overriding eval_strategy. Defaults to None.
107108
108109
Returns:
109110
Tuple[List[str], List[str]]: Subsampled input data and labels.
110111
"""
111-
if block_idx is not None and isinstance(block_idx, int):
112-
block_idx = [block_idx]
113112

114113
if block_idx is not None:
115114
return [self.xs[i] for i in block_idx], [self.ys[i] for i in block_idx]
@@ -128,17 +127,9 @@ def subsample(
128127
indices = np.arange(start_idx, end_idx)
129128
return [self.xs[i] for i in indices], [self.ys[i] for i in indices]
130129
elif eval_strategy == "sequential_block":
131-
if isinstance(self.block_idx, list):
132-
block_indices: List[int] = []
133-
for block_id in self.block_idx:
134-
start_idx = block_id * self.n_subsamples
135-
end_idx = min((block_id + 1) * self.n_subsamples, len(self.xs))
136-
block_indices.extend(range(start_idx, end_idx))
137-
indices = np.array(sorted(set(block_indices)), dtype=int)
138-
else:
139-
start_idx = self.block_idx * self.n_subsamples
140-
end_idx = min((self.block_idx + 1) * self.n_subsamples, len(self.xs))
141-
indices = np.arange(start_idx, end_idx)
130+
start_idx = self.block_idx * self.n_subsamples
131+
end_idx = min((self.block_idx + 1) * self.n_subsamples, len(self.xs))
132+
indices = np.arange(start_idx, end_idx)
142133

143134
return [self.xs[i] for i in indices], [self.ys[i] for i in indices]
144135
else:
@@ -268,9 +259,20 @@ def evaluate(
268259
269260
This method orchestrates subsampling, prediction, caching, and result collection.
270261
Sequences, token costs, raw scores, and aggregated scores are always returned.
262+
263+
Args:
264+
prompts (Union[Prompt, List[Prompt]]): A single prompt or a list of prompts to evaluate. Results will be returned in the same order.
265+
predictor (BasePredictor): The predictor to evaluate the prompts with.
266+
system_prompts (Optional[Union[str, List[str]]], optional): Optional system prompts to parse to the predictor.
267+
eval_strategy (Optional[EvalStrategy], optional): Subsampling strategy to use instead of self.eval_strategy. Defaults to None, which uses self.eval_strategy.
268+
block_idx (Optional[int | list[int]], optional): Specific block index or indices to evaluate, overriding eval_strategy. Defaults to None.
271269
"""
272270
prompts_list: List[Prompt] = [prompts] if isinstance(prompts, Prompt) else list(prompts)
273271
eval_strategy = eval_strategy or self.eval_strategy
272+
273+
if block_idx is not None and isinstance(block_idx, int):
274+
block_idx = [block_idx]
275+
274276
xs, ys = self.subsample(eval_strategy=eval_strategy, block_idx=block_idx)
275277
(
276278
prompts_to_evaluate,
@@ -298,10 +300,17 @@ def evaluate(
298300

299301
# Record evaluated block for block strategies
300302
for prompt in prompts_list:
301-
if isinstance(self.block_idx, list):
302-
self.prompt_evaluated_blocks.setdefault(prompt, []).extend(self.block_idx)
303-
else:
303+
if eval_strategy == "evaluated":
304+
continue
305+
elif block_idx is not None:
306+
self.prompt_evaluated_blocks.setdefault(prompt, []).extend(block_idx)
307+
elif eval_strategy in ["sequential_block", "random_block"]:
304308
self.prompt_evaluated_blocks.setdefault(prompt, []).append(self.block_idx)
309+
else:
310+
self.prompt_evaluated_blocks.setdefault(prompt, []).extend(
311+
list(range(self.n_blocks))
312+
)
313+
305314

306315
input_tokens, output_tokens, agg_input_tokens, agg_output_tokens = self._compute_costs(
307316
prompts_list, xs, ys, predictor

promptolution/tasks/multi_objective_task.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,6 @@ def evaluate( # type: ignore
130130
prompts_list, xs, ys, predictor
131131
)
132132

133-
# Record evaluated block for block strategies
134-
for prompt in prompts_list:
135-
block_set = task.prompt_evaluated_blocks.setdefault(prompt, [])
136-
if isinstance(task.block_idx, list):
137-
block_set.extend(task.block_idx)
138-
else:
139-
block_set.append(task.block_idx)
140133
per_task_results.append(
141134
EvalResult(
142135
scores=scores_array,

0 commit comments

Comments
 (0)