Skip to content
This repository was archived by the owner on Mar 16, 2022. It is now read-only.

Commit dcc1dc1

Browse files
authored
Merge pull request #13 from datumbox/custom_weights_example
Adding an example of passing custom weights
2 parents c1e8511 + 51fdaff commit dcc1dc1

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

README.md

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,18 @@ as it has previously been discussed by the domain library maintainers.
5555

5656
In TorchVision, model versioning related issues have been handled on a case-by-case basis
5757
[[8](https://github.com/pytorch/vision/issues/2599), [9](https://github.com/pytorch/vision/issues/2326),
58-
[10](https://github.com/pytorch/vision/pull/3205), [11](https://github.com/pytorch/vision/pull/1224)]. We typically
59-
try to maintain BC as much as possible except in cases where the issue is considered a bug or extremely detrimental to
60-
the user experience. In the latter cases, we often deploy BC-breaking fixes but try to reduce the effects of the change
61-
as much as possible [[12](https://github.com/pytorch/vision/pull/2940)]. TorchAudio has only recently introduced the
62-
pre-trained models and their API is currently in beta. TorchText is currently working on introducing its first
63-
pre-trained models so their solution is in prototype phase. Finally, though PyTorch Core doesn't provide pre-trained
64-
models, occasionally there is the need of handling code changes on existing Layers and operators. These are typically
65-
handled with a mix of version parameters, deprecation warnings and method renamings
66-
[[13](https://github.com/pytorch/pytorch/blob/294db060/torch/nn/quantized/dynamic/modules/linear.py#L44-L49),
67-
[14](https://github.com/pytorch/pytorch/blob/c371542e/caffe2/python/cnn.py#L182-L183),
68-
[15](https://github.com/pytorch/pytorch/blob/c371542efc31b1abfe6f388042aa3ab0cef935f2/caffe2/python/brew.py#L65-L66)]
58+
[10](https://github.com/pytorch/vision/pull/3205), [11](https://github.com/pytorch/vision/pull/1224),
59+
[12](https://github.com/pytorch/vision/pull/4487)]. We typically try to maintain BC as much as possible except in cases
60+
where the issue is considered a bug or extremely detrimental to the user experience. In the latter cases, we often
61+
deploy BC-breaking fixes but try to reduce the effects of the change as much as possible
62+
[[13](https://github.com/pytorch/vision/pull/2940)]. TorchAudio has only recently introduced the pre-trained models and
63+
their API is currently in beta. TorchText is currently working on introducing its first pre-trained models so their
64+
solution is in prototype phase. Finally, though PyTorch Core doesn't provide pre-trained models, occasionally there is
65+
the need of handling code changes on existing Layers and operators. These are typically handled with a mix of version
66+
parameters, deprecation warnings and method renamings
67+
[[14](https://github.com/pytorch/pytorch/blob/294db060/torch/nn/quantized/dynamic/modules/linear.py#L44-L49),
68+
[15](https://github.com/pytorch/pytorch/blob/c371542e/caffe2/python/cnn.py#L182-L183),
69+
[16](https://github.com/pytorch/pytorch/blob/c371542efc31b1abfe6f388042aa3ab0cef935f2/caffe2/python/brew.py#L65-L66)]
6970

7071
## Repository structure
7172

examples/scenario1.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from functools import partial
2424
from torch import nn, Tensor
25-
from typing import Any, Optional
25+
from typing import Any, Optional, Union
2626

2727
from dapi_lib.models._api import register, Weights, WeightEntry
2828
from dapi_lib.transforms.vision_presets import ConvertImageDtype
@@ -110,7 +110,8 @@ class MySOTAV2Weights(Weights):
110110
# assessed on a case-by-case basis. See https://github.com/pytorch/vision/pull/1224 and
111111
# https://github.com/pytorch/pytorch/blob/294db060/torch/nn/quantized/dynamic/modules/linear.py#L44-L49
112112
@register
113-
def mysota_v2(weights: Optional[MySOTAV2Weights] = None, progress: bool = True, **kwargs: Any) -> MySOTA:
113+
def mysota_v2(weights: Optional[Union[MySOTAV2Weights, WeightEntry]] = None, progress: bool = True,
114+
**kwargs: Any) -> MySOTA:
114115
# Confirm we got the right weights
115116
MySOTAV2Weights.check_type(weights)
116117

@@ -134,3 +135,12 @@ def mysota_v2(weights: Optional[MySOTAV2Weights] = None, progress: bool = True,
134135
assert v1 == v2
135136

136137
mysota_v2(MySOTAV2Weights.NOTHOTDOG)
138+
139+
# Ability to introduce and pass custom weights
140+
custom_weights = WeightEntry(
141+
'https://fake/models/my_custom_weights_v2.pth',
142+
partial(ConvertImageDtype, dtype=torch.float16),
143+
{'size': (32, 32), 'classes': ['not burger', 'burger']},
144+
True
145+
)
146+
mysota_v2(custom_weights)

0 commit comments

Comments
 (0)