Ryan Russell
80183ca58b
docs: fix Berkeley ref ( #611 )
...
Signed-off-by: Ryan Russell <git@ryanrussell.org >
Signed-off-by: Ryan Russell <git@ryanrussell.org >
2022-09-21 22:55:32 +02:00
Anton Lozhkov
6bd005ebbe
[ONNX] Collate the external weights, speed up loading from the hub ( #610 )
2022-09-21 22:26:30 +02:00
Pedro Cuenca
a9fdb3de9e
Return Flax scheduler state ( #601 )
...
* Optionally return state in from_config.
Useful for Flax schedulers.
* has_state is now a property, make check more strict.
I don't check the class is `SchedulerMixin` to prevent circular
dependencies. It should be enough that the class name starts with "Flax"
the object declares it "has_state" and the "create_state" exists too.
* Use state in pipeline from_pretrained.
* Make style
2022-09-21 22:25:27 +02:00
Anton Lozhkov
e72f1a8a71
Add torchvision to training deps ( #607 )
2022-09-21 13:54:32 +02:00
Anton Lozhkov
4f1c989ffb
Add smoke tests for the training examples ( #585 )
...
* Add smoke tests for the training examples
* upd
* use a dummy dataset
* mark as slow
* cleanup
* Update test cases
* naming
2022-09-21 13:36:59 +02:00
Younes Belkada
3fc8ef7297
Replace dropout_prob by dropout in vae ( #595 )
...
replace `dropout_prob` by `dropout` in `vae`
2022-09-21 11:43:28 +02:00
Mishig Davaadorj
8685699392
Mv weights name consts to diffusers.utils ( #605 )
2022-09-21 11:30:14 +02:00
Mishig Davaadorj
f810060006
Fix flax from_pretrained pytorch weight check ( #603 )
2022-09-21 11:17:15 +02:00
Pedro Cuenca
fb2fbab10b
Allow dtype to be specified in Flax pipeline ( #600 )
...
* Fix typo in docstring.
* Allow dtype to be overridden on model load.
This may be a temporary solution until #567 is addressed.
* Create latents in float32
The denoising loop always computes the next step in float32, so this
would fail when using `bfloat16`.
2022-09-21 10:57:01 +02:00
Pedro Cuenca
fb03aad8b4
Fix params replication when using the dummy checker ( #602 )
...
Fix params replication when sing the dummy checker.
2022-09-21 09:38:10 +02:00
Patrick von Platen
2345481c0e
[Flax] Fix unet and ddim scheduler ( #594 )
...
* [Flax] Fix unet and ddim scheduler
* correct
* finish
2022-09-20 23:29:09 +02:00
Mishig Davaadorj
d934d3d795
FlaxDiffusionPipeline & FlaxStableDiffusionPipeline ( #559 )
...
* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline
* todo comment
* Fix imports
* Fix imports
* add dummies
* Fix empty init
* make pipeline work
* up
* Use Flax schedulers (typing, docstring)
* Wrap model imports inside availability checks.
* more updates
* make sure flax is not broken
* make style
* more fixes
* up
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com >
Co-authored-by: Pedro Cuenca <pedro@latenitesoft.com >
2022-09-20 21:28:07 +02:00
Suraj Patil
c6629e6f11
[flax safety checker] Use FlaxPreTrainedModel for saving/loading ( #591 )
...
* use FlaxPreTrainedModel for flax safety module
* fix name
* fix one more
* Apply suggestions from code review
2022-09-20 20:11:32 +02:00
Anton Lozhkov
8a6833b85c
Add the K-LMS scheduler to the inpainting pipeline + tests ( #587 )
...
* Add the K-LMS scheduler to the inpainting pipeline + tests
* Remove redundant casts
2022-09-20 19:10:44 +02:00
Anton Lozhkov
a45dca077c
Fix BaseOutput initialization from dict ( #570 )
...
* Fix BaseOutput initialization from dict
* style
* Simplify post-init, add tests
* remove debug
2022-09-20 18:32:16 +02:00
Suraj Patil
c01ec2d119
[FlaxAutoencoderKL] rename weights to align with PT ( #584 )
...
* rename weights to align with PT
* DiagonalGaussianDistribution => FlaxDiagonalGaussianDistribution
* fix name
2022-09-20 13:04:16 +02:00
Younes Belkada
0902449ef8
Add from_pt argument in .from_pretrained ( #527 )
...
* first commit:
- add `from_pt` argument in `from_pretrained` function
- add `modeling_flax_pytorch_utils.py` file
* small nit
- fix a small nit - to not enter in the second if condition
* major changes
- modify FlaxUnet modules
- first conversion script
- more keys to be matched
* keys match
- now all keys match
- change module names for correct matching
- upsample module name changed
* working v1
- test pass with atol and rtol= `4e-02`
* replace unsued arg
* make quality
* add small docstring
* add more comments
- add TODO for embedding layers
* small change
- use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array
* add more conditions on conversion
- add better test to check for keys conversion
* make shapes consistent
- output `img_w x img_h x n_channels` from the VAE
* Revert "make shapes consistent"
This reverts commit 4cad1aeb4a .
* fix unet shape
- channels first!
2022-09-20 12:39:25 +02:00
Yuta Hayashibe
ca74951323
Fix typos ( #568 )
...
* Fix a setting bug
* Fix typos
* Reverted params to parms
2022-09-19 21:58:41 +02:00
Yih-Dar
84616b5de5
Fix CrossAttention._sliced_attention ( #563 )
...
* Fix CrossAttention._sliced_attention
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com >
2022-09-19 18:07:32 +02:00
Suraj Patil
8d36d5adb1
Update clip_guided_stable_diffusion.py
2022-09-19 18:03:00 +02:00
Suraj Patil
dc2a1c1d07
[examples/community] add CLIPGuidedStableDiffusion ( #561 )
...
* add CLIPGuidedStableDiffusion
* add credits
* add readme
* style
* add clip prompt
* fnfix cond_n
* fix cond fn
* fix cond fn for lms
2022-09-19 17:29:19 +02:00
Anton Lozhkov
9727cda678
[Tests] Mark the ncsnpp model tests as slow ( #575 )
...
* [Tests] Mark the ncsnpp model tests as slow
* style
2022-09-19 17:20:58 +02:00
Anton Lozhkov
0a2c42f3e2
[Tests] Upload custom test artifacts ( #572 )
...
* make_reports
* add test utils
* style
* style
2022-09-19 17:08:29 +02:00
Patrick von Platen
2a8477de5c
[Flax] Solve problem with VAE ( #574 )
2022-09-19 16:50:22 +02:00
Patrick von Platen
bf5ca036fa
[Flax] Add Vae for Stable Diffusion ( #555 )
...
* [Flax] Add Vae
* correct
* Apply suggestions from code review
Co-authored-by: Suraj Patil <surajp815@gmail.com >
* Finish
Co-authored-by: Suraj Patil <surajp815@gmail.com >
2022-09-19 16:00:54 +02:00
Yih-Dar
b17d49f863
Fix _upsample_2d ( #535 )
...
* Fix _upsample_2d
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com >
2022-09-19 15:52:52 +02:00
Anton Lozhkov
b8d1f2d344
Remove check_tf_utils to avoid an unnecessary TF import for now ( #566 )
2022-09-19 15:37:36 +02:00
Pedro Cuenca
5b3f249659
Flax: ignore dtype for configuration ( #565 )
...
Flax: ignore dtype for configuration.
This makes it possible to save models and configuration files.
2022-09-19 15:37:07 +02:00
Pedro Cuenca
fde9abcbba
JAX/Flax safety checker ( #558 )
...
* Starting to integrate safety checker.
* Fix initialization of CLIPVisionConfig
* Remove commented lines.
* make style
* Remove unused import
* Pass dtype to modules
Co-authored-by: Suraj Patil <surajp815@gmail.com >
* Pass dtype to modules
Co-authored-by: Suraj Patil <surajp815@gmail.com >
Co-authored-by: Suraj Patil <surajp815@gmail.com >
2022-09-19 15:26:49 +02:00
Kashif Rasul
b1182bcf21
[Flax] fix Flax scheduler ( #564 )
...
* remove match_shape
* ported fixes from #479 to flax
* remove unused argument
* typo
* remove warnings
2022-09-19 14:48:00 +02:00
ydshieh
0424615a5d
revert the accidental commit
2022-09-19 14:16:10 +02:00
ydshieh
8187865aef
Fix CrossAttention._sliced_attention
2022-09-19 14:08:29 +02:00
Mishig Davaadorj
0c0c222432
FlaxUNet2DConditionOutput @flax.struct.dataclass ( #550 )
2022-09-18 19:35:37 +02:00
Younes Belkada
d09bbae515
make fixup support (#546 )
...
* add `get_modified_files.py`
- file copied from https://github.com/huggingface/transformers/blob/main/utils/get_modified_files.py
* make fixup
2022-09-18 19:34:51 +02:00
Patrick von Platen
429dace10a
[Configuration] Better logging ( #545 )
...
* [Config] improve logging
* finish
2022-09-17 14:09:13 +02:00
Jonatan Kłosko
d7dcba4a13
Unify offset configuration in DDIM and PNDM schedulers ( #479 )
...
* Unify offset configuration in DDIM and PNDM schedulers
* Format
Add missing variables
* Fix pipeline test
* Update src/diffusers/schedulers/scheduling_ddim.py
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com >
* Default set_alpha_to_one to false
* Format
* Add tests
* Format
* add deprecation warning
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com >
2022-09-17 14:07:43 +02:00
Patrick von Platen
9e439d8c60
[Hub] Update hub version ( #538 )
2022-09-16 20:29:01 +02:00
Patrick von Platen
e5902ed11a
[Download] Smart downloading ( #512 )
...
* [Download] Smart downloading
* add test
* finish test
* update
* make style
2022-09-16 19:32:40 +02:00
Sid Sahai
a54cfe6828
Add LMSDiscreteSchedulerTest ( #467 )
...
* [WIP] add LMSDiscreteSchedulerTest
* fixes for comments
* add torch numpy test
* rebase
* Update tests/test_scheduler.py
* Update tests/test_scheduler.py
* style
* return residuals
Co-authored-by: Anton Lozhkov <anton@huggingface.co >
2022-09-16 19:10:56 +02:00
Patrick von Platen
88972172d8
Revert "adding more typehints to DDIM scheduler" ( #533 )
...
Revert "adding more typehints to DDIM scheduler (#456 )"
This reverts commit a0558b1146 .
2022-09-16 17:48:02 +02:00
V Vishnu Anirudh
a0558b1146
adding more typehints to DDIM scheduler ( #456 )
...
* adding more typehints
* resolving mypy issues
* resolving formatting issue
* fixing isort issue
Co-authored-by: V Vishnu Anirudh <git.vva@gmail.com >
Co-authored-by: V Vishnu Anirudh <vvani@kth.se >
2022-09-16 17:41:58 +02:00
Suraj Patil
06924c6a4f
[StableDiffusionInpaintPipeline] accept tensors for init and mask image ( #439 )
...
* accept tensors
* fix mask handling
* make device placement cleaner
* update doc for mask image
2022-09-16 17:35:41 +02:00
Anton Lozhkov
761f0297b0
[Tests] Fix spatial transformer tests on GPU ( #531 )
2022-09-16 16:04:37 +02:00
Anton Lozhkov
c1796efd5f
Quick fix for the img2img tests ( #530 )
...
* Quick fix for the img2img tests
* Remove debug lines
2022-09-16 15:52:26 +02:00
Yuta Hayashibe
76d492ea49
Fix typos and add Typo check GitHub Action ( #483 )
...
* Fix typos
* Add a typo check action
* Fix a bug
* Changed to manual typo check currently
Ref: https://github.com/huggingface/diffusers/pull/483#pullrequestreview-1104468010
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com >
* Removed a confusing message
* Renamed "nin_shortcut" to "in_shortcut"
* Add memo about NIN
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com >
2022-09-16 15:36:51 +02:00
Yih-Dar
c0493723f7
Remove the usage of numpy in up/down sample_2d ( #503 )
...
* Fix PT up/down sample_2d
* empty commit
* style
* style
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com >
2022-09-16 15:15:05 +02:00
Anton Lozhkov
c727a6a5fb
Finally fix the image-based SD tests ( #509 )
...
* Finally fix the image-based SD tests
* Remove autocast
* Remove autocast in image tests
2022-09-16 14:37:12 +02:00
Sid Sahai
f73ca908e5
[Tests] Test attention.py ( #368 )
...
* add test for AttentionBlock, SpatialTransformer
* add context_dim, handle device
* removed dropout test
* fixes, add dropout test
2022-09-16 12:59:42 +02:00
SkyTNT
37c9d789aa
Fix is_onnx_available ( #440 )
...
* Fix is_onnx_available
Fix: If user install onnxruntime-gpu, is_onnx_available() will return False.
* add more onnxruntime candidates
* Run `make style`
Co-authored-by: anton-l <anton@huggingface.co >
2022-09-16 12:13:22 +02:00
Anton Lozhkov
214520c66a
[CI] Add stalebot ( #481 )
...
* Add stalebot
* style
* Remove the closing logic
* Make sure not to spam
2022-09-16 12:03:04 +02:00