Logistic Regression in Python


Logistic Regression is a statistical technique capable of predicting a binary outcome. It's a well-known strategy, widely used in disciplines ranging from credit and finance to medicine to criminology and other social sciences. Logistic regression is fairly intuitive and very effective; you're likely to find it among the first few chapters of a machine learning or applied statistics book and it's usage is covered by many stats courses.
It's not hard to find quality logistic regression examples using R. This tutorial, for example, published by UCLA, is a great resource and one that I've consulted many times. Python is one of the most popular languages for machine learning, and while there are bountiful resources covering topics like Support Vector Machines and text classification using Python, there's far less material on logistic regression.
This is a post about using logistic regression in Python.

Introduction

We'll use a few libraries in the code samples. Make sure you have these installed before you run through the code on your machine.
  • numpy: a language extension that defines the numerical array and matrix
  • pandas: primary package to handle and operate directly on data.
  • statsmodels: statistics & econometrics package with useful tools for parameter estimation & statistical testing
  • pylab: for generating plots
Check out our post on Setting Up Scientific Python if you're missing one or more of these.

Example Use Case for Logistic Regression

We'll be using the same dataset as UCLA's Logit Regression in R tutorial to explore logistic regression in Python. Our goal will be to identify the various factors that may influence admission into graduate school.
The dataset contains several columns which we can use as predictor variables:
  • gpa
  • gre score
  • rank or presitge of an applicant's undergraduate alma mater
The fourth column, admit, is our binary target variable. It indicates whether or not a candidate was admitted our not.

Load the data

Load the data using pandas.read_csv. We now have a DataFrame and can explore the data.
123456789101112131415161718192021
import pandas as pd
import statsmodels.api as sm
import pylab as pl
import numpy as np
# read the data in
df = pd.read_csv("http://www.ats.ucla.edu/stat/data/binary.csv")
# take a look at the dataset
print df.head()
# admit gre gpa rank
# 0 0 380 3.61 3
# 1 1 660 3.67 3
# 2 1 800 4.00 1
# 3 1 640 3.19 4
# 4 0 520 2.93 4
# rename the 'rank' column because there is also a DataFrame method called 'rank'
df.columns = ["admit", "gre", "gpa", "prestige"]
print df.columns
# array([admit, gre, gpa, prestige], dtype=object)

Read more