Skip to content

ViTDet object detection + segmentation implementation #7690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

hgaiser
Copy link
Contributor

@hgaiser hgaiser commented Jun 21, 2023

This PR implements ViTDet, as per #7630 . I needed this implementation regardless of the feedback from torchvision maintainers, but I figured it makes sense to try and merge this upstream. The implementation borrows heavily from the implementation in detectron2. There is still some work to do, but since there is no feedback on whether this will ever be merged I will pause development at this stage.

Discussion points

  1. I had to move some weights around and use different implementation for the Attention layer, making existing weights incompatible.
  2. Currently I put the ViTDet implementation inside the mask_rcnn.py file, since they are so much alike. Should I put it in a separate vitdet.py file instead?
  3. I have only added a convenience function for a MaskRCNN with ViT-B/16 backbone. Do we want other backbones? If yes, which ones? For ResNet we also only provide convenience functions for ResNet50.. so not sure what to do here.

Current status

A training with the following command:

python train.py \
    --dataset coco --model maskrcnn_vit_b_16_sfpn --epochs 10 --batch-size 2 \
    --aspect-ratio-group-factor -1 --weights-backbone ViT_B_16_Weights.IMAGENET1K_V1 --data-path=/srv/data/coco \
    --opt vitdet --lr 8e-5 --wd 0.1 --data-augmentation lsj --lr-steps 3 6 --image-min-size 1024 --image-max-size 1024

python train.py \
    --dataset coco --model maskrcnn_vit_b_16_sfpn --epochs 10 --batch-size 2 \
    --aspect-ratio-group-factor -1 --weights-backbone ViT_B_16_Weights.IMAGENET1K_V1 --data-path=/srv/data/coco \
    --opt vitdet --lr 8e-5 --wd 0.1 --data-augmentation lsj --lr-steps 3 6 --image-min-size 1024 --image-max-size 1024

Achieves the following result:

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.475
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.691
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.524
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.322
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.512
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.612
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.366
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.579
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.606
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.432
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.645
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.757
IoU metric: segm
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.424
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.662
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.455
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.236
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.454
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.617
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.337
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.525
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.548
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.364
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.590
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.715

The segmentation results are identical to the results from their paper.

Todo's

  • I broke the implementation for the classification part of ViT. This needs some more work.
  • I removed the previously available argument to set trainable layers.
  • Train a MaskRCNN model + upload weights.
  • Double check all docstrings to make sure they are still correct.
  • Check formatting / unit tests.
  • Check conversion to torchscript.
  • Check conversion to ONNX (?).

My main intention with opening this PR is to allow torchvision maintainers to provide their feedback and opinion. @fmassa I'm not sure if you are still working on these things, but I tag you since we worked together on the RetinaNet implementation :).

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 21, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7690

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@hgaiser
Copy link
Contributor Author

hgaiser commented Jun 28, 2023

I updated this PR so that the implementation more closely resembles the initial implementation of ViT in torchvision. I have also updated the first post accordingly, to avoid unnecessary reading :p

The only difference made in this PR now is that pos_embedding has moved to ViT class instead of the Encoder class (so that any sized images are accepted). This means that existing weights files are not compatible anymore, for which I added a workaround in the _vision_transformer function. Is this acceptable?

(the rest of the todo's still stand)

@hgaiser
Copy link
Contributor Author

hgaiser commented Jul 11, 2023

I trained a COCO model using ViT-B as backbone with the following command:

python train.py \
    --dataset coco --model maskrcnn_vit_b_16_sfpn --epochs 10 --batch-size 4 \
    --aspect-ratio-group-factor -1 --weights-backbone ViT_B_16_Weights.IMAGENET1K_V1 --data-path=/srv/data/coco \
    --opt adamw --lr 8e-5 --wd 0.1 --data-augmentation lsj --lr-steps 3 6

And got the following results:

IoU metric: segm
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.320
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.534
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.331
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.138
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.338
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.492
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.285
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.440
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.465
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.258
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.502
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.640

This configuration should get approximately 0.424 mAP according to the ViTDet paper (versus 0.320 in this training). This tells me that it is learning something, so in general the implementation is correct, but there is still something missing.

One thing to note is that I trained on a single GPU with batchsize=4, whereas they trained with 64 GPUs (1 image per GPU). I'm not sure what the effect of this is, since I don't have 64 GPUs at my disposal. If someone has the resources to train with batchsize=64, I would be very interested to see how it performs.

In the meantime I will try and use this model some more to see if I can improve on these results.

@SeucheAchat9115
Copy link

Is there any update how to fix this? I really would like to have a working VITDet torchvision implementation.

@hgaiser
Copy link
Contributor Author

hgaiser commented Jul 30, 2023

None that I have found. I modified the implementation to match that of detectron2 (to the point where both networks output the same features, given the same input and a seed for RNG), but the results are surprisingly even worse. I don't have the numbers on hand at the moment, but I will continue to look into this.

If you're interested, feel free to give it a go and see what performance you get.

@hgaiser
Copy link
Contributor Author

hgaiser commented Sep 12, 2023

I'm slowly making progress on this, but I am not completely there yet. Is there still interest in this from the torchvision maintainers to merge this at some point?

@pmeier can I ask you for your feedback? Or alternatively can you let me know who best to ask?

@hgaiser
Copy link
Contributor Author

hgaiser commented Oct 7, 2023

The latest changes did have an impact on the COCO evaluation score:

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.418
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.635
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.458
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.267
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.450
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.559
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.334
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.531
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.559
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.377
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.598
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.715
IoU metric: segm
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.380
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.604
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.405
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.194
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.409
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.573
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.313
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.488
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.512
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.316
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.558
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.688

Though 0.380 still isn't the expected 0.424. I worry that the relative positional embedding in the multihead attention might explain this difference (which is not possible using the Attention layer from torch). The easiest solution would be to implement a custom Attention layer in torchvision, a la detectron2.

@hgaiser
Copy link
Contributor Author

hgaiser commented Oct 17, 2023

Good news, the accuracy has gone up significantly by changing the attention layer. The main difference should be that it uses a relative positional embedding. The score I am getting on COCO now is:

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.471
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.691
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.519
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.315
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.509
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.608
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.363
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.576
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.604
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.427
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.645
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.755
IoU metric: segm
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.421
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.660
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.453
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.224
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.455
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.609
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.334
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.521
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.544
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.355
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.591
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.71

That 0.421 is awfully close to the reported 0.424 by their paper. I will update the first post with TODO's that are still left for implementation. Considering there seems to be little to no interest in this, I will stop development here as this was all I needed (working ViTDet in torchvision).

@hgaiser
Copy link
Contributor Author

hgaiser commented Nov 7, 2023

I found some bug in the learning rate decay, with those fixes the results are:

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.475
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.691
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.524
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.322
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.512
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.612
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.366
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.579
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.606
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.432
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.645
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.757
IoU metric: segm
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.424
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.662
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.455
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.236
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.454
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.617
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.337
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.525
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.548
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.364
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.590
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.715

Which for segmentation is identical to the results in the paper, bbox is nearly identical.

🥳

@hgaiser hgaiser marked this pull request as ready for review November 7, 2023 15:49
@JohannesTheo
Copy link
Contributor

Any update on this? Would be really cool to have VitDet in torchvision :)

@hgaiser
Copy link
Contributor Author

hgaiser commented Aug 28, 2024

If necessary i can rebase this PR, but I haven't heard from any torchvision maintainer yet so I will wait ^^

@JohannesTheo
Copy link
Contributor

Makes sense. @datumbox @pmeier @NicolasHug @fmassa sry for the ping if you are on holiday! Can someone maybe leave a short comment if this PR has a chance to be considered? Would be really cool to have VitDet in torchvision.

@denn-s
Copy link

denn-s commented Sep 20, 2024

It would be great to have this model in Torchvision's model zoo.

@ArthurOuaknine
Copy link

Thanks for this great contribution! Any update on the release of this code?

@hgaiser
Copy link
Contributor Author

hgaiser commented Jan 21, 2025

I am curious to know too :).

@NicolasHug apologies for tagging, but you seem to be actively working on torchvision from the Meta AI group. Are you in a position help guide this PR to a merge-able state, or do you know someone who is? I still think it would be a good addition to have.

@NicolasHug
Copy link
Member

@hgaiser I'm really sorry, I appreciate the work, but we've been unable to prioritize model authoring in torchvision for a while. We won't be adding new models in the foreseeable future.

I think the best way for you to make this available would be to publish it through torch.hub, and/or the huggingface hub?

@hgaiser
Copy link
Contributor Author

hgaiser commented Feb 19, 2025

@NicolasHug thanks for the response, it is what it is :(.

@ArthurOuaknine
Copy link

@hgaiser Please let us know in this thread if you plan to release this work in another hub or repo. That would be awesome.
Thanks again for your contribution.

@hgaiser
Copy link
Contributor Author

hgaiser commented Feb 20, 2025

@hgaiser Please let us know in this thread if you plan to release this work in another hub or repo. That would be awesome. Thanks again for your contribution.

At the moment, I don't have any plans to release it anywhere else. To anyone interested, feel free to pick it up.

Things might change in the future, but for now I have no time to work on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants