Kursinnehåll
Linjär Regression med Python
Linjär Regression med Python
Bygga Linjär Regression med Statsmodels
I föregående kapitel använde vi en funktion från NumPy för att beräkna parametrarna.
Nu kommer vi att använda klassobjektet istället för funktionen för att representera den linjära regressionen. Detta tillvägagångssätt kräver fler kodrader för att hitta parametrarna, men det lagrar mycket användbar information i objektet och gör prediktionen mer direkt.
Bygga en linjär regressionsmodell
I statsmodels kan klassen OLS
användas för att skapa en linjär regressionsmodell.
Vi behöver först initiera ett OLS
-klassobjekt med hjälp av
sm.OLS(y, X_tilde)
.
Träna sedan modellen med metoden fit()
.
model = sm.OLS(y, X_tilde)
model = model.fit()
Vilket är ekvivalent med:
model = sm.OLS(y, X_tilde).fit()
Konstruktorn för klassen OLS
förväntar sig en specifik array X_tilde
som indata, vilket vi såg i Normalekvationen. Därför behöver du konvertera din X
-array till X_tilde
. Detta kan göras med funktionen sm.add_constant()
.
Hitta parametrar
När modellen är tränad kan du enkelt komma åt parametrarna med attributet params
.
import statsmodels.api as sm # import statsmodels import pandas as pd file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Open the file X, y = df['Father'], df['Height'] # Assign the variables # Get the correct form of input for OLS X_tilde = sm.add_constant(X) # Initialize an OLS object regression_model = sm.OLS(y, X_tilde) # Train the object regression_model = regression_model.fit() # Get the paramters beta_0, beta_1 = regression_model.params print('beta_0 is: ', beta_0) print('beta_1 is: ', beta_1)
Göra förutsägelser
Nya instanser kan enkelt förutsägas med hjälp av metoden predict()
, men du måste även förbehandla indata för dessa:
import statsmodels.api as sm import pandas as pd import numpy as np file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Open the file X, y = df['Father'], df['Height'] # Assign the variables X_tilde = sm.add_constant(X) # Preprocess regression_model = sm.OLS(y, X_tilde) # Initialize an OLS object regression_model = regression_model.fit() # Train the object # Predict new values X_new = np.array([65,70,75]) # Feature values of new instances X_new_tilde = sm.add_constant(X_new) # Preprocess X_new y_pred = regression_model.predict(X_new_tilde) # Predict the target print(y_pred)
Hämta sammanfattningen
Som du förmodligen har märkt är det inte lika enkelt att använda klassen OLS
som funktionen polyfit()
. Men att använda OLS
har sina fördelar. Under träningen beräknas mycket statistisk information. Du kan komma åt informationen med hjälp av metoden summary()
.
import statsmodels.api as sm import pandas as pd file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Read the file X, y = df['Father'], df['Height'] X_tilde = sm.add_constant(X) # Preprocess X regression_model = sm.OLS(y, X_tilde) # Initialize an OLS object regression_model = regression_model.fit() # Train the object # Print the summary print(regression_model.summary())
Det är mycket statistik. Vi kommer att diskutera tabellens viktigaste delar i senare avsnitt.
Tack för dina kommentarer!