InfoNCELoss
- class trove.modeling.losses.InfoNCELoss(args, **kwargs)
- __init__(args, **kwargs)
Implements InfoNCE loss.
- forward(logits, **kwargs)
Calculates the loss given the similarity scores between query and passages.
The
logitsargument contains the similarity scores between queries and passages for the entire batch. For each query in the batch, there are K corresponding passages, where 1 is positive and k-1 are negative passages. For each query the positive passages comes first before negative passages, e.g.,passages_for_q1 = [pos, neg, neg, neg, neg, ..., neg]To use in-batch negatives, the passages for all queries are concatenated in a list. I.e.,
all_passages = passages_for_q1 + passages_for_q2 + ... + passages_for_qnAs a result, the index of the positive passage for query i is i * K where K is the number of passages per query.- Parameters:
logits (torch.Tensor) – The similarity scores between queries and passages. shape: [NUM_QUERIES, NUM_PASSAGES]
**kwargs – Not used. Just to make the signature compatible with other losses.
- Return type:
Tensor- Returns:
the InfoNCE loss value.