RetrievalLoss

class trove.modeling.loss_base.RetrievalLoss(*args, **kwargs)

Base class for loss functions that can automatically detect the correct subclass to instantiate.

To add a new loss function, create a new class that inherits from RetrievalLoss. The users can instantiate the new loss function by its name. For example, if you do:

class MyLoss(RetrievalLoss):
    _alias = 'foo_loss'
    ...

Then users can use MyLoss if ModelArguments.loss == "MyLoss" or ModelArguments.loss == "foo_loss".

Do not overwrite or modify cls._loss_registry class attribute.

classmethod available_losses()

Prints a list of all available loss functions and their aliases.

Return type:

None

classmethod from_model_args(args, **kwargs)

Instantiate the correct subclass of BaseLoss based on args.loss.

Parameters:
  • args (ModelArguments) – args.loss is used to detect the correct loss function subclass. args is also passed to the constructor of the target subclass.

  • **kwargs – is passed to the constructor of the target subclass.

Returns:

An instance of the specified loss subclass. Or None if args.loss is None

forward(logits, label, **kwargs)

Calculates the loss given the similarity scores between query and passages.

The logits argument contains the similarity score between queries and passages (including in-batch negatives) for the entire batch. The shape of the logits argument is [NUM_QUERIES, NUM_QUERIES * DOCS_PER_QUERY]. The documents are organized sequentially. Basically, the related docs for the i_th query are docs[i * DOCS_PER_QUERY: (i + 1) * DOCS_PER_QUERY] and the rest are in-batch negatives.

The label argument contains the groundtruth relevancy level between queries and documents for the entire batch. label does NOT include in-batch negatives. label tensor is of shape [NUM_QUERIES, DOCS_PER_QUERY]. We assign in-batch negatives, a relevancy level of 0 and extend the label argument accordingly.

You can use the following snippet of code to make sure logits and labels are of the same shape and give a label of zero (0) to in-batch negatives.

if label.size(1) != logits.size(1):
    label = torch.block_diag(*torch.chunk(label, label.shape[0]))
Parameters:
  • logits (torch.Tensor) – The similarity scores between queries and passages (including in-batch negatives). shape: [NUM_QUERIES, NUM_QUERIES * DOCS_PER_QUERY]

  • label (torch.Tensor) – The groundtruth relevancy level between queries and passages (excluding in-batch negatives). shape: [NUM_QUERIES, DOCS_PER_QUERY].

  • **kwargs – Not used. Just to make the signature compatible with other losses.

Return type:

Tensor

Returns:

the loss value.