Pythonでscikit-learnを使ってランダムフォレストの分類器を書いて、学習結果を保存してみた

2015-10-01
このエントリーをはてなブックマークに追加

分類するだけじゃなくて学習したデータを保存もしています。学習データがあったらデータを読み込んで利用します。

import os

from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib


class RandomForest(object):
    def __init__(self):
        if os.path.exists('./.supervised/data.bin'):
            self._m = joblib.load('./.supervised/data.bin')

        else:
            _training_data = [
                [2, 0, 0, 0],
                [0, 2, 0, 0],
                [0, 0, 2, 0],
                [0, 0, 0, 2],
            ]
            _training_label = [1, 2, 3, 4]

            self._m = RandomForestClassifier()
            self._m.fit(_training_data, _training_label)

            joblib.dump(self._m, './.supervised/data.bin')

    def start(self):
        _test_data = [
            [0, 4, 0, 0],
            [0, 0, 0, 2]
        ]
        _out = self._m.predict(_test_data)

        print(_out)

if __name__ == "__main__":
    _obj = RandomForest()
    _obj.start()

結果出力
[2 4]