diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_gpt2_optimus.py b/src/diffusers/pipelines/versatile_diffusion/modeling_gpt2_optimus.py index 02a0ba822c..647eb841fc 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_gpt2_optimus.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_gpt2_optimus.py @@ -1,8 +1,118 @@ +import math + import torch from torch import nn from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions -from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2PreTrainedModel +from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2PreTrainedModel +from transformers.pytorch_utils import Conv1D + + +class GPT2OptimusAttention(nn.Module): + def __init__(self, nx, n_ctx, config, scale=False): + super().__init__() + self.output_attentions = config.output_attentions + + n_state = nx # in Attention: n_state=768 (nx=n_embd) + # [switch nx => n_state from Block to Attention to keep identical to TF implem] + assert n_state % config.n_head == 0 + self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) + self.n_head = config.n_head + self.split_size = n_state + self.scale = scale + + self.c_attn = Conv1D(n_state * 3, nx) + self.c_proj = Conv1D(n_state, nx) + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.pruned_heads = set() + + def _attn(self, q, k, v, attention_mask=None, head_mask=None): + w = torch.matmul(q, k) + if self.scale: + w = w / math.sqrt(v.size(-1)) + nd, ns = w.size(-2), w.size(-1) + b = self.bias[:, :, ns - nd : ns, :ns] + w = w * b - 1e4 * (1 - b) + + if attention_mask is not None: + # Apply the attention mask + w = w + attention_mask + + w = nn.Softmax(dim=-1)(w) + w = self.attn_dropout(w) + + # Mask heads if we want to + if head_mask is not None: + w = w * head_mask + + outputs = [torch.matmul(w, v)] + if self.output_attentions: + outputs.append(w) + return outputs + + def merge_heads(self, x): + x = x.permute(0, 2, 1, 3).contiguous() + new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) + return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states + + def split_heads(self, x, k=False): + new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) + x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states + if k: + return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) + else: + return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def forward(self, x, layer_past=None, attention_mask=None, head_mask=None): + x = self.c_attn(x) + query, key, value = x.split(self.split_size, dim=2) + query = self.split_heads(query) + key = self.split_heads(key, k=True) + value = self.split_heads(value) + + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] # transpose back cf below + + past_key = self.split_heads(past_key, k=True) + past_value = self.split_heads(past_value) + # pdb.set_trace() + key = torch.cat((past_key, key), dim=-1) + value = torch.cat((past_value, value), dim=-2) + present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking + + attn_outputs = self._attn(query, key, value, attention_mask, head_mask) + a = attn_outputs[0] + + a = self.merge_heads(a) + a = self.c_proj(a) + a = self.resid_dropout(a) + + outputs = [a, present] + attn_outputs[1:] + return outputs # a, present, (attentions) + + +class GPT2OptimusBlock(nn.Module): + def __init__(self, config): + super().__init__() + nx = config.n_embd + self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) + self.attn = GPT2OptimusAttention(nx, config.n_ctx, config, scale=True) + self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(4 * nx, config) + + def forward(self, x, layer_past=None, attention_mask=None, head_mask=None): + output_attn = self.attn( + self.ln_1(x), layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask + ) + a = output_attn[0] # output_attn: a, present, (attentions) + + x = x + a + m = self.mlp(self.ln_2(x)) + x = x + m + + outputs = [x] + output_attn[1:] + return outputs # x, present, (attentions) class GPT2OptimusModel(GPT2PreTrainedModel): @@ -17,7 +127,7 @@ class GPT2OptimusModel(GPT2PreTrainedModel): self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.drop = nn.Dropout(config.embd_pdrop) - self.h = nn.ModuleList([GPT2Block(config, i) for i in range(config.n_layer)]) + self.h = nn.ModuleList([GPT2OptimusBlock(config) for i in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.linear = nn.Linear( @@ -48,17 +158,11 @@ class GPT2OptimusModel(GPT2PreTrainedModel): if self.latent_as_gpt_memory: past = self.linear(past) - share_latent = False - if share_latent: - # the same latent vector shared by all layers - past = [past.unsqueeze(-2), past.unsqueeze(-2)] # query, key - past = [past] * len(self.h) - past_length = past[0][0].size(-2) - else: - # different latent vectors for each layer - past_split = torch.split(past.unsqueeze(1), self.config.hidden_size, dim=2) - past = list(zip(past_split, past_split)) - past_length = 1 # past[0][0].size(-2) + + # different latent vectors for each layer + past_split = torch.split(past.unsqueeze(1), self.config.hidden_size, dim=2) + past = list(zip(past_split, past_split)) + past_length = 1 # past[0][0].size(-2) else: past_length = 0 past = [None] * len(self.h) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_to_text.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_to_text.py index 216ed9efe6..129134a479 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_to_text.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_to_text.py @@ -265,9 +265,8 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline): return image_embeddings - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = latents.reshape(latents.shape[:-2]).unsqueeze(1) + latents = latents.reshape(latents.shape[:-2]) self.text_vae_decoder = self.text_vae_decoder.to(self._execution_device) bos_token = self.text_vae_tokenizer.bos_token_id output = self.text_vae_decoder.generate(bos_token_id=bos_token, past=latents) @@ -454,7 +453,7 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline): # 11. Convert to strings if output_type == "str": - text = self.text_vae_tokenizer.decode(text) + text = self.text_vae_tokenizer.batch_decode(text) if not return_dict: return (text,) diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_to_text.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_to_text.py index 7e5cb92536..648ef96758 100644 --- a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_to_text.py +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_to_text.py @@ -43,14 +43,12 @@ class VersatileDiffusionImageToTextPipelineIntegrationTests(unittest.TestCase): "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" ) generator = torch.Generator(device=torch_device).manual_seed(0) - tokens = pipe( + text = pipe( image=image_prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, - output_type="numpy", + output_type="str", ).text - assert tokens.shape == (1, 30) - expected_tokens = np.array([0, 1, 2, 3, 4, 5, 6, 7]) - assert self.assertItemsEqual(tokens[0], expected_tokens) + assert text == "Corret me"