Skip to content

Commit b7ed683

Browse files
committed
Try to format code as in #5106
1 parent 94c7dde commit b7ed683

File tree

188 files changed

+553
-740
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

188 files changed

+553
-740
lines changed

references/classification/train_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torchvision
1010
import utils
1111
from torch import nn
12-
from train import train_one_epoch, evaluate, load_data
12+
from train import evaluate, load_data, train_one_epoch
1313

1414

1515
def main(args):

references/detection/group_by_aspect_ratio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import copy
33
import math
44
from collections import defaultdict
5-
from itertools import repeat, chain
5+
from itertools import chain, repeat
66

77
import numpy as np
88
import torch

references/detection/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
import torchvision.models.detection.mask_rcnn
3030
import utils
3131
from coco_utils import get_coco, get_coco_kp
32-
from engine import train_one_epoch, evaluate
33-
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
32+
from engine import evaluate, train_one_epoch
33+
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
3434
from torchvision.transforms import InterpolationMode
3535
from transforms import SimpleCopyPaste
3636

references/detection/transforms.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from typing import List, Tuple, Dict, Optional, Union
1+
from typing import Dict, List, Optional, Tuple, Union
22

33
import torch
44
import torchvision
55
from torch import nn, Tensor
66
from torchvision import ops
7-
from torchvision.transforms import functional as F
8-
from torchvision.transforms import transforms as T, InterpolationMode
7+
from torchvision.transforms import functional as F, InterpolationMode, transforms as T
98

109

1110
def _flip_coco_person_keypoints(kps, width):

references/optical_flow/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import torch
77
import torchvision.models.optical_flow
88
import utils
9-
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
10-
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
9+
from presets import OpticalFlowPresetEval, OpticalFlowPresetTrain
10+
from torchvision.datasets import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
1111

1212

1313
def get_train_dataset(stage, dataset_root):

references/optical_flow/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import datetime
22
import os
33
import time
4-
from collections import defaultdict
5-
from collections import deque
4+
from collections import defaultdict, deque
65

76
import torch
87
import torch.distributed as dist
@@ -158,7 +157,7 @@ def log_every(self, iterable, print_freq=5, header=None):
158157
def compute_metrics(flow_pred, flow_gt, valid_flow_mask=None):
159158

160159
epe = ((flow_pred - flow_gt) ** 2).sum(dim=1).sqrt()
161-
flow_norm = (flow_gt ** 2).sum(dim=1).sqrt()
160+
flow_norm = (flow_gt**2).sum(dim=1).sqrt()
162161

163162
if valid_flow_mask is not None:
164163
epe = epe[valid_flow_mask]
@@ -183,7 +182,7 @@ def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400)
183182
raise ValueError(f"Gamma should be < 1, got {gamma}.")
184183

185184
# exlude invalid pixels and extremely large diplacements
186-
flow_norm = torch.sum(flow_gt ** 2, dim=1).sqrt()
185+
flow_norm = torch.sum(flow_gt**2, dim=1).sqrt()
187186
valid_flow_mask = valid_flow_mask & (flow_norm < max_flow)
188187

189188
valid_flow_mask = valid_flow_mask[:, None, :, :]

references/segmentation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def update(self, a, b):
7575
with torch.inference_mode():
7676
k = (a >= 0) & (a < n)
7777
inds = n * a[k].to(torch.int64) + b[k]
78-
self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)
78+
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
7979

8080
def reset(self):
8181
self.mat.zero_()

test/builtin_dataset_mocks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
import unittest.mock
1515
import warnings
1616
import xml.etree.ElementTree as ET
17-
from collections import defaultdict, Counter
17+
from collections import Counter, defaultdict
1818

1919
import numpy as np
2020
import pytest
2121
import torch
22-
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid
22+
from datasets_utils import combinations_grid, create_image_file, create_image_folder, make_tar, make_zip
2323
from torch.nn.functional import one_hot
2424
from torch.testing import make_tensor as _make_tensor
2525
from torchvision.prototype import datasets

test/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pytest
55
import torch
6-
from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG
6+
from common_utils import CIRCLECI_GPU_NO_CUDA_MSG, CUDA_NOT_AVAILABLE_MSG, IN_CIRCLE_CI, IN_FBCODE, IN_RE_WORKER
77

88

99
def pytest_configure(config):

test/datasets_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323
import torchvision.datasets
2424
import torchvision.io
25-
from common_utils import get_tmp_dir, disable_console_output
25+
from common_utils import disable_console_output, get_tmp_dir
2626

2727

2828
__all__ = [

test/test_datasets_download.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
from os import path
1010
from urllib.error import HTTPError, URLError
1111
from urllib.parse import urlparse
12-
from urllib.request import urlopen, Request
12+
from urllib.request import Request, urlopen
1313

1414
import pytest
1515
from torchvision import datasets
1616
from torchvision.datasets.utils import (
17-
download_url,
17+
_get_redirect_url,
1818
check_integrity,
1919
download_file_from_google_drive,
20-
_get_redirect_url,
20+
download_url,
2121
USER_AGENT,
2222
)
2323

test/test_datasets_samplers.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
import pytest
22
import torch
3-
from common_utils import get_list_of_videos, assert_equal
3+
from common_utils import assert_equal, get_list_of_videos
44
from torchvision import io
5-
from torchvision.datasets.samplers import (
6-
DistributedSampler,
7-
RandomClipSampler,
8-
UniformClipSampler,
9-
)
5+
from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler
106
from torchvision.datasets.video_utils import VideoClips
117

128

test/test_datasets_video_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import pytest
22
import torch
3-
from common_utils import get_list_of_videos, assert_equal
3+
from common_utils import assert_equal, get_list_of_videos
44
from torchvision import io
5-
from torchvision.datasets.video_utils import VideoClips, unfold
5+
from torchvision.datasets.video_utils import unfold, VideoClips
66

77

88
class TestVideo:

test/test_extended_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import test_models as TM
66
import torch
77
from torchvision import models
8-
from torchvision.models._api import WeightsEnum, Weights
8+
from torchvision.models._api import Weights, WeightsEnum
99
from torchvision.models._utils import handle_legacy_interface
1010

1111

test/test_functional_tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
import torchvision.transforms.functional_pil as F_pil
1515
import torchvision.transforms.functional_tensor as F_t
1616
from common_utils import (
17-
cpu_and_gpu,
18-
needs_cuda,
17+
_assert_approx_equal_tensor_to_pil,
18+
_assert_equal_tensor_to_pil,
1919
_create_data,
2020
_create_data_batch,
21-
_assert_equal_tensor_to_pil,
22-
_assert_approx_equal_tensor_to_pil,
2321
_test_fn_on_batch,
2422
assert_equal,
23+
cpu_and_gpu,
24+
needs_cuda,
2525
)
2626
from torchvision.transforms import InterpolationMode
2727

test/test_image.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,21 @@
88
import pytest
99
import torch
1010
import torchvision.transforms.functional as F
11-
from common_utils import needs_cuda, assert_equal
12-
from PIL import Image, __version__ as PILLOW_VERSION
11+
from common_utils import assert_equal, needs_cuda
12+
from PIL import __version__ as PILLOW_VERSION, Image
1313
from torchvision.io.image import (
14-
decode_png,
14+
_read_png_16,
15+
decode_image,
1516
decode_jpeg,
17+
decode_png,
1618
encode_jpeg,
17-
write_jpeg,
18-
decode_image,
19-
read_file,
2019
encode_png,
21-
write_png,
22-
write_file,
2320
ImageReadMode,
21+
read_file,
2422
read_image,
25-
_read_png_16,
23+
write_file,
24+
write_jpeg,
25+
write_png,
2626
)
2727

2828
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
@@ -168,7 +168,7 @@ def test_decode_png(img_path, pil_mode, mode):
168168
img_lpng = _read_png_16(img_path, mode=mode)
169169
assert img_lpng.dtype == torch.int32
170170
# PIL converts 16 bits pngs in uint8
171-
img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8)
171+
img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
172172
else:
173173
data = read_file(img_path)
174174
img_lpng = decode_image(data, mode=mode)

test/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.fx
1515
import torch.nn as nn
1616
from _utils_internal import get_relative_path
17-
from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda
17+
from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
1818
from torchvision import models
1919

2020
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"

test/test_models_detection_negative_samples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from common_utils import assert_equal
55
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
66
from torchvision.models.detection.roi_heads import RoIHeads
7-
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
7+
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
88
from torchvision.ops import MultiScaleRoIAlign
99

1010

@@ -60,7 +60,7 @@ def test_assign_targets_to_proposals(self):
6060

6161
resolution = box_roi_pool.output_size[0]
6262
representation_size = 1024
63-
box_head = TwoMLPHead(4 * resolution ** 2, representation_size)
63+
box_head = TwoMLPHead(4 * resolution**2, representation_size)
6464

6565
representation_size = 1024
6666
box_predictor = FastRCNNPredictor(representation_size, 2)

test/test_onnx.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44

55
import pytest
66
import torch
7-
from common_utils import set_rng_seed, assert_equal
8-
from torchvision import models
9-
from torchvision import ops
7+
from common_utils import assert_equal, set_rng_seed
8+
from torchvision import models, ops
109
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
1110
from torchvision.models.detection.image_list import ImageList
1211
from torchvision.models.detection.roi_heads import RoIHeads
13-
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
12+
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
1413
from torchvision.models.detection.transform import GeneralizedRCNNTransform
1514
from torchvision.ops._register_onnx_ops import _onnx_opset_version
1615

@@ -265,7 +264,7 @@ def _init_test_roi_heads_faster_rcnn(self):
265264

266265
resolution = box_roi_pool.output_size[0]
267266
representation_size = 1024
268-
box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
267+
box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
269268

270269
representation_size = 1024
271270
box_predictor = FastRCNNPredictor(representation_size, num_classes)

test/test_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar
7979
rois_dtype = self.dtype if rois_dtype is None else rois_dtype
8080
pool_size = 5
8181
# n_channels % (pool_size ** 2) == 0 required for PS opeartions.
82-
n_channels = 2 * (pool_size ** 2)
82+
n_channels = 2 * (pool_size**2)
8383
x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
8484
if not contiguous:
8585
x = x.permute(0, 1, 3, 2)
@@ -115,7 +115,7 @@ def test_is_leaf_node(self, device):
115115
def test_backward(self, seed, device, contiguous):
116116
torch.random.manual_seed(seed)
117117
pool_size = 2
118-
x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
118+
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
119119
if not contiguous:
120120
x = x.permute(0, 1, 3, 2)
121121
rois = torch.tensor(

test/test_prototype_builtin_datasets.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66
import pytest
77
import torch
8-
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
9-
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
8+
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
9+
from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair
1010
from torch.utils.data import DataLoader
1111
from torch.utils.data.graph import traverse
1212
from torch.utils.data.graph_settings import get_all_graph_pipes
13-
from torchdata.datapipes.iter import Shuffler, ShardingFilter
13+
from torchdata.datapipes.iter import ShardingFilter, Shuffler
1414
from torchvision._utils import sequence_to_str
15-
from torchvision.prototype import transforms, datasets
15+
from torchvision.prototype import datasets, transforms
1616
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
1717
from torchvision.prototype.features import Image, Label
1818

test/test_prototype_datasets_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from torchdata.datapipes.iter import FileOpener, TarArchiveLoader
1010
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
1111
from torchvision.datasets.utils import _decompress
12-
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset, OnlineResource
13-
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
12+
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, HttpResource, OnlineResource
13+
from torchvision.prototype.datasets.utils._internal import fromfile, read_flo
1414

1515

1616
@pytest.mark.filterwarnings("error:The given NumPy array is not writeable:UserWarning")

test/test_prototype_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import test_models as TM
33
import torch
44
import torchvision.prototype.models.depth.stereo.raft_stereo as raft_stereo
5-
from common_utils import set_rng_seed, cpu_and_gpu
5+
from common_utils import cpu_and_gpu, set_rng_seed
66

77

88
@pytest.mark.parametrize("model_builder", (raft_stereo.raft_stereo_base, raft_stereo.raft_stereo_realtime))

test/test_prototype_transforms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
import torch
55
from common_utils import assert_equal
66
from test_prototype_transforms_functional import (
7-
make_images,
8-
make_bounding_boxes,
97
make_bounding_box,
10-
make_one_hot_labels,
8+
make_bounding_boxes,
9+
make_images,
1110
make_label,
11+
make_one_hot_labels,
1212
make_segmentation_mask,
1313
)
14-
from torchvision.prototype import transforms, features
15-
from torchvision.transforms.functional import to_pil_image, pil_to_tensor, InterpolationMode
14+
from torchvision.prototype import features, transforms
15+
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
1616

1717

1818
def make_vanilla_tensor_images(*args, **kwargs):

test/test_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
except ImportError:
2525
stats = None
2626

27-
from common_utils import cycle_over, int_dtypes, float_dtypes, assert_equal
27+
from common_utils import assert_equal, cycle_over, float_dtypes, int_dtypes
2828

2929

3030
GRACE_HOPPER = get_file_path_2(

0 commit comments

Comments
 (0)