Detecting Semantic Drift within Image Data: Monitoring Context-Full Data with whylogs
- Data Analytics
- Data Logging
- Image Data
- ML Monitoring
Aug 7, 2021
This article was originally published on Towards Data Science, on July 29, 2021
Your machine learning model sees the world through the lens of its training data. That means that your model gets more and more myopic as the real world gets further from what your training data represents. However, when operating machine learning applications, upgrading your glasses (retraining your model) is not our only concern. We also control the data pipeline that feeds information into our model during production, and thus have the responsibility to ensure its quality.
Concept drifts in data can have different sources and can originate in different stages of your data pipeline, even before the data collection itself. In order to take the correct measures, one must be able to pinpoint its source. However, whatever the source of the problem is and what correction measures need to be taken, it all starts with one basic requirement: we need to be aware that an issue exists in the first place, which is a challenge in itself.
In this article, we’ll show how whylogs can help you monitor your machine learning system’s data ingestion pipeline by enabling concept drift detection, specifically for image data. To do this, we’ll work with a couple of demo use cases.
For the first case we will source images from two distinct datasets: Landscape Pictures and IEEE Signal Processing Society Camera Model Identification. We’ll create a scenario where the semantic meaning changes drastically, and monitor specific features that lead us to detect unexpected data changes. We’ll also monitor the image metadata associated with either the content — or the environment associated with the image creation — and its context.
In the second scenario, we’ll demonstrate how to create more generalized semantic metrics. Together with the ability to define custom transformations, we can monitor specialized semantic information directly from our dataset in a consistent, versatile manner. In this case, we’ll generate image embeddings from a pre-trained DenseNet model and use them as our semantic-full data. We’ll simulate deployments with datasets of different
CIFAR-10 classes as examples of how to capture and potentially detect these shifts in our data and models.
Metadata and Features
Even though we don’t have an actual model for prediction, let’s assume that our model input is expected to consist mainly of landscape images. Using a simulated production stage, we can test if it’s possible to detect the presence of out-of-domain images by logging some basic image features. These can include metadata information and properties such as hue, saturation, and brightness.
We’ll store the images in folders as follows:
landscape_baseline/baseline subfolder contains 1800
JPEG images of landscapes. We’ll use these images as the baseline for comparison. The remaining three folders contain the image data to be monitored during production:
landscape_imageshas 200 JPG images of landscapes
camera_imageshas 200 TIFF images of no particular category, extracted from the Camera Model Identification dataset.
mixed_imageshas a combination of the previous two datasets in a 50/50 proportion.
One way to detect data changes is to create and compare histograms of image properties such as hue, saturation, or brightness (HSB). We can use whylogs to automatically log these properties, and then use the profiled information to plot the histograms. Note that the HSB color space is more useful in understanding content than the RGB color space — and its associated histograms — since these properties are directly related to visual interpretation of the images.
For this case we’ll create 3 different datasets:
- The first contains information about our baseline dataset
- The second will simulate a batch of images that is expected: only landscape images
- The third will contain only “unseen” images, unrelated to landscapes
The idea is that the data distribution of batches received during production should be similar to the distribution of our baseline set. If we can detect that the distributions between the two begin to drift apart, it might be an alert for further inspection.
Let’s begin by creating a session, which is how our application interacts with whylogs:
from whylogs import get_or_create_session session = get_or_create_session()
And proceed by logging the content of each folder using
with session.logger(dataset_name="baseline") as logger: logger.log_local_dataset(root_dir="landscape_baseline", image_feature_transforms= [ ComposeTransforms([Resize(50),Saturation()])], show_progress=True) profile_baseline=logger.profile
As you can see in the code above, we can apply image transformations prior to logging using
image_feature_transforms(). We can also mix and match functions to create our customized feature transformation pipeline with
ComposeTransforms(). In this case, we resize our images to be 50x50 pixels, and then get the saturation values for every pixel.
We also have to log the images for
camera_images. Since it is exactly the same code as above, we’ll avoid repeating it here. The only differences are in the
Then, we can define our functions to plot our histograms:
import matplotlib.pyplot as plt import seaborn as sns import numpy as np minnpf = np.frompyfunc(lambda x, y: min(x,y), 2, 1) maxnpf = np.frompyfunc(lambda x, y: max(x,y), 2, 1) def get_custom_histogram_info(profiles, variable, n_bins): summaries = [profile.flat_summary()["summary"] for profile in profiles] min_range= minnpf.accumulate([ summary[summary["column"]==variable]["min"].values for summary in summaries], dtype=np.object).astype(np.int) max_range= maxpf.accumulate([ summary[summary["column"]==variable]["max"].values for summary in summaries], dtype=np.object).astype(np.int) bins = np.linspace(int(min_range), int(max_range), int((max_range-min_range)/n_bins)) counts= [ profile.columns[variable].number_tracker.histogram.get_pmf(bins[:-1]) for profile in profiles] return bins, counts
def plot_distribution_shift(profiles, variable, n_bins, baseline = None): """Visualization for distribution shift""" bins, counts = get_custom_histogram_info(profiles, variable, n_bins) fig, ax = plt.subplots(figsize=(10, 3)) if baseline: bins_b, counts_b = get_custom_histogram_info(baseline, variable, n_bins) for idx, profile in enumerate(baseline): sns.histplot(x=bins_b, weights=counts_b[idx], bins=n_bins, label="baseline",color="teal", alpha=0.7, ax=ax) for idx, profile in enumerate(profiles): sns.histplot(x=bins, weights=counts[idx], bins=n_bins,label=profile.name,color="gold", alpha=0.7, ax=ax) ax.legend() plt.show()
These functions will plot the histogram of the chosen feature by combining information of every profile in the input list. If a baseline is defined, the histogram for the baseline dataset is also plotted in the background for comparison.
Let’s plot their distributions:
As we can see, the resulting saturation distribution for the landscape set is much more similar to the baseline set than our “drifted” batch. If we know what to expect, plots such as the second one will certainly be reason for attention.
A second way to obtain useful information is to monitor the logged images’ metadata information. Not every file will have them, but we can try to benefit from what is available, which will be automatically logged by whylogs. We can also log and search for specific Tags.
EXIF Tags is a standard used by most camera and image software manufactures. Unfortunately as one goes into more obscure data formats one tends to encounter custom metadata, thankfully they generally tend to be
TIFF Tags. Below is an example of some of the
TIFF tags found in one of my personal camera images.
╔═══════════════════════════╦══════════════════════╗ ║ Tag ║ Value ║ ╠═══════════════════════════╬══════════════════════╣ ║ Manufacturer ║ CANNON ║ ║ Model ║ EOS-600D ║ ║ Orientation ║ top-left ║ ║ Software ║ Ver2.21 ║ ║ Date and time ║ 20200:08:11 16:55:22 ║ ║ YCbCr positioning ║ centered ║ ║ Compression ║ JPEG compression ║ ║ file_format ║ jpg ║ ║ X resolution ║ 72.00 ║ ║ Y resolution ║ 72.00 ║ ║ Resolution unit ║ Inch ║ ║ Compressed bits per pixel ║ 4 ║ ║ Exposure bias ║ 0.0 ║ ║ Max. aperture value ║ 2.00 ║ ║ Metering mode ║ Pattern ║ ║ Flash ║ Flash did not fire ║ ║ Focal length ║ 16.1 mm ║ ║ Color space ║ sRGB ║ ║ Pixel X dimension ║ 2240 ║ ║ Pixel Y dimension ║ 1680 ║ ╚═══════════════════════════╩══════════════════════╝
Resolution unit along with X and Y resolution, gives us a view into potential object scaling problems, if the unit size of each pixel change and camera location (relative to the objects of interest) has not, we could run into objects that be smaller or larger than previously encountered, specially if there was no augmentation step in those regards. While
Compressed bits per pixel could potentially inform us of pixel intensity scaling or new context to image such as compression artifacts that perhaps were not present in the baseline set.
In this simple case, we’ll exploit the
file_format information. As shown before, the landscape images are expected to be of
JPEG format while our “noise” images are all in
TIFF format. Therefore, we can monitor the count for unique values across the dataset to ensure consistency between batches.
To demonstrate, we’ll log a series of five batches in the same dataset. Four of the datasets consist of only landscape images, and other is a half-and-half mix of landscape and out-of-domain images:
from whylogs.viz import ProfileVisualizer viz = ProfileVisualizer() viz.set_profiles(profiles_mixed) viz.plot_uniqueness("file_format")
As expected, the plot reflects the number of unique file formats for each batch of data. Even though this is a very simple case, it demonstrates how changes in the number of file formats can indicate the presence of data issues.
Metadata gives us the power to quickly detect potentials issues with date, instrumentation, image size, encoding, authorship. This in turn makes the monitoring images pipelines able to provide quick fail-safe. This might not be as important for someone monitoring a static Kaggle dataset, but a large company keeping, say medical images, consistent and experiments reproducible, will incredibly grateful.
Additionally these can lead contextual information that will be valuable while debugging faults during a retrospective.
In a typical image classification task, it is important to be aware of changes in distributions across classes. While ground truth labels are often not readily available, we can benefit from transfer learning to generate feature embeddings in order to gain some semantic insights about our data. In this scenario, we’ll use a subset from the
CIFAR-10 dataset –namely, the
horse classes to form an “animals” baseline. We’ll create a separate skewed “vehicles” set made of samples from the
Custom Features — Distances from Cluster Centers
In this case, we want to generate and log features that can semantically represent the distance from the logged images to the “ideal” representation of each class. To do this we can make use of whylogs ability to create custom functions to generate specific image features.That way, we can log the distance from the image to both classes, and plot a histogram to assess the distance distributions for each batch.
First we need to embed the original pixel values into an N-dimensional vector. In this example, the feature extractor was based on the
DenseNet-121 architecture and generates vectors of 1024 dimensions. These embeddings can be further normalized, so the scalar product between them will lie in a normalized range, putting each vector in the surface of hypersphere.
Check this post by Apple’s ML team on on how similar embeddings are used in their face recognition pipeline.
Since we have the labels for each image in the training stage, we can calculate the centroid’s coordinates for each class beforehand. With the cluster’s centers at hand we can choose the images — or objects — closest to these centers, as our semantic center. These images give us anchors to compute and monitor the distribution of our embeddings. We can create a custom function to log the distance from each image to the centers by calculating the dot product between them:
class SemanticCenterDistance: def __init__(self, semantic_centers: List[nd.array],embbeding_function: Callable): self.semantic_centers = semantic_centers self.embedding_fuction = embbeding_function def __call__(self, x): return np.array([np.dot(embbeding_function(x),center) for center in self.semantic_centers]).reshape(-1,1) def __repr__(self,): return self.__class__.__name__
As in the previous case, we can apply our custom transformation to log two datasets. One being a set made up of four randomly chosen animal classes (
horse) in the
CIFAR-10 dataset. The unknown image set of semantic-full objects is made up of two vehicle categories (
automobile), which will result in the drift in our datasets.
In the image above, we use the
UMAP map as a visualization for the embeddings along with the location of each category’s semantic center in the original dimensions, given by the location of the category labels. Note that since we are using a plain
UMAP embedding the spatial ordering is not preserved in the projection. Nonetheless, they still give us a view on how embeddings are clustered around semantic similar objects.
We compute our semantic-full metrics using only the original baseline centers, (bird, cat, deer, horse), since we know these are part of our original baseline. In figure below, we plot the distributions of the distances from each of the semantic center for the baseline (left chart) and to the vehicle only distribution (right chart).
We can then combine these distances into a single metric, as defined in the custom function above. In the figure below, we show the distribution for vehicle-only embeddings from each baseline semantic center, along with the distribution for embedding in both animal and vehicle together (a mixed set) and animal-only embeddings as a way to visualize the semantic drift.
To quantify this drift we use the KS Statistic for each distribution compared to the baseline (animal-only) distributions. For example, comparing baseline to vehicle embeddings lead to a KS Statistic of 0.25. And given the large number of embeddings we used to compute the distributions above we get a very small likelihood of null hypothesis (p < 0.0001). In this simple experiment it can be shown that we would need approximately 200 embeddings from baseline and shifted distribution of vehicles to significantly reject the null-hypothesis using this embedding model and classes. This number gives us the optimal batch size for detecting such changes in our pipeline for this particular example.
Real world data can behave a lot differently than you might expect. Distributional drift is one of the most common problems encountered in AI applications and, as such, you should have tools to effectively detect those issues the moment they appear.
In this article we covered some common use cases to demonstrate how whylogs can help you detect data drift issues in images, ranging from using default image features and metadata information to applying custom transformations to generate specific features.
That said, the examples covered in this article can be further expanded. For instance, in the saturation case, we could continue our analysis by quantifying the histogram shift through metrics such as the KS Statistics or other distribution metrics. Furthermore, these metrics and its associated threshold can be potentially included it into your CI/CD pipeline — for example — check out this post on how you can include whylogs constraint checks in your github actions.
When available, other types of metadata could also provide valuable insights, such as aspect ratio and resolution, or object detection and segmentation data to further understand and give context to the subject matter. Or, in the semantic distribution case, we could gradually increase the proportion of the unknown classes to define a viable threshold value the KS-statistic, or other metrics such as the KL-divergence or Hellinger Distance. What’s more, these approaches — and others — can be integrated into different stages of the data pipeline to provide full observability of your machine learning application.
Here one such recent method that use transformer-based models to provide a few-shot method for OOD detection! They also provide some common metrics used to determine the validity of such task.
If you are interested in contributing to whylogs or have any questions, please feel free reach out to @lalmei, or head over to our OS repo https://github.com/whylabs/whylogs. You can also check out our documentation page at https://docs.whylabs.ai for more information on whylogs.
Achieving Ethical AI with Model Performance Tracing and ML Explainability
Feb 2, 2023
- ML Monitoring
WhyLabs Private Beta: Real-time Data Monitoring on Prem
Dec 21, 2022
Understanding Kolmogorov-Smirnov (KS) Tests for Data Drift on Profiled Data
Dec 21, 2022
- Data Science
- Machine Learning
Re-imagine Data Monitoring with whylogs and Apache Spark
Nov 23, 2022
- Apache Spark
AIShield and WhyLabs: Threat Detection and Monitoring for AI
Nov 8, 2022
- AI Observability
Large Scale Data Profiling with whylogs and Fugue on Spark, Ray or Dask
Oct 13, 2022