diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 6c535a6cfc..4eda1721ad 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -1,4 +1,3 @@ -import math from dataclasses import dataclass from typing import Dict, Optional, Union @@ -249,11 +248,7 @@ class PriorTransformer(ModelMixin, ConfigMixin): # but time_embedding might be fp16, so we need to cast here. timesteps_projected = timesteps_projected.to(dtype=self.dtype) time_embeddings = self.time_embedding(timesteps_projected) - - # Rescale the features to have unit variance - # YiYi TO-DO: It was normalized before during encode_prompt step, move this step to pipeline - if self.clip_mean is None: - proj_embedding = math.sqrt(proj_embedding.shape[1]) * proj_embedding + proj_embeddings = self.embedding_proj(proj_embedding) if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py index 23afa0b3b4..ddbf1a71a9 100644 --- a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from dataclasses import dataclass from typing import List, Optional, Union @@ -242,6 +243,9 @@ class ShapEPipeline(DiffusionPipeline): # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # Rescale the features to have unit variance (this step is taken from the original repo) + prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds return prompt_embeds