RetrievalLoss
- class trove.modeling.loss_base.RetrievalLoss(*args, **kwargs)
Base class for loss functions that can automatically detect the correct subclass to instantiate.
To add a new loss function, create a new class that inherits from
RetrievalLoss
. The users can instantiate the new loss function by its name. For example, if you do:class MyLoss(RetrievalLoss): _alias = 'foo_loss' ...
Then users can use
MyLoss
ifModelArguments.loss == "MyLoss"
orModelArguments.loss == "foo_loss"
.Do not overwrite or modify
cls._loss_registry
class attribute.- classmethod available_losses()
Prints a list of all available loss functions and their aliases.
- Return type:
None
- classmethod from_model_args(args, **kwargs)
Instantiate the correct subclass of BaseLoss based on
args.loss
.- Parameters:
args (ModelArguments) –
args.loss
is used to detect the correct loss function subclass.args
is also passed to the constructor of the target subclass.**kwargs – is passed to the constructor of the target subclass.
- Returns:
An instance of the specified loss subclass. Or
None
ifargs.loss is None
- forward(logits, label, **kwargs)
Calculates the loss given the similarity scores between query and passages.
The
logits
argument contains the similarity score between queries and passages (including in-batch negatives) for the entire batch. The shape of thelogits
argument is [NUM_QUERIES, NUM_QUERIES * DOCS_PER_QUERY]. The documents are organized sequentially. Basically, the related docs for the i_th query aredocs[i * DOCS_PER_QUERY: (i + 1) * DOCS_PER_QUERY]
and the rest are in-batch negatives.The
label
argument contains the groundtruth relevancy level between queries and documents for the entire batch.label
does NOT include in-batch negatives.label
tensor is of shape [NUM_QUERIES, DOCS_PER_QUERY]. We assign in-batch negatives, a relevancy level of 0 and extend thelabel
argument accordingly.You can use the following snippet of code to make sure
logits
andlabels
are of the same shape and give a label of zero (0) to in-batch negatives.if label.size(1) != logits.size(1): label = torch.block_diag(*torch.chunk(label, label.shape[0]))
- Parameters:
logits (torch.Tensor) – The similarity scores between queries and passages (including in-batch negatives). shape: [NUM_QUERIES, NUM_QUERIES * DOCS_PER_QUERY]
label (torch.Tensor) – The groundtruth relevancy level between queries and passages (excluding in-batch negatives). shape: [NUM_QUERIES, DOCS_PER_QUERY].
**kwargs – Not used. Just to make the signature compatible with other losses.
- Return type:
Tensor
- Returns:
the loss value.