Model Database's logo
Join the Model Database community

and get access to the augmented documentation experience

to get started

NeuronTrainer

The NeuronTrainer class provides an extended API for the feature-complete Transformers Trainer. It is used in all the example scripts.

The NeuronTrainer class is optimized for 🤗 Transformers models running on AWS Trainium.

Here is an example of how to customize NeuronTrainer to use a weighted loss (useful when you have an unbalanced training set):

from torch import nn
from optimum.neuron import NeuronTrainer


class CustomNeuronTrainer(NeuronTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 3 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

Another way to customize the training loop behavior for the PyTorch NeuronTrainer is to use callbacks that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms…) and take decisions (like early stopping).

NeuronTrainer

class optimum.neuron.NeuronTrainer

< >

( *args **kwargs )

Trainer that is suited for performing training on AWS Tranium instances.

class optimum.neuron.Seq2SeqNeuronTrainer

< >

( *args **kwargs )

Seq2SeqTrainer that is suited for performing training on AWS Tranium instances.