分類するだけじゃなくて学習したデータを保存もしています。学習データがあったらデータを読み込んで利用します。
import os from sklearn.ensemble import RandomForestClassifier from sklearn.externals import joblib class RandomForest(object): def __init__(self): if os.path.exists('./.supervised/data.bin'): self._m = joblib.load('./.supervised/data.bin') else: _training_data = [ [2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 2], ] _training_label = [1, 2, 3, 4] self._m = RandomForestClassifier() self._m.fit(_training_data, _training_label) joblib.dump(self._m, './.supervised/data.bin') def start(self): _test_data = [ [0, 4, 0, 0], [0, 0, 0, 2] ] _out = self._m.predict(_test_data) print(_out) if __name__ == "__main__": _obj = RandomForest() _obj.start()
結果出力
[2 4]