Source code for lensai_profiler.metrics

# lensai/metrics.py

import numpy as np
import datasketches

[docs]class Metrics: def __init__(self, framework='tf'): """ Initialize the Metrics class. Args: framework: Specify the framework ('tf' for TensorFlow, 'pt' for PyTorch). """ self.framework = framework if self.framework == 'tf': global tf try: import tensorflow as tf except ImportError: raise ImportError("TensorFlow is not installed. Please install it to use TensorFlow metrics.") elif self.framework == 'pt': global torch try: import torch except ImportError: raise ImportError("PyTorch is not installed. Please install it to use PyTorch metrics.") else: raise ValueError("Unsupported framework. Use 'tf' for TensorFlow or 'pt' for PyTorch.")
[docs] def calculate_brightness(self, image): """ Calculate the brightness of an image. Args: image: A TensorFlow tensor representing an RGB image. Returns: A TensorFlow tensor containing the mean brightness of the image. """ if self.framework == 'tf': return self._calculate_brightness_tf(image) elif self.framework == 'pt': return self._calculate_brightness_pt(image)
[docs] def calculate_sharpness_laplacian(self, image): """ Calculate the sharpness of an image using the Laplacian operator. Args: image: A TensorFlow tensor representing an RGB image. Returns: A TensorFlow tensor containing the sharpness of the image. """ if self.framework == 'tf': return self._calculate_sharpness_laplacian_tf(image) elif self.framework == 'pt': return self._calculate_sharpness_laplacian_pt(image)
[docs] def calculate_channel_mean(self, image): """ Calculate the mean of each channel of an image. Args: image: A TensorFlow tensor representing an image. Returns: A TensorFlow tensor containing the mean of each channel. """ if self.framework == 'tf': return self._calculate_channel_mean_tf(image) elif self.framework == 'pt': return self._calculate_channel_mean_pt(image)
[docs] def calculate_snr(self, image): """ Calculate the Signal-to-Noise Ratio (SNR) of an image. Args: image_tensor: A TensorFlow tensor representing an RGB or RGBA image. Returns: A TensorFlow tensor containing the SNR of the image. """ if self.framework == 'tf': return self._calculate_snr_tf(image) elif self.framework == 'pt': return self._calculate_snr_pt(image)
[docs] def calculate_channel_histogram(self, image): """ Calculate the histogram of the channels of an image. Args: image: A TensorFlow tensor representing an image. Returns: A TensorFlow tensor containing the histogram of the image channels. """ if self.framework == 'tf': return self._calculate_channel_histogram_tf(image) elif self.framework == 'pt': return self._calculate_channel_histogram_pt(image)
[docs] def process_batch(self, images): """ Process a batch of images and calculate various metrics. Args: images: A TensorFlow tensor representing a batch of images. Returns: A tuple containing the brightness, sharpness, channel mean, SNR, and channel histogram of the batch. """ if self.framework == 'tf': return self._process_batch_tf(images) elif self.framework == 'pt': return self._process_batch_pt(images)
# TensorFlow implementations def _calculate_brightness_tf(self, image): grayscale = tf.image.rgb_to_grayscale(image) return tf.reduce_mean(grayscale) def _calculate_sharpness_laplacian_tf(self, image): kernel = tf.constant([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=tf.float32) kernel = tf.reshape(kernel, [3, 3, 1, 1]) grayscale = tf.image.rgb_to_grayscale(image) grayscale = tf.expand_dims(grayscale, axis=0) sharpness = tf.nn.conv2d(grayscale, kernel, strides=[1, 1, 1, 1], padding='SAME') return tf.reduce_mean(tf.abs(sharpness)) def _calculate_channel_mean_tf(self, image): return tf.reduce_mean(image, axis=[0, 1]) def _calculate_snr_tf(self, image_tensor): grayscale = tf.image.rgb_to_grayscale(image_tensor) mean, variance = tf.nn.moments(grayscale, axes=[0, 1]) sigma = tf.sqrt(variance) snr = tf.where(sigma == 0, np.inf, 20 * tf.math.log(mean / sigma + 1e-7) / tf.math.log(10.0)) return snr def _calculate_channel_histogram_tf(self, image): num_channels = image.shape[-1] channel_pixels = tf.reshape(image, [-1, num_channels]) return channel_pixels def _calculate_channel_histogram_pt(self, image): num_channels = image.shape[0] channel_pixels = image.view(num_channels, -1).permute(1, 0) # Flatten the image and switch dimensions to (num_pixels, num_channels) return channel_pixels def _process_batch_tf(self, images): brightness = tf.map_fn(self._calculate_brightness_tf, images, dtype=tf.float32) sharpness = tf.map_fn(self._calculate_sharpness_laplacian_tf, images, dtype=tf.float32) channel_mean = tf.map_fn(self._calculate_channel_mean_tf, images, dtype=tf.float32) snr = tf.map_fn(self._calculate_snr_tf, images, dtype=tf.float32) channel_pixels = tf.map_fn(self._calculate_channel_histogram_tf, images, dtype=tf.float32) return brightness, sharpness, channel_mean, snr, channel_pixels # PyTorch implementations def _calculate_brightness_pt(self, image): grayscale = torch.mean(image, dim=0, keepdim=True) return torch.mean(grayscale) def _calculate_sharpness_laplacian_pt(self, image): # Define the Laplacian kernel kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32) kernel = kernel.view(1, 1, 3, 3) # Reshape to (out_channels, in_channels, height, width) # Convert the image to grayscale if it has more than one channel if image.size(1) != 1: # Check if the image has more than one channel grayscale = torch.mean(image, dim=0, keepdim=True) # Convert to grayscale else: grayscale = image # Image is already grayscale grayscale = grayscale.unsqueeze(0) # Add a batch dimension if necessary, shape becomes [1, 1, H, W] #Apply the Laplacian kernel using conv2d sharpness = torch.nn.functional.conv2d(grayscale, kernel, stride=1, padding=1) return torch.mean(torch.abs(sharpness)) # Calculate mean absolute sharpness value def _calculate_channel_mean_pt(self, image): return torch.mean(image, dim=[1, 2]) def _calculate_snr_pt(self, image_tensor): grayscale = torch.mean(image_tensor, dim=0, keepdim=True) mean = torch.mean(grayscale) sigma = torch.std(grayscale) snr = torch.where(sigma == 0, torch.tensor(float('inf')), 20 * torch.log10(mean / (sigma + 1e-7))) return snr def _process_batch_pt(self, images): brightness = torch.stack([self._calculate_brightness_pt(img) for img in images]) sharpness = torch.stack([self._calculate_sharpness_laplacian_pt(img) for img in images]) channel_mean = torch.stack([self._calculate_channel_mean_pt(img) for img in images]) snr = torch.stack([self._calculate_snr_pt(img) for img in images]) channel_pixels = torch.stack([self._calculate_channel_histogram_pt(img) for img in images]) return brightness, sharpness, channel_mean, snr, channel_pixels
# Utility functions for both TensorFlow and PyTorch
[docs]def calculate_percentiles(x, probabilities, lower_percentile=0.01, upper_percentile=0.99): """ Calculates percentiles from a PMF (Probability Mass Function) represented as two separate lists. Args: x: List containing the x-values (possible values) in the distribution. probabilities: List containing the probabilities corresponding to the x-values. lower_percentile: Float between 0 and 1 (inclusive) specifying the lower percentile (default 0.01). upper_percentile: Float between 0 and 1 (inclusive) specifying the upper percentile (default 0.99). Returns: A tuple containing the lower and upper percentiles (x-values, float). """ # Ensure lists have the same length if len(x) != len(probabilities): raise ValueError("x and probabilities lists must have the same length") # Ensure PMF is a valid probability distribution (sums to 1) if not np.isclose(sum(probabilities), 1): raise ValueError("PMF must sum to 1") # Combine x-values and probabilities into a single list of tuples pmf = list(zip(x, probabilities)) # Sort PMF based on x-values (ascending order) pmf.sort(key=lambda item: item[0]) # Calculate cumulative sum of probabilities cdf = np.cumsum([p for _, p in pmf]) # Calculate percentile indices with edge case handling lower_percentile_idx = np.searchsorted(cdf, lower_percentile, side='right') upper_percentile_idx = np.searchsorted(cdf, upper_percentile, side='right') # Access corresponding x-values from the sorted PMF lower_percentile_value = pmf[lower_percentile_idx][0] if lower_percentile_idx < len(pmf) else pmf[-1][0] upper_percentile_value = pmf[upper_percentile_idx][0] if upper_percentile_idx < len(pmf) else pmf[-1][0] return lower_percentile_value, upper_percentile_value
[docs]def get_histogram_sketch(sketch, num_splits=30): """ Reads a binary file, deserializes the content, and extracts the PMF. Args: sketch: A probabilistic data structure representing the sketch of the distribution. num_splits: Number of splits for the PMF (default: 30). Returns: A tuple containing x-axis values and the PMF. """ if sketch.is_empty(): return None, None xmin = sketch.get_min_value() try: step = (sketch.get_max_value() - xmin) / num_splits except ZeroDivisionError: print(f"Error: num_splits should be non-zero") return None, None if step == 0: step = 0.01 splits = [xmin + (i * step) for i in range(0, num_splits)] pmf = sketch.get_pmf(splits) x = splits + [sketch.get_max_value()] # Append max value for x-axis return x, pmf