diff --git a/dataflow/operators/core_text/filter/prompted_filter.py b/dataflow/operators/core_text/filter/prompted_filter.py index 8f07bb58..8c316fc8 100644 --- a/dataflow/operators/core_text/filter/prompted_filter.py +++ b/dataflow/operators/core_text/filter/prompted_filter.py @@ -18,6 +18,21 @@ def __init__(self, llm_serving: LLMServingABC, system_prompt: str = "Please eval self.prompted_evaluator = PromptedEvaluator(llm_serving, system_prompt) self.min_score = min_score self.max_score = max_score + + @staticmethod + def _has_valid_content(value) -> bool: + if value is None: + return False + if isinstance(value, str): + return value.strip() != "" + if isinstance(value, (list, tuple, set, dict)): + return len(value) > 0 + try: + if pd.isna(value): + return False + except TypeError: + pass + return bool(value) @staticmethod def get_desc(lang: str = "zh"): @@ -71,7 +86,7 @@ def run(self, storage: DataFlowStorage, input_key: str = "raw_content", output_k self.logger.info(f"Loading, number of rows: {len(dataframe)}") # Drop rows where input_key is empty/null before evaluation - valid_mask = dataframe[input_key].notna() & (dataframe[input_key].astype(str).str.strip() != '') + valid_mask = dataframe[input_key].apply(self._has_valid_content) valid_dataframe = dataframe[valid_mask] self.logger.info(f"Skipping {(~valid_mask).sum()} rows with empty '{input_key}'") diff --git a/dataflow/operators/core_text/generate/text2multihopqa_generator.py b/dataflow/operators/core_text/generate/text2multihopqa_generator.py index b6a74c4f..c4753b0c 100644 --- a/dataflow/operators/core_text/generate/text2multihopqa_generator.py +++ b/dataflow/operators/core_text/generate/text2multihopqa_generator.py @@ -225,13 +225,23 @@ def run( dataframe = storage.read("dataframe") self._validate_dataframe(dataframe) texts = dataframe[self.input_key].tolist() - outputs=self.process_batch(texts) - dataframe[self.output_key] = [ + outputs = self.process_batch(texts) + qa_pairs_column = [ output['qa_pairs'][:self.num_q] if len(output['qa_pairs']) >= self.num_q else output['qa_pairs'] for output in outputs ] + metadata_column = [output['metadata'] for output in outputs] + + dataframe = dataframe.copy() + dataframe[self.output_key] = qa_pairs_column + dataframe[self.output_meta_key] = metadata_column + + valid_mask = dataframe[self.output_key].apply(lambda qa_pairs: isinstance(qa_pairs, list) and len(qa_pairs) > 0) + filtered_count = int((~valid_mask).sum()) + if filtered_count: + self.logger.info(f"Filtering out {filtered_count} rows with empty '{self.output_key}'") + dataframe = dataframe[valid_mask].reset_index(drop=True) - dataframe[self.output_meta_key] = [output['metadata'] for output in outputs] output_file = storage.write(dataframe) self.logger.info(f"Results saved to {output_file}")