Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _check_input(

@staticmethod
def _generate_value(left: float, right: float) -> float:
return float(torch.distributions.Uniform(left, right).sample())
return torch.empty(1).uniform_(left, right).item()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching to this random generator we get a performance boost on GPU. Moreover this option is JIT-scriptable (if on the future we decide to add support) and doesn't require to constantly initialize a distribution object as before:

[--------- ColorJitter cpu torch.float32 ---------]
                     |   old random  |   new random
1 threads: ----------------------------------------
      (3, 400, 400)  |       17      |       17    
6 threads: ----------------------------------------
      (3, 400, 400)  |       21      |       21    

Times are in milliseconds (ms).

[--------- ColorJitter cuda torch.float32 --------]
                     |   old random  |   new random
1 threads: ----------------------------------------
      (3, 400, 400)  |      1090     |      883    
6 threads: ----------------------------------------
      (3, 400, 400)  |      1090     |      882    

Times are in microseconds (us).


def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
fn_idx = torch.randperm(4)
Expand Down