mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix image to text
This commit is contained in:
@@ -1,8 +1,118 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2PreTrainedModel
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2PreTrainedModel
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
|
||||
class GPT2OptimusAttention(nn.Module):
|
||||
def __init__(self, nx, n_ctx, config, scale=False):
|
||||
super().__init__()
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
assert n_state % config.n_head == 0
|
||||
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
|
||||
self.c_attn = Conv1D(n_state * 3, nx)
|
||||
self.c_proj = Conv1D(n_state, nx)
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
|
||||
w = torch.matmul(q, k)
|
||||
if self.scale:
|
||||
w = w / math.sqrt(v.size(-1))
|
||||
nd, ns = w.size(-2), w.size(-1)
|
||||
b = self.bias[:, :, ns - nd : ns, :ns]
|
||||
w = w * b - 1e4 * (1 - b)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
w = w + attention_mask
|
||||
|
||||
w = nn.Softmax(dim=-1)(w)
|
||||
w = self.attn_dropout(w)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
w = w * head_mask
|
||||
|
||||
outputs = [torch.matmul(w, v)]
|
||||
if self.output_attentions:
|
||||
outputs.append(w)
|
||||
return outputs
|
||||
|
||||
def merge_heads(self, x):
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
||||
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
||||
|
||||
def split_heads(self, x, k=False):
|
||||
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
||||
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
||||
if k:
|
||||
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
|
||||
else:
|
||||
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
|
||||
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
|
||||
x = self.c_attn(x)
|
||||
query, key, value = x.split(self.split_size, dim=2)
|
||||
query = self.split_heads(query)
|
||||
key = self.split_heads(key, k=True)
|
||||
value = self.split_heads(value)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past[0], layer_past[1] # transpose back cf below
|
||||
|
||||
past_key = self.split_heads(past_key, k=True)
|
||||
past_value = self.split_heads(past_value)
|
||||
# pdb.set_trace()
|
||||
key = torch.cat((past_key, key), dim=-1)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
||||
|
||||
attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
|
||||
a = attn_outputs[0]
|
||||
|
||||
a = self.merge_heads(a)
|
||||
a = self.c_proj(a)
|
||||
a = self.resid_dropout(a)
|
||||
|
||||
outputs = [a, present] + attn_outputs[1:]
|
||||
return outputs # a, present, (attentions)
|
||||
|
||||
|
||||
class GPT2OptimusBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
nx = config.n_embd
|
||||
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2OptimusAttention(nx, config.n_ctx, config, scale=True)
|
||||
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPT2MLP(4 * nx, config)
|
||||
|
||||
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
|
||||
output_attn = self.attn(
|
||||
self.ln_1(x), layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
|
||||
)
|
||||
a = output_attn[0] # output_attn: a, present, (attentions)
|
||||
|
||||
x = x + a
|
||||
m = self.mlp(self.ln_2(x))
|
||||
x = x + m
|
||||
|
||||
outputs = [x] + output_attn[1:]
|
||||
return outputs # x, present, (attentions)
|
||||
|
||||
|
||||
class GPT2OptimusModel(GPT2PreTrainedModel):
|
||||
@@ -17,7 +127,7 @@ class GPT2OptimusModel(GPT2PreTrainedModel):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
self.h = nn.ModuleList([GPT2Block(config, i) for i in range(config.n_layer)])
|
||||
self.h = nn.ModuleList([GPT2OptimusBlock(config) for i in range(config.n_layer)])
|
||||
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.linear = nn.Linear(
|
||||
@@ -48,17 +158,11 @@ class GPT2OptimusModel(GPT2PreTrainedModel):
|
||||
|
||||
if self.latent_as_gpt_memory:
|
||||
past = self.linear(past)
|
||||
share_latent = False
|
||||
if share_latent:
|
||||
# the same latent vector shared by all layers
|
||||
past = [past.unsqueeze(-2), past.unsqueeze(-2)] # query, key
|
||||
past = [past] * len(self.h)
|
||||
past_length = past[0][0].size(-2)
|
||||
else:
|
||||
# different latent vectors for each layer
|
||||
past_split = torch.split(past.unsqueeze(1), self.config.hidden_size, dim=2)
|
||||
past = list(zip(past_split, past_split))
|
||||
past_length = 1 # past[0][0].size(-2)
|
||||
|
||||
# different latent vectors for each layer
|
||||
past_split = torch.split(past.unsqueeze(1), self.config.hidden_size, dim=2)
|
||||
past = list(zip(past_split, past_split))
|
||||
past_length = 1 # past[0][0].size(-2)
|
||||
else:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
|
||||
@@ -265,9 +265,8 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
|
||||
return image_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = latents.reshape(latents.shape[:-2]).unsqueeze(1)
|
||||
latents = latents.reshape(latents.shape[:-2])
|
||||
self.text_vae_decoder = self.text_vae_decoder.to(self._execution_device)
|
||||
bos_token = self.text_vae_tokenizer.bos_token_id
|
||||
output = self.text_vae_decoder.generate(bos_token_id=bos_token, past=latents)
|
||||
@@ -454,7 +453,7 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
|
||||
# 11. Convert to strings
|
||||
if output_type == "str":
|
||||
text = self.text_vae_tokenizer.decode(text)
|
||||
text = self.text_vae_tokenizer.batch_decode(text)
|
||||
|
||||
if not return_dict:
|
||||
return (text,)
|
||||
|
||||
@@ -43,14 +43,12 @@ class VersatileDiffusionImageToTextPipelineIntegrationTests(unittest.TestCase):
|
||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
||||
)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
tokens = pipe(
|
||||
text = pipe(
|
||||
image=image_prompt,
|
||||
generator=generator,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=50,
|
||||
output_type="numpy",
|
||||
output_type="str",
|
||||
).text
|
||||
|
||||
assert tokens.shape == (1, 30)
|
||||
expected_tokens = np.array([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
assert self.assertItemsEqual(tokens[0], expected_tokens)
|
||||
assert text == "Corret me"
|
||||
|
||||
Reference in New Issue
Block a user