-
Notifications
You must be signed in to change notification settings - Fork 165
feature(xjy): add multi-task learning pipeline in jericho environment #365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-multitask-balance-clean
Are you sure you want to change the base?
feature(xjy): add multi-task learning pipeline in jericho environment #365
Conversation
from ding.config import compile_config | ||
from ding.envs import create_env_manager, get_vec_env_setting | ||
from ding.policy import create_policy, Policy | ||
# from ding.rl_utils import get_epsilon_greedy_fn # get_epsilon_greedy_fn 已被弃用,如果需要需要从 ding.exploration 导入 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
英文注释
|
||
return weights | ||
|
||
def train_unizero_multitask_ddp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
英文注释,上面的工具函数可以移到entry/utils.py中去
def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok=1): | ||
class MoELayer(nn.Module): | ||
""" | ||
Mixture-of-Experts (MoE) 层的实现,参考了如下的设计: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
英文注释
# 若使用 Register Token,则将其拼到序列最前面 | ||
# 训练阶段和推理阶段都统一处理 | ||
if self.use_register_token: | ||
sequences = self.add_register_tokens(sequences, task_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pull最新的opendilab:dev-multitask-balance-clean这个分支,rotary_emb等已实现功能不应该去掉
# self.feed_forward = MoELayer(moe_cfg) | ||
# print("=" * 20) | ||
# print(f"Use MoE feed_forward, num_experts={moe_cfg.num_experts_total}") | ||
# print("=" * 20) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除没用到的注释
@@ -869,10 +858,12 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: | |||
# self.task_num_for_current_rank 作为当前rank的base_index | |||
num_tasks = self.task_num_for_current_rank | |||
# If the number of tasks is provided, extend the monitored variables list with task-specific variables | |||
# TODO xiongjyu: 以下代码感觉有问题,如果num_tasks != 1(例如2), 4个任务的self.task_id分别是0, 1, 2, 3; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个具体的问题在群里截图发一下看看哈
manual_temperature_decay=False, | ||
num_simulations=num_simulations, | ||
n_episode=n_episode, | ||
train_start_after_envsteps=int(0), # TODO: ===== only for debug ===== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉不用的注释
import_names=['zoo.jericho.envs.jericho_env'], | ||
), | ||
env_manager=dict(type='base'), | ||
# env_manager=dict(type='subprocess'), # subprocess在jericho环境下不支持 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
英文注释
num_heads=24, | ||
obs_type="text", # TODO: Modify as needed. | ||
env_num=max(collector_env_num, evaluator_env_num), | ||
task_embed_option=None, # ==============TODO: none ============== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
全部改动都检查一下注释格式。删除不用的TODO
max_action_num=max_action_num, | ||
tokenizer_path=model_name, | ||
max_seq_len=512, | ||
game_path=f"./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前最新的结果是用这个跑的吗?
No description provided.