Bygga Linjär Regression med Statsmodels
I föregående kapitel använde vi en funktion från NumPy för att beräkna parametrarna.
Nu kommer vi att använda klassobjektet istället för funktionen för att representera den linjära regressionen. Detta tillvägagångssätt kräver fler kodrader för att hitta parametrarna, men det lagrar mycket användbar information i objektet och gör prediktionen mer direkt.
Bygga en linjär regressionsmodell
I statsmodels kan klassen OLS
användas för att skapa en linjär regressionsmodell.
Vi behöver först initiera ett OLS
-klassobjekt med hjälp av
sm.OLS(y, X_tilde)
.
Träna sedan modellen med metoden fit()
.
model = sm.OLS(y, X_tilde)
model = model.fit()
Vilket är ekvivalent med:
model = sm.OLS(y, X_tilde).fit()
Konstruktorn för klassen OLS
förväntar sig en specifik array X_tilde
som indata, vilket vi såg i Normalekvationen. Därför behöver du konvertera din X
-array till X_tilde
. Detta kan göras med funktionen sm.add_constant()
.
Hitta parametrar
När modellen är tränad kan du enkelt komma åt parametrarna med attributet params
.
12345678910111213141516import statsmodels.api as sm # import statsmodels import pandas as pd file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Open the file X, y = df['Father'], df['Height'] # Assign the variables # Get the correct form of input for OLS X_tilde = sm.add_constant(X) # Initialize an OLS object regression_model = sm.OLS(y, X_tilde) # Train the object regression_model = regression_model.fit() # Get the paramters beta_0, beta_1 = regression_model.params print('beta_0 is: ', beta_0) print('beta_1 is: ', beta_1)
Göra förutsägelser
Nya instanser kan enkelt förutsägas med hjälp av metoden predict()
, men du måste även förbehandla indata för dessa:
1234567891011121314151617import statsmodels.api as sm import pandas as pd import numpy as np file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Open the file X, y = df['Father'], df['Height'] # Assign the variables X_tilde = sm.add_constant(X) # Preprocess regression_model = sm.OLS(y, X_tilde) # Initialize an OLS object regression_model = regression_model.fit() # Train the object # Predict new values X_new = np.array([65,70,75]) # Feature values of new instances X_new_tilde = sm.add_constant(X_new) # Preprocess X_new y_pred = regression_model.predict(X_new_tilde) # Predict the target print(y_pred)
Hämta sammanfattningen
Som du förmodligen har märkt är det inte lika enkelt att använda klassen OLS
som funktionen polyfit()
. Men att använda OLS
har sina fördelar. Under träningen beräknas mycket statistisk information. Du kan komma åt informationen med hjälp av metoden summary()
.
123456789101112import statsmodels.api as sm import pandas as pd file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Read the file X, y = df['Father'], df['Height'] X_tilde = sm.add_constant(X) # Preprocess X regression_model = sm.OLS(y, X_tilde) # Initialize an OLS object regression_model = regression_model.fit() # Train the object # Print the summary print(regression_model.summary())
Det är mycket statistik. Vi kommer att diskutera tabellens viktigaste delar i senare avsnitt.
Tack för dina kommentarer!
Fråga AI
Fråga AI
Fråga vad du vill eller prova någon av de föreslagna frågorna för att starta vårt samtal
Awesome!
Completion rate improved to 5.26
Bygga Linjär Regression med Statsmodels
Svep för att visa menyn
I föregående kapitel använde vi en funktion från NumPy för att beräkna parametrarna.
Nu kommer vi att använda klassobjektet istället för funktionen för att representera den linjära regressionen. Detta tillvägagångssätt kräver fler kodrader för att hitta parametrarna, men det lagrar mycket användbar information i objektet och gör prediktionen mer direkt.
Bygga en linjär regressionsmodell
I statsmodels kan klassen OLS
användas för att skapa en linjär regressionsmodell.
Vi behöver först initiera ett OLS
-klassobjekt med hjälp av
sm.OLS(y, X_tilde)
.
Träna sedan modellen med metoden fit()
.
model = sm.OLS(y, X_tilde)
model = model.fit()
Vilket är ekvivalent med:
model = sm.OLS(y, X_tilde).fit()
Konstruktorn för klassen OLS
förväntar sig en specifik array X_tilde
som indata, vilket vi såg i Normalekvationen. Därför behöver du konvertera din X
-array till X_tilde
. Detta kan göras med funktionen sm.add_constant()
.
Hitta parametrar
När modellen är tränad kan du enkelt komma åt parametrarna med attributet params
.
12345678910111213141516import statsmodels.api as sm # import statsmodels import pandas as pd file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Open the file X, y = df['Father'], df['Height'] # Assign the variables # Get the correct form of input for OLS X_tilde = sm.add_constant(X) # Initialize an OLS object regression_model = sm.OLS(y, X_tilde) # Train the object regression_model = regression_model.fit() # Get the paramters beta_0, beta_1 = regression_model.params print('beta_0 is: ', beta_0) print('beta_1 is: ', beta_1)
Göra förutsägelser
Nya instanser kan enkelt förutsägas med hjälp av metoden predict()
, men du måste även förbehandla indata för dessa:
1234567891011121314151617import statsmodels.api as sm import pandas as pd import numpy as np file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Open the file X, y = df['Father'], df['Height'] # Assign the variables X_tilde = sm.add_constant(X) # Preprocess regression_model = sm.OLS(y, X_tilde) # Initialize an OLS object regression_model = regression_model.fit() # Train the object # Predict new values X_new = np.array([65,70,75]) # Feature values of new instances X_new_tilde = sm.add_constant(X_new) # Preprocess X_new y_pred = regression_model.predict(X_new_tilde) # Predict the target print(y_pred)
Hämta sammanfattningen
Som du förmodligen har märkt är det inte lika enkelt att använda klassen OLS
som funktionen polyfit()
. Men att använda OLS
har sina fördelar. Under träningen beräknas mycket statistisk information. Du kan komma åt informationen med hjälp av metoden summary()
.
123456789101112import statsmodels.api as sm import pandas as pd file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Read the file X, y = df['Father'], df['Height'] X_tilde = sm.add_constant(X) # Preprocess X regression_model = sm.OLS(y, X_tilde) # Initialize an OLS object regression_model = regression_model.fit() # Train the object # Print the summary print(regression_model.summary())
Det är mycket statistik. Vi kommer att diskutera tabellens viktigaste delar i senare avsnitt.
Tack för dina kommentarer!