Skip to content

Prototype transforms cleanup #5504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 1, 2022

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Mar 1, 2022

Addresses minor comments from #5500.

@facebook-github-bot
Copy link

facebook-github-bot commented Mar 1, 2022

💊 CI failures summary and remediations

As of commit 448bec2 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@datumbox datumbox merged commit 64e7460 into pytorch:main Mar 1, 2022
@pmeier pmeier deleted the prototype-transforms-cleanup branch March 1, 2022 16:49
facebook-github-bot pushed a commit that referenced this pull request Mar 5, 2022
Summary:
* fix grayscale to RGB for batches

* make unsupported types in auto augment a parameter

* make auto augment kwargs explicit

* add missing error message

* add support for specifying probabilites on RandomChoice

* remove TODO for deprecating p on random transforms

* streamline sample type checking

* address comments

* split image_size into height and width in auto augment

Reviewed By: datumbox

Differential Revision: D34579511

fbshipit-source-id: 757663a5a77f229cd1592b4c23dc17c7e8fe4807

Co-authored-by: Vasilis Vryniotis <[email protected]>
Comment on lines +39 to +48
def _extract_types(sample: Any) -> Iterator[Type]:
return query_recursively(lambda id, input: type(input), sample)


def has_any(sample: Any, *types: Type) -> bool:
return any(issubclass(type, types) for type in _extract_types(sample))


def has_all(sample: Any, *types: Type) -> bool:
return not bool(set(types) - set(_extract_types(sample)))
Copy link
Collaborator

@vfdev-5 vfdev-5 Jul 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier fyi maybe we could code has_all and has_any using torch tree_flatten and simplify the code:

from typing import Any, Type

from torch.utils._pytree import tree_flatten


def has_any(sample: Any, *types: Type) -> bool:
    flat_sample, _ = tree_flatten(sample)
    return any(issubclass(type(obj), types) for obj in flat_sample)


def has_all(sample: Any, *types: Type) -> bool:
    flat_sample, _ = tree_flatten(sample)
    return not bool(set(types) - set([type(obj) for obj in flat_sample]))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants