PretrainedEncoder
- class trove.modeling.pretrained_encoder.PretrainedEncoder(args, training_args=None, preprocess_only=False, **kwargs)
A wrapper around different encoders.
This class wraps the encoder and takes care of model specific actions like saving and loading checkpoints, pooling, normalization, formatting inputs, etc.
PretrainedEncoder
automatically detects and instantiates the correct wrapper subclass that can load a specific model.To support a new type of encoder, inherit from this class and implement the class method
cls.can_wrap()
. This method takes the model name and its arguments and if it can wrap this model, returns True. Otherwise, returns False.Do not overwrite or modify
cls._model_registry
.- classmethod can_wrap(model_name_or_path, args)
returns true if this wrapper can wrap the specified model with the given arguments.
Subclasses of
PretrainedEncoder
should implement this method. We use this method to automatically choose the correct subclass to wrap different models.- Parameters:
model_name_or_path (str) – name of the model to wrap.
args (ModelArguments) – arguments that describe the model to wrap.
- Return type:
bool
- Returns:
True if this class can wrap the model, and false otherwise.
- classmethod find_appropriate_wrapper(model_name_or_path, args)
Find the appropriate wrapper than can wrap and load a specific model.
If
args.encoder_class
is set, then a subclass ofPretrainedEncoder
with that name (or alias) is returned. Otherwise, the model arguments are passed to thecan_wrap()
method of all the registered subclasses ofPretrainedEncoder
. The subclass that itscan_wrap()
method returnsTrue
is returned by this method. If thecan_wrap()
method of all subclasses returnFalse
and the checkpoint is a fine-tuned model, it finds the base model of the given checkpoint and repeats the same process to find the subclass that can load the base model.- Parameters:
model_name_or_path (str) – name or path of the model that we want to wrap.
args (ModelArguments) – arguments that describe the model that we want to wrap.
cls.can_wrap()
method of subclasses might use this in addition tomodel_name_or_path
to determine if they can wrap the model.
- Returns:
A pointer to the subclass that can wrap the given model.
- classmethod from_model_args(args, model_name_or_path=None, training_args=None, **kwargs)
Instantiate the correct subclass of
PretrainedEncoder
that can wrap the specified model.- Parameters:
args (ModelArguments) – arguments that describe the model that we want to wrap.
model_name_or_path (Optional[str]) – name of the model to wrap. If not provided, use
args.model_name_or_path
.training_args (TrainingArguments) – passed to __init__. Used for activating gradient checkpointing.
**kwargs – extra keyword arguments passed to the wrapper constructor.
- Returns:
an instance of the correct wrapper class that wraps the specified model.
- __init__(args, training_args=None, preprocess_only=False, **kwargs)
Wraps encoder models and provides methods and attributes for pre-processing encoder inputs.
If
preprocess_only
is True, you are expected to only provide attributes and methods required for preparing the input for the encoder (e.g.,append_eos_token
,format_query()
, etc.). If possible you should avoid loading the model parameters ifpreprocess_only
is True. This allows us to pre-process the data without loading the model, which leaves more resources (e.g., memory) for preprocessing operations.- Parameters:
args (ModelArguments) – config for instantiating the model
preprocess_only (bool) – if true, do not load the model parameters and just provide the attributes and methods necessary for pre-processing the input. E.g.,
append_eos_token
,format_query()
, etc.training_args (TrainingArguments) – You can use this to activate gradient checkpointing if needed.
**kwargs – extra kwargs passed to PretrainedModel.from_pretrained when loading the encoder model.
- encode(inputs)
calculate the embeddings for tokenized input.
- Return type:
Tensor
- format_query(text, **kwargs)
Format the query before passing it to tokenizer.
You can also ask for other parameters like dataset_name for example if your model uses different formatting for different datasets like intfloat/e5-mistral-7b-instruct
- Return type:
str
- format_passage(text, title=None, **kwargs)
Format the passage before passing it to tokenizer.
You can also ask for other parameters like dataset_name for example if your model uses different formatting for different datasets like intfloat/e5-mistral-7b-instruct
- Return type:
str
- save_pretrained(*args, **kwargs)
Save model parameters.
It should replicate the signature and behavior of
transformers.PreTrainedModel.save_pretrained
.
- encode_query(inputs)
Overwrite if queries are encoded differently.
- Return type:
Tensor
- encode_passage(inputs)
Overwrite if passages are encoded differently.
- Return type:
Tensor
- similarity_fn(query, passage)
Similarity between query and passage embeddings.
Overwrite if your encoder uses a different similarity function.
- Parameters:
query (torch.Tensor) – query embeddings. shape: [NUM_QUERIES, EMB_DIM]
passage (torch.Tensor) – passage embeddings. shape: [NUM_PASSAGES, EMB_DIM]
- Return type:
Tensor
- Returns:
query-passage similarities. shape is [NUM_QUERIES, NUM_PASSAGES]
- compute_scores(query, passage)
Compute similarity score between tokenized query and passages.
- Parameters:
query (Dict[str, torch.Tensor]) – query tokens
passage (Dict[str, torch.Tensor]) – passage tokens
- Return type:
Tensor
- Returns:
query-passage similarities. shape is [NUM_QUERIES, NUM_PASSAGES]
- gradient_checkpointing_enable(*args, **kwargs)