Training
==============
.. role:: raw-html(raw)
:format: html
.. py:currentmodule:: trove
.. |RetrievalTrainingArguments| replace:: :py:class:`~trove.trainer.RetrievalTrainingArguments`
.. |ModelArguments| replace:: :py:class:`~trove.modeling.model_args.ModelArguments`
.. |BiEncoderRetriever| replace:: :py:class:`~trove.modeling.retriever_biencoder.BiEncoderRetriever`
.. |DataArguments| replace:: :py:class:`~trove.data.data_args.DataArguments`
.. |MaterializedQRelConfig| replace:: :py:class:`~trove.containers.materialized_qrel_config.MaterializedQRelConfig`
.. |PretrainedEncoder| replace:: :py:class:`~trove.modeling.pretrained_encoder.PretrainedEncoder`
.. |BinaryDataset| replace:: :py:class:`~trove.data.ir_dataset_binary.BinaryDataset`
.. |MultiLevelDataset| replace:: :py:class:`~trove.data.ir_dataset_multilevel.MultiLevelDataset`
.. |RetrievalCollator| replace:: :py:class:`~trove.data.collator.RetrievalCollator`
.. |RetrievalTrainer| replace:: :py:class:`~trove.trainer.RetrievalTrainer`
.. |IRMetrics| replace:: :py:class:`~trove.evaluation.metrics.IRMetrics`
Trove uses huggingface transformers for training.
You should just use ``trove.RetrievalTrainer`` instead of ``transformers.Trainer``, which makes small modifications to allow saving the checkpoints for ``PretrainedRetriever`` subclasses.
It also scales up the loss value for aggregation in distributed environments (the effective loss value remains the same).
Everything else remains the same as ``transformers.Trainer``.
Workflow
---------------------
Here we explain the workflow of Trove for training.
The example explained here is roughly equal to `this script `_.
Loading the Model
~~~~~~~~~~~~~~~~~~~~~~
First, we create an instance of |RetrievalTrainingArguments|. This is identical to ``transformers.TrainingArguments`` and just adds one extra option (``trove_logging_mode``) to control the logging mode for Trove.
You have access to all arguments in `transformers.TrainingArguments `_ like learning rate, save frequency, etc.
.. code-block:: python
from trove import RetrievalTrainingArguments
train_args = RetrievalTrainingArguments(output_dir="./my_model", learning_rate=1e-5, ...)
Next, we create an instance of |ModelArguments| which determines how the model should be loaded and used.
.. code-block:: python
from trove import ModelArguments
model_args = ModelArguments(
model_name_or_path="facebook/contriever",
encoder_class="default",
pooling="mean",
normalize=False,
loss="infonce"
)
We need to specify which wrapper should be used to load the encoder.
``encoder_class="default"`` means that the model checkpoint (``model_name_or_path``) should be loaded with a subclass of |PretrainedEncoder| that its name (or alias) is ``"default"``.
.. tip::
If the specified encoder wrapper supports it, we can use |ModelArguments| to ask the wrapper to quantize the encoder or add LORA adapters using options like ``use_peft`` and ``load_in_4bit``.
The default wrapper supports both LORA and quantization.
Similarly, ``loss="infonce"`` specifies the name (or alias) of the loss function that should be instantiated.
.. tip::
Use ``trove.RetrievalLoss.available_losses()`` to see the name of all available loss functions.
See :doc:`Modeling ` for how you can add custom encoder wrappers and loss functions.
Next, we create a bi-encoder retriever (|BiEncoderRetriever|) using this config.
The retriever instantiates the encoder and loss function based on the given ``model_args``.
.. code-block:: python
from trove import BiEncoderRetriever
model = BiEncoderRetriever.from_model_args(args=model_args)
Creating Training Dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~
First, we create an instance of |DataArguments| to specify how the data should be processed.
The ``dataset_name`` is helpful if your encoder expects the inputs to be processed differently based on the dataset (e.g., use different task instructions for each dataset).
``group_size`` is the number of documents used for each query.
In this example, we will create a binary dataset, which means we will have one positive and 15 negatives for each query.
.. code-block:: python
from trove import DataArguments
data_args = DataArguments(
dataset_name="msmarco",
group_size=16,
query_max_len=32,
passage_max_len=128
)
Next, we create two instances of |MaterializedQRelConfig| for negatives and positives to specify where to find the data and how to load and process it.
.. code-block:: python
from trove import MaterializedQRelConfig
pos_conf = MaterializedQRelConfig(
qrel_path="train_qrel.tsv",
corpus_path="corpus.jsonl",
query_path="queries.jsonl",
min_score=1,
)
neg_conf = MaterializedQRelConfig(
qrel_path="train_qrel.tsv",
corpus_path="corpus.jsonl",
query_path="queries.jsonl",
max_score=1,
)
Let's consider positives first.
The above snippet says that queries should be loaded from the ``queries.jsonl`` file and documents should be loaded from ``corpus.jsonl`` file.
And, the annotations should be loaded from ``train_qrel.tsv`` file.
Importantly, we filter the documents and do not use all the annotated documents as positives.
But, we restrict positives to only documents that their label is greater than or equal to one (``min_score=1``).
Similarly for negatives, we load the query, corpus, and annotations from the same set of files as positives.
But this time, we restrict negatives to only documents that their label is less than one (``max_score=1``), effectively anything with label zero.
Note that you can also use a list of filenames instead of a single file and the results are merged.
.. tip::
You can apply more complex data processing pipelines like filtering with arbitrary functions, transforming the scores, etc.
See :doc:`data` for more information.
Now, we creae a binary training dataset (|BinaryDataset|) using these negative and positive documents.
.. code-block:: python
from trove import BinaryDataset
dataset = BinaryDataset(
data_args=data_args,
positive_configs=pos_conf,
negative_configs=neg_conf,
format_query=model.format_query,
format_passage=model.format_passage,
)
The ``format_query`` and ``format_passage`` methods of the model take a raw query and passage and apply whatever processing that the encoder expects.
For example, for models like E5-mistral, these functions are expected to add a task instruction to the query.
.. tip::
You can use a list of config objects for each of the ``positive_configs`` and ``negative_configs`` arguments to create more complex data pipelines.
This allows you to combine your positives and negatives each from multiple sources.
You can even process each data source differently before merging.
.. tip::
``BinaryDataset`` is suitable for training with binary relevance labels using InfoNCE loss.
If you want to train with multiple levels of relevance (e.g., labels from ``{0, 1, 2, 3}``), you need
to use |MultiLevelDataset| instead of |BinaryDataset|.
The process is very similar.
You just need one set of ``MaterializedQRelConfig`` objects instead of the two sets expected by the binary dataset for positives and negatives.
We also create a data collator (|RetrievalCollator|) that takes care of processes like tokenization, truncation, padding, etc.
.. code-block:: python
from transformers import AutoTokenizer
from trove import RetrievalCollator
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"
data_collator = RetrievalCollator(
data_args=data_args,
tokenizer=tokenizer,
append_eos=model.append_eos_token,
)
If ``append_eos`` is set to ``True``, the collator makes sure all input sequences end with an ``eos`` token.
This is helpful when using last-token pooling.
Ideally, encoders should specify if they expect an ``eos`` token or not. So, we can use ``model.append_eos_token`` to correctly config the data pipeline without any manual effort.
Trainer
~~~~~~~~~~~~~~~~~~~~~~
Finally, we create an instance of |RetrievalTrainer| for training.
|RetrievalTrainer| is almost identical to ``transformers.Trainer`` with very small changes, which does not impact how it is used.
In users' training scripts, everything is the same as ``transformers.Trainer``.
.. code-block:: python
from trove import RetrievalTrainer
trainer = RetrievalTrainer(
args=train_args,
model=model,
tokenizer=tokenizer,
data_collator=data_collator,
train_dataset=dataset,
)
trainer.train()
trainer.save_model()
Distributed Training
---------------------
Trove is fully compatible with huggingface transformers ecosystem.
So, you can just launch your script using a distributed launcher and that is all you need to do for distributed training across multiple nodes and GPUs.
Similarly, you can use deepspeed just as you do with any transformers training script.
.. code-block:: bash
deepspeed --include localhost:0,1 my_script.py \
--deepspeed 'my_deepspeed_config.json'
# rest of your script arguments
.. note::
Trove retrievers automatically collect in-batch negatives from across devices and nodes.
IR Metrics During Training
-----------------------------
During training, Trove can report IR metrics like nDCG for a dev set.
Since calculating exact metrics multiple times on the entire corpus is too expensive, we choose to approximate IR metrics on a small subset of annotated documents for each query.
It is sort of similar to a reranking task.
For instance, given a dev set that provides a reasonable number of annotations (~100) per query, we can rank `only` these annotated documents (and not the entire corpus) for each query and calculate IR metrics based on that.
Most of the above code remains the same except for a few changes.
First we need to update the training arguments.
.. code-block:: python
train_args = RetrievalTrainingArguments(
...
batch_eval_metrics=True,
label_names=["label"],
eval_strategy="steps",
eval_steps=1000,
)
Next you need to create an evaluation dataset.
.. code-block:: python
eval_mqrel = MaterializedQRelConfig(
qrel_path="dev_qrel.tsv",
corpus_path="corpus.jsonl",
query_path="queries.jsonl",
)
.. tip::
Often such annotated dev sets are not available and we only have a few positives for each query.
In these cases, we can mine a limited number of negatives (~100) for each query in the dev set.
We then combine these mined negatives with annotated positives to create a dev set for approximating IR metrics during training.
.. code-block:: python
eval_mqrel = [
MaterializedQRelConfig(
qrel_path="dev_qrel_positives.tsv",
corpus_path="corpus.jsonl",
query_path="queries.jsonl",
score_transform=1
),
MaterializedQRelConfig(
qrel_path="dev_mined_negs.tsv",
corpus_path="corpus.jsonl",
query_path="queries.jsonl",
score_transform=0
)
]
For evaluations, we must use |MultiLevelDataset| even if our labels are binary.
.. code-block:: python
arg_overrides = {"group_size": 100, "passage_selection_strategy": "most_relevant"}
eval_dataset = MultiLevelDataset(
data_args=data_args,
format_query=model.format_query,
format_passage=model.format_passage,
qrel_config=eval_mqrel,
data_args_overrides=arg_overrides,
num_proc=8,
)
We reuse the same ``data_args`` object that we used for the training dataset; but we override the value of some attributes with ``arg_overrides``.
To calculate approximate metrics, Trove expects all queries to have the same number of documents (for easier batching).
So, we set ``"group_size":100`` to make sure all queries have 100 annotated documents.
If a query has more than 100 documents, ``MultiLevelDataset`` uses a subset of annotated documents.
To make sure the positive documents are included in the subst, we set ``"passage_selection_strategy": "most_relevant"``.
If a query has fewer than 100 annotated documents, ``MultiLevelDataset`` duplicates some documents.
We also need to create a stateful callback function (|IRMetrics|) to compute the metrics for each batch of eval data.
.. code-block:: python
from trove import IRMetrics
# k_values are cutoff values for IR metrics
metric_callback = IRMetrics(k_values=[10, 100])
Finally, we add a few extra arguments when instantiating and using the trainer.
.. code-block:: python
trainer = RetrievalTrainer(
...
eval_dataset=eval_dataset,
compute_metrics=metric_callback,
)
trainer.train(ignore_keys_for_eval=["query", "passage"])