mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* 📄 Renamed File for Better Understanding Renamed the 'rl' file to 'run_locomotion'. This change was made to improve the clarity and readability of the codebase. The 'rl' name was ambiguous, and 'run_locomotion' provides a more clear description of the file's purpose. Thanks 🙌 * 📁 [Docs] Renamed Directory for Better Clarity Renamed the 'rl' directory to 'reinforcement_learning'. This change provides a clearer understanding of the directory's purpose and its contents. * Update examples/reinforcement_learning/README.md Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * 📝 Update README --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
60 lines
1.5 KiB
Python
60 lines
1.5 KiB
Python
import d4rl # noqa
|
|
import gym
|
|
import tqdm
|
|
from diffusers.experimental import ValueGuidedRLPipeline
|
|
|
|
|
|
config = {
|
|
"n_samples": 64,
|
|
"horizon": 32,
|
|
"num_inference_steps": 20,
|
|
"n_guide_steps": 2, # can set to 0 for faster sampling, does not use value network
|
|
"scale_grad_by_std": True,
|
|
"scale": 0.1,
|
|
"eta": 0.0,
|
|
"t_grad_cutoff": 2,
|
|
"device": "cpu",
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
env_name = "hopper-medium-v2"
|
|
env = gym.make(env_name)
|
|
|
|
pipeline = ValueGuidedRLPipeline.from_pretrained(
|
|
"bglick13/hopper-medium-v2-value-function-hor32",
|
|
env=env,
|
|
)
|
|
|
|
env.seed(0)
|
|
obs = env.reset()
|
|
total_reward = 0
|
|
total_score = 0
|
|
T = 1000
|
|
rollout = [obs.copy()]
|
|
try:
|
|
for t in tqdm.tqdm(range(T)):
|
|
# call the policy
|
|
denorm_actions = pipeline(obs, planning_horizon=32)
|
|
|
|
# execute action in environment
|
|
next_observation, reward, terminal, _ = env.step(denorm_actions)
|
|
score = env.get_normalized_score(total_reward)
|
|
|
|
# update return
|
|
total_reward += reward
|
|
total_score += score
|
|
print(
|
|
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
|
|
f" {total_score}"
|
|
)
|
|
|
|
# save observations for rendering
|
|
rollout.append(next_observation.copy())
|
|
|
|
obs = next_observation
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
print(f"Total reward: {total_reward}")
|