mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Update RL docs for better sharing / adding models (#1563)
* init docs update * style * fix bad colab formatting, add pipeline comment * update todo
This commit is contained in:
@@ -14,7 +14,8 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Diffusers is in the process of expanding to modalities other than images.
|
||||
|
||||
Currently, one example is for [molecule conformation](https://www.nature.com/subjects/molecular-conformation#:~:text=Definition,to%20changes%20in%20their%20environment.) generation.
|
||||
* Generate conformations in Colab [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb)
|
||||
Example type | Colab | Pipeline |
|
||||
:-------------------------:|:-------------------------:|:-------------------------:|
|
||||
[Molecule conformation](https://www.nature.com/subjects/molecular-conformation#:~:text=Definition,to%20changes%20in%20their%20environment.) generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb) | ❌
|
||||
|
||||
More coming soon!
|
||||
@@ -13,6 +13,13 @@ specific language governing permissions and limitations under the License.
|
||||
# Using Diffusers for reinforcement learning
|
||||
|
||||
Support for one RL model and related pipelines is included in the `experimental` source of diffusers.
|
||||
More models and examples coming soon!
|
||||
|
||||
To try some of this in colab, please look at the following example:
|
||||
* Model-based reinforcement learning on Colab [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) 
|
||||
# Diffuser Value-guided Planning
|
||||
|
||||
You can run the model from [*Planning with Diffusion for Flexible Behavior Synthesis*](https://arxiv.org/abs/2205.09991) with Diffusers.
|
||||
The script is located in the [RL Examples](https://github.com/huggingface/diffusers/tree/main/examples/rl) folder.
|
||||
|
||||
Or, run this example in Colab [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb)
|
||||
|
||||
[[autodoc]] diffusers.experimental.ValueGuidedRLPipeline
|
||||
@@ -1,9 +1,12 @@
|
||||
# Overview
|
||||
|
||||
These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers.
|
||||
There are four scripts,
|
||||
1. `run_diffuser_locomotion.py` to sample actions and run them in the environment,
|
||||
2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model.
|
||||
These examples show how to run [Diffuser](https://arxiv.org/abs/2205.09991) in Diffusers.
|
||||
There are two ways to use the script, `run_diffuser_locomotion.py`.
|
||||
|
||||
The key option is a change of the variable `n_guide_steps`.
|
||||
When `n_guide_steps=0`, the trajectories are sampled from the diffusion model, but not fine-tuned to maximize reward in the environment.
|
||||
By default, `n_guide_steps=2` to match the original implementation.
|
||||
|
||||
|
||||
You will need some RL specific requirements to run the examples:
|
||||
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
import d4rl # noqa
|
||||
import gym
|
||||
import tqdm
|
||||
from diffusers.experimental import ValueGuidedRLPipeline
|
||||
|
||||
|
||||
config = dict(
|
||||
n_samples=64,
|
||||
horizon=32,
|
||||
num_inference_steps=20,
|
||||
n_guide_steps=0,
|
||||
scale_grad_by_std=True,
|
||||
scale=0.1,
|
||||
eta=0.0,
|
||||
t_grad_cutoff=2,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_name = "hopper-medium-v2"
|
||||
env = gym.make(env_name)
|
||||
|
||||
pipeline = ValueGuidedRLPipeline.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32",
|
||||
env=env,
|
||||
)
|
||||
|
||||
env.seed(0)
|
||||
obs = env.reset()
|
||||
total_reward = 0
|
||||
total_score = 0
|
||||
T = 1000
|
||||
rollout = [obs.copy()]
|
||||
try:
|
||||
for t in tqdm.tqdm(range(T)):
|
||||
# Call the policy
|
||||
denorm_actions = pipeline(obs, planning_horizon=32)
|
||||
|
||||
# execute action in environment
|
||||
next_observation, reward, terminal, _ = env.step(denorm_actions)
|
||||
score = env.get_normalized_score(total_reward)
|
||||
# update return
|
||||
total_reward += reward
|
||||
total_score += score
|
||||
print(
|
||||
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
|
||||
f" {total_score}"
|
||||
)
|
||||
# save observations for rendering
|
||||
rollout.append(next_observation.copy())
|
||||
|
||||
obs = next_observation
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
print(f"Total reward: {total_reward}")
|
||||
@@ -8,7 +8,7 @@ config = dict(
|
||||
n_samples=64,
|
||||
horizon=32,
|
||||
num_inference_steps=20,
|
||||
n_guide_steps=2,
|
||||
n_guide_steps=2, # can set to 0 for faster sampling, does not use value network
|
||||
scale_grad_by_std=True,
|
||||
scale=0.1,
|
||||
eta=0.0,
|
||||
@@ -40,6 +40,7 @@ if __name__ == "__main__":
|
||||
# execute action in environment
|
||||
next_observation, reward, terminal, _ = env.step(denorm_actions)
|
||||
score = env.get_normalized_score(total_reward)
|
||||
|
||||
# update return
|
||||
total_reward += reward
|
||||
total_score += score
|
||||
@@ -47,6 +48,7 @@ if __name__ == "__main__":
|
||||
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
|
||||
f" {total_score}"
|
||||
)
|
||||
|
||||
# save observations for rendering
|
||||
rollout.append(next_observation.copy())
|
||||
|
||||
|
||||
@@ -23,6 +23,22 @@ from ...utils.dummy_pt_objects import DDPMScheduler
|
||||
|
||||
|
||||
class ValueGuidedRLPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
|
||||
|
||||
Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
|
||||
|
||||
Parameters:
|
||||
value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
|
||||
unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
|
||||
application is [`DDPMScheduler`].
|
||||
env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value_function: UNet1DModel,
|
||||
@@ -78,21 +94,26 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
|
||||
for _ in range(n_guide_steps):
|
||||
with torch.enable_grad():
|
||||
x.requires_grad_()
|
||||
|
||||
# permute to match dimension for pre-trained models
|
||||
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
|
||||
grad = torch.autograd.grad([y.sum()], [x])[0]
|
||||
|
||||
posterior_variance = self.scheduler._get_variance(i)
|
||||
model_std = torch.exp(0.5 * posterior_variance)
|
||||
grad = model_std * grad
|
||||
|
||||
grad[timesteps < 2] = 0
|
||||
x = x.detach()
|
||||
x = x + scale * grad
|
||||
x = self.reset_x0(x, conditions, self.action_dim)
|
||||
|
||||
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
|
||||
# TODO: set prediction_type when instantiating the model
|
||||
|
||||
# TODO: verify deprecation of this kwarg
|
||||
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
|
||||
|
||||
# apply conditions to the trajectory
|
||||
# apply conditions to the trajectory (set the initial state)
|
||||
x = self.reset_x0(x, conditions, self.action_dim)
|
||||
x = self.to_torch(x)
|
||||
return x, y
|
||||
@@ -126,5 +147,6 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
|
||||
else:
|
||||
# if we didn't run value guiding, select a random action
|
||||
selected_index = np.random.randint(0, batch_size)
|
||||
|
||||
denorm_actions = denorm_actions[selected_index, 0]
|
||||
return denorm_actions
|
||||
|
||||
Reference in New Issue
Block a user