|
| 1 | +# %% |
1 | 2 | import json |
2 | 3 | import subprocess |
3 | 4 | import time |
4 | 5 | import zipfile |
| 6 | +from itertools import chain |
5 | 7 | from pathlib import Path |
6 | 8 |
|
| 9 | +import nbformat |
| 10 | +from jinja2 import Environment, StrictUndefined |
| 11 | +from rich import print |
7 | 12 | from selenium import webdriver |
8 | 13 | from selenium.webdriver.chrome.service import Service |
9 | 14 | from selenium.webdriver.common.by import By |
10 | 15 |
|
11 | 16 | from rdagent.app.kaggle.conf import KAGGLE_IMPLEMENT_SETTING |
| 17 | +from rdagent.core.prompts import Prompts |
12 | 18 | from rdagent.log import rdagent_logger as logger |
| 19 | +from rdagent.oai.llm_utils import APIBackend |
13 | 20 |
|
| 21 | +# %% |
14 | 22 | options = webdriver.ChromeOptions() |
15 | 23 | options.add_argument("--no-sandbox") |
16 | 24 | options.add_argument("--disable-dev-shm-usage") |
@@ -79,6 +87,121 @@ def download_data(competition: str, local_path: str = "/data/userdata/share/kagg |
79 | 87 | zip_ref.extractall(data_path) |
80 | 88 |
|
81 | 89 |
|
| 90 | +def download_notebooks( |
| 91 | + competition: str, local_path: str = "/data/userdata/share/kaggle/notebooks", num: int = 15 |
| 92 | +) -> None: |
| 93 | + data_path = Path(f"{local_path}/{competition}") |
| 94 | + from kaggle.api.kaggle_api_extended import KaggleApi |
| 95 | + |
| 96 | + api = KaggleApi() |
| 97 | + api.authenticate() |
| 98 | + |
| 99 | + # judge the sort_by |
| 100 | + ll = api.competition_leaderboard_view(competition) |
| 101 | + score_diff = float(ll[0].score) - float(ll[-1].score) |
| 102 | + if score_diff > 0: |
| 103 | + sort_by = "scoreDescending" |
| 104 | + else: |
| 105 | + sort_by = "scoreAscending" |
| 106 | + |
| 107 | + # download notebooks |
| 108 | + nl = api.kernels_list(competition=competition, sort_by=sort_by, page=1, page_size=num) |
| 109 | + for nb in nl: |
| 110 | + author = nb.ref.split("/")[0] |
| 111 | + api.kernels_pull(nb.ref, path=data_path / author) |
| 112 | + print(f"Downloaded {len(nl)} notebooks for {competition}. ([red]{sort_by}[/red])") |
| 113 | + |
| 114 | + |
| 115 | +def notebook_to_knowledge(notebook_text: str) -> str: |
| 116 | + prompt_dict = Prompts(file_path=Path(__file__).parent / "prompts.yaml") |
| 117 | + |
| 118 | + sys_prompt = ( |
| 119 | + Environment(undefined=StrictUndefined) |
| 120 | + .from_string(prompt_dict["gen_knowledge_from_code_DSAgent"]["system"]) |
| 121 | + .render() |
| 122 | + ) |
| 123 | + |
| 124 | + user_prompt = ( |
| 125 | + Environment(undefined=StrictUndefined) |
| 126 | + .from_string(prompt_dict["gen_knowledge_from_code_DSAgent"]["user"]) |
| 127 | + .render(notebook=notebook_text) |
| 128 | + ) |
| 129 | + |
| 130 | + response = APIBackend().build_messages_and_create_chat_completion( |
| 131 | + user_prompt=user_prompt, |
| 132 | + system_prompt=sys_prompt, |
| 133 | + json_mode=False, |
| 134 | + ) |
| 135 | + return response |
| 136 | + |
| 137 | + |
| 138 | +def convert_notebooks_to_text(competition: str, local_path: str = "/data/userdata/share/kaggle/notebooks") -> None: |
| 139 | + data_path = Path(f"{local_path}/{competition}") |
| 140 | + converted_num = 0 |
| 141 | + |
| 142 | + # convert ipynb and irnb files |
| 143 | + for nb_path in chain(data_path.glob("**/*.ipynb"), data_path.glob("**/*.irnb")): |
| 144 | + with nb_path.open("r", encoding="utf-8") as f: |
| 145 | + nb = nbformat.read(f, as_version=4) |
| 146 | + text = [] |
| 147 | + for cell in nb.cells: |
| 148 | + if cell.cell_type == "markdown": |
| 149 | + text.append(f"```markdown\n{cell.source}```") |
| 150 | + elif cell.cell_type == "code": |
| 151 | + text.append(f"```code\n{cell.source}```") |
| 152 | + text = "\n\n".join(text) |
| 153 | + |
| 154 | + text = notebook_to_knowledge(text) |
| 155 | + |
| 156 | + text_path = nb_path.with_suffix(".txt") |
| 157 | + text_path.write_text(text, encoding="utf-8") |
| 158 | + converted_num += 1 |
| 159 | + |
| 160 | + # convert py files |
| 161 | + for py_path in data_path.glob("**/*.py"): |
| 162 | + with py_path.open("r", encoding="utf-8") as f: |
| 163 | + text = f"```code\n{f.read()}```" |
| 164 | + |
| 165 | + text = notebook_to_knowledge(text) |
| 166 | + |
| 167 | + text_path = py_path.with_suffix(".txt") |
| 168 | + text_path.write_text(text, encoding="utf-8") |
| 169 | + converted_num += 1 |
| 170 | + |
| 171 | + print(f"Converted {converted_num} notebooks to text files.") |
| 172 | + |
| 173 | + |
| 174 | +def collect_knowledge_texts(local_path: str = "/data/userdata/share/kaggle") -> dict[str, list[str]]: |
| 175 | + """ |
| 176 | + { |
| 177 | + "competition1": [ |
| 178 | + "knowledge_text1", |
| 179 | + "knowledge_text2", |
| 180 | + ... |
| 181 | + ], |
| 182 | + “competition2”: [ |
| 183 | + "knowledge_text1", |
| 184 | + "knowledge_text2", |
| 185 | + ... |
| 186 | + ], |
| 187 | + ... |
| 188 | + } |
| 189 | + """ |
| 190 | + notebooks_dir = Path(local_path) / "notebooks" |
| 191 | + |
| 192 | + competition_knowledge_texts_dict = {} |
| 193 | + for competition_dir in notebooks_dir.iterdir(): |
| 194 | + knowledge_texts = [] |
| 195 | + for text_path in competition_dir.glob("**/*.txt"): |
| 196 | + text = text_path.read_text(encoding="utf-8") |
| 197 | + knowledge_texts.append(text) |
| 198 | + |
| 199 | + competition_knowledge_texts_dict[competition_dir.name] = knowledge_texts |
| 200 | + |
| 201 | + return competition_knowledge_texts_dict |
| 202 | + |
| 203 | + |
| 204 | +# %% |
82 | 205 | if __name__ == "__main__": |
83 | 206 | dsagent_cs = [ |
84 | 207 | "feedback-prize-english-language-learning", |
@@ -124,14 +247,16 @@ def download_data(competition: str, local_path: str = "/data/userdata/share/kagg |
124 | 247 | "store-sales-time-series-forecasting", |
125 | 248 | "titanic", |
126 | 249 | "tpu-getting-started", |
| 250 | + # scenario competition |
127 | 251 | "covid19-global-forecasting-week-1", |
128 | | - "birdsong-recognition", |
129 | | - "optiver-trading-at-the-close", |
| 252 | + "statoil-iceberg-classifier-challenge", |
| 253 | + "optiver-realized-volatility-prediction", |
130 | 254 | "facebook-v-predicting-check-ins", |
131 | 255 | ] |
132 | 256 |
|
133 | | - for i in dsagent_cs + other_cs: |
134 | | - crawl_descriptions(i) |
| 257 | + all_cs = dsagent_cs + other_cs |
| 258 | + for c in all_cs: |
| 259 | + convert_notebooks_to_text(c) |
135 | 260 | exit() |
136 | 261 | from kaggle.api.kaggle_api_extended import KaggleApi |
137 | 262 |
|
|
0 commit comments