|
| 1 | +import json |
| 2 | +import sys |
| 3 | +from typing import Any, Dict, Set, Tuple |
| 4 | + |
| 5 | +import requests |
| 6 | + |
| 7 | +# If the PR has any of these labels, we mark it as properly labeled. |
| 8 | +REQUIRED_LABELS = { |
| 9 | + "new feature", |
| 10 | + "bug", |
| 11 | + "code quality", |
| 12 | + "enhancement", |
| 13 | + "bc-breaking", |
| 14 | + "dependency issue", |
| 15 | + "deprecation", |
| 16 | + "module: c++ frontend", |
| 17 | + "module: ci", |
| 18 | + "module: datasets", |
| 19 | + "module: documentation", |
| 20 | + "module: io", |
| 21 | + "module: models.quantization", |
| 22 | + "module: models", |
| 23 | + "module: onnx", |
| 24 | + "module: ops", |
| 25 | + "module: reference scripts", |
| 26 | + "module: rocm", |
| 27 | + "module: tests", |
| 28 | + "module: transforms", |
| 29 | + "module: utils", |
| 30 | + "module: video", |
| 31 | + "Perf", |
| 32 | + "Revert(ed)", |
| 33 | +} |
| 34 | + |
| 35 | + |
| 36 | +def main(commit_hash: str) -> Dict[str, Any]: |
| 37 | + pr_number = get_pr_number(commit_hash) |
| 38 | + merger, labels = get_pr_merger_and_labels(pr_number) |
| 39 | + is_properly_labeled = bool(REQUIRED_LABELS.intersection(labels)) |
| 40 | + if not is_properly_labeled: |
| 41 | + users = {merger, *get_pr_reviewers(pr_number)} |
| 42 | + else: |
| 43 | + users = () |
| 44 | + return dict( |
| 45 | + is_properly_labeled=is_properly_labeled, |
| 46 | + responsible_users=", ".join(sorted([f"@{user}" for user in users])), |
| 47 | + ) |
| 48 | + |
| 49 | + |
| 50 | +def _query_torchvision(cmd: str, *, accept) -> Any: |
| 51 | + response = requests.get(f"https://api.github.com/repos/pytorch/vision/{cmd}", headers=dict(Accept=accept)) |
| 52 | + return response.json() |
| 53 | + |
| 54 | + |
| 55 | +def get_pr_number(commit_hash: str) -> int: |
| 56 | + # See https://docs.github.com/en/rest/reference/repos#list-pull-requests-associated-with-a-commit |
| 57 | + data = _query_torchvision(f"commits/{commit_hash}/pulls", accept="application/vnd.github.groot-preview+json") |
| 58 | + return data[0]["number"] |
| 59 | + |
| 60 | + |
| 61 | +def get_pr_merger_and_labels(pr_number: int) -> Tuple[str, Set[str]]: |
| 62 | + # See https://docs.github.com/en/rest/reference/pulls#get-a-pull-request |
| 63 | + data = _query_torchvision(f"pulls/{pr_number}", accept="application/vnd.github.v3+json") |
| 64 | + merger = data["merged_by"]["login"] |
| 65 | + labels = {label["name"] for label in data["labels"]} |
| 66 | + return merger, labels |
| 67 | + |
| 68 | + |
| 69 | +def get_pr_reviewers(pr_number: int) -> Set[str]: |
| 70 | + # See https://docs.github.com/en/rest/reference/pulls#list-reviews-for-a-pull-request |
| 71 | + data = _query_torchvision(f"pulls/{pr_number}/reviews", accept="application/vnd.github.v3+json") |
| 72 | + return {review["user"]["login"] for review in data if review["state"] == "APPROVED"} |
| 73 | + |
| 74 | + |
| 75 | +if __name__ == "__main__": |
| 76 | + commit_hash = sys.argv[1] |
| 77 | + data = main(commit_hash) |
| 78 | + print(json.dumps(data)) |
0 commit comments