Skip to main content

classy.pl_modules.hf.generation

Classes

BartGenerativeModule

class BartGenerativeModule()

Simple Mixin to model the prediction behavior of a classy.pl_modules.base.ClassyPLModule.

Subclasses (1)

__init__

def __init__(
    transformer_model: str,
    decoding_skip_special_tokens: bool,
    decoding_clean_up_tokenization_spaces: bool,
    optim_conf: omegaconf.dictconfig.DictConfig,
    additional_special_tokens: Optional[List[str]] = None,
)

load_prediction_params

def load_prediction_params(
    self,
    prediction_params: Dict[~KT, ~VT],
)

GPT2GenerativeModule

class GPT2GenerativeModule()

Simple Mixin to model the prediction behavior of a classy.pl_modules.base.ClassyPLModule.

__init__

def __init__(
    transformer_model: str,
    decoding_skip_special_tokens: bool,
    decoding_clean_up_tokenization_spaces: bool,
    optim_conf: omegaconf.dictconfig.DictConfig,
    additional_special_tokens: Optional[List[str]] = None,
)

load_prediction_params

def load_prediction_params(
    self,
    prediction_params: Dict[~KT, ~VT],
)

HFGenerationPLModule

class HFGenerationPLModule()

Simple Mixin to model the prediction behavior of a classy.pl_modules.base.ClassyPLModule.

Subclasses (3)

__init__

def __init__(
    vocabulary: Optional[Vocabulary],
    optim_conf: omegaconf.dictconfig.DictConfig,
)

test_step

def test_step(
    self,
    batch: dict,
    batch_idx: int,
) ‑> None

training_step

def training_step(
    self,
    batch: dict,
    batch_idx: int,
) ‑> torch.Tensor

validation_step

def validation_step(
    self,
    batch: dict,
    batch_idx: int,
) ‑> None

MBartGenerativeModule

class MBartGenerativeModule()

Simple Mixin to model the prediction behavior of a classy.pl_modules.base.ClassyPLModule.

__init__

def __init__(
    *args,
    **kwargs,
)

T5GenerativeModule

class T5GenerativeModule()

Simple Mixin to model the prediction behavior of a classy.pl_modules.base.ClassyPLModule.

__init__

def __init__(
    transformer_model: str,
    decoding_skip_special_tokens: bool,
    decoding_clean_up_tokenization_spaces: bool,
    optim_conf: omegaconf.dictconfig.DictConfig,
    additional_special_tokens: Optional[List[str]] = None,
)

load_prediction_params

def load_prediction_params(
    self,
    prediction_params: Dict[~KT, ~VT],
)