Insurance — линейная регрессия charges
Классическая учебная задача: по возрасту, BMI, курению и другим признакам предсказать charges — сумму медицинских расходов. Теория: Scikit-learn — регрессия.
Зависимости: pip install pandas scikit-learn matplotlib
1. Загрузка
import pandas as pd
url = "https://raw.githubusercontent.com/stedy/Machine-Learning-with-R-datasets/master/insurance.csv"
df = pd.read_csv(url)
print(df.head())
print(df.dtypes)
2. Подготовка признаков
df = pd.get_dummies(df, columns=["sex", "smoker", "region"], drop_first=True)
X = df.drop(columns="charges")
y = df["charges"]
print(X.columns.tolist())
get_dummies превращает категории в 0/1 — см. кодирование признаков.
3. Train / test и модель
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error, r2_score
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
reg = LinearRegression()
reg.fit(X_train, y_train)
y_pred = reg.predict(X_test)
print("intercept_:", reg.intercept_)
for name, coef in zip(X.columns, reg.coef_):
print(f" {name:25s} {coef:+.2f}")
print("MAE:", mean_absolute_error(y_test, y_pred))
print("R²:", r2_score(y_test, y_pred))
Самопроверка: какой коэффициент по модулю больше всего влияет на charges? Ожидаемо сильный вклад smoker_yes.
4. График — charges vs age
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 5))
for smoker, group in df.groupby("smoker_yes"):
label = "smoker" if smoker else "non-smoker"
plt.scatter(group["age"], group["charges"], alpha=0.5, label=label, s=20)
plt.xlabel("age")
plt.ylabel("charges")
plt.legend()
plt.grid(alpha=0.3)
plt.title("Расходы vs возраст")
plt.show()
5. Сравнение с SGDRegressor (опционально)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import SGDRegressor
sgd = Pipeline([
("scaler", StandardScaler()),
("reg", SGDRegressor(max_iter=5000, tol=1e-3, random_state=42)),
])
sgd.fit(X_train, y_train)
print("SGD MAE:", mean_absolute_error(y_test, sgd.predict(X_test)))
Сравнение МНК и SGD — в 6-02/10.