最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

Prompt

运维笔记admin89浏览0评论

Prompt

Prompt

Prompt-Tuning源码分析

源码

我们这里的代码解析以huggingface peft源码为主
从模型类结构可以看到,Prompt Tuning 只在输入层加入 prompt virtual tokens,其他地方均没有变化,具体可查看 PromptEmbedding 的源码。

伪代码示例

soft_prompt=torch.nn.Parameter(#Make tensor trainable 
torch.rand(num_tokens,embed_dim))#Initialize soft prompt tensor 
def input_with_softprompt(x,soft_prompt):x=concatenate([soft_prompt,x] #Prepend soft prompt to input dim=seq_len)return x 
model(input_with_softprompt(x))

peft源码

class PromptEmbedding(torch.nn.Module):"""```py>>> from peft import PromptEmbedding, PromptTuningConfig>>> config = PromptTuningConfig(...     peft_type="PROMPT_TUNING",...     task_type="SEQ_2_SEQ_LM",...     num_virtual_tokens=20,...     token_dim=768,...     num_transformer_submodules=1,...     num_attention_heads=12,...     num_layers=12,...     prompt_tuning_init="TEXT",...     prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",...     tokenizer_name_or_path="t5-base",... )>>> # t5_model.shared is the word embeddings of the base model>>> prompt_embedding = PromptEmbedding(config, t5_model.shared)```Input Shape: (`batch_size`, `total_virtual_tokens`)Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`)"""def __init__(self, config, word_embeddings):super().__init__()total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodulesself.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)if config.prompt_tuning_init == PromptTuningInit.TEXT:from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)init_text = config.prompt_tuning_init_textinit_token_ids = tokenizer(init_text)["input_ids"]# Trim or iterate until num_text_tokens matches total_virtual_tokensnum_text_tokens = len(init_token_ids)if num_text_tokens > total_virtual_tokens:init_token_ids = init_token_ids[:total_virtual_tokens]elif num_text_tokens < total_virtual_tokens:num_reps = math.ceil(total_virtual_tokens / num_text_tokens)init_token_ids = init_token_ids * num_repsinit_token_ids = init_token_ids[:total_virtual_tokens]word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()word_embedding_weights = word_embedding_weights.to(torch.float32)self.embedding.weight = torch.nn.Parameter(word_embedding_weights)def forward(self, indices):# Just get embeddingsprompt_embeddings = self.embedding(indices)return prompt_embeddings

输出的模型权重文件如下所示:

/data/nfs/llm/model/bloomz-560m_PROMPT_TUNING_CAUSAL_LM
├── [ 500]  adapter_config.json
├── [ 33K]  adapter_model.bin
└── [ 111]  README.md0 directories, 3 files

其中,adapter_config.json 为 Prompt Tuning 配置文件;adapter_model.bin 为 Prompt Tuning 权重文件。

推理

from peft import PeftModel, PeftConfigpeft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"# 加载PEFT配置
config = PeftConfig.from_pretrained(peft_model_id)# 加载基础模型
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
# 加载PEFT模型
model = PeftModel.from_pretrained(model, peft_model_id)# Tokenizer编码
inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")# 模型推理
outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=10, eos_token_id=3)# Tokenizer 解码
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))
发布评论

评论列表(0)

  1. 暂无评论