Skip to main content

classy.pl_modules.base

Classes

ClassificationOutput

class ClassificationOutput()

ClassificationOutput(logits, probabilities, predictions, loss)

__init__

def __init__(
    logits: torch.Tensor,
    probabilities: torch.Tensor,
    predictions: torch.Tensor,
    loss: Optional[torch.Tensor] = None,
)

ClassyPLModule

class ClassyPLModule()

Simple Mixin to model the prediction behavior of a classy.pl_modules.base.ClassyPLModule.

Subclasses (4)

__init__

def __init__(
    vocabulary: Optional[Vocabulary],
    optim_conf: omegaconf.dictconfig.DictConfig,
)

configure_optimizers

def configure_optimizers(
    self,
)

forward

def forward(
    self,
    *args,
    **kwargs,
) ‑> ClassificationOutput

Same as :meth:torch.nn.Module.forward().

Args
*args
Whatever you decide to pass into the forward method.
**kwargs
Keyword arguments are also possible.
Return

Your model's output

load_prediction_params

def load_prediction_params(
    self,
    prediction_params: Dict[~KT, ~VT],
)