14
14
from torch .nn .parallel import DistributedDataParallel as DDP
15
15
from torch .utils .data .distributed import DistributedSampler
16
16
from transformers .models .t5 .modeling_t5 import T5Block
17
+ from nlp import load_dataset
17
18
18
19
from torch .distributed .fsdp import (
19
20
FullyShardedDataParallel as FSDP ,
@@ -86,11 +87,11 @@ def fsdp_main(args):
86
87
print ("Size of train dataset: " , dataset ['train' ].shape )
87
88
print ("Size of Validation dataset: " , dataset ['validation' ].shape )
88
89
89
-
90
+
90
91
#wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
91
- train_dataset = wikihow (tokenizer , 'train' , 1500 , 512 , 150 , False )
92
+ train_dataset = wikihow (tokenizer , 'train' , 1500 , 512 , 150 , False )
92
93
val_dataset = wikihow (tokenizer , 'validation' , 300 , 512 , 150 , False )
93
-
94
+
94
95
sampler1 = DistributedSampler (train_dataset , rank = rank , num_replicas = world_size , shuffle = True )
95
96
sampler2 = DistributedSampler (val_dataset , rank = rank , num_replicas = world_size )
96
97
@@ -107,20 +108,20 @@ def fsdp_main(args):
107
108
108
109
train_loader = torch .utils .data .DataLoader (train_dataset ,** train_kwargs )
109
110
val_loader = torch .utils .data .DataLoader (val_dataset , ** test_kwargs )
110
-
111
+
111
112
torch .cuda .set_device (local_rank )
112
-
113
+
113
114
# Set up FSDP parameters
114
115
mixed_precision_policy , t5_auto_wrap_policy = get_policies (train_config , rank )
115
-
116
+
116
117
# Apply FSDP wrapping to the model
117
118
model = FSDP (model ,
118
119
auto_wrap_policy = t5_auto_wrap_policy ,
119
120
mixed_precision = mixed_precision_policy ,
120
121
sharding_strategy = fsdp_config .sharding_strategy ,
121
122
device_id = torch .cuda .current_device (),
122
123
limit_all_gathers = fsdp_config .limit_all_gathers )
123
-
124
+
124
125
# Enabling this causes https://github.com/pytorch/examples/issues/1210
125
126
if fsdp_config .fsdp_activation_checkpointing :
126
127
policies .apply_fsdp_checkpointing (model )
@@ -150,7 +151,7 @@ def fsdp_main(args):
150
151
if args .run_validation :
151
152
curr_val_loss = validation (model , rank , world_size , val_loader )
152
153
scheduler .step ()
153
-
154
+
154
155
if rank == 0 :
155
156
156
157
print (f"--> epoch { epoch } completed...entering save and stats zone" )
@@ -170,7 +171,7 @@ def fsdp_main(args):
170
171
)
171
172
172
173
if train_config .save_model and curr_val_loss < best_val_loss :
173
-
174
+
174
175
if fsdp_config .checkpoint_type == StateDictType .FULL_STATE_DICT :
175
176
model_checkpointing .save_model_checkpoint (
176
177
model , optimizer , rank , fsdp_config , epoch = 1
@@ -183,7 +184,7 @@ def fsdp_main(args):
183
184
if fsdp_config .save_optimizer :
184
185
model_checkpointing .save_optimizer_checkpoint (
185
186
model , optimizer , rank , fsdp_config , epoch = 1
186
- )
187
+ )
187
188
if curr_val_loss < best_val_loss :
188
189
189
190
best_val_loss = curr_val_loss
@@ -212,5 +213,5 @@ def fsdp_main(args):
212
213
args = parser .parse_args ()
213
214
214
215
torch .manual_seed (args .seed )
215
-
216
+
216
217
fsdp_main (args )
0 commit comments