blog bg left
Back to Blog

How to Troubleshoot Embeddings Without Eye-balling t-SNE or UMAP Plots

Production data systems are often centered around structured tabular data - that is, data that can be largely organized into rows and columns of primitive data types. Columns in tabular data may be related to one another, but each column can largely be understood independently. This makes it easy to create and interpret relevant metrics that describe the data in aggregate, such as means, quantiles, and cardinality.

This isn’t the ideal way to handle categorical encodings, GPS coordinates, personal identifiable information (PII) nor more complex data types such as images, text, and audio. For this data, you need to manipulate and aggregate multiple columns at the same time to make sense of it. So we turn to vectors and embeddings. High dimensional embeddings are also what facilitate advanced machine learning models such as ChatGPT and stable diffusion.

Embeddings are heavily used in machine learning for a variety of data types and tasks as inputs, intermediate products, and outputs. Here are just a few:

Natural language understanding and text analysis

  • Sentiment analysis
  • Document classification
  • Text generation

Computer vision and image processing

  • Manufacturing quality assurance
  • Autonomous driving

Audio processing

  • Text-to-speech models
  • Speaker identification

Tabular machine learning

  • Product recommendation
  • Anonymization and privacy

WhyLabs recently released features that make it even easier to profile and monitor high dimensional embeddings data in both whylogs and the WhyLabs Observability Platform. And it doesn’t require you to explore data by hand.

Want to try it out? Sign up for a free WhyLabs account or check out our Jupyter notebook for profiling embeddings in Python.

What do embeddings typically look like?

In short, your data should look like a numerical array of some fixed size (or dimension). For example, [0, 4.5, -1.2, 7.9] 4 represents a vector that could be treated as a dense embedding. Embeddings need to be structure-preserving, meaning that relationships and distances between the embeddings should be meaningful. In machine learning, we often choose embeddings in vector space to satisfy our requirements.

Embeddings are either sparse or dense. Sparse embeddings often represent booleans or counts. One hot encoding is the practice of translating a categorical feature into a vector of 0s and 1s. Most ML algorithms are designed for numerical data, so this is a critical step even in tabular settings.

Another popular sparse embedding is common in text analysis, referred to as bag of words. Here, the dimensionality of the embedding matches the size of the vocabulary of all words used and the entries represent counts of each word within the document.

But most commonly discussed today are dense embeddings, which are often trained using statistical or machine learning models themselves. Dense embeddings solve a problem of sparse embeddings which treat all concepts as completely separate. For example, the words bicycle, banana, and car may be encoded as [1, 0, 0], [0, 1, 0], and [0, 0, 1] respectively which doesn’t indicate the closer relationship between the concepts of cars and bicycles as compared to bananas. But dense embeddings do. The same words may be described in two-dimensional embeddings [4.53, 7.5], [-0.91, 2.3], and [5.17, 7.12]. If we use distance, for example, Euclidean distance, we can see that the similarity between bicycle and car is preserved for the dense embedding.

Exploring embeddings by hand

So how do data scientists typically explore a set of embeddings? With lots of manual effort. Most often, embeddings analysis tools are essentially visualization tools that translate high dimensional embeddings to 2D or 3D to be plotted. But this still requires data scientists to hover over thousands of data points and manually find trends in the data.

Caption: Two dimensional t-SNE plot for 784-dimensional MNIST colored by label.

This visualization looks very cool, but it’s difficult to extract insights from it just by eyeballing. While you may spot a pattern or a change, this is neither scalable nor a precise way for data scientists and stakeholders to detect issues in your organization’s data.

Exploring embeddings at scale

At WhyLabs, we’ve developed techniques to explore embeddings at scale without the manual work of exploring individual data points. This is done by comparing each data point to several meaningful reference points within your embeddings vector space.

For example, take the popular MNIST dataset. This data is highly dimensional, 28 x 28 pixels = 784 dimensions for our vectors. While it is common to reduce such highly dimensional data into dense vectors, we work with the full sparse embedding in this post.

There are several key references that you may want to compare your data to. We’d like some canonical example for each of the ten digits in our dataset for comparison with incoming images. We’ve packaged functions in our `preprocess` module of whylogs for finding references for both supervised and unsupervised data. But for some applications, you may have specific references of interest, such as popular variations of the digits, non-numeric characters, and more.

Comparing production embeddings data to predetermined references allows us to better measure how incoming data shifts relative to the data points that matter most to you. This approach is highly flexible and customizable to detect real issues when capturing distributional information. References give actionable insights on which types of data are most related to issues seen in production without hovering over individual data points by hand.

Some useful measurements and metrics we’ve found for embeddings include:

  • Cosine similarity across batches of embeddings
  • Minimum, maximum, and distribution across the individual values within the embeddings
  • Distribution of distance to reference embeddings (pictured below)
  • Distribution of the closest reference embeddings
  • Cluster analysis measures of the clusters formed by reference embeddings, e.g., Silhouette score, Calinski-Harabasz index

By utilizing many of these measures, we are able to programmatically measure and diagnose drift. Using only one or two metrics may help you to detect drift, but it is difficult to distinguish between potential causes.

Getting started with embeddings in whylogs

You can get started with this feature in our open source Python library, whylogs. Check out our Jupyter notebook for profiling embeddings in Python for an end-to-end example using MNIST as well as our notebook for submitting data profiles to the WhyLabs Observatory Platform to monitor your production data over time.

Assuming you have some embeddings in a numpy array, `data`, you can log embeddings with the following:

import whylogs as why
from whylogs.core.resolvers import MetricSpec, ResolverSpec
from whylogs.core.schema import DeclarativeSchema
from whylogs.experimental.extras.embedding_metric import (
    DistanceFunction,
    EmbeddingConfig,
    EmbeddingMetric,
)
from whylogs.experimental.preprocess.embeddings.selectors import PCAKMeansSelector

references, labels = PCAKMeansSelector(n_components=20).calculate_references(data)

config = EmbeddingConfig(
    references=references,
    labels=labels,
    distance_fn=DistanceFunction.euclidean,
)
schema = DeclarativeSchema([ResolverSpec(column_name="embeddings", metrics=[MetricSpec(EmbeddingMetric, config)])])

results = why.log(row={"embeddings": data}, schema=schema)

To enable drift detection, send your data profiles (not the raw data) to the WhyLabs Observatory Platform. It’s as easy as calling `.write(“whylabs”)` on your profile after you’ve signed up and set up your credentials.

import os
os.environ["WHYLABS_DEFAULT_ORG_ID"] = “YOUR ORG_ID HERE”
os.environ["WHYLABS_DEFAULT_DATASET_ID"] = “YOUR DATASET_ID HERE”
os.environ["WHYLABS_API_KEY"] = “YOUR API_KEY HERE”

results.write(“whylabs”)
Distance to closest reference is collected for each profiled data point, shown here on the Profiles tab. This shows that there’s a drift between our reference and production data.

Feedback


This feature has been released in beta and we’d love feedback on metrics and use cases for your organization. Join our Slack channel or start a GitHub issue to discuss improvements to embeddings with the WhyLabs team.

Get started by creating a free WhyLabs account or contact us to learn more about embeddings.

Resources




Other posts

Robust & Responsible AI Newsletter - Issue #5

Every quarter we send out a roundup of the hottest MLOps and Data-Centric AI news including industry highlights, what’s brewing at WhyLabs, and more.

Detecting Financial Fraud in Real-Time: A Guide to ML Monitoring

Fraud is a significant challenge for financial institutions and businesses. As fraudsters constantly adapt their tactics, it’s essential to implement a robust ML monitoring system to ensure that models effectively detect fraud and minimize false positives.

Achieving Ethical AI with Model Performance Tracing and ML Explainability

With Model Performance Tracing and ML Explainability, we’ve accelerated our customers’ journey toward achieving the three goals of ethical AI - fairness, accountability and transparency.

Detecting and Fixing Data Drift in Computer Vision

In this tutorial, Magdalena Konkiewicz from Toloka focuses on the practical part of data drift detection and fixing it on a computer vision example.

BigQuery Data Monitoring with WhyLabs

We’re excited to announce the release of a no-code solution for data monitoring in Google BigQuery, making it simple to monitor your data quality without writing a single line of code.

Robust & Responsible AI Newsletter - Issue #4

Every quarter we send out a roundup of the hottest MLOps and Data-Centric AI news including industry highlights, what’s brewing at WhyLabs, and more.

WhyLabs Private Beta: Real-time Data Monitoring on Prem

We’re excited to announce our Private Beta release of an extension service for the Profile Store, enabling production use cases of whylogs on customers' premises.

Understanding Kolmogorov-Smirnov (KS) Tests for Data Drift on Profiled Data

We experiment with statistical tests, Kolmogorov-Smirnov (KS) specifically, applied to full datasets and dataset profiles and compare the results.

Re-imagine Data Monitoring with whylogs and Apache Spark

An overview of how the whylogs integration with Apache Spark achieves large scale data profiling, and how users can apply this integration into existing data and ML pipelines.
pre footer decoration
pre footer decoration
pre footer decoration

Run AI With Certainty

Book a demo
loading...