Pythonのライブラリ【scikit-learn】で機械学習をはじめる
昨今めざましい技術革新が行われているAI(人工知能)の分野ですが、その中でもプログラミング言語の「Python」は機械学習においてかなり注目されています。
今回はPythonを用いて実際に機械学習を行ってみましたので、その解説とコードの紹介をしていきます。
機械学習とは
現在、AIが世界的に注目されています。
AIを搭載した無人で運転を行う車や、その日の夕飯のレシピを提案してくれるAI搭載の冷蔵庫など、様々な場所で【AI】という言葉を耳にすることかと思います。
そのAIですが、一言にAIといっても様々なものがあり、それぞれ学習方法によって性質や得意分野の違いなどが多く存在します。
機械学習の定義
今回紹介する【機械学習】とは、【莫大な数のデータ集合を解析し、特徴を捉え、未知のデータについての予測を行うことができるよう学習させるもの】を意味します。
AIについてのニュースでよく【ディープラーニング】という単語について話しているのを聞いたことはないでしょうか。
【ディープラーニング】もまた、今回紹介する【機械学習】というカテゴリの一分野です。
【ディープラーニング】は人間の脳をモデルにした学習方法を採用している点がその他の機械学習とは異なる点ですが、【何らかの方法でデータを分析し学習を行う】という点では共通しています。
Pythonの機械学習ライブラリ【scikit-learn】
機械学習を行うのに特化したライブラリのひとつに、Pythonの【scikit-learn】(サイキットラーン)があります。【scikit-learn】は機械学習を行うにあたって有用なアルゴリズムが数多く実装されています。
極めて幅の広いアルゴリズムが実装されている上に、機械学習のチュートリアル用に用意されている様々なデータセットが付属しているため、はじめて機械学習を学ぶプログラマにはうってつけのライブラリとなっています。
Pythonとscikit-learnで機械学習を実際に行う
scikit-learnはオープンソースとして公開されている機械学習ライブラリであるため、無料で導入することが出来るようになっています。
今回は、scikit-learnに付属する【手書き数字の画像データ】をもとに、【手書き数字の画像が実際に数字の何番であるかを学習し、判定・分類することが出来るよう機械学習を行う】アルゴリズムを、scikit-learnのチュートリアルとして実装しました。
Pythonとscikit-learnによる機械学習の方法
まず、必要なライブラリやデータセットなどをインポートします。インポートは次のコードで行いました。
# matplotlibライブラリのpyplotをインポートする
import matplotlib.pyplot as plt
# データセット、分類機能、評価機能をインポートする
from sklearn import datasets, svm, metrics
なお、最初の行でインポートした【matplotlib】ライブラリとは、Python用の【グラフ描画ライブラリ】です。このmatplotlibライブラリの中から、pyplotモジュールをインポートしました。
次に、インポートしたデータセットからチュートリアル用に用意された【手書き数字の画像】を取得します。
# データセットに用意された【手書き数字の画像データ】を読み込む
digits = datasets.load_digits()
# digits.images:手書き数字の画像データ
# digits.target:画像データが数字の何番を示すものであるかの正解ラベル
images_and_labels = list(zip(digits.images, digits.target))
# 取得した画像データを、実際の画像として出力する
for index, (image, label) in enumerate(images_and_labels[:4]):
plt.subplot(2, 4, index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Training: %i' % label)
データを取得したら、次に分類を行うため、取得したデータの平坦化を行います。
# 【手書き数字の画像】を平坦化し、データを行列に変換する
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
ここで変換に用いているreshapeメソッドは、【配列を形状変換する】ことが出来ます。これを用いてデータを行列に変換します。
ここまでが機械学習を行うための分析対象の準備です。機械学習の準備がほぼ出来ましたので、次に平坦化を行ったデータをもとに、機械学習を行います。
# 分類の作成
classifier = svm.SVC(gamma=0.001)
# 分類のトレーニング用にデータの半分を用い、機械学習を行う
# (fitメソッドは学習を行う場合に使用する)
# fit_第一引数:学習用データ, 第二引数:結果
# データと結果から、学習を行う。
classifier.fit(data[:n_samples // 2], digits.target[:n_samples // 2])
scikit-learnが提供する【fitメソッド】は、機械学習を実行するメソッドです。第一引数に分類対象のテストデータを、第二引数に分類した結果の正しい答えを指定することで、機械学習を行うことが出来ます。
fitメソッド完了時点で、学習は完了しています。そのため【機械学習をするだけ】であれば、この時点でもう目的は達成しています。
学習した結果、どのように手書き数字を判定・分類をできるようになったかについては次のコードでテストを行うことが出来ます。
predicted = classifier.predict(data[n_samples // 2:])
ここで行ったのは、機械学習の完了した状態で、未知の手書き数字データに対しての分類(数字が何番であるかの判定)処理の実行です。
この結果の正答率が高ければ【適切な機械学習に成功した】と言えますし、正答率が低い場合は【さらなる学習が必要】と言えるでしょう。
まとめ
このようにPythonの機械学習ライブラリ【scikit-learn】を用いると、様々な機械学習用のアルゴリズムを活用できます。
今回使用したチュートリアル用データセットは【手書き数字の画像】データでしたが、この他にもscikit-learnには様々なデータセットと、それを用いたチュートリアルが用意されています。
これから機械学習を学んでいきたいという方は、ぜひ他のデータセットも触れてみることをおすすめします。