Statsmodelsを使用した線形回帰の構築
メニューを表示するにはスワイプしてください
線形回帰モデルの構築
statsmodels では、OLS クラスを使用して線形回帰モデルを作成可能。
まず、OLS クラスのオブジェクトを初期化する必要がある。
sm.OLS(y, X_tilde)。
その後、fit() メソッドで学習を行う。
model = sm.OLS(y, X_tilde)
model = model.fit()
これは次のコードと同等。
model = sm.OLS(y, X_tilde).fit()
注意
OLS クラスのコンストラクタは、特定の配列 X_tilde を入力として要求する(これは正規方程式で確認済み)。そのため、X 配列を X_tilde に変換する必要がある。これは sm.add_constant() 関数を使うことで実現可能。
パラメータの取得
モデルが学習された後、params属性を使用してパラメータに簡単にアクセス可能。
123456789import statsmodels.api as sm import pandas as pd df = pd.read_csv('https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv') X, y = df['Father'], df['Height'] X_tilde = sm.add_constant(X) model = sm.OLS(y, X_tilde).fit() beta_0, beta_1 = model.params print(beta_0, beta_1)
予測の実行
新しいインスタンスの予測はpredict()メソッドで簡単に実行可能。ただし、入力データも事前処理が必要。
12345import numpy as np X_new = np.array([65, 70, 75]) X_new_tilde = sm.add_constant(X_new) print(model.predict(X_new_tilde))
サマリーの取得
OLS クラスの使用は、polyfit() 関数ほど簡単ではないことにお気付きかもしれません。しかし、OLS を使用することで多くの利点があります。トレーニング中に多くの統計情報が計算されます。これらの情報には summary() メソッドを使ってアクセスできます。
1print(model.summary())
多くの統計情報が表示されています。表の最も重要な部分については、後のセクションで説明します。
すべて明確でしたか?
フィードバックありがとうございます!
セクション 1. 章 4
AIに質問する
AIに質問する
何でも質問するか、提案された質問の1つを試してチャットを始めてください
セクション 1. 章 4