文章目录
- 前言
- 一、def __init__(self, **env_params):
- 函数功能
- 函数代码
- 二、use_saved_problems(self, filename, device)
- 函数功能
- 函数代码
- 三、load_problems(self, batch_size, aug_factor=1)
- 函数功能
- 函数代码
- use_saved_problems 与 load_problems 之间的关系
- 四、reset(self)
- 函数功能
- 函数代码
- 五、pre_step(self)
- 函数功能
- 函数代码
- 六、step(self, selected)
- 函数功能
- 函数代码
- 七、_get_travel_distance(self)
- 函数功能
- 问题
- 什么是“滚动”?
- 函数代码
- 附件
- 代码(全):CVRPEnv.py
- 代码:一、def __init__(self, **env_params)
前言
对CVRPEnv.py中的类(class CVRPEnv)
代码的学习。
代码地址如下:
/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPEnv.py
一、def init(self, **env_params):
函数功能
这段代码是CVRPEnv类的初始化方法,主要用于初始化
与**车辆路径问题(CVRP)**环境相关的各个参数和变量。
参数思维导图链接
函数代码
def __init__(self, **env_params):
# Const @INIT
####################################
self.env_params = env_params
self.problem_size = env_params['problem_size'] #提取问题规模
self.pomo_size = env_params['pomo_size'] #POMO 智能体数量
self.FLAG__use_saved_problems = False #设置是否使用保存的问题实例
self.saved_depot_xy = None #配送中心(depot)的坐标
self.saved_node_xy = None #节点(客户或城市)的坐标
self.saved_node_demand = None #保存节点的需求量
self.saved_index = None #保存节点的索引
# Const @Load_Problem
####################################
self.batch_size = None
self.BATCH_IDX = None
self.POMO_IDX = None
# IDX.shape: (batch, pomo)
self.depot_node_xy = None
# shape: (batch, problem+1, 2)
self.depot_node_demand = None
# shape: (batch, problem+1)
# Dynamic-1
####################################
self.selected_count = None
self.current_node = None
# shape: (batch, pomo)
self.selected_node_list = None
# shape: (batch, pomo, 0~)
# Dynamic-2
####################################
self.at_the_depot = None
# shape: (batch, pomo)
self.load = None
# shape: (batch, pomo)
self.visited_ninf_flag = None
# shape: (batch, pomo, problem+1)
self.ninf_mask = None
# shape: (batch, pomo, problem+1)
self.finished = None
# shape: (batch, pomo)
# states to return
####################################
self.reset_state = Reset_State()
self.step_state = Step_State()
# regret
####################################
self.mode = None
self.last_current_node = None
self.last_load = None
self.regret_count = None
self.regret_mask_matrix = None
self.add_mask_matrix = None
self.time_step=0
二、use_saved_problems(self, filename, device)
函数功能
函数的功能是加载预先保存的问题实例
,并将这些问题实例的数据保存到类的属性中。
它会从指定的文件中读取问题数据,包括配送中心的位置(depot_xy)、节点的位置(node_xy)和节点的需求量(node_demand),然后将这些数据存储在类的属性中,以供后续使用。
函数思维导图链接
函数代码
def use_saved_problems(self, filename, device):
self.FLAG__use_saved_problems = True
loaded_dict = torch.load(filename, map_location=device) #加载保存的问题实例
self.saved_depot_xy = loaded_dict['depot_xy'] #解析加载的数据
self.saved_node_xy = loaded_dict['node_xy'] #
self.saved_node_demand = loaded_dict['node_demand']
self.saved_index = 0
三、load_problems(self, batch_size, aug_factor=1)
函数功能
该函数用于加载**车辆路径问题(CVRP)**实例
,包括:
- 动态生成问题实例 或 从预加载数据中提取问题
- 数据增强
- 初始化索引和状态变量
- 存储到环境变量
工作方式:
- 如果
self.FLAG__use_saved_problems
为True
,则从通过use_saved_problems
加载的预先保存的问题实例中提取数据(self.saved_depot_xy
,self.saved_node_xy
,self.saved_node_demand
),并更新索引self.saved_index
。 - 如果
self.FLAG__use_saved_problems
为False
,则动态生成问题实例。使用get_random_problems()
方法生成指定batch_size
和problem_size
的问题数据。 load_problems
还支持数据增强,通过指定aug_factor
来增强生成的数据(目前仅支持aug_factor=8
),扩展批次数量并改变问题实例的坐标和需求。
函数功能思维导图链接
函数代码
def load_problems(self, batch_size, aug_factor=1):
self.batch_size = batch_size
#加载问题实例
if not self.FLAG__use_saved_problems:
#动态生成模式
depot_xy, node_xy, node_demand = get_random_problems(batch_size, self.problem_size)
else:
#预加载模式,从保存的实例数据中提取问题
depot_xy = self.saved_depot_xy[self.saved_index:self.saved_index+batch_size]
node_xy = self.saved_node_xy[self.saved_index:self.saved_index+batch_size]
node_demand = self.saved_node_demand[self.saved_index:self.saved_index+batch_size]
self.saved_index += batch_size
#数据增强
if aug_factor > 1:
if aug_factor == 8:
self.batch_size = self.batch_size * 8
depot_xy = augment_xy_data_by_8_fold(depot_xy)
node_xy = augment_xy_data_by_8_fold(node_xy)
node_demand = node_demand.repeat(8, 1)
else:
raise NotImplementedError
#合并配送中心和节点数据
self.depot_node_xy = torch.cat((depot_xy, node_xy), dim=1)
# shape: (batch, problem+1, 2)
depot_demand = torch.zeros(size=(self.batch_size, 1))
# shape: (batch, 1)
self.depot_node_demand = torch.cat((depot_demand, node_demand), dim=1)
# shape: (batch, problem+1)
#初始化批量索引和 POMO 索引
self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)
#更新重置状态和步骤状态
self.reset_state.depot_xy = depot_xy
self.reset_state.node_xy = node_xy
self.reset_state.node_demand = node_demand
self.step_state.BATCH_IDX = self.BATCH_IDX
self.step_state.POMO_IDX = self.POMO_IDX
use_saved_problems 与 load_problems 之间的关系
-
use_saved_problems
作为数据加载的前置条件:-
use_saved_problems
主要负责加载已经保存好的问题实例文件(比如一个torch.save()
保存的文件),并将这些数据存储到环境中的特定变量中(例如self.saved_depot_xy
,self.saved_node_xy
)。 -
一旦执行了
use_saved_problems
,它设置了self.FLAG__use_saved_problems = True
,这意味着在后续的操作中,环境会从保存的数据中加载问题实例,而不是重新生成问题。 -
但是,
use_saved_problems
本身并不负责加载具体的问题实例数据,它只是为后续的加载操作(如load_problems
)提供了指示标志。
-
-
load_problems
使用use_saved_problems
加载的数据:load_problems
是执行数据加载和问题生成的主函数,它根据self.FLAG__use_saved_problems
的值,决定是从保存的数据中提取问题实例,还是生成新的随机问题实例。- 当
self.FLAG__use_saved_problems = True
时,load_problems
会从self.saved_depot_xy
、self.saved_node_xy
、self.saved_node_demand
等变量中读取数据,并根据需要为每个批次的问题实例做进一步处理(如索引的更新、数据增强等)。 - 如果
self.FLAG__use_saved_problems = False
,则load_problems
会使用get_random_problems()
来动态生成问题数据。
四、reset(self)
函数功能
reset 函数的主要目的是将环境的状态变量重置为初始值,通常在每个新的训练回合或实验开始时调用。该函数确保环境处于一个已知的初始状态,以便智能体能够从一个干净的状态开始进行决策和学习。
函数参数思维导图
函数代码
def reset(self):
#重置选择计数
self.selected_count = torch.zeros((self.batch_size, self.pomo_size), dtype=torch.long)
#重置当前节点
self.current_node = None
# shape: (batch, pomo)
#重置已选择的节点列表
self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)
# shape: (batch, pomo, 0~)
#初始化是否在配送中心
self.at_the_depot = torch.ones(size=(self.batch_size, self.pomo_size), dtype=torch.bool)
# shape: (batch, pomo)
# 初始化负载
self.load = torch.ones(size=(self.batch_size, self.pomo_size))
# shape: (batch, pomo)
#初始化访问掩码
self.visited_ninf_flag = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size+2))
self.visited_ninf_flag[:, :, self.problem_size+1] = float('-inf')
# shape: (batch, pomo, problem+1)
#初始化负无穷掩码
self.ninf_mask = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size+2))
self.ninf_mask[:, :, self.problem_size+1] = float('-inf')
# shape: (batch, pomo, problem+1)
#初始化完成状态
self.finished = torch.zeros(size=(self.batch_size, self.pomo_size), dtype=torch.bool)
# shape: (batch, pomo)
#初始化其他状态变量
self.regret_count = torch.zeros((self.batch_size, self.pomo_size))
self.mode = torch.full((self.batch_size, self.pomo_size), 0)
self.last_current_node = None
self.last_load = None
self.time_step=0
reward = None
done = False
return self.reset_state, reward, done
五、pre_step(self)
函数功能
pre_step 函数是环境中的一个预处理步骤,用于在每个时间步之前设置必要的状态信息。
通常,在强化学习环境中,每个时间步会根据当前状态和动作进行更新,pre_step 函数则为每个时间步提供所需的状态,供后续的决策和学习过程使用。
函数功能思维导图
函数代码
def pre_step(self):
#重置 selected_count
self.step_state.selected_count = 0
#复制当前负载
self.step_state.load = self.load
#设置当前节点
self.step_state.current_node = self.current_node
#更新掩码状态
self.step_state.ninf_mask = self.ninf_mask
#返回步骤状态、奖励和完成标志
reward = None
done = False
return self.step_state, reward, done
六、step(self, selected)
函数功能
这个函数的主要功能是在每个时间步(step)中更新智能体的状态,执行任务、处理负载、选择节点等,最终返回当前的状态、奖励和是否完成任务的标志。
函数功能与参数的思维导图链接
函数代码
def step(self, selected):
# selected.shape: (batch, pomo)
#时间步数控制
if self.time_step<4:
# 控制时间步的递增
self.time_step=self.time_step+1
self.selectex_count = self.selected_count+1
#判断是否在配送中心
self.at_the_depot = (selected == 0)
#特定时间步的操作
if self.time_step==3:
self.last_current_node = self.current_node.clone()
self.last_load = self.load.clone()
if self.time_step == 4:
self.last_current_node = self.current_node.clone()
self.last_load = self.load.clone()
self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot)&(self.last_current_node!=0)] = 0
#更新当前节点和已选择节点列表
self.current_node = selected
self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
#更新需求和负载
demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
gathering_index = selected[:, :, None]
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
self.load -= selected_demand
self.load[self.at_the_depot] = 1 # refill loaded at the depot
#更新访问标记(防止重复选择已访问的节点)
self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0 # depot is considered unvisited, unless you are AT the depot
#更新负无穷掩码(屏蔽需求量超过当前负载的节点)
self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
_2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)
demand_too_large = torch.cat((demand_too_large, _2), dim=2)
self.ninf_mask[demand_too_large] = float('-inf')
#更新步骤状态,将更新后的状态同步到 self.step_state
self.step_state.selected_count = self.time_step
self.step_state.load = self.load
self.step_state.current_node = self.current_node
self.step_state.ninf_mask = self.ninf_mask
#时间步大于等于 4 的复杂操作
else:
#动作模式分类
action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))
action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1)) # regret
action2_bool_index = self.mode == 1
action3_bool_index = self.mode == 2
action1_index = torch.nonzero(action1_bool_index)
action2_index = torch.nonzero(action2_bool_index)
action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))
#更新选择计数
self.selected_count = self.selected_count+1
#后悔模式
self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2
#节点更新
self.last_is_depot = (self.last_current_node == 0)
_ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
self.last_current_node = self.current_node.clone()
self.current_node = selected.clone()
self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()
#更新已选择节点列表
self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)
#更新负载
self.at_the_depot = (selected == 0)
demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
# shape: (batch, pomo, problem+1)
_3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)
#扩展需求列表 demand_list
demand_list = torch.cat((demand_list, _3), dim=2)
gathering_index = selected[:, :, None]
# shape: (batch, pomo, 1)
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
_1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()
self.last_load= self.load.clone()
# shape: (batch, pomo)
self.load -= selected_demand
self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()
self.load[self.at_the_depot] = 1 # refill loaded at the depot
#更新访问标记
self.visited_ninf_flag[:, :, self.problem_size+1][self.last_is_depot] = 0
self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)
self.visited_ninf_flag[:, :, self.problem_size+1][self.at_the_depot] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0
# 更新负无穷掩码
self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
# shape: (batch, pomo, problem+1)
self.ninf_mask[demand_too_large] = float('-inf')
# 更新完成状态
# 检查哪些智能体已经完成所有节点的访问。
# 更新完成标记 self.finished。
newly_finished = (self.visited_ninf_flag == float('-inf'))[:,:,:self.problem_size+1].all(dim=2)
# shape: (batch, pomo)
self.finished = self.finished + newly_finished
# shape: (batch, pomo)
#更新模式
self.mode[action1_bool_index] = 1
self.mode[action2_bool_index] = 2
self.mode[action3_bool_index] = 0
self.mode[self.finished] = 4
# 更新完成后的掩码调整
self.ninf_mask[:, :, 0][self.finished] = 0
self.ninf_mask[:, :, self.problem_size+1][self.finished] = float('-inf')
# 更新步骤状态
self.step_state.selected_count = self.time_step
self.step_state.load = self.load
self.step_state.current_node = self.current_node
self.step_state.ninf_mask = self.ninf_mask
# returning values
done = self.finished.all()
if done:
reward = -self._get_travel_distance() # note the minus sign!
else:
reward = None
return self.step_state, reward, done
七、_get_travel_distance(self)
函数功能
_get_travel_distance
函数的主要功能是计算每个智能体(POMO智能体)在每个时间步所选择的节点之间的旅行距离。
函数参数和流程图链接
问题
什么是“滚动”?
“滚动”是对张量或数组进行操作的一种方式,它通过沿特定维度(通常是时间维度)移动元素,从而生成一个新的数组或张量。
例子
设我们有一个一维张量表示时间步的节点选择情况:
tensor = torch.tensor([1, 2, 3, 4, 5])
如果我们对这个张量进行滚动操作,沿着时间维度向右滚动1步:
rolled_tensor = tensor.roll(dims=0, shifts=1)
这时,rolled_tensor 将变成:
tensor([5, 1, 2, 3, 4])
函数代码
def _get_travel_distance(self):
m1 = (self.selected_node_list==self.problem_size+1)
m2 = (m1.roll(dims=2, shifts=-1) | m1)
m3 = m1.roll(dims=2, shifts=1)
m4 = ~(m2|m3)
selected_node_list_right = self.selected_node_list.roll(dims=2, shifts=1)
selected_node_list_right2 = self.selected_node_list.roll(dims=2, shifts=3)
self.regret_mask_matrix = m1
self.add_mask_matrix = (~m2)
travel_distances = torch.zeros((self.batch_size, self.pomo_size))
for t in range(self.selected_node_list.shape[2]):
add1_index = (m4[:,:,t].unsqueeze(2)).nonzero()
add3_index = (m3[:,:,t].unsqueeze(2)).nonzero()
travel_distances[add1_index[:,0],add1_index[:,1]] = travel_distances[add1_index[:,0],add1_index[:,1]].clone()+((self.depot_node_xy[add1_index[:,0],self.selected_node_list[add1_index[:,0],add1_index[:,1],t],:]-self.depot_node_xy[add1_index[:,0],selected_node_list_right[add1_index[:,0],add1_index[:,1],t],:])**2).sum(1).sqrt()
travel_distances[add3_index[:,0],add3_index[:,1]] = travel_distances[add3_index[:,0],add3_index[:,1]].clone()+((self.depot_node_xy[add3_index[:,0],self.selected_node_list[add3_index[:,0],add3_index[:,1],t],:]-self.depot_node_xy[add3_index[:,0],selected_node_list_right2[add3_index[:,0],add3_index[:,1],t],:])**2).sum(1).sqrt()
return travel_distances
附件
代码(全):CVRPEnv.py
返回:前言
/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPEnv.py
from dataclasses import dataclass
import torch
from CVRProblemDef import get_random_problems, augment_xy_data_by_8_fold
@dataclass
class Reset_State:
depot_xy: torch.Tensor = None
# shape: (batch, 1, 2)
node_xy: torch.Tensor = None
# shape: (batch, problem, 2)
node_demand: torch.Tensor = None
# shape: (batch, problem)
@dataclass
class Step_State:
BATCH_IDX: torch.Tensor = None #表示批次的索引
POMO_IDX: torch.Tensor = None #表示 POMO 算法中的多智能体索引
# shape: (batch, pomo)
selected_count: int = None #表示当前已经选中的节点数量
load: torch.Tensor = None #表示当前负载状态
# shape: (batch, pomo)
current_node: torch.Tensor = None #表示当前正在访问的节点编号
# shape: (batch, pomo)
ninf_mask: torch.Tensor = None #表示负无穷掩码
# shape: (batch, pomo, problem+1)
class CVRPEnv:
def __init__(self, **env_params):
# Const @INIT
####################################
self.env_params = env_params
self.problem_size = env_params['problem_size'] #提取问题规模
self.pomo_size = env_params['pomo_size'] #POMO 智能体数量
self.FLAG__use_saved_problems = False #设置是否使用保存的问题实例
self.saved_depot_xy = None #配送中心(depot)的坐标
self.saved_node_xy = None #节点(客户或城市)的坐标
self.saved_node_demand = None #保存节点的需求量
self.saved_index = None #保存节点的索引
# Const @Load_Problem
####################################
self.batch_size = None
self.BATCH_IDX = None
self.POMO_IDX = None
# IDX.shape: (batch, pomo)
self.depot_node_xy = None
# shape: (batch, problem+1, 2)
self.depot_node_demand = None
# shape: (batch, problem+1)
# Dynamic-1
####################################
self.selected_count = None
self.current_node = None
# shape: (batch, pomo)
self.selected_node_list = None
# shape: (batch, pomo, 0~)
# Dynamic-2
####################################
self.at_the_depot = None
# shape: (batch, pomo)
self.load = None
# shape: (batch, pomo)
self.visited_ninf_flag = None
# shape: (batch, pomo, problem+1)
self.ninf_mask = None
# shape: (batch, pomo, problem+1)
self.finished = None
# shape: (batch, pomo)
# states to return
####################################
self.reset_state = Reset_State()
self.step_state = Step_State()
# regret
####################################
self.mode = None
self.last_current_node = None
self.last_load = None
self.regret_count = None
self.regret_mask_matrix = None
self.add_mask_matrix = None
self.time_step=0
#加载保存的问题实例数据
def use_saved_problems(self, filename, device):
self.FLAG__use_saved_problems = True
loaded_dict = torch.load(filename, map_location=device) #加载保存的问题实例
self.saved_depot_xy = loaded_dict['depot_xy'] #解析加载的数据
self.saved_node_xy = loaded_dict['node_xy'] #
self.saved_node_demand = loaded_dict['node_demand']
self.saved_index = 0
def load_problems(self, batch_size, aug_factor=1):
self.batch_size = batch_size
#加载问题实例
if not self.FLAG__use_saved_problems:
#动态生成模式
depot_xy, node_xy, node_demand = get_random_problems(batch_size, self.problem_size)
else:
#预加载模式,从保存的实例数据中提取问题
depot_xy = self.saved_depot_xy[self.saved_index:self.saved_index+batch_size]
node_xy = self.saved_node_xy[self.saved_index:self.saved_index+batch_size]
node_demand = self.saved_node_demand[self.saved_index:self.saved_index+batch_size]
self.saved_index += batch_size
#数据增强
if aug_factor > 1:
if aug_factor == 8:
self.batch_size = self.batch_size * 8
depot_xy = augment_xy_data_by_8_fold(depot_xy)
node_xy = augment_xy_data_by_8_fold(node_xy)
node_demand = node_demand.repeat(8, 1)
else:
raise NotImplementedError
#合并配送中心和节点数据
self.depot_node_xy = torch.cat((depot_xy, node_xy), dim=1)
# shape: (batch, problem+1, 2)
depot_demand = torch.zeros(size=(self.batch_size, 1))
# shape: (batch, 1)
self.depot_node_demand = torch.cat((depot_demand, node_demand), dim=1)
# shape: (batch, problem+1)
#初始化批量索引和 POMO 索引
self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)
#更新重置状态和步骤状态
self.reset_state.depot_xy = depot_xy
self.reset_state.node_xy = node_xy
self.reset_state.node_demand = node_demand
self.step_state.BATCH_IDX = self.BATCH_IDX
self.step_state.POMO_IDX = self.POMO_IDX
def reset(self):
#重置选择计数
self.selected_count = torch.zeros((self.batch_size, self.pomo_size), dtype=torch.long)
#重置当前节点
self.current_node = None
# shape: (batch, pomo)
#重置已选择的节点列表
self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)
# shape: (batch, pomo, 0~)
#初始化是否在配送中心
self.at_the_depot = torch.ones(size=(self.batch_size, self.pomo_size), dtype=torch.bool)
# shape: (batch, pomo)
# 初始化负载
self.load = torch.ones(size=(self.batch_size, self.pomo_size))
# shape: (batch, pomo)
#初始化访问掩码
self.visited_ninf_flag = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size+2))
self.visited_ninf_flag[:, :, self.problem_size+1] = float('-inf')
# shape: (batch, pomo, problem+1)
#初始化负无穷掩码
self.ninf_mask = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size+2))
self.ninf_mask[:, :, self.problem_size+1] = float('-inf')
# shape: (batch, pomo, problem+1)
#初始化完成状态
self.finished = torch.zeros(size=(self.batch_size, self.pomo_size), dtype=torch.bool)
# shape: (batch, pomo)
#初始化其他状态变量
self.regret_count = torch.zeros((self.batch_size, self.pomo_size))
self.mode = torch.full((self.batch_size, self.pomo_size), 0)
self.last_current_node = None
self.last_load = None
self.time_step=0
reward = None
done = False
return self.reset_state, reward, done
def pre_step(self):
#重置 selected_count
self.step_state.selected_count = 0
#复制当前负载
self.step_state.load = self.load
#设置当前节点
self.step_state.current_node = self.current_node
#更新掩码状态
self.step_state.ninf_mask = self.ninf_mask
#返回步骤状态、奖励和完成标志
reward = None
done = False
return self.step_state, reward, done
def step(self, selected):
# selected.shape: (batch, pomo)
#时间步数控制
if self.time_step<4:
# 控制时间步的递增
self.time_step=self.time_step+1
self.selectex_count = self.selected_count+1
#判断是否在配送中心
self.at_the_depot = (selected == 0)
#特定时间步的操作
if self.time_step==3:
self.last_current_node = self.current_node.clone()
self.last_load = self.load.clone()
if self.time_step == 4:
self.last_current_node = self.current_node.clone()
self.last_load = self.load.clone()
self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot)&(self.last_current_node!=0)] = 0
#更新当前节点和已选择节点列表
self.current_node = selected
self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
#更新需求和负载
demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
gathering_index = selected[:, :, None]
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
self.load -= selected_demand
self.load[self.at_the_depot] = 1 # refill loaded at the depot
#更新访问标记(防止重复选择已访问的节点)
self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0 # depot is considered unvisited, unless you are AT the depot
#更新负无穷掩码(屏蔽需求量超过当前负载的节点)
self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
_2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)
demand_too_large = torch.cat((demand_too_large, _2), dim=2)
self.ninf_mask[demand_too_large] = float('-inf')
#更新步骤状态,将更新后的状态同步到 self.step_state
self.step_state.selected_count = self.time_step
self.step_state.load = self.load
self.step_state.current_node = self.current_node
self.step_state.ninf_mask = self.ninf_mask
#时间步大于等于 4 的复杂操作
else:
#动作模式分类
action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))
action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1)) # regret
action2_bool_index = self.mode == 1
action3_bool_index = self.mode == 2
action1_index = torch.nonzero(action1_bool_index)
action2_index = torch.nonzero(action2_bool_index)
action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))
#更新选择计数
self.selected_count = self.selected_count+1
#后悔模式
self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2
#节点更新
self.last_is_depot = (self.last_current_node == 0)
_ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
self.last_current_node = self.current_node.clone()
self.current_node = selected.clone()
self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()
#更新已选择节点列表
self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)
#更新负载
self.at_the_depot = (selected == 0)
demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
# shape: (batch, pomo, problem+1)
_3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)
#扩展需求列表 demand_list
demand_list = torch.cat((demand_list, _3), dim=2)
gathering_index = selected[:, :, None]
# shape: (batch, pomo, 1)
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
_1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()
self.last_load= self.load.clone()
# shape: (batch, pomo)
self.load -= selected_demand
self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()
self.load[self.at_the_depot] = 1 # refill loaded at the depot
#更新访问标记
self.visited_ninf_flag[:, :, self.problem_size+1][self.last_is_depot] = 0
self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)
self.visited_ninf_flag[:, :, self.problem_size+1][self.at_the_depot] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0
# 更新负无穷掩码
self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
# shape: (batch, pomo, problem+1)
self.ninf_mask[demand_too_large] = float('-inf')
# 更新完成状态
# 检查哪些智能体已经完成所有节点的访问。
# 更新完成标记 self.finished。
newly_finished = (self.visited_ninf_flag == float('-inf'))[:,:,:self.problem_size+1].all(dim=2)
# shape: (batch, pomo)
self.finished = self.finished + newly_finished
# shape: (batch, pomo)
#更新模式
self.mode[action1_bool_index] = 1
self.mode[action2_bool_index] = 2
self.mode[action3_bool_index] = 0
self.mode[self.finished] = 4
# 更新完成后的掩码调整
self.ninf_mask[:, :, 0][self.finished] = 0
self.ninf_mask[:, :, self.problem_size+1][self.finished] = float('-inf')
# 更新步骤状态
self.step_state.selected_count = self.time_step
self.step_state.load = self.load
self.step_state.current_node = self.current_node
self.step_state.ninf_mask = self.ninf_mask
# returning values
done = self.finished.all()
if done:
reward = -self._get_travel_distance() # note the minus sign!
else:
reward = None
return self.step_state, reward, done
def _get_travel_distance(self):
m1 = (self.selected_node_list==self.problem_size+1)
m2 = (m1.roll(dims=2, shifts=-1) | m1)
m3 = m1.roll(dims=2, shifts=1)
m4 = ~(m2|m3)
selected_node_list_right = self.selected_node_list.roll(dims=2, shifts=1)
selected_node_list_right2 = self.selected_node_list.roll(dims=2, shifts=3)
self.regret_mask_matrix = m1
self.add_mask_matrix = (~m2)
travel_distances = torch.zeros((self.batch_size, self.pomo_size))
for t in range(self.selected_node_list.shape[2]):
add1_index = (m4[:,:,t].unsqueeze(2)).nonzero()
add3_index = (m3[:,:,t].unsqueeze(2)).nonzero()
travel_distances[add1_index[:,0],add1_index[:,1]] = travel_distances[add1_index[:,0],add1_index[:,1]].clone()+((self.depot_node_xy[add1_index[:,0],self.selected_node_list[add1_index[:,0],add1_index[:,1],t],:]-self.depot_node_xy[add1_index[:,0],selected_node_list_right[add1_index[:,0],add1_index[:,1],t],:])**2).sum(1).sqrt()
travel_distances[add3_index[:,0],add3_index[:,1]] = travel_distances[add3_index[:,0],add3_index[:,1]].clone()+((self.depot_node_xy[add3_index[:,0],self.selected_node_list[add3_index[:,0],add3_index[:,1],t],:]-self.depot_node_xy[add3_index[:,0],selected_node_list_right2[add3_index[:,0],add3_index[:,1],t],:])**2).sum(1).sqrt()
return travel_distances
代码:一、def init(self, **env_params)
def __init__(self, **env_params):
# Const @INIT
####################################
self.env_params = env_params
self.problem_size = env_params['problem_size'] #提取问题规模
self.pomo_size = env_params['pomo_size'] #POMO 智能体数量
self.FLAG__use_saved_problems = False #设置是否使用保存的问题实例
self.saved_depot_xy = None #配送中心(depot)的坐标
self.saved_node_xy = None #节点(客户或城市)的坐标
self.saved_node_demand = None #保存节点的需求量
self.saved_index = None #保存节点的索引
# Const @Load_Problem
####################################
self.batch_size = None
self.BATCH_IDX = None
self.POMO_IDX = None
# IDX.shape: (batch, pomo)
self.depot_node_xy = None
# shape: (batch, problem+1, 2)
self.depot_node_demand = None
# shape: (batch, problem+1)
# Dynamic-1
####################################
self.selected_count = None
self.current_node = None
# shape: (batch, pomo)
self.selected_node_list = None
# shape: (batch, pomo, 0~)
# Dynamic-2
####################################
self.at_the_depot = None
# shape: (batch, pomo)
self.load = None
# shape: (batch, pomo)
self.visited_ninf_flag = None
# shape: (batch, pomo, problem+1)
self.ninf_mask = None
# shape: (batch, pomo, problem+1)
self.finished = None
# shape: (batch, pomo)
# states to return
####################################
self.reset_state = Reset_State()
self.step_state = Step_State()
# regret
####################################
self.mode = None
self.last_current_node = None
self.last_load = None
self.regret_count = None
self.regret_mask_matrix = None
self.add_mask_matrix = None
self.time_step=0