Зміст курсу
Лінійна Регресія з Python
Лінійна Регресія з Python
Побудова Лінійної Регресії за Допомогою Statsmodels
У попередньому розділі ми використовували функцію з NumPy для обчислення параметрів.
Тепер ми використаємо клас-об'єкт замість функції для представлення лінійної регресії. Цей підхід потребує більше рядків коду для знаходження параметрів, але зберігає багато корисної інформації всередині об'єкта та робить прогнозування більш зрозумілим.
Побудова моделі лінійної регресії
У statsmodels для створення моделі лінійної регресії можна використовувати клас OLS
.
Спочатку потрібно ініціалізувати об'єкт класу OLS
за допомогою
sm.OLS(y, X_tilde)
.
Потім навчити його за допомогою методу fit()
.
model = sm.OLS(y, X_tilde)
model = model.fit()
Що еквівалентно:
model = sm.OLS(y, X_tilde).fit()
Конструктор класу OLS
очікує на вхід певний масив X_tilde
, який ми розглядали у нормальному рівнянні. Тому необхідно перетворити масив X
на X_tilde
. Це можна зробити за допомогою функції sm.add_constant()
.
Знаходження параметрів
Після навчання моделі параметри можна легко отримати за допомогою атрибута params
.
import statsmodels.api as sm # import statsmodels import pandas as pd file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Open the file X, y = df['Father'], df['Height'] # Assign the variables # Get the correct form of input for OLS X_tilde = sm.add_constant(X) # Initialize an OLS object regression_model = sm.OLS(y, X_tilde) # Train the object regression_model = regression_model.fit() # Get the paramters beta_0, beta_1 = regression_model.params print('beta_0 is: ', beta_0) print('beta_1 is: ', beta_1)
Виконання прогнозів
Нові екземпляри можна легко передбачити за допомогою методу predict()
, але для них також потрібно виконати попередню обробку вхідних даних:
import statsmodels.api as sm import pandas as pd import numpy as np file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Open the file X, y = df['Father'], df['Height'] # Assign the variables X_tilde = sm.add_constant(X) # Preprocess regression_model = sm.OLS(y, X_tilde) # Initialize an OLS object regression_model = regression_model.fit() # Train the object # Predict new values X_new = np.array([65,70,75]) # Feature values of new instances X_new_tilde = sm.add_constant(X_new) # Preprocess X_new y_pred = regression_model.predict(X_new_tilde) # Predict the target print(y_pred)
Отримання підсумкової інформації
Як ви, ймовірно, помітили, використання класу OLS
не таке просте, як функції polyfit()
. Проте використання OLS
має свої переваги. Під час навчання він обчислює багато статистичної інформації. Доступ до цієї інформації можна отримати за допомогою методу summary()
.
import statsmodels.api as sm import pandas as pd file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Read the file X, y = df['Father'], df['Height'] X_tilde = sm.add_constant(X) # Preprocess X regression_model = sm.OLS(y, X_tilde) # Initialize an OLS object regression_model = regression_model.fit() # Train the object # Print the summary print(regression_model.summary())
Це велика кількість статистичних даних. Найважливіші частини цієї таблиці будуть розглянуті у наступних розділах.
Дякуємо за ваш відгук!