Skip to content

Commit 600a3f2

Browse files
authored
pre-download model weights in CI docs build (#5625)
* pre-download model weights in CI docs build * move changes into template * change docs image
1 parent 9edd22c commit 600a3f2

File tree

3 files changed

+56
-28
lines changed

3 files changed

+56
-28
lines changed

.circleci/config.yml

Lines changed: 24 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,26 @@ commands:
174174
- store_test_results:
175175
path: test-results
176176

177+
download_model_weights:
178+
parameters:
179+
extract_roots:
180+
type: string
181+
default: "torchvision/models"
182+
background:
183+
type: boolean
184+
default: true
185+
steps:
186+
- apt_install:
187+
args: parallel wget
188+
descr: Install download utilitites
189+
- run:
190+
name: Download model weights
191+
background: << parameters.background >>
192+
command: |
193+
mkdir -p ~/.cache/torch/hub/checkpoints
194+
python scripts/collect_model_urls.py << parameters.extract_roots >> \
195+
| parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci'
196+
177197
binary_common: &binary_common
178198
parameters:
179199
# Edit these defaults to do a release
@@ -340,14 +360,8 @@ jobs:
340360
resource_class: xlarge
341361
steps:
342362
- checkout
343-
- run:
344-
name: Download model weights
345-
background: true
346-
command: |
347-
sudo apt update -qy && sudo apt install -qy parallel wget
348-
mkdir -p ~/.cache/torch/hub/checkpoints
349-
python scripts/collect_model_urls.py torchvision/prototype/models \
350-
| parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci'
363+
- download_model_weights:
364+
extract_roots: torchvision/prototype/models
351365
- install_torchvision
352366
- install_prototype_dependencies
353367
- pip_install:
@@ -1011,12 +1025,13 @@ jobs:
10111025
build_docs:
10121026
<<: *binary_common
10131027
docker:
1014-
- image: "pytorch/manylinux-cuda100"
1028+
- image: circleci/python:3.7
10151029
resource_class: 2xlarge+
10161030
steps:
10171031
- attach_workspace:
10181032
at: ~/workspace
10191033
- checkout
1034+
- download_model_weights
10201035
- run:
10211036
name: Setup
10221037
command: .circleci/unittest/linux/scripts/setup_env.sh

scripts/collect_model_urls.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,19 @@
22
import re
33
import sys
44

5-
MODEL_URL_PATTERN = re.compile(r"https://download[.]pytorch[.]org/models/.*?[.]pth")
5+
MODEL_URL_PATTERN = re.compile(r"https://download[.]pytorch[.]org/models/.+?[.]pth")
66

77

8-
def main(root):
8+
def main(*roots):
99
model_urls = set()
10-
for path in pathlib.Path(root).glob("**/*"):
11-
if path.name.startswith("_") or not path.suffix == ".py":
12-
continue
13-
14-
with open(path, "r") as file:
15-
for line in file:
16-
model_urls.update(MODEL_URL_PATTERN.findall(line))
10+
for root in roots:
11+
for path in pathlib.Path(root).rglob("*.py"):
12+
with open(path, "r") as file:
13+
for line in file:
14+
model_urls.update(MODEL_URL_PATTERN.findall(line))
1715

1816
print("\n".join(sorted(model_urls)))
1917

2018

2119
if __name__ == "__main__":
22-
main(sys.argv[1])
20+
main(*sys.argv[1:])

0 commit comments

Comments
 (0)