from dataclasses import dataclass from typing import Optional, Tuple import torch import torch.nn as nn from transformers import LlamaPreTrainedModel, LlamaModel from transformers.utils import ModelOutput @dataclass class MultiAspectRewardOutput(ModelOutput): """ Custom output class to return multi-aspect predictions plus final reward. Args: aspect_scores (torch.FloatTensor): shape (batch, 5) final_reward (torch.FloatTensor): shape (batch,) logits (torch.FloatTensor): shape (batch,) same as final_reward loss (torch.FloatTensor): optional scalar hidden_states (tuple(torch.FloatTensor)): optional hidden states attentions (tuple(torch.FloatTensor)): optional attentions """ aspect_scores: torch.FloatTensor = None final_reward: torch.FloatTensor = None logits: torch.FloatTensor = None loss: Optional[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None class LlamaFixedWeightReward(LlamaPreTrainedModel): """ A single final class that: 1) Optionally takes a pretrained Llama backbone (base_llama), 2) Predicts 5 aspect scores, computing MSE if 5-dim labels are provided, 3) Aggregates the 5 aspect scores via fixed weights -> 1 scalar reward, 4) Returns MultiAspectRewardOutput with shape [batch] in 'final_reward' and 'logits'. """ def __init__(self, config, base_llama=None, rule_weights=None): """ Args: config: LlamaConfig with num_labels=5 for multi-aspect predictions. base_llama: (optional) an already loaded LlamaModel rule_weights: (optional) A list or torch.Tensor of shape (5,) for aggregation. If None, defaults to [0.2, 0.2, 0.2, 0.2, 0.2]. """ super().__init__(config) # 1) If base_llama is given, re-use that. Otherwise instantiate from config if base_llama is not None: self.llama = base_llama else: self.llama = LlamaModel(config) # 2) Linear head to predict 5 aspect scores # Expect config.num_labels=5 self.aspect_head = nn.Linear(config.hidden_size, config.num_labels) # 3) Register the fixed aggregator weights if rule_weights is not None: w = torch.tensor(rule_weights, dtype=torch.float) else: weights = [1/config.num_labels] * config.num_labels # weights = [1.0] + [0.0] *9 #DEBUG w = torch.tensor(weights, dtype=torch.float) self.register_buffer("rule_weights", w.view(1, -1), persistent=True) self.post_init() def forward( self, input_ids=None, attention_mask=None, labels=None, # shape: (batch, 5), optional **kwargs ): # 1) Forward pass through Llama outputs = self.llama( input_ids=input_ids, attention_mask=attention_mask, **kwargs ) # last hidden state: [batch, seq_len, hidden_size] last_hidden = outputs.last_hidden_state # 2) pool by taking the last token representation pooled = last_hidden[:, -1, :] # [batch, hidden_size] # 3) Predict 5 aspect scores aspect_scores = self.aspect_head(pooled) # [batch, 5] # If your labels are in [0,1], clamp with sigmoid aspect_scores = torch.sigmoid(aspect_scores) # 4) optional MSE loss loss = None if labels is not None: mse_fn = nn.MSELoss() loss = mse_fn(aspect_scores, labels.float()) # 5) aggregate via fixed weights => final scalar: shape [batch] reward = (aspect_scores * self.rule_weights).sum(dim=-1) # Return a custom output return MultiAspectRewardOutput( loss=loss, aspect_scores=aspect_scores, # shape: [batch, 5] final_reward=reward, # shape: [batch] logits=reward, # same as final_reward hidden_states=outputs.hidden_states, attentions=outputs.attentions )