diff --git a/src/python.rs b/src/python.rs index ff50dc4..2eaab65 100644 --- a/src/python.rs +++ b/src/python.rs @@ -576,6 +576,10 @@ pub struct TaxonomyIterator { #[pymethods] impl TaxonomyIterator { + fn __iter__(slf: PyRef) -> PyRef { + slf + } + fn __next__(mut slf: PyRefMut, py: Python<'_>) -> PyResult> { let traverse_preorder = true; loop { diff --git a/test_python.py b/test_python.py index c032d76..7f0c1e2 100644 --- a/test_python.py +++ b/test_python.py @@ -150,8 +150,16 @@ def test_json_internal_index(json_tax: Taxonomy): ] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +def test_taxonomy_iterator_is_iterable(newick_tax): + it = iter(newick_tax) + assert iter(it) is it + + def test_json_find_all_by_name(json_tax: Taxonomy): - assert sorted([n.id for n in json_tax.find_all_by_name("species 1.1")]) == ["10", "12"] + assert sorted([n.id for n in json_tax.find_all_by_name("species 1.1")]) == [ + "10", + "12", + ] def test_json_edit_node_parent_updates_children(json_tax: Taxonomy): @@ -475,7 +483,8 @@ def test_ncbi_edit_node_parent(): def test_ncbi_repr(): tax = Taxonomy.from_ncbi("tests/data/") assert ( - tax["562"].__repr__() == '' + tax["562"].__repr__() + == '' ) @@ -511,7 +520,8 @@ def test_gtdb_invalid_format(): @pytest.mark.skipif( - not os.getenv("TAXONOMY_TEST_NCBI"), reason="Define TAXONOMY_TEST_NCBI to run NCBI test" + not os.getenv("TAXONOMY_TEST_NCBI"), + reason="Define TAXONOMY_TEST_NCBI to run NCBI test", ) def test_latestncbi_load_latest_ncbi_taxonomy(): download("https://ftp.ncbi.nih.gov/pub/taxonomy/taxdump.tar.gz")