Skip to content

Support categoricals in alternating optimization #2866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8a08572
Initial setup for introducing categorical dims
TobyBoyne Jun 4, 2025
07ddd83
Implement categorical neighbors
TobyBoyne Jun 4, 2025
a58b7e7
Add initial sampling for categorical-mixed feature spaces
TobyBoyne Jun 4, 2025
01a4c10
Add perturbation of cat features in discrete step
TobyBoyne Jun 4, 2025
87bf8c4
Test for existence of integer/cat dimensions; begin writing test
TobyBoyne Jun 4, 2025
e7858f6
Untransform OneHotToNumeric
TobyBoyne Jun 4, 2025
83d9069
Update botorch/optim/optimize_mixed.py
TobyBoyne Jun 5, 2025
f277b0c
Update botorch/optim/optimize_mixed.py
TobyBoyne Jun 5, 2025
45dad1c
Discrete dims now Noneable; remove repeated input transform
TobyBoyne Jun 5, 2025
be461a4
Sample values for large categorical features
TobyBoyne Jun 5, 2025
50fa5ed
Fix tests by passing empty cat_dims
TobyBoyne Jun 5, 2025
e344040
Revert failure on purely continuous problem
TobyBoyne Jun 5, 2025
590791f
Add test for get_categorical_neighbors
TobyBoyne Jun 5, 2025
88d0d40
Update botorch/optim/optimize_mixed.py
TobyBoyne Jun 6, 2025
f165357
Update botorch/optim/optimize_mixed.py
TobyBoyne Jun 6, 2025
2d92a6e
Update test/optim/test_optimize_mixed.py
TobyBoyne Jun 6, 2025
e5586ce
Use Python primitives for constructing categorical neighbors
TobyBoyne Jun 6, 2025
f89910d
Update `optimize_acqf_mixed_alternating` docstring to reflect categor…
TobyBoyne Jun 6, 2025
d60b7b4
Set manual seed in all `test_optimize_acqf_mixed_*` tests
TobyBoyne Jun 6, 2025
05a0c13
Test random subsampling of categorical values
TobyBoyne Jun 6, 2025
1d7f2a2
Correct private function docstring
TobyBoyne Jun 6, 2025
d0c7fc5
Fix randomly sampling `current_x` as neighbor
TobyBoyne Jun 6, 2025
0d8e128
Fix random seed in test
TobyBoyne Jun 6, 2025
4a18054
Merge branch 'main' into optimize-cat-alternating
TobyBoyne Jun 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 167 additions & 37 deletions botorch/optim/optimize_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
# LICENSE file in the root directory of this source tree.

import dataclasses
import itertools
import random
import warnings
from typing import Any, Callable
from typing import Any, Callable, Sequence

import torch
from botorch.acquisition import AcquisitionFunction
Expand Down Expand Up @@ -164,10 +166,76 @@ def get_nearest_neighbors(
return unique_neighbors


def get_categorical_neighbors(
current_x: Tensor,
bounds: Tensor,
cat_dims: Tensor,
max_num_cat_values: int = MAX_DISCRETE_VALUES,
) -> Tensor:
r"""Generate all 1-Hamming distance neighbors of a given input. The neighbors
are generated for the categorical dimensions only.

We assume that all categorical values are equidistant. If the number of values
is greater than `max_num_cat_values`, we sample uniformly from the
possible values for that dimension.

NOTE: This assumes that `current_x` is detached and uses in-place operations,
which are known to be incompatible with autograd.

Args:
current_x: The design to find the neighbors of. A tensor of shape `d`.
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
cat_dims: A tensor of indices corresponding to categorical parameters.
max_num_cat_values: Maximum number of values for a categorical parameter,
beyond which values are uniformly sampled.

Returns:
A tensor of shape `num_neighbors x d`, denoting up to `max_num_cat_values`
unique 1-Hamming distance neighbors for each categorical dimension.
"""

# Neighbors are generated by considering all possible values for each
# categorical dimension, one at a time.
def _get_cat_values(dim: int) -> Sequence[int]:
r"""Get a sequence of up to `max_num_cat_values` values that a categorical
feature may take."""
lb, ub = bounds[:, dim].long()
current_value = current_x[dim]
cat_values = range(lb, ub + 1)
if ub - lb + 1 <= max_num_cat_values:
return cat_values
else:
return random.sample(
[v for v in cat_values if v != current_value], k=max_num_cat_values
)

new_cat_values_lst = list(
itertools.chain.from_iterable(_get_cat_values(dim) for dim in cat_dims)
)
new_cat_values = torch.tensor(
new_cat_values_lst, device=current_x.device, dtype=current_x.dtype
)

num_cat_values = (bounds[1, :] - bounds[0, :] + 1).to(dtype=torch.long)
num_cat_values.clamp_(max=max_num_cat_values)
new_cat_idcs = torch.cat(
tuple(torch.full((num_cat_values[dim].item(),), dim) for dim in cat_dims)
)
neighbors = current_x.repeat(len(new_cat_values), 1)
# Assign the new values to their corresponding columns.
neighbors.scatter_(1, new_cat_idcs.view(-1, 1), new_cat_values.view(-1, 1))

unique_neighbors = neighbors.unique(dim=0)
# Also remove current_x if it is in unique_neighbors.
unique_neighbors = unique_neighbors[~(unique_neighbors == current_x).all(dim=-1)]
return unique_neighbors


def get_spray_points(
X_baseline: Tensor,
cont_dims: Tensor,
discrete_dims: Tensor,
cat_dims: Tensor,
bounds: Tensor,
num_spray_points: int,
std_cont_perturbation: float = STD_CONT_PERTURBATION,
Expand All @@ -182,6 +250,7 @@ def get_spray_points(
X_baseline: Tensor of best acquired points across BO run.
cont_dims: Indices of continuous parameters/input dimensions.
discrete_dims: Indices of binary/integer parameters/input dimensions.
cat_dims: Indices of categorical parameters/input dimensions.
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
num_spray_points: Number of spray points to return.
std_cont_perturbation: standard deviation of Normal perturbations of
Expand All @@ -194,20 +263,35 @@ def get_spray_points(
device, dtype = X_baseline.device, X_baseline.dtype
perturb_nbors = torch.zeros(0, dim, device=device, dtype=dtype)
for x in X_baseline:
discrete_perturbs = get_nearest_neighbors(
current_x=x, bounds=bounds, discrete_dims=discrete_dims
)
discrete_perturbs = discrete_perturbs[
torch.randint(len(discrete_perturbs), (num_spray_points,), device=device)
]
if discrete_dims.numel():
discrete_perturbs = get_nearest_neighbors(
current_x=x, bounds=bounds, discrete_dims=discrete_dims
)
discrete_perturbs = discrete_perturbs[
torch.randint(
len(discrete_perturbs), (num_spray_points,), device=device
)
]
if cat_dims.numel():
cat_perturbs = get_categorical_neighbors(
current_x=x, bounds=bounds, cat_dims=cat_dims
)
cat_perturbs = cat_perturbs[
torch.randint(len(cat_perturbs), (num_spray_points,), device=device)
]

cont_perturbs = x[cont_dims] + std_cont_perturbation * torch.randn(
num_spray_points, len(cont_dims), device=device, dtype=dtype
)
cont_perturbs = cont_perturbs.clamp_(
min=bounds[0, cont_dims], max=bounds[1, cont_dims]
)
nbds = torch.zeros(num_spray_points, dim, device=device, dtype=dtype)
nbds[..., discrete_dims] = discrete_perturbs[..., discrete_dims]
if discrete_dims.numel():
nbds[..., discrete_dims] = discrete_perturbs[..., discrete_dims]
if cat_dims.numel():
nbds[..., cat_dims] = cat_perturbs[..., cat_dims]

nbds[..., cont_dims] = cont_perturbs
perturb_nbors = torch.cat([perturb_nbors, nbds], dim=0)
return perturb_nbors
Expand All @@ -216,6 +300,7 @@ def get_spray_points(
def sample_feasible_points(
opt_inputs: OptimizeAcqfInputs,
discrete_dims: Tensor,
cat_dims: Tensor,
num_points: int,
) -> Tensor:
r"""Sample feasible points from the optimization domain.
Expand All @@ -235,6 +320,7 @@ def sample_feasible_points(
opt_inputs: Common set of arguments for acquisition optimization.
discrete_dims: A tensor of indices corresponding to binary and
integer parameters.
cat_dims: A tensor of indices corresponding to categorical parameters.
num_points: The number of points to sample.

Returns:
Expand Down Expand Up @@ -272,7 +358,8 @@ def generator(n: int) -> Tensor:
# Generate twice as many, since we're likely to filter out some points.
base_points = generator(n=num_remaining * 2)
# Round the discrete dimensions to the nearest integer.
base_points[:, discrete_dims] = base_points[:, discrete_dims].round()
non_cont_dims = torch.cat((discrete_dims, cat_dims), dim=0)
base_points[:, non_cont_dims] = base_points[:, non_cont_dims].round()
# Fix the fixed features.
base_points = fix_features(
X=base_points, fixed_features=opt_inputs.fixed_features
Expand All @@ -293,6 +380,7 @@ def generator(n: int) -> Tensor:
def generate_starting_points(
opt_inputs: OptimizeAcqfInputs,
discrete_dims: Tensor,
cat_dims: Tensor,
cont_dims: Tensor,
) -> tuple[Tensor, Tensor]:
"""Generate initial starting points for the alternating optimization.
Expand All @@ -307,6 +395,7 @@ def generate_starting_points(
from `opt_inputs`.
discrete_dims: A tensor of indices corresponding to integer and
binary parameters.
cat_dims: A tensor of indices corresponding to categorical parameters.
cont_dims: A tensor of indices corresponding to continuous parameters.

Returns:
Expand Down Expand Up @@ -407,6 +496,7 @@ def generate_starting_points(
X_baseline=X_baseline,
cont_dims=cont_dims,
discrete_dims=discrete_dims,
cat_dims=cat_dims,
bounds=bounds,
num_spray_points=num_spray_points,
std_cont_perturbation=assert_is_instance(
Expand All @@ -429,6 +519,7 @@ def generate_starting_points(
new_x_init = sample_feasible_points(
opt_inputs=opt_inputs,
discrete_dims=discrete_dims,
cat_dims=cat_dims,
num_points=num_restarts - len(x_init_candts),
)
x_init_candts = torch.cat([x_init_candts, new_x_init], dim=0)
Expand All @@ -454,6 +545,7 @@ def generate_starting_points(
def discrete_step(
opt_inputs: OptimizeAcqfInputs,
discrete_dims: Tensor,
cat_dims: Tensor,
current_x: Tensor,
) -> tuple[Tensor, Tensor]:
"""Discrete nearest neighbour search.
Expand All @@ -464,6 +556,7 @@ def discrete_step(
and constraints from `opt_inputs`.
discrete_dims: A tensor of indices corresponding to binary and
integer parameters.
cat_dims: A tensor of indices corresponding to categorical parameters.
current_x: Starting point. A tensor of shape `d`.

Returns:
Expand All @@ -476,14 +569,32 @@ def discrete_step(
for _ in range(
assert_is_instance(options.get("maxiter_discrete", MAX_ITER_DISCRETE), int)
):
x_neighbors = get_nearest_neighbors(
current_x=current_x.detach(),
bounds=opt_inputs.bounds,
discrete_dims=discrete_dims,
)
x_neighbors = _filter_infeasible(
X=x_neighbors, inequality_constraints=opt_inputs.inequality_constraints
)
neighbors = []
if discrete_dims.numel():
x_neighbors_discrete = get_nearest_neighbors(
current_x=current_x.detach(),
bounds=opt_inputs.bounds,
discrete_dims=discrete_dims,
)
x_neighbors_discrete = _filter_infeasible(
X=x_neighbors_discrete,
inequality_constraints=opt_inputs.inequality_constraints,
)
neighbors.append(x_neighbors_discrete)

if cat_dims.numel():
x_neighbors_cat = get_categorical_neighbors(
current_x=current_x.detach(),
bounds=opt_inputs.bounds,
cat_dims=cat_dims,
)
x_neighbors_cat = _filter_infeasible(
X=x_neighbors_cat,
inequality_constraints=opt_inputs.inequality_constraints,
)
neighbors.append(x_neighbors_cat)

x_neighbors = torch.cat(neighbors, dim=0)
if x_neighbors.numel() == 0:
# Exit gracefully with last point if there are no feasible neighbors.
break
Expand All @@ -508,6 +619,7 @@ def discrete_step(
def continuous_step(
opt_inputs: OptimizeAcqfInputs,
discrete_dims: Tensor,
cat_dims: Tensor,
current_x: Tensor,
) -> tuple[Tensor, Tensor]:
"""Continuous search using L-BFGS-B through optimize_acqf.
Expand All @@ -518,14 +630,17 @@ def continuous_step(
`fixed_features` and constraints from `opt_inputs`.
discrete_dims: A tensor of indices corresponding to binary and
integer parameters.
cat_dims: A tensor of indices corresponding to categorical parameters.
current_x: Starting point. A tensor of shape `d`.

Returns:
A tuple of two tensors: a (1 x d)-dim tensor of optimized points
and a (1)-dim tensor of acquisition values.
"""
options = opt_inputs.options or {}
if len(discrete_dims) == len(current_x): # nothing continuous to optimize
non_cont_dims = torch.cat((discrete_dims, cat_dims), dim=0)

if len(non_cont_dims) == len(current_x): # nothing continuous to optimize
with torch.no_grad():
return current_x, opt_inputs.acq_function(current_x.unsqueeze(0))

Expand All @@ -536,7 +651,7 @@ def continuous_step(
raw_samples=None,
batch_initial_conditions=current_x.unsqueeze(0),
fixed_features={
**dict(zip(discrete_dims.tolist(), current_x[discrete_dims])),
**dict(zip(non_cont_dims.tolist(), current_x[non_cont_dims])),
**(opt_inputs.fixed_features or {}),
},
options={
Expand All @@ -551,7 +666,8 @@ def continuous_step(
def optimize_acqf_mixed_alternating(
acq_function: AcquisitionFunction,
bounds: Tensor,
discrete_dims: list[int],
discrete_dims: list[int] | None = None,
cat_dims: list[int] | None = None,
options: dict[str, Any] | None = None,
q: int = 1,
raw_samples: int = RAW_SAMPLES,
Expand All @@ -562,23 +678,25 @@ def optimize_acqf_mixed_alternating(
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
) -> tuple[Tensor, Tensor]:
r"""
Optimizes acquisition function over mixed binary and continuous input spaces.
Multiple random restarting starting points are picked by evaluating a large set
of initial candidates. From each starting point, alternating discrete local search
and continuous optimization via (L-BFGS) is performed for a fixed number of
iterations.

NOTE: This method assumes that all discrete variables are integer valued.
Optimizes acquisition function over mixed integer, categorical, and continuous
input spaces. Multiple random restarting starting points are picked by evaluating
a large set of initial candidates. From each starting point, alternating
discrete/categorical local search and continuous optimization via (L-BFGS)
is performed for a fixed number of iterations.

NOTE: This method assumes that all discrete and categorical variables are
integer valued.
The discrete dimensions that have more than
`options.get("max_discrete_values", MAX_DISCRETE_VALUES)` values will
be optimized using continuous relaxation.

# TODO: Support categorical variables.
The categorical dimensions that have more than `MAX_DISCRETE_VALUES` values
be optimized by selecting random subsamples of the possible values.

Args:
acq_function: BoTorch Acquisition function.
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
discrete_dims: A list of indices corresponding to integer and binary parameters.
cat_dims: A list of indices corresponding to categorical parameters.
options: Dictionary specifying optimization options. Supports the following:
- "initialization_strategy": Strategy used to generate the initial candidates.
"random", "continuous_relaxation" or "equally_spaced" (linspace style).
Expand Down Expand Up @@ -631,6 +749,9 @@ def optimize_acqf_mixed_alternating(
"sequential optimization."
)

cat_dims = cat_dims or []
discrete_dims = discrete_dims or []

fixed_features = fixed_features or {}
options = options or {}
options.setdefault("batch_limit", MAX_BATCH_SIZE)
Expand Down Expand Up @@ -676,22 +797,29 @@ def optimize_acqf_mixed_alternating(
tkwargs: dict[str, Any] = {"device": bounds.device, "dtype": bounds.dtype}
# Remove fixed features from dims, so they don't get optimized.
discrete_dims = [dim for dim in discrete_dims if dim not in fixed_features]
if len(discrete_dims) == 0:
cat_dims = [dim for dim in cat_dims if dim not in fixed_features]
non_cont_dims = [*discrete_dims, *cat_dims]
if len(non_cont_dims) == 0:
# If the problem is fully continuous, fall back to standard optimization.
return _optimize_acqf(opt_inputs=opt_inputs)
if not (
isinstance(discrete_dims, list)
and len(set(discrete_dims)) == len(discrete_dims)
and min(discrete_dims) >= 0
and max(discrete_dims) <= dim - 1
isinstance(non_cont_dims, list)
and len(set(non_cont_dims)) == len(non_cont_dims)
and min(non_cont_dims) >= 0
and max(non_cont_dims) <= dim - 1
):
raise ValueError(
"`discrete_dims` must be a list with unique integers "
"between 0 and num_dims - 1."
"`discrete_dims` and `cat_dims` must be lists with unique, disjoint "
"integers between 0 and num_dims - 1."
)
discrete_dims_t = torch.tensor(
discrete_dims, dtype=torch.long, device=tkwargs["device"]
)
cont_dims = complement_indices_like(indices=discrete_dims_t, d=dim)
cat_dims_t = torch.tensor(cat_dims, dtype=torch.long, device=tkwargs["device"])
non_cont_dims = torch.tensor(
non_cont_dims, dtype=torch.long, device=tkwargs["device"]
)
cont_dims = complement_indices_like(indices=non_cont_dims, d=dim)
# Fixed features are all in cont_dims. Remove them, so they don't get optimized.
ff_idcs = torch.tensor(
list(fixed_features.keys()), dtype=torch.long, device=tkwargs["device"]
Expand All @@ -703,6 +831,7 @@ def optimize_acqf_mixed_alternating(
best_X, best_acq_val = generate_starting_points(
opt_inputs=opt_inputs,
discrete_dims=discrete_dims_t,
cat_dims=cat_dims_t,
cont_dims=cont_dims,
)

Expand All @@ -718,6 +847,7 @@ def optimize_acqf_mixed_alternating(
best_X[i], best_acq_val[i] = step(
opt_inputs=opt_inputs,
discrete_dims=discrete_dims_t,
cat_dims=cat_dims_t,
current_x=best_X[i],
)

Expand Down
Loading