Prompt
Prompt-Tuning源码分析
源码
我们这里的代码解析以huggingface peft源码为主
从模型类结构可以看到,Prompt Tuning 只在输入层加入 prompt virtual tokens,其他地方均没有变化,具体可查看 PromptEmbedding 的源码。
伪代码示例
soft_prompt=torch.nn.Parameter(#Make tensor trainable
torch.rand(num_tokens,embed_dim))#Initialize soft prompt tensor
def input_with_softprompt(x,soft_prompt):x=concatenate([soft_prompt,x] #Prepend soft prompt to input dim=seq_len)return x
model(input_with_softprompt(x))
peft源码
class PromptEmbedding(torch.nn.Module):"""```py>>> from peft import PromptEmbedding, PromptTuningConfig>>> config = PromptTuningConfig(... peft_type="PROMPT_TUNING",... task_type="SEQ_2_SEQ_LM",... num_virtual_tokens=20,... token_dim=768,... num_transformer_submodules=1,... num_attention_heads=12,... num_layers=12,... prompt_tuning_init="TEXT",... prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",... tokenizer_name_or_path="t5-base",... )>>> # t5_model.shared is the word embeddings of the base model>>> prompt_embedding = PromptEmbedding(config, t5_model.shared)```Input Shape: (`batch_size`, `total_virtual_tokens`)Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`)"""def __init__(self, config, word_embeddings):super().__init__()total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodulesself.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)if config.prompt_tuning_init == PromptTuningInit.TEXT:from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)init_text = config.prompt_tuning_init_textinit_token_ids = tokenizer(init_text)["input_ids"]# Trim or iterate until num_text_tokens matches total_virtual_tokensnum_text_tokens = len(init_token_ids)if num_text_tokens > total_virtual_tokens:init_token_ids = init_token_ids[:total_virtual_tokens]elif num_text_tokens < total_virtual_tokens:num_reps = math.ceil(total_virtual_tokens / num_text_tokens)init_token_ids = init_token_ids * num_repsinit_token_ids = init_token_ids[:total_virtual_tokens]word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()word_embedding_weights = word_embedding_weights.to(torch.float32)self.embedding.weight = torch.nn.Parameter(word_embedding_weights)def forward(self, indices):# Just get embeddingsprompt_embeddings = self.embedding(indices)return prompt_embeddings
输出的模型权重文件如下所示:
/data/nfs/llm/model/bloomz-560m_PROMPT_TUNING_CAUSAL_LM
├── [ 500] adapter_config.json
├── [ 33K] adapter_model.bin
└── [ 111] README.md0 directories, 3 files
其中,adapter_config.json 为 Prompt Tuning 配置文件;adapter_model.bin 为 Prompt Tuning 权重文件。
推理
from peft import PeftModel, PeftConfigpeft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"# 加载PEFT配置
config = PeftConfig.from_pretrained(peft_model_id)# 加载基础模型
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
# 加载PEFT模型
model = PeftModel.from_pretrained(model, peft_model_id)# Tokenizer编码
inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")# 模型推理
outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=10, eos_token_id=3)# Tokenizer 解码
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))