From a4c91be73b871e2b1b0e934d893001978415e547 Mon Sep 17 00:00:00 2001 From: superhero-7 <57797766+superhero-7@users.noreply.github.com> Date: Thu, 20 Apr 2023 01:00:29 +0800 Subject: [PATCH] Modified altdiffusion pipline to support altdiffusion-m18 (#2993) * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 --------- Co-authored-by: root --- .../alt_diffusion/modeling_roberta_series.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py index 637d6dd186..f73ef15d7d 100644 --- a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +++ b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py @@ -56,7 +56,7 @@ class RobertaSeriesConfig(XLMRobertaConfig): class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] base_model_prefix = "roberta" config_class = RobertaSeriesConfig @@ -65,6 +65,10 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): super().__init__(config) self.roberta = XLMRobertaModel(config) self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.has_pre_transformation = getattr(config, "has_pre_transformation", False) + if self.has_pre_transformation: + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_init() def forward( @@ -95,15 +99,26 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + output_hidden_states=True if self.has_pre_transformation else output_hidden_states, return_dict=return_dict, ) - projection_state = self.transformation(outputs.last_hidden_state) + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) - return TransformationModelOutput( - projection_state=projection_state, - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + return TransformationModelOutput( + projection_state=projection_state2, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + projection_state = self.transformation(outputs.last_hidden_state) + return TransformationModelOutput( + projection_state=projection_state, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )