diff --git a/tests/test_cache_types.py b/tests/test_cache_types.py new file mode 100644 index 00000000..43c0aeb2 --- /dev/null +++ b/tests/test_cache_types.py @@ -0,0 +1,184 @@ +""" +Tests for gas.cache.types module. + +Covers PromptEntry, MediaEntry, VerificationResult dataclasses +and their serialization behavior. +""" + +import time +import unittest + +from gas.cache.types import PromptEntry, MediaEntry, VerificationResult, Media +from gas.types import Modality, MediaType, SourceType + + +class TestPromptEntry(unittest.TestCase): + """Tests for PromptEntry dataclass.""" + + def test_to_dict(self): + entry = PromptEntry( + id="p1", + content="a cat", + content_type="prompt", + created_at=1000.0, + ) + d = entry.to_dict() + self.assertEqual(d["id"], "p1") + self.assertEqual(d["content"], "a cat") + self.assertEqual(d["used_count"], 0) + self.assertIsNone(d["last_used"]) + self.assertIsNone(d["source_media_id"]) + + def test_optional_fields(self): + entry = PromptEntry( + id="p2", + content="test", + content_type="search_query", + created_at=2000.0, + used_count=5, + last_used=2500.0, + source_media_id="m1", + modality="image", + ) + self.assertEqual(entry.used_count, 5) + self.assertEqual(entry.modality, "image") + + +class TestMediaEntry(unittest.TestCase): + """Tests for MediaEntry dataclass.""" + + def test_created_at_auto_set(self): + before = time.time() + entry = MediaEntry( + id="m1", + prompt_id="p1", + file_path="/tmp/test.png", + modality=Modality.IMAGE, + media_type=MediaType.SYNTHETIC, + ) + after = time.time() + self.assertGreaterEqual(entry.created_at, before) + self.assertLessEqual(entry.created_at, after) + + def test_explicit_created_at_preserved(self): + entry = MediaEntry( + id="m1", + prompt_id="p1", + file_path="/tmp/test.png", + modality=Modality.IMAGE, + media_type=MediaType.REAL, + created_at=12345.0, + ) + self.assertEqual(entry.created_at, 12345.0) + + def test_to_dict_converts_enums(self): + entry = MediaEntry( + id="m1", + prompt_id="p1", + file_path="/tmp/test.png", + modality=Modality.IMAGE, + media_type=MediaType.SYNTHETIC, + source_type=SourceType.GENERATED, + ) + d = entry.to_dict() + self.assertEqual(d["modality"], "image") + self.assertEqual(d["media_type"], "synthetic") + self.assertEqual(d["source_type"], "generated") + + def test_default_source_type(self): + entry = MediaEntry( + id="m1", + prompt_id="p1", + file_path="/tmp/test.png", + modality=Modality.IMAGE, + media_type=MediaType.REAL, + ) + self.assertEqual(entry.source_type, SourceType.GENERATED) + + def test_miner_fields(self): + entry = MediaEntry( + id="m1", + prompt_id="p1", + file_path="/tmp/test.png", + modality=Modality.IMAGE, + media_type=MediaType.SYNTHETIC, + source_type=SourceType.MINER, + uid=42, + hotkey="hk_abc", + verified=True, + ) + self.assertEqual(entry.uid, 42) + self.assertEqual(entry.hotkey, "hk_abc") + self.assertTrue(entry.verified) + + def test_c2pa_fields(self): + entry = MediaEntry( + id="m1", + prompt_id="p1", + file_path="/tmp/test.png", + modality=Modality.VIDEO, + media_type=MediaType.REAL, + c2pa_verified=True, + c2pa_issuer="Adobe", + ) + self.assertTrue(entry.c2pa_verified) + self.assertEqual(entry.c2pa_issuer, "Adobe") + + +class TestVerificationResult(unittest.TestCase): + """Tests for VerificationResult.""" + + def test_default_passed_false(self): + entry = MediaEntry( + id="m1", prompt_id="p1", file_path="/tmp/t.png", + modality=Modality.IMAGE, media_type=MediaType.SYNTHETIC, + ) + result = VerificationResult(media_entry=entry) + self.assertFalse(result.passed) + self.assertIsNone(result.verification_score) + self.assertIsNone(result.original_prompt) + + def test_with_scores(self): + entry = MediaEntry( + id="m1", prompt_id="p1", file_path="/tmp/t.png", + modality=Modality.IMAGE, media_type=MediaType.SYNTHETIC, + ) + result = VerificationResult( + media_entry=entry, + original_prompt="a dog", + generated_caption="a dog running", + verification_score={"clip": 0.85}, + passed=True, + ) + self.assertTrue(result.passed) + self.assertEqual(result.verification_score["clip"], 0.85) + + +class TestMedia(unittest.TestCase): + """Tests for Media dataclass.""" + + def test_basic_construction(self): + m = Media( + modality=Modality.IMAGE, + media_type=MediaType.SYNTHETIC, + media_content="placeholder", + format="JPEG", + ) + self.assertEqual(m.modality, Modality.IMAGE) + self.assertIsNone(m.prompt_id) + self.assertIsNone(m.model_name) + + def test_with_generation_args(self): + m = Media( + modality=Modality.VIDEO, + media_type=MediaType.SYNTHETIC, + media_content=None, + format="MP4", + model_name="test-model", + generation_args={"steps": 50}, + ) + self.assertEqual(m.generation_args["steps"], 50) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_duplicate_detection.py b/tests/test_duplicate_detection.py new file mode 100644 index 00000000..97516915 --- /dev/null +++ b/tests/test_duplicate_detection.py @@ -0,0 +1,188 @@ +""" +Comprehensive tests for gas.verification.duplicate_detection module. + +Tests cover: +- Hash extraction utilities (extract_phash, extract_crop_segments) +- Hamming distance calculation +- Duplicate detection logic (pHash and crop-resistant) +- find_duplicates with various scenarios +- Edge cases: empty strings, missing segments, malformed hashes +""" + +import unittest +from unittest.mock import patch, MagicMock + +from gas.verification.duplicate_detection import ( + extract_phash, + extract_crop_segments, + hamming_distance, + count_crop_segment_matches, + is_duplicate, + find_duplicates, + compute_media_hash, + DEFAULT_HAMMING_THRESHOLD, + DEFAULT_CROP_RESISTANT_MATCH_THRESHOLD, +) + + +class TestExtractPhash(unittest.TestCase): + """Tests for extract_phash utility.""" + + def test_plain_hash(self): + self.assertEqual(extract_phash("abcdef1234567890"), "abcdef1234567890") + + def test_hash_with_crop_segments(self): + self.assertEqual(extract_phash("abcdef|seg1;seg2"), "abcdef") + + def test_video_hash_with_frame_count(self): + self.assertEqual(extract_phash("abcdef_4"), "abcdef") + + def test_video_hash_with_crop_and_frames(self): + self.assertEqual(extract_phash("abcdef|seg1;seg2_4"), "abcdef") + + def test_empty_string(self): + self.assertEqual(extract_phash(""), "") + + def test_hash_with_multiple_pipes(self): + # Edge case: should take first part + result = extract_phash("abc|def|ghi") + self.assertEqual(result, "abc") + + +class TestExtractCropSegments(unittest.TestCase): + """Tests for extract_crop_segments utility.""" + + def test_no_segments(self): + self.assertEqual(extract_crop_segments("abcdef"), []) + + def test_single_segment(self): + self.assertEqual(extract_crop_segments("abcdef|seg1"), ["seg1"]) + + def test_multiple_segments(self): + result = extract_crop_segments("abcdef|seg1;seg2;seg3") + self.assertEqual(result, ["seg1", "seg2", "seg3"]) + + def test_empty_crop_part(self): + self.assertEqual(extract_crop_segments("abcdef|"), [""]) + + def test_no_pipe_returns_empty(self): + self.assertEqual(extract_crop_segments("just_a_hash"), []) + + +class TestHammingDistance(unittest.TestCase): + """Tests for hamming_distance function.""" + + def test_identical_hashes(self): + h = "a" * 64 # 256-bit hash as hex + dist = hamming_distance(h, h) + self.assertEqual(dist, 0) + + def test_different_hashes(self): + h1 = "0" * 64 + h2 = "f" * 64 + dist = hamming_distance(h1, h2) + self.assertGreater(dist, 0) + + def test_extracts_phash_before_comparing(self): + """Should strip crop segments before comparing.""" + h = "0" * 64 + h_with_crop = f"{'0' * 64}|seg1;seg2" + dist = hamming_distance(h, h_with_crop) + self.assertEqual(dist, 0) + + def test_strips_frame_count(self): + h = "0" * 64 + h_video = f"{'0' * 64}_4" + dist = hamming_distance(h, h_video) + self.assertEqual(dist, 0) + + +class TestIsDuplicate(unittest.TestCase): + """Tests for is_duplicate function.""" + + def test_identical_hashes_are_duplicate(self): + h = "0" * 64 + self.assertTrue(is_duplicate(h, h)) + + def test_very_different_hashes_not_duplicate(self): + h1 = "0" * 64 + h2 = "f" * 64 + self.assertFalse(is_duplicate(h1, h2, threshold=5)) + + def test_custom_threshold(self): + h = "0" * 64 + self.assertTrue(is_duplicate(h, h, threshold=0)) + + +class TestFindDuplicates(unittest.TestCase): + """Tests for find_duplicates function.""" + + def test_empty_existing_hashes(self): + result = find_duplicates("0" * 64, []) + self.assertEqual(result, []) + + def test_finds_exact_match(self): + h = "0" * 64 + existing = [("media1", h), ("media2", "f" * 64)] + result = find_duplicates(h, existing) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], "media1") + self.assertEqual(result[0][1], 0) + + def test_skips_none_hashes(self): + h = "0" * 64 + existing = [("media1", None), ("media2", "")] + result = find_duplicates(h, existing) + self.assertEqual(result, []) + + def test_sorted_by_distance(self): + h = "0" * 64 + existing = [("media1", h), ("media2", h)] + result = find_duplicates(h, existing) + # Both should match with distance 0 + self.assertEqual(len(result), 2) + for _, dist in result: + self.assertEqual(dist, 0) + + +class TestComputeMediaHash(unittest.TestCase): + """Tests for compute_media_hash function.""" + + def test_unsupported_modality_returns_none(self): + result = compute_media_hash(b"data", modality="audio") + self.assertIsNone(result) + + @patch("gas.verification.duplicate_detection.IMAGEHASH_AVAILABLE", False) + def test_imagehash_unavailable_returns_none(self): + from gas.verification.duplicate_detection import compute_image_hash + result = compute_image_hash(b"data") + self.assertIsNone(result) + + +class TestCountCropSegmentMatches(unittest.TestCase): + """Tests for count_crop_segment_matches.""" + + def test_no_segments_returns_zero(self): + result = count_crop_segment_matches("abc", "def") + self.assertEqual(result, 0) + + def test_one_hash_without_segments(self): + result = count_crop_segment_matches("abc|seg1", "def") + self.assertEqual(result, 0) + + +class TestDefaultConstants(unittest.TestCase): + """Verify default constants are reasonable.""" + + def test_default_hamming_threshold(self): + self.assertIsInstance(DEFAULT_HAMMING_THRESHOLD, int) + self.assertGreater(DEFAULT_HAMMING_THRESHOLD, 0) + self.assertLessEqual(DEFAULT_HAMMING_THRESHOLD, 20) + + def test_default_crop_match_threshold(self): + self.assertIsInstance(DEFAULT_CROP_RESISTANT_MATCH_THRESHOLD, int) + self.assertGreater(DEFAULT_CROP_RESISTANT_MATCH_THRESHOLD, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rewards.py b/tests/test_rewards.py new file mode 100644 index 00000000..824fb8bd --- /dev/null +++ b/tests/test_rewards.py @@ -0,0 +1,428 @@ +""" +Comprehensive tests for gas.evaluation.rewards module. + +Tests cover: +- Discriminator reward calculation with various inputs +- Generator base reward computation from verification stats +- Generator reward multipliers with fool rate and sample size scaling +- Edge cases: empty inputs, invalid data types, missing fields +- Liveness filtering for generator rewards +""" + +import math +import time +import unittest +from unittest.mock import MagicMock, patch + + +from gas.evaluation.rewards import ( + get_discriminator_rewards, + get_generator_base_rewards, + get_generator_reward_multipliers, +) + + +def _make_metagraph(hotkeys): + """Create a mock metagraph with the given hotkey list.""" + mg = MagicMock() + mg.hotkeys = list(hotkeys) + return mg + + +class TestGetDiscriminatorRewards(unittest.TestCase): + """Tests for get_discriminator_rewards.""" + + def test_empty_runs_returns_empty(self): + mg = _make_metagraph(["hk0", "hk1"]) + result = get_discriminator_rewards([], mg) + self.assertEqual(result, {}) + + def test_none_runs_returns_empty(self): + mg = _make_metagraph(["hk0"]) + result = get_discriminator_rewards(None, mg) + self.assertEqual(result, {}) + + def test_single_image_result(self): + mg = _make_metagraph(["hk0", "hk1"]) + runs = [ + { + "discriminator_address": "hk1", + "modality": "image", + "binary_mcc": 0.8, + "multiclass_mcc": 0.6, + } + ] + result = get_discriminator_rewards(runs, mg) + # image reward = 0.5 * (0.5*0.8 + 0.5*0.6) = 0.5 * 0.7 = 0.35 + self.assertIn(1, result) + self.assertAlmostEqual(result[1], 0.35, places=5) + + def test_image_and_video_combined(self): + mg = _make_metagraph(["hk0"]) + runs = [ + { + "discriminator_address": "hk0", + "modality": "image", + "binary_mcc": 1.0, + "multiclass_mcc": 1.0, + }, + { + "discriminator_address": "hk0", + "modality": "video", + "binary_mcc": 0.5, + "multiclass_mcc": 0.5, + }, + ] + result = get_discriminator_rewards(runs, mg) + # image_reward = 0.5*1.0 + 0.5*1.0 = 1.0 + # video_reward = 0.5*0.5 + 0.5*0.5 = 0.5 + # final = 0.5*1.0 + 0.5*0.5 = 0.75 + self.assertAlmostEqual(result[0], 0.75, places=5) + + def test_negative_mcc_clamped_to_zero(self): + mg = _make_metagraph(["hk0"]) + runs = [ + { + "discriminator_address": "hk0", + "modality": "image", + "binary_mcc": -0.5, + "multiclass_mcc": -1.0, + } + ] + result = get_discriminator_rewards(runs, mg) + # max(0, -0.5) = 0, max(0, -1.0) = 0 → reward = 0 + self.assertAlmostEqual(result[0], 0.0, places=5) + + def test_none_mcc_values_default_to_zero(self): + mg = _make_metagraph(["hk0"]) + runs = [ + { + "discriminator_address": "hk0", + "modality": "image", + "binary_mcc": None, + "multiclass_mcc": None, + } + ] + result = get_discriminator_rewards(runs, mg) + self.assertAlmostEqual(result[0], 0.0, places=5) + + def test_unknown_hotkey_ignored(self): + mg = _make_metagraph(["hk0"]) + runs = [ + { + "discriminator_address": "unknown_key", + "modality": "image", + "binary_mcc": 0.9, + "multiclass_mcc": 0.9, + } + ] + result = get_discriminator_rewards(runs, mg) + self.assertEqual(result, {}) + + def test_missing_modality_skipped(self): + mg = _make_metagraph(["hk0"]) + runs = [ + { + "discriminator_address": "hk0", + "binary_mcc": 0.9, + "multiclass_mcc": 0.9, + } + ] + result = get_discriminator_rewards(runs, mg) + self.assertEqual(result, {}) + + def test_invalid_result_type_skipped(self): + mg = _make_metagraph(["hk0"]) + runs = ["not_a_dict", 42, None] + result = get_discriminator_rewards(runs, mg) + self.assertEqual(result, {}) + + def test_string_mcc_values_converted(self): + """MCC values that come as strings should be converted to float.""" + mg = _make_metagraph(["hk0"]) + runs = [ + { + "discriminator_address": "hk0", + "modality": "image", + "binary_mcc": "0.7", + "multiclass_mcc": "0.3", + } + ] + result = get_discriminator_rewards(runs, mg) + expected = 0.5 * (0.5 * 0.7 + 0.5 * 0.3) + self.assertAlmostEqual(result[0], expected, places=5) + + def test_custom_weights(self): + mg = _make_metagraph(["hk0"]) + runs = [ + { + "discriminator_address": "hk0", + "modality": "image", + "binary_mcc": 1.0, + "multiclass_mcc": 0.0, + }, + { + "discriminator_address": "hk0", + "modality": "video", + "binary_mcc": 0.0, + "multiclass_mcc": 1.0, + }, + ] + result = get_discriminator_rewards( + runs, + mg, + image_score_weight=0.7, + video_score_weight=0.3, + binary_score_weight=1.0, + multiclass_score_weight=0.0, + ) + # image: 1.0*1.0 + 0.0*0.0 = 1.0 + # video: 1.0*0.0 + 0.0*1.0 = 0.0 + # final = 0.7*1.0 + 0.3*0.0 = 0.7 + self.assertAlmostEqual(result[0], 0.7, places=5) + + def test_multiple_miners(self): + mg = _make_metagraph(["hk0", "hk1", "hk2"]) + runs = [ + {"discriminator_address": "hk0", "modality": "image", "binary_mcc": 0.5, "multiclass_mcc": 0.5}, + {"discriminator_address": "hk2", "modality": "image", "binary_mcc": 1.0, "multiclass_mcc": 1.0}, + ] + result = get_discriminator_rewards(runs, mg) + self.assertIn(0, result) + self.assertIn(2, result) + self.assertNotIn(1, result) + + +class TestGetGeneratorBaseRewards(unittest.TestCase): + """Tests for get_generator_base_rewards.""" + + def test_empty_stats_returns_empty(self): + rewards, media_ids = get_generator_base_rewards({}) + self.assertEqual(rewards, {}) + self.assertEqual(media_ids, []) + + def test_none_stats_returns_empty(self): + rewards, media_ids = get_generator_base_rewards(None) + self.assertEqual(rewards, {}) + self.assertEqual(media_ids, []) + + def test_single_miner_perfect_pass_rate(self): + stats = { + "hotkey_a": { + "uid": 5, + "total_verified": 8, + "total_failed": 0, + "total_evaluated": 8, + "pass_rate": 1.0, + "media_ids": ["m1", "m2"], + } + } + rewards, media_ids = get_generator_base_rewards(stats) + # reward = 1.0 * min(8, 10) = 8.0 + self.assertEqual(rewards[5], 8.0) + self.assertEqual(media_ids, ["m1", "m2"]) + + def test_volume_bonus_capped_at_ten(self): + stats = { + "hk": { + "uid": 0, + "total_verified": 50, + "total_failed": 0, + "total_evaluated": 50, + "pass_rate": 1.0, + "media_ids": ["m1"], + } + } + rewards, _ = get_generator_base_rewards(stats) + # Capped: 1.0 * min(50, 10) = 10.0 + self.assertEqual(rewards[0], 10.0) + + def test_zero_pass_rate(self): + stats = { + "hk": { + "uid": 1, + "total_verified": 0, + "total_failed": 5, + "total_evaluated": 5, + "pass_rate": 0.0, + "media_ids": ["m1"], + } + } + rewards, _ = get_generator_base_rewards(stats) + self.assertEqual(rewards[1], 0.0) + + def test_multiple_miners_media_ids_aggregated(self): + stats = { + "hk0": { + "uid": 0, + "total_verified": 3, + "total_failed": 1, + "total_evaluated": 4, + "pass_rate": 0.75, + "media_ids": ["a", "b"], + }, + "hk1": { + "uid": 1, + "total_verified": 5, + "total_failed": 0, + "total_evaluated": 5, + "pass_rate": 1.0, + "media_ids": ["c"], + }, + } + rewards, media_ids = get_generator_base_rewards(stats) + self.assertEqual(len(rewards), 2) + self.assertIn("a", media_ids) + self.assertIn("c", media_ids) + + +class TestGetGeneratorRewardMultipliers(unittest.TestCase): + """Tests for get_generator_reward_multipliers.""" + + def test_empty_results_returns_empty(self): + mg = _make_metagraph(["hk0"]) + result = get_generator_reward_multipliers([], mg) + self.assertEqual(result, {}) + + def test_none_results_returns_empty(self): + mg = _make_metagraph(["hk0"]) + result = get_generator_reward_multipliers(None, mg) + self.assertEqual(result, {}) + + def test_basic_fool_rate(self): + mg = _make_metagraph(["hk0"]) + results = [ + {"ss58_address": "hk0", "fooled_count": 15, "not_fooled_count": 5}, + ] + rewards = get_generator_reward_multipliers(results, mg) + # fool_rate = 15/20 = 0.75, total=20 >= ref=20 → multiplier = 1.0 + # reward = 0.75 * 1.0 = 0.75 + self.assertIn(0, rewards) + self.assertAlmostEqual(rewards[0], 0.75, places=3) + + def test_small_sample_penalized(self): + mg = _make_metagraph(["hk0"]) + results = [ + {"ss58_address": "hk0", "fooled_count": 5, "not_fooled_count": 0}, + ] + rewards = get_generator_reward_multipliers(results, mg) + # total=5, fool_rate=1.0, multiplier = max(0.5, 5/20) = 0.5 → but 5/20=0.25 < 0.5 + # reward = 1.0 * 0.5 = 0.5 + self.assertAlmostEqual(rewards[0], 0.5, places=3) + + def test_large_sample_bonus(self): + mg = _make_metagraph(["hk0"]) + results = [ + {"ss58_address": "hk0", "fooled_count": 100, "not_fooled_count": 0}, + ] + rewards = get_generator_reward_multipliers(results, mg) + # total=100, fool_rate=1.0 + # multiplier = min(2.0, 1.0 + log(100/20)) = min(2.0, 1.0+1.609) = min(2.0, 2.609) = 2.0 + # reward = 1.0 * 2.0 = 2.0 + self.assertAlmostEqual(rewards[0], 2.0, places=3) + + def test_reward_clamped_to_two(self): + mg = _make_metagraph(["hk0"]) + results = [ + {"ss58_address": "hk0", "fooled_count": 10000, "not_fooled_count": 0}, + ] + rewards = get_generator_reward_multipliers(results, mg) + self.assertLessEqual(rewards[0], 2.0) + + def test_zero_total_count(self): + mg = _make_metagraph(["hk0"]) + results = [ + {"ss58_address": "hk0", "fooled_count": 0, "not_fooled_count": 0}, + ] + rewards = get_generator_reward_multipliers(results, mg) + self.assertNotIn(0, rewards) + + def test_aggregation_across_multiple_results(self): + mg = _make_metagraph(["hk0"]) + results = [ + {"ss58_address": "hk0", "fooled_count": 10, "not_fooled_count": 5}, + {"ss58_address": "hk0", "fooled_count": 5, "not_fooled_count": 0}, + ] + rewards = get_generator_reward_multipliers(results, mg) + # total_fooled=15, total_not_fooled=5, total=20 + # fool_rate=0.75, multiplier=1.0 → reward=0.75 + self.assertAlmostEqual(rewards[0], 0.75, places=3) + + def test_liveness_filter_excludes_inactive(self): + mg = _make_metagraph(["hk0", "hk1"]) + current = time.time() + results = [ + {"ss58_address": "hk0", "fooled_count": 20, "not_fooled_count": 0}, + {"ss58_address": "hk1", "fooled_count": 20, "not_fooled_count": 0}, + ] + liveness = { + "hk0": current - 3600, # active (1h ago) + "hk1": current - 100 * 3600, # inactive (100h ago) + } + rewards = get_generator_reward_multipliers( + results, mg, generator_liveness=liveness, max_inactive_hours=24 + ) + self.assertIn(0, rewards) + self.assertNotIn(1, rewards) + + def test_liveness_filter_none_means_no_filtering(self): + mg = _make_metagraph(["hk0"]) + results = [ + {"ss58_address": "hk0", "fooled_count": 20, "not_fooled_count": 0}, + ] + rewards = get_generator_reward_multipliers( + results, mg, generator_liveness=None + ) + self.assertIn(0, rewards) + + def test_invalid_counts_default_to_zero(self): + mg = _make_metagraph(["hk0"]) + results = [ + {"ss58_address": "hk0", "fooled_count": "abc", "not_fooled_count": None}, + ] + rewards = get_generator_reward_multipliers(results, mg) + # Both default to 0 → total=0 → no reward + self.assertNotIn(0, rewards) + + def test_unknown_address_ignored(self): + mg = _make_metagraph(["hk0"]) + results = [ + {"ss58_address": "unknown", "fooled_count": 20, "not_fooled_count": 0}, + ] + rewards = get_generator_reward_multipliers(results, mg) + self.assertEqual(rewards, {}) + + +class TestRewardEdgeCases(unittest.TestCase): + """Edge cases and integration-style tests.""" + + def test_discriminator_last_modality_wins_on_overwrite(self): + """When same uid has multiple results for same modality, last one wins.""" + mg = _make_metagraph(["hk0"]) + runs = [ + {"discriminator_address": "hk0", "modality": "image", "binary_mcc": 0.1, "multiclass_mcc": 0.1}, + {"discriminator_address": "hk0", "modality": "image", "binary_mcc": 0.9, "multiclass_mcc": 0.9}, + ] + result = get_discriminator_rewards(runs, mg) + # Last image result overwrites: 0.5*0.9 + 0.5*0.9 = 0.9 + # final = 0.5 * 0.9 = 0.45 + self.assertAlmostEqual(result[0], 0.45, places=5) + + def test_generator_base_rewards_uid_as_string(self): + """uid field as string should be converted to int.""" + stats = { + "hk": { + "uid": "3", + "total_verified": 5, + "total_failed": 0, + "total_evaluated": 5, + "pass_rate": 0.8, + "media_ids": [], + } + } + rewards, _ = get_generator_base_rewards(stats) + self.assertIn(3, rewards) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 00000000..1d881dbb --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,158 @@ +""" +Tests for gas.types module. + +Covers enum behavior, MediaType int_value mapping, DatasetConfig validation, +ModelConfig construction, and DiscriminatorModelId serialization roundtrip. +""" + +import json +import unittest + +from gas.types import ( + Modality, + MediaType, + SourceType, + DatasetConfig, + ModelConfig, + ModelTask, + DiscriminatorModelId, + DiscriminatorModelMetadata, + SOURCE_TYPE_TO_DB_NAME_FIELD, + SOURCE_TYPE_TO_NAME, +) + + +class TestModality(unittest.TestCase): + def test_values(self): + self.assertEqual(Modality.IMAGE.value, "image") + self.assertEqual(Modality.VIDEO.value, "video") + + def test_from_string(self): + self.assertEqual(Modality("image"), Modality.IMAGE) + self.assertEqual(Modality("video"), Modality.VIDEO) + + def test_invalid_raises(self): + with self.assertRaises(ValueError): + Modality("audio") + + +class TestMediaType(unittest.TestCase): + def test_int_values(self): + self.assertEqual(MediaType.REAL.int_value, 0) + self.assertEqual(MediaType.SYNTHETIC.int_value, 1) + self.assertEqual(MediaType.SEMISYNTHETIC.int_value, 2) + + def test_from_string(self): + self.assertEqual(MediaType("real"), MediaType.REAL) + self.assertEqual(MediaType("synthetic"), MediaType.SYNTHETIC) + + +class TestSourceType(unittest.TestCase): + def test_all_source_types_in_mappings(self): + for st in SourceType: + self.assertIn(st, SOURCE_TYPE_TO_NAME) + self.assertIn(st, SOURCE_TYPE_TO_DB_NAME_FIELD) + + +class TestDatasetConfig(unittest.TestCase): + def test_default_source_format_image(self): + cfg = DatasetConfig(path="test/path", modality=Modality.IMAGE, media_type=MediaType.REAL) + self.assertEqual(cfg.source_format, "parquet") + + def test_default_source_format_video(self): + cfg = DatasetConfig(path="test/path", modality=Modality.VIDEO, media_type=MediaType.REAL) + self.assertEqual(cfg.source_format, "zip") + + def test_string_modality_converted(self): + cfg = DatasetConfig(path="test/path", modality="image", media_type="real") + self.assertIsInstance(cfg.modality, Modality) + self.assertIsInstance(cfg.media_type, MediaType) + + def test_custom_source_format_preserved(self): + cfg = DatasetConfig( + path="test/path", + modality=Modality.IMAGE, + media_type=MediaType.REAL, + source_format="custom", + ) + self.assertEqual(cfg.source_format, "custom") + + +class TestModelConfig(unittest.TestCase): + def test_default_media_type_t2i(self): + cfg = ModelConfig(path="model", task=ModelTask.TEXT_TO_IMAGE, pipeline_cls="cls") + self.assertEqual(cfg.media_type, MediaType.SYNTHETIC) + + def test_default_media_type_i2i(self): + cfg = ModelConfig(path="model", task=ModelTask.IMAGE_TO_IMAGE, pipeline_cls="cls") + self.assertEqual(cfg.media_type, MediaType.SEMISYNTHETIC) + + def test_to_dict_keys(self): + cfg = ModelConfig(path="model", task=ModelTask.TEXT_TO_IMAGE, pipeline_cls="cls") + d = cfg.to_dict() + expected_keys = { + "pipeline_cls", "from_pretrained_args", "generation_args", + "use_autocast", "enable_model_cpu_offload", + "enable_sequential_cpu_offload", "vae_enable_slicing", + "vae_enable_tiling", "scheduler", "save_args", + "pipeline_stages", "clear_memory_on_stage_end", + } + self.assertEqual(set(d.keys()), expected_keys) + + def test_custom_generation_args(self): + cfg = ModelConfig( + path="model", + task=ModelTask.TEXT_TO_VIDEO, + pipeline_cls="cls", + generation_args={"steps": 50}, + ) + self.assertEqual(cfg.generation_args["steps"], 50) + + +class TestDiscriminatorModelId(unittest.TestCase): + def test_hash_truncated_to_16(self): + mid = DiscriminatorModelId(key="mykey", hash="a" * 32) + self.assertEqual(len(mid.hash), 16) + + def test_short_hash_preserved(self): + mid = DiscriminatorModelId(key="mykey", hash="abc") + self.assertEqual(mid.hash, "abc") + + def test_compressed_str_roundtrip(self): + original = DiscriminatorModelId(key="test/model", hash="abcdef1234567890") + compressed = original.to_compressed_str() + restored = DiscriminatorModelId.from_compressed_str(compressed) + self.assertEqual(original, restored) + + def test_equality(self): + a = DiscriminatorModelId(key="k", hash="h") + b = DiscriminatorModelId(key="k", hash="h") + c = DiscriminatorModelId(key="k", hash="x") + self.assertEqual(a, b) + self.assertNotEqual(a, c) + + def test_not_equal_to_other_types(self): + mid = DiscriminatorModelId(key="k", hash="h") + self.assertNotEqual(mid, "not_a_model_id") + + +class TestDiscriminatorModelMetadata(unittest.TestCase): + def test_dict_roundtrip(self): + mid = DiscriminatorModelId(key="k", hash="h") + meta = DiscriminatorModelMetadata(id=mid, block=100) + d = meta.to_dict() + restored = DiscriminatorModelMetadata.from_dict(d) + self.assertEqual(restored.id, mid) + self.assertEqual(restored.block, 100) + + +class TestModelTask(unittest.TestCase): + def test_task_values(self): + self.assertEqual(ModelTask.TEXT_TO_IMAGE.value, "t2i") + self.assertEqual(ModelTask.TEXT_TO_VIDEO.value, "t2v") + self.assertEqual(ModelTask.IMAGE_TO_IMAGE.value, "i2i") + self.assertEqual(ModelTask.IMAGE_TO_VIDEO.value, "i2v") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..615b8bdc --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,140 @@ +""" +Tests for gas.utils.utils module. + +Covers utility functions: get_file_modality, get_metadata, ExitContext, +run_in_thread, and the fail_with_none decorator. +""" + +import json +import os +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +from gas.utils.utils import ( + get_file_modality, + get_metadata, + ExitContext, + fail_with_none, + run_in_thread, +) + + +class TestGetFileModality(unittest.TestCase): + """Tests for get_file_modality.""" + + def test_image_extensions(self): + for ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]: + self.assertEqual(get_file_modality(f"test{ext}"), "image", f"Failed for {ext}") + + def test_video_extensions(self): + for ext in [".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv"]: + self.assertEqual(get_file_modality(f"test{ext}"), "video", f"Failed for {ext}") + + def test_unknown_extension(self): + self.assertEqual(get_file_modality("test.txt"), "file") + self.assertEqual(get_file_modality("test.pdf"), "file") + + def test_case_insensitive(self): + self.assertEqual(get_file_modality("test.JPG"), "image") + self.assertEqual(get_file_modality("test.MP4"), "video") + + def test_path_with_directories(self): + self.assertEqual(get_file_modality("/path/to/image.png"), "image") + + def test_no_extension(self): + self.assertEqual(get_file_modality("noext"), "file") + + +class TestGetMetadata(unittest.TestCase): + """Tests for get_metadata.""" + + def test_returns_empty_when_no_json(self): + result = get_metadata("/nonexistent/path/file.png") + self.assertEqual(result, {}) + + def test_reads_existing_json(self): + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as img: + img_path = img.name + json_path = os.path.splitext(img_path)[0] + ".json" + try: + with open(json_path, "w") as f: + json.dump({"key": "value"}, f) + result = get_metadata(img_path) + self.assertEqual(result, {"key": "value"}) + finally: + os.unlink(img_path) + os.unlink(json_path) + + def test_invalid_json_returns_empty(self): + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as img: + img_path = img.name + json_path = os.path.splitext(img_path)[0] + ".json" + try: + with open(json_path, "w") as f: + f.write("not valid json{{{") + result = get_metadata(img_path) + self.assertEqual(result, {}) + finally: + os.unlink(img_path) + os.unlink(json_path) + + +class TestExitContext(unittest.TestCase): + """Tests for ExitContext.""" + + def test_initial_state(self): + ctx = ExitContext() + self.assertFalse(ctx.isExiting) + self.assertFalse(bool(ctx)) + + def test_start_exit(self): + ctx = ExitContext() + ctx.startExit() + self.assertTrue(ctx.isExiting) + self.assertTrue(bool(ctx)) + + def test_double_exit_raises_system_exit(self): + ctx = ExitContext() + ctx.startExit() + with self.assertRaises(SystemExit): + ctx.startExit() + + +class TestFailWithNone(unittest.TestCase): + """Tests for fail_with_none decorator.""" + + def test_successful_function(self): + @fail_with_none("error msg") + def good_func(): + return 42 + self.assertEqual(good_func(), 42) + + def test_failing_function_returns_none(self): + @fail_with_none("error msg") + def bad_func(): + raise ValueError("boom") + self.assertIsNone(bad_func()) + + def test_preserves_arguments(self): + @fail_with_none() + def add(a, b): + return a + b + self.assertEqual(add(1, 2), 3) + + +class TestRunInThread(unittest.TestCase): + """Tests for run_in_thread.""" + + def test_returns_result(self): + result = run_in_thread(lambda: 42, timeout=5) + self.assertEqual(result, 42) + + def test_timeout_raises(self): + import time + with self.assertRaises(TimeoutError): + run_in_thread(lambda: time.sleep(10), timeout=1) + + +if __name__ == "__main__": + unittest.main()