-
Notifications
You must be signed in to change notification settings - Fork 657
Support Early Exit Loss and/or Layer Dropout #1076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+2,395
−3
Merged
Changes from 18 commits
Commits
Show all changes
90 commits
Select commit
Hold shift + click to select a range
97cb9a8
start of layer dropout implementation
mostafaelhoushi 4a25c5b
have different dropouts at different layers
mostafaelhoushi ac8ad0b
add option to specify which layers to apply dropout
mostafaelhoushi ae61c85
start early exit loss
mostafaelhoushi 735d2a8
parallelize processing of early exit losses
mostafaelhoushi be912a6
use absolute imports
mostafaelhoushi 0686dd2
remove unnecessary sync
mostafaelhoushi 4e4783f
move early exit loss to separate file and add layers as arg
mostafaelhoushi 268813e
perform loss scaling every iteration
mostafaelhoushi ccb4a50
return hidden states as an output rather than storing
mostafaelhoushi ff7d157
ensure last layer is always included
mostafaelhoushi 5a23811
return either last logits or hidden states
mostafaelhoushi e11aeba
fix scaling layers
mostafaelhoushi f9e164f
rotational early exit curriculum
mostafaelhoushi 069b661
set early exit params from cli
mostafaelhoushi 954d097
ensure last layer loss is always calculated
mostafaelhoushi 5789745
implement gradual early exit
mostafaelhoushi c3534e6
get streaming to work
mostafaelhoushi d1c6963
Merge branch 'main' into layerskip
mostafaelhoushi 7849130
add separate recipe for early exit
mostafaelhoushi df89c4f
port early exit loss code from PR
mostafaelhoushi 6cedb19
convert boolean array to indices
mostafaelhoushi a83da5a
decide on hidden outputs by member variable not forward pass
mostafaelhoushi 2a8791d
add early exit recipe config
mostafaelhoushi a326937
refactor unembedding
mostafaelhoushi 8ba6ab4
got early exit loss to work
mostafaelhoushi 681e7ca
add TopV2 instruction set
mostafaelhoushi 119ac7d
ensure all early exit loss params from cfg file are passed to code
mostafaelhoushi 3ec9d23
fix gradual early exit
mostafaelhoushi 04a590f
add test cases for early exit loss
mostafaelhoushi 9b5c96a
add more assertions for rotational early exit
mostafaelhoushi 3319ab0
test to follow training code
mostafaelhoushi 619b3eb
fix curriculum update
mostafaelhoushi d376ddd
update recipe
mostafaelhoushi ff3977b
reset changes to data loading
mostafaelhoushi 75b2e01
code cleanup
mostafaelhoushi 33a95f5
rename early_exit to early_exit_loss
mostafaelhoushi 5d7e903
address some early exit TODOs
mostafaelhoushi 87f2ee0
get layer dropout to work
mostafaelhoushi 1de0c2a
clean up early exit curriculum
mostafaelhoushi 2b0cdd1
enable grad curriculum for subset of layers + clear hidden_states at …
mostafaelhoushi 7973459
add docstring for slice_str_to_array
mostafaelhoushi baed8a9
support commas and add assertion statements
mostafaelhoushi 27f6b56
add test cases for slice_to_str_array
mostafaelhoushi 63e7c5b
add copyright header
mostafaelhoushi 638056b
support single index
mostafaelhoushi a20b07c
add new line at end of file
mostafaelhoushi 64210e6
Merge branch 'main' into layerskip
mostafaelhoushi 98897a8
add layer dropout test cases
mostafaelhoushi 2cc94cc
rename apply_layer_dropout to prepare_layer_dropout
mostafaelhoushi f4f8e02
add test cases for get_scale
mostafaelhoushi fed955e
cleanup get_scale + re-write mathematically equivalent + ensure max s…
mostafaelhoushi ca7d8da
test layer_dropout
mostafaelhoushi 0146764
start adding early exit loss and layer dropout to docstring
mostafaelhoushi f599eca
fix and update code and test cases to handle updating last layer sepa…
mostafaelhoushi 2437092
change match to if-else for CI
mostafaelhoushi ad090af
add assertion on type of loss fn for early exit loss
mostafaelhoushi cec8cd4
add docstring and slightly change attribute of layer_dropout and earl…
mostafaelhoushi b69f2f3
refactor layer_dropout and add test cases on wrapper
mostafaelhoushi a21cbd3
add TODO comment
mostafaelhoushi eb37cb6
fix error in checking if early exit loss is enabled
mostafaelhoushi 2e3f502
change recipe defaults of dataset and layer_drop probability
mostafaelhoushi 66a41b2
add detailed docstring to training script
mostafaelhoushi 345a0a3
ensure we set last layer early exit enable correctly
mostafaelhoushi 20c618c
ensure uniform early exit loss works
mostafaelhoushi f0e8d7f
add documentation to .yaml file and update doc in .py
mostafaelhoushi b03cb57
remove commented lines
mostafaelhoushi 199b8dd
remove check on PyTorch version since we assume latest stable PyTorch
mostafaelhoushi 6a2d79b
load curriculum step when resuming
mostafaelhoushi e5534ea
repeat arguments in derived classes
mostafaelhoushi d270d1f
rename percent_scale to fraction_scale and change its implementation
mostafaelhoushi e51419c
fixes to docstrings and config examples
mostafaelhoushi 40b7987
check if cfg_early_exit_loss has curriculum
mostafaelhoushi 0c18595
add comment to explain when has no effect
mostafaelhoushi 3e68696
organize early exit loss tests into classes
mostafaelhoushi 418951b
fix typo
mostafaelhoushi e5a53f9
test all loss scale types
mostafaelhoushi 3567a24
use variable number of subset layers
mostafaelhoushi ae2108d
ensure get_scale returns values between 0 and 1
mostafaelhoushi 71707de
add test cases for sigmoid
mostafaelhoushi 78aff5a
make prepare_layer_dropout apply on a list of layers rather than a model
mostafaelhoushi 0fb373b
Only add `optional` in docstring when argument is optional
mostafaelhoushi b66e23b
add Dropout class and prepare_layer_dropout APIs to docs
mostafaelhoushi cd8be64
add empty line between function description and Args
mostafaelhoushi 2675b4c
remove assert statement as we added the check in testing
mostafaelhoushi 00d8efa
change loss scale from enum to function
mostafaelhoushi 78b8996
change curriculum from enum to function
mostafaelhoushi ed33ba9
rename scale_type to scale_fn
mostafaelhoushi c7f02de
change default
mostafaelhoushi 69f840c
update docstring
mostafaelhoushi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from enum import Enum | ||
from typing import Callable, Optional | ||
import math | ||
import torch | ||
|
||
from torchtune.modules.common_utils import slice_str_to_array | ||
|
||
class LayerDropout(torch.nn.Module): | ||
def __init__(self, prob=0.0, dim=0, disable_on_eval=True, seed=None): | ||
super().__init__() | ||
self.prob: float = prob | ||
self.dim = dim | ||
self.disable_on_eval: bool = disable_on_eval | ||
self.generator = torch.Generator(device="cpu") | ||
self.inferred: float = None | ||
|
||
if seed is not None: | ||
self.generator.manual_seed(seed) | ||
|
||
def forward(self, function: Callable, input: torch.Tensor, *args, **kwargs): | ||
n = input.shape[self.dim] | ||
|
||
if self.prob == 0 or (self.disable_on_eval and self.training is False): | ||
self.inferred = 1.0 | ||
return function(input, *args, **kwargs) | ||
|
||
skip = torch.bernoulli(torch.Tensor((n) * [self.prob]), generator=self.generator).to(input.device).to(input.dtype) | ||
self.inferred = 1 - torch.mean(skip) | ||
ind_selected = (skip == 0).nonzero().squeeze() | ||
|
||
if ind_selected.numel() > 0: | ||
x_selected = torch.index_select(input, self.dim, ind_selected) | ||
out_selected = function(x_selected, *args, **kwargs) | ||
|
||
out = input.clone() | ||
assert self.dim == 0, "Currently only supporting dropping elements along the 0th dimension" | ||
if ind_selected.numel() > 0: | ||
out[ind_selected] = out_selected | ||
return out | ||
|
||
class ScaleType(str, Enum): | ||
UNIFORM = "uniform" | ||
EXP = "exp" | ||
LINEAR = "linear" | ||
LOG = "log" | ||
SIN = "sin" | ||
SIGMOID = "sigmoid" | ||
STEP = "step" | ||
|
||
def get_scale(scale_type: ScaleType, scale_period: int, val: int): | ||
if scale_period == 0: | ||
return 1 | ||
|
||
# all the equations below aim to make scale = 0 when val=0, and scale = 1 when val=scale_period | ||
return { | ||
ScaleType.UNIFORM: 1, | ||
ScaleType.EXP: math.exp(val * math.log(2) / scale_period) - 1, | ||
ScaleType.LINEAR: val / scale_period, | ||
ScaleType.LOG: math.log(val + 1) / math.log(scale_period + 1), | ||
ScaleType.SIN: math.sin(0.5 * math.pi * val / scale_period), | ||
ScaleType.SIGMOID: 1 / (1 + math.exp(-10 * (val / scale_period - 0.5))), | ||
}[scale_type] | ||
|
||
def create_layer_dropout_modules(num_layers: int, prob_max: float= 0.0, prob_layer_scale: ScaleType = ScaleType.EXP, layers_str: Optional[str] = None, disable_on_eval: bool = True): | ||
layer_dropouts = torch.nn.ModuleList() | ||
has_dropout = slice_str_to_array(layers_str, num_layers) if layers_str else [True] * num_layers | ||
mostafaelhoushi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
for layer_id in range(num_layers): | ||
prob = prob_max * get_scale( | ||
scale_type = prob_layer_scale, | ||
scale_period = num_layers - 1, | ||
val = layer_id, | ||
) if has_dropout[layer_id] else 0.0 | ||
assert prob >= 0.0 and prob <= prob_max, f"prob={prob} should be between 0 and {prob_max}" | ||
# We would like each layer to have a different seed, so that we don't have the same samples skipped across layers. Hence, we use the layer_id as a seed for each layer's dropout. | ||
layer_dropout = LayerDropout(prob, disable_on_eval=disable_on_eval, seed=layer_id) | ||
layer_dropouts.append(layer_dropout) | ||
|
||
return layer_dropouts |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.