1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

move the rescale prompt_embeds from prior_transformer to pipeline

This commit is contained in:
yiyixuxu
2023-06-23 03:11:55 +00:00
parent 6ec68eec40
commit 5b5a8e6be9
2 changed files with 5 additions and 6 deletions

View File

@@ -1,4 +1,3 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Union
@@ -249,11 +248,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
# but time_embedding might be fp16, so we need to cast here.
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
time_embeddings = self.time_embedding(timesteps_projected)
# Rescale the features to have unit variance
# YiYi TO-DO: It was normalized before during encode_prompt step, move this step to pipeline
if self.clip_mean is None:
proj_embedding = math.sqrt(proj_embedding.shape[1]) * proj_embedding
proj_embeddings = self.embedding_proj(proj_embedding)
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import List, Optional, Union
@@ -242,6 +243,9 @@ class ShapEPipeline(DiffusionPipeline):
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# Rescale the features to have unit variance (this step is taken from the original repo)
prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds
return prompt_embeds