From dce22b3407d5854c14c3b06a59fd63c0bad8cc06 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 8 Dec 2021 14:52:40 +0000 Subject: [PATCH 1/4] Add pretrained weights on Chairs and Things for raft_large --- references/optical_flow/README.md | 55 ++++++++++ references/optical_flow/train.py | 32 +++++- test/test_prototype_models.py | 7 +- torchvision/models/optical_flow/raft.py | 28 +++-- .../prototype/models/optical_flow/raft.py | 102 ++++++++++++++---- 5 files changed, 192 insertions(+), 32 deletions(-) create mode 100644 references/optical_flow/README.md diff --git a/references/optical_flow/README.md b/references/optical_flow/README.md new file mode 100644 index 00000000000..d0963f44317 --- /dev/null +++ b/references/optical_flow/README.md @@ -0,0 +1,55 @@ +# Optical flow reference training scripts + +This folder contains reference training scripts for optical flow. +They serve as a log of how to train specific models, so as to provide baseline +training and evaluation scripts to quickly bootstrap research. + + +### RAFT Large + +The RAFT large model was trained on Flying Chairs and then on Flying Things. +Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The +rest of the hyper-parameters are exactly the same as the original RAFT training +recipe from https://github.com/princeton-vl/RAFT. + +``` +torchrun --nproc_per_node 8 --nnodes 1 train.py \ + --dataset-root $dataset_root \ + --name $name_chairs \ + --train-dataset chairs \ + --batch-size 2 \ + --lr 0.0004 \ + --weight-decay 0.0001 \ + --num-steps 100000 \ + --output-dir $chairs_dir +``` + +``` +torchrun --nproc_per_node 8 --nnodes 1 train.py \ + --dataset-root $dataset_root \ + --name $name_things \ + --train-dataset things \ + --batch-size 2 \ + --lr 0.000125 \ + --weight-decay 0.0001 \ + --num-steps 100000 \ + --freeze-batch-norm \ + --output-dir $things_dir\ + --resume $chairs_dir/$name_chairs.pth +``` + + +### Evaluation + +``` +torchrun --nproc_per_node 8 --nnodes 1 train.py --val-dataset sintel --batch-size 10 --dataset-root $dataset_root --model raft_large --pretrained +``` + +This should give an epe of about 1.3825 on the clean pass and 2.7148 on the +final pass of Sintel. Results may vary slightly depending on the batch size and +the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`. + +``` +Sintel val clean epe: 1.3825 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3782 f1: 4.0234 +Sintel val final epe: 2.7148 1px: 0.8526 3px: 0.9203 5px: 0.9392 per_image_epe: 2.7199 f1: 7.6100 +``` diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index eaf03fbe4f3..860fb4ee1db 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -3,10 +3,16 @@ from pathlib import Path import torch +import torchvision.models.optical_flow import utils from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K -from torchvision.models.optical_flow import raft_large, raft_small + +try: + from torchvision.prototype import models as PM + from torchvision.prototype.models import optical_flow as PMOF +except ImportError: + PM = None def get_train_dataset(stage, dataset_root): @@ -125,6 +131,13 @@ def inner_loop(blob): def validate(model, args): val_datasets = args.val_dataset or [] + + if args.weights: + weights = PM.get_weight(args.weights) + preprocessing = weights.transforms() + else: + preprocessing = OpticalFlowPresetEval() + for name in val_datasets: if name == "kitti": # Kitti has different image sizes so we need to individually pad them, we can't batch. @@ -134,14 +147,14 @@ def validate(model, args): f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1." ) - val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval()) + val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing) _validate( model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1 ) elif name == "sintel": for pass_name in ("clean", "final"): val_dataset = Sintel( - root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval() + root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing ) _validate( model, @@ -187,7 +200,11 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s def main(args): utils.setup_ddp(args) - model = raft_small() if args.small else raft_large() + if args.weights: + model = PMOF.__dict__[args.model](weights=args.weights) + else: + model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained) + model = model.to(args.local_rank) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) @@ -306,7 +323,12 @@ def get_args_parser(add_help=True): "--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode." ) - parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.") + parser.add_argument( + "--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small" + ) + # TODO: resume, pretrained, and weights should be in an exclusive arg group + parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights") + parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.") parser.add_argument( "--num_flow_updates", diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 56e91bb3d48..46c143302af 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -201,13 +201,18 @@ def test_old_vs_new_factory(model_fn, dev): if module_name == "detection": x = [x] + if module_name == "optical_flow": + args = [x, x] # RAFT model requires img1, img2 as input + else: + args = [x] + # compare with new model builder parameterized in the old fashion way try: model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev) model_new = _build_model(model_fn, **kwargs).to(device=dev) except ModuleNotFoundError: pytest.skip(f"Model '{model_name}' not available in both modules.") - torch.testing.assert_close(model_new(x), model_old(x), rtol=0.0, atol=0.0, check_dtype=False) + torch.testing.assert_close(model_new(*args), model_old(*args), rtol=0.0, atol=0.0, check_dtype=False) def test_smoke(): diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index f653895598f..ff851b6382e 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -8,6 +8,7 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.ops import ConvNormActivation +from ..._internally_replaced_utils import load_state_dict_from_url from ...utils import _log_api_usage_once from ._utils import grid_sample, make_coords_grid, upsample_flow @@ -19,6 +20,9 @@ ) +_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"} + + class ResidualBlock(nn.Module): """Slightly modified Residual block with extra relu and biases.""" @@ -474,8 +478,8 @@ def forward(self, image1, image2, num_flow_updates: int = 12): hidden_state = torch.tanh(hidden_state) context = F.relu(context) - coords0 = make_coords_grid(batch_size, h // 8, w // 8).cuda() - coords1 = make_coords_grid(batch_size, h // 8, w // 8).cuda() + coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) + coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) flow_predictions = [] for _ in range(num_flow_updates): @@ -496,6 +500,9 @@ def forward(self, image1, image2, num_flow_updates: int = 12): def _raft( *, + arch=None, + pretrained=False, + progress=False, # Feature encoder feature_encoder_layers, feature_encoder_block, @@ -560,7 +567,7 @@ def _raft( multiplier=0.25, # See comment in MaskPredictor about this ) - return RAFT( + model = RAFT( feature_encoder=feature_encoder, context_encoder=context_encoder, corr_block=corr_block, @@ -568,6 +575,11 @@ def _raft( mask_predictor=mask_predictor, **kwargs, # not really needed, all params should be consumed by now ) + if pretrained: + state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) + model.load_state_dict(state_dict) + + return model def raft_large(*, pretrained=False, progress=True, **kwargs): @@ -584,10 +596,10 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): nn.Module: The model. """ - if pretrained: - raise ValueError("No checkpoint is available for raft_large") - return _raft( + arch="raft_large", + pretrained=pretrained, + progress=progress, # Feature encoder feature_encoder_layers=(64, 64, 96, 128, 256), feature_encoder_block=ResidualBlock, @@ -629,11 +641,13 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): nn.Module: The model. """ - if pretrained: raise ValueError("No checkpoint is available for raft_small") return _raft( + arch="raft_small", + pretrained=pretrained, + progress=progress, # Feature encoder feature_encoder_layers=(32, 32, 64, 96, 128), feature_encoder_block=BottleneckBlock, diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 4dad4b3b6b1..49432302729 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -4,12 +4,10 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.models.optical_flow import RAFT from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock - -# from torchvision.prototype.transforms import RaftEval +from torchvision.prototype.transforms import RaftEval from .._api import WeightsEnum - -# from .._api import Weights +from .._api import Weights from .._utils import handle_legacy_interface @@ -23,16 +21,16 @@ class Raft_Large_Weights(WeightsEnum): - pass - # C_T_V1 = Weights( - # # Chairs + Things - # url="", - # transforms=RaftEval, - # meta={ - # "recipe": "", - # "epe": -1234, - # }, - # ) + C_T_V2 = Weights( + # Chairs + Things + url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", + transforms=RaftEval, + meta={ + "recipe": "", # TODO + "sintel_train_cleanpass_epe": 1.3822, + "sintel_train_finalpass_epe": 2.7161, + }, + ) # C_T_SKHT_V1 = Weights( # # Chairs + Things + Sintel fine-tuning, i.e.: @@ -59,7 +57,7 @@ class Raft_Large_Weights(WeightsEnum): # }, # ) - # default = C_T_V1 + default = C_T_V2 class Raft_Small_Weights(WeightsEnum): @@ -75,13 +73,75 @@ class Raft_Small_Weights(WeightsEnum): # default = C_T_V1 -@handle_legacy_interface(weights=("pretrained", None)) +def _raft_builder( + *, + weights, + progress, + # Feature encoder + feature_encoder_layers, + feature_encoder_block, + feature_encoder_norm_layer, + # Context encoder + context_encoder_layers, + context_encoder_block, + context_encoder_norm_layer, + # Correlation block + corr_block_num_levels, + corr_block_radius, + # Motion encoder + motion_encoder_corr_layers, + motion_encoder_flow_layers, + motion_encoder_out_channels, + # Recurrent block + recurrent_block_hidden_state_size, + recurrent_block_kernel_size, + recurrent_block_padding, + # Flow Head + flow_head_hidden_size, + # Mask predictor + use_mask_predictor, + **kwargs, +): + model = _raft( + # Feature encoder + feature_encoder_layers=feature_encoder_layers, + feature_encoder_block=feature_encoder_block, + feature_encoder_norm_layer=feature_encoder_norm_layer, + # Context encoder + context_encoder_layers=context_encoder_layers, + context_encoder_block=context_encoder_block, + context_encoder_norm_layer=context_encoder_norm_layer, + # Correlation block + corr_block_num_levels=corr_block_num_levels, + corr_block_radius=corr_block_radius, + # Motion encoder + motion_encoder_corr_layers=motion_encoder_corr_layers, + motion_encoder_flow_layers=motion_encoder_flow_layers, + motion_encoder_out_channels=motion_encoder_out_channels, + # Recurrent block + recurrent_block_hidden_state_size=recurrent_block_hidden_state_size, + recurrent_block_kernel_size=recurrent_block_kernel_size, + recurrent_block_padding=recurrent_block_padding, + # Flow head + flow_head_hidden_size=flow_head_hidden_size, + # Mask predictor + use_mask_predictor=use_mask_predictor, + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2)) def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): """RAFT model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Args: - weights(Raft_Large_weights, optinal): TODO not implemented yet + weights(Raft_Large_weights, optional): pretrained weights to use. progress (bool): If True, displays a progress bar of the download to stderr kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class to override any default. @@ -92,7 +152,9 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * weights = Raft_Large_Weights.verify(weights) - return _raft( + return _raft_builder( + weights=weights, + progress=progress, # Feature encoder feature_encoder_layers=(64, 64, 96, 128, 256), feature_encoder_block=ResidualBlock, @@ -138,7 +200,9 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, * weights = Raft_Small_Weights.verify(weights) - return _raft( + return _raft_builder( + weights=weights, + progress=progress, # Feature encoder feature_encoder_layers=(32, 32, 64, 96, 128), feature_encoder_block=BottleneckBlock, From d244401ea9e41d555497119123e4d43dfe63f607 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 8 Dec 2021 15:09:15 +0000 Subject: [PATCH 2/4] Minor stuff --- references/optical_flow/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 860fb4ee1db..326f0be5f66 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -12,7 +12,7 @@ from torchvision.prototype import models as PM from torchvision.prototype.models import optical_flow as PMOF except ImportError: - PM = None + PM = PMOF = None def get_train_dataset(stage, dataset_root): From f186973ea46a483c2a0ec07355d84cff05db7f13 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 8 Dec 2021 17:12:13 +0000 Subject: [PATCH 3/4] Add pretrained weights from paper's repo as V1 --- torchvision/prototype/models/optical_flow/raft.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 49432302729..5f72af970fc 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -21,6 +21,17 @@ class Raft_Large_Weights(WeightsEnum): + C_T_V1 = Weights( + # Chairs + Things, ported from original paper repo (raft-things.pth) + url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", + transforms=RaftEval, + meta={ + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_train_cleanpass_epe": 1.4411, + "sintel_train_finalpass_epe": 2.7894, + }, + ) + C_T_V2 = Weights( # Chairs + Things url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", From 57aff364b06c58339645a4d76690c5b6e11d9174 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 8 Dec 2021 17:54:12 +0000 Subject: [PATCH 4/4] Address comments --- references/optical_flow/README.md | 12 +-- test/test_prototype_models.py | 4 +- .../prototype/models/optical_flow/raft.py | 87 ++++--------------- 3 files changed, 28 insertions(+), 75 deletions(-) diff --git a/references/optical_flow/README.md b/references/optical_flow/README.md index d0963f44317..660ee1b0c38 100644 --- a/references/optical_flow/README.md +++ b/references/optical_flow/README.md @@ -16,6 +16,7 @@ recipe from https://github.com/princeton-vl/RAFT. torchrun --nproc_per_node 8 --nnodes 1 train.py \ --dataset-root $dataset_root \ --name $name_chairs \ + --model raft_large \ --train-dataset chairs \ --batch-size 2 \ --lr 0.0004 \ @@ -28,6 +29,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \ torchrun --nproc_per_node 8 --nnodes 1 train.py \ --dataset-root $dataset_root \ --name $name_things \ + --model raft_large \ --train-dataset things \ --batch-size 2 \ --lr 0.000125 \ @@ -42,14 +44,14 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \ ### Evaluation ``` -torchrun --nproc_per_node 8 --nnodes 1 train.py --val-dataset sintel --batch-size 10 --dataset-root $dataset_root --model raft_large --pretrained +torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained ``` -This should give an epe of about 1.3825 on the clean pass and 2.7148 on the +This should give an epe of about 1.3822 on the clean pass and 2.7161 on the final pass of Sintel. Results may vary slightly depending on the batch size and -the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`. +the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`: ``` -Sintel val clean epe: 1.3825 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3782 f1: 4.0234 -Sintel val final epe: 2.7148 1px: 0.8526 3px: 0.9203 5px: 0.9392 per_image_epe: 2.7199 f1: 7.6100 +Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248 +Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964 ``` diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 46c143302af..87a269c7a41 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -91,7 +91,8 @@ def test_naming_conventions(model_fn): + TM.get_models_from_module(models.detection) + TM.get_models_from_module(models.quantization) + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video), + + TM.get_models_from_module(models.video) + + TM.get_models_from_module(models.optical_flow), ) def test_schema_meta_validation(model_fn): classification_fields = ["size", "categories", "acc@1", "acc@5"] @@ -102,6 +103,7 @@ def test_schema_meta_validation(model_fn): "quantization": classification_fields + ["backend", "quantization", "unquantized"], "segmentation": ["categories", "mIoU", "acc"], "video": classification_fields, + "optical_flow": [], } module_name = model_fn.__module__.split(".")[-2] fields = set(defaults["all"] + defaults[module_name]) diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 5f72af970fc..4fc7e962864 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -5,6 +5,7 @@ from torchvision.models.optical_flow import RAFT from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock from torchvision.prototype.transforms import RaftEval +from torchvision.transforms.functional import InterpolationMode from .._api import WeightsEnum from .._api import Weights @@ -20,12 +21,16 @@ ) +_COMMON_META = {"interpolation": InterpolationMode.BILINEAR} + + class Raft_Large_Weights(WeightsEnum): C_T_V1 = Weights( # Chairs + Things, ported from original paper repo (raft-things.pth) url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", transforms=RaftEval, meta={ + **_COMMON_META, "recipe": "https://github.com/princeton-vl/RAFT", "sintel_train_cleanpass_epe": 1.4411, "sintel_train_finalpass_epe": 2.7894, @@ -37,7 +42,8 @@ class Raft_Large_Weights(WeightsEnum): url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", transforms=RaftEval, meta={ - "recipe": "", # TODO + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", "sintel_train_cleanpass_epe": 1.3822, "sintel_train_finalpass_epe": 2.7161, }, @@ -84,68 +90,6 @@ class Raft_Small_Weights(WeightsEnum): # default = C_T_V1 -def _raft_builder( - *, - weights, - progress, - # Feature encoder - feature_encoder_layers, - feature_encoder_block, - feature_encoder_norm_layer, - # Context encoder - context_encoder_layers, - context_encoder_block, - context_encoder_norm_layer, - # Correlation block - corr_block_num_levels, - corr_block_radius, - # Motion encoder - motion_encoder_corr_layers, - motion_encoder_flow_layers, - motion_encoder_out_channels, - # Recurrent block - recurrent_block_hidden_state_size, - recurrent_block_kernel_size, - recurrent_block_padding, - # Flow Head - flow_head_hidden_size, - # Mask predictor - use_mask_predictor, - **kwargs, -): - model = _raft( - # Feature encoder - feature_encoder_layers=feature_encoder_layers, - feature_encoder_block=feature_encoder_block, - feature_encoder_norm_layer=feature_encoder_norm_layer, - # Context encoder - context_encoder_layers=context_encoder_layers, - context_encoder_block=context_encoder_block, - context_encoder_norm_layer=context_encoder_norm_layer, - # Correlation block - corr_block_num_levels=corr_block_num_levels, - corr_block_radius=corr_block_radius, - # Motion encoder - motion_encoder_corr_layers=motion_encoder_corr_layers, - motion_encoder_flow_layers=motion_encoder_flow_layers, - motion_encoder_out_channels=motion_encoder_out_channels, - # Recurrent block - recurrent_block_hidden_state_size=recurrent_block_hidden_state_size, - recurrent_block_kernel_size=recurrent_block_kernel_size, - recurrent_block_padding=recurrent_block_padding, - # Flow head - flow_head_hidden_size=flow_head_hidden_size, - # Mask predictor - use_mask_predictor=use_mask_predictor, - **kwargs, - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - @handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2)) def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): """RAFT model from @@ -163,9 +107,7 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * weights = Raft_Large_Weights.verify(weights) - return _raft_builder( - weights=weights, - progress=progress, + model = _raft( # Feature encoder feature_encoder_layers=(64, 64, 96, 128, 256), feature_encoder_block=ResidualBlock, @@ -192,6 +134,11 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * **kwargs, ) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + @handle_legacy_interface(weights=("pretrained", None)) def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): @@ -211,9 +158,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, * weights = Raft_Small_Weights.verify(weights) - return _raft_builder( - weights=weights, - progress=progress, + model = _raft( # Feature encoder feature_encoder_layers=(32, 32, 64, 96, 128), feature_encoder_block=BottleneckBlock, @@ -239,3 +184,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, * use_mask_predictor=False, **kwargs, ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model