這篇文章主要介紹了Pytorch中mean和std調(diào)查的示例分析,具有一定借鑒價(jià)值,感興趣的朋友可以參考下,希望大家閱讀完這篇文章之后大有收獲,下面讓小編帶著大家一起了解一下。
創(chuàng)新互聯(lián)公司主營(yíng)羅山網(wǎng)站建設(shè)的網(wǎng)絡(luò)公司,主營(yíng)網(wǎng)站建設(shè)方案,重慶APP開(kāi)發(fā)公司,羅山h5成都小程序開(kāi)發(fā)搭建,羅山網(wǎng)站營(yíng)銷推廣歡迎羅山等地區(qū)企業(yè)咨詢如下所示:
# coding: utf-8 from __future__ import print_function import copy import click import cv2 import numpy as np import torch from torch.autograd import Variable from torchvision import models, transforms import matplotlib.pyplot as plt import load_caffemodel import scipy.io as sio # if model has LSTM # torch.backends.cudnn.enabled = False imgpath = 'D:/ck/files_detected_face224/' imgname = 'S055_002_00000025.png' # anger image_path = imgpath + imgname mean_file = [0.485, 0.456, 0.406] std_file = [0.229, 0.224, 0.225] raw_image = cv2.imread(image_path)[..., ::-1] print(raw_image.shape) raw_image = cv2.resize(raw_image, (224, ) * 2) image = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=mean_file, std =std_file, #mean = mean_file, #std = std_file, ) ])(raw_image).unsqueeze(0) print(image.shape) convert_image1 = image.numpy() convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1)) convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C print(convert_image1.shape) convert_image1 = convert_image1 * 255 diff = raw_image - convert_image1 err = np.max(diff) print(err) plt.imshow(np.uint8(convert_image1)) plt.show()
結(jié)論:
input_image = (raw_image / 255 - mean) ./ std
下面調(diào)查均值文件和方差文件是如何生成的:
mean_file = [0.485, 0.456, 0.406] std_file = [0.229, 0.224, 0.225]
# coding: utf-8 import matplotlib.pyplot as plt import argparse import os import numpy as np import torchvision import torchvision.transforms as transforms dataset_names = ('cifar10','cifar100','mnist') parser = argparse.ArgumentParser(description='PyTorchLab') parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names, help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)') args = parser.parse_args() data_dir = os.path.join('.', args.dataset) print(args.dataset) args.dataset = 'cifar10' if args.dataset == "cifar10": train_transform = transforms.Compose([transforms.ToTensor()]) train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform) #print(vars(train_set)) print(train_set.train_data.shape) print(train_set.train_data.mean(axis=(0,1,2))/255) print(train_set.train_data.std(axis=(0,1,2))/255) # imshow image train_data = train_set.train_data ind = 100 img0 = train_data[ind,...] ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe # error produce #b,g,r=cv2.split(img0) #img0=cv2.merge([r,g,b]) print(img0.shape) print(type(img0)) plt.imshow(img0) plt.show() # in ship in sea #img0 = cv2.resize(img0,(224,224)) #cv2.imshow('img0',img0) #cv2.waitKey() elif args.dataset == "cifar100": train_transform = transforms.Compose([transforms.ToTensor()]) train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform) #print(vars(train_set)) print(train_set.train_data.shape) print(np.mean(train_set.train_data, axis=(0,1,2))/255) print(np.std(train_set.train_data, axis=(0,1,2))/255) elif args.dataset == "mnist": train_transform = transforms.Compose([transforms.ToTensor()]) train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform) #print(vars(train_set)) print(list(train_set.train_data.size())) print(train_set.train_data.float().mean()/255) print(train_set.train_data.float().std()/255)
結(jié)果:
cifar10 Files already downloaded and verified (50000, 32, 32, 3) [ 0.49139968 0.48215841 0.44653091] [ 0.24703223 0.24348513 0.26158784] (32, 32, 3) <class 'numpy.ndarray'>
使用matlab檢測(cè)是如何計(jì)算mean_file和std_file的:
% load cifar10 dataset data = load('cifar10_train_data.mat'); train_data = data.train_data; disp(size(train_data)); temp = mean(train_data,1); disp(size(temp)); train_data = double(train_data); % compute mean_file mean_val = mean(mean(mean(train_data,1),2),3)/255; % compute std_file temp1 = train_data(:,:,:,1); std_val1 = std(temp1(:))/255; temp2 = train_data(:,:,:,2); std_val2 = std(temp2(:))/255; temp3 = train_data(:,:,:,3); std_val3 = std(temp3(:))/255; mean_val = squeeze(mean_val); std_val = [std_val1, std_val2, std_val3]; disp(mean_val); disp(std_val); % result: mean_val: [0.4914, 0.4822, 0.4465] % std_val: [0.2470, 0.2435, 0.2616]
均值計(jì)算的過(guò)程也可以遵循標(biāo)準(zhǔn)差的計(jì)算過(guò)程。為 了簡(jiǎn)單,例如對(duì)于一個(gè)矩陣,所有元素的均值,等于兩個(gè)方向上先后均值。所以會(huì)直接采用如下的形式:
mean_val = mean(mean(mean(train_data,1),2),3)/255;
標(biāo)準(zhǔn)差的計(jì)算是每一個(gè)通道的對(duì)所有樣本的求標(biāo)準(zhǔn)差。然后再除以255。
感謝你能夠認(rèn)真閱讀完這篇文章,希望小編分享的“Pytorch中mean和std調(diào)查的示例分析”這篇文章對(duì)大家有幫助,同時(shí)也希望大家多多支持創(chuàng)新互聯(lián)成都網(wǎng)站設(shè)計(jì)公司,關(guān)注創(chuàng)新互聯(lián)成都網(wǎng)站設(shè)計(jì)公司行業(yè)資訊頻道,更多相關(guān)知識(shí)等著你來(lái)學(xué)習(xí)!
另外有需要云服務(wù)器可以了解下創(chuàng)新互聯(lián)scvps.cn,海內(nèi)外云服務(wù)器15元起步,三天無(wú)理由+7*72小時(shí)售后在線,公司持有idc許可證,提供“云服務(wù)器、裸金屬服務(wù)器、網(wǎng)站設(shè)計(jì)器、香港服務(wù)器、美國(guó)服務(wù)器、虛擬主機(jī)、免備案服務(wù)器”等云主機(jī)租用服務(wù)以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡(jiǎn)單易用、服務(wù)可用性高、性價(jià)比高”等特點(diǎn)與優(yōu)勢(shì),專為企業(yè)上云打造定制,能夠滿足用戶豐富、多元化的應(yīng)用場(chǎng)景需求。
網(wǎng)站題目:Pytorch中mean和std調(diào)查的示例分析-創(chuàng)新互聯(lián)
分享地址:http://vcdvsql.cn/article0/phpoo.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供靜態(tài)網(wǎng)站、軟件開(kāi)發(fā)、網(wǎng)站設(shè)計(jì)、做網(wǎng)站、營(yíng)銷型網(wǎng)站建設(shè)、微信公眾號(hào)
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容