BiEncoderRetriever
- class trove.modeling.retriever_biencoder.BiEncoderRetriever(*args, **kwargs)
- forward(query=None, passage=None, label=None, return_loss=True, **kwargs)
Encodes query and passages and potentially calculate similarity scores and loss.
- Parameters:
query (Optional[Dict[str, torch.Tensor]]) – tokenized query.
passage (Optional[Dict[str, torch.Tensor]]) – tokenized passages.
label (Optional[torch.Tensor]) – Relevancy scores of the corresponding passages for each query. If there are k documents for each query, this is a 2D tensor of shape [NUM_QUERIES, k]
return_loss (Optional[bool]) – if true, calculate the loss value.
**kwargs – unused keyword arguments are passed to the
forward()
method of the loss module.
- Return type:
- Returns:
If possible, query and passage embeddings as well as similarity and loss scores.