1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/pipelines/model_glm.py
vladmandic 0d90d95bf6 lint and safeguard glm
Signed-off-by: vladmandic <mandic00@live.com>
2026-01-16 09:40:48 +01:00

142 lines
5.5 KiB
Python

import time
import rich.progress as rp
import transformers
import diffusers
from modules import shared, devices, sd_models, model_quant, sd_hijack_te
from pipelines import generic
class GLMTokenProgressProcessor(transformers.LogitsProcessor):
"""LogitsProcessor that tracks autoregressive token generation progress for GLM-Image."""
def __init__(self):
self.total_tokens = 0
self.current_step = 0
self.task_id = None
self.pbar = None
self.pbar_task = None
self.start_time = 0
def set_total(self, total_tokens: int):
self.total_tokens = total_tokens
self.current_step = 0
def __call__(self, input_ids, scores):
if self.current_step == 0:
self.task_id = shared.state.begin('AR Generation')
self.start_time = time.time()
self.pbar = rp.Progress(
rp.TextColumn('[cyan]AR Generation'),
rp.TextColumn('{task.fields[speed]}'),
rp.BarColumn(bar_width=40, complete_style='#327fba', finished_style='#327fba'),
rp.TaskProgressColumn(),
rp.MofNCompleteColumn(),
rp.TimeElapsedColumn(),
rp.TimeRemainingColumn(),
console=shared.console,
)
self.pbar.start()
self.pbar_task = self.pbar.add_task(description='', total=self.total_tokens, speed='')
self.current_step += 1
shared.state.sampling_step = self.current_step
shared.state.sampling_steps = self.total_tokens
if self.pbar is not None and self.pbar_task is not None:
elapsed = time.time() - self.start_time
speed = f'{self.current_step / elapsed:.2f}tok/s' if elapsed > 0 else ''
self.pbar.update(self.pbar_task, completed=self.current_step, speed=speed)
if self.current_step >= self.total_tokens:
if self.pbar is not None:
self.pbar.stop()
self.pbar = None
if self.task_id is not None:
shared.state.end(self.task_id)
self.task_id = None
return scores
def hijack_vision_language_generate(pipe):
"""Wrap vision_language_encoder.generate to add progress tracking."""
if not hasattr(pipe, 'vision_language_encoder') or pipe.vision_language_encoder is None:
return
original_generate = pipe.vision_language_encoder.generate
progress_processor = GLMTokenProgressProcessor()
def wrapped_generate(*args, **kwargs):
# Get max_new_tokens to determine total tokens
max_new_tokens = kwargs.get('max_new_tokens', 0)
progress_processor.set_total(max_new_tokens)
# Add progress processor to logits_processor list
existing_processors = kwargs.get('logits_processor', None)
if existing_processors is None:
existing_processors = []
elif not isinstance(existing_processors, list):
existing_processors = list(existing_processors)
kwargs['logits_processor'] = existing_processors + [progress_processor]
return original_generate(*args, **kwargs)
pipe.vision_language_encoder.generate = wrapped_generate
def load_glm_image(checkpoint_info, diffusers_load_config=None):
if diffusers_load_config is None:
diffusers_load_config = {}
repo_id = sd_models.path_to_repo(checkpoint_info)
sd_models.hf_auth_check(checkpoint_info)
if not hasattr(transformers, 'GlmImageForConditionalGeneration'):
shared.log.error(f'Load model: type=GLM-Image repo="{repo_id}" transformers={transformers.__version__} not supported')
return None
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
shared.log.debug(f'Load model: type=GLM-Image repo="{repo_id}" offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
# Load transformer (DiT decoder - 7B) with quantization support
transformer = generic.load_transformer(
repo_id,
cls_name=diffusers.GlmImageTransformer2DModel,
load_config=diffusers_load_config
)
# Load text encoder (ByT5 for glyph) - cannot use shared T5 as GLM-Image requires specific ByT5 encoder (1472 hidden size)
text_encoder = generic.load_text_encoder(
repo_id,
cls_name=transformers.T5EncoderModel,
load_config=diffusers_load_config,
allow_shared=False
)
# Load vision-language encoder (AR model - 9B)
# Note: This is a conditional generation model, different from typical text encoders
vision_language_encoder = generic.load_text_encoder(
repo_id,
cls_name=transformers.GlmImageForConditionalGeneration, # pylint: disable=no-member
subfolder="vision_language_encoder",
load_config=diffusers_load_config,
allow_shared=False
)
pipe = diffusers.GlmImagePipeline.from_pretrained(
repo_id,
cache_dir=shared.opts.diffusers_dir,
transformer=transformer,
text_encoder=text_encoder,
vision_language_encoder=vision_language_encoder,
**load_args,
)
pipe.task_args = {
'output_type': 'np',
'generate_kwargs': {
'eos_token_id': None, # Disable EOS early stopping to ensure all required tokens are generated
},
}
del transformer, text_encoder, vision_language_encoder
sd_hijack_te.init_hijack(pipe)
hijack_vision_language_generate(pipe) # Add progress tracking for AR token generation
devices.torch_gc(force=True, reason='load')
return pipe