mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
AudioDiffusionPipeline - fix encode method after config changes (#3114)
* config fixes * deprecate get_input_dims
This commit is contained in:
committed by
GitHub
parent
eb29dbad17
commit
b63419a28a
@@ -51,21 +51,6 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae)
|
||||
|
||||
def get_input_dims(self) -> Tuple:
|
||||
"""Returns dimension of input image
|
||||
|
||||
Returns:
|
||||
`Tuple`: (height, width)
|
||||
"""
|
||||
input_module = self.vqvae if self.vqvae is not None else self.unet
|
||||
# For backwards compatibility
|
||||
sample_size = (
|
||||
(input_module.config.sample_size, input_module.config.sample_size)
|
||||
if type(input_module.config.sample_size) == int
|
||||
else input_module.config.sample_size
|
||||
)
|
||||
return sample_size
|
||||
|
||||
def get_default_steps(self) -> int:
|
||||
"""Returns default number of steps recommended for inference
|
||||
|
||||
@@ -123,8 +108,6 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
||||
# For backwards compatibility
|
||||
if type(self.unet.config.sample_size) == int:
|
||||
self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
|
||||
input_dims = self.get_input_dims()
|
||||
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
|
||||
if noise is None:
|
||||
noise = randn_tensor(
|
||||
(
|
||||
@@ -234,7 +217,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
||||
sample = torch.Tensor(sample).to(self.device)
|
||||
|
||||
for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))):
|
||||
prev_timestep = t - self.scheduler.num_train_timesteps // self.scheduler.num_inference_steps
|
||||
prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = (
|
||||
self.scheduler.alphas_cumprod[prev_timestep]
|
||||
|
||||
Reference in New Issue
Block a user