RetrievalEvaluatorUtilsMixin

class trove.evaluation.evaluator_mixin_utils.RetrievalEvaluatorUtilsMixin(args=None, tracker_init_kwargs=None, tracker_extra_configs=None, tracker_callbacks=None)
__init__(args=None, tracker_init_kwargs=None, tracker_extra_configs=None, tracker_callbacks=None)

A simple Mixin for RetrievalEvaluator that provides convenient utilities.

Anything that is not part of the main algorithm or is not one of your design decisions should go here. For example:

  • dealing with distributed env attributes: is it main process, local process, etc.

  • providing progress bar in distributed environment

  • create a temp file that is shared among processes

  • saving/loading files and checkpoints (what you save and where you save it should be decided in the main class. You just define a function here that takes those as arguments and make sure it is written correctly. e.g., only on main process)

  • logging things to wandb, etc.

But, even small details that impact your design should not be here. For example, if there is a temp file that you only delete it under specific circumstances, you should check whether to delete the file in the main class.

Parameters:
  • args (Optional[EvaluationArguments]) – Evaluation arguments. Same as what is passed to RetrievalEvaluator.

  • tracker_init_kwargs (Optional[Dict]) – extra kwargs for initializing experiment trackers. Directly passed to tracker init method with dict comprehension.

  • tracker_extra_configs (Optional[Union[List[Dict], Dict]]) – extra configs to log with experiment trackers. These configs are merged with the evaluator and model configs if present. You can either pass one config object or a list of config objects. Config objects must be instances of python dictionary.

  • tracker_callbacks (Optional[Any]) – One or multiple custom experiment tracker callbacks. Callbacks must have a .log() method for logging the evaluation metrics. and a .setup() method that is called in the Evaluator constructor to initialize the tracker. If in a distributed environment, these methods will only be called from the main global process.

update_distributed_state()

Set the attributes related to the distributed environment.

Return type:

None

init_infinite_barrier()

Create an instance of InfiniteBarrier for as a dist.barrier() alternative without timeout.

Return type:

None

allowed_on_this_process(mode)

Parse user provided mode argument to determine if we are allowed to perform certain actions on this processes.

For example, to check if we should use a progress bar or print print results to stdout. If not in a distributed environment, always returns True. Valie modes are:

  • None : all processes are allowed

  • one : No process is allowed

  • main : only main process is allowed

  • local_main : only local main process is allowed

  • all : all processes are allowed

Parameters:

mode (Optional[str]) – mode of operation for a specific action.

Return type:

bool

Returns:

a boolean showing that if we can or cannot perform a specific action from the current process.

all_gather_object(obj)

Gather a pickleable object from all processes in a distributed environment.

If not in a distributed environment, it simulates a distributed environment with only one process and returns a list with only one item, which is the object from the current (and only) process.

Parameters:

obj (Any) – The pickleable object to gather from all processes

Return type:

List[Any]

Returns:

A list of objects. where len(objects) == self.world_size and objects[rank] is the data collected from process with rank

gather_object(obj, dst=0)

Same as all_gather_object() but only gathers objects on process with rank == dst.

Parameters:
  • obj (Any) – see all_gather_object()

  • dst (int) – rank of the process that collects the objects from the entire group.

Return type:

Optional[List[Any]]

Returns:

None if dst != self.rank. Otherwise, a list of objects. where len(objects) == self.world_size and objects[rank] is the data collected from process with rank

broadcast_obj(obj, src=0)

Broadcast a pickleable object to all processes.

Broadcast the obj from process with rank == src to all processes. Returns the broadcasted object. I.e., in all processes, returns the object from process with rank == src.

Parameters:
  • obj (Any) – pickleable object to broadcast to all processes.

  • src (int) – rank of the process that broadcasts the object.

Return type:

Any

Returns:

the broadcasted object from process with rank == src.

barrier(infinite=False)

Similar to dist.barrier().

In future, it should call the appropriate method for different distributed computing environments

Parameters:

infinite (bool) – If True, use InfiniteBarrier which can wait forever without a timeout. Otherwise, use regular dist.barrier().

Return type:

None

property device: torch.device

Current device that model and data should be moved to.

property is_distributed: bool

Returns true if running in a distributed environment.

property rank: int

Global rank of current process.

Returns 0 if not running in distributed environment.

property world_size: int

The total number of processes.

Returns 1 if not running in a distributed environment.

property is_main_process: bool

Returns true if this is the global main process.

property is_local_main_process: bool

Returns true if this is the local main process.

all_devices_are_similar()

Returns true if all devices available in a distributed environment are the same and False otherwise.

Return type:

bool

get_free_port_on_master()

Finds a free tcp port on master host and shares it with all processes.

Return type:

int

get_shared_uuid()

Generates a random UUID that is the same across processes in a distributed environment.

Return type:

str

pbar(*args, **kwargs)

Returns a tqdm instance.

It is disabled if pbars are not allowed in this process.

Return type:

tqdm

initialize_trackers()

Initialize experiment trackers.

Return type:

None

log_metrics(metrics)

Log metrics to experiment trackers.

Return type:

None

write_json(obj, path, **kwargs)

Write json file only from the main process.

Parameters:
  • obj (Any) – json serializable object to write to json file.

  • path (PathLike) – path to destination file.

  • **kwargs – keyword arguments passed to json.dump().

Return type:

None

write_json_lines(obj_list, path, **kwargs)

Write json lines file only from the main process.

Parameters:
  • obj_list (Iterable[Any]) – list of json serializable objects to write to json file.

  • path (PathLike) – path to destination file.

  • chunk_size – if provided, force flush the file buffer every chunk_size records.

  • **kwargs – keyword arguments passed to json.dumps().

Return type:

None