这里将给出一段Demo来介绍使用目标检测数据集
时如何将本工具与PyTorch配合使用
import dfs
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
class DataSegment(Dataset):
"""继承官方类"""
def __init__(self, data_dl, transform=None):
super().__init__()
self.data_dl = data_dl
self.length = len(self.data_dl)
self.candidate_index = list(i for i in range(self.length))
self.transform = transform
def __len__(self):
return self.length
def __getitem__(self, idx):
image_url, labels = self.data_dl[idx]
labels_list = []
for key, value in labels.items():
value = list(i + [key] for i in value)
labels_list += value
labels_list = np.asarray(labels_list)
# backend可选择pil或者cv2
image = dfs.decode_img(image_url, backend='pil')
image_tensor = self.transform(image)
image_tensor = image_tensor.unsqueeze(0)
return image_tensor, labels_list
def dataset_collate(batch):
images = []
bboxes = []
for img, box in batch:
images.append(img)
bboxes.append(box)
images = torch.cat(images, dim=0)
return images, bboxes
client = dfs.Client(
access_token=<YOUR_ACCESSKEY>,dataset_id=<DATASET_ID>)
to_tensor = transforms.ToTensor()
normalization = transforms.Normalize(mean=[0.485], std=[0.229])
my_transforms = transforms.Compose([to_tensor, normalization])
train_set = dfs.LMDataset(client, segment_name='train')
# 开启本地缓存,避免重复拉取图片
train_set.enable_cache()
train_dataset = DataSegment(train_set, my_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=dataset_collate)
for epoch in range(10):
for index, (image, label) in enumerate(train_dataloader):
print("index:{}:\timage_shape:{}\tlabel:{}".format(index, image.shape, label))
# 删除本次缓存
train_set.remove_cache_alone()