Skip to content

Prototype transforms cleanup#5504

Merged
datumbox merged 11 commits into
pytorch:mainfrom
pmeier:prototype-transforms-cleanup
Mar 1, 2022
Merged

Prototype transforms cleanup#5504
datumbox merged 11 commits into
pytorch:mainfrom
pmeier:prototype-transforms-cleanup

Conversation

@pmeier

@pmeier pmeier commented Mar 1, 2022

Copy link
Copy Markdown
Contributor

Addresses minor comments from #5500.

@facebook-github-bot

facebook-github-bot commented Mar 1, 2022

Copy link
Copy Markdown
Contributor

💊 CI failures summary and remediations

As of commit 448bec2 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Comment thread torchvision/prototype/transforms/_augment.py
Comment thread torchvision/prototype/transforms/_augment.py
Comment thread torchvision/prototype/transforms/_auto_augment.py Outdated
Comment thread torchvision/prototype/transforms/_auto_augment.py Outdated

@datumbox datumbox left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@datumbox datumbox merged commit 64e7460 into pytorch:main Mar 1, 2022
@pmeier pmeier deleted the prototype-transforms-cleanup branch March 1, 2022 16:49
facebook-github-bot pushed a commit that referenced this pull request Mar 5, 2022
Summary:
* fix grayscale to RGB for batches

* make unsupported types in auto augment a parameter

* make auto augment kwargs explicit

* add missing error message

* add support for specifying probabilites on RandomChoice

* remove TODO for deprecating p on random transforms

* streamline sample type checking

* address comments

* split image_size into height and width in auto augment

Reviewed By: datumbox

Differential Revision: D34579511

fbshipit-source-id: 757663a5a77f229cd1592b4c23dc17c7e8fe4807

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Comment on lines +39 to +48
def _extract_types(sample: Any) -> Iterator[Type]:
return query_recursively(lambda id, input: type(input), sample)


def has_any(sample: Any, *types: Type) -> bool:
return any(issubclass(type, types) for type in _extract_types(sample))


def has_all(sample: Any, *types: Type) -> bool:
return not bool(set(types) - set(_extract_types(sample)))

@vfdev-5 vfdev-5 Jul 11, 2022

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier fyi maybe we could code has_all and has_any using torch tree_flatten and simplify the code:

from typing import Any, Type

from torch.utils._pytree import tree_flatten


def has_any(sample: Any, *types: Type) -> bool:
    flat_sample, _ = tree_flatten(sample)
    return any(issubclass(type(obj), types) for obj in flat_sample)


def has_all(sample: Any, *types: Type) -> bool:
    flat_sample, _ = tree_flatten(sample)
    return not bool(set(types) - set([type(obj) for obj in flat_sample]))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants