Skip to content

Commit a4c91be

Browse files
superhero-7root
and
root
authored
Modified altdiffusion pipline to support altdiffusion-m18 (#2993)
* Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 --------- Co-authored-by: root <[email protected]>
1 parent 3becd36 commit a4c91be

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656

5757

5858
class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
59-
_keys_to_ignore_on_load_unexpected = [r"pooler"]
59+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"]
6060
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
6161
base_model_prefix = "roberta"
6262
config_class = RobertaSeriesConfig
@@ -65,6 +65,10 @@ def __init__(self, config):
6565
super().__init__(config)
6666
self.roberta = XLMRobertaModel(config)
6767
self.transformation = nn.Linear(config.hidden_size, config.project_dim)
68+
self.has_pre_transformation = getattr(config, "has_pre_transformation", False)
69+
if self.has_pre_transformation:
70+
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
71+
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
6872
self.post_init()
6973

7074
def forward(
@@ -95,15 +99,26 @@ def forward(
9599
encoder_hidden_states=encoder_hidden_states,
96100
encoder_attention_mask=encoder_attention_mask,
97101
output_attentions=output_attentions,
98-
output_hidden_states=output_hidden_states,
102+
output_hidden_states=True if self.has_pre_transformation else output_hidden_states,
99103
return_dict=return_dict,
100104
)
101105

102-
projection_state = self.transformation(outputs.last_hidden_state)
103-
104-
return TransformationModelOutput(
105-
projection_state=projection_state,
106-
last_hidden_state=outputs.last_hidden_state,
107-
hidden_states=outputs.hidden_states,
108-
attentions=outputs.attentions,
109-
)
106+
if self.has_pre_transformation:
107+
sequence_output2 = outputs["hidden_states"][-2]
108+
sequence_output2 = self.pre_LN(sequence_output2)
109+
projection_state2 = self.transformation_pre(sequence_output2)
110+
111+
return TransformationModelOutput(
112+
projection_state=projection_state2,
113+
last_hidden_state=outputs.last_hidden_state,
114+
hidden_states=outputs.hidden_states,
115+
attentions=outputs.attentions,
116+
)
117+
else:
118+
projection_state = self.transformation(outputs.last_hidden_state)
119+
return TransformationModelOutput(
120+
projection_state=projection_state,
121+
last_hidden_state=outputs.last_hidden_state,
122+
hidden_states=outputs.hidden_states,
123+
attentions=outputs.attentions,
124+
)

0 commit comments

Comments
 (0)