mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Remove Redundant Variables from Encoder and Decoder (#5569)
* I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using. * Update src/diffusers/models/unet_2d_blocks.py This changes suggest by maintener. Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/models/unet_2d_blocks.py Add suggested text Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update unet_2d_blocks.py I changed the Parameter to Args text. * Update unet_2d_blocks.py proper indentation set in this file. * Update unet_2d_blocks.py a little bit of change in the act_fun argument line. * I run the black command to reformat style in the code * Update unet_2d_blocks.py similar doc-string add to have in the original diffusion repository. * I removed the dummy variable defined in both the encoder and decoder. * Now, I run black package to reformat my file --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -130,9 +130,9 @@ class Encoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
sample = x
|
||||
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -273,9 +273,11 @@ class Decoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, z: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
def forward(
|
||||
self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
sample = z
|
||||
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
|
||||
Reference in New Issue
Block a user