Skip to content

Commit 6c8f73a

Browse files
will-keenimgtec-admin
authored andcommitted
Merge pull request #8 from imaginationtech/rand_len_next
Pass random length list tests with sparse/thorough solver
2 parents d1178c8 + 46e44c9 commit 6c8f73a

File tree

4 files changed

+233
-122
lines changed

4 files changed

+233
-122
lines changed

constrainedrandom/internal/multivar.py

Lines changed: 39 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
# SPDX-License-Identifier: MIT
22
# Copyright (c) 2023 Imagination Technologies Ltd. All Rights Reserved
33

4-
import constraint
54
from collections import defaultdict
6-
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Union
5+
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union
76

87
from .vargroup import VarGroup
98

109
from .. import utils
11-
from ..debug import RandomizationDebugInfo, RandomizationFail
10+
from ..debug import RandomizationDebugInfo
1211

1312
if TYPE_CHECKING:
1413
from ..randobj import RandObj
@@ -138,74 +137,65 @@ def solve_groups(
138137
'''
139138
constraints = self.constraints
140139
sparse_solver = solutions_per_group is not None
141-
solutions = []
142-
solved_vars = []
140+
solutions : List[Dict[str, Any]] = []
141+
solved_vars : List[str] = []
143142

144-
# Respect assigned temporary values
143+
# Respect assigned temporary values.
145144
if len(with_values) > 0:
146145
for var_name in with_values.keys():
147146
solved_vars.append(var_name)
148147
solutions.append(with_values)
149148

150-
# If solving sparsely, we'll create a new problem for each group.
151-
# If not solving sparsely, just create one big problem that we add to
152-
# as we go along.
153-
if not sparse_solver:
154-
problem = constraint.Problem()
155-
for var_name, value in with_values.items():
156-
problem.addVariable(var_name, (value,))
157-
149+
# For each group, construct a problem and solve it.
158150
for group in groups:
159-
if sparse_solver:
160-
# Construct one problem per group, add solved variables from previous groups.
161-
problem = constraint.Problem()
162-
# Construct the appropriate group variable problem
163-
group_problem = VarGroup(
164-
group,
165-
solved_vars,
166-
problem,
167-
constraints,
168-
self.max_domain_size,
169-
self.debug,
170-
)
171-
172151
group_solutions = None
152+
group_problem = None
173153
attempts = 0
174154
while group_solutions is None or len(group_solutions) == 0:
155+
# Early loop exit cases
175156
if attempts >= max_iterations:
176157
# We have failed, give up
177158
return None
178159
if attempts > 0 and not group_problem.can_retry():
179160
# Not worth retrying - the same result will be obtained.
180161
return None
181-
if sparse_solver:
182-
if len(solutions) > 0:
183-
# Respect a proportion of the solution space, determined
184-
# by the sparsity/solutions_per_group.
162+
163+
# Determine what the starting state space for this group
164+
# should be.
165+
if sparse_solver and len(solutions) > 0:
166+
# Respect a proportion of the solution space, determined
167+
# by the sparsity/solutions_per_group.
168+
# Start by choosing a subset of the possible solutions.
185169
if solutions_per_group >= len(solutions):
186-
solution_subset = solutions
170+
solution_subset = list(solutions)
187171
else:
188172
solution_subset = self.parent._get_random().choices(
189173
solutions,
190174
k=solutions_per_group
191175
)
192-
if solutions_per_group == 1:
193-
for var_name, value in solution_subset[0].items():
194-
if var_name in problem._variables:
195-
del problem._variables[var_name]
196-
problem.addVariable(var_name, (value,))
197-
else:
198-
solution_space = defaultdict(list)
199-
for soln in solution_subset:
200-
for var_name, value in soln.items():
201-
# List is ~2x slower than set for 'in',
202-
# but variables might be non-hashable.
203-
if value not in solution_space[var_name]:
204-
solution_space[var_name].append(value)
205-
for var_name, values in solution_space.items():
206-
if var_name in problem._variables:
207-
del problem._variables[var_name]
208-
problem.addVariable(var_name, values)
176+
else:
177+
# If not sparse, maintain the entire list of possible solutions.
178+
solution_subset = list(solutions)
179+
180+
# Translate this subset into a dictionary of the
181+
# possible values for each variable.
182+
solution_space = defaultdict(list)
183+
for soln in solution_subset:
184+
for var_name, value in soln.items():
185+
# List is ~2x slower than set for 'in',
186+
# but variables might be non-hashable.
187+
if value not in solution_space[var_name]:
188+
solution_space[var_name].append(value)
189+
190+
# Construct the appropriate group variable problem.
191+
# Must be done after selecting the solution space.
192+
group_problem = VarGroup(
193+
group,
194+
solution_space,
195+
constraints,
196+
self.max_domain_size,
197+
self.debug,
198+
)
209199

210200
# Attempt to solve the group
211201
group_solutions = group_problem.solve(

constrainedrandom/internal/randvar.py

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ def get_length(self) -> int:
253253
" but none was given when get_length was called.")
254254
return self.rand_length_val
255255

256+
def is_list(self) -> bool:
257+
'''
258+
Returns ``True`` if this is a list variable.
259+
260+
:return: ``True`` if this is a list variable, otherwise ``False``.
261+
'''
262+
return self.length is not None or self.rand_length is not None
263+
256264
def set_rand_length(self, length: int) -> None:
257265
'''
258266
Function to set the random length.
@@ -282,31 +290,48 @@ def _get_random(self) -> random.Random:
282290
return random
283291
return self._random
284292

285-
def get_domain_size(self) -> int:
293+
def get_domain_size(self, possible_lengths: Optional[List[int]]=None) -> int:
286294
'''
287295
Return total domain size, accounting for length of this random variable.
288296
297+
:param possible_lengths: Optional, when there is more than one possiblity
298+
for the value of the random length, specifies a list of the
299+
possibilities.
289300
:return: domain size, integer.
290301
'''
291302
if self.domain is None:
292303
# If there's no domain, it means we can't estimate the complexity
293304
# of this variable. Return 1.
294305
return 1
295306
else:
296-
length = self.get_length()
297-
if length is None:
298-
# length is None implies a scalar variable.
299-
return len(self.domain)
300-
elif length == 0:
301-
# This is a zero-length list, adding no complexity.
302-
return 1
303-
elif length == 1:
304-
return len(self.domain)
307+
# possible_lengths is used when the variable has a random
308+
# length and that length is not yet fully determined.
309+
if possible_lengths is None:
310+
# Normal, fixed length of some description.
311+
length = self.get_length()
312+
if length is None:
313+
# length is None implies a scalar variable.
314+
return len(self.domain)
315+
elif length == 0:
316+
# This is a zero-length list, adding no complexity.
317+
return 1
318+
elif length == 1:
319+
return len(self.domain)
320+
else:
321+
# In this case it is effectively cartesian product, i.e.
322+
# n ** k, where n is the size of the domain and k is the length
323+
# of the list.
324+
return len(self.domain) ** length
305325
else:
306-
# In this case it is effectively cartesian product, i.e.
307-
# n ** k, where n is the size of the domain and k is the length
308-
# of the list.
309-
return len(self.domain) ** length
326+
# Random length which could be one of a number of values.
327+
assert self.rand_length is not None, "Cannot use possible_lengths " \
328+
"for a variable with non-random length."
329+
# For each possible length, the domain is the cartesian
330+
# product as above, but added together.
331+
total = 0
332+
for poss_len in possible_lengths:
333+
total += len(self.domain) ** poss_len
334+
return total
310335

311336
def can_use_with_constraint(self) -> bool:
312337
'''
@@ -321,30 +346,42 @@ def can_use_with_constraint(self) -> bool:
321346
# and the domain isn't a dictionary.
322347
return self.domain is not None and not isinstance(self.domain, dict)
323348

324-
def get_constraint_domain(self) -> utils.Domain:
349+
def get_constraint_domain(self, possible_lengths: Optional[List[int]]=None) -> utils.Domain:
325350
'''
326351
Get a ``constraint`` package friendly version of the domain
327352
of this random variable.
328353
354+
:param possible_lengths: Optional, when there is more than one possiblity
355+
for the value of the random length, specifies a list of the
356+
possibilities.
329357
:return: the variable's domain in a format that will work
330358
with the ``constraint`` package.
331359
'''
332-
length = self.get_length()
333-
if length is None:
334-
# Straightforward, scalar
335-
return self.domain
336-
elif length == 0:
337-
# List of length zero - an empty list is only correct choice.
338-
return [[]]
339-
elif length == 1:
340-
# List of length one
341-
return [[x] for x in self.domain]
360+
if possible_lengths is None:
361+
length = self.get_length()
362+
if length is None:
363+
# Straightforward, scalar
364+
return self.domain
365+
elif length == 0:
366+
# List of length zero - an empty list is only correct choice.
367+
return [[]]
368+
elif length == 1:
369+
# List of length one
370+
return [[x] for x in self.domain]
371+
else:
372+
# List of greater length, cartesian product.
373+
# Beware that this may be an extremely large domain.
374+
# Ensure each element is of type list, which is what
375+
# we want to return.
376+
return [list(x) for x in product(self.domain, repeat=length)]
342377
else:
343-
# List of greater length, cartesian product.
344-
# Beware that this may be an extremely large domain.
345-
# Ensure each element is of type list, which is what
346-
# we want to return.
347-
return [list(x) for x in product(self.domain, repeat=length)]
378+
# For each possible length, return the possible domains.
379+
# This can get extremely large, even more so than
380+
# the regular product.
381+
result = []
382+
for poss_len in possible_lengths:
383+
result += [list(x) for x in product(self.domain, repeat=poss_len)]
384+
return result
348385

349386
def randomize_once(self, constraints: Iterable[utils.Constraint], check_constraints: bool, debug: bool) -> Any:
350387
'''

0 commit comments

Comments
 (0)