Як використовувати розповсюдження міток Scikit-Learn для структурованих даних графіків?


11

У рамках мого дослідження мені цікаво розповсюдження міток на графіку. Мене особливо цікавлять ці два методи:

  • Сяоїн Чжу та Зубін Гахрамані. Навчання з мічених та не маркованих даних із розповсюдженням міток. Технічний звіт CMU-CALD-02-107, Університет Карнегі Меллон, 2002 http://pages.cs.wisc.edu/~jerryzhu/pub/CMU-CALD-02-107.pdf
  • Денджонг Чжоу, Олів'є Буске, Томас Навін Лал, Джейсон Вестон, Бернхард Шелкопфф. Навчання з місцевою та глобальною послідовністю (2004) http://citeseer.ist.psu.edu/viewdoc/summary?doi=10.1.1.115.3219

Я бачив, що scikit-learn пропонують модель зробити це. Однак ця модель повинна застосовуватися для векторних структурованих даних ( тобто точок даних).

Модель будує матрицю спорідненості з точок даних за допомогою ядра, а потім запускає алгоритм на побудованій матриці. Я хотів би мати можливість безпосередньо ввести матрицю суміжності мого графіка замість матриці подібності.

Будь-яка ідея, як цього досягти? Або вам відома будь-яка бібліотека Python, яка дозволить запускати розповсюдження міток безпосередньо на структурованих графіком даних для двох вищезгаданих методів?

Заздалегідь дякую за вашу допомогу!


Ви перевірили вихідний код Scikit-learn, щоб побачити, що він робить після обчислення матриці спорідненості? Можливо, ви могли б "скопіювати" код після цієї частини, щоб застосувати його безпосередньо до вашої матриці суміжності.
Тасос

Дякуємо за Ваш коментар! Тож насправді це я зараз роблю, але деякі частини коду, які мені потрібно змінити, щоб відповідати моїм потребам, є дещо виразними. Боюся, що перезапис цих частин призведе до помилок. Я сподівався, що існує більш простий метод.
Тібо Мартінес

1
Вихідний код на веб- сайті github.com/scikit-learn/scikit-learn/blob/7389dba/sklearn/… - говорить, що реалізація повинна замінити метод _build_graph. Тому практично слід спробувати створити похідний клас, який приймає попередньо обчислену матрицю.
mikolai

Відповіді:


2

Відповідаючи тут на власне запитання, я сподіваюся, що він буде корисний деяким читачам.

Scikit-learn призначений насамперед для роботи з векторними структурованими даними. Отже, якщо ви хочете виконати розповсюдження міток / розповсюдження міток на даних, структурованих за графами, вам, ймовірно, краще самостійно реалізовувати метод, а не використовувати інтерфейс Scikit.

Ось реалізація розповсюдження міток та розповсюдження міток у PyTorch.

Дві методи в цілому дотримуються одних і тих же алгоритмічних кроків, з варіаціями того, як нормалізується матриця суміжності та як розповсюджуються мітки на кожному кроці. Отже, давайте створимо базовий клас для наших двох моделей.

from abc import abstractmethod
import torch

class BaseLabelPropagation:
    """Base class for label propagation models.

    Parameters
    ----------
    adj_matrix: torch.FloatTensor
        Adjacency matrix of the graph.
    """
    def __init__(self, adj_matrix):
        self.norm_adj_matrix = self._normalize(adj_matrix)
        self.n_nodes = adj_matrix.size(0)
        self.one_hot_labels = None 
        self.n_classes = None
        self.labeled_mask = None
        self.predictions = None

    @staticmethod
    @abstractmethod
    def _normalize(adj_matrix):
        raise NotImplementedError("_normalize must be implemented")

    @abstractmethod
    def _propagate(self):
        raise NotImplementedError("_propagate must be implemented")

    def _one_hot_encode(self, labels):
        # Get the number of classes
        classes = torch.unique(labels)
        classes = classes[classes != -1]
        self.n_classes = classes.size(0)

        # One-hot encode labeled data instances and zero rows corresponding to unlabeled instances
        unlabeled_mask = (labels == -1)
        labels = labels.clone()  # defensive copying
        labels[unlabeled_mask] = 0
        self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
        self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)
        self.one_hot_labels[unlabeled_mask, 0] = 0

        self.labeled_mask = ~unlabeled_mask

    def fit(self, labels, max_iter, tol):
        """Fits a semi-supervised learning label propagation model.

        labels: torch.LongTensor
            Tensor of size n_nodes indicating the class number of each node.
            Unlabeled nodes are denoted with -1.
        max_iter: int
            Maximum number of iterations allowed.
        tol: float
            Convergence tolerance: threshold to consider the system at steady state.
        """
        self._one_hot_encode(labels)

        self.predictions = self.one_hot_labels.clone()
        prev_predictions = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)

        for i in range(max_iter):
            # Stop iterations if the system is considered at a steady state
            variation = torch.abs(self.predictions - prev_predictions).sum().item()

            if variation < tol:
                print(f"The method stopped after {i} iterations, variation={variation:.4f}.")
                break

            prev_predictions = self.predictions
            self._propagate()

    def predict(self):
        return self.predictions

    def predict_classes(self):
        return self.predictions.max(dim=1).indices

Модель бере на вхід матрицю примикання графіка, а також мітки вузлів. Мітки мають форму вектора цілого числа, що вказує номер класу кожного вузла з -1 у положенні немечених вузлів.

Алгоритм розповсюдження міток представлений нижче.

W: adjacency matrix of the graph Compute the diagonal degree matrix D by DiijWij Initialize Y^(0)(y1,,yl,0,0,,0) Iterate  1. Y^(t+1)D1WY^(t) 2. Y^l(t+1)Yl until convergence to Y^() Label point xi by the sign of y^i()

Від Сяоїна Чжу та Зубіна Гахрамані. Навчання з мічених та не маркованих даних із розповсюдженням міток. Технічний звіт CMU-CALD-02-107, Університет Карнегі Меллон, 2002

Ми отримуємо таку реалізацію.

class LabelPropagation(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1 * W"""
        degs = adj_matrix.sum(dim=1)
        degs[degs == 0] = 1  # avoid division by 0 error
        return adj_matrix / degs[:, None]

    def _propagate(self):
        self.predictions = torch.matmul(self.norm_adj_matrix, self.predictions)

        # Put back already known labels
        self.predictions[self.labeled_mask] = self.one_hot_labels[self.labeled_mask]

    def fit(self, labels, max_iter=1000, tol=1e-3):
        super().fit(labels, max_iter, tol)

Алгоритм розповсюдження міток:

W: adjacency matrix of the graph Compute the diagonal degree matrix D by DiijWij Compute the normalized graph Laplacian LD1/2WD1/2 Initialize Y^(0)(y1,,yl,0,0,,0) Choose a parameter α[0,1) Iterate Y^(t+1)αLY^(t)+(1α)Y^(0) until convergence to Y^() Label point xi by the sign of y^i()

З Дендьонг Чжоу, Олів'є Буске, Томаса Навіна Лал, Джейсона Вестона, Бернгарда Шелкопфа. Навчання з місцевою та глобальною послідовністю (2004)

Отже, реалізація полягає в наступному.

class LabelSpreading(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)
        self.alpha = None

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1/2 * W * D^-1/2"""
        degs = adj_matrix.sum(dim=1)
        norm = torch.pow(degs, -0.5)
        norm[torch.isinf(norm)] = 1
        return adj_matrix * norm[:, None] * norm[None, :]

    def _propagate(self):
        self.predictions = (
            self.alpha * torch.matmul(self.norm_adj_matrix, self.predictions)
            + (1 - self.alpha) * self.one_hot_labels
        )

    def fit(self, labels, max_iter=1000, tol=1e-3, alpha=0.5):
        """
        Parameters
        ----------
        alpha: float
            Clamping factor.
        """
        self.alpha = alpha
        super().fit(labels, max_iter, tol)

Тепер перевіримо наші моделі поширення на синтетичних даних. Для цього ми обираємо графік печерного людини .

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

# Create caveman graph
n_cliques = 4
size_cliques = 10
caveman_graph = nx.connected_caveman_graph(n_cliques, size_cliques)
adj_matrix = nx.adjacency_matrix(caveman_graph).toarray()

# Create labels
labels = np.full(n_cliques * size_cliques, -1.)

# Only one node per clique is labeled. Each clique belongs to a different class.
labels[0] = 0
labels[size_cliques] = 1
labels[size_cliques * 2] = 2
labels[size_cliques * 3] = 3

# Create input tensors
adj_matrix_t = torch.FloatTensor(adj_matrix)
labels_t = torch.LongTensor(labels)

# Learn with Label Propagation
label_propagation = LabelPropagation(adj_matrix_t)
label_propagation.fit(labels_t)
label_propagation_output_labels = label_propagation.predict_classes()

# Learn with Label Spreading
label_spreading = LabelSpreading(adj_matrix_t)
label_spreading.fit(labels_t, alpha=0.8)
label_spreading_output_labels = label_spreading.predict_classes()

# Plot graphs
color_map = {-1: "orange", 0: "blue", 1: "green", 2: "red", 3: "cyan"}
input_labels_colors = [color_map[l] for l in labels]
lprop_labels_colors = [color_map[l] for l in label_propagation_output_labels.numpy()]
lspread_labels_colors = [color_map[l] for l in label_spreading_output_labels.numpy()]

plt.figure(figsize=(14, 6))
ax1 = plt.subplot(1, 4, 1)
ax2 = plt.subplot(1, 4, 2)
ax3 = plt.subplot(1, 4, 3)

ax1.title.set_text("Raw data (4 classes)")
ax2.title.set_text("Label Propagation")
ax3.title.set_text("Label Spreading")

pos = nx.spring_layout(caveman_graph)
nx.draw(caveman_graph, ax=ax1, pos=pos, node_color=input_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax2, pos=pos, node_color=lprop_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax3, pos=pos, node_color=lspread_labels_colors, node_size=50)

# Legend
ax4 = plt.subplot(1, 4, 4)
ax4.axis("off")
legend_colors = ["orange", "blue", "green", "red", "cyan"]
legend_labels = ["unlabeled", "class 0", "class 1", "class 2", "class 3"]
dummy_legend = [ax4.plot([], [], ls='-', c=c)[0] for c in legend_colors]
plt.legend(dummy_legend, legend_labels)

plt.show()

Реалізовані моделі працюють правильно і дозволяють виявляти спільноти на графіку.

Реалізація розповсюдження міток та розповсюдження міток, протестовані на графіку печерних людей

Примітка. Представлені методи поширення мають використовуватися на ненаправлених графіках.

Код доступний як інтерактивний ноутбук Юпітера тут .

Біндер

Використовуючи наш веб-сайт, ви визнаєте, що прочитали та зрозуміли наші Політику щодо файлів cookie та Політику конфіденційності.
Licensed under cc by-sa 3.0 with attribution required.