- predicting whether a student will be admitted or not
- The logistic regression predicts the probability of an event occurring
- Predicting categorical outcomes through a logistic regression:
- yes or no
- 0 or 1
- will be admitted or not
- Create a scatter plot (with a logistic regression curve)
Note: the dependent variable is 'Admitted' & the independent variable is 'SAT'
- Importing the relevant libraries
- Loading data
- Dummy Variable
- Declaring the dependent and independent variables
- Adding a Constant
- Creating a Logit Regression
- Fitting the Model
- Logistic Regression Curve
- Creating a Logit Function
- Sorting the y and x to Plot the Curve
- Plotting the Logistic Regression Curve
- Scatter Plot with the Logistic Regression Curve
import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
url = 'https://datascienceschools.github.io/Machine_Learning/StatsModel/admission.csv'
df = pd.read_csv(url)
df.head()
- Replace all No entries with 0, and all Yes entries with 1
data = df.copy()
data['Admitted'] = data['Admitted'].map({'Yes': 1, 'No': 0})
data.head()
y = data['Admitted']
x = data['SAT']
x_constant = sm.add_constant(x)
model = sm.Logit(y,x_constant)
results = model.fit()
def f(x_constant,b0,b1):
return np.array(np.exp(b0+x_constant*b1) / (1 + np.exp(b0+x_constant*b1)))
f_sorted = np.sort(f(x,results.params[0],results.params[1]))
x_sorted = np.sort(np.array(x))
plt.plot(x_sorted,f_sorted,color='red')
- Plotting the Logistic Regression Curve: plt.plot(x_sorted,f_sorted,color='red')
- when the score is relatively low -> The probability of getting admitted is 0
- when the score is relatively high-> The probability of getting admitted is 1
- a score in between 6500 and 1750 is uncertain
- score 1650 -> 0.5 or 50 percent chance of getting admitted
- score 7300 -> 0.8 or 80 percent chance of getting admitted
plt.scatter(x,y,color='C0')
plt.xlabel('SAT', fontsize = 20)
plt.ylabel('Admitted', fontsize = 20)
plt.plot(x_sorted,f_sorted,color='red')
plt.show()