@@ -507,6 +507,43 @@ class DummyDecoderTextInputGenerator(DummyTextInputGenerator):
507507 )
508508
509509
510+ class DummyDecisionTransformerInputGenerator (DummyTextInputGenerator ):
511+ """
512+ Generates dummy decision transformer inputs.
513+ """
514+
515+ SUPPORTED_INPUT_NAMES = (
516+ "states" ,
517+ "actions" ,
518+ "timesteps" ,
519+ "returns_to_go" ,
520+ "attention_mask" ,
521+ )
522+
523+ def __init__ (self , * args , ** kwargs ):
524+ super ().__init__ (* args , ** kwargs )
525+ self .act_dim = self .normalized_config .config .act_dim
526+ self .state_dim = self .normalized_config .config .state_dim
527+ self .max_ep_len = self .normalized_config .config .max_ep_len
528+
529+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
530+ if input_name == "states" :
531+ shape = [self .batch_size , self .sequence_length , self .state_dim ]
532+ elif input_name == "actions" :
533+ shape = [self .batch_size , self .sequence_length , self .act_dim ]
534+ elif input_name == "rewards" :
535+ shape = [self .batch_size , self .sequence_length , 1 ]
536+ elif input_name == "returns_to_go" :
537+ shape = [self .batch_size , self .sequence_length , 1 ]
538+ elif input_name == "attention_mask" :
539+ shape = [self .batch_size , self .sequence_length ]
540+ elif input_name == "timesteps" :
541+ shape = [self .batch_size , self .sequence_length ]
542+ return self .random_int_tensor (shape = shape , max_value = self .max_ep_len , framework = framework , dtype = int_dtype )
543+
544+ return self .random_float_tensor (shape , min_value = - 2.0 , max_value = 2.0 , framework = framework , dtype = float_dtype )
545+
546+
510547class DummySeq2SeqDecoderTextInputGenerator (DummyDecoderTextInputGenerator ):
511548 SUPPORTED_INPUT_NAMES = (
512549 "decoder_input_ids" ,
0 commit comments