百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

面向强化学习的状态空间建模:RSSM的介绍和PyTorch实现

ztj100 2025-02-11 14:27 15 浏览 0 评论

循环状态空间模型(Recurrent State Space Models, RSSM)最初由 Danijar Hafer 等人在论文《Learning Latent Dynamics for Planning from Pixels》中提出。该模型在现代基于模型的强化学习(Model-Based Reinforcement Learning, MBRL)中发挥着关键作用,其主要目标是构建可靠的环境动态预测模型。通过这些学习得到的模型,智能体能够模拟未来轨迹并进行前瞻性的行为规划。

下面我们就来用一个实际案例来介绍RSSM。

环境配置

环境配置是实现过程中的首要步骤。我们这里用易于使用的 Gym API。为了提高实现效率,设计了多个模块化的包装器(wrapper),用于初始化参数并将观察结果调整为指定格式。

InitialWrapper 的设计允许在不执行任何动作的情况下进行特定数量的观察,同时支持在返回观察结果之前多次重复同一动作。这种设计对于响应具有显著延迟特性的环境特别有效。

PreprocessFrame 包装器负责将观察结果转换为正确的数据类型(本文中使用 numpy 数组),并支持灰度转换功能。

class InitialWrapper(gym.Wrapper): 
def __init__(self, env: gym.Env, no_ops: int = 0, repeat: int = 1): 
super(InitialWrapper, self).__init__(env) 
self.repeat = repeat 
self.no_ops = no_ops 
self.op_counter = 0 

def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: 
if self.op_counter < self.no_ops: 
obs, reward, done, info = self.env.step(0) 
self.op_counter += 1 

total_reward = 0.0 
done = False 
for _ in range(self.repeat): 
obs, reward, done, info = self.env.step(action) 
total_reward += reward 
if done: 
break 

return obs, total_reward, done, info 
class PreprocessFrame(gym.ObservationWrapper): 
def __init__(self, env: gym.Env, new_shape: Sequence[int] = (128, 128, 3), grayscale: bool = False): 
super(PreprocessFrame, self).__init__(env) 
self.shape = new_shape 
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=self.shape, dtype=np.float32) 
self.grayscale = grayscale 

if self.grayscale: 
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(*self.shape[:-1], 1), dtype=np.float32) 

def observation(self, obs: torch.Tensor) -> torch.Tensor: 
obs = obs.astype(np.uint8) 
new_frame = cv.resize(obs, self.shape[:-1], interpolation=cv.INTER_AREA) 
if self.grayscale: 
new_frame = cv.cvtColor(new_frame, cv.COLOR_RGB2GRAY) 
new_frame = np.expand_dims(new_frame, -1) 

torch_frame = torch.from_numpy(new_frame).float() 
torch_frame = torch_frame / 255.0 

return torch_frame 

def make_env(env_name: str, new_shape: Sequence[int] = (128, 128, 3), grayscale: bool = True, **kwargs): 
env = gym.make(env_name, **kwargs) 
env = PreprocessFrame(env, new_shape, grayscale=grayscale) 
return env

make_env 函数用于创建一个具有指定配置参数的环境实例。

模型架构

RSSM 的实现依赖于多个关键模型组件。具体来说,需要实现以下四个核心模块:

  • 原始观察编码器(Encoder)
  • 动态模型(Dynamics Model):通过确定性状态 h 和随机状态 s 对编码观察的时间依赖性进行建模
  • 解码器(Decoder):将随机状态和确定性状态映射回原始观察空间
  • 奖励模型(Reward Model):将随机状态和确定性状态映射到奖励值

RSSM 模型组件结构图。模型包含随机状态 s 和确定性状态 h。

编码器实现

编码器采用简单的卷积神经网络(CNN)结构,将输入图像降维到一维嵌入表示。实现中使用了 BatchNorm 来提升训练稳定性。

class EncoderCNN(nn.Module): 
def __init__(self, in_channels: int, embedding_dim: int = 2048, input_shape: Tuple[int, int] = (128, 128)): 
super(EncoderCNN, self).__init__() 
# 定义卷积层结构
self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1) 
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) 
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) 
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) 

self.fc1 = nn.Linear(self._compute_conv_output((in_channels, input_shape[0], input_shape[1])), embedding_dim) 

# 批标准化层
self.bn1 = nn.BatchNorm2d(32) 
self.bn2 = nn.BatchNorm2d(64) 
self.bn3 = nn.BatchNorm2d(128) 
self.bn4 = nn.BatchNorm2d(256) 

def _compute_conv_output(self, shape: Tuple[int, int, int]): 
with torch.no_grad(): 
x = torch.randn(1, shape[0], shape[1], shape[2]) 
x = self.conv1(x) 
x = self.conv2(x) 
x = self.conv3(x) 
x = self.conv4(x) 

return x.shape[1] * x.shape[2] * x.shape[3] 
def forward(self, x): 
x = torch.relu(self.conv1(x)) 
x = self.bn1(x) 
x = torch.relu(self.conv2(x)) 
x = self.bn2(x) 

x = torch.relu(self.conv3(x)) 
x = self.bn3(x) 

x = self.conv4(x) 
x = self.bn4(x) 

x = x.view(x.size(0), -1) 
x = self.fc1(x) 

return x

解码器实现

解码器遵循传统自编码器架构设计,其功能是将编码后的观察结果重建回原始观察空间。

class DecoderCNN(nn.Module): 
def __init__(self, hidden_size: int, state_size: int, embedding_size: int, 
use_bn: bool = True, output_shape: Tuple[int, int] = (3, 128, 128)): 
super(DecoderCNN, self).__init__() 

self.output_shape = output_shape 

self.embedding_size = embedding_size 
# 全连接层进行特征变换
self.fc1 = nn.Linear(hidden_size + state_size, embedding_size) 
self.fc2 = nn.Linear(embedding_size, 256 * (output_shape[1] // 16) * (output_shape[2] // 16)) 

# 反卷积层进行上采样
self.conv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1) # ×2 
self.conv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) # ×2 
self.conv3 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) # ×2 
self.conv4 = nn.ConvTranspose2d(32, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1) 

# 批标准化层
self.bn1 = nn.BatchNorm2d(128) 
self.bn2 = nn.BatchNorm2d(64) 
self.bn3 = nn.BatchNorm2d(32) 

self.use_bn = use_bn 
def forward(self, h: torch.Tensor, s: torch.Tensor): 
x = torch.cat([h, s], dim=-1) 
x = self.fc1(x) 
x = torch.relu(x) 
x = self.fc2(x) 

x = x.view(-1, 256, self.output_shape[1] // 16, self.output_shape[2] // 16) 

if self.use_bn: 
x = torch.relu(self.bn1(self.conv1(x))) 
x = torch.relu(self.bn2(self.conv2(x))) 
x = torch.relu(self.bn3(self.conv3(x))) 

else: 
x = torch.relu(self.conv1(x)) 
x = torch.relu(self.conv2(x)) 
x = torch.relu(self.conv3(x)) 

x = self.conv4(x) 

return x

奖励模型实现

奖励模型采用了一个三层前馈神经网络结构,用于将随机状态 s 和确定性状态 h 映射到正态分布参数,进而通过采样获得奖励预测。

class RewardModel(nn.Module): 
def __init__(self, hidden_dim: int, state_dim: int): 
super(RewardModel, self).__init__() 

self.fc1 = nn.Linear(hidden_dim + state_dim, hidden_dim) 
self.fc2 = nn.Linear(hidden_dim, hidden_dim) 
self.fc3 = nn.Linear(hidden_dim, 2) 

def forward(self, h: torch.Tensor, s: torch.Tensor): 
x = torch.cat([h, s], dim=-1) 
x = torch.relu(self.fc1(x)) 
x = torch.relu(self.fc2(x)) 
x = self.fc3(x) 

return x

动态模型的实现

动态模型是 RSSM 架构中最复杂的组件,需要同时处理先验和后验状态转移模型:

  1. 后验转移模型:在能够访问真实观察的情况下使用(主要在训练阶段),用于在给定观察和历史状态的条件下近似随机状态的后验分布。
  2. 先验转移模型:用于近似先验状态分布,仅依赖于前一时刻状态,不依赖于观察。这在无法获取后验观察的推理阶段使用。

这两个模型均通过单层前馈网络进行参数化,输出各自正态分布的均值和对数方差,用于状态 s 的采样。该实现采用了简单的网络结构,但可以根据需要扩展为更复杂的架构。

确定性状态采用门控循环单元(GRU)实现。其输入包括:

  • 前一时刻的隐藏状态
  • 独热编码动作
  • 前一时刻随机状态 s(根据是否可以获取观察来选择使用后验或先验状态)

这些输入信息足以让模型了解动作历史和系统状态。以下是具体实现代码:

class DynamicsModel(nn.Module): 
def __init__(self, hidden_dim: int, action_dim: int, state_dim: int, embedding_dim: int, rnn_layer: int = 1): 
super(DynamicsModel, self).__init__() 

self.hidden_dim = hidden_dim 

# 递归层实现,支持多层 GRU
self.rnn = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(rnn_layer)]) 

# 状态动作投影层
self.project_state_action = nn.Linear(action_dim + state_dim, hidden_dim) 

# 先验网络:输出正态分布参数
self.prior = nn.Linear(hidden_dim, state_dim * 2) 
self.project_hidden_action = nn.Linear(hidden_dim + action_dim, hidden_dim) 

# 后验网络:输出正态分布参数
self.posterior = nn.Linear(hidden_dim, state_dim * 2) 
self.project_hidden_obs = nn.Linear(hidden_dim + embedding_dim, hidden_dim) 

self.state_dim = state_dim 
self.act_fn = nn.ReLU() 

def forward(self, prev_hidden: torch.Tensor, prev_state: torch.Tensor, actions: torch.Tensor, 
obs: torch.Tensor = None, dones: torch.Tensor = None): 
""" 
动态模型的前向传播
参数: 
prev_hidden: RNN的前一隐藏状态,形状 (batch_size, hidden_dim) 
prev_state: 前一随机状态,形状 (batch_size, state_dim) 
actions: 独热编码动作序列,形状 (sequence_length, batch_size, action_dim) 
obs: 编码器输出的观察嵌入,形状 (sequence_length, batch_size, embedding_dim) 
dones: 终止状态标志
""" 
B, T, _ = actions.size() # 用于无观察访问时的推理

# 初始化存储列表
hiddens_list = [] 
posterior_means_list = [] 
posterior_logvars_list = [] 
prior_means_list = [] 
prior_logvars_list = [] 
prior_states_list = [] 
posterior_states_list = [] 

# 存储初始状态
hiddens_list.append(prev_hidden.unsqueeze(1)) 
prior_states_list.append(prev_state.unsqueeze(1)) 
posterior_states_list.append(prev_state.unsqueeze(1)) 

# 时序展开
for t in range(T - 1): 
# 提取当前时刻状态和动作
action_t = actions[:, t, :] 
obs_t = obs[:, t, :] if obs is not None else torch.zeros(B, self.embedding_dim, device=actions.device) 
state_t = posterior_states_list[-1][:, 0, :] if obs is not None else prior_states_list[-1][:, 0, :] 
state_t = state_t if dones is None else state_t * (1 - dones[:, t, :]) 
hidden_t = hiddens_list[-1][:, 0, :] 

# 状态动作组合
state_action = torch.cat([state_t, action_t], dim=-1) 
state_action = self.act_fn(self.project_state_action(state_action)) 

# RNN 状态更新
for i in range(len(self.rnn)): 
hidden_t = self.rnn[i](state_action, hidden_t) 

# 先验分布计算
hidden_action = torch.cat([hidden_t, action_t], dim=-1) 
hidden_action = self.act_fn(self.project_hidden_action(hidden_action)) 
prior_params = self.prior(hidden_action) 
prior_mean, prior_logvar = torch.chunk(prior_params, 2, dim=-1) 

# 从先验分布采样
prior_dist = torch.distributions.Normal(prior_mean, torch.exp(F.softplus(prior_logvar))) 
prior_state_t = prior_dist.rsample() 

# 后验分布计算
if obs is None: 
posterior_mean = prior_mean 
posterior_logvar = prior_logvar 
else: 
hidden_obs = torch.cat([hidden_t, obs_t], dim=-1) 
hidden_obs = self.act_fn(self.project_hidden_obs(hidden_obs)) 
posterior_params = self.posterior(hidden_obs) 
posterior_mean, posterior_logvar = torch.chunk(posterior_params, 2, dim=-1) 

# 从后验分布采样
posterior_dist = torch.distributions.Normal(posterior_mean, torch.exp(F.softplus(posterior_logvar))) 
posterior_state_t = posterior_dist.rsample() 

# 保存状态
posterior_means_list.append(posterior_mean.unsqueeze(1)) 
posterior_logvars_list.append(posterior_logvar.unsqueeze(1)) 
prior_means_list.append(prior_mean.unsqueeze(1)) 
prior_logvars_list.append(prior_logvar.unsqueeze(1)) 
prior_states_list.append(prior_state_t.unsqueeze(1)) 
posterior_states_list.append(posterior_state_t.unsqueeze(1)) 
hiddens_list.append(hidden_t.unsqueeze(1)) 

# 合并时序数据
hiddens = torch.cat(hiddens_list, dim=1) 
prior_states = torch.cat(prior_states_list, dim=1) 
posterior_states = torch.cat(posterior_states_list, dim=1) 
prior_means = torch.cat(prior_means_list, dim=1) 
prior_logvars = torch.cat(prior_logvars_list, dim=1) 
posterior_means = torch.cat(posterior_means_list, dim=1) 
posterior_logvars = torch.cat(posterior_logvars_list, dim=1) 

return hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars

需要特别注意的是,这里的观察输入并非原始观察数据,而是经过编码器处理后的嵌入表示。这种设计能够有效降低计算复杂度并提升模型的泛化能力。

RSSM 整体架构

将前述组件整合为完整的 RSSM 模型。其核心是 generate_rollout 方法,负责调用动态模型并生成环境动态的潜在表示序列。对于没有历史潜在状态的情况(通常发生在轨迹开始时),该方法会进行必要的初始化。下面是完整的实现代码:

class RSSM: 
def __init__(self, 
encoder: EncoderCNN, 
decoder: DecoderCNN, 
reward_model: RewardModel, 
dynamics_model: nn.Module, 
hidden_dim: int, 
state_dim: int, 
action_dim: int, 
embedding_dim: int, 
device: str = "mps"): 
""" 
循环状态空间模型(RSSM)实现

参数:
encoder: 确定性状态编码器
decoder: 观察重构解码器
reward_model: 奖励预测模型
dynamics_model: 状态动态模型
hidden_dim: RNN 隐藏层维度
state_dim: 随机状态维度
action_dim: 动作空间维度
embedding_dim: 观察嵌入维度
device: 计算设备
""" 
super(RSSM, self).__init__() 

# 模型组件初始化
self.dynamics = dynamics_model 
self.encoder = encoder 
self.decoder = decoder 
self.reward_model = reward_model 

# 维度参数存储
self.hidden_dim = hidden_dim 
self.state_dim = state_dim 
self.action_dim = action_dim 
self.embedding_dim = embedding_dim 

# 模型迁移至指定设备
self.dynamics.to(device) 
self.encoder.to(device) 
self.decoder.to(device) 
self.reward_model.to(device) 

def generate_rollout(self, actions: torch.Tensor, hiddens: torch.Tensor = None, states: torch.Tensor = None, 
obs: torch.Tensor = None, dones: torch.Tensor = None): 
"""
生成状态序列展开

参数:
actions: 动作序列
hiddens: 初始隐藏状态(可选)
states: 初始随机状态(可选)
obs: 观察序列(可选)
dones: 终止标志序列

返回:
完整的状态展开序列
"""
# 状态初始化
if hiddens is None: 
hiddens = torch.zeros(actions.size(0), self.hidden_dim).to(actions.device) 

if states is None: 
states = torch.zeros(actions.size(0), self.state_dim).to(actions.device) 

# 执行动态模型展开
dynamics_result = self.dynamics(hiddens, states, actions, obs, dones) 
hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars = dynamics_result 

return hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars 

def train(self): 
"""启用训练模式"""
self.dynamics.train() 
self.encoder.train() 
self.decoder.train() 
self.reward_model.train() 

def eval(self): 
"""启用评估模式"""
self.dynamics.eval() 
self.encoder.eval() 
self.decoder.eval() 
self.reward_model.eval() 

def encode(self, obs: torch.Tensor): 
"""观察编码"""
return self.encoder(obs) 

def decode(self, state: torch.Tensor): 
"""状态解码为观察"""
return self.decoder(state) 

def predict_reward(self, h: torch.Tensor, s: torch.Tensor): 
"""奖励预测"""
return self.reward_model(h, s) 

def parameters(self): 
"""返回所有可训练参数"""
return list(self.dynamics.parameters()) + list(self.encoder.parameters()) + \
list(self.decoder.parameters()) + list(self.reward_model.parameters()) 

def save(self, path: str): 
"""模型状态保存"""
torch.save({ 
"dynamics": self.dynamics.state_dict(), 
"encoder": self.encoder.state_dict(), 
"decoder": self.decoder.state_dict(), 
"reward_model": self.reward_model.state_dict() 
}, path) 

def load(self, path: str): 
"""模型状态加载"""
checkpoint = torch.load(path) 
self.dynamics.load_state_dict(checkpoint["dynamics"]) 
self.encoder.load_state_dict(checkpoint["encoder"]) 
self.decoder.load_state_dict(checkpoint["decoder"]) 
self.reward_model.load_state_dict(checkpoint["reward_model"])

这个实现提供了一个完整的 RSSM 框架,包含了模型的训练、评估、状态保存和加载等基本功能。该框架可以作为基础结构,根据具体应用场景进行扩展和优化。

训练系统设计

RSSM 的训练系统主要包含两个核心组件:经验回放缓冲区(Experience Replay Buffer)和智能体(Agent)。其中,缓冲区负责存储历史经验数据用于训练,而智能体则作为环境与 RSSM 之间的接口,实现数据收集策略。

经验回放缓冲区实现

缓冲区采用循环队列结构,用于存储和管理观察、动作、奖励和终止状态等数据。通过 sample 方法可以随机采样训练序列。

class Buffer: 
def __init__(self, buffer_size: int, obs_shape: tuple, action_shape: tuple, device: torch.device): 
"""
经验回放缓冲区初始化

参数:
buffer_size: 缓冲区容量
obs_shape: 观察数据维度
action_shape: 动作数据维度
device: 计算设备
"""
self.buffer_size = buffer_size 
self.obs_buffer = np.zeros((buffer_size, *obs_shape), dtype=np.float32) 
self.action_buffer = np.zeros((buffer_size, *action_shape), dtype=np.int32) 
self.reward_buffer = np.zeros((buffer_size, 1), dtype=np.float32) 
self.done_buffer = np.zeros((buffer_size, 1), dtype=np.bool_) 

self.device = device 
self.idx = 0 

def add(self, obs: torch.Tensor, action: int, reward: float, done: bool): 
"""
添加单步经验数据
"""
self.obs_buffer[self.idx] = obs 
self.action_buffer[self.idx] = action 
self.reward_buffer[self.idx] = reward 
self.done_buffer[self.idx] = done 
self.idx = (self.idx + 1) % self.buffer_size 
def sample(self, batch_size: int, sequence_length: int): 
"""
随机采样经验序列

参数:
batch_size: 批量大小
sequence_length: 序列长度

返回:
经验数据元组 (observations, actions, rewards, dones)
"""
# 随机选择序列起始位置
starting_idxs = np.random.randint(0, (self.idx % self.buffer_size) - sequence_length, (batch_size,)) 

# 构建完整序列索引
index_tensor = np.stack([np.arange(start, start + sequence_length) for start in starting_idxs]) 

# 提取数据序列
obs_sequence = self.obs_buffer[index_tensor] 
action_sequence = self.action_buffer[index_tensor] 
reward_sequence = self.reward_buffer[index_tensor] 
done_sequence = self.done_buffer[index_tensor] 

return obs_sequence, action_sequence, reward_sequence, done_sequence 
def save(self, path: str): 
"""保存缓冲区数据"""
np.savez(path, obs_buffer=self.obs_buffer, action_buffer=self.action_buffer, 
reward_buffer=self.reward_buffer, done_buffer=self.done_buffer, idx=self.idx) 

def load(self, path: str): 
"""加载缓冲区数据"""
data = np.load(path) 
self.obs_buffer = data["obs_buffer"] 
self.action_buffer = data["action_buffer"] 
self.reward_buffer = data["reward_buffer"] 
self.done_buffer = data["done_buffer"] 
self.idx = data["idx"]

智能体设计

智能体实现了数据收集和规划功能。当前实现采用了简单的随机策略进行数据收集,但该框架支持扩展更复杂的策略。

class Policy(ABC): 
"""策略基类"""
@abstractmethod 
def __call__(self, obs): 
pass 

class RandomPolicy(Policy): 
"""随机采样策略"""
def __init__(self, env: Env): 
self.env = env 

def __call__(self, obs): 
return self.env.action_space.sample() 
class Agent: 
def __init__(self, env: Env, rssm: RSSM, buffer_size: int = 100000, 
collection_policy: str = "random", device="mps"): 
"""
智能体初始化

参数:
env: 环境实例
rssm: RSSM模型实例
buffer_size: 经验缓冲区大小
collection_policy: 数据收集策略类型
device: 计算设备
"""
self.env = env 
# 策略选择
match collection_policy: 
case "random": 
self.rollout_policy = RandomPolicy(env) 
case _: 
raise ValueError("Invalid rollout policy") 

self.buffer = Buffer(buffer_size, env.observation_space.shape, 
env.action_space.shape, device=device) 
self.rssm = rssm 

def data_collection_action(self, obs): 
"""执行数据收集动作"""
return self.rollout_policy(obs) 

def collect_data(self, num_steps: int): 
"""
收集训练数据

参数:
num_steps: 收集步数
"""
obs = self.env.reset() 
done = False 

iterator = tqdm(range(num_steps), desc="Data Collection") 
for _ in iterator: 
action = self.data_collection_action(obs) 
next_obs, reward, done, _, _ = self.env.step(action) 
self.buffer.add(next_obs, action, reward, done) 
obs = next_obs 
if done: 
obs = self.env.reset() 

def imagine_rollout(self, prev_hidden: torch.Tensor, prev_state: torch.Tensor, 
actions: torch.Tensor): 
"""
执行想象展开

参数:
prev_hidden: 前一隐藏状态
prev_state: 前一随机状态
actions: 动作序列

返回:
完整的模型输出,包括隐藏状态、先验状态、后验状态等
"""
hiddens, prior_states, posterior_states, prior_means, prior_logvars, \
posterior_means, posterior_logvars = self.rssm.generate_rollout(
actions, prev_hidden, prev_state) 

# 在想象阶段使用先验状态预测奖励
rewards = self.rssm.predict_reward(hiddens, prior_states) 

return hiddens, prior_states, posterior_states, prior_means, \
prior_logvars, posterior_means, posterior_logvars, rewards 

def plan(self, num_steps: int, prev_hidden: torch.Tensor, 
prev_state: torch.Tensor, actions: torch.Tensor): 
"""
执行规划

参数:
num_steps: 规划步数
prev_hidden: 初始隐藏状态
prev_state: 初始随机状态
actions: 动作序列

返回:
规划得到的隐藏状态和先验状态序列
"""
hidden_states = [] 
prior_states = [] 

hiddens = prev_hidden 
states = prev_state 

for _ in range(num_steps): 
hiddens, states, _, _, _, _, _, _ = self.imagine_rollout(
hiddens, states, actions) 
hidden_states.append(hiddens) 
prior_states.append(states) 

hidden_states = torch.stack(hidden_states) 
prior_states = torch.stack(prior_states) 

return hidden_states, prior_states

这部分实现提供了完整的数据管理和智能体交互框架。通过经验回放缓冲区,可以高效地存储和重用历史数据;通过智能体的抽象策略接口,可以方便地扩展不同的数据收集策略。同时智能体还实现了基于模型的想象展开和规划功能,为后续的决策制定提供了基础。

训练器实现与实验

训练器设计

训练器是 RSSM 实现中的最后一个关键组件,负责协调模型训练过程。训练器接收 RSSM 模型、智能体、优化器等组件,并实现具体的训练逻辑。

logging.basicConfig( 
level=logging.INFO, 
format="%(asctime)s - %(levelname)s - %(message)s", 
handlers=[ 
logging.StreamHandler(), # 控制台输出
logging.FileHandler("training.log", mode="w") # 文件输出
] 
) 

logger = logging.getLogger(__name__) 
class Trainer: 
def __init__(self, rssm: RSSM, agent: Agent, optimizer: torch.optim.Optimizer, 
device: torch.device): 
"""
训练器初始化

参数:
rssm: RSSM 模型实例
agent: 智能体实例
optimizer: 优化器实例
device: 计算设备
"""
self.rssm = rssm 
self.optimizer = optimizer 
self.device = device 
self.agent = agent 
self.writer = SummaryWriter() # tensorboard 日志记录器

def train_batch(self, batch_size: int, seq_len: int, iteration: int, 
save_images: bool = False): 
"""
单批次训练

参数:
batch_size: 批量大小
seq_len: 序列长度
iteration: 当前迭代次数
save_images: 是否保存重建图像
"""
# 采样训练数据
obs, actions, rewards, dones = self.agent.buffer.sample(batch_size, seq_len) 

# 数据预处理
actions = torch.tensor(actions).long().to(self.device) 
actions = F.one_hot(actions, self.rssm.action_dim).float() 
obs = torch.tensor(obs, requires_grad=True).float().to(self.device) 
rewards = torch.tensor(rewards, requires_grad=True).float().to(self.device) 
dones = torch.tensor(dones).float().to(self.device) 

# 观察编码
encoded_obs = self.rssm.encoder(obs.reshape(-1, *obs.shape[2:]).permute(0, 3, 1, 2)) 
encoded_obs = encoded_obs.reshape(batch_size, seq_len, -1) 

# 执行 RSSM 展开
rollout = self.rssm.generate_rollout(actions, obs=encoded_obs, dones=dones) 
hiddens, prior_states, posterior_states, prior_means, prior_logvars, \
posterior_means, posterior_logvars = rollout 

# 重构观察
hiddens_reshaped = hiddens.reshape(batch_size * seq_len, -1) 
posterior_states_reshaped = posterior_states.reshape(batch_size * seq_len, -1) 
decoded_obs = self.rssm.decoder(hiddens_reshaped, posterior_states_reshaped) 
decoded_obs = decoded_obs.reshape(batch_size, seq_len, *obs.shape[-3:]) 

# 奖励预测
reward_params = self.rssm.reward_model(hiddens, posterior_states) 
mean, logvar = torch.chunk(reward_params, 2, dim=-1) 
logvar = F.softplus(logvar) 
reward_dist = Normal(mean, torch.exp(logvar)) 
predicted_rewards = reward_dist.rsample() 

# 可视化
if save_images: 
batch_idx = np.random.randint(0, batch_size) 
seq_idx = np.random.randint(0, seq_len - 3) 
fig = self._visualize(obs, decoded_obs, rewards, predicted_rewards, 
batch_idx, seq_idx, iteration, grayscale=True) 
if not os.path.exists("reconstructions"): 
os.makedirs("reconstructions") 
fig.savefig(f"reconstructions/iteration_{iteration}.png") 
self.writer.add_figure("Reconstructions", fig, iteration) 
plt.close(fig) 

# 计算损失
reconstruction_loss = self._reconstruction_loss(decoded_obs, obs) 
kl_loss = self._kl_loss(prior_means, F.softplus(prior_logvars), 
posterior_means, F.softplus(posterior_logvars)) 
reward_loss = self._reward_loss(rewards, predicted_rewards) 

loss = reconstruction_loss + kl_loss + reward_loss 

# 反向传播和优化
self.optimizer.zero_grad() 
loss.backward() 
nn.utils.clip_grad_norm_(self.rssm.parameters(), 1, norm_type=2) 
self.optimizer.step() 

return loss.item(), reconstruction_loss.item(), kl_loss.item(), reward_loss.item() 

def train(self, iterations: int, batch_size: int, seq_len: int): 
"""
执行完整训练过程

参数:
iterations: 迭代总次数
batch_size: 批量大小
seq_len: 序列长度
"""
self.rssm.train() 
iterator = tqdm(range(iterations), desc="Training", total=iterations) 
losses = [] 
infos = [] 
last_loss = float("inf") 

for i in iterator: 
# 执行单批次训练
loss, reconstruction_loss, kl_loss, reward_loss = self.train_batch(
batch_size, seq_len, i, save_images=i % 100 == 0) 

# 记录训练指标
self.writer.add_scalar("Loss", loss, i) 
self.writer.add_scalar("Reconstruction Loss", reconstruction_loss, i) 
self.writer.add_scalar("KL Loss", kl_loss, i) 
self.writer.add_scalar("Reward Loss", reward_loss, i) 

# 保存最佳模型
if loss < last_loss: 
self.rssm.save("rssm.pth") 
last_loss = loss 

# 记录详细信息
info = { 
"Loss": loss, 
"Reconstruction Loss": reconstruction_loss, 
"KL Loss": kl_loss, 
"Reward Loss": reward_loss 
} 
losses.append(loss) 
infos.append(info) 

# 定期输出训练状态
if i % 10 == 0: 
logger.info("\n----------------------------") 
logger.info(f"Iteration: {i}") 
logger.info(f"Loss: {loss:.4f}") 
logger.info(f"Running average last 20 losses: {sum(losses[-20:]) / 20: .4f}") 
logger.info(f"Reconstruction Loss: {reconstruction_loss:.4f}") 
logger.info(f"KL Loss: {kl_loss:.4f}") 
logger.info(f"Reward Loss: {reward_loss:.4f}")
### 实验示例
以下是一个在 CarRacing 环境中训练 RSSM 的完整示例:
```python
# 环境初始化
env = make_env("CarRacing-v2", render_mode="rgb_array", continuous=False, grayscale=True) 
# 模型参数设置
hidden_size = 1024 
embedding_dim = 1024 
state_dim = 512 
# 模型组件实例化
encoder = EncoderCNN(in_channels=1, embedding_dim=embedding_dim) 
decoder = DecoderCNN(hidden_size=hidden_size, state_size=state_dim, 
embedding_size=embedding_dim, output_shape=(1,128,128)) 
reward_model = RewardModel(hidden_dim=hidden_size, state_dim=state_dim) 
dynamics_model = DynamicsModel(hidden_dim=hidden_size, state_dim=state_dim, 
action_dim=5, embedding_dim=embedding_dim) 
# RSSM 模型构建
rssm = RSSM(dynamics_model=dynamics_model, 
encoder=encoder, 
decoder=decoder, 
reward_model=reward_model, 
hidden_dim=hidden_size, 
state_dim=state_dim, 
action_dim=5, 
embedding_dim=embedding_dim) 
# 训练设置
optimizer = torch.optim.Adam(rssm.parameters(), lr=1e-3) 
agent = Agent(env, rssm) 
trainer = Trainer(rssm, agent, optimizer=optimizer, device="cuda") 
# 数据收集和训练
trainer.collect_data(20000) # 收集 20000 步经验数据
trainer.save_buffer("buffer.npz") # 保存经验缓冲区
trainer.train(10000, 32, 20) # 执行 10000 次迭代训练

总结

本文详细介绍了基于 PyTorch 实现 RSSM 的完整过程。RSSM 的架构相比传统的 VAE 或 RNN 更为复杂,这主要源于其混合了随机和确定性状态的特性。通过手动实现这一架构,我们可以深入理解其背后的理论基础及其强大之处。RSSM 能够递归地生成未来潜在状态轨迹,这为智能体的行为规划提供了基础。

实现的优点在于其计算负载适中,可以在单个消费级 GPU 上进行训练,在有充足时间的情况下甚至可以在 CPU 上运行。这一工作基于论文《Learning Latent Dynamics for Planning from Pixels》,该论文为 RSSM 类动态模型奠定了基础。后续的研究工作如《Dream to Control: Learning Behaviors by Latent Imagination》进一步发展了这一架构。这些改进的架构将在未来的研究中深入探讨,因为它们对理解 MBRL 方法提供了重要的见解。

作者:Lukas Bierling

相关推荐

从IDEA开始,迈进GO语言之门(idea got)

前言笔者在学习GO语言编程的时候,GO语言在国内还没有像JAVA/Php/Python那样普及,绕了不少的弯路,要开始入门学习一门编程语言,最好就先从选择一个好的编程语言的开发环境开始,有了这个开发环...

基于SpringBoot+MyBatis的私人影院java网上购票jsp源代码Mysql

本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目介绍基于SpringBoot...

基于springboot的个人服装管理系统java网上商城jsp源代码mysql

本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目介绍基于springboot...

基于springboot的美食网站Java食品销售jsp源代码Mysql

本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目介绍基于springboot...

贸易管理进销存springboot云管货管账分析java jsp源代码mysql

本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目描述贸易管理进销存spring...

SpringBoot+VUE员工信息管理系统Java人员管理jsp源代码Mysql

本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目介绍SpringBoot+V...

目前见过最牛的一个SpringBoot商城项目(附源码)还有人没用过吗

帮粉丝找了一个基于SpringBoot的天猫商城项目,快速部署运行,所用技术:MySQL,Druid,Log4j2,Maven,Echarts,Bootstrap...免费给大家分享出来前台演示...

SpringBoot+Mysql实现的手机商城附带源码演示导入视频

今天为大家带来的是基于SpringBoot+JPA+Thymeleaf框架的手机商城管理系统,商城系统分为前台和后台、前台用的是Bootstrap框架后台用的是SpringBoot+JPA都是现在主...

全网首发!马士兵内部共享—1658页《Java面试突击核心讲》

又是一年一度的“金九银十”秋招大热门,为助力广大程序员朋友“面试造火箭”,小编今天给大家分享的便是这份马士兵内部的面试神技——1658页《Java面试突击核心讲》!...

SpringBoot数据库操作的应用(springboot与数据库交互)

1.JDBC+HikariDataSource...

SpringBoot 整合 Flink 实时同步 MySQL

1、需求在Flink发布SpringBoot打包的jar包能够实时同步MySQL表,做到原表进行新增、修改、删除的时候目标表都能对应同步。...

SpringBoot + Mybatis + Shiro + mysql + redis智能平台源码分享

后端技术栈基于SpringBoot+Mybatis+Shiro+mysql+redis构建的智慧云智能教育平台基于数据驱动视图的理念封装element-ui,即使没有vue的使...

Springboot+Mysql舞蹈课程在线预约系统源码附带视频运行教程

今天发布的是由【猿来入此】的优秀学员独立做的一个基于springboot脚手架的Springboot+Mysql舞蹈课程在线预约系统,系统项目源代码在【猿来入此】获取!https://www.yuan...

SpringBoot+Mysql在线众筹系统源码+讲解视频+开发文档(参考论文

今天发布的是由【猿来入此】的优秀学员独立做的一个基于springboot脚手架的在线众筹管理系统,主要实现了普通用户在线参与众筹基本操作流程的全部功能,系统分普通用户、超级管理员等角色,除基础脚手架外...

Docker一键部署 SpringBoot 应用的方法,贼快贼好用

这两天发现个Gradle插件,支持一键打包、推送Docker镜像。今天我们来讲讲这个插件,希望对大家有所帮助!GradleDockerPlugin简介...

取消回复欢迎 发表评论: