InfoNCELoss

class trove.modeling.losses.InfoNCELoss(args, **kwargs)
__init__(args, **kwargs)

Implements InfoNCE loss.

forward(logits, **kwargs)

Calculates the loss given the similarity scores between query and passages.

The logits argument contains the similarity scores between queries and passages for the entire batch. For each query in the batch, there are K corresponding passages, where 1 is positive and k-1 are negative passages. For each query the positive passages comes first before negative passages, e.g., passages_for_q1 = [pos, neg, neg, neg, neg, ..., neg]

To use in-batch negatives, the passages for all queries are concatenated in a list. I.e., all_passages = passages_for_q1 + passages_for_q2 + ... + passages_for_qn As a result, the index of the positive passage for query i is i * K where K is the number of passages per query.

Parameters:
  • logits (torch.Tensor) – The similarity scores between queries and passages. shape: [NUM_QUERIES, NUM_PASSAGES]

  • **kwargs – Not used. Just to make the signature compatible with other losses.

Return type:

Tensor

Returns:

the InfoNCE loss value.