import unittest
import numpy as np
from lensai_profiler.metrics import Metrics, calculate_percentiles, get_histogram_sketch
[docs]class TestLensaiMetrics(unittest.TestCase):
[docs] def setUp(self):
# Create test images for both TF and PT
self.frameworks = ['tf', 'pt']
self.images_rgb = {}
self.images_rgba = {}
for framework in self.frameworks:
if framework == 'tf':
import tensorflow as tf
self.images_rgb[framework] = tf.constant([
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
], dtype=tf.float32)
self.images_rgba[framework] = tf.constant([
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
[[1.0, 2.0, 3.0, 4.0], [4.0, 5.0, 6.0, 7.0]]
], dtype=tf.float32)
elif framework == 'pt':
import torch
self.images_rgb[framework] = torch.tensor([
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
], dtype=torch.float32)
self.images_rgba[framework] = torch.tensor([
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
[[1.0, 2.0, 3.0, 4.0], [4.0, 5.0, 6.0, 7.0]]
], dtype=torch.float32)
[docs] def test_calculate_brightness(self):
for framework in self.frameworks:
metrics = Metrics(framework=framework)
image_rgb = self.images_rgb[framework]
brightness = metrics.calculate_brightness(image_rgb)
if framework == 'tf':
import tensorflow as tf
expected_brightness = tf.reduce_mean(tf.image.rgb_to_grayscale(image_rgb))
self.assertTrue(np.isclose(brightness.numpy(), expected_brightness.numpy()))
elif framework == 'pt':
import torch
expected_brightness = torch.mean(torch.mean(image_rgb, dim=0))
self.assertTrue(np.isclose(brightness.item(), expected_brightness.item()))
[docs] def test_calculate_snr(self):
for framework in self.frameworks:
metrics = Metrics(framework=framework)
image_rgb = self.images_rgb[framework]
snr = metrics.calculate_snr(image_rgb)
if framework == 'tf':
import tensorflow as tf
grayscale = tf.image.rgb_to_grayscale(image_rgb)
mean, variance = tf.nn.moments(grayscale, axes=[0, 1])
sigma = tf.sqrt(variance)
expected_snr = tf.where(sigma == 0, np.inf, 20 * tf.math.log(mean / (sigma + 1e-7)) / tf.math.log(10.0))
self.assertTrue(np.isclose(snr.numpy(), expected_snr.numpy()).all())
elif framework == 'pt':
import torch
grayscale = torch.mean(image_rgb, dim=0, keepdim=True)
mean = torch.mean(grayscale)
sigma = torch.std(grayscale)
expected_snr = torch.where(sigma == 0, torch.tensor(float('inf')), 20 * torch.log10(mean / (sigma + 1e-7)))
self.assertTrue(np.isclose(snr.item(), expected_snr.item()))
[docs] def test_calculate_sharpness_laplacian(self):
for framework in self.frameworks:
metrics = Metrics(framework=framework)
image_rgb = self.images_rgb[framework]
sharpness = metrics.calculate_sharpness_laplacian(image_rgb)
if framework == 'tf':
import tensorflow as tf
kernel = tf.constant([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=tf.float32)
kernel = tf.reshape(kernel, [3, 3, 1, 1]) # Kernel shape for TensorFlow [height, width, in_channels, out_channels]
if image_rgb.shape[-1] != 1: # Check if the image is not grayscale
grayscale = tf.image.rgb_to_grayscale(image_rgb)
else:
grayscale = image_rgb
grayscale = tf.expand_dims(grayscale, axis=0) # Add batch dimension
expected_sharpness = tf.nn.conv2d(grayscale, kernel, strides=[1, 1, 1, 1], padding='SAME')
expected_sharpness = tf.reduce_mean(tf.abs(expected_sharpness))
self.assertTrue(np.isclose(sharpness.numpy(), expected_sharpness.numpy()))
elif framework == 'pt':
import torch
# Define the Laplacian kernel for edge detection
kernel = torch.tensor([[-1, -1, -1],
[-1, 8, -1],
[-1, -1, -1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
# Kernel shape: [out_channels, in_channels, height, width]
# Convert the image to grayscale if it has more than one channel
if image_rgb.size(1) != 1: # Check if the image has more than one channel
grayscale = torch.mean(image_rgb, dim=0, keepdim=True) # Convert to grayscale
else:
grayscale = image_rgb # Image is already grayscale
# Ensure the grayscale image has the correct shape for convolution
grayscale = grayscale.unsqueeze(0) # Add a batch dimension if necessary, shape becomes [1, 1, H, W]
# Apply the Laplacian kernel using conv2d
expected_sharpness = torch.nn.functional.conv2d(grayscale, kernel, stride=1, padding=1)
expected_sharpness = torch.mean(torch.abs(expected_sharpness)) # Calculate mean absolute sharpness value
self.assertTrue(np.isclose(sharpness.item(), expected_sharpness.item()))
[docs] def test_calculate_channel_mean(self):
for framework in self.frameworks:
metrics = Metrics(framework=framework)
image_rgb = self.images_rgb[framework]
channel_mean = metrics.calculate_channel_mean(image_rgb)
if framework == 'tf':
import tensorflow as tf
expected_mean = tf.reduce_mean(image_rgb, axis=[0, 1])
self.assertTrue(np.allclose(channel_mean.numpy(), expected_mean.numpy()))
elif framework == 'pt':
import torch
expected_mean = torch.mean(image_rgb, dim=[1, 2])
self.assertTrue(np.allclose(channel_mean.numpy(), expected_mean.numpy()))
[docs] def test_calculate_channel_histogram(self):
for framework in self.frameworks:
metrics = Metrics(framework=framework)
image_rgb = self.images_rgb[framework]
channel_histogram = metrics.calculate_channel_histogram(image_rgb)
expected_histogram_shape = (3, 256)
self.assertTrue(channel_histogram.shape, expected_histogram_shape)
[docs] def test_process_batch(self):
for framework in self.frameworks:
metrics = Metrics(framework=framework)
image_rgb = self.images_rgb[framework]
if framework == 'tf':
import tensorflow as tf
images = tf.stack([image_rgb, image_rgb], axis=0)
elif framework == 'pt':
import torch
images = torch.stack([image_rgb, image_rgb], dim=0)
brightness, sharpness, channel_mean, snr, channel_pixels = metrics.process_batch(images)
# Check the shapes of the results
self.assertEqual(brightness.shape[0], 2)
self.assertEqual(sharpness.shape[0], 2)
self.assertEqual(channel_mean.shape, (2, 3))
self.assertEqual(snr.shape[0], 2)
self.assertEqual(channel_pixels.shape, (2, 6, 3)) # Assuming image size is 2x2 with 3 channels
[docs]class TestCalculatePercentiles(unittest.TestCase):
[docs] def test_mismatched_lengths(self):
x = [1, 2, 3]
probabilities = [0.1, 0.2]
with self.assertRaises(ValueError):
calculate_percentiles(x, probabilities)
[docs] def test_probabilities_not_summing_to_one(self):
x = [1, 2, 3]
probabilities = [0.1, 0.2, 0.3]
with self.assertRaises(ValueError):
calculate_percentiles(x, probabilities)
[docs] def test_edge_case_single_nonzero_probability(self):
x = [1, 2, 3]
probabilities = [0, 1, 0]
lower, upper = calculate_percentiles(x, probabilities, 0.01, 0.99)
self.assertEqual(lower, 2)
self.assertEqual(upper, 2)
if __name__ == '__main__':
unittest.main()