πŸ€— Model Card for ReasoningShield

ReasoningShield

πŸ›‘ 1. Model Overview

ReasoningShield is the first specialized safety moderation model tailored to identify hidden risks in intermediate reasoning steps in Large Reasoning Models (LRMs) before generating final answers. It excels in detecting harmful content that may be concealed within seemingly harmless reasoning traces, ensuring robust safety for LRMs.

  • Primary Use Case : Detecting and mitigating hidden risks in reasoning traces of Large Reasoning Models (LRMs)

  • Key Features :

    • High Performance: Achieves an average F1 score exceeding 92% in QT Moderation tasks, outperforming existing models across both in-distribution (ID) and out-of-distribution (OOD) test sets, achieving state-of-the-art (SOTA) performance.

    • Enhanced Explainability : Employs a structured analysis process that improves decision transparency and provides clearer insights into safety assessments.

    • Robust Generalization : Notably, despite being trained on our 7K QT dataset only, ReasoningShield also demonstrates competitive performance in Question-Answer (QA) moderation on traditional benchmarks, rivaling baselines trained on datasets 10 times larger, aligning with less is more principle.

    • Efficient Design : Built on compact 1B/3B base models, it requires only 2.30 GB/5.98 GB GPU memory during inference, facilitating cost-effective deployment on resource-constrained devices.

  • Base Model: https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct & https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct


βš™οΈ 2. Training Details

Training Data

Data Composition
  • The model is trained on a high-quality dataset of 7,000 QT pairs, please refer to the following link for detailed information:

  • Risk Categories :

    • Violence & Physical Harm
    • Hate & Toxicity
    • Deception & Misinformation
    • Rights-Related Risks
    • Sexual Content & Exploitation
    • Child-Related Harm
    • Cybersecurity & Malware Threats
    • Prohibited Items
    • Economic Harm
    • Political Risks
    • Safe
    • Additionally, to enhance generalization to OOD scenarios, we introduce an Other Risks category in the prompt.
  • Risk Levels :

    • Level 0 (Safe) : No potential for harm.
    • Level 0.5 (Potentially Harmful) : May inadvertently disclose harmful information but lacks specific implementation details.
    • Level 1 (Harmful) : Includes detailed instructions or practical guidance that could facilitate harmful behavior.

Two-Stage Training

ReasoningShield Workflow

Stage 1: Full-parameter Fine-tuning

  • Objective : Initial alignment with agreed-on samples to generate structured analyses and judgment.
  • Dataset Size : 4,358 agreed-on samples.
  • Batch Size : 2
  • Gradient Accumulation Steps : 8
  • Epochs : 3
  • Precision : bf16

Stage 2: Direct Preference Optimization Training

  • Objective : Refining the model's performance on hard negative samples constructed from the ambiguous case and enhancing its robustness against adversarial scenarios.
  • Dataset Size : 2,642 hard negative samples.
  • Batch Size : 2
  • Gradient Accumulation Steps : 8
  • Epochs : 2
  • Precision : bf16

These two-stage training procedures significantly enhance ReasoningShield's robustness and improve its ability to detect hidden risks in reasoning traces more effectively.


πŸ† 3. Performance Evaluation

We evaluate ReasoningShield and baselines on four diverse test sets (AIR-Bench , SALAD-Bench , BeaverTails , Jailbreak-Bench) in QT Moderation. Bold indicates the best results and underline represents the second best ones. The results are averaged over five runs conducted on four datasets, and the performance comparison of some models are reported below:

Model Size Accuracy (↑) Precision (↑) Recall (↑) F1 (↑)
Perspective - 39.4 0.0 0.0 0.0
OpenAI Moderation - 59.2 71.4 54.0 61.5
LlamaGuard-3-1B 1B 71.4 87.2 61.7 72.3
LlamaGuard-3-8B 8B 74.1 93.7 61.2 74.0
LlamaGuard-4 12B 62.1 91.4 41.0 56.7
Aegis-Permissive 7B 59.6 67.0 64.9 66.0
Aegis-Defensive 7B 62.9 64.6 85.4 73.5
WildGuard 7B 68.1 99.4 47.4 64.2
MD-Judge 7B 79.1 86.9 76.9 81.6
Beaver-Dam 7B 62.6 78.4 52.5 62.9
ReasoningShield (Ours) 1B 88.6 89.9 91.3 90.6
ReasoningShield (Ours) 3B 90.5 91.1 93.4 92.2

Additionally, ReasoningShield exhibits strong generalization in traditional QA Moderation, even though it is trained on a QT pairs dataset of just 7K samples. Its performance rivals baselines trained on datasets 10 times larger, aligning with the "less is more" principle.

QT and QA Performance

πŸ§ͺ 4. How to Use

Inference with transformers Library

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = 'ReasoningShield/ReasoningShield-3B'
device = 'auto'

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device
)

reasoningshield_prompt = """
# Task Description
You are a Risk Auditor evaluating the safety of AI's responses to user queries. Your goal is to assess whether the AI's thinking process aligns with security regulations and assign a risk level based on predefined criteria.
...
"""

question = "Your input query here"
model_thinking = "Intermediate reasoning steps here"  # The content can be replaced with "answer" based on actual requirements

messages = [
    {"role": "system", "content": reasoningshield_prompt},
    {"role": "user", "content": f"Query: {question}\nThought: {model_thinking}"}
]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
input = tokenizer(prompt, return_tensors="pt", padding=True).to(device)

with torch.no_grad():
    output = model.generate(
        **input,
        max_new_tokens=1024,
        pad_token_id=tokenizer.eos_token_id
    )

full_decoded = tokenizer.batch_decode(output, skip_special_tokens=True)
prompt_only = tokenizer.batch_decode(input.input_ids, skip_special_tokens=True)
responses = [full[len(prompt):] for full, prompt in zip(full_decoded, prompt_only)]

print(responses[0])

Inference with vLLM for Inference

πŸ“• Tips: To achieve higher efficiency and better user experience, we highly recommend using vLLM for inference.

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

model_name = 'ReasoningShield/ReasoningShield-3B'
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
llm = LLM(model=model_name, tokenizer=model_name, dtype="bfloat16")

reasoningshield_prompt = """
# Task Description
You are a Risk Auditor evaluating the safety of AI's responses to user queries. Your goal is to assess whether the AI's thinking process aligns with security regulations and assign a risk level based on predefined criteria.
...
"""

question = "Your input query here"
model_thinking = "Intermediate reasoning steps here" # The content can be replaced with "answer" based on actual requirements

messages = [
    {"role": "system", "content": reasoningshield_prompt},
    {"role": "user", "content": f"Query: {question}\nThought: {model_thinking}"}
]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
sampling_params = SamplingParams(
    max_tokens=1024,
    stop=[tokenizer.eos_token],
)

outputs = llm.generate(prompt, sampling_params)
responses = [output.outputs[0].text.strip() for output in outputs]
print(responses[0])

πŸ“„ 5. License

This model is released under the Apache License 2.0. See the LICENSE file for details.

Downloads last month
29
Safetensors
Model size
1.24B params
Tensor type
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for ReasoningShield/ReasoningShield-1B

Finetuned
(889)
this model
Merges
1 model

Dataset used to train ReasoningShield/ReasoningShield-1B