mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[PRX] Improve model compilation (#12787)
* Reimplement img2seq & seq2img in PRX to enable ONNX build without Col2Im (incompatible with TensorRT). * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
committed by
GitHub
parent
2246d2c7c4
commit
3d02cd543e
@@ -16,7 +16,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.functional import fold, unfold
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
@@ -532,7 +531,19 @@ def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
|
||||
Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
|
||||
// patch_size)` is the number of patches.
|
||||
"""
|
||||
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
|
||||
b, c, h, w = img.shape
|
||||
p = patch_size
|
||||
|
||||
# Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions
|
||||
img = img.reshape(b, c, h // p, p, w // p, p)
|
||||
|
||||
# Permute to (B, H//p, W//p, C, p, p) using einsum
|
||||
# n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width
|
||||
img = torch.einsum("nchpwq->nhwcpq", img)
|
||||
|
||||
# Flatten to (B, L, C * p * p)
|
||||
img = img.reshape(b, -1, c * p * p)
|
||||
return img
|
||||
|
||||
|
||||
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
|
||||
@@ -554,12 +565,26 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
|
||||
Reconstructed image tensor of shape `(B, C, H, W)`.
|
||||
"""
|
||||
if isinstance(shape, tuple):
|
||||
shape = shape[-2:]
|
||||
h, w = shape[-2:]
|
||||
elif isinstance(shape, torch.Tensor):
|
||||
shape = (int(shape[0]), int(shape[1]))
|
||||
h, w = (int(shape[0]), int(shape[1]))
|
||||
else:
|
||||
raise NotImplementedError(f"shape type {type(shape)} not supported")
|
||||
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
b, l, d = seq.shape
|
||||
p = patch_size
|
||||
c = d // (p * p)
|
||||
|
||||
# Reshape back to grid structure: (B, H//p, W//p, C, p, p)
|
||||
seq = seq.reshape(b, h // p, w // p, c, p, p)
|
||||
|
||||
# Permute back to image layout: (B, C, H//p, p, W//p, p)
|
||||
# n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width
|
||||
seq = torch.einsum("nhwcpq->nchpwq", seq)
|
||||
|
||||
# Final reshape to (B, C, H, W)
|
||||
seq = seq.reshape(b, c, h, w)
|
||||
return seq
|
||||
|
||||
|
||||
class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
|
||||
Reference in New Issue
Block a user