Skip to content

Commit c31baed

Browse files
authored
[stubsabot] add script (#8035)
1 parent c383e95 commit c31baed

File tree

1 file changed

+360
-0
lines changed

1 file changed

+360
-0
lines changed

scripts/stubsabot.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
#!/usr/bin/env python3
2+
from __future__ import annotations
3+
4+
import argparse
5+
import asyncio
6+
import datetime
7+
import enum
8+
import functools
9+
import io
10+
import os
11+
import re
12+
import subprocess
13+
import sys
14+
import tarfile
15+
import urllib.parse
16+
import zipfile
17+
from dataclasses import dataclass
18+
from pathlib import Path
19+
from typing import Any
20+
21+
import aiohttp
22+
import packaging.specifiers
23+
import packaging.version
24+
import tomli
25+
import tomlkit
26+
27+
28+
class ActionLevel(enum.IntEnum):
29+
def __new__(cls, value: int, doc: str):
30+
member = int.__new__(cls, value)
31+
member._value_ = value
32+
member.__doc__ = doc
33+
return member
34+
35+
nothing = 0, "make no changes"
36+
local = 1, "make changes that affect local repo"
37+
everything = 2, "do everything, e.g. open PRs"
38+
39+
40+
@dataclass
41+
class StubInfo:
42+
distribution: str
43+
version_spec: str
44+
obsolete: bool
45+
no_longer_updated: bool
46+
47+
48+
def read_typeshed_stub_metadata(stub_path: Path) -> StubInfo:
49+
with (stub_path / "METADATA.toml").open("rb") as f:
50+
meta = tomli.load(f)
51+
return StubInfo(
52+
distribution=stub_path.name,
53+
version_spec=meta["version"],
54+
obsolete="obsolete_since" in meta,
55+
no_longer_updated=meta.get("no_longer_updated", False),
56+
)
57+
58+
59+
@dataclass
60+
class PypiInfo:
61+
distribution: str
62+
version: packaging.version.Version
63+
upload_date: datetime.datetime
64+
# https://warehouse.pypa.io/api-reference/json.html#get--pypi--project_name--json
65+
# Corresponds to a single entry from `releases` for the given version
66+
release_to_download: dict[str, Any]
67+
68+
69+
async def fetch_pypi_info(distribution: str, session: aiohttp.ClientSession) -> PypiInfo:
70+
url = f"https://pypi.org/pypi/{urllib.parse.quote(distribution)}/json"
71+
async with session.get(url) as response:
72+
response.raise_for_status()
73+
j = await response.json()
74+
version = j["info"]["version"]
75+
# prefer wheels, since it's what most users will get / it's pretty easy to mess up MANIFEST
76+
release_to_download = sorted(j["releases"][version], key=lambda x: bool(x["packagetype"] == "bdist_wheel"))[-1]
77+
date = datetime.datetime.fromisoformat(release_to_download["upload_time"])
78+
return PypiInfo(
79+
distribution=distribution,
80+
version=packaging.version.Version(version),
81+
upload_date=date,
82+
release_to_download=release_to_download,
83+
)
84+
85+
86+
@dataclass
87+
class Update:
88+
distribution: str
89+
stub_path: Path
90+
old_version_spec: str
91+
new_version_spec: str
92+
93+
def __str__(self) -> str:
94+
return f"Updating {self.distribution} from {self.old_version_spec!r} to {self.new_version_spec!r}"
95+
96+
97+
@dataclass
98+
class Obsolete:
99+
distribution: str
100+
stub_path: Path
101+
obsolete_since_version: str
102+
obsolete_since_date: datetime.datetime
103+
104+
def __str__(self) -> str:
105+
return f"Marking {self.distribution} as obsolete since {self.obsolete_since_version!r}"
106+
107+
108+
@dataclass
109+
class NoUpdate:
110+
distribution: str
111+
reason: str
112+
113+
def __str__(self) -> str:
114+
return f"Skipping {self.distribution}: {self.reason}"
115+
116+
117+
async def package_contains_py_typed(release_to_download: dict[str, Any], session: aiohttp.ClientSession) -> bool:
118+
async with session.get(release_to_download["url"]) as response:
119+
body = io.BytesIO(await response.read())
120+
121+
packagetype = release_to_download["packagetype"]
122+
if packagetype == "bdist_wheel":
123+
assert release_to_download["filename"].endswith(".whl")
124+
with zipfile.ZipFile(body) as zf:
125+
return any(Path(f).name == "py.typed" for f in zf.namelist())
126+
elif packagetype == "sdist":
127+
assert release_to_download["filename"].endswith(".tar.gz")
128+
with tarfile.open(fileobj=body, mode="r:gz") as zf:
129+
return any(Path(f).name == "py.typed" for f in zf.getnames())
130+
else:
131+
raise AssertionError(f"Unknown package type: {packagetype}")
132+
133+
134+
def _check_spec(updated_spec: str, version: packaging.version.Version) -> str:
135+
assert version in packaging.specifiers.SpecifierSet("==" + updated_spec), f"{version} not in {updated_spec}"
136+
return updated_spec
137+
138+
139+
def get_updated_version_spec(spec: str, version: packaging.version.Version) -> str:
140+
"""
141+
Given the old specifier and an updated version, returns an updated specifier that has the
142+
specificity of the old specifier, but matches the updated version.
143+
144+
For example:
145+
spec="1", version="1.2.3" -> "1.2.3"
146+
spec="1.0.1", version="1.2.3" -> "1.2.3"
147+
spec="1.*", version="1.2.3" -> "1.*"
148+
spec="1.*", version="2.3.4" -> "2.*"
149+
spec="1.1.*", version="1.2.3" -> "1.2.*"
150+
spec="1.1.1.*", version="1.2.3" -> "1.2.3.*"
151+
"""
152+
if not spec.endswith(".*"):
153+
return _check_spec(version.base_version, version)
154+
155+
specificity = spec.count(".") if spec.removesuffix(".*") else 0
156+
rounded_version = version.base_version.split(".")[:specificity]
157+
rounded_version.extend(["0"] * (specificity - len(rounded_version)))
158+
159+
return _check_spec(".".join(rounded_version) + ".*", version)
160+
161+
162+
async def determine_action(stub_path: Path, session: aiohttp.ClientSession) -> Update | NoUpdate | Obsolete:
163+
stub_info = read_typeshed_stub_metadata(stub_path)
164+
if stub_info.obsolete:
165+
return NoUpdate(stub_info.distribution, "obsolete")
166+
if stub_info.no_longer_updated:
167+
return NoUpdate(stub_info.distribution, "no longer updated")
168+
169+
pypi_info = await fetch_pypi_info(stub_info.distribution, session)
170+
spec = packaging.specifiers.SpecifierSet("==" + stub_info.version_spec)
171+
if pypi_info.version in spec:
172+
return NoUpdate(stub_info.distribution, "up to date")
173+
174+
if await package_contains_py_typed(pypi_info.release_to_download, session):
175+
return Obsolete(
176+
stub_info.distribution,
177+
stub_path,
178+
obsolete_since_version=str(pypi_info.version),
179+
obsolete_since_date=pypi_info.upload_date,
180+
)
181+
182+
return Update(
183+
distribution=stub_info.distribution,
184+
stub_path=stub_path,
185+
old_version_spec=stub_info.version_spec,
186+
new_version_spec=get_updated_version_spec(stub_info.version_spec, pypi_info.version),
187+
)
188+
189+
190+
TYPESHED_OWNER = "python"
191+
192+
193+
@functools.lru_cache()
194+
def get_origin_owner():
195+
output = subprocess.check_output(["git", "remote", "get-url", "origin"], text=True)
196+
match = re.search(r"([email protected]:|https://github.com/)(?P<owner>[^/]+)/(?P<repo>[^/]+).git", output)
197+
assert match is not None
198+
assert match.group("repo") == "typeshed"
199+
return match.group("owner")
200+
201+
202+
async def create_or_update_pull_request(*, title: str, body: str, branch_name: str, session: aiohttp.ClientSession):
203+
secret = os.environ["GITHUB_TOKEN"]
204+
if secret.startswith("ghp"):
205+
auth = f"token {secret}"
206+
else:
207+
auth = f"Bearer {secret}"
208+
209+
fork_owner = get_origin_owner()
210+
211+
async with session.post(
212+
f"https://api.github.com/repos/{TYPESHED_OWNER}/typeshed/pulls",
213+
json={"title": title, "body": body, "head": f"{fork_owner}:{branch_name}", "base": "master"},
214+
headers={"Accept": "application/vnd.github.v3+json", "Authorization": auth},
215+
) as response:
216+
resp_json = await response.json()
217+
if response.status == 422 and any(
218+
"A pull request already exists" in e.get("message", "") for e in resp_json.get("errors", [])
219+
):
220+
# Find the existing PR
221+
async with session.get(
222+
f"https://api.github.com/repos/{TYPESHED_OWNER}/typeshed/pulls",
223+
params={"state": "open", "head": f"{fork_owner}:{branch_name}", "base": "master"},
224+
headers={"Accept": "application/vnd.github.v3+json", "Authorization": auth},
225+
) as response:
226+
response.raise_for_status()
227+
resp_json = await response.json()
228+
assert len(resp_json) >= 1
229+
pr_number = resp_json[0]["number"]
230+
# Update the PR's title and body
231+
async with session.patch(
232+
f"https://api.github.com/repos/{TYPESHED_OWNER}/typeshed/pulls/{pr_number}",
233+
json={"title": title, "body": body},
234+
headers={"Accept": "application/vnd.github.v3+json", "Authorization": auth},
235+
) as response:
236+
response.raise_for_status()
237+
return
238+
response.raise_for_status()
239+
240+
241+
def normalize(name: str) -> str:
242+
# PEP 503 normalization
243+
return re.sub(r"[-_.]+", "-", name).lower()
244+
245+
246+
# lock should be unnecessary, but can't hurt to enforce mutual exclusion
247+
_repo_lock = asyncio.Lock()
248+
249+
BRANCH_PREFIX = "stubsabot"
250+
251+
252+
async def suggest_typeshed_update(update: Update, session: aiohttp.ClientSession, action_level: ActionLevel) -> None:
253+
if action_level <= ActionLevel.nothing:
254+
return
255+
title = f"[stubsabot] Bump {update.distribution} to {update.new_version_spec}"
256+
async with _repo_lock:
257+
branch_name = f"{BRANCH_PREFIX}/{normalize(update.distribution)}"
258+
subprocess.check_call(["git", "checkout", "-B", branch_name, "origin/master"])
259+
with open(update.stub_path / "METADATA.toml", "rb") as f:
260+
meta = tomlkit.load(f)
261+
meta["version"] = update.new_version_spec
262+
with open(update.stub_path / "METADATA.toml", "w") as f:
263+
tomlkit.dump(meta, f)
264+
subprocess.check_call(["git", "commit", "--all", "-m", title])
265+
if action_level <= ActionLevel.local:
266+
return
267+
subprocess.check_call(["git", "push", "origin", branch_name, "--force-with-lease"])
268+
269+
body = """\
270+
If stubtest fails for this PR:
271+
- Leave this PR open (as a reminder, and to prevent stubsabot from opening another PR)
272+
- Fix stubtest failures in another PR, then close this PR
273+
"""
274+
await create_or_update_pull_request(title=title, body=body, branch_name=branch_name, session=session)
275+
276+
277+
async def suggest_typeshed_obsolete(obsolete: Obsolete, session: aiohttp.ClientSession, action_level: ActionLevel) -> None:
278+
if action_level <= ActionLevel.nothing:
279+
return
280+
title = f"[stubsabot] Mark {obsolete.distribution} as obsolete since {obsolete.obsolete_since_version}"
281+
async with _repo_lock:
282+
branch_name = f"{BRANCH_PREFIX}/{normalize(obsolete.distribution)}"
283+
subprocess.check_call(["git", "checkout", "-B", branch_name, "origin/master"])
284+
with open(obsolete.stub_path / "METADATA.toml", "rb") as f:
285+
meta = tomlkit.load(f)
286+
obs_string = tomlkit.string(obsolete.obsolete_since_version)
287+
obs_string.comment(f"Released on {obsolete.obsolete_since_date.date().isoformat()}")
288+
meta["obsolete_since"] = obs_string
289+
with open(obsolete.stub_path / "METADATA.toml", "w") as f:
290+
tomlkit.dump(meta, f)
291+
subprocess.check_call(["git", "commit", "--all", "-m", title])
292+
if action_level <= ActionLevel.local:
293+
return
294+
subprocess.check_call(["git", "push", "origin", branch_name, "--force-with-lease"])
295+
296+
await create_or_update_pull_request(title=title, body="", branch_name=branch_name, session=session)
297+
298+
299+
async def main() -> None:
300+
assert sys.version_info >= (3, 9)
301+
302+
parser = argparse.ArgumentParser()
303+
parser.add_argument(
304+
"--action-level",
305+
type=lambda x: getattr(ActionLevel, x), # type: ignore[no-any-return]
306+
default=ActionLevel.everything,
307+
help="Limit actions performed to achieve dry runs for different levels of dryness",
308+
)
309+
parser.add_argument(
310+
"--action-count-limit",
311+
type=int,
312+
default=None,
313+
help="Limit number of actions performed and the remainder are logged. Useful for testing",
314+
)
315+
args = parser.parse_args()
316+
317+
if args.action_level > ActionLevel.local:
318+
if os.environ.get("GITHUB_TOKEN") is None:
319+
raise ValueError("GITHUB_TOKEN environment variable must be set")
320+
321+
denylist = {"gdb"} # gdb is not a pypi distribution
322+
323+
try:
324+
conn = aiohttp.TCPConnector(limit_per_host=10)
325+
async with aiohttp.ClientSession(connector=conn) as session:
326+
tasks = [
327+
asyncio.create_task(determine_action(stubs_path, session))
328+
for stubs_path in Path("stubs").iterdir()
329+
if stubs_path.name not in denylist
330+
]
331+
332+
action_count = 0
333+
for task in asyncio.as_completed(tasks):
334+
update = await task
335+
print(update)
336+
337+
if isinstance(update, NoUpdate):
338+
continue
339+
340+
if args.action_count_limit is not None and action_count >= args.action_count_limit:
341+
print("... but we've reached action count limit")
342+
continue
343+
action_count += 1
344+
345+
if isinstance(update, Update):
346+
await suggest_typeshed_update(update, session, action_level=args.action_level)
347+
continue
348+
if isinstance(update, Obsolete):
349+
await suggest_typeshed_obsolete(update, session, action_level=args.action_level)
350+
continue
351+
raise AssertionError
352+
finally:
353+
# if you need to cleanup, try:
354+
# git branch -D $(git branch --list 'stubsabot/*')
355+
if args.action_level >= ActionLevel.local:
356+
subprocess.check_call(["git", "checkout", "master"])
357+
358+
359+
if __name__ == "__main__":
360+
asyncio.run(main())

0 commit comments

Comments
 (0)