1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Commit Graph

907 Commits

Author SHA1 Message Date
Pedro Cuenca
fde9abcbba JAX/Flax safety checker (#558)
* Starting to integrate safety checker.

* Fix initialization of CLIPVisionConfig

* Remove commented lines.

* make style

* Remove unused import

* Pass dtype to modules

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Pass dtype to modules

Co-authored-by: Suraj Patil <surajp815@gmail.com>

Co-authored-by: Suraj Patil <surajp815@gmail.com>
2022-09-19 15:26:49 +02:00
Kashif Rasul
b1182bcf21 [Flax] fix Flax scheduler (#564)
* remove match_shape

* ported fixes from #479 to flax

* remove unused argument

* typo

* remove warnings
2022-09-19 14:48:00 +02:00
ydshieh
0424615a5d revert the accidental commit 2022-09-19 14:16:10 +02:00
ydshieh
8187865aef Fix CrossAttention._sliced_attention 2022-09-19 14:08:29 +02:00
Mishig Davaadorj
0c0c222432 FlaxUNet2DConditionOutput @flax.struct.dataclass (#550) 2022-09-18 19:35:37 +02:00
Younes Belkada
d09bbae515 make fixup support (#546)
* add `get_modified_files.py`

- file copied from https://github.com/huggingface/transformers/blob/main/utils/get_modified_files.py

* make fixup
2022-09-18 19:34:51 +02:00
Patrick von Platen
429dace10a [Configuration] Better logging (#545)
* [Config] improve logging

* finish
2022-09-17 14:09:13 +02:00
Jonatan Kłosko
d7dcba4a13 Unify offset configuration in DDIM and PNDM schedulers (#479)
* Unify offset configuration in DDIM and PNDM schedulers

* Format

Add missing variables

* Fix pipeline test

* Update src/diffusers/schedulers/scheduling_ddim.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Default set_alpha_to_one to false

* Format

* Add tests

* Format

* add deprecation warning

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-09-17 14:07:43 +02:00
Patrick von Platen
9e439d8c60 [Hub] Update hub version (#538) 2022-09-16 20:29:01 +02:00
Patrick von Platen
e5902ed11a [Download] Smart downloading (#512)
* [Download] Smart downloading

* add test

* finish test

* update

* make style
2022-09-16 19:32:40 +02:00
Sid Sahai
a54cfe6828 Add LMSDiscreteSchedulerTest (#467)
* [WIP] add LMSDiscreteSchedulerTest

* fixes for comments

* add torch numpy test

* rebase

* Update tests/test_scheduler.py

* Update tests/test_scheduler.py

* style

* return residuals

Co-authored-by: Anton Lozhkov <anton@huggingface.co>
2022-09-16 19:10:56 +02:00
Patrick von Platen
88972172d8 Revert "adding more typehints to DDIM scheduler" (#533)
Revert "adding more typehints to DDIM scheduler (#456)"

This reverts commit a0558b1146.
2022-09-16 17:48:02 +02:00
V Vishnu Anirudh
a0558b1146 adding more typehints to DDIM scheduler (#456)
* adding more typehints

* resolving mypy issues

* resolving formatting issue

* fixing isort issue

Co-authored-by: V Vishnu Anirudh <git.vva@gmail.com>
Co-authored-by: V Vishnu Anirudh <vvani@kth.se>
2022-09-16 17:41:58 +02:00
Suraj Patil
06924c6a4f [StableDiffusionInpaintPipeline] accept tensors for init and mask image (#439)
* accept tensors

* fix mask handling

* make device placement cleaner

* update doc for mask image
2022-09-16 17:35:41 +02:00
Anton Lozhkov
761f0297b0 [Tests] Fix spatial transformer tests on GPU (#531) 2022-09-16 16:04:37 +02:00
Anton Lozhkov
c1796efd5f Quick fix for the img2img tests (#530)
* Quick fix for the img2img tests

* Remove debug lines
2022-09-16 15:52:26 +02:00
Yuta Hayashibe
76d492ea49 Fix typos and add Typo check GitHub Action (#483)
* Fix typos

* Add a typo check action

* Fix a bug

* Changed to manual typo check currently

Ref: https://github.com/huggingface/diffusers/pull/483#pullrequestreview-1104468010

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* Removed a confusing message

* Renamed "nin_shortcut" to "in_shortcut"

* Add memo about NIN

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
2022-09-16 15:36:51 +02:00
Yih-Dar
c0493723f7 Remove the usage of numpy in up/down sample_2d (#503)
* Fix PT up/down sample_2d

* empty commit

* style

* style

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2022-09-16 15:15:05 +02:00
Anton Lozhkov
c727a6a5fb Finally fix the image-based SD tests (#509)
* Finally fix the image-based SD tests

* Remove autocast

* Remove autocast in image tests
2022-09-16 14:37:12 +02:00
Sid Sahai
f73ca908e5 [Tests] Test attention.py (#368)
* add test for AttentionBlock, SpatialTransformer

* add context_dim, handle device

* removed dropout test

* fixes, add dropout test
2022-09-16 12:59:42 +02:00
SkyTNT
37c9d789aa Fix is_onnx_available (#440)
* Fix is_onnx_available

Fix: If user install onnxruntime-gpu, is_onnx_available() will return False.

* add more onnxruntime candidates

* Run `make style`

Co-authored-by: anton-l <anton@huggingface.co>
2022-09-16 12:13:22 +02:00
Anton Lozhkov
214520c66a [CI] Add stalebot (#481)
* Add stalebot

* style

* Remove the closing logic

* Make sure not to spam
2022-09-16 12:03:04 +02:00
Suraj Patil
039958eae5 Stable diffusion text2img conversion script. (#154)
* begin text2img conversion script

* add fn to convert config

* create config if not provided

* update imports and use UNet2DConditionModel

* fix imports, layer names

* fix unet coversion

* add function to convert VAE

* fix vae conversion

* update main

* create text model

* update config creating logic for unet

* fix config creation

* update script to create and save pipeline

* remove unused imports

* fix checkpoint loading

* better name

* save progress

* finish

* up

* up

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-09-16 00:07:32 +02:00
Pedro Cuenca
d8b0e4f433 UNet Flax with FlaxModelMixin (#502)
* First UNet Flax modeling blocks.

Mimic the structure of the PyTorch files.
The model classes themselves need work, depending on what we do about
configuration and initialization.

* Remove FlaxUNet2DConfig class.

* ignore_for_config non-config args.

* Implement `FlaxModelMixin`

* Use new mixins for Flax UNet.

For some reason the configuration is not correctly applied; the
signature of the `__init__` method does not contain all the parameters
by the time it's inspected in `extract_init_dict`.

* Import `FlaxUNet2DConditionModel` if flax is available.

* Rm unused method `framework`

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Indicate types in flax.struct.dataclass as pointed out by @mishig25

Co-authored-by: Mishig Davaadorj <mishig.davaadorj@coloradocollege.edu>

* Fix typo in transformer block.

* make style

* some more changes

* make style

* Add comment

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Rm unneeded comment

* Update docstrings

* correct ignore kwargs

* make style

* Update docstring examples

* Make style

* Style: remove empty line.

* Apply style (after upgrading black from pinned version)

* Remove some commented code and unused imports.

* Add init_weights (not yet in use until #513).

* Trickle down deterministic to blocks.

* Rename q, k, v according to the latest PyTorch version.

Note that weights were exported with the old names, so we need to be
careful.

* Flax UNet docstrings, default props as in PyTorch.

* Fix minor typos in PyTorch docstrings.

* Use FlaxUNet2DConditionOutput as output from UNet.

* make style

Co-authored-by: Mishig Davaadorj <dmishig@gmail.com>
Co-authored-by: Mishig Davaadorj <mishig.davaadorj@coloradocollege.edu>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-09-15 18:07:15 +02:00
Mishig Davaadorj
fb5468a6aa Add init_weights method to FlaxMixin (#513)
* Add `init_weights` method to `FlaxMixin`

* Rn `random_state` -> `shape_state`

* `PRNGKey(0)` for `jax.eval_shape`

* No allow mismatched sizes

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* docstring diffusers

Co-authored-by: Suraj Patil <surajp815@gmail.com>
2022-09-15 17:01:41 +02:00
Suraj Patil
d144c46a59 [UNet2DConditionModel, UNet2DModel] pass norm_num_groups to all the blocks (#442)
* pass norm_num_groups to unet blocs and attention

* fix UNet2DConditionModel

* add norm_num_groups arg in vae

* add tests

* remove comment

* Apply suggestions from code review
2022-09-15 16:35:14 +02:00
Kashif Rasul
b34be039f9 Karras VE, DDIM and DDPM flax schedulers (#508)
* beta never changes removed from state

* fix typos in docs

* removed unused var

* initial ddim flax scheduler

* import

* added dummy objects

* fix style

* fix typo

* docs

* fix typo in comment

* set return type

* added flax ddom

* fix style

* remake

* pass PRNG key as argument and split before use

* fix doc string

* use config

* added flax Karras VE scheduler

* make style

* fix dummy

* fix ndarray type annotation

* replace returns a new state

* added lms_discrete scheduler

* use self.config

* add_noise needs state

* use config

* use config

* docstring

* added flax score sde ve

* fix imports

* fix typos
2022-09-15 15:55:48 +02:00
Mishig Davaadorj
83a7bb2aba Implement FlaxModelMixin (#493)
* Implement `FlaxModelMixin`

* Rm unused method `framework`

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* some more changes

* make style

* Add comment

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Rm unneeded comment

* Update docstrings

* correct ignore kwargs

* make style

* Update docstring examples

* Make style

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Rm incorrect docstring

* Add FlaxModelMixin to __init__.py

* make fix-copies

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2022-09-14 16:34:44 +02:00
Suraj Patil
8b45096927 [CrossAttention] add different method for sliced attention (#446)
* add different method for sliced attention

* Update src/diffusers/models/attention.py

* Apply suggestions from code review

* Update src/diffusers/models/attention.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-09-14 16:01:24 +02:00
Pedro Cuenca
1a69c6ff0e Fix MPS scheduler indexing when using mps (#450)
* Fix LMS scheduler indexing in `add_noise` #358.

* Fix DDIM and DDPM indexing with mps device.

* Verify format is PyTorch before using `.to()`
2022-09-14 14:33:37 +02:00
Nicolas Patry
7c4b38baca Removing .float() (autocast in fp16 will discard this (I think)). (#495) 2022-09-14 08:20:27 +02:00
Jithin James
ab7a78e8f1 docs: bocken doc links for relative links (#504)
fix: bocken doc links for relative links
2022-09-14 00:50:02 +02:00
Patrick von Platen
d12e9ebc90 [Docs] Add subfolder docs (#500)
* [Docs] Add subfolder docs

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* up

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2022-09-13 19:18:02 +02:00
Kashif Rasul
da7e3994ad Fix vae tests for cpu and gpu (#480) 2022-09-13 19:14:20 +02:00
Kashif Rasul
55f7ca3bb9 initial flax pndm schedular (#492)
* initial flax pndm

* fix typo

* use state

* return state

* add FlaxSchedulerOutput

* fix style

* add flax imports

* make style

* fix typos

* return created state

* make style

* add torch/flax imports

* docs

* fixed typo

* remove tensor_format

* round instead of cast

* ets is jnp array

* remove copy
2022-09-13 19:11:45 +02:00
Nathan Lambert
b56f102765 Fix scheduler inference steps error with power of 3 (#466)
* initial attempt at solving

* fix pndm power of 3 inference_step

* add power of 3 test

* fix index in pndm test, remove ddim test

* add comments, change to round()
2022-09-13 09:48:33 -06:00
Nathan Lambert
da990633a9 Scheduler docs update (#464)
* update scheduler docs TODOs, fix typos

* fix another typo
2022-09-13 08:34:33 -06:00
Pedro Cuenca
e335f05fb1 Rename test_scheduler_outputs_equivalence in model tests. (#451) 2022-09-13 15:03:36 +02:00
Pedro Cuenca
f7cd6b87e1 Fix disable_attention_slicing in pipelines (#498)
Fix `disable_attention_slicing` in pipelines.
2022-09-13 14:25:22 +02:00
Patrick von Platen
721e017401 [Flax] Make room for more frameworks (#494)
* start

* finish
2022-09-13 13:24:27 +02:00
Kashif Rasul
f4781a0b27 update expected results of slow tests (#268)
* update expected results of slow tests

* relax sum and mean tests

* Print shapes when reporting exception

* formatting

* fix sentence

* relax test_stable_diffusion_fast_ddim for gpu fp16

* relax flakey tests on GPU

* added comment on large tolerences

* black

* format

* set scheduler seed

* added generator

* use np.isclose

* set num_inference_steps to 50

* fix dep. warning

* update expected_slice

* preprocess if image

* updated expected results

* updated expected from CI

* pass generator to VAE

* undo change back to orig

* use orignal

* revert back the expected on cpu

* revert back values for CPU

* more undo

* update result after using gen

* update mean

* set generator for mps

* update expected on CI server

* undo

* use new seed every time

* cpu manual seed

* reduce num_inference_steps

* style

* use generator for randn

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-09-12 15:49:39 +02:00
Nathan Lambert
25a51b63ca fix table formatting for stable diffusion pipeline doc (add blank line) (#471)
fix table formatting (add blank line)
2022-09-12 10:28:27 +02:00
Partho
8eaaa546d8 Docs: fix installation typo (#453)
installation doc typo fix
2022-09-09 15:17:17 -06:00
Partho
58434879e1 Renamed variables from single letter to better naming (#449)
* renamed variable names

q -> query
k -> key
v -> value
b -> batch
c -> channel
h -> height
w -> weight

* rename variable names

missed some in the initial commit

* renamed more variable names

As per  code review suggestions, renamed x -> hidden_states and x_in -> residual

* fixed minor typo
2022-09-09 22:16:44 +05:30
Suraj Patil
5adb0a7bf7 use torch.matmul instead of einsum in attnetion. (#445)
* use torch.matmul instead of einsum

* fix softmax
2022-09-09 17:16:06 +05:30
Patrick von Platen
b2b3b1a8ab [Black] Update black (#433)
* Update black

* update table
2022-09-08 22:10:01 +02:00
Patrick von Platen
44968e4204 [Docs] Correct links (#432) 2022-09-08 21:29:24 +02:00
anton-l
5e71fb7752 Version bump: 0.4.0.dev0 2022-09-08 19:14:29 +02:00
anton-l
3f55d1359f Release: 0.3.0 v0.3.0 2022-09-08 18:20:05 +02:00
Patrick von Platen
195ebe5a02 Mark in painting experimental (#430) 2022-09-08 18:12:46 +02:00