classy.pl_modules.mixins.prediction
Classes
PredictionMixin
Simple Mixin to model the prediction behavior of a classy.pl_modules.base.ClassyPLModule.
Subclasses (1)
batch_predict
General method that must be implemented by each classy.pl_modules.base.ClassyPLModule in order to perform batch prediction.
Returns
An iterator over a collection of samples with the predicted annotation updated with the model outputs.
predict
def predict(
self,
samples: Iterator[ClassySample],
dataset_conf: Union[Dict[~KT, ~VT], omegaconf.dictconfig.DictConfig],
token_batch_size: int = 1024,
progress_bar: bool = False,
**kwargs,
) ‑> Iterator[ClassySample]
self,
samples: Iterator[ClassySample],
dataset_conf: Union[Dict[~KT, ~VT], omegaconf.dictconfig.DictConfig],
token_batch_size: int = 1024,
progress_bar: bool = False,
**kwargs,
) ‑> Iterator[ClassySample]
Exposed method of each classy.pl_modules.base.ClassyPLModule invoked to annotate a collection of input samples.
Args
samples
- iterator over the samples that have to be annotated.
dataset_conf
- the dataset configuration used to instantiate the Dataset with hydra.
token_batch_size
- the maximum number of tokens in each batch.
progress_bar
- whether or not to show a progress bar of the prediction process.
**kwargs
- additional parameters. (Future proof atm)
Returns
An iterator over the input samples with the predicted annotation updated.