diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py index ea9df999dd..c7162c6d18 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py @@ -278,6 +278,9 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline): truncation=True, padding="max_length", ) + input_ids = ( + input_ids["input_ids"] if not isinstance(input_ids, list) and "input_ids" in input_ids else input_ids + ) input_ids = torch.LongTensor(input_ids) input_ids_batch.append(input_ids)