From 5b5a8e6be918fefd114a2945ed89d8e8fa8be21b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 23 Jun 2023 03:11:55 +0000 Subject: [PATCH] move the rescale prompt_embeds from prior_transformer to pipeline --- src/diffusers/models/prior_transformer.py | 7 +------ src/diffusers/pipelines/shap_e/pipeline_shap_e.py | 4 ++++ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 6c535a6cfc..4eda1721ad 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -1,4 +1,3 @@ -import math from dataclasses import dataclass from typing import Dict, Optional, Union @@ -249,11 +248,7 @@ class PriorTransformer(ModelMixin, ConfigMixin): # but time_embedding might be fp16, so we need to cast here. timesteps_projected = timesteps_projected.to(dtype=self.dtype) time_embeddings = self.time_embedding(timesteps_projected) - - # Rescale the features to have unit variance - # YiYi TO-DO: It was normalized before during encode_prompt step, move this step to pipeline - if self.clip_mean is None: - proj_embedding = math.sqrt(proj_embedding.shape[1]) * proj_embedding + proj_embeddings = self.embedding_proj(proj_embedding) if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py index 23afa0b3b4..ddbf1a71a9 100644 --- a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from dataclasses import dataclass from typing import List, Optional, Union @@ -242,6 +243,9 @@ class ShapEPipeline(DiffusionPipeline): # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # Rescale the features to have unit variance (this step is taken from the original repo) + prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds return prompt_embeds