Insurance — линейная регрессия charges

Классическая учебная задача: по возрасту, BMI, курению и другим признакам предсказать charges — сумму медицинских расходов. Теория: Scikit-learn — регрессия.

Зависимости: pip install pandas scikit-learn matplotlib


1. Загрузка

import pandas as pd

url = "https://raw.githubusercontent.com/stedy/Machine-Learning-with-R-datasets/master/insurance.csv"
df = pd.read_csv(url)
print(df.head())
print(df.dtypes)

2. Подготовка признаков

df = pd.get_dummies(df, columns=["sex", "smoker", "region"], drop_first=True)
X = df.drop(columns="charges")
y = df["charges"]
print(X.columns.tolist())

get_dummies превращает категории в 0/1 — см. кодирование признаков.


3. Train / test и модель

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error, r2_score

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

reg = LinearRegression()
reg.fit(X_train, y_train)
y_pred = reg.predict(X_test)

print("intercept_:", reg.intercept_)
for name, coef in zip(X.columns, reg.coef_):
    print(f"  {name:25s} {coef:+.2f}")

print("MAE:", mean_absolute_error(y_test, y_pred))
print("R²:", r2_score(y_test, y_pred))

Самопроверка: какой коэффициент по модулю больше всего влияет на charges? Ожидаемо сильный вклад smoker_yes.


4. График — charges vs age

import matplotlib.pyplot as plt

plt.figure(figsize=(8, 5))
for smoker, group in df.groupby("smoker_yes"):
    label = "smoker" if smoker else "non-smoker"
    plt.scatter(group["age"], group["charges"], alpha=0.5, label=label, s=20)
plt.xlabel("age")
plt.ylabel("charges")
plt.legend()
plt.grid(alpha=0.3)
plt.title("Расходы vs возраст")
plt.show()

5. Сравнение с SGDRegressor (опционально)

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import SGDRegressor

sgd = Pipeline([
    ("scaler", StandardScaler()),
    ("reg", SGDRegressor(max_iter=5000, tol=1e-3, random_state=42)),
])
sgd.fit(X_train, y_train)
print("SGD MAE:", mean_absolute_error(y_test, sgd.predict(X_test)))

Сравнение МНК и SGD — в 6-02/10.


Дальше