小編給大家分享一下pytorch中cnn如何識別手寫的字并實現自建圖片數據,相信大部分人都還不怎么了解,因此分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后大有收獲,下面讓我們一起去了解一下吧!
創新互聯建站堅持“要么做到,要么別承諾”的工作理念,服務領域包括:成都網站制作、成都網站建設、外貿營銷網站建設、企業官網、英文網站、手機端網站、網站推廣等服務,滿足客戶于互聯網時代的上猶網站設計、移動媒體設計的需求,幫助企業找到有效的互聯網解決方案。努力成為您成熟可靠的網絡建設合作伙伴!具體如下:
# library # standard library import os # third-party library import torch import torch.nn as nn from torch.autograd import Variable from torch.utils.data import Dataset, DataLoader import torchvision import matplotlib.pyplot as plt from PIL import Image import numpy as np # torch.manual_seed(1) # reproducible # Hyper Parameters EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch BATCH_SIZE = 50 LR = 0.001 # learning rate root = "./mnist/raw/" def default_loader(path): # return Image.open(path).convert('RGB') return Image.open(path) class MyDataset(Dataset): def __init__(self, txt, transform=None, target_transform=None, loader=default_loader): fh = open(txt, 'r') imgs = [] for line in fh: line = line.strip('\n') line = line.rstrip() words = line.split() imgs.append((words[0], int(words[1]))) self.imgs = imgs self.transform = transform self.target_transform = target_transform self.loader = loader fh.close() def __getitem__(self, index): fn, label = self.imgs[index] img = self.loader(fn) img = Image.fromarray(np.array(img), mode='L') if self.transform is not None: img = self.transform(img) return img,label def __len__(self): return len(self.imgs) train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor()) train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True) test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor()) test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE) class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Sequential( # input shape (1, 28, 28) nn.Conv2d( in_channels=1, # input height out_channels=16, # n_filters kernel_size=5, # filter size stride=1, # filter movement/step padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1 ), # output shape (16, 28, 28) nn.ReLU(), # activation nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14) ) self.conv2 = nn.Sequential( # input shape (16, 14, 14) nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14) nn.ReLU(), # activation nn.MaxPool2d(2), # output shape (32, 7, 7) ) self.out = nn.Linear(32 * 7 * 7, 10) # fully connected layer, output 10 classes def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7) output = self.out(x) return output, x # return x for visualization cnn = CNN() print(cnn) # net architecture optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted # training and testing for epoch in range(EPOCH): for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader b_x = Variable(x) # batch x b_y = Variable(y) # batch y output = cnn(b_x)[0] # cnn output loss = loss_func(output, b_y) # cross entropy loss optimizer.zero_grad() # clear gradients for this training step loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients if step % 50 == 0: cnn.eval() eval_loss = 0. eval_acc = 0. for i, (tx, ty) in enumerate(test_loader): t_x = Variable(tx) t_y = Variable(ty) output = cnn(t_x)[0] loss = loss_func(output, t_y) eval_loss += loss.data[0] pred = torch.max(output, 1)[1] num_correct = (pred == t_y).sum() eval_acc += float(num_correct.data[0]) acc_rate = eval_acc / float(len(test_data)) print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))
圖片和label 見上一篇文章《pytorch 把MNIST數據集轉換成圖片和txt》
結果如下:
以上是“pytorch中cnn如何識別手寫的字并實現自建圖片數據”這篇文章的所有內容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內容對大家有所幫助,如果還想學習更多知識,歡迎關注創新互聯行業資訊頻道!
當前標題:pytorch中cnn如何識別手寫的字并實現自建圖片數據-創新互聯
文章分享:http://vcdvsql.cn/article10/ddpcgo.html
成都網站建設公司_創新互聯,為您提供網站設計、域名注冊、靜態網站、App開發、網站排名、網站內鏈
聲明:本網站發布的內容(圖片、視頻和文字)以用戶投稿、用戶轉載內容為主,如果涉及侵權請盡快告知,我們將會在第一時間刪除。文章觀點不代表本網站立場,如需處理請聯系客服。電話:028-86922220;郵箱:631063699@qq.com。內容未經允許不得轉載,或轉載時需注明來源: 創新互聯