mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
move the rescale prompt_embeds from prior_transformer to pipeline
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
@@ -249,11 +248,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
# but time_embedding might be fp16, so we need to cast here.
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
# Rescale the features to have unit variance
|
||||
# YiYi TO-DO: It was normalized before during encode_prompt step, move this step to pipeline
|
||||
if self.clip_mean is None:
|
||||
proj_embedding = math.sqrt(proj_embedding.shape[1]) * proj_embedding
|
||||
|
||||
proj_embeddings = self.embedding_proj(proj_embedding)
|
||||
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@@ -242,6 +243,9 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# Rescale the features to have unit variance (this step is taken from the original repo)
|
||||
prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
Reference in New Issue
Block a user