BiEncoderRetriever

class trove.modeling.retriever_biencoder.BiEncoderRetriever(*args, **kwargs)
forward(query=None, passage=None, label=None, return_loss=True, **kwargs)

Encodes query and passages and potentially calculate similarity scores and loss.

Parameters:
  • query (Optional[Dict[str, torch.Tensor]]) – tokenized query.

  • passage (Optional[Dict[str, torch.Tensor]]) – tokenized passages.

  • label (Optional[torch.Tensor]) – Relevancy scores of the corresponding passages for each query. If there are k documents for each query, this is a 2D tensor of shape [NUM_QUERIES, k]

  • return_loss (Optional[bool]) – if true, calculate the loss value.

  • **kwargs – unused keyword arguments are passed to the forward() method of the loss module.

Return type:

RetrieverOutput

Returns:

If possible, query and passage embeddings as well as similarity and loss scores.