import unittest
import os
import numpy as np
import tensorflow as tf
import torch
from lensai_profiler import Sketches # Replace with the actual import path
[docs]class TestSketches(unittest.TestCase):
[docs] def setUp(self):
"""Setup common test data and instances."""
self.num_channels = 3
self.sketches = Sketches() # Initialize the Sketches class
self.save_path = "./sketches_test"
os.makedirs(self.save_path, exist_ok=True)
[docs] def tearDown(self):
# Clean up any resources after each test
if os.path.exists(self.save_path):
for root, dirs, files in os.walk(self.save_path, topdown=False):
for name in files:
os.remove(os.path.join(root, name))
for name in dirs:
os.rmdir(os.path.join(root, name))
os.rmdir(self.save_path)
[docs] def test_initialization(self):
"""Test that the Sketches class initializes correctly."""
self.sketches.register_metric('test_metric', num_channels=self.num_channels)
self.assertIn('test_metric', self.sketches.sketch_registry)
self.assertEqual(len(self.sketches.sketch_registry['test_metric']), self.num_channels)
[docs] def test_update_kll_sketch_tensorflow(self):
"""Test updating KLL sketch with TensorFlow tensor data."""
self.sketches.register_metric('brightness')
tf_values = tf.constant([1.0, 2.0, 3.0], dtype=tf.float32)
self.sketches.update_kll_sketch(self.sketches.sketch_registry['brightness'], tf_values)
# Ensure the sketch has been updated
self.assertGreater(self.sketches.sketch_registry['brightness'].n, 0)
[docs] def test_update_kll_sketch_pytorch(self):
"""Test updating KLL sketch with PyTorch tensor data."""
self.sketches.register_metric('brightness')
torch_values = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
self.sketches.update_kll_sketch(self.sketches.sketch_registry['brightness'], torch_values)
# Ensure the sketch has been updated
self.assertGreater(self.sketches.sketch_registry['brightness'].n, 0)
[docs] def test_update_sketches_tensorflow(self):
"""Test updating all sketches with TensorFlow tensor data."""
self.sketches.register_metric('brightness')
self.sketches.register_metric('sharpness')
self.sketches.register_metric('channel_mean', num_channels=self.num_channels)
self.sketches.register_metric('snr')
self.sketches.register_metric('channel_pixels', num_channels=self.num_channels)
brightness = tf.constant([1.0, 2.0, 3.0], dtype=tf.float32)
sharpness = tf.constant([4.0, 5.0, 6.0], dtype=tf.float32)
channel_mean = tf.constant([[1.0, 2.0, 3.0]], dtype=tf.float32)
snr = tf.constant([7.0, 8.0, 9.0], dtype=tf.float32)
channel_pixels = tf.constant([[10.0, 11.0, 12.0]], dtype=tf.float32)
print(self.sketches.sketch_registry['channel_mean'][0].n)
self.sketches.update_sketches(brightness=brightness, sharpness=sharpness, channel_mean=channel_mean, snr=snr, channel_pixels=channel_pixels)
self.assertGreater(self.sketches.sketch_registry['brightness'].n, 0)
self.assertGreater(self.sketches.sketch_registry['sharpness'].n, 0)
self.assertGreater(self.sketches.sketch_registry['snr'].n, 0)
self.assertGreater(self.sketches.sketch_registry['channel_mean'][0].n, 0)
self.assertGreater(self.sketches.sketch_registry['channel_pixels'][0].n, 0)
[docs] def test_save_and_load_sketches(self):
"""Test saving and loading sketches."""
# Register and update a sketch with some dummy data
self.sketches.register_metric('brightness')
brightness = tf.constant([1.0, 2.0, 3.0], dtype=tf.float32)
self.sketches.update_kll_sketch(self.sketches.sketch_registry['brightness'], brightness)
# Save sketches to disk
self.sketches.save_sketches(self.save_path)
# Create a new Sketches instance and load from disk
new_sketches = Sketches()
new_sketches.load_sketches(self.save_path)
# Check if the loaded sketch is the same as the saved one
self.assertTrue('brightness' in new_sketches.sketch_registry)
original_sketch = self.sketches.sketch_registry['brightness']
loaded_sketch = new_sketches.sketch_registry['brightness']
self.assertEqual(original_sketch.n, loaded_sketch.n)
[docs] def test_compute_thresholds(self):
"""Test the computation of thresholds."""
self.sketches.register_metric('brightness')
brightness = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0], dtype=tf.float32)
self.sketches.update_kll_sketch(self.sketches.sketch_registry['brightness'], brightness)
thresholds = self.sketches.compute_thresholds(lower_percentile=0.2, upper_percentile=0.8)
self.assertIn('brightness', thresholds)
self.assertEqual(len(thresholds['brightness']), 2)
[docs] def test_custom_metric_logging(self):
"""Test logging a custom metric like embedding vectors."""
self.sketches.register_metric('embeddings')
embeddings = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float32)
self.sketches.update_kll_sketch(self.sketches.sketch_registry['embeddings'], embeddings)
self.assertGreater(self.sketches.sketch_registry['embeddings'].n, 0)
if __name__ == '__main__':
unittest.main()