From 29cf163b95b7ebe4e6609d1e52c9ff226f4679c1 Mon Sep 17 00:00:00 2001 From: Chi Date: Thu, 2 Nov 2023 02:20:33 +0530 Subject: [PATCH] Remove Redundant Variables from Encoder and Decoder (#5569) * I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using. * Update src/diffusers/models/unet_2d_blocks.py This changes suggest by maintener. Co-authored-by: Sayak Paul * Update src/diffusers/models/unet_2d_blocks.py Add suggested text Co-authored-by: Sayak Paul * Update unet_2d_blocks.py I changed the Parameter to Args text. * Update unet_2d_blocks.py proper indentation set in this file. * Update unet_2d_blocks.py a little bit of change in the act_fun argument line. * I run the black command to reformat style in the code * Update unet_2d_blocks.py similar doc-string add to have in the original diffusion repository. * I removed the dummy variable defined in both the encoder and decoder. * Now, I run black package to reformat my file --------- Co-authored-by: Sayak Paul Co-authored-by: Dhruv Nair --- src/diffusers/models/vae.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index da08bc3609..0f849a66ea 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -130,9 +130,9 @@ class Encoder(nn.Module): self.gradient_checkpointing = False - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" - sample = x + sample = self.conv_in(sample) if self.training and self.gradient_checkpointing: @@ -273,9 +273,11 @@ class Decoder(nn.Module): self.gradient_checkpointing = False - def forward(self, z: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + def forward( + self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" - sample = z + sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype