Training
Trove uses huggingface transformers for training.
You should just use trove.RetrievalTrainer
instead of transformers.Trainer
, which makes small modifications to allow saving the checkpoints for PretrainedRetriever
subclasses.
It also scales up the loss value for aggregation in distributed environments (the effective loss value remains the same).
Everything else remains the same as transformers.Trainer
.
Workflow
Here we explain the workflow of Trove for training. The example explained here is roughly equal to this script.
Loading the Model
First, we create an instance of RetrievalTrainingArguments
. This is identical to transformers.TrainingArguments
and just adds one extra option (trove_logging_mode
) to control the logging mode for Trove.
You have access to all arguments in transformers.TrainingArguments like learning rate, save frequency, etc.
from trove import RetrievalTrainingArguments
train_args = RetrievalTrainingArguments(output_dir="./my_model", learning_rate=1e-5, ...)
Next, we create an instance of ModelArguments
which determines how the model should be loaded and used.
from trove import ModelArguments
model_args = ModelArguments(
model_name_or_path="facebook/contriever",
encoder_class="default",
pooling="mean",
normalize=False,
loss="infonce"
)
We need to specify which wrapper should be used to load the encoder.
encoder_class="default"
means that the model checkpoint (model_name_or_path
) should be loaded with a subclass of PretrainedEncoder
that its name (or alias) is "default"
.
Tip
If the specified encoder wrapper supports it, we can use ModelArguments
to ask the wrapper to quantize the encoder or add LORA adapters using options like use_peft
and load_in_4bit
.
The default wrapper supports both LORA and quantization.
Similarly, loss="infonce"
specifies the name (or alias) of the loss function that should be instantiated.
Tip
Use trove.RetrievalLoss.available_losses()
to see the name of all available loss functions.
See Modeling for how you can add custom encoder wrappers and loss functions.
Next, we create a bi-encoder retriever (BiEncoderRetriever
) using this config.
The retriever instantiates the encoder and loss function based on the given model_args
.
from trove import BiEncoderRetriever
model = BiEncoderRetriever.from_model_args(args=model_args)
Creating Training Dataset
First, we create an instance of DataArguments
to specify how the data should be processed.
The dataset_name
is helpful if your encoder expects the inputs to be processed differently based on the dataset (e.g., use different task instructions for each dataset).
group_size
is the number of documents used for each query.
In this example, we will create a binary dataset, which means we will have one positive and 15 negatives for each query.
from trove import DataArguments
data_args = DataArguments(
dataset_name="msmarco",
group_size=16,
query_max_len=32,
passage_max_len=128
)
Next, we create two instances of MaterializedQRelConfig
for negatives and positives to specify where to find the data and how to load and process it.
from trove import MaterializedQRelConfig
pos_conf = MaterializedQRelConfig(
qrel_path="train_qrel.tsv",
corpus_path="corpus.jsonl",
query_path="queries.jsonl",
min_score=1,
)
neg_conf = MaterializedQRelConfig(
qrel_path="train_qrel.tsv",
corpus_path="corpus.jsonl",
query_path="queries.jsonl",
max_score=1,
)
Let’s consider positives first.
The above snippet says that queries should be loaded from the queries.jsonl
file and documents should be loaded from corpus.jsonl
file.
And, the annotations should be loaded from train_qrel.tsv
file.
Importantly, we filter the documents and do not use all the annotated documents as positives.
But, we restrict positives to only documents that their label is greater than or equal to one (min_score=1
).
Similarly for negatives, we load the query, corpus, and annotations from the same set of files as positives.
But this time, we restrict negatives to only documents that their label is less than one (max_score=1
), effectively anything with label zero.
Note that you can also use a list of filenames instead of a single file and the results are merged.
Tip
You can apply more complex data processing pipelines like filtering with arbitrary functions, transforming the scores, etc. See Data for more information.
Now, we creae a binary training dataset (BinaryDataset
) using these negative and positive documents.
from trove import BinaryDataset
dataset = BinaryDataset(
data_args=data_args,
positive_configs=pos_conf,
negative_configs=neg_conf,
format_query=model.format_query,
format_passage=model.format_passage,
)
The format_query
and format_passage
methods of the model take a raw query and passage and apply whatever processing that the encoder expects.
For example, for models like E5-mistral, these functions are expected to add a task instruction to the query.
Tip
You can use a list of config objects for each of the positive_configs
and negative_configs
arguments to create more complex data pipelines.
This allows you to combine your positives and negatives each from multiple sources.
You can even process each data source differently before merging.
Tip
BinaryDataset
is suitable for training with binary relevance labels using InfoNCE loss.
If you want to train with multiple levels of relevance (e.g., labels from {0, 1, 2, 3}
), you need
to use MultiLevelDataset
instead of BinaryDataset
.
The process is very similar.
You just need one set of MaterializedQRelConfig
objects instead of the two sets expected by the binary dataset for positives and negatives.
We also create a data collator (RetrievalCollator
) that takes care of processes like tokenization, truncation, padding, etc.
from transformers import AutoTokenizer
from trove import RetrievalCollator
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"
data_collator = RetrievalCollator(
data_args=data_args,
tokenizer=tokenizer,
append_eos=model.append_eos_token,
)
If append_eos
is set to True
, the collator makes sure all input sequences end with an eos
token.
This is helpful when using last-token pooling.
Ideally, encoders should specify if they expect an eos
token or not. So, we can use model.append_eos_token
to correctly config the data pipeline without any manual effort.
Trainer
Finally, we create an instance of RetrievalTrainer
for training.
RetrievalTrainer
is almost identical to transformers.Trainer
with very small changes, which does not impact how it is used.
In users’ training scripts, everything is the same as transformers.Trainer
.
from trove import RetrievalTrainer
trainer = RetrievalTrainer(
args=train_args,
model=model,
tokenizer=tokenizer,
data_collator=data_collator,
train_dataset=dataset,
)
trainer.train()
trainer.save_model()
Distributed Training
Trove is fully compatible with huggingface transformers ecosystem. So, you can just launch your script using a distributed launcher and that is all you need to do for distributed training across multiple nodes and GPUs.
Similarly, you can use deepspeed just as you do with any transformers training script.
deepspeed --include localhost:0,1 my_script.py \
--deepspeed 'my_deepspeed_config.json'
# rest of your script arguments
Note
Trove retrievers automatically collect in-batch negatives from across devices and nodes.
IR Metrics During Training
During training, Trove can report IR metrics like nDCG for a dev set. Since calculating exact metrics multiple times on the entire corpus is too expensive, we choose to approximate IR metrics on a small subset of annotated documents for each query. It is sort of similar to a reranking task. For instance, given a dev set that provides a reasonable number of annotations (~100) per query, we can rank only these annotated documents (and not the entire corpus) for each query and calculate IR metrics based on that.
Most of the above code remains the same except for a few changes. First we need to update the training arguments.
train_args = RetrievalTrainingArguments(
...
batch_eval_metrics=True,
label_names=["label"],
eval_strategy="steps",
eval_steps=1000,
)
Next you need to create an evaluation dataset.
eval_mqrel = MaterializedQRelConfig(
qrel_path="dev_qrel.tsv",
corpus_path="corpus.jsonl",
query_path="queries.jsonl",
)
Tip
Often such annotated dev sets are not available and we only have a few positives for each query. In these cases, we can mine a limited number of negatives (~100) for each query in the dev set. We then combine these mined negatives with annotated positives to create a dev set for approximating IR metrics during training.
eval_mqrel = [
MaterializedQRelConfig(
qrel_path="dev_qrel_positives.tsv",
corpus_path="corpus.jsonl",
query_path="queries.jsonl",
score_transform=1
),
MaterializedQRelConfig(
qrel_path="dev_mined_negs.tsv",
corpus_path="corpus.jsonl",
query_path="queries.jsonl",
score_transform=0
)
]
For evaluations, we must use MultiLevelDataset
even if our labels are binary.
arg_overrides = {"group_size": 100, "passage_selection_strategy": "most_relevant"}
eval_dataset = MultiLevelDataset(
data_args=data_args,
format_query=model.format_query,
format_passage=model.format_passage,
qrel_config=eval_mqrel,
data_args_overrides=arg_overrides,
num_proc=8,
)
We reuse the same data_args
object that we used for the training dataset; but we override the value of some attributes with arg_overrides
.
To calculate approximate metrics, Trove expects all queries to have the same number of documents (for easier batching).
So, we set "group_size":100
to make sure all queries have 100 annotated documents.
If a query has more than 100 documents, MultiLevelDataset
uses a subset of annotated documents.
To make sure the positive documents are included in the subst, we set "passage_selection_strategy": "most_relevant"
.
If a query has fewer than 100 annotated documents, MultiLevelDataset
duplicates some documents.
We also need to create a stateful callback function (IRMetrics
) to compute the metrics for each batch of eval data.
from trove import IRMetrics
# k_values are cutoff values for IR metrics
metric_callback = IRMetrics(k_values=[10, 100])
Finally, we add a few extra arguments when instantiating and using the trainer.
trainer = RetrievalTrainer(
...
eval_dataset=eval_dataset,
compute_metrics=metric_callback,
)
trainer.train(ignore_keys_for_eval=["query", "passage"])