defcompute_cache(args,model,subset_ds,cache_shape):"""Compute the cache containing features of images, which is used to find best positive and hardest negatives."""subset_dl=DataLoader(dataset=subset_ds,num_workers=args.num_workers,batch_size=args.infer_batch_size,shuffle=False,pin_memory=(args.device=="cuda"))model=model.eval()# RAMEfficient2DMatrix can be replaced by np.zeros, but using# RAMEfficient2DMatrix is RAM efficient for full database mining.cache=RAMEfficient2DMatrix(cache_shape,dtype=np.float32)withtorch.no_grad():forimages,indexesintqdm(subset_dl,ncols=100):images=images.to(args.device)features=model(images)cache[indexes.numpy()]=features.cpu().numpy()returncache
def__getitem__(self,index):ifself.is_inference:# At inference time return the single image. This is used for caching or computing NetVLAD's clustersreturnsuper().__getitem__(index)query_index,best_positive_index,neg_indexes=torch.split(self.triplets_global_indexes[index],(1,1,self.negs_num_per_query))query=self.query_transform(path_to_pil_img(self.queries_paths[query_index]))positive=self.resized_transform(path_to_pil_img(self.database_paths[best_positive_index]))negatives=[self.resized_transform(path_to_pil_img(self.database_paths[i]))foriinneg_indexes]images=torch.stack((query,positive,*negatives),0)triplets_local_indexes=torch.empty((0,3),dtype=torch.int)forneg_numinrange(len(neg_indexes)):triplets_local_indexes=torch.cat((triplets_local_indexes,torch.tensor([0,1,2+neg_num]).reshape(1,3)))returnimages,triplets_local_indexes,self.triplets_global_indexes[index]