mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
* re-add RL model code * match model forward api * add register_to_config, pass training tests * fix tests, update forward outputs * remove unused code, some comments * add to docs * remove extra embedding code * unify time embedding * remove conv1d output sequential * remove sequential from conv1dblock * style and deleting duplicated code * clean files * remove unused variables * clean variables * add 1d resnet block structure for downsample * rename as unet1d * fix renaming * rename files * add get_block(...) api * unify args for model1d like model2d * minor cleaning * fix docs * improve 1d resnet blocks * fix tests, remove permuts * fix style * add output activation * rename flax blocks file * Add Value Function and corresponding example script to Diffuser implementation (#884) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review Co-authored-by: Nathan Lambert <nathan@huggingface.co> * update post merge of scripts * add mdiblock / outblock architecture * Pipeline cleanup (#947) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src Co-authored-by: Nathan Lambert <nathan@huggingface.co> * Update src/diffusers/models/unet_1d_blocks.py * Update tests/test_models_unet.py * RL Cleanup v2 (#965) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src * add specific vf block and update tests * style * Update tests/test_models_unet.py Co-authored-by: Nathan Lambert <nathan@huggingface.co> * fix quality in tests * fix quality style, split test file * fix checks / tests * make timesteps closer to main * unify block API * unify forward api * delete lines in examples * style * examples style * all tests pass * make style * make dance_diff test pass * Refactoring RL PR (#1200) * init file changes * add import utils * finish cleaning files, imports * remove import flags * clean examples * fix imports, tests for merge * update readmes * hotfix for tests * quality * fix some tests * change defaults * more mps test fixes * unet1d defaults * do not default import experimental * defaults for tests * fix tests * fix-copies * fix * changes per Patrik's comments (#1285) * changes per Patrik's comments * update conversion script * fix renaming * skip more mps tests * last test fix * Update examples/rl/README.md Co-authored-by: Ben Glickenhaus <benglickenhaus@gmail.com>
58 lines
1.4 KiB
Python
58 lines
1.4 KiB
Python
import d4rl # noqa
|
|
import gym
|
|
import tqdm
|
|
from diffusers.experimental import ValueGuidedRLPipeline
|
|
|
|
|
|
config = dict(
|
|
n_samples=64,
|
|
horizon=32,
|
|
num_inference_steps=20,
|
|
n_guide_steps=2,
|
|
scale_grad_by_std=True,
|
|
scale=0.1,
|
|
eta=0.0,
|
|
t_grad_cutoff=2,
|
|
device="cpu",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
env_name = "hopper-medium-v2"
|
|
env = gym.make(env_name)
|
|
|
|
pipeline = ValueGuidedRLPipeline.from_pretrained(
|
|
"bglick13/hopper-medium-v2-value-function-hor32",
|
|
env=env,
|
|
)
|
|
|
|
env.seed(0)
|
|
obs = env.reset()
|
|
total_reward = 0
|
|
total_score = 0
|
|
T = 1000
|
|
rollout = [obs.copy()]
|
|
try:
|
|
for t in tqdm.tqdm(range(T)):
|
|
# call the policy
|
|
denorm_actions = pipeline(obs, planning_horizon=32)
|
|
|
|
# execute action in environment
|
|
next_observation, reward, terminal, _ = env.step(denorm_actions)
|
|
score = env.get_normalized_score(total_reward)
|
|
# update return
|
|
total_reward += reward
|
|
total_score += score
|
|
print(
|
|
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
|
|
f" {total_score}"
|
|
)
|
|
# save observations for rendering
|
|
rollout.append(next_observation.copy())
|
|
|
|
obs = next_observation
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
print(f"Total reward: {total_reward}")
|