最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

machine learning - How to Compute Teacher-Forced Accuracy (TFA) for Hugging Face Models While Handling EOS Tokens? - Stack Overf

programmeradmin1浏览0评论

I am trying to compute Teacher-Forced Accuracy (TFA) for Hugging Face models, ensuring the following:

  1. EOS Token Handling: The model should be rewarded for predicting the first EOS token.
  2. Ignoring Padding: Any padding tokens (beyond the first EOS) should be ignored during accuracy calculation.
  3. Right-Shifted Input: The inputs are shifted correctly for teacher-forced training.
  4. List item

Here’s the full code I wrote:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def compute_tfa(model, tokenizer, input_texts):
    """
    Computes Teacher-Forced Accuracy (TFA), rewarding the model for correctly predicting
    the first EOS token while ignoring predictions for padding tokens.

    Parameters:
        model: The language model (Hugging Face CausalLM).
        tokenizer: The tokenizer corresponding to the model.
        input_texts: List of input texts to compute TFA.

    Returns:
        TFA score as a float.
    """
    # Tokenize input texts
    tokenizer.pad_token = tokenizer.eos_token  # Use EOS as the pad token
    inputs = tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True)
    input_ids = inputs['input_ids']
    
    # Create right-shifted input by adding the EOS token at the beginning
    eos_token_id = tokenizer.eos_token_id
    right_shifted_input_ids = torch.cat([
        torch.full((input_ids.shape[0], 1), eos_token_id, dtype=torch.long),  # Add EOS token
        input_ids[:, :-1]
    ], dim=1)

    # Perform a forward pass with the right-shifted inputs
    with torch.no_grad():
        outputs = model(input_ids=right_shifted_input_ids)
        logits = outputs.logits  # Shape: (batch_size, sequence_length, vocab_size)

    # Compute predictions
    predicted_token_ids = torch.argmax(logits, dim=-1)  # Shape: (batch_size, sequence_length)

    # Find the first EOS position in each sequence
    eos_positions = (input_ids == eos_token_id).int().argmax(dim=1)  # Shape: (batch_size,)

    # Mask to ignore tokens after the first EOS
    sequence_lengths = input_ids.size(1)
    mask = torch.arange(sequence_lengths).unsqueeze(0).to(input_ids.device)
    mask = mask < eos_positions.unsqueeze(1)

    # Include the first EOS token in the mask
    mask.scatter_(1, eos_positions.unsqueeze(1), 1)

    # Apply the mask to filter predictions and labels
    filtered_predictions = predicted_token_ids[mask]
    filtered_labels = input_ids[mask]

    # Compute accuracy
    correct_predictions = (filtered_predictions == filtered_labels).float()
    accuracy = correct_predictions.mean().item()

    return accuracy

def main():
    # Define models and their URLs
    models_and_urls = {
        "google/gemma-2-2b": ";,
        "meta-llama/Llama-3.1-8B": ".1-8B",
        "gpt2": ";
    }

    # Define input texts
    input_texts = [
        "The quick brown fox jumps over the lazy dog.",
        "Artificial Intelligence is transforming the world of science."
    ]

    # Test each model
    for model_name, model_url in models_and_urls.items():
        print(f"Testing model: {model_name} ({model_url})")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)

        # Compute TFA
        tfa_score = compute_tfa(model, tokenizer, input_texts)
        print(f"Teacher-Forced Accuracy (TFA) for {model_name}: {tfa_score:.4f}\n")

if __name__ == "__main__":
    main()

What I Need Help With:

  1. EOS Token Masking: Is the masking logic I implemented for ignoring tokens after the first EOS correct? Specifically, I used:

    mask = torch.arange(sequence_lengths).unsqueeze(0).to(input_ids.device)
    mask = mask < eos_positions.unsqueeze(1)
    mask.scatter_(1, eos_positions.unsqueeze(1), 1)
    

    Is this the best way to ensure only tokens up to and including the first EOS are considered?

  2. Right-Shifted Input: I prepend the EOS token to the input like this:

    right_shifted_input_ids = torch.cat([
        torch.full((input_ids.shape[0], 1), eos_token_id, dtype=torch.long),
        input_ids[:, :-1]
    ], dim=1)
    

    Is this a standard way to handle the right-shift for teacher-forced evaluation?

  3. Generalization: The code is designed to evaluate multiple models, such as google/gemma-2-2b, meta-llama/Llama-3.1-8B, and gpt2. Are there any additional considerations or best practices I should follow for TFA computation across diverse models?

  4. Performance Optimization: Is there a more efficient way to compute the mask and apply it to the predictions and labels? My current method seems to work but might be suboptimal for larger datasets.

  5. cross:

  6. cross2:


Bounty: Fixing the off by 1 error

I get my teacher forced accuracy (tfa) is always zero. Current code:

import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    PreTrainedTokenizer,
    PreTrainedModel,
)
from typing import List, Optional
import os


def seed_everything(seed: int = 0) -> None:
    """
    Seed all relevant libraries for reproducibility.
    """
    import random
    import numpy as np
    from transformers import set_seed as hf_set_seed

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    hf_set_seed(seed)


def compute_tfa(
    model: nn.Module,
    tokenizer: PreTrainedTokenizer,
    input_texts: List[str],
    bos_token_id: Optional[int] = None,
    eos_token_id: Optional[int] = None,
    pad_token_id: Optional[int] = None,
) -> float:
    """
    Calculate Teacher-Forced Accuracy (TFA) for a model.
    """
    eos_token_id = eos_token_id or tokenizer.eos_token_id
    bos_token_id = bos_token_id or tokenizer.bos_token_id or eos_token_id
    pad_token_id = pad_token_id or tokenizer.pad_token_id or eos_token_id
    tokenizer.pad_token_id = pad_token_id

    inputs = tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True)
    input_ids = inputs['input_ids']

    if input_ids[0, 0].item() != bos_token_id:
        right_shifted_input_ids = torch.cat(
            [torch.full((input_ids.size(0), 1), bos_token_id), input_ids[:, :-1]], dim=1
        )
        labels = input_ids
    else:
        right_shifted_input_ids = input_ids[:, :-1]
        labels = input_ids[:, 1:]

    with torch.no_grad():
        logits = model(input_ids=right_shifted_input_ids).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    eos_positions = (labels == eos_token_id).int().argmax(dim=1)
    seq_len = labels.size(1)
    mask = torch.arange(seq_len).unsqueeze(0).to(input_ids.device) <= eos_positions.unsqueeze(1)

    filtered_preds = predicted_ids[mask]
    filtered_labels = labels[mask]

    accuracy = (filtered_preds == filtered_labels).float().mean().item()
    return accuracy


def main() -> None:
    os.environ['CUDA_VISIBLE_DEVICES'] = '7'
    seed_everything(seed=123)

    model_configs = [
        {"name": "Meta-Llama-3-8B-Instruct", "repo": "meta-llama/Meta-Llama-3-8B-Instruct"}
    ]

    input_texts = ["The happy dog."]
    code_input = ["<|fim_prefix|>def solution():\n    return True<|fim_suffix|><|fim_middle|>"]

    for config in model_configs:
        name, repo = config["name"], config["repo"]
        print(f"Evaluating {name} from {repo}")
        model = AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)

        bos_id, eos_id, pad_id = tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id

        tfa_score = compute_tfa(model, tokenizer, input_texts, bos_id, eos_id, pad_id)
        assert 0.0 <= tfa_score <= 1.0, f"TFA out of range: {tfa_score}"
        print(f"[{name}] TFA (General): {tfa_score:.4f}")

        if "codegemma-2b" in name.lower():
            tfa_code = compute_tfa(model, tokenizer, code_input, bos_id, eos_id, pad_id)
            assert 0.0 <= tfa_code <= 1.0, f"TFA (Code) out of range: {tfa_code}"
            print(f"[{name}] TFA (Code Input): {tfa_code:.4f}")

        print()


if __name__ == "__main__":
    main()

What might be wrong? TFA is always zero. Need careful reasoning and suggestions.

I’m trying to compute Teacher-Forced Accuracy (TFA) for a language model. The expected behavior is that the model predicts the input correctly under teacher-forcing, where the inputs are right-shifted by one token (prepended with BOS if necessary). However, the TFA score always comes out as zero, which is highly suspicious.

My Reasoning:

There are two cases regarding the presence of the BOS token, and my code handles them as follows:

  1. When the tokenizer includes BOS:

    • The output from the tokenizer also includes EOS because my tokenizer code assumes EOS is present (and all models should have EOS).
    • For the model input:
      • The input is right-shifted by one token.
      • It’s fine to remove EOS from the input to avoid interfering with predictions.
    • For evaluation:
      • Remove the BOS token from the model's output to align predictions with labels.
      • Ensure EOS in the predictions matches the labels correctly.
  2. When the tokenizer does not include BOS:

    • In this case, BOS is not set in the tokenizer (and I’ve already verified it’s absent).
    • For the model input:
      • I prepend BOS to the input manually.
      • It’s fine to remove EOS as before.
    • For evaluation:
      • Remove the BOS token from predictions and ensure proper alignment between predictions and labels.

My Questions:

  • Isn’t this logic correct for computing TFA? It ensures alignment between inputs, predictions, and labels in both cases.
  • If the code is logically correct, why is TFA always zero? A score of exactly zero is highly suspicious.
  • Could there be an edge case or implementation detail that I’m missing?

What I Need:

  • Reasoning: Please reason carefully through my explanation and code. Is my reasoning correct, or am I misunderstanding something about teacher-forcing or token alignment?
  • Suggestions: If something seems wrong or could be improved, what tests or changes should I make? I’d also appreciate specific code fixes if necessary.
  • Tests to Debug: What specific cases or checks should I run to pinpoint the issue?

I am trying to compute Teacher-Forced Accuracy (TFA) for Hugging Face models, ensuring the following:

  1. EOS Token Handling: The model should be rewarded for predicting the first EOS token.
  2. Ignoring Padding: Any padding tokens (beyond the first EOS) should be ignored during accuracy calculation.
  3. Right-Shifted Input: The inputs are shifted correctly for teacher-forced training.
  4. List item

Here’s the full code I wrote:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def compute_tfa(model, tokenizer, input_texts):
    """
    Computes Teacher-Forced Accuracy (TFA), rewarding the model for correctly predicting
    the first EOS token while ignoring predictions for padding tokens.

    Parameters:
        model: The language model (Hugging Face CausalLM).
        tokenizer: The tokenizer corresponding to the model.
        input_texts: List of input texts to compute TFA.

    Returns:
        TFA score as a float.
    """
    # Tokenize input texts
    tokenizer.pad_token = tokenizer.eos_token  # Use EOS as the pad token
    inputs = tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True)
    input_ids = inputs['input_ids']
    
    # Create right-shifted input by adding the EOS token at the beginning
    eos_token_id = tokenizer.eos_token_id
    right_shifted_input_ids = torch.cat([
        torch.full((input_ids.shape[0], 1), eos_token_id, dtype=torch.long),  # Add EOS token
        input_ids[:, :-1]
    ], dim=1)

    # Perform a forward pass with the right-shifted inputs
    with torch.no_grad():
        outputs = model(input_ids=right_shifted_input_ids)
        logits = outputs.logits  # Shape: (batch_size, sequence_length, vocab_size)

    # Compute predictions
    predicted_token_ids = torch.argmax(logits, dim=-1)  # Shape: (batch_size, sequence_length)

    # Find the first EOS position in each sequence
    eos_positions = (input_ids == eos_token_id).int().argmax(dim=1)  # Shape: (batch_size,)

    # Mask to ignore tokens after the first EOS
    sequence_lengths = input_ids.size(1)
    mask = torch.arange(sequence_lengths).unsqueeze(0).to(input_ids.device)
    mask = mask < eos_positions.unsqueeze(1)

    # Include the first EOS token in the mask
    mask.scatter_(1, eos_positions.unsqueeze(1), 1)

    # Apply the mask to filter predictions and labels
    filtered_predictions = predicted_token_ids[mask]
    filtered_labels = input_ids[mask]

    # Compute accuracy
    correct_predictions = (filtered_predictions == filtered_labels).float()
    accuracy = correct_predictions.mean().item()

    return accuracy

def main():
    # Define models and their URLs
    models_and_urls = {
        "google/gemma-2-2b": "https://huggingface.co/google/gemma-2-2b",
        "meta-llama/Llama-3.1-8B": "https://huggingface.co/meta-llama/Llama-3.1-8B",
        "gpt2": "https://huggingface.co/gpt2"
    }

    # Define input texts
    input_texts = [
        "The quick brown fox jumps over the lazy dog.",
        "Artificial Intelligence is transforming the world of science."
    ]

    # Test each model
    for model_name, model_url in models_and_urls.items():
        print(f"Testing model: {model_name} ({model_url})")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)

        # Compute TFA
        tfa_score = compute_tfa(model, tokenizer, input_texts)
        print(f"Teacher-Forced Accuracy (TFA) for {model_name}: {tfa_score:.4f}\n")

if __name__ == "__main__":
    main()

What I Need Help With:

  1. EOS Token Masking: Is the masking logic I implemented for ignoring tokens after the first EOS correct? Specifically, I used:

    mask = torch.arange(sequence_lengths).unsqueeze(0).to(input_ids.device)
    mask = mask < eos_positions.unsqueeze(1)
    mask.scatter_(1, eos_positions.unsqueeze(1), 1)
    

    Is this the best way to ensure only tokens up to and including the first EOS are considered?

  2. Right-Shifted Input: I prepend the EOS token to the input like this:

    right_shifted_input_ids = torch.cat([
        torch.full((input_ids.shape[0], 1), eos_token_id, dtype=torch.long),
        input_ids[:, :-1]
    ], dim=1)
    

    Is this a standard way to handle the right-shift for teacher-forced evaluation?

  3. Generalization: The code is designed to evaluate multiple models, such as google/gemma-2-2b, meta-llama/Llama-3.1-8B, and gpt2. Are there any additional considerations or best practices I should follow for TFA computation across diverse models?

  4. Performance Optimization: Is there a more efficient way to compute the mask and apply it to the predictions and labels? My current method seems to work but might be suboptimal for larger datasets.

  5. cross: https://discuss.huggingface.co/t/how-to-compute-teacher-forced-accuracy-tfa-for-hugging-face-models-while-handling-eos-tokens/126403

  6. cross2: https://discuss.pytorch./t/how-to-compute-teacher-forced-accuracy-tfa-for-hugging-face-models-while-handling-eos-tokens/213435


Bounty: Fixing the off by 1 error

I get my teacher forced accuracy (tfa) is always zero. Current code:

import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    PreTrainedTokenizer,
    PreTrainedModel,
)
from typing import List, Optional
import os


def seed_everything(seed: int = 0) -> None:
    """
    Seed all relevant libraries for reproducibility.
    """
    import random
    import numpy as np
    from transformers import set_seed as hf_set_seed

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    hf_set_seed(seed)


def compute_tfa(
    model: nn.Module,
    tokenizer: PreTrainedTokenizer,
    input_texts: List[str],
    bos_token_id: Optional[int] = None,
    eos_token_id: Optional[int] = None,
    pad_token_id: Optional[int] = None,
) -> float:
    """
    Calculate Teacher-Forced Accuracy (TFA) for a model.
    """
    eos_token_id = eos_token_id or tokenizer.eos_token_id
    bos_token_id = bos_token_id or tokenizer.bos_token_id or eos_token_id
    pad_token_id = pad_token_id or tokenizer.pad_token_id or eos_token_id
    tokenizer.pad_token_id = pad_token_id

    inputs = tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True)
    input_ids = inputs['input_ids']

    if input_ids[0, 0].item() != bos_token_id:
        right_shifted_input_ids = torch.cat(
            [torch.full((input_ids.size(0), 1), bos_token_id), input_ids[:, :-1]], dim=1
        )
        labels = input_ids
    else:
        right_shifted_input_ids = input_ids[:, :-1]
        labels = input_ids[:, 1:]

    with torch.no_grad():
        logits = model(input_ids=right_shifted_input_ids).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    eos_positions = (labels == eos_token_id).int().argmax(dim=1)
    seq_len = labels.size(1)
    mask = torch.arange(seq_len).unsqueeze(0).to(input_ids.device) <= eos_positions.unsqueeze(1)

    filtered_preds = predicted_ids[mask]
    filtered_labels = labels[mask]

    accuracy = (filtered_preds == filtered_labels).float().mean().item()
    return accuracy


def main() -> None:
    os.environ['CUDA_VISIBLE_DEVICES'] = '7'
    seed_everything(seed=123)

    model_configs = [
        {"name": "Meta-Llama-3-8B-Instruct", "repo": "meta-llama/Meta-Llama-3-8B-Instruct"}
    ]

    input_texts = ["The happy dog."]
    code_input = ["<|fim_prefix|>def solution():\n    return True<|fim_suffix|><|fim_middle|>"]

    for config in model_configs:
        name, repo = config["name"], config["repo"]
        print(f"Evaluating {name} from {repo}")
        model = AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)

        bos_id, eos_id, pad_id = tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id

        tfa_score = compute_tfa(model, tokenizer, input_texts, bos_id, eos_id, pad_id)
        assert 0.0 <= tfa_score <= 1.0, f"TFA out of range: {tfa_score}"
        print(f"[{name}] TFA (General): {tfa_score:.4f}")

        if "codegemma-2b" in name.lower():
            tfa_code = compute_tfa(model, tokenizer, code_input, bos_id, eos_id, pad_id)
            assert 0.0 <= tfa_code <= 1.0, f"TFA (Code) out of range: {tfa_code}"
            print(f"[{name}] TFA (Code Input): {tfa_code:.4f}")

        print()


if __name__ == "__main__":
    main()

What might be wrong? TFA is always zero. Need careful reasoning and suggestions.

I’m trying to compute Teacher-Forced Accuracy (TFA) for a language model. The expected behavior is that the model predicts the input correctly under teacher-forcing, where the inputs are right-shifted by one token (prepended with BOS if necessary). However, the TFA score always comes out as zero, which is highly suspicious.

My Reasoning:

There are two cases regarding the presence of the BOS token, and my code handles them as follows:

  1. When the tokenizer includes BOS:

    • The output from the tokenizer also includes EOS because my tokenizer code assumes EOS is present (and all models should have EOS).
    • For the model input:
      • The input is right-shifted by one token.
      • It’s fine to remove EOS from the input to avoid interfering with predictions.
    • For evaluation:
      • Remove the BOS token from the model's output to align predictions with labels.
      • Ensure EOS in the predictions matches the labels correctly.
  2. When the tokenizer does not include BOS:

    • In this case, BOS is not set in the tokenizer (and I’ve already verified it’s absent).
    • For the model input:
      • I prepend BOS to the input manually.
      • It’s fine to remove EOS as before.
    • For evaluation:
      • Remove the BOS token from predictions and ensure proper alignment between predictions and labels.

My Questions:

  • Isn’t this logic correct for computing TFA? It ensures alignment between inputs, predictions, and labels in both cases.
  • If the code is logically correct, why is TFA always zero? A score of exactly zero is highly suspicious.
  • Could there be an edge case or implementation detail that I’m missing?

What I Need:

  • Reasoning: Please reason carefully through my explanation and code. Is my reasoning correct, or am I misunderstanding something about teacher-forcing or token alignment?
  • Suggestions: If something seems wrong or could be improved, what tests or changes should I make? I’d also appreciate specific code fixes if necessary.
  • Tests to Debug: What specific cases or checks should I run to pinpoint the issue?
Share Improve this question edited Jan 21 at 1:32 Charlie Parker asked Nov 21, 2024 at 0:25 Charlie ParkerCharlie Parker 5,29676 gold badges246 silver badges407 bronze badges 3
  • ref: gist.github/brando90/66a58ad38f702cf92afa7f7e03877530 – Charlie Parker Commented Nov 21, 2024 at 0:33
  • ref: discuss.pytorch./t/… – Charlie Parker Commented Nov 26, 2024 at 3:37
  • ref: discuss.huggingface.co/t/… – Charlie Parker Commented Nov 26, 2024 at 3:38
Add a comment  | 

1 Answer 1

Reset to default 0

There are some improvements that can be made eg an example with chat template but that can be improved later if needed:

import os
import time
import torch
from typing import Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel
from datasets import load_dataset

def seed_everything(seed: int = 42):
    """
    Seed Python, NumPy, and PyTorch for reproducibility.
    """
    import random
    import numpy as np
    from transformers import set_seed as hf_set_seed

    print(f"Seeding everything with seed={seed}")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():
        hf_set_seed(seed)
    else:
        print("Warning: Full determinism is best on GPU.")

def teacher_forced_accuracy_tfa(
    prompt: str,
    response: str,
    model: PreTrainedModel,
    repo: str,
    device: str = "cuda"
) -> float:
    """
    Computes teacher-forced token-level accuracy for a prompt + response.
    """
    combined_text = prompt + "\n\n" + response
    tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)
    enc = tokenizer(combined_text, return_tensors="pt")
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    preds = torch.argmax(logits, dim=-1)

    response_enc = tokenizer(response, add_special_tokens=False)
    len_response = len(response_enc["input_ids"])
    prompt_enc = tokenizer(prompt, add_special_tokens=False)
    len_prompt = len(prompt_enc["input_ids"])
    total_seq_len = input_ids.size(1)

    if len_prompt + len_response >= total_seq_len:
        return 0.0

    # Model predicts the next token, so offset labels by +1 in the response area
    pred_slice = preds[:, len_prompt : len_prompt + len_response]
    label_slice = input_ids[:, (len_prompt + 1) : (len_prompt + 1 + len_response)]

    if pred_slice.size(1) == 0 or label_slice.size(1) == 0:
        return 0.0

    correctness = (pred_slice == label_slice).float()
    return correctness.mean().item()

def compute_tfa_for_subds(
    sub_ds,
    model: PreTrainedModel,
    repo: str,
    prompt_format_fn=None,
    device: str = "cuda"
) -> float:
    """
    Computes the average teacher-forced accuracy over a subset of data.
    """
    sum_acc = 0.0
    count = 0

    for i, example in enumerate(sub_ds):
        nl_statement = example["nl_statement"]
        formal_statement = example["formal_statement"]

        if prompt_format_fn is not None:
            prompt = prompt_format_fn(nl_statement)
        else:
            prompt = (
                "Translate the natural language version of the math statement "
                f"to a formal Lean version:\n{nl_statement}\n"
            )

        acc_i = teacher_forced_accuracy_tfa(
            prompt=prompt,
            response=formal_statement,
            model=model,
            repo=repo,
            device=device
        )
        sum_acc += acc_i
        count += 1
        print(f"Example {i}: TFA = {acc_i:.4f}")

    return sum_acc / count if count > 0 else 0.0

def main():
    start_time = time.time()
    os.environ['CUDA_VISIBLE_DEVICES'] = '4'
    seed_everything()

    ds = load_dataset("hoskinson-center/proofnet", split="validation")
    N = 5
    sub_ds = ds.select(range(min(N, len(ds))))

    model_token_configs = [
        {"name": "internlm2-math-plus-1_8b", "repo": "internlm/internlm2-math-plus-1_8b"},
        {"name": "google/gemma-2-2b", "repo": "google/gemma-2-2b"},
        {"name": "Mistral-7B-v0.1", "repo": "mistralai/Mistral-7B-v0.1"},
        {"name": "google/codegemma-2b", "repo": "google/codegemma-2b"},
        {"name": "Meta-Llama-3-8B", "repo": "meta-llama/Meta-Llama-3-8B"},
        {"name": "Meta-Llama-3-8B-Instruct", "repo": "meta-llama/Meta-Llama-3-8B-Instruct"},
        {"name": "google/gemma-2-2b-it", "repo": "google/gemma-2-2b-it"},
        {"name": "GPT-2 (small)", "repo": "gpt2"},
    ]

    def my_prompt_format(nl_statement: str) -> str:
        return (
            "Translate the natural language version of the mathematical statement "
            f"to a formal Lean version:\n{nl_statement}\n"
        )

    device = "cuda" if torch.cuda.is_available() else "cpu"

    for config in model_token_configs:
        model_name = config["name"]
        repo = config["repo"]

        print(f"\nEvaluating {model_name} on {N} example(s).")
        model = AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True).to(device)

        avg_tfa = compute_tfa_for_subds(
            sub_ds=sub_ds,
            model=model,
            repo=repo,
            prompt_format_fn=my_prompt_format,
            device=device
        )
        print(f" => Average TFA for {model_name} = {avg_tfa:.4f}")

    total_seconds = time.time() - start_time
    print(f"\nDone. Total run time: {total_seconds:.2f} s.")

if __name__ == "__main__":
    main()

eg with

https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct

from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" # the device to load the model onto

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

messages = [
    {"role": "user", "content": "What is your favourite condiment?"},
    {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
    {"role": "user", "content": "Do you have mayonnaise recipes?"}
]

encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")

model_inputs = encodeds.to(device)
model.to(device)

generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
decoded = tokenizer.batch_decode(generated_ids)
print(decoded[0])

or https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1

from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" # the device to load the model onto

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

messages = [
    {"role": "user", "content": "What is your favourite condiment?"},
    {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
    {"role": "user", "content": "Do you have mayonnaise recipes?"}
]

encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")

model_inputs = encodeds.to(device)
model.to(device)

generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
decoded = tokenizer.batch_decode(generated_ids)
print(decoded[0])

It could also be vectorized for speed.

More references:

  • https://github/kaifronsdal/TheoremSense/blob/main/theoremsense/model_generate.py
  • https://github/kaifronsdal/TheoremSense/blob/main/theoremsense/old/compute_TFA.py

sample output:

(zip_fit) brando9@skampere1~/ZIP-FIT $  cd /lfs/skampere1/0/brando9/ZIP-FIT ; /usr/bin/env /lfs/skampere1/0/brando9/miniconda/envs/zip_fit/bin/python /lfs/skampere1/0/brando9/.vscode-server-insiders/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher 56493 -- /lfs/skampere1/0/brando9/ZIP-FIT/zip_fit/tfa.py 
/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/utils/hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
  warnings.warn(
Setting random seed = 42

Evaluating internlm2-math-plus-1_8b from internlm/internlm2-math-plus-1_8b on 5 example(s) of ProofNet validation.
A new version of the following files was downloaded from https://huggingface.co/internlm/internlm2-math-plus-1_8b:
- configuration_internlm2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/internlm/internlm2-math-plus-1_8b:
- modeling_internlm2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/internlm/internlm2-math-plus-1_8b:
- tokenization_internlm2_fast.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
 Example 0: TFA = 0.5789
A new version of the following files was downloaded from https://huggingface.co/internlm/internlm2-math-plus-1_8b:
- tokenization_internlm2_fast.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
 Example 1: TFA = 0.7742
A new version of the following files was downloaded from https://huggingface.co/internlm/internlm2-math-plus-1_8b:
- tokenization_internlm2_fast.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
 Example 2: TFA = 0.7424
A new version of the following files was downloaded from https://huggingface.co/internlm/internlm2-math-plus-1_8b:
- tokenization_internlm2_fast.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
 Example 3: TFA = 0.8269
A new version of the following files was downloaded from https://huggingface.co/internlm/internlm2-math-plus-1_8b:
- tokenization_internlm2_fast.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
 Example 4: TFA = 0.7500
 => Average TFA for internlm2-math-plus-1_8b on these 5 example(s) = 0.7345
 => Time to compute TFA for internlm2-math-plus-1_8b: 12.18 seconds.

Evaluating google/gemma-2-2b from google/gemma-2-2b on 5 example(s) of ProofNet validation.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  9.25it/s]
The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
 Example 0: TFA = 0.5000
 Example 1: TFA = 0.5185
 Example 2: TFA = 0.4225
 Example 3: TFA = 0.6591
 Example 4: TFA = 0.6571
 => Average TFA for google/gemma-2-2b on these 5 example(s) = 0.5515
 => Time to compute TFA for google/gemma-2-2b: 6.12 seconds.

Evaluating Mistral-7B-v0.1 from mistralai/Mistral-7B-v0.1 on 5 example(s) of ProofNet validation.
Downloading shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 10485.76it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.21it/s]
 Example 0: TFA = 0.6429
 Example 1: TFA = 0.7500
 Example 2: TFA = 0.7125
 Example 3: TFA = 0.7407
 Example 4: TFA = 0.6923
 => Average TFA for Mistral-7B-v0.1 on these 5 example(s) = 0.7077
 => Time to compute TFA for Mistral-7B-v0.1: 6.43 seconds.

Evaluating google/codegemma-2b from google/codegemma-2b on 5 example(s) of ProofNet validation.
Downloading shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 10525.23it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.28it/s]
 Example 0: TFA = 0.3611
 Example 1: TFA = 0.4815
 Example 2: TFA = 0.3521
 Example 3: TFA = 0.4318
 Example 4: TFA = 0.4286
 => Average TFA for google/codegemma-2b on these 5 example(s) = 0.4110
 => Time to compute TFA for google/codegemma-2b: 6.24 seconds.

Evaluating Meta-Llama-3-8B from meta-llama/Meta-Llama-3-8B on 5 example(s) of ProofNet validation.
Downloading shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 12965.39it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.70it/s]
 Example 0: TFA = 0.5882
 Example 1: TFA = 0.6923
 Example 2: TFA = 0.7049
 Example 3: TFA = 0.8293
 Example 4: TFA = 0.7353
 => Average TFA for Meta-Llama-3-8B on these 5 example(s) = 0.7100
 => Time to compute TFA for Meta-Llama-3-8B: 8.12 seconds.

Evaluating Meta-Llama-3-8B-Instruct from meta-llama/Meta-Llama-3-8B-Instruct on 5 example(s) of ProofNet validation.
Downloading shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 14004.35it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.10it/s]
 Example 0: TFA = 0.5000
 Example 1: TFA = 0.7308
 Example 2: TFA = 0.6721
 Example 3: TFA = 0.8537
 Example 4: TFA = 0.8235
 => Average TFA for Meta-Llama-3-8B-Instruct on these 5 example(s) = 0.7160
 => Time to compute TFA for Meta-Llama-3-8B-Instruct: 7.55 seconds.

Evaluating google/gemma-2-2b-it from google/gemma-2-2b-it on 5 example(s) of ProofNet validation.
Downloading shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 4888.47it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.35it/s]
 Example 0: TFA = 0.3611
 Example 1: TFA = 0.4815
 Example 2: TFA = 0.4366
 Example 3: TFA = 0.5000
 Example 4: TFA = 0.5429
 => Average TFA for google/gemma-2-2b-it on these 5 example(s) = 0.4644
 => Time to compute TFA for google/gemma-2-2b-it: 6.26 seconds.

Evaluating GPT-2 (small) from gpt2 on 5 example(s) of ProofNet validation.
 Example 0: TFA = 0.2821
 Example 1: TFA = 0.1852
 Example 2: TFA = 0.2800
 Example 3: TFA = 0.2553
 Example 4: TFA = 0.2162
 => Average TFA for GPT-2 (small) on these 5 example(s) = 0.2438
 => Time to compute TFA for GPT-2 (small): 1.42 seconds.

Done. Total run time for all models: 55.71 seconds.
发布评论

评论列表(0)

  1. 暂无评论