mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix FLUX.2 context parallel (#12737)
This commit is contained in:
@@ -676,8 +676,8 @@ class Flux2Transformer2DModel(
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
"txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
|
||||
"img_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"txt_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user