mirror of https://github.com/THUDM/ChatGLM3
Improve compatibility; add reference memory consumption; fix errors
This commit is contained in:
parent
826fed0d9f
commit
0a4281417d
|
@ -4,3 +4,6 @@ __pycache__
|
||||||
finetune/output
|
finetune/output
|
||||||
finetune/data
|
finetune/data
|
||||||
finetune/formatted_data
|
finetune/formatted_data
|
||||||
|
ToolAlpaca/
|
||||||
|
AdvertiseGen/
|
||||||
|
*.gz
|
|
@ -4,6 +4,12 @@
|
||||||
|
|
||||||
如果将模型下载到了本地,本文和代码中的 `THUDM/chatglm3-6b` 字段均应替换为相应地址以从本地加载模型。
|
如果将模型下载到了本地,本文和代码中的 `THUDM/chatglm3-6b` 字段均应替换为相应地址以从本地加载模型。
|
||||||
|
|
||||||
|
运行示例需要 `python>=3.9`,除基础的 `torch` 依赖外,示例代码运行还需要依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install transformers==4.30.2 accelerate sentencepiece
|
||||||
|
```
|
||||||
|
|
||||||
## 多轮对话格式
|
## 多轮对话格式
|
||||||
|
|
||||||
多轮对话微调示例采用 ChatGLM3 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。
|
多轮对话微调示例采用 ChatGLM3 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。
|
||||||
|
@ -49,12 +55,12 @@
|
||||||
|
|
||||||
- 每种角色可以附带一个 `bool` 类型的 `loss` 字段,表示该字段所预测的内容是否参与 `loss` 计算。若没有该字段,样例实现中默认对 `system`, `user` 不计算 `loss`,其余角色则计算 `loss`。
|
- 每种角色可以附带一个 `bool` 类型的 `loss` 字段,表示该字段所预测的内容是否参与 `loss` 计算。若没有该字段,样例实现中默认对 `system`, `user` 不计算 `loss`,其余角色则计算 `loss`。
|
||||||
|
|
||||||
- `tool` 并不是 ChatGLM3 中的原生角色,这里的 `tool` 在预处理阶段将被自动转化为一个具有工具调用 `metadata` 的 `assistant` 角色和一个表示工具返回值的 `observation` 角色。
|
- `tool` 并不是 ChatGLM3 中的原生角色,这里的 `tool` 在预处理阶段将被自动转化为一个具有工具调用 `metadata` 的 `assistant` 角色(默认计算 `loss`)和一个表示工具返回值的 `observation` 角色(不计算 `loss`)。
|
||||||
|
|
||||||
作为示例,我们使用 ToolAlpaca 数据集来进行微调。首先,克隆 [ToolAlpaca 数据集](https://github.com/tangqiaoyu/ToolAlpaca),并使用
|
作为示例,我们使用 ToolAlpaca 数据集来进行微调。首先,克隆 [ToolAlpaca 数据集](https://github.com/tangqiaoyu/ToolAlpaca),并使用
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./data/format_tool_alpaca.py --path "ToolAlpaca/data/train_data.json"
|
./scripts/format_tool_alpaca.py --path "ToolAlpaca/data/train_data.json"
|
||||||
```
|
```
|
||||||
|
|
||||||
将数据集处理成上述格式。在这里,我们有意将工具处理成了了 `list[str]` 这样的自然语言形式,以观察模型在微调前后对工具定义的理解能力。
|
将数据集处理成上述格式。在这里,我们有意将工具处理成了了 `list[str]` 这样的自然语言形式,以观察模型在微调前后对工具定义的理解能力。
|
||||||
|
@ -64,8 +70,8 @@
|
||||||
以下脚本提供了微调模型的参考方式。
|
以下脚本提供了微调模型的参考方式。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./scripts/finetune_multiturn.sh # 全量微调
|
./scripts/finetune_ds_multiturn.sh # 全量微调
|
||||||
./scripts/finetune_multiturn_pt.sh # P-Tuning v2 微调
|
./scripts/finetune_pt_multiturn.sh # P-Tuning v2 微调
|
||||||
```
|
```
|
||||||
|
|
||||||
### 部署
|
### 部署
|
||||||
|
@ -94,7 +100,7 @@ MODEL_PATH="THUDM/chatglm3-6b" PT_PATH="path to p-tuning checkpoint" streamlit r
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"prompt": "<prompt text>",
|
"prompt": "<prompt text>",
|
||||||
"response": "<prompt text>"
|
"response": "<response text>"
|
||||||
}
|
}
|
||||||
// ...
|
// ...
|
||||||
]
|
]
|
||||||
|
@ -105,7 +111,7 @@ MODEL_PATH="THUDM/chatglm3-6b" PT_PATH="path to p-tuning checkpoint" streamlit r
|
||||||
作为示例,我们使用 AdvertiseGen 数据集来进行微调。从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 AdvertiseGen 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。
|
作为示例,我们使用 AdvertiseGen 数据集来进行微调。从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 AdvertiseGen 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./data/format_advertise_gen.py --path "AdvertiseGen/train.json"
|
./scripts/format_advertise_gen.py --path "AdvertiseGen/train.json"
|
||||||
```
|
```
|
||||||
|
|
||||||
来下载和将数据集处理成上述格式。
|
来下载和将数据集处理成上述格式。
|
||||||
|
@ -115,13 +121,13 @@ MODEL_PATH="THUDM/chatglm3-6b" PT_PATH="path to p-tuning checkpoint" streamlit r
|
||||||
以下脚本提供了微调模型的参考方式。
|
以下脚本提供了微调模型的参考方式。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./scripts/finetune_ds_multiturn.sh # 全量微调
|
./scripts/finetune_ds.sh # 全量微调
|
||||||
./scripts/finetune_pt_multiturn.sh # P-Tuning v2 微调
|
./scripts/finetune_pt.sh # P-Tuning v2 微调
|
||||||
```
|
```
|
||||||
|
|
||||||
### 推理验证
|
### 推理验证
|
||||||
|
|
||||||
对于输入输出格式的微调,下列脚本演示了基本的推理方式。
|
对于输入输出格式的微调,可使用 `inference.py` 进行基本的推理验证。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python inference.py \
|
python inference.py \
|
||||||
|
@ -139,38 +145,43 @@ python inference.py \
|
||||||
|
|
||||||
1. 微调代码在开始训练前,会先打印首条训练数据的预处理信息,显示为
|
1. 微调代码在开始训练前,会先打印首条训练数据的预处理信息,显示为
|
||||||
|
|
||||||
```log
|
```log
|
||||||
Sanity Check >>>>>>>>>>>>>
|
Sanity Check >>>>>>>>>>>>>
|
||||||
'[gMASK]': 64790 -> -100
|
'[gMASK]': 64790 -> -100
|
||||||
'sop': 64792 -> -100
|
'sop': 64792 -> -100
|
||||||
'<|system|>': 64794 -> -100
|
'<|system|>': 64794 -> -100
|
||||||
'': 30910 -> -100
|
'': 30910 -> -100
|
||||||
'\n': 13 -> -100
|
'\n': 13 -> -100
|
||||||
'Answer': 20115 -> -100
|
'Answer': 20115 -> -100
|
||||||
'the': 267 -> -100
|
'the': 267 -> -100
|
||||||
'following': 1762 -> -100
|
'following': 1762 -> -100
|
||||||
...
|
...
|
||||||
'know': 683 -> -100
|
'know': 683 -> -100
|
||||||
'the': 267 -> -100
|
'the': 267 -> -100
|
||||||
'response': 3010 -> -100
|
'response': 3010 -> -100
|
||||||
'details': 3296 -> -100
|
'details': 3296 -> -100
|
||||||
'.': 30930 -> -100
|
'.': 30930 -> -100
|
||||||
'<|assistant|>': 64796 -> -100
|
'<|assistant|>': 64796 -> -100
|
||||||
'': 30910 -> 30910
|
'': 30910 -> 30910
|
||||||
'\n': 13 -> 13
|
'\n': 13 -> 13
|
||||||
'I': 307 -> 307
|
'I': 307 -> 307
|
||||||
'need': 720 -> 720
|
'need': 720 -> 720
|
||||||
'to': 289 -> 289
|
'to': 289 -> 289
|
||||||
'use': 792 -> 792
|
'use': 792 -> 792
|
||||||
...
|
...
|
||||||
<<<<<<<<<<<<< Sanity Check
|
<<<<<<<<<<<<< Sanity Check
|
||||||
```
|
```
|
||||||
|
|
||||||
字样,每行依次表示一个 detokenized string, token_id 和 target_id。可在日志中仔细查看这部分的 `loss_mask` 是否符合预期。若不符合,可能需要调整代码或数据。
|
字样,每行依次表示一个 detokenized string, token_id 和 target_id。可在日志中查看这部分的 `loss_mask` 是否符合预期。若不符合,可能需要调整代码或数据。
|
||||||
|
|
||||||
2. 若显存不足,出现 `RuntimeError: CUDA out of memory.`,可以考虑
|
2. P-Tuning V2 参考显存用量
|
||||||
- 尝试降低 `DEV_BATCH_SIZE` 并提升 `GRAD_ACCUMULARION_STEPS`
|
|
||||||
- 尝试添加 `--quantization_bit 8` 或 `--quantization_bit 4`;注意使用量化训练需要安装 `cpm_kernels`。
|
`PRE_SEQ_LEN=128`, `DEV_BATCH_SIZE=1`, `GRAD_ACCUMULARION_STEPS=16`, `MAX_SEQ_LEN=2048` 配置下约需要 21GB 显存。
|
||||||
|
|
||||||
|
3. 若尝试后发现显存不足,可以考虑
|
||||||
|
- 尝试降低 `DEV_BATCH_SIZE` 并提升 `GRAD_ACCUMULARION_STEPS`
|
||||||
|
- 尝试添加 `--quantization_bit 8` 或 `--quantization_bit 4`。
|
||||||
|
- `PRE_SEQ_LEN=128`, `DEV_BATCH_SIZE=1`, `GRAD_ACCUMULARION_STEPS=16`, `MAX_SEQ_LEN=1024` 配置下,`--quantization_bit 8` 约需 12GB 显存,`--quantization_bit 4` 约需 7.6GB 显存。
|
||||||
|
|
||||||
## 参考文献
|
## 参考文献
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import astunparse
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
# text constants
|
# text constants
|
||||||
FUNCTION_CALL_NAME = 'tool_call'
|
FUNCTION_CALL_NAME = 'tool_call'
|
||||||
|
@ -13,7 +14,7 @@ TOOL_DEFINITION_PREFIX = 'Answer the following questions as best as you can. You
|
||||||
CONVERSATOIN_KEY = 'conversations'
|
CONVERSATOIN_KEY = 'conversations'
|
||||||
TOOL_DESC_KEY = 'tools'
|
TOOL_DESC_KEY = 'tools'
|
||||||
|
|
||||||
def format_function_call(function_name: str, parameters: dict[str, str]):
|
def format_function_call(function_name: str, parameters: Dict[str, str]):
|
||||||
function_name = ast.Name(id=function_name)
|
function_name = ast.Name(id=function_name)
|
||||||
keywords = [
|
keywords = [
|
||||||
ast.keyword(arg=arg_name, value=ast.Constant(arg_value))
|
ast.keyword(arg=arg_name, value=ast.Constant(arg_value))
|
||||||
|
@ -22,13 +23,13 @@ def format_function_call(function_name: str, parameters: dict[str, str]):
|
||||||
func_call = ast.Call(func=function_name, args=[], keywords=keywords)
|
func_call = ast.Call(func=function_name, args=[], keywords=keywords)
|
||||||
return astunparse.unparse(func_call).strip()
|
return astunparse.unparse(func_call).strip()
|
||||||
|
|
||||||
def format_conversation(item, tokenizer: "ChatGLMTokenizer", conversation_key: str, tool_key: str):
|
def format_conversation(item, tokenizer, conversation_key: str, tool_key: str):
|
||||||
conversations = deepcopy(item[conversation_key])
|
conversations = deepcopy(item[conversation_key])
|
||||||
|
|
||||||
# Note: `loss_mask` here means whether *the prediction* of the token should take loss
|
# Note: `loss_mask` here means whether *the prediction* of the token should take loss
|
||||||
tokens, loss_masks = [tokenizer.get_command("[gMASK]"), tokenizer.get_command("sop")], [0, 0]
|
tokens, loss_masks = [tokenizer.get_command("[gMASK]"), tokenizer.get_command("sop")], [0, 0]
|
||||||
|
|
||||||
def _update(_tokens: list[int], value: int = 1):
|
def _update(_tokens: List[int], value: int = 1):
|
||||||
value = int(value)
|
value = int(value)
|
||||||
tokens.extend(_tokens)
|
tokens.extend(_tokens)
|
||||||
loss_masks.extend([value] * len(_tokens))
|
loss_masks.extend([value] * len(_tokens))
|
||||||
|
@ -67,7 +68,7 @@ def format_conversation(item, tokenizer: "ChatGLMTokenizer", conversation_key: s
|
||||||
assert len(tokens) == len(loss_masks), f"length mismatch: {len(tokens)} vs {len(loss_masks)}"
|
assert len(tokens) == len(loss_masks), f"length mismatch: {len(tokens)} vs {len(loss_masks)}"
|
||||||
return tokens, loss_masks
|
return tokens, loss_masks
|
||||||
|
|
||||||
def sanity_check(tokens: list[int], target: list[int], tokenizer: PreTrainedTokenizer):
|
def sanity_check(tokens: List[int], target: List[int], tokenizer: PreTrainedTokenizer):
|
||||||
print("Sanity Check >>>>>>>>>>>>>")
|
print("Sanity Check >>>>>>>>>>>>>")
|
||||||
for t, m in zip(tokens, target):
|
for t, m in zip(tokens, target):
|
||||||
decoded = tokenizer.tokenizer.index_special_tokens[t] \
|
decoded = tokenizer.tokenizer.index_special_tokens[t] \
|
||||||
|
@ -79,7 +80,7 @@ def sanity_check(tokens: list[int], target: list[int], tokenizer: PreTrainedToke
|
||||||
assert len(tokens) == len(target), f"length mismatch: {len(tokens)} vs {len(target)}"
|
assert len(tokens) == len(target), f"length mismatch: {len(tokens)} vs {len(target)}"
|
||||||
|
|
||||||
class MultiTurnDataset(Dataset):
|
class MultiTurnDataset(Dataset):
|
||||||
def __init__(self, data: list[dict], tokenizer: PreTrainedTokenizer, max_seq_length: int):
|
def __init__(self, data: List[dict], tokenizer: PreTrainedTokenizer, max_seq_length: int):
|
||||||
super(MultiTurnDataset, self).__init__()
|
super(MultiTurnDataset, self).__init__()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_seq_length = max_seq_length
|
self.max_seq_length = max_seq_length
|
||||||
|
@ -109,7 +110,7 @@ class MultiTurnDataset(Dataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
class InputOutputDataset(Dataset):
|
class InputOutputDataset(Dataset):
|
||||||
def __init__(self, data: list[dict], tokenizer: PreTrainedTokenizer, max_source_length: int, max_target_length: int):
|
def __init__(self, data: List[dict], tokenizer: PreTrainedTokenizer, max_source_length: int, max_target_length: int):
|
||||||
super(InputOutputDataset, self).__init__()
|
super(InputOutputDataset, self).__init__()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_source_length = max_source_length
|
self.max_source_length = max_source_length
|
||||||
|
|
|
@ -21,7 +21,7 @@ MASTER_PORT=$(shuf -n 1 -i 10000-65535)
|
||||||
|
|
||||||
mkdir -p $OUTPUT_DIR
|
mkdir -p $OUTPUT_DIR
|
||||||
|
|
||||||
deepspeed --num_gpus=$NUM_GPUS --master_port $MASTER_PORT finetune.py \
|
torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
|
||||||
--train_format input-output \
|
--train_format input-output \
|
||||||
--train_file $DATASET_PATH \
|
--train_file $DATASET_PATH \
|
||||||
--preprocessing_num_workers 1 \
|
--preprocessing_num_workers 1 \
|
||||||
|
|
|
@ -19,7 +19,7 @@ OUTPUT_DIR=output/${RUN_NAME}-${DATASTR}-${LR}
|
||||||
|
|
||||||
mkdir -p $OUTPUT_DIR
|
mkdir -p $OUTPUT_DIR
|
||||||
|
|
||||||
torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS finetune.py \
|
torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
|
||||||
--train_format multi-turn \
|
--train_format multi-turn \
|
||||||
--train_file $DATASET_PATH \
|
--train_file $DATASET_PATH \
|
||||||
--max_seq_length $MAX_SEQ_LEN \
|
--max_seq_length $MAX_SEQ_LEN \
|
||||||
|
|
|
@ -7,8 +7,8 @@ LR=2e-2
|
||||||
NUM_GPUS=1
|
NUM_GPUS=1
|
||||||
MAX_SOURCE_LEN=1024
|
MAX_SOURCE_LEN=1024
|
||||||
MAX_TARGET_LEN=128
|
MAX_TARGET_LEN=128
|
||||||
DEV_BATCH_SIZE=32
|
DEV_BATCH_SIZE=1
|
||||||
GRAD_ACCUMULARION_STEPS=1
|
GRAD_ACCUMULARION_STEPS=32
|
||||||
MAX_STEP=1000
|
MAX_STEP=1000
|
||||||
SAVE_INTERVAL=500
|
SAVE_INTERVAL=500
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ OUTPUT_DIR=output/${RUN_NAME}-${DATASTR}-${PRE_SEQ_LEN}-${LR}
|
||||||
|
|
||||||
mkdir -p $OUTPUT_DIR
|
mkdir -p $OUTPUT_DIR
|
||||||
|
|
||||||
torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS finetune.py \
|
torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
|
||||||
--train_format input-output \
|
--train_format input-output \
|
||||||
--train_file $DATASET_PATH \
|
--train_file $DATASET_PATH \
|
||||||
--preprocessing_num_workers 1 \
|
--preprocessing_num_workers 1 \
|
||||||
|
|
|
@ -6,8 +6,8 @@ PRE_SEQ_LEN=128
|
||||||
LR=2e-2
|
LR=2e-2
|
||||||
NUM_GPUS=1
|
NUM_GPUS=1
|
||||||
MAX_SEQ_LEN=2048
|
MAX_SEQ_LEN=2048
|
||||||
DEV_BATCH_SIZE=16
|
DEV_BATCH_SIZE=1
|
||||||
GRAD_ACCUMULARION_STEPS=1
|
GRAD_ACCUMULARION_STEPS=16
|
||||||
MAX_STEP=1000
|
MAX_STEP=1000
|
||||||
SAVE_INTERVAL=500
|
SAVE_INTERVAL=500
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ OUTPUT_DIR=output/${RUN_NAME}-${DATASTR}-${PRE_SEQ_LEN}-${LR}
|
||||||
|
|
||||||
mkdir -p $OUTPUT_DIR
|
mkdir -p $OUTPUT_DIR
|
||||||
|
|
||||||
torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS finetune.py \
|
torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
|
||||||
--train_format multi-turn \
|
--train_format multi-turn \
|
||||||
--train_file $DATASET_PATH \
|
--train_file $DATASET_PATH \
|
||||||
--max_seq_length $MAX_SEQ_LEN \
|
--max_seq_length $MAX_SEQ_LEN \
|
||||||
|
|
Loading…
Reference in New Issue