Skip to main content

Custom Model


classy is built on top of PyTorch Lightning and, in order to better understand classy code infrastructure, we recommend going through PyTorch Lightning intro guide before proceeding.

Implementing your own model within classy is easy. You just need to:

  • subclass ClassyPLModule and your task mixin (SequenceTask, SentencePairTask, TokensTask, QATask)
  • implement abstract methods
  • (optional) override any other method

For instance, considering Sequence Classification, you would need to implement the following class:

# subclass your task and ClassyPLModule
class MyCustomClassyPLModule(SequenceTask, ClassyPLModule):
def __init__(
param1: Any,
param2: Any,
vocabulary: Vocabulary,
optim_conf: omegaconf.DictConfig,
super().__init__(vocabulary=vocabulary, optim_conf=optim_conf)
raise NotImplementedError

def forward(self, *args, **kwargs):
# standard pytorch forward
raise NotImplementedError

def batch_predict(self, *args, **kwargs) -> Iterator[Tuple[SequenceSample, str]]:
# wrapper for your forward method
# it takes as input the batches produced by your dataset
# it emits tuples (sequence sample, predicted label)
# decoding logic, such as converting labels from tensors to strings, goes here
raise NotImplementedError

# lightning hooks #

def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
raise NotImplementedError

def validation_step(self, batch: dict, batch_idx: int) -> None:
raise NotImplementedError

def test_step(self, batch: dict, batch_idx: int) -> None:
raise NotImplementedError

A Minimal Example

Practically, imagine you want to build a Sequence Classification model on top of a HuggingFace Transformer model.

class MyCustomClassyPLModule(SequenceTask, ClassyPLModule):

You first implement its constructor:

def __init__(
transformer_model: str,
vocabulary: Vocabulary,
optim_conf: omegaconf.DictConfig,
super().__init__(vocabulary=vocabulary, optim_conf=optim_conf)
num_classes = vocabulary.get_size(k="labels") # number of target classes
self.classifier = AutoModelForSequenceClassification.from_pretrained(
transformer_model, num_labels=num_classes
) # underlying classifier
self.accuracy_metric = (
) # metric to track your model performance

Then, you need to implement the PyTorch forward:

def forward(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
samples: List[SequenceSample],
token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> ClassificationOutput:
model_input = {"input_ids": input_ids, "attention_mask": attention_mask}
if token_type_ids is not None:
model_input["token_type_ids"] = token_type_ids
if labels is not None:
model_input["labels"] = labels
model_output = self.classifier(**model_input)
return ClassificationOutput(
probabilities=torch.softmax(model_output.logits, dim=-1),
predictions=torch.argmax(model_output.logits, dim=-1),

There's nothing really special about this forward. ClassificationOutput is just a dataclass to conveniently store logits, probabilities, predictions and loss. The only important thing is the signature: it must match with the batches your dataset emits (here, we are using

Then, there's the batch predict method, which wraps your forward method to emit classified SequenceSample-s:

def batch_predict(
self, *args, **kwargs
) -> Iterator[Tuple[Union[SequenceSample, SentencePairSample], str]]:
samples = kwargs.get("samples")
classification_output = self.forward(*args, **kwargs)
for sample, prediction in zip(samples, classification_output.predictions):
yield sample, self.vocabulary.get_elem(k="labels", idx=prediction.item())

You just invoke the forward method, and use the vocabulary to perform label tensor-to-string decoding.

Finally, you have to implement lightning hooks:

def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
classification_output = self.forward(**batch)
self.log("loss", classification_output.loss)
return classification_output.loss

def validation_step(self, batch: dict, batch_idx: int) -> None:
classification_output = self.forward(**batch)
self.accuracy_metric(classification_output.predictions, batch["labels"].squeeze(-1))
self.log("val_loss", classification_output.loss)
self.log("val_accuracy", self.accuracy_metric, prog_bar=True)

def test_step(self, batch: dict, batch_idx: int) -> None:
classification_output = self.forward(**batch)
self.accuracy_metric(classification_output.predictions, batch["labels"].squeeze(-1))
self.log("test_accuracy", self.accuracy_metric)

The only missing component is writing the configuration file:

_target_: 'classy.pl_modules.custom_model.MyCustomClassyPLModule'
transformer_model: ${transformer_model}
_target_: classy.optim.factories.TorchFactory
_target_: torch.optim.Adam
lr: 1e-5

and start the training:

classy train sequence <dataset-path> -c model=sequence-custom