Source code for tests.test_sketches

import unittest
import os
import numpy as np
import tensorflow as tf
import torch
from lensai_profiler import Sketches  # Replace with the actual import path

[docs]class TestSketches(unittest.TestCase):
[docs] def setUp(self): """Setup common test data and instances.""" self.envs = ["pt", "tf"] # Environments: 'pt' for PyTorch, 'tf' for TensorFlow self.num_channels = int(os.getenv("CHANNELS", 3)) # Number of channels self.save_path = "./sketches_test" os.makedirs(self.save_path, exist_ok=True)
[docs] def tearDown(self): """Clean up any resources after each test.""" if os.path.exists(self.save_path): for root, dirs, files in os.walk(self.save_path, topdown=False): for name in files: os.remove(os.path.join(root, name)) for name in dirs: os.rmdir(os.path.join(root, name)) os.rmdir(self.save_path)
[docs] def test_environments(self): """Test all functionalities in both environments.""" for env in self.envs: with self.subTest(env=env): os.environ["ENV"] = env sketches = Sketches() # Initialize the Sketches class # Test Initialization sketches.register_metric('test_metric', num_channels=self.num_channels) self.assertIn('test_metric', sketches.sketch_registry) self.assertEqual(len(sketches.sketch_registry['test_metric']), self.num_channels) del sketches.sketch_registry['test_metric'] # Test Update KLL Sketch sketches.register_metric('brightness') if env == "tf": values = tf.constant([1.0, 2.0, 3.0], dtype=tf.float32) elif env == "pt": values = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) sketches.update_kll_sketch(sketches.sketch_registry['brightness'], values) self.assertGreater(sketches.sketch_registry['brightness'].n, 0, "Sketch should be updated.") # Test Update Sketches sketches.register_metric('sharpness') sketches.register_metric('channel_mean', num_channels=self.num_channels) sketches.register_metric('snr') sketches.register_metric('channel_pixels', num_channels=self.num_channels) if env == "tf": brightness = tf.constant([1.0, 2.0, 3.0], dtype=tf.float32) sharpness = tf.constant([4.0, 5.0, 6.0], dtype=tf.float32) channel_mean = tf.constant([[1.0, 2.0, 3.0]], dtype=tf.float32) snr = tf.constant([7.0, 8.0, 9.0], dtype=tf.float32) channel_pixels = tf.constant([[10.0, 11.0, 12.0]], dtype=tf.float32) elif env == "pt": brightness = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) sharpness = torch.tensor([4.0, 5.0, 6.0], dtype=torch.float32) channel_mean = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32) snr = torch.tensor([7.0, 8.0, 9.0], dtype=torch.float32) channel_pixels = torch.tensor([[10.0, 11.0, 12.0]], dtype=torch.float32) sketches.update_sketches(brightness=brightness, sharpness=sharpness, channel_mean=channel_mean, snr=snr, channel_pixels=channel_pixels) self.assertGreater(sketches.sketch_registry['brightness'].n, 0) self.assertGreater(sketches.sketch_registry['sharpness'].n, 0) self.assertGreater(sketches.sketch_registry['snr'].n, 0) self.assertGreater(sketches.sketch_registry['channel_mean'][0].n, 0) self.assertGreater(sketches.sketch_registry['channel_pixels'][0].n, 0) # Test Save and Load Sketches sketches.save_sketches(self.save_path) new_sketches = Sketches() new_sketches.load_sketches(self.save_path) self.assertIn('brightness', new_sketches.sketch_registry) original_sketch = sketches.sketch_registry['brightness'] loaded_sketch = new_sketches.sketch_registry['brightness'] self.assertEqual(original_sketch.n, loaded_sketch.n) # Test Compute Thresholds thresholds = sketches.compute_thresholds(lower_percentile=0.2, upper_percentile=0.8) self.assertIn('brightness', thresholds) self.assertEqual(len(thresholds['brightness']), 2) # Test Custom Metric Logging sketches.register_metric('embeddings') embeddings = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float32) if env == "pt" else tf.constant([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=tf.float32) sketches.update_kll_sketch(sketches.sketch_registry['embeddings'], embeddings) self.assertGreater(sketches.sketch_registry['embeddings'].n, 0)
if __name__ == '__main__': unittest.main()