#导入常用包
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
超参数可以统一设置,参数初始化:
#超参数定义
# 批次的大小
batch_size = 16 #可选32、64、128
# 优化器的学习率
lr = 1e-4
#运行epoch
max_epochs = 10
# 方案一:指定GPU的方式
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 指明调用的GPU为0,1号
# 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # 指明调用的GPU为1号
Dataset类主要包含三个函数:
# 数据读取
#cifar10数据集为例给出构建Dataset类的方式
from torchvision import datasets
#“data_transform”可以对图像进行一定的变换,如翻转、裁剪、归一化等操作,可自己定义
data_transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
train_cifar_dataset = datasets.CIFAR10('cifar10',train=True, download=False,transform=data_transform)
test_cifar_dataset = datasets.CIFAR10('cifar10',train=False, download=False,transform=data_transform)
#查看dataset
print(test_cifar_dataset.__len__)
image_demo = test_cifar_dataset.__getitem__(1)[0]
print(image_demo)
print(image_demo.size())
<bound method CIFAR10.__len__ of Dataset CIFAR10
Number of datapoints: 10000
Root location: cifar10
Split: Test
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
)>
tensor([[[ 0.8431, 0.8118, 0.8196, ..., 0.8275, 0.8275, 0.8196],
[ 0.8667, 0.8431, 0.8431, ..., 0.8510, 0.8510, 0.8431],
[ 0.8588, 0.8353, 0.8353, ..., 0.8431, 0.8431, 0.8353],
#查看数据集
import matplotlib.pyplot as plt
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
dataiter = iter(test_cifar_dataset)
plt.show()
for i in range(10):
images, labels = dataiter.__next__()
print(images.size())
print(str(classes[labels]))
# images = images.numpy().transpose(1, 2, 0) # 把channel那一维放到最后
# plt.title(str(classes[labels]))
# plt.imshow(images)
torch.Size([3, 32, 32])
cat
torch.Size([3, 32, 32])
ship
torch.Size([3, 32, 32])
ship
torch.Size([3, 32, 32])
plane
torch.Size([3, 32, 32])
frog
torch.Size([3, 32, 32])
frog
torch.Size([3, 32, 32])
car
torch.Size([3, 32, 32])
frog
#构建好Dataset后,就可以使用DataLoader来按批次读入数据了
train_loader = torch.utils.data.DataLoader(train_cifar_dataset,
batch_size=batch_size, num_workers=4,
shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(test_cifar_dataset,
batch_size=batch_size, num_workers=4,
shuffle=False)
参数说明:
#自定义 Dataset 类
class MyDataset(Dataset):
def __init__(self, data_dir, info_csv, image_list, transform=None):
"""
Args:
data_dir: path to image directory.
info_csv: path to the csv file containing image indexes
with corresponding labels.
image_list: path to the txt file contains image names to training/validation set
transform: optional transform to be applied on a sample.
"""
label_info = pd.read_csv(info_csv)
image_file = open(image_list).readlines()
self.data_dir = data_dir
self.image_file = image_file
self.label_info = label_info
self.transform = transform
def __getitem__(self, index):
"""
Args:
index: the index of item
Returns:
image and its labels
"""
image_name = self.image_file[index].strip('\n')
raw_label = self.label_info.loc[self.label_info['Image_index'] == image_name]
label = raw_label.iloc[:,0]
image_name = os.path.join(self.data_dir, image_name)
image = Image.open(image_name).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.image_file)
#自定义 dataset demo
data_dir = ''
info_csv = ''
image_list = ''
my_dataset = MyDataset(data_dir,info_csv,image_list)
阅读量:2037
点赞量:0
收藏量:0