1414from torch .nn .parallel import DistributedDataParallel as DDP
1515from torch .utils .data .distributed import DistributedSampler
1616from transformers .models .t5 .modeling_t5 import T5Block
17+ from nlp import load_dataset
1718
1819from torch .distributed .fsdp import (
1920 FullyShardedDataParallel as FSDP ,
@@ -86,11 +87,11 @@ def fsdp_main(args):
8687 print ("Size of train dataset: " , dataset ['train' ].shape )
8788 print ("Size of Validation dataset: " , dataset ['validation' ].shape )
8889
89-
90+
9091 #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
91- train_dataset = wikihow (tokenizer , 'train' , 1500 , 512 , 150 , False )
92+ train_dataset = wikihow (tokenizer , 'train' , 1500 , 512 , 150 , False )
9293 val_dataset = wikihow (tokenizer , 'validation' , 300 , 512 , 150 , False )
93-
94+
9495 sampler1 = DistributedSampler (train_dataset , rank = rank , num_replicas = world_size , shuffle = True )
9596 sampler2 = DistributedSampler (val_dataset , rank = rank , num_replicas = world_size )
9697
@@ -107,20 +108,20 @@ def fsdp_main(args):
107108
108109 train_loader = torch .utils .data .DataLoader (train_dataset ,** train_kwargs )
109110 val_loader = torch .utils .data .DataLoader (val_dataset , ** test_kwargs )
110-
111+
111112 torch .cuda .set_device (local_rank )
112-
113+
113114 # Set up FSDP parameters
114115 mixed_precision_policy , t5_auto_wrap_policy = get_policies (train_config , rank )
115-
116+
116117 # Apply FSDP wrapping to the model
117118 model = FSDP (model ,
118119 auto_wrap_policy = t5_auto_wrap_policy ,
119120 mixed_precision = mixed_precision_policy ,
120121 sharding_strategy = fsdp_config .sharding_strategy ,
121122 device_id = torch .cuda .current_device (),
122123 limit_all_gathers = fsdp_config .limit_all_gathers )
123-
124+
124125 # Enabling this causes https://github.com/pytorch/examples/issues/1210
125126 if fsdp_config .fsdp_activation_checkpointing :
126127 policies .apply_fsdp_checkpointing (model )
@@ -150,7 +151,7 @@ def fsdp_main(args):
150151 if args .run_validation :
151152 curr_val_loss = validation (model , rank , world_size , val_loader )
152153 scheduler .step ()
153-
154+
154155 if rank == 0 :
155156
156157 print (f"--> epoch { epoch } completed...entering save and stats zone" )
@@ -170,7 +171,7 @@ def fsdp_main(args):
170171 )
171172
172173 if train_config .save_model and curr_val_loss < best_val_loss :
173-
174+
174175 if fsdp_config .checkpoint_type == StateDictType .FULL_STATE_DICT :
175176 model_checkpointing .save_model_checkpoint (
176177 model , optimizer , rank , fsdp_config , epoch = 1
@@ -183,7 +184,7 @@ def fsdp_main(args):
183184 if fsdp_config .save_optimizer :
184185 model_checkpointing .save_optimizer_checkpoint (
185186 model , optimizer , rank , fsdp_config , epoch = 1
186- )
187+ )
187188 if curr_val_loss < best_val_loss :
188189
189190 best_val_loss = curr_val_loss
@@ -212,5 +213,5 @@ def fsdp_main(args):
212213 args = parser .parse_args ()
213214
214215 torch .manual_seed (args .seed )
215-
216+
216217 fsdp_main (args )
0 commit comments