PretrainedEncoder

class trove.modeling.pretrained_encoder.PretrainedEncoder(args, training_args=None, preprocess_only=False, **kwargs)

A wrapper around different encoders.

This class wraps the encoder and takes care of model specific actions like saving and loading checkpoints, pooling, normalization, formatting inputs, etc.

PretrainedEncoder automatically detects and instantiates the correct wrapper subclass that can load a specific model.

To support a new type of encoder, inherit from this class and implement the class method cls.can_wrap(). This method takes the model name and its arguments and if it can wrap this model, returns True. Otherwise, returns False.

Do not overwrite or modify cls._model_registry.

classmethod can_wrap(model_name_or_path, args)

returns true if this wrapper can wrap the specified model with the given arguments.

Subclasses of PretrainedEncoder should implement this method. We use this method to automatically choose the correct subclass to wrap different models.

Parameters:
  • model_name_or_path (str) – name of the model to wrap.

  • args (ModelArguments) – arguments that describe the model to wrap.

Return type:

bool

Returns:

True if this class can wrap the model, and false otherwise.

classmethod find_appropriate_wrapper(model_name_or_path, args)

Find the appropriate wrapper than can wrap and load a specific model.

If args.encoder_class is set, then a subclass of PretrainedEncoder with that name (or alias) is returned. Otherwise, the model arguments are passed to the can_wrap() method of all the registered subclasses of PretrainedEncoder. The subclass that its can_wrap() method returns True is returned by this method. If the can_wrap() method of all subclasses return False and the checkpoint is a fine-tuned model, it finds the base model of the given checkpoint and repeats the same process to find the subclass that can load the base model.

Parameters:
  • model_name_or_path (str) – name or path of the model that we want to wrap.

  • args (ModelArguments) – arguments that describe the model that we want to wrap. cls.can_wrap() method of subclasses might use this in addition to model_name_or_path to determine if they can wrap the model.

Returns:

A pointer to the subclass that can wrap the given model.

classmethod from_model_args(args, model_name_or_path=None, training_args=None, **kwargs)

Instantiate the correct subclass of PretrainedEncoder that can wrap the specified model.

Parameters:
  • args (ModelArguments) – arguments that describe the model that we want to wrap.

  • model_name_or_path (Optional[str]) – name of the model to wrap. If not provided, use args.model_name_or_path.

  • training_args (TrainingArguments) – passed to __init__. Used for activating gradient checkpointing.

  • **kwargs – extra keyword arguments passed to the wrapper constructor.

Returns:

an instance of the correct wrapper class that wraps the specified model.

__init__(args, training_args=None, preprocess_only=False, **kwargs)

Wraps encoder models and provides methods and attributes for pre-processing encoder inputs.

If preprocess_only is True, you are expected to only provide attributes and methods required for preparing the input for the encoder (e.g., append_eos_token, format_query(), etc.). If possible you should avoid loading the model parameters if preprocess_only is True. This allows us to pre-process the data without loading the model, which leaves more resources (e.g., memory) for preprocessing operations.

Parameters:
  • args (ModelArguments) – config for instantiating the model

  • preprocess_only (bool) – if true, do not load the model parameters and just provide the attributes and methods necessary for pre-processing the input. E.g., append_eos_token, format_query(), etc.

  • training_args (TrainingArguments) – You can use this to activate gradient checkpointing if needed.

  • **kwargs – extra kwargs passed to PretrainedModel.from_pretrained when loading the encoder model.

encode(inputs)

calculate the embeddings for tokenized input.

Return type:

Tensor

format_query(text, **kwargs)

Format the query before passing it to tokenizer.

You can also ask for other parameters like dataset_name for example if your model uses different formatting for different datasets like intfloat/e5-mistral-7b-instruct

Return type:

str

format_passage(text, title=None, **kwargs)

Format the passage before passing it to tokenizer.

You can also ask for other parameters like dataset_name for example if your model uses different formatting for different datasets like intfloat/e5-mistral-7b-instruct

Return type:

str

save_pretrained(*args, **kwargs)

Save model parameters.

It should replicate the signature and behavior of transformers.PreTrainedModel.save_pretrained.

encode_query(inputs)

Overwrite if queries are encoded differently.

Return type:

Tensor

encode_passage(inputs)

Overwrite if passages are encoded differently.

Return type:

Tensor

similarity_fn(query, passage)

Similarity between query and passage embeddings.

Overwrite if your encoder uses a different similarity function.

Parameters:
  • query (torch.Tensor) – query embeddings. shape: [NUM_QUERIES, EMB_DIM]

  • passage (torch.Tensor) – passage embeddings. shape: [NUM_PASSAGES, EMB_DIM]

Return type:

Tensor

Returns:

query-passage similarities. shape is [NUM_QUERIES, NUM_PASSAGES]

compute_scores(query, passage)

Compute similarity score between tokenized query and passages.

Parameters:
  • query (Dict[str, torch.Tensor]) – query tokens

  • passage (Dict[str, torch.Tensor]) – passage tokens

Return type:

Tensor

Returns:

query-passage similarities. shape is [NUM_QUERIES, NUM_PASSAGES]

gradient_checkpointing_enable(*args, **kwargs)