PretrainedRetriever

class trove.modeling.pretrained_retriever.PretrainedRetriever(model_args, encoder, preprocess_only=False, format_query=None, format_passage=None, append_eos_token=None, loss_extra_kwargs=None)
__init__(model_args, encoder, preprocess_only=False, format_query=None, format_passage=None, append_eos_token=None, loss_extra_kwargs=None)

A base class for training/inference with different retrievers.

Parameters:
  • model_args (ModelArguments) – config specifying the model and loss to use. Currently, we only use model_args to instantiate the loss (it might also be saved by some loggers). To instantiate both model and loss, use from_model_args().

  • encoder (nn.Module) – encoder model to use. It must expose encode_query() and encode_passage() methods. encoder is also expected to provide a save_pretrained() method but only if you call save_pretrained()

  • preprocess_only (bool) – if true, do not instantiate loss module. You should also pass this to from_model_args() if you do not want to load the model parameters. See PretrainedEncoder.__init__ for details.

  • format_query (Optional[Callable]) – Callable similar to PretrainedEncoder.format_query. If provided, it is prioritized over encoder.format_query(). It is not used by this class internally. It is just exposed as a convenience method to keep everything needed to encode a query in one place.

  • format_passage (Optional[Callable]) – Callable similar to PretrainedEncoder.format_passage() If provided, it is prioritized over encoder.format_passage. It is not used by this class internally. It is just exposed as a convenience method to keep everything needed to encode a query in one place.

  • append_eos_token (Optional[bool]) – Similar to PretrainedEncoder.append_eos_token If provided, it is prioritized over encoder.append_eos_token. It is not used by this class internally. It is just exposed as a convenience method to keep everything needed to encode a query in one place.

  • loss_extra_kwargs (Optional[Dict]) – If given, these are passed to RetrievalLoss.__init__() as keyword arguments.

classmethod from_model_args(args, model_name_or_path=None, training_args=None, loss_extra_kwargs=None, **kwargs)

Instantiate the retriever according based on the given args.

Parameters:
  • args (ModelArguments) – config used to instantiate the encoder and loss modules.

  • model_name_or_path (Optional[str]) – name of the encoder model to load. If not provided, use args.model_name_or_path. You should almost never need to use this.

  • training_args (TrainingArguments) – passed to PretrainedEncoder.from_model_args(). It is used to enable gradient checkpointing if needed.

  • loss_extra_kwargs (Optional[Dict]) – passed to BiEncoderRetriever.__init__() which then passes them to RetrievalLoss.__init__().

  • **kwargs – extra keyword arguments passed to BiEncoderRetriever.__init__() and PretrainedEncoder.from_model_args(). If you want to avoid loading model parameters and only load methods and attributes required for pre-processing, you should pass preprocess_only=True as part of this kwargs. See BiEncoderRetriever.__init__() and PretrainedEncoder.__init__() for details.

Returns:

an instance of one of PretrainedRetriever subclasses.

forward(query=None, passage=None, label=None, return_loss=True, **kwargs)

Encodes query and passages and potentially calculate similarity scores and loss.

Parameters:
  • query (Optional[Dict[str, torch.Tensor]]) – tokenized query.

  • passage (Optional[Dict[str, torch.Tensor]]) – tokenized passages.

  • label (Optional[torch.Tensor]) – Relevancy scores of the corresponding passages for each query. If there are k documents for each query, this is a 2D tensor of shape [NUM_QUERIES, k]

  • return_loss (Optional[bool]) – if true, calculate the loss value.

  • **kwargs – unused keyword arguments are passed to the forward() method of the loss module.

Return type:

RetrieverOutput

Returns:

If possible, query and passage embeddings as well as similarity and loss scores.

save_pretrained(*args, **kwargs)
class trove.modeling.pretrained_retriever.RetrieverOutput(query=None, passage=None, loss=None, logits=None)

Contains the output of the retriever model.

query: Optional[Tensor] = None

Query embeddings

passage: Optional[Tensor] = None

Passage embeddings

loss: Optional[Tensor] = None

Calculated loss

logits: Optional[Tensor] = None

similarity score between queries and passages