from time import time
from os import listdir
from os.path import basename
from PIL import Image
from sklearn import svm
from sklearn.model_selection import GridSearchCV
# 图像尺寸
width, height = 30, 60
def loadDigits(dstDir='datasets'):
# 获取所有图像文件名
digitsFile = [dstDir+'\\'+fn for fn in listdir(dstDir)
if fn.endswith('.jpg')]
# 按编号排序
digitsFile.sort(key=lambda fn: int(basename(fn)[:-4]))
# digitsData用于存放读取的图片中数字信息
# 每个图片中所有像素值存放于digitsData中的一行数据
digitsData = []
for fn in digitsFile:
with Image.open(fn) as im:
data = [sum(im.getpixel((w,h)))/len(im.getpixel((w,h)))
for w in range(width)
for h in range(height)]
digitsData.append(data)
# digitsLabel用于存放图片中数字的标准分类
with open(dstDir+'\\digits.txt') as fp:
digitsLabel = fp.readlines()
digitsLabel = [label.strip() for label in digitsLabel]
return (digitsData, digitsLabel)
# 加载数据
data = loadDigits()
print('数据加载完成。')
# 创建模型
svcClassifier = svm.SVC()
# 待测试的参数
parameters = {'kernel': ('linear', 'rbf'),
'C': (0.001, 1, 1000),
'gamma':(0.001, 1, 10)}
# 网格搜索
start = time()
clf = GridSearchCV(svcClassifier, parameters)
clf.fit(data[0], data[1])
# 解除注释可以查看详细结果
# print(clf.cv_results_)
print(clf.best_params_)
print('得分:', clf.score(data[0], data[1]))
print('用时(秒):', time()-start)