Skip to main content

classy.pl_modules.mixins.prediction

Classes

PredictionMixin

class PredictionMixin()

Simple Mixin to model the prediction behavior of a classy.pl_modules.base.ClassyPLModule.

Subclasses (1)

batch_predict

def batch_predict(
    self,
    *args,
    **kwargs,
) ‑> Iterator[ClassySample]

General method that must be implemented by each classy.pl_modules.base.ClassyPLModule in order to perform batch prediction.

Returns

An iterator over a collection of samples with the predicted annotation updated with the model outputs.

predict

def predict(
    self,
    samples: Iterator[ClassySample],
    dataset_conf: Union[Dict[~KT, ~VT], omegaconf.dictconfig.DictConfig],
    token_batch_size: int = 1024,
    progress_bar: bool = False,
    **kwargs,
) ‑> Iterator[ClassySample]

Exposed method of each classy.pl_modules.base.ClassyPLModule invoked to annotate a collection of input samples.

Args
samples
iterator over the samples that have to be annotated.
dataset_conf
the dataset configuration used to instantiate the Dataset with hydra.
token_batch_size
the maximum number of tokens in each batch.
progress_bar
whether or not to show a progress bar of the prediction process.
**kwargs
additional parameters. (Future proof atm)
Returns

An iterator over the input samples with the predicted annotation updated.