Skip to content

Commit 19869ea

Browse files
feat: query & cache package_info (#1083)
* feat: add package query in draft.py (not yet enabled) * feat: integrate package query into task_gen and cache runtime environment - Remove pkg_query modifications from draft components - Add package declaration requirement in task_gen prompts - Add optional packages field to CodingSketch model - Cache runtime_environment in scenario object for loop-wide reuse - Parse packages from LLM response and generate runtime environment dynamically * some refinement * feat: merge default packages with CLI args in package_info.py * fix: code style --------- Co-authored-by: Qizheng Li <jenssenlee@163.com>
1 parent a26f394 commit 19869ea

11 files changed

Lines changed: 120 additions & 35 deletions

File tree

rdagent/components/coder/data_science/pipeline/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from rdagent.components.coder.data_science.conf import DSCoderCoSTEERSettings
4040
from rdagent.components.coder.data_science.pipeline.eval import PipelineCoSTEEREvaluator
41-
from rdagent.components.coder.data_science.raw_data_loader.exp import DataLoaderTask
41+
from rdagent.components.coder.data_science.pipeline.exp import PipelineTask
4242
from rdagent.components.coder.data_science.share.eval import ModelDumpEvaluator
4343
from rdagent.core.exception import CoderError
4444
from rdagent.core.experiment import FBWorkspace
@@ -53,7 +53,7 @@
5353
class PipelineMultiProcessEvolvingStrategy(MultiProcessEvolvingStrategy):
5454
def implement_one_task(
5555
self,
56-
target_task: DataLoaderTask,
56+
target_task: PipelineTask,
5757
queried_knowledge: CoSTEERQueriedKnowledge | None = None,
5858
workspace: FBWorkspace | None = None,
5959
prev_task_feedback: CoSTEERSingleFeedback | None = None,
@@ -86,6 +86,7 @@ def implement_one_task(
8686
queried_former_failed_knowledge=queried_former_failed_knowledge[0],
8787
out_spec=PythonAgentOut.get_spec(),
8888
runtime_environment=runtime_environment,
89+
package_info=target_task.package_info,
8990
enable_model_dump=DS_RD_SETTING.enable_model_dump,
9091
enable_debug_mode=DS_RD_SETTING.sample_data_by_LLM,
9192
)

rdagent/components/coder/data_science/pipeline/exp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33

44
# Because we use isinstance to distinguish between different types of tasks, we need to use sub classes to represent different types of tasks
55
class PipelineTask(CoSTEERTask):
6-
def __init__(self, name: str = "Pipeline", *args, **kwargs) -> None:
6+
def __init__(self, name: str = "Pipeline", package_info: str | None = None, *args, **kwargs) -> None:
77
super().__init__(name=name, *args, **kwargs)
8+
self.package_info = package_info

rdagent/components/coder/data_science/pipeline/prompts.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ pipeline_coder:
1111
1212
## The runtime environment your code will running on
1313
{{ runtime_environment }}
14+
15+
{% if package_info is not none %}
16+
To help you write the runnable code, the user has provided the package information which contains the package names and versions.
17+
You should be careful about the package versions, as the code will be executed in the environment with the specified version and the api might be different from the latest version.
18+
The user might provide the packages the environment doesn't have, you should avoid using any of them.
19+
## Package Information
20+
{{ package_info }}
21+
{% endif %}
1422
1523
## Hyperparameters Specification
1624
Follow the hyperparameter choices if they are specified in the task description, unless they are unreasonable or incorrect.

rdagent/scenarios/data_science/dev/runner/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def evaluate(
133133
scenario=self.scen.get_scenario_all_desc(eda_output=implementation.file_dict.get("EDA.md", None)),
134134
is_sub_enabled=test_eval.is_sub_enabled(self.scen.competition),
135135
task_desc=target_task.get_task_information(),
136-
runtime_environment=self.scen.get_runtime_environment(),
136+
runtime_environment=self.scen.runtime_environment,
137137
)
138138
user_prompt = T(".prompts:DSCoSTEER_eval.user").r(
139139
code=implementation.all_codes,
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import sys
2+
from importlib.metadata import distributions
3+
4+
5+
def get_installed_packages():
6+
return {dist.metadata["Name"].lower(): dist.version for dist in distributions()}
7+
8+
9+
def print_filtered_packages(installed_packages, filtered_packages):
10+
to_print = []
11+
for package_name in filtered_packages:
12+
version = installed_packages.get(package_name.lower())
13+
if version:
14+
to_print.append((package_name, version))
15+
if not to_print:
16+
print("=== No matching packages found ===")
17+
else:
18+
print("=== Installed Packages ===")
19+
for package_name, version in to_print:
20+
# Print package name and version in the format "package_name==version"
21+
print(f"{package_name}=={version}")
22+
23+
24+
def get_python_packages():
25+
# Allow the caller to pass a custom package list via command-line arguments.
26+
# Example: `python package_info.py pandas torch scikit-learn`
27+
# If no extra arguments are provided we fall back to the original default list
28+
# to keep full backward-compatibility.
29+
packages_list = [ # default packages
30+
"transformers",
31+
"accelerate",
32+
"torch",
33+
"tensorflow",
34+
"pandas",
35+
"numpy",
36+
"scikit-learn",
37+
"scipy",
38+
"xgboost",
39+
"sklearn",
40+
"lightgbm",
41+
"vtk",
42+
"opencv-python",
43+
"keras",
44+
"matplotlib",
45+
"pydicom",
46+
]
47+
if len(sys.argv) > 1:
48+
packages_list = list(set(packages_list) | set(sys.argv[1:]))
49+
50+
installed_packages = get_installed_packages()
51+
52+
print_filtered_packages(installed_packages, packages_list)
53+
54+
# TODO: Handle missing packages.
55+
# Report packages that are requested by the LLM but are not installed.
56+
missing_pkgs = [pkg for pkg in packages_list if pkg.lower() not in installed_packages]
57+
if missing_pkgs:
58+
print("\n=== Missing Packages (Avoid using these packages) ===")
59+
for pkg in missing_pkgs:
60+
print(pkg)
61+
62+
63+
if __name__ == "__main__":
64+
get_python_packages()

rdagent/scenarios/data_science/proposal/exp_gen/prompts.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,4 +345,5 @@ output_format:
345345
The output should follow JSON format. The schema is as follows:
346346
{
347347
"description": "A detailed, step-by-step implementation guide for `main.py` that synthesizes planned modifications and code structure into a comprehensive coding plan. Must be formatted in Markdown with level-3 headings (###) organizing logical sections, key decision points, and implementation steps. Should provide sufficient detail covering implementation flow, algorithms, data handling, and key logic points for unambiguous developer execution.",
348+
"packages": ["package1", "package2", ...] # Optional, list of packages needed for the task. If no packages are needed, leave it empty.
348349
}

rdagent/scenarios/data_science/proposal/exp_gen/prompts_v2.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,11 @@ task_gen:
270270
- For neural networks, prefer PyTorch or PyTorch based library (over TensorFlow) unless the SOTA or hypothesis dictates otherwise.
271271
- For neural networks, prefer fine-tuning pre-trained models over training from scratch.
272272
273+
## Package Declaration
274+
At the end of your design, **you MUST** provide a key `packages` in the final JSON output.
275+
It should be an **array of PyPI package names** (strings) that you expect to `import` in the forthcoming implementation.
276+
List only third-party packages (do **NOT** include built-in modules like `os`, `json`).
277+
273278
# Guidelines for Sketching the `main.py` Workflow
274279
275280
YOUR TASK IS TO create a conceptual sketch for drafting or updating the `main.py` workflow. This is a plan, not code.

rdagent/scenarios/data_science/proposal/exp_gen/proposal.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
DSDraftExpGen, # TODO: DSDraftExpGen should be moved to router in the further
2424
)
2525
from rdagent.scenarios.data_science.proposal.exp_gen.idea_pool import DSIdea
26+
from rdagent.scenarios.data_science.proposal.exp_gen.utils import get_packages
2627
from rdagent.utils.agent.tpl import T
2728
from rdagent.utils.repo.diff import generate_diff_from_dict
2829
from rdagent.utils.workflow import wait_retry
@@ -274,6 +275,11 @@ class CodingSketch(BaseModel):
274275
"The content **must** be formatted using Markdown, with logical sections, key decision points, or implementation steps clearly organized by level-3 headings (i.e., `###`). "
275276
"This field should provide sufficient detail for a developer to understand the implementation flow, algorithms, data handling, and key logic points without ambiguity."
276277
)
278+
packages: List[str] = Field(
279+
default=None,
280+
description="A list of third-party package names (PyPI) that the planned implementation will import. "
281+
"Used to query the runtime environment dynamically. Leave `null` or omit if not applicable.",
282+
)
277283

278284

279285
def draft_exp_in_decomposition(scen: Scenario, trace: DSTrace) -> None | DSDraftExpGen:
@@ -775,6 +781,15 @@ def task_gen(
775781
name=task_name,
776782
description=task_desc,
777783
)
784+
785+
assert isinstance(task, PipelineTask), f"Task {task_name} is not a PipelineTask, got {type(task)}"
786+
# only for llm with response schema.(TODO: support for non-schema llm?)
787+
# If the LLM provides a "packages" field (list[str]), compute runtime environment now and cache it for subsequent prompts in later loops.
788+
if isinstance(task_dict, dict) and "packages" in task_dict and isinstance(task_dict["packages"], list):
789+
pkgs: list[str] = [str(p) for p in task_dict["packages"]]
790+
# Persist for later stages
791+
task.package_info = get_packages(pkgs)
792+
778793
exp = DSExperiment(pending_tasks_list=[[task]], hypothesis=hypotheses[0])
779794
if sota_exp is not None:
780795
exp.experiment_workspace.inject_code_from_file_dict(sota_exp.experiment_workspace)

rdagent/scenarios/data_science/proposal/exp_gen/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
from pathlib import Path
12
from typing import Any, Dict, List, Optional, Tuple
23

34
from pydantic import BaseModel, Field
45

6+
from rdagent.components.coder.data_science.conf import get_ds_env
57
from rdagent.components.coder.data_science.ensemble.exp import EnsembleTask
68
from rdagent.components.coder.data_science.feature.exp import FeatureTask
79
from rdagent.components.coder.data_science.model.exp import ModelTask
810
from rdagent.components.coder.data_science.pipeline.exp import PipelineTask
911
from rdagent.components.coder.data_science.raw_data_loader.exp import DataLoaderTask
1012
from rdagent.components.coder.data_science.workflow.exp import WorkflowTask
13+
from rdagent.core.experiment import FBWorkspace
1114
from rdagent.utils.agent.tpl import T
1215

1316
_COMPONENT_META: Dict[str, Dict[str, Any]] = {
@@ -86,3 +89,20 @@ class CodingSketch(BaseModel):
8689
"The content **must** be formatted using Markdown, with logical sections, key decision points, or implementation steps clearly organized by level-3 headings (i.e., `###`). "
8790
"This field should provide sufficient detail for a developer to understand the implementation flow, algorithms, data handling, and key logic points without ambiguity."
8891
)
92+
93+
94+
def get_packages(self, pkgs: list[str] | None = None) -> str:
95+
# TODO: add it into base class. Environment should(i.e. `DSDockerConf`) should be part of the scenario class.
96+
"""Return runtime environment information."""
97+
# Reuse package list cached during Draft stage when available.
98+
if pkgs is None and hasattr(self, "required_packages"):
99+
pkgs = getattr(self, "required_packages") # type: ignore[arg-type]
100+
101+
env = get_ds_env()
102+
implementation = FBWorkspace()
103+
fname = "package_info.py"
104+
implementation.inject_files(**{fname: (Path(__file__).absolute().resolve().parent / "package_info.py").read_text()})
105+
106+
pkg_args = " ".join(pkgs) if pkgs else ""
107+
stdout = implementation.execute(env=env, entry=f"python {fname} {pkg_args}")
108+
return stdout

rdagent/scenarios/data_science/scen/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def get_scenario_all_desc(self, eda_output=None) -> str:
166166

167167
def get_runtime_environment(self) -> str:
168168
# TODO: add it into base class. Environment should(i.e. `DSDockerConf`) should be part of the scenario class.
169+
"""Return runtime environment information."""
169170
env = get_ds_env()
170171
implementation = FBWorkspace()
171172
fname = "runtime_info.py"

0 commit comments

Comments
 (0)