mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -1597,6 +1597,7 @@ def main(args):
|
||||
tokenizers=[None, None],
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
device=accelerator.device,
|
||||
prompt=prompts,
|
||||
)
|
||||
else:
|
||||
@@ -1606,6 +1607,7 @@ def main(args):
|
||||
tokenizers=[None, None],
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
device=accelerator.device,
|
||||
prompt=args.instance_prompt,
|
||||
)
|
||||
|
||||
|
||||
@@ -465,6 +465,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
|
||||
@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
|
||||
class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = FluxImg2ImgPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
|
||||
class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = FluxInpaintPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user