Modeling

Our goal is to be compatible with existing huggingface transformers ecosystem (e.g., PEFT and distributed training) and maintain this compatibility in the future with minimal changes to Trove. Trove’s goal is not to support and cover everything right out of the box. Instead, we want to keep the code simple and flexible so users can easily adapt it for their use case.

To achieve this, Trove models rely on an encoder object that encapsulates the most dynamic aspects of modeling like supporting different PEFT techniques or implementing new retrieval methods (e.g., retrieval using task instructions). We provide an optional abstraction and some helper functions to help with creating the encoder object. For maximum flexibility, Trove also accepts any arbitrary encoder object provided by the user, with minimal limitations to remain compatible with huggingface transformers.

Trove Models

Retriever variants (e.g., BiEncoderRetriever) are the main classes in Trove, i.e., retriever is what we use as model in training/inference scripts. Each retriever has an encoder attribute that is responsible for everything related to the backbone transformers model (e.g., Contriever). For example, encoder object should save/load the checkpoints and provide the logic for calculating the embedding vectors (e.g., pooling, normalization).

Trove provides three options for using a transformers model as encoder.

Arbitrary torch Module

We can create a retriever with an instance of torch.nn.Module as encoder, as long as it provides certain methods (see below).

my_custom_encoder: torch.nn.Module = ...
args = trove.ModelArguments(loss='infonce')
model = trove.BiEncoderRetriever(encoder=my_custom_encoder, model_args=args)

Warning

For training, you must make sure your custom encoder is compatible with huggingface transformers.Trainer.

Trove expects each encoder to provide several methods

  • encode_query(inputs) -> torch.Tensor: function that takes batched query tokens as inputs and returns the embedding vectors.

  • encode_passage(inputs) -> torch.Tensor: function that takes batched passage tokens as inputs and returns the embedding vectors.

  • save_pretrained(): the encoder must provide this method if we need to save checkpoints. The signature is the same as that of huggingface transformers models.

  • similarity_fn(query: torch.Tensor, passage: torch.Tensor) -> torch.Tensor It is optional. The default is the dot product between query and passage embeddings.

Warning

Trove retrievers like BiEncoderRetriever provide other methods like format_query and format_passage and attributes like append_eos_token. These are valid only if your encoder provides methods and attributes with the same names or if you pass these as arguments to retrievers __init__ method. Otherwise, these are set to default values which might not be valid for your specific encoder.

Loss Functions

If a loss function is already registered with Trove, you just need to specify its name in your model arguments. The retriever class automatically instantiates the corresponding loss class.

Tip

Use trove.RetrievalLoss.available_losses() to see the name of all available loss functions.

For example, to use infonce, you can do this:

args = trove.ModelArguments(loss='infonce', ...)
model = trove.BiEncoderRetriever.from_model_args(args=args)
# how you instantiate your encoder does not impact the loss function
# this instantiates the same loss class
encoder = MyEncoder(args)
model = trove.BiEncoderRetriever(encoder=encoder, model_args=args)

Tip

If a loss function supports or expects extra keyword arguments in its __init__ method, you can pass those keyword arguments by loss_extra_kwargs argument of the retriever like trove.BiEncoderRetriever.from_model_args(args=args, loss_extra_kwargs={...})

Attention

When using Trove’s builtin InfoNCE loss (InfoNCELoss), you must use an instance of BinaryDataset for training. InfoNCELoss ignores the given labels. Instead, it assumes the positive is the very first item in the list of passages for each query.

Custom Loss Functions

You can easily implement and register a new loss function with Trove. You just need to create a subclass of RetrievalLoss that implements your loss function.

Let’s go through an example that implements the KL divergence loss. Note that KL loss is already implemented in Trove and you can use it by setting model_args.loss="kl".

First, we inherit from RetrievalLoss and parse the arguments.

class MyCustomKLLoss(RetrievalLoss):
    _alias = "custom_kl"

    def __init__(self, args: ModelArguments, **kwargs) -> None:
        super().__init__()
        self.temperature = args.temperature

Next, we implement the forward method that calculates the loss value.

def forward(self, logits: torch.Tensor, label: torch.Tensor, **kwargs) -> torch.Tensor:
    if label.size(1) != logits.size(1):
        label = torch.block_diag(*torch.chunk(label, label.shape[0]))

    preds = F.log_softmax(logits / self.temperature, dim=1)
    targets = F.log_softmax(label.double(), dim=1)
    loss = F.kl_div(
        input=preds, target=targets, log_target=True, reduction="batchmean"
    )
    return loss

logits are similarity scores between all queries and all documents. In a distributed environment with multiple processes, logits includes the similarity scores even for in-batch negatives. But, label only has enteries for labeled documents, and not for in-batch negatives (e.g., only for positives and mined hard negatives). So, shape of label and logits diverges. To make label and logits sizes match, we assign label zero (0) to all in-batch negatives and expand label matrix by:

label = torch.block_diag(*torch.chunk(label, label.shape[0]))

To use the new loss function, we just need to specify its name in model arguments:

model_args = ModelArguments(loss="custom_kl", ...)
# or
model_args = ModelArguments(loss="MyCustomKLLoss", ...)

Retrieval Logic

As you have seen so far, the retriever class is the main model class used in training and evaluation scripts. Trove implements the bi-encoder retrieval logic (BiEncoderRetriever), which encodes the query and document separately and then calculates their similarity based on some metric like dot product.

Here is an example that shows how to use the retriever class. See PretrainedRetriever and BiEncoderRetriever documentation for more details.

model = trove.BiEncoderRetriever.from_model_args(...)
# embed queries
query_embs = model.encode_query(query_tokens)
# embed passages
passage_embs = model.encode_passage(passage_tokens)
# full forward pass
output = model(query=query_tokens, passage=passage_tokens, label=labels) # label is optional
print(output.query.shape) # query embeddings
print(output.passage.shape) # passage embeddings
print(output.logits.shape) # query-passage similarity scores
# if lables are given and retriever is instantiated with a loss module
print(output.loss)

Custom Retrieval Logic

To implement a new retrieval logic, you need to create a subclass of PretrainedRetriever and implement the forward() method. See PretrainedRetriever documentation for signature of the forward() method. You can follow BiEncoderRetriever code as an example.