AdaBoost:理論と実装
メニューを表示するにはスワイプしてください
AdaBoost(Adaptive Boosting)は、複数の弱学習器を組み合わせて強力な分類器を構築するアンサンブル学習手法。各弱学習器の学習後、AdaBoostは誤分類されたサンプルの重みを増やし、正しく分類されたサンプルの重みを減らすことで適応的に重み付けを行う。この適応的な重み付けにより、新しい学習器は前の学習器が苦手とした難しいケースにより注目し、繰り返しの中で誤りを修正することでアンサンブル全体の精度向上を実現する。
AdaBoostの主要な数式:重み付き誤差と学習器の重み
AdaBoostは各弱学習器ごとに2つの主要な計算を行う:
- 重み付き誤差(εt):現在の学習器が重み付けされた訓練データをどれだけ誤分類したかを測定;
- 学習器の重み(αt):現在の学習器が最終予測にどれだけ影響を与えるかを決定。
数式は以下の通り:
- 重み付き誤差:
- wi:サンプルiの重み;
- yi:サンプルiの正解ラベル;
- ht(xi):t番目の弱学習器によるサンプルiの予測;
- [yi=ht(xi)]:予測が誤りなら1、正解なら0。
学習器の重み:
αt=21ln(εt1−εt)- αt:t番目の学習器が最終予測で持つ発言力;
- 誤差(εt)が小さい学習器ほどαtが大きくなり、投票の重みが増す。
これらの計算により、新しい学習器は最も難しいサンプルに注目し、より正確な学習器がアンサンブルの出力に大きな影響を与える。
1234567891011121314151617181920212223242526272829303132333435363738from sklearn.ensemble import AdaBoostClassifier from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score import matplotlib.pyplot as plt # Generate a synthetic binary classification dataset X, y = make_classification(n_samples=500, n_features=10, n_informative=7, n_classes=2, random_state=42) # Split into training and test sets X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) # Create a weak learner: a decision tree with max_depth=1 (stump) base_estimator = DecisionTreeClassifier(max_depth=1, random_state=42) # Fit AdaBoost with 50 weak learners and a learning rate of 1.0 ada = AdaBoostClassifier( estimator=base_estimator, n_estimators=50, learning_rate=1.0, random_state=42 ) ada.fit(X_train, y_train) # Predict and evaluate y_pred = ada.predict(X_test) # Plot staged learning curve (test accuracy at each boosting iteration) test_accuracies = [accuracy_score(y_test, y_pred_stage) for y_pred_stage in ada.staged_predict(X_test)] plt.figure(figsize=(8, 4)) plt.plot(range(1, len(test_accuracies) + 1), test_accuracies, marker='o') plt.title(f"AdaBoost Test Accuracy over Iterations (Average accuracy: {accuracy_score(y_test, y_pred):.2f})") plt.xlabel("Number of Weak Learners") plt.ylabel("Test Accuracy") plt.grid(True) plt.tight_layout() plt.show()
すべて明確でしたか?
フィードバックありがとうございます!
セクション 1. 章 9
AIに質問する
AIに質問する
何でも質問するか、提案された質問の1つを試してチャットを始めてください
セクション 1. 章 9