RetrievalEvaluatorUtilsMixin
- class trove.evaluation.evaluator_mixin_utils.RetrievalEvaluatorUtilsMixin(args=None, tracker_init_kwargs=None, tracker_extra_configs=None, tracker_callbacks=None)
- __init__(args=None, tracker_init_kwargs=None, tracker_extra_configs=None, tracker_callbacks=None)
A simple Mixin for
RetrievalEvaluator
that provides convenient utilities.Anything that is not part of the main algorithm or is not one of your design decisions should go here. For example:
dealing with distributed env attributes: is it main process, local process, etc.
providing progress bar in distributed environment
create a temp file that is shared among processes
saving/loading files and checkpoints (what you save and where you save it should be decided in the main class. You just define a function here that takes those as arguments and make sure it is written correctly. e.g., only on main process)
logging things to wandb, etc.
But, even small details that impact your design should not be here. For example, if there is a temp file that you only delete it under specific circumstances, you should check whether to delete the file in the main class.
- Parameters:
args (Optional[EvaluationArguments]) – Evaluation arguments. Same as what is passed to
RetrievalEvaluator
.tracker_init_kwargs (Optional[Dict]) – extra kwargs for initializing experiment trackers. Directly passed to tracker init method with dict comprehension.
tracker_extra_configs (Optional[Union[List[Dict], Dict]]) – extra configs to log with experiment trackers. These configs are merged with the evaluator and model configs if present. You can either pass one config object or a list of config objects. Config objects must be instances of python dictionary.
tracker_callbacks (Optional[Any]) – One or multiple custom experiment tracker callbacks. Callbacks must have a
.log()
method for logging the evaluation metrics. and a.setup()
method that is called in the Evaluator constructor to initialize the tracker. If in a distributed environment, these methods will only be called from the main global process.
- update_distributed_state()
Set the attributes related to the distributed environment.
- Return type:
None
- init_infinite_barrier()
Create an instance of
InfiniteBarrier
for as adist.barrier()
alternative without timeout.- Return type:
None
- allowed_on_this_process(mode)
Parse user provided
mode
argument to determine if we are allowed to perform certain actions on this processes.For example, to check if we should use a progress bar or print print results to stdout. If not in a distributed environment, always returns True. Valie modes are:
None
: all processes are allowedone
: No process is allowedmain
: only main process is allowedlocal_main
: only local main process is allowedall
: all processes are allowed
- Parameters:
mode (Optional[str]) – mode of operation for a specific action.
- Return type:
bool
- Returns:
a boolean showing that if we can or cannot perform a specific action from the current process.
- all_gather_object(obj)
Gather a pickleable object from all processes in a distributed environment.
If not in a distributed environment, it simulates a distributed environment with only one process and returns a list with only one item, which is the object from the current (and only) process.
- Parameters:
obj (Any) – The pickleable object to gather from all processes
- Return type:
List
[Any
]- Returns:
A list of objects. where
len(objects) == self.world_size
andobjects[rank]
is the data collected from process withrank
- gather_object(obj, dst=0)
Same as
all_gather_object()
but only gathers objects on process withrank == dst
.- Parameters:
obj (Any) – see
all_gather_object()
dst (int) – rank of the process that collects the objects from the entire group.
- Return type:
Optional
[List
[Any
]]- Returns:
None
ifdst != self.rank
. Otherwise, a list of objects. wherelen(objects) == self.world_size
andobjects[rank]
is the data collected from process withrank
- broadcast_obj(obj, src=0)
Broadcast a pickleable object to all processes.
Broadcast the obj from process with
rank == src
to all processes. Returns the broadcasted object. I.e., in all processes, returns the object from process withrank == src
.- Parameters:
obj (Any) – pickleable object to broadcast to all processes.
src (int) – rank of the process that broadcasts the object.
- Return type:
Any
- Returns:
the broadcasted object from process with
rank == src
.
- barrier(infinite=False)
Similar to
dist.barrier()
.In future, it should call the appropriate method for different distributed computing environments
- Parameters:
infinite (
bool
) – If True, useInfiniteBarrier
which can wait forever without a timeout. Otherwise, use regulardist.barrier()
.- Return type:
None
- property device: torch.device
Current device that model and data should be moved to.
- property is_distributed: bool
Returns true if running in a distributed environment.
- property rank: int
Global rank of current process.
Returns 0 if not running in distributed environment.
- property world_size: int
The total number of processes.
Returns 1 if not running in a distributed environment.
- property is_main_process: bool
Returns true if this is the global main process.
- property is_local_main_process: bool
Returns true if this is the local main process.
- all_devices_are_similar()
Returns true if all devices available in a distributed environment are the same and False otherwise.
- Return type:
bool
- get_free_port_on_master()
Finds a free tcp port on master host and shares it with all processes.
- Return type:
int
Generates a random UUID that is the same across processes in a distributed environment.
- Return type:
str
- pbar(*args, **kwargs)
Returns a
tqdm
instance.It is disabled if pbars are not allowed in this process.
- Return type:
tqdm
- initialize_trackers()
Initialize experiment trackers.
- Return type:
None
- log_metrics(metrics)
Log metrics to experiment trackers.
- Return type:
None
- write_json(obj, path, **kwargs)
Write json file only from the main process.
- Parameters:
obj (
Any
) – json serializable object to write to json file.path (
PathLike
) – path to destination file.**kwargs – keyword arguments passed to
json.dump()
.
- Return type:
None
- write_json_lines(obj_list, path, **kwargs)
Write json lines file only from the main process.
- Parameters:
obj_list (
Iterable
[Any
]) – list of json serializable objects to write to json file.path (
PathLike
) – path to destination file.chunk_size – if provided, force flush the file buffer every
chunk_size
records.**kwargs – keyword arguments passed to
json.dumps()
.
- Return type:
None