"""This module provides functions for gathering data."""# Created by Wenjie Du <wenjay.du@gmail.com># License: BSD-3-Clauseimportnumpyasnpimporttorch
[docs]defgather_listed_dicts(dict_list:list)->dict:"""Gather batched dict output from model forward Parameters ---------- dict_list: A list of dict output from model forward. Each dict should have the same keys. Returns ------- gathered_dict: A dict with the same keys as the input dict, but with values concatenated along the batch dimension. """# check if all dicts have the same keyskeys=dict_list[0].keys()fordindict_list[1:]:assertd.keys()==keys,"All dicts should have the same keys"gathered_dict=dict()forkinkeys:ifisinstance(dict_list[0][k],torch.Tensor):ifdict_list[0][k].dim()>0:gathered_dict[k]=torch.cat([d[k]fordindict_list],dim=0).cpu().detach().numpy()elifisinstance(dict_list[0][k],np.ndarray):ifdict_list[0][k].ndim>0:gathered_dict[k]=np.concatenate([d[k]fordindict_list],axis=0)else:raiseValueError("Only support torch.Tensor and np.ndarray")returngathered_dict