Contenido del Curso
Linear Regression with Python
Linear Regression with Python
Building Linear Regression Using Statsmodels
In the previous chapter, we used a function from NumPy to calculate the parameters.
Now we will use the class object instead of the function to represent the linear regression. This approach takes more lines of code to find the parameters, but it stores a lot of helpful information inside the object and makes the prediction more straightforward.
Building a Linear Regression model
In statsmodels, the OLS
class can be used to create a linear regression model.
We first need to initialize an OLS
class object using
sm.OLS(y, X_tilde)
.
Then train it using the fit()
method.
Which is equivalent to:
Note
The constructor of the
OLS
class expects a specific arrayX_tilde
as an input, which we saw in the Normal Equation. So you need to convert yourX
array toX_tilde
. This is achievable using thesm.add_constant()
function.
Finding parameters
When the model is trained, you can easily access the parameters using the params
attribute.
import statsmodels.api as sm # import statsmodels import pandas as pd file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Open the file X, y = df['Father'], df['Height'] # Assign the variables # Get the correct form of input for OLS X_tilde = sm.add_constant(X) # Initialize an OLS object regression_model = sm.OLS(y, X_tilde) # Train the object regression_model = regression_model.fit() # Get the paramters beta_0, beta_1 = regression_model.params print('beta_0 is: ', beta_0) print('beta_1 is: ', beta_1)
Making the predictions
New instances can easily be predicted using predict()
method, but you need to preprocess the input for them too:
import statsmodels.api as sm import pandas as pd import numpy as np file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Open the file X, y = df['Father'], df['Height'] # Assign the variables X_tilde = sm.add_constant(X) # Preprocess regression_model = sm.OLS(y, X_tilde) # Initialize an OLS object regression_model = regression_model.fit() # Train the object # Predict new values X_new = np.array([65,70,75]) # Feature values of new instances X_new_tilde = sm.add_constant(X_new) # Preprocess X_new y_pred = regression_model.predict(X_new_tilde) # Predict the target print(y_pred)
Getting the summary
As you probably noticed, using the OLS
class is not as easy as the polyfit()
function. But using OLS
has its benefits. While training, it calculates a lot of statistical information. You can access the information using the summary()
method.
import statsmodels.api as sm import pandas as pd file_link = 'https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b22d1166-efda-45e8-979e-6c3ecfc566fc/simple_height_data.csv' df = pd.read_csv(file_link) # Read the file X, y = df['Father'], df['Height'] X_tilde = sm.add_constant(X) # Preprocess X regression_model = sm.OLS(y, X_tilde) # Initialize an OLS object regression_model = regression_model.fit() # Train the object # Print the summary print(regression_model.summary())
That's a lot of statistics. We will discuss the table's most important parts in later sections.
¡Gracias por tus comentarios!