Skip to content

Commit f70f154

Browse files
ryanjung-jmcopybara-github
authored andcommitted
chore: Add filtering for experiment.list method
PiperOrigin-RevId: 722502329
1 parent 49c0bf5 commit f70f154

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

google/cloud/aiplatform/metadata/experiment_resources.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ def get_or_create(
316316
def list(
317317
cls,
318318
*,
319+
filter: Optional[str] = None,
319320
project: Optional[str] = None,
320321
location: Optional[str] = None,
321322
credentials: Optional[auth_credentials.Credentials] = None,
@@ -327,6 +328,8 @@ def list(
327328
```
328329
329330
Args:
331+
filter (str):
332+
Optional. A query to filter available resources for matching results.
330333
project (str):
331334
Optional. Project to list these experiments from. Overrides project set in
332335
aiplatform.init.
@@ -343,6 +346,8 @@ def list(
343346
filter_str = metadata_utils._make_filter_string(
344347
schema_title=constants.SYSTEM_EXPERIMENT
345348
)
349+
if filter:
350+
filter_str = f"{filter_str} AND ({filter})"
346351

347352
with _SetLoggerLevel(resource):
348353
experiment_contexts = context.Context.list(

tests/system/aiplatform/test_experiments.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,27 @@ def test_get_run(self):
121121
assert run.name == _RUN
122122
assert run.state == aiplatform.gapic.Execution.State.RUNNING
123123

124+
def test_list_experiment(self):
125+
experiments = aiplatform.Experiment.list(
126+
project=e2e_base._PROJECT,
127+
location=e2e_base._LOCATION,
128+
)
129+
assert isinstance(experiments, list)
130+
assert any(
131+
experiment.name == self._experiment_name for experiment in experiments
132+
)
133+
134+
def test_list_experiment_filter(self):
135+
experiments = aiplatform.Experiment.list(
136+
filter=f"display_name = {self._experiment_name}",
137+
project=e2e_base._PROJECT,
138+
location=e2e_base._LOCATION,
139+
)
140+
assert len(experiments) == 1
141+
assert any(
142+
experiment.name == self._experiment_name for experiment in experiments
143+
)
144+
124145
def test_log_params(self):
125146
aiplatform.init(
126147
project=e2e_base._PROJECT,

0 commit comments

Comments
 (0)