PretrainedRetriever
- class trove.modeling.pretrained_retriever.PretrainedRetriever(model_args, encoder, preprocess_only=False, format_query=None, format_passage=None, append_eos_token=None, loss_extra_kwargs=None)
- __init__(model_args, encoder, preprocess_only=False, format_query=None, format_passage=None, append_eos_token=None, loss_extra_kwargs=None)
A base class for training/inference with different retrievers.
- Parameters:
model_args (ModelArguments) – config specifying the model and loss to use. Currently, we only use
model_args
to instantiate the loss (it might also be saved by some loggers). To instantiate both model and loss, usefrom_model_args()
.encoder (nn.Module) – encoder model to use. It must expose
encode_query()
andencode_passage()
methods.encoder
is also expected to provide asave_pretrained()
method but only if you callsave_pretrained()
preprocess_only (bool) – if true, do not instantiate loss module. You should also pass this to
from_model_args()
if you do not want to load the model parameters. SeePretrainedEncoder.__init__
for details.format_query (Optional[Callable]) – Callable similar to PretrainedEncoder.format_query. If provided, it is prioritized over
encoder.format_query()
. It is not used by this class internally. It is just exposed as a convenience method to keep everything needed to encode a query in one place.format_passage (Optional[Callable]) – Callable similar to
PretrainedEncoder.format_passage()
If provided, it is prioritized overencoder.format_passage
. It is not used by this class internally. It is just exposed as a convenience method to keep everything needed to encode a query in one place.append_eos_token (Optional[bool]) – Similar to
PretrainedEncoder.append_eos_token
If provided, it is prioritized overencoder.append_eos_token
. It is not used by this class internally. It is just exposed as a convenience method to keep everything needed to encode a query in one place.loss_extra_kwargs (Optional[Dict]) – If given, these are passed to
RetrievalLoss.__init__()
as keyword arguments.
- classmethod from_model_args(args, model_name_or_path=None, training_args=None, loss_extra_kwargs=None, **kwargs)
Instantiate the retriever according based on the given args.
- Parameters:
args (ModelArguments) – config used to instantiate the encoder and loss modules.
model_name_or_path (Optional[str]) – name of the encoder model to load. If not provided, use
args.model_name_or_path
. You should almost never need to use this.training_args (TrainingArguments) – passed to
PretrainedEncoder.from_model_args()
. It is used to enable gradient checkpointing if needed.loss_extra_kwargs (Optional[Dict]) – passed to
BiEncoderRetriever.__init__()
which then passes them toRetrievalLoss.__init__()
.**kwargs – extra keyword arguments passed to
BiEncoderRetriever.__init__()
andPretrainedEncoder.from_model_args()
. If you want to avoid loading model parameters and only load methods and attributes required for pre-processing, you should passpreprocess_only=True
as part of this kwargs. SeeBiEncoderRetriever.__init__()
andPretrainedEncoder.__init__()
for details.
- Returns:
an instance of one of PretrainedRetriever subclasses.
- forward(query=None, passage=None, label=None, return_loss=True, **kwargs)
Encodes query and passages and potentially calculate similarity scores and loss.
- Parameters:
query (Optional[Dict[str, torch.Tensor]]) – tokenized query.
passage (Optional[Dict[str, torch.Tensor]]) – tokenized passages.
label (Optional[torch.Tensor]) – Relevancy scores of the corresponding passages for each query. If there are k documents for each query, this is a 2D tensor of shape [NUM_QUERIES, k]
return_loss (Optional[bool]) – if true, calculate the loss value.
**kwargs – unused keyword arguments are passed to the
forward()
method of the loss module.
- Return type:
- Returns:
If possible, query and passage embeddings as well as similarity and loss scores.
- save_pretrained(*args, **kwargs)
- class trove.modeling.pretrained_retriever.RetrieverOutput(query=None, passage=None, loss=None, logits=None)
Contains the output of the retriever model.
-
query:
Optional
[Tensor
] = None Query embeddings
-
passage:
Optional
[Tensor
] = None Passage embeddings
-
loss:
Optional
[Tensor
] = None Calculated loss
-
logits:
Optional
[Tensor
] = None similarity score between queries and passages
-
query: