Python seaborn - Regression Plots

# Seaborn — Regression Plots
# Regression plot = scatter plot with a trend line drawn through the data
# Trend line shows the overall direction/pattern in the data
# Used to answer: "as X increases, what happens to Y?"
#
# Two main functions:
# regplot → single plot, simple to use
# lmplot → grid support, can split by category using col= row= hue=

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
# ── Sample data used throughout ───────────────────────────────────────────────
df = pd.DataFrame({
"age": [22, 25, 28, 30, 33, 35, 38, 40, 43, 45],
"salary": [30000, 35000, 42000, 50000, 55000, 62000, 70000, 75000, 82000, 90000],
"score": [90, 85, 80, 78, 74, 70, 65, 60, 55, 50],
"dept": ["HR","HR","IT","IT","HR","Finance","Finance","IT","HR","Finance"]
})

# ══════════════════════════════════════════════════════════════════════════════
# ── 1. regplot — Scatter + Trend Line ────────────────────────────────────────
# regplot draws:
# dots → each data point
# line → best fit trend line through the data
# shaded band around line → confidence interval (how reliable the line is)
# wider band = less confident, narrower = more confident
# ══════════════════════════════════════════════════════════════════════════════

sns.regplot(x="age", y="salary", data=df)
plt.title("Age vs Salary with Trend Line")
plt.show() # Output: dots going up-right with a rising trend line — salary increases with age

# ── ci= — confidence interval band ───────────────────────────────────────────
# ci = confidence interval — the shaded area around the trend line
# ci=95 (default) → 95% confident the true line is within this band
# ci=None → hide the shaded band
sns.regplot(x="age", y="salary", data=df, ci=95)
plt.title("regplot with 95% Confidence Band")
plt.show() # Output: trend line with shaded band — shows uncertainty of the line

sns.regplot(x="age", y="salary", data=df, ci=None)
plt.title("regplot without Confidence Band")
plt.show() # Output: clean trend line and dots, no shaded area

# ── color and marker styling ──────────────────────────────────────────────────
sns.regplot(x="age", y="salary", data=df,
color="green", # color of dots and line
scatter_kws={"color": "blue", "s": 80}, # scatter_kws — style the dots separately
line_kws={"color": "red", "linewidth": 2}) # line_kws — style the line separately
plt.title("Styled regplot")
plt.show() # Output: blue dots, red trend line

# ── negative trend — score decreases as age increases ────────────────────────
sns.regplot(x="age", y="score", data=df)
plt.title("Age vs Score — Negative Trend")
plt.show() # Output: dots going down-right with a falling trend line — score drops with age

# ══════════════════════════════════════════════════════════════════════════════
# ── 2. lmplot — regplot with grid support ────────────────────────────────────
# lm = linear model
# lmplot works exactly like regplot BUT supports:
# hue= → separate trend line per category (different colors)
# col= → one plot per category, side by side
# row= → one plot per category, stacked
# ══════════════════════════════════════════════════════════════════════════════

# ── basic lmplot — same output as regplot ─────────────────────────────────────
sns.lmplot(x="age", y="salary", data=df)
plt.title("Basic lmplot")
plt.show() # Output: scatter + trend line — same as regplot

# ── hue= — separate trend line per department ────────────────────────────────
# draws one line per unique value in the hue column — each in a different color
sns.lmplot(x="age", y="salary", data=df, hue="dept")
plt.title("Trend Line per Department")
plt.show() # Output: dots and lines colored by dept — HR/IT/Finance each get their own line

# ── col= — one plot per department, side by side ─────────────────────────────
sns.lmplot(x="age", y="salary", data=df, col="dept")
plt.suptitle("Age vs Salary — one plot per Dept", y=1.02)
plt.show() # Output: 3 separate plots side by side — one trend line per dept

# ── col= + hue= — split by dept, color by dept ───────────────────────────────
sns.lmplot(x="age", y="salary", data=df, col="dept", hue="dept")
plt.suptitle("Dept split + colored", y=1.02)
plt.show() # Output: 3 plots side by side, each dept's line and dots in its own color

# ── height= and aspect= — control plot size ───────────────────────────────────
sns.lmplot(x="age", y="salary", data=df, col="dept", height=4, aspect=1)
plt.suptitle("Larger plots per Dept", y=1.02)
plt.show() # Output: same 3 plots but each is taller and wider
# ══════════════════════════════════════════════════════════════════════════════
# ── 3. residplot — Shows errors of the trend line ────────────────────────────
# residual = how far each actual data point is from the trend line
# dot above zero line → actual value is HIGHER than what line predicted
# dot below zero line → actual value is LOWER than what line predicted
# If dots are randomly scattered → trend line is a good fit
# If dots show a pattern → trend line is missing something
# ══════════════════════════════════════════════════════════════════════════════

sns.residplot(x="age", y="salary", data=df, color="purple")
plt.axhline(0, color="gray", linestyle="--") # axhline — draws a horizontal line at 0
plt.title("Residual Plot — How far each point is from the trend line")
plt.xlabel("Age")
plt.ylabel("Residual (actual − predicted)")
plt.show() # Output: dots scattered above and below zero line — random = good fit
# ══════════════════════════════════════════════════════════════════════════════
# ── regplot vs lmplot ─────────────────────────────────────────────────────────
# ┌──────────────┬────────────────────────────────┬──────────────────────────────┐
# │ │ regplot │ lmplot │
# ├──────────────┼────────────────────────────────┼──────────────────────────────┤
# │ Basic use │ single scatter + trend line │ same as regplot │
# │ hue= │ not supported │ separate line per category │
# │ col= / row= │ not supported │ grid of plots by category │
# │ Best for │ quick single plot │ comparing groups / categories│
# └──────────────┴────────────────────────────────┴──────────────────────────────┘
# ══════════════════════════════════════════════════════════════════════════════

# ══════════════════════════════════════════════════════════════════════════════
# ── Quick Reference ───────────────────────────────────────────────────────────
# ══════════════════════════════════════════════════════════════════════════════

# sns.regplot(x=, y=, data=) → scatter + single trend line
# sns.regplot(..., ci=95) → confidence band (default 95%)
# sns.regplot(..., ci=None) → hide confidence band
# sns.regplot(..., scatter_kws={}) → style the dots (color, size)
# sns.regplot(..., line_kws={}) → style the line (color, width)
#
# sns.lmplot(x=, y=, data=) → same as regplot
# sns.lmplot(..., hue="col") → separate trend line per category
# sns.lmplot(..., col="col") → one plot per category (side by side)
# sns.lmplot(..., row="col") → one plot per category (stacked)
# sns.lmplot(..., height=4, aspect=1) → control plot size
#
# sns.residplot(x=, y=, data=) → shows how far each point is from trend line
# plt.axhline(0, linestyle="--") → draws zero reference line
#
# plt.show() → display the plot


No comments:

Post a Comment

Please comment below to feedback or ask questions.