mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-29 05:02:09 +03:00
59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
from typing import Any, Dict
|
|
from ...base.base_model import BaseModel
|
|
from .base_conditioner_config import BaseConditionerConfig
|
|
|
|
|
|
DIM2CONDITIONING = {
|
|
2: "vector",
|
|
3: "crossattn",
|
|
4: "concat",
|
|
}
|
|
|
|
|
|
class BaseConditioner(BaseModel):
|
|
"""This is the base class for all the conditioners. This absctacts the conditioning process
|
|
|
|
Args:
|
|
|
|
config (BaseConditionerConfig): The configuration of the conditioner
|
|
|
|
Examples
|
|
########
|
|
|
|
To use the conditioner, you can import the class and use it as follows:
|
|
|
|
.. code-block:: python
|
|
|
|
from cr.models.embedders import BaseConditioner, BaseConditionerConfig
|
|
|
|
# Create the conditioner config
|
|
config = BaseConditionerConfig(
|
|
input_key="text", # The key for the input
|
|
unconditional_conditioning_rate=0.3, # Drops the conditioning with 30% probability during training
|
|
)
|
|
|
|
# Create the conditioner
|
|
conditioner = BaseConditioner(config)
|
|
"""
|
|
|
|
def __init__(self, config: BaseConditionerConfig):
|
|
BaseModel.__init__(self, config)
|
|
self.config = config
|
|
self.input_key = config.input_key
|
|
self.dim2outputkey = DIM2CONDITIONING
|
|
self.ucg_rate = config.unconditional_conditioning_rate
|
|
|
|
def forward(
|
|
self, batch: Dict[str, Any], force_zero_embedding: bool = False, *args, **kwargs
|
|
):
|
|
"""
|
|
Forward pass of the embedder.
|
|
|
|
Args:
|
|
|
|
batch (Dict[str, Any]): A dictionary containing the input data.
|
|
force_zero_embedding (bool): Whether to force zero embedding.
|
|
This will return an embedding with all entries set to 0. Defaults to False.
|
|
"""
|
|
raise NotImplementedError("Forward pass must be implemented in child class")
|