1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
sdnext/scripts/lbm/embedders/base/base_conditioner.py
Vladimir Mandic 2b9056179d add lbm background replace with relightining
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-07-04 15:33:16 -04:00

59 lines
1.7 KiB
Python

from typing import Any, Dict
from ...base.base_model import BaseModel
from .base_conditioner_config import BaseConditionerConfig
DIM2CONDITIONING = {
2: "vector",
3: "crossattn",
4: "concat",
}
class BaseConditioner(BaseModel):
"""This is the base class for all the conditioners. This absctacts the conditioning process
Args:
config (BaseConditionerConfig): The configuration of the conditioner
Examples
########
To use the conditioner, you can import the class and use it as follows:
.. code-block:: python
from cr.models.embedders import BaseConditioner, BaseConditionerConfig
# Create the conditioner config
config = BaseConditionerConfig(
input_key="text", # The key for the input
unconditional_conditioning_rate=0.3, # Drops the conditioning with 30% probability during training
)
# Create the conditioner
conditioner = BaseConditioner(config)
"""
def __init__(self, config: BaseConditionerConfig):
BaseModel.__init__(self, config)
self.config = config
self.input_key = config.input_key
self.dim2outputkey = DIM2CONDITIONING
self.ucg_rate = config.unconditional_conditioning_rate
def forward(
self, batch: Dict[str, Any], force_zero_embedding: bool = False, *args, **kwargs
):
"""
Forward pass of the embedder.
Args:
batch (Dict[str, Any]): A dictionary containing the input data.
force_zero_embedding (bool): Whether to force zero embedding.
This will return an embedding with all entries set to 0. Defaults to False.
"""
raise NotImplementedError("Forward pass must be implemented in child class")