diff --git a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py index 637d6dd18698..f73ef15d7de7 100644 --- a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +++ b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py @@ -56,7 +56,7 @@ def __init__( class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] base_model_prefix = "roberta" config_class = RobertaSeriesConfig @@ -65,6 +65,10 @@ def __init__(self, config): super().__init__(config) self.roberta = XLMRobertaModel(config) self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.has_pre_transformation = getattr(config, "has_pre_transformation", False) + if self.has_pre_transformation: + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_init() def forward( @@ -95,15 +99,26 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + output_hidden_states=True if self.has_pre_transformation else output_hidden_states, return_dict=return_dict, ) - projection_state = self.transformation(outputs.last_hidden_state) - - return TransformationModelOutput( - projection_state=projection_state, - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) + + return TransformationModelOutput( + projection_state=projection_state2, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + projection_state = self.transformation(outputs.last_hidden_state) + return TransformationModelOutput( + projection_state=projection_state, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )