Detecting Semantic Drift within Image Data: Monitoring Context-Full Data with whylogs
- Image Data
- Whylogs
- ML Monitoring
- Open Source
Aug 7, 2021
This article was originally published on Towards Data Science, on July 29, 2021
Your machine learning model sees the world through the lens of its training data. That means that your model gets more and more myopic as the real world gets further from what your training data represents. However, when operating machine learning applications, upgrading your glasses (retraining your model) is not our only concern. We also control the data pipeline that feeds information into our model during production, and thus have the responsibility to ensure its quality.
Concept drifts in data can have different sources and can originate in different stages of your data pipeline, even before the data collection itself. In order to take the correct measures, one must be able to pinpoint its source. However, whatever the source of the problem is and what correction measures need to be taken, it all starts with one basic requirement: we need to be aware that an issue exists in the first place, which is a challenge in itself.
In this article, we’ll show how whylogs can help you monitor your machine learning system’s data ingestion pipeline by enabling concept drift detection, specifically for image data. To do this, we’ll work with a couple of demo use cases.
For the first case we will source images from two distinct datasets: Landscape Pictures and IEEE Signal Processing Society Camera Model Identification. We’ll create a scenario where the semantic meaning changes drastically, and monitor specific features that lead us to detect unexpected data changes. We’ll also monitor the image metadata associated with either the content — or the environment associated with the image creation — and its context.
In the second scenario, we’ll demonstrate how to create more generalized semantic metrics. Together with the ability to define custom transformations, we can monitor specialized semantic information directly from our dataset in a consistent, versatile manner. In this case, we’ll generate image embeddings from a pre-trained DenseNet model and use them as our semantic-full data. We’ll simulate deployments with datasets of different CIFAR-10
classes as examples of how to capture and potentially detect these shifts in our data and models.
Metadata and Features
Even though we don’t have an actual model for prediction, let’s assume that our model input is expected to consist mainly of landscape images. Using a simulated production stage, we can test if it’s possible to detect the presence of out-of-domain images by logging some basic image features. These can include metadata information and properties such as hue, saturation, and brightness.
We’ll store the images in folders as follows:
The landscape_baseline/baseline
subfolder contains 1800 JPEG
images of landscapes. We’ll use these images as the baseline for comparison. The remaining three folders contain the image data to be monitored during production:
landscape_images
has 200 JPG images of landscapescamera_images
has 200 TIFF images of no particular category, extracted from the Camera Model Identification dataset.mixed_images
has a combination of the previous two datasets in a 50/50 proportion.
Image features
One way to detect data changes is to create and compare histograms of image properties such as hue, saturation, or brightness (HSB). We can use whylogs to automatically log these properties, and then use the profiled information to plot the histograms. Note that the HSB color space is more useful in understanding content than the RGB color space — and its associated histograms — since these properties are directly related to visual interpretation of the images.
For this case we’ll create 3 different datasets:
- The first contains information about our baseline dataset
- The second will simulate a batch of images that is expected: only landscape images
- The third will contain only “unseen” images, unrelated to landscapes
The idea is that the data distribution of batches received during production should be similar to the distribution of our baseline set. If we can detect that the distributions between the two begin to drift apart, it might be an alert for further inspection.
Let’s begin by creating a session, which is how our application interacts with whylogs:
from whylogs import get_or_create_session
session = get_or_create_session()
And proceed by logging the content of each folder using log_local_dataset()
with session.logger(dataset_name="baseline") as logger:
logger.log_local_dataset(root_dir="landscape_baseline", image_feature_transforms= [ ComposeTransforms([Resize(50),Saturation()])],
show_progress=True)
profile_baseline=logger.profile
As you can see in the code above, we can apply image transformations prior to logging using image_feature_transforms()
. We can also mix and match functions to create our customized feature transformation pipeline with ComposeTransforms()
. In this case, we resize our images to be 50x50 pixels, and then get the saturation values for every pixel.
We also have to log the images for landscape_images
and camera_images
. Since it is exactly the same code as above, we’ll avoid repeating it here. The only differences are in the dataset_name
, root_dir
and *_baseline
variables.
Then, we can define our functions to plot our histograms:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
minnpf = np.frompyfunc(lambda x, y: min(x,y), 2, 1)
maxnpf = np.frompyfunc(lambda x, y: max(x,y), 2, 1)
def get_custom_histogram_info(profiles, variable, n_bins):
summaries = [profile.flat_summary()["summary"] for profile in profiles]
min_range= minnpf.accumulate([ summary[summary["column"]==variable]["min"].values[0] for summary in summaries], dtype=np.object).astype(np.int)
max_range= maxpf.accumulate([ summary[summary["column"]==variable]["max"].values[0] for summary in summaries], dtype=np.object).astype(np.int)
bins = np.linspace(int(min_range), int(max_range), int((max_range-min_range)/n_bins))
counts= [ profile.columns[variable].number_tracker.histogram.get_pmf(bins[:-1]) for profile in profiles]
return bins, counts
def plot_distribution_shift(profiles, variable, n_bins, baseline = None):
"""Visualization for distribution shift"""
bins, counts = get_custom_histogram_info(profiles, variable, n_bins)
fig, ax = plt.subplots(figsize=(10, 3))
if baseline:
bins_b, counts_b = get_custom_histogram_info(baseline, variable, n_bins)
for idx, profile in enumerate(baseline):
sns.histplot(x=bins_b, weights=counts_b[idx], bins=n_bins,
label="baseline",color="teal", alpha=0.7, ax=ax)
for idx, profile in enumerate(profiles):
sns.histplot(x=bins, weights=counts[idx], bins=n_bins,label=profile.name,color="gold", alpha=0.7, ax=ax)
ax.legend()
plt.show()
These functions will plot the histogram of the chosen feature by combining information of every profile in the input list. If a baseline is defined, the histogram for the baseline dataset is also plotted in the background for comparison.
Let’s plot their distributions:
plot_distribution_shift([profile_landscapes],"Saturation(Resize(IMG))",10,baseline=[profile_baseline])
plot_distribution_shift([profile_cameras],"Saturation(Resize(IMG))",10,baseline=[profile_baseline])
As we can see, the resulting saturation distribution for the landscape set is much more similar to the baseline set than our “drifted” batch. If we know what to expect, plots such as the second one will certainly be reason for attention.
Image metadata
A second way to obtain useful information is to monitor the logged images’ metadata information. Not every file will have them, but we can try to benefit from what is available, which will be automatically logged by whylogs. We can also log and search for specific Tags. EXIF
Tags is a standard used by most camera and image software manufactures. Unfortunately as one goes into more obscure data formats one tends to encounter custom metadata, thankfully they generally tend to be TIFF
Tags. Below is an example of some of the TIFF
tags found in one of my personal camera images.
╔═══════════════════════════╦══════════════════════╗
║ Tag ║ Value ║
╠═══════════════════════════╬══════════════════════╣
║ Manufacturer ║ CANNON ║
║ Model ║ EOS-600D ║
║ Orientation ║ top-left ║
║ Software ║ Ver2.21 ║
║ Date and time ║ 20200:08:11 16:55:22 ║
║ YCbCr positioning ║ centered ║
║ Compression ║ JPEG compression ║
║ file_format ║ jpg ║
║ X resolution ║ 72.00 ║
║ Y resolution ║ 72.00 ║
║ Resolution unit ║ Inch ║
║ Compressed bits per pixel ║ 4 ║
║ Exposure bias ║ 0.0 ║
║ Max. aperture value ║ 2.00 ║
║ Metering mode ║ Pattern ║
║ Flash ║ Flash did not fire ║
║ Focal length ║ 16.1 mm ║
║ Color space ║ sRGB ║
║ Pixel X dimension ║ 2240 ║
║ Pixel Y dimension ║ 1680 ║
╚═══════════════════════════╩══════════════════════╝
Tags like Resolution unit
along with X and Y resolution, gives us a view into potential object scaling problems, if the unit size of each pixel change and camera location (relative to the objects of interest) has not, we could run into objects that be smaller or larger than previously encountered, specially if there was no augmentation step in those regards. While Compressed bits per pixel
could potentially inform us of pixel intensity scaling or new context to image such as compression artifacts that perhaps were not present in the baseline set.
In this simple case, we’ll exploit the file_format
information. As shown before, the landscape images are expected to be of JPEG
format while our “noise” images are all in TIFF
format. Therefore, we can monitor the count for unique values across the dataset to ensure consistency between batches.
To demonstrate, we’ll log a series of five batches in the same dataset. Four of the datasets consist of only landscape images, and other is a half-and-half mix of landscape and out-of-domain images:
from whylogs.viz import ProfileVisualizer
viz = ProfileVisualizer()
viz.set_profiles(profiles_mixed)
viz.plot_uniqueness("file_format")
As expected, the plot reflects the number of unique file formats for each batch of data. Even though this is a very simple case, it demonstrates how changes in the number of file formats can indicate the presence of data issues.
Metadata gives us the power to quickly detect potentials issues with date, instrumentation, image size, encoding, authorship. This in turn makes the monitoring images pipelines able to provide quick fail-safe. This might not be as important for someone monitoring a static Kaggle dataset, but a large company keeping, say medical images, consistent and experiments reproducible, will incredibly grateful.
Additionally these can lead contextual information that will be valuable while debugging faults during a retrospective.
Semantic Drifts
In a typical image classification task, it is important to be aware of changes in distributions across classes. While ground truth labels are often not readily available, we can benefit from transfer learning to generate feature embeddings in order to gain some semantic insights about our data. In this scenario, we’ll use a subset from the CIFAR-10
dataset –namely, the bird
,cat
, deer
and horse
classes to form an “animals” baseline. We’ll create a separate skewed “vehicles” set made of samples from the ship
and automobile
classes.
Custom Features — Distances from Cluster Centers
In this case, we want to generate and log features that can semantically represent the distance from the logged images to the “ideal” representation of each class. To do this we can make use of whylogs ability to create custom functions to generate specific image features.That way, we can log the distance from the image to both classes, and plot a histogram to assess the distance distributions for each batch.
First we need to embed the original pixel values into an N-dimensional vector. In this example, the feature extractor was based on the DenseNet-121
architecture and generates vectors of 1024 dimensions. These embeddings can be further normalized, so the scalar product between them will lie in a normalized range, putting each vector in the surface of hypersphere.
Check this post by Apple’s ML team on on how similar embeddings are used in their face recognition pipeline.
Since we have the labels for each image in the training stage, we can calculate the centroid’s coordinates for each class beforehand. With the cluster’s centers at hand we can choose the images — or objects — closest to these centers, as our semantic center. These images give us anchors to compute and monitor the distribution of our embeddings. We can create a custom function to log the distance from each image to the centers by calculating the dot product between them:
class SemanticCenterDistance:
def __init__(self, semantic_centers: List[nd.array],embbeding_function: Callable):
self.semantic_centers = semantic_centers
self.embedding_fuction = embbeding_function
def __call__(self, x):
return np.array([np.dot(embbeding_function(x),center) for center in self.semantic_centers]).reshape(-1,1)
def __repr__(self,):
return self.__class__.__name__
As in the previous case, we can apply our custom transformation to log two datasets. One being a set made up of four randomly chosen animal classes (bird
, cat
, deer
, horse
) in the CIFAR-10
dataset. The unknown image set of semantic-full objects is made up of two vehicle categories (ship
, automobile
), which will result in the drift in our datasets.
In the image above, we use the UMAP
map as a visualization for the embeddings along with the location of each category’s semantic center in the original dimensions, given by the location of the category labels. Note that since we are using a plain UMAP
embedding the spatial ordering is not preserved in the projection. Nonetheless, they still give us a view on how embeddings are clustered around semantic similar objects.
We compute our semantic-full metrics using only the original baseline centers, (bird, cat, deer, horse), since we know these are part of our original baseline. In figure below, we plot the distributions of the distances from each of the semantic center for the baseline (left chart) and to the vehicle only distribution (right chart).
We can then combine these distances into a single metric, as defined in the custom function above. In the figure below, we show the distribution for vehicle-only embeddings from each baseline semantic center, along with the distribution for embedding in both animal and vehicle together (a mixed set) and animal-only embeddings as a way to visualize the semantic drift.
To quantify this drift we use the KS Statistic for each distribution compared to the baseline (animal-only) distributions. For example, comparing baseline to vehicle embeddings lead to a KS Statistic of 0.25. And given the large number of embeddings we used to compute the distributions above we get a very small likelihood of null hypothesis (p < 0.0001). In this simple experiment it can be shown that we would need approximately 200 embeddings from baseline and shifted distribution of vehicles to significantly reject the null-hypothesis using this embedding model and classes. This number gives us the optimal batch size for detecting such changes in our pipeline for this particular example.
Conclusion
Real world data can behave a lot differently than you might expect. Distributional drift is one of the most common problems encountered in AI applications and, as such, you should have tools to effectively detect those issues the moment they appear.
In this article we covered some common use cases to demonstrate how whylogs can help you detect data drift issues in images, ranging from using default image features and metadata information to applying custom transformations to generate specific features.
That said, the examples covered in this article can be further expanded. For instance, in the saturation case, we could continue our analysis by quantifying the histogram shift through metrics such as the KS Statistics or other distribution metrics. Furthermore, these metrics and its associated threshold can be potentially included it into your CI/CD pipeline — for example — check out this post on how you can include whylogs constraint checks in your github actions.
When available, other types of metadata could also provide valuable insights, such as aspect ratio and resolution, or object detection and segmentation data to further understand and give context to the subject matter. Or, in the semantic distribution case, we could gradually increase the proportion of the unknown classes to define a viable threshold value the KS-statistic, or other metrics such as the KL-divergence or Hellinger Distance. What’s more, these approaches — and others — can be integrated into different stages of the data pipeline to provide full observability of your machine learning application.
Here one such recent method that use transformer-based models to provide a few-shot method for OOD detection! They also provide some common metrics used to determine the validity of such task.
If you are interested in contributing to whylogs or have any questions, please feel free reach out to @lalmei, or head over to our OS repo https://github.com/whylabs/whylogs. You can also check out our documentation page at https://docs.whylabs.ai/docs/ for more information on whylogs.
Other posts
How to Evaluate and Improve RAG Applications for Safe Production Deployment
Jul 17, 2024
- AI Observability
- LLMs
- LLM Security
- LangKit
- RAG
- Open Source
WhyLabs Integrates with NVIDIA NIM to Deliver GenAI Applications with Security and Control
Jun 2, 2024
- AI Observability
- Generative AI
- Integrations
- LLM Security
- LLMs
- Partnerships
OWASP Top 10 Essential Tips for Securing LLMs: Guide to Improved LLM Safety
May 21, 2024
- LLMs
- LLM Security
- Generative AI
7 Ways to Evaluate and Monitor LLMs
May 13, 2024
- LLMs
- Generative AI
How to Distinguish User Behavior and Data Drift in LLMs
May 7, 2024
- LLMs
- Generative AI