Визуализация регрессионных моделей

Давайте сначала выполним импортирование всех пакетов, которые могут нам понадобиться в дальнейшей работе:

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

И сразу загрузим данные, на основе которых будем делать примеры:

tips = sns.load_dataset("tips")
tips.head()
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4

Это данные о размере чаевых в ресторане, подробное описание которых можно посмотреть здесь.

Для визуализации регрессионных моделей в Seaborn есть две функции: regplot() и lmplot(). Первая строит графики на уровне области Axes и может принимать в качестве данных любые форматы последовательности чисел (списки и кортежи Python, массивы NumPy, серии Pandas):

sns.regplot(x = list(range(100)),
            y = 0.01*np.arange(100) + np.random.rand(100));

График линейной регрессии, построенный в Seaborn с помощью функции regplot

Эта функция может оказаться очень удобной, если нужно быстро взглянуть на какие-то данные, которые часто получаются в некоторых промежуточных вычислениях. А вот функция lmplot() принимает данные только в виде объектов DataFrame, а в качестве переменных x и y строки имен соответствующих столбцов:

sns.lmplot(x='total_bill',
           y='tip',
           data=tips);

График линейной регрессии, построенный в Seaborn с помощью функции lmplot

Так же, функция lmplot() обладает большей функциональностью (простите за тафтологию), которую мы продемонстрируем в дальнейшем. А пока давайте двигаться дальше.

Иногда, данные по одной из осей могут принимать дискретное значение, в этих случаях можно так же построить линейную регрессию, но точки будут сливаться вместе:

f, ax = plt.subplots(figsize=(14, 7))

sns.regplot(x='size',
            y='total_bill',
            data=tips);

График линейной регрессии для дискретных данных, построенный в Seaborn с помощью функции regplot

Как видите, чем больше людей за столиком, тем больше общая стоимость заказа. Но улучшить восприятие данного графика можно с помощью параметра x_jitter который добавляет точкам небольшое горизонтальное рассеяние:

f, ax = plt.subplots(figsize=(14, 7))

sns.regplot(x='size',
            y='total_bill',
            x_jitter=0.05,
            data=tips);

График линейной регрессии для дискретных данных с параметром x_jitter, построенный в Seaborn с помощью функции regplot

Еще, вместо добавления горизонтального "дрожания" ко всем точкам внутри каждой категории, можно отметить одной точкой их среднее значение вместе с доверительным интервалом этого значения:

sns.set_style('whitegrid')
f, ax = plt.subplots(figsize=(14, 7))

sns.regplot(x='size',
            y='total_bill',
            x_estimator=np.mean,
            data=tips);

График линейной регрессии для дискретных данных с параметром x_estimator, построенный в Seaborn с помощью функции regplot

В данном случае мы еще добавили команду sns.set_style('whitegrid') благодаря которой улучшилось восприятие графика.

Если в данных наблюдается полиномиальная тенденция, то построить соответствующую линию регрессии можно с помощью параметра order:

x_data = np.linspace(0, 10, 100)
y_data = -0.5*x_data**2 +3*x_data + 2 + 10*np.random.rand(100)

f, ax = plt.subplots(figsize=(14, 7))

sns.regplot(x=x_data,
            y=y_data,
            order=2,
            scatter_kws={'s': 20});

График полиномиальной регрессии, построенный в Seaborn с помощью функции regplot

Еще бывает так, что среди данных попадается несколько аномально больших или аномально маленьких значений, которые могут довольно сильно влиять на регрессионую модель:

x_data = np.linspace(0, 10, 20)
y_data = -0.5*x_data +3 + np.random.rand(20)

# добавим несколько аномальных значений:
y_data[[-3, -4]] += 5

f, ax = plt.subplots(figsize=(14, 7))

sns.regplot(x=x_data,
            y=y_data,
            scatter_kws={'s': 20});

График линейной регрессии для данных с выбросами без параметра robust, построенный в Seaborn с помощью функции regplot

В таком случае, с помощью параметра robust можно установить устойчивую к выбросам модель регрессии (необходим пакет statsmodels):

x_data = np.linspace(0, 10, 20)
y_data = -0.5*x_data +3 + np.random.rand(20)

# добавим несколько аномальных значений:
y_data[[-3, -4]] += 5

f, ax = plt.subplots(figsize=(14, 7))

sns.regplot(x=x_data,
            y=y_data,
            robust=True,
            scatter_kws={'s': 20});

График линейной регрессии для данных с выбросами с параметром robust, построенный в Seaborn с помощью функции regplot

Линия регрессии теперь выглядит лучше, но вот доверительный интервал по прежнему учитывает выбросы. К сожалению, привести в порядок доверительный интервал только лишь силами seaborn не получится, но его всегда можно убрать, приравняв параметр ci значению None:

f, ax = plt.subplots(figsize=(14, 7))

sns.regplot(x=x_data,
            y=y_data,
            robust=True,
            ci=None,
            scatter_kws={'s': 20});

График линейной регрессии для данных с выбросами с параметром robust без доверительного интервала, построенный в Seaborn с помощью функции regplot

Довольно часто значения могут разделяться на два подмножества (две категории, два класса), в этих случаях линейная регрессия так же работает и по сути пытается разделить эти множества. Давайте создадим какие-нибудь искусственные данные для примера:

v = np.random.randint(0, 2, 100)
t = np.array([np.random.normal(10, 3) if i == 0 
              else np.random.normal(20, 6) for i in v])

А теперь взглянем как это выглядит:

График линейной регрессии для бинарных данных, построенный в Seaborn с помощью функции regplot

Такие данные могут иметь реальный контекст, например, по оси \(x\) может быть отмечено время пребывания пользователя на страницах интернет-магазина, а по оси \(y\) может быть отмечен факт совершения покупки: 1 - покупка совершена, 0 - нет. В таком случае лучше воспользоваться логистической регрессией, которая позволит оценить вероятность совершения покупки:

f, ax = plt.subplots(figsize=(14, 7))

sns.regplot(x=t,
            y=v,
            logistic=True);

График логистической регрессии для бинарных данных, построенный в Seaborn с помощью функции regplot

На графике видно, что если продолжительность посещения интернет-магазина больше 17 минут, то вероятность совершения покупки больше 0.8

Наверное вы обратили внимание, на то что построение логистической регрессии и регрессии устойчивой к выбросом выполняется гораздо дольше чем для обычной линейной. Это связано с более сложными вычислениями, к тому же, доверительный интервал строится bootstrap-методом, что тоже занимает довольно много времени. Если график строится очень долго, то можно вообще отключить вычисление и нанесение на график доверительного итервала, используя ci=None.

Еще один подход, заключается в построении непараметрической регрессии, которая позволяет выявлять более сложные регрессионные зависимости:

fig = plt.figure()

ax_1 = fig.add_subplot(1, 2, 1)
ax_2 = fig.add_subplot(1, 2, 2)

sns.regplot(x='total_bill', 
           y='tip',
           data=tips,
           lowess=False,
           scatter_kws={"color": "0.4"},
           line_kws={"color": "red"},
           ax=ax_1);

sns.regplot(x='total_bill',
            y='tip',
            data=tips,
            lowess=True,
            scatter_kws={"color": "0.4"},
            line_kws={"color": "red"},
            ax=ax_2);

ax_1.set_title('Линейная регрессия',
               fontsize=15)
ax_2.set_title('Непараметрическая регрессия',
               fontsize=15)

fig.set_figwidth(14)
fig.set_figheight(7)

plt.show()

График непараметрической регрессии, построенный в Seaborn с помощью функции regplot

Однако, вычисление непараметрической регрессии требует еще больше вычислительных ресурсов, поэтому доверительные интервалы для нее вообще не вычисляются.

Еще одним полезным инструментом является функция residplot(), которая позволяет визуально оценить приемлемость простой линейной регрессии. Для примера давайте вернемся к вышесгенирированным данным:

x_data = np.linspace(0, 10, 20)
y_data = -0.5*x_data +3 + np.random.rand(20)


f, ax = plt.subplots(figsize=(14, 7))

sns.regplot(x=x_data,
            y=y_data,
            scatter_kws={'s': 20});

График линейной регрессии, который является приемлемым для данных, построенный в Seaborn с помощью функции regplot

Функция residplot() изображает линию регрессии в горизонтальном виде, а остатки в виде точек и если эти точки равномерно сгруппированы по обе стороны прямой, то линейную модель можно считать приемлемой:

f, ax = plt.subplots(figsize=(14, 7))

sns.residplot(x=x_data,
             y=y_data);

График residplot(), показывающий, что линейная регрессия является приемлемой для некоторых данных, построенный в Seaborn

А теперь давайте сгенерируем другие данные:

x_data = np.linspace(0, 10, 100)
y_data = -0.5*x_data**2 +3*x_data + 2 + 10*np.random.rand(100)

f, ax = plt.subplots(figsize=(14, 7))

sns.regplot(x=x_data,
            y=y_data,
            scatter_kws={'s': 20});

График линейной регрессии, который является неприемлемым для данных, построенный в Seaborn с помощью функции regplot

И взглянем на остатки:

f, ax = plt.subplots(figsize=(14, 7))

sns.residplot(x=x_data,
              y=y_data);

График residplot(), показывающий, что линейная регрессия является неприемлемой для данных, построенный в Seaborn

Если в остатках видна структура (а она видна), то линейная модель не является приемлемой. Например, для зависимости размера чаевых от общей стоимости заказа мы можем заключить, что линейная регрессия является вполне приемлемой:

f, ax = plt.subplots(figsize=(14, 7))

sns.residplot(x=tips['total_bill'],
              y=tips['tip']);

График residplot() для реальных данных, построенный в Seaborn

Однако, на этом же графике видно, что множество точек выстраиваются в наклонные линии, поэтому приемлемость линейной регрессии под вопросом.


Выделение подмножеств значений

Теперь давайте рассмотрим функцию lmplot(), которая позволяет визульно выделять подмножества значений переменных и строить многосегментные (решетчатые) графики.

Мы уже видели, как строилась линейная регрессия для зависимости размера чаевых от общей стоимости заказа:

sns.regplot(x='total_bill', 
           y='tip',
           data=tips,);

График линейной регрессии, построенный с помощью функции regplot() в Seaborn

А вот как данный график может выглядеть если мы воспользуемся функцией lmplot() с параметром hue:

sns.lmplot(x='total_bill',
           y='tip',
           hue='sex',
           data=tips);

График линейной регрессии, построенный с помощью функции lmplot() с параметром hue в Seaborn

В данном случае мы выделили цветом два подмножества: мужчин и женщин оплачивавших заказ и решавших сколько оставить чаевых. Как видите и те и другие, похоже руководствуются одними и теми же принципами при определении размера чаевых. Различие подмножеств можно сделать еще сильнее если изменить вид маркеров и установить необходимую палитру цветов:

sns.lmplot(x='total_bill',
           y='tip',
           hue='sex',
           markers=['o', 'x'],
           palette='Dark2',
           data=tips);

График линейной регрессии, построенный с помощью функции lmplot() с параметрами markers и palette в Seaborn

Но самый надежный способ улучшения восприятия информации о подмножествах - это построение многосегментных графиков. Например, если имя столбца с информацией о подмножествах передать параметру col, то графики для каждого подмножества будут разделены по разным столбцам:

sns.lmplot(x='total_bill',
           y='tip',
           col='sex',
           data=tips);

График линейной регрессии, построенный с помощью функции lmplot() с параметром col в Seaborn

Более того, параметр row позволяет добавить еще один уровень, для другой переменной. Например вот так можно посмотреть на зависимость размера чаевых от пола и времени дня:

sns.lmplot(x='total_bill',
           y='tip',
           col='sex',
           row='time',
           data=tips);

График линейной регрессии, построенный с помощью функции lmplot() с параметрами col и row в Seaborn

При использовании lmplot() может возникнуть необходимость в настройке размеров графика, а так же более компактного размещения подграфиков. Например, вот так будет выглядеть график для размера чаевых в зависимости от заказа и количества людей за столиком, без каких-либо настроек:

sns.lmplot(x='total_bill',
           y='tip',
           col='size',
           ci=None,
           data=tips);

График линейной регрессии, построенный с помощью функции lmplot() без настроек размера в Seaborn

Управлять размером многосегментного графика можно с помощью параметров height и aspect:

sns.lmplot(x='total_bill',
           y='tip',
           col='size',
           ci=None,
           height=4,
           aspect=0.5,
           data=tips);

График линейной регрессии, построенный с помощью функции lmplot() с настройками размера в Seaborn

Но если количество категорий слишком велико и получается много колонок, то лучше организовать их перенос на следующую строку с помощью параметра col_wrap:

sns.lmplot(x='total_bill',
           y='tip',
           col='size',
           ci=None,
           col_wrap=3,
           height=4,
           aspect=1.1,
           data=tips);

График линейной регрессии, построенный с помощью функции lmplot() с настройками размера и переносом графиков на следующую строку в Seaborn

Если воспользоваться функцией jointplot() то можно изобразить регрессию с дополнительной информацией для каждой переменной:

sns.jointplot(x='size',
              y='tip',
              data=tips,
              kind='reg');

График линейной регрессии, построенный с помощью функции jointplot() в Seaborn