Skip to content
2 changes: 2 additions & 0 deletions native/python/src/fairseq2n/bindings/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def_data_pipeline(py::module_ &data_module)

.def_property_readonly("is_broken", &data_pipeline::is_broken)

.def_property_readonly("warning_count", &data_pipeline::warning_count)

// state_dict
.def(
"state_dict",
Expand Down
14 changes: 11 additions & 3 deletions native/src/fairseq2n/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,16 @@ data_pipeline::next()
if (ex.recoverable() && warning_count_ < max_num_warnings_) {
warning_count_++;

// TODO: log exception
// Log the exception with the current warning count
(void) fprintf(stderr, "Data pipeline warning (%zu/%zu): %s\n",
warning_count_, max_num_warnings_, ex.what());

// Continue to the next example
continue;
} else {
if (max_num_warnings_ > 0) {
// TODO: log max number of warnings reached.
if (max_num_warnings_ > 0 && warning_count_ >= max_num_warnings_) {
(void) fprintf(stderr, "Data pipeline error: Maximum number of warnings (%zu) reached.\n",
max_num_warnings_);
}

// If the error is not recoverable, any further attempt to read
Expand All @@ -87,6 +93,8 @@ data_pipeline::reset(bool reset_rng)

try {
source_->reset(reset_rng);
// Reset warning counter when pipeline is reset
warning_count_ = 0;
} catch (const std::exception &) {
is_broken_ = true;

Expand Down
6 changes: 6 additions & 0 deletions native/src/fairseq2n/data/data_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class FAIRSEQ2_API data_pipeline {
return is_broken_;
}

std::size_t
warning_count() const noexcept
{
return warning_count_;
}

private:
bool
is_initialized() const noexcept;
Expand Down
10 changes: 9 additions & 1 deletion native/src/fairseq2n/data/filter_data_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,15 @@ filter_data_source::finitude_type() const noexcept
bool
filter_data_source::invoke_function(data &example)
{
return predicate_fn_(example);
try {
return predicate_fn_(example);
} catch (const std::exception &ex) {
// Convert any exception into a recoverable data_pipeline_error
throw data_pipeline_error(
std::string("Error in filter function: ") + ex.what(),
example, // Pass the example that caused the error
true); // Mark as recoverable
}
}

} // fairseq2n::detail
4 changes: 1 addition & 3 deletions tests/unit/data/data_pipeline/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def fn(d: int) -> bool:

assert output == [1, 3, 5]

@pytest.mark.skip("need additional work in data_pipeline::next")
def test_next_does_not_raise_error_when_num_errors_is_less_than_max_num_warnings(
self,
) -> None:
Expand All @@ -89,7 +88,6 @@ def fn(d: int) -> bool:
# TODO: assert log warning

@pytest.mark.parametrize("max_num_warnings", [0, 1, 2])
@pytest.mark.skip("need additional work in data_pipeline::next")
def test_next_raises_error_when_num_errors_exceed_max_num_warnings(
self, max_num_warnings: int
) -> None:
Expand All @@ -109,7 +107,7 @@ def fn(d: int) -> bool:

pipeline = read_sequence(seq).filter(fn).and_return(max_num_warnings)

with pytest.raises(ValueError):
with pytest.raises(DataPipelineError):
for _ in pipeline:
pass

Expand Down
7 changes: 4 additions & 3 deletions tests/unit/data/data_pipeline/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pytest

from fairseq2.data import read_sequence
from fairseq2.data import DataPipelineError, read_sequence


class TestFilterOp:
Expand All @@ -34,8 +34,9 @@ def fn(d: int) -> bool:

pipeline = read_sequence([1, 2, 3, 4]).filter(fn).and_return()

with pytest.raises(ValueError) as exc_info:
with pytest.raises(DataPipelineError) as exc_info:
for d in pipeline:
pass

assert str(exc_info.value) == "filter error"
# Check that the original error message is included in the DataPipelineError
assert "filter error" in str(exc_info.value)
Loading