Semi-Supervised Learning


Semi-supervised learning is a technique that combines a small labeled dataset with a large unlabeled dataset to improve model performance. The process involves inferring labels for the unlabeled data based on how labeled classes are structured within the feature space.

Semi-supervised learning assumes that different label classes exhibit clustering or recognizable structure.

img source: https://www.altexsoft.com/blog/semi-supervised-learning/

Key Principle

Instead of adding tags to the entire dataset, you go through and hand-label just a small part of the data and use it to train a model, which then is applied to the ocean of unlabeled data.

Techniques:

Self-Training

img source: https://www.altexsoft.com/blog/semi-supervised-learning/

  1. Train a base model using a small amount of labeled data through supervised methods.
  2. Apply pseudo-labeling, where the partially trained model predicts labels for the unlabeled data.
  3. Select the most confident predictions above a certain threshold and add them to the labeled dataset.
  4. Create a new combined input from the labeled and pseudo-labeled data to train an improved model.
  5. Iterate this process multiple times, adding more pseudo-labels at each iteration to improve model performance.

Co-Training

img source: https://www.altexsoft.com/blog/semi-supervised-learning/

  1. Train separate classifiers for each view using a small amount of labeled data.
  2. Add the larger pool of unlabeled data and generate pseudo-labels.
  3. Co-train the classifiers using the pseudo-labeled data with the highest confidence level.
  4. Update each classifier using confident pseudo-labels assigned by the other classifier.
  5. Combine the predictions from the two updated classifiers to obtain the final classification result.
  6. Iterate this process to create an additional labeled dataset from the unlabeled data.

Label Propagation

Graph-based label propagation

img

  1. Represent the data as a graph, with most points being unlabeled and a few carrying labeled points.
  2. Propagate the colored labels throughout the network using paths connecting each data point.
  3. Count the number of paths leading to different colored nodes to determine the label for each point.
  4. Repeat this process for every point on the graph.

Graph-based label propagation is commonly used in personalization and recommender systems to predict customer interests based on connections between users.


References

  1. https://www.coursera.org/learn/machine-learning-data-lifecycle-in-production/home/week/4
  2. https://www.altexsoft.com/blog/semi-supervised-learning/
  3. https://pages.cs.wisc.edu/~jerryzhu/pub/CMU-CALD-02-107.pdf
  4. https://www.cs.cmu.edu/~avrim/Papers/cotrain.pdf