diff --git a/docs/images/.DS_Store b/docs/images/.DS_Store deleted file mode 100644 index 606ad25a5..000000000 Binary files a/docs/images/.DS_Store and /dev/null differ diff --git a/funasr/models/fun_asr_nano/model.py b/funasr/models/fun_asr_nano/model.py index a4eba71b5..67267fe23 100644 --- a/funasr/models/fun_asr_nano/model.py +++ b/funasr/models/fun_asr_nano/model.py @@ -54,7 +54,9 @@ def __init__( else -1 ) audio_encoder = ( - model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder + model.model.model.encoder + if hasattr(model.model, "model") + else model.model.encoder ) else: encoder_class = tables.encoder_classes.get(audio_encoder) @@ -135,7 +137,9 @@ def __init__( if init_param_path is not None: src_state = torch.load(init_param_path, map_location="cpu") flag = self.ctc_decoder.load_state_dict(src_state, strict=False) - logging.info(f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}") + logging.info( + f"Loading ctc_decoder ckpt: {init_param_path}, status: {flag}" + ) freeze = ctc_decoder_conf.get("freeze", False) if freeze: for _, param in self.ctc_decoder.named_parameters(): @@ -189,7 +193,9 @@ def forward( encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) # audio_adaptor - encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) + encoder_out, encoder_out_lens = self.audio_adaptor( + encoder_out, encoder_out_lens + ) batch_size, token_num, dims = inputs_embeds.shape fake_token_len = kwargs.get("fake_token_len") @@ -228,7 +234,9 @@ def forward( stats["batch_size_speech"] = batch_size_speech stats["batch_size_x_frames"] = frames * batch_size_speech stats["batch_size_real_frames"] = speech_lengths.sum().item() - stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] + stats["padding_frames"] = ( + stats["batch_size_x_frames"] - stats["batch_size_real_frames"] + ) device_type = next(self.parameters()).device.type with torch.autocast( @@ -247,7 +255,9 @@ def forward( with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) - acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) + acc_att = compute_accuracy( + preds[:, :-1], labels_ids[:, 1:], ignore_label=-100 + ) stats["acc"] = acc_att stats["loss"] = torch.clone(loss.detach()) @@ -255,7 +265,9 @@ def forward( stats["batch_size_x_tokens"] = token_num * batch_size stats["batch_size_real_tokens"] = attention_mask.sum().item() - stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] + stats["padding_tokens"] = ( + stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] + ) dialog_turns = (fbank_beg > 0).sum(-1) dialog_turns_max = torch.max(dialog_turns).int().item() @@ -305,7 +317,9 @@ def data_template(self, data): return contents - def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs): + def data_load_speech( + self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs + ): system = contents["system"] user = contents["user"] assistant = contents["assistant"] @@ -326,7 +340,9 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, ** [], ) input_source_ids = [] - for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)): + for i, (system_prompt, user_prompt, target_out) in enumerate( + zip(system, user, assistant) + ): if i >= kwargs.get("multiturn_num_max", 5): break if len(input_ids) > kwargs.get("max_token_length", 1500): @@ -341,16 +357,12 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, ** else: source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" if not sys_prompt: - source_input = ( - f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" - ) + source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" else: if kwargs.get("infer_with_assistant_input", False): source_input = f"<|im_start|>user\n{user_prompt}" else: - source_input = ( - f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" - ) + source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" if not do_think: source_input += "\n\n\n\n" if kwargs.get("prev_text", None) is not None: @@ -383,7 +395,9 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, ** time2 = time.perf_counter() meta_data["load_data"] = f"{time2 - time1:0.3f}" except Exception as e: - logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}") + logging.error( + f"Loading wav failed! {str(e)}, {traceback.format_exc()}" + ) speech, speech_lengths = extract_fbank( data_src, @@ -425,7 +439,9 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, ** fbank.append(speech[0, :, :]) fbank_lens.append(speech_lengths) - input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length] + input_ids = torch.tensor( + input_ids, dtype=torch.int64 + ) # [: self.max_token_length] attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length] @@ -436,7 +452,9 @@ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, ** target_ids = torch.tensor(target_ids, dtype=torch.int64) if len(fbank) > 0: - speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0) + speech = torch.nn.utils.rnn.pad_sequence( + fbank, batch_first=True, padding_value=0.0 + ) speech_lengths = torch.nn.utils.rnn.pad_sequence( fbank_lens, batch_first=True, padding_value=-1 ) @@ -469,11 +487,10 @@ def inference_prepare( ): meta_data = {} - if kwargs.get("batch_size", 1) > 1: - raise NotImplementedError("batch decoding is not implemented") - contents = self.data_template(data_in[0]) - output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs) + output = self.data_load_speech( + contents, tokenizer, frontend, meta_data=meta_data, **kwargs + ) batch = to_device(output, kwargs["device"]) # audio encoder @@ -494,7 +511,9 @@ def inference_prepare( encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) # audio_adaptor - adaptor_out, adaptor_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) + adaptor_out, adaptor_out_lens = self.audio_adaptor( + encoder_out, encoder_out_lens + ) meta_data["encoder_out"] = encoder_out meta_data["encoder_out_lens"] = encoder_out_lens meta_data["audio_adaptor_out"] = adaptor_out @@ -566,7 +585,10 @@ def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]): if isinstance(data, str): return [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"}, + { + "role": "user", + "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>", + }, {"role": "assistant", "content": "null"}, ] elif isinstance(data, torch.Tensor): @@ -590,7 +612,9 @@ def inference( **kwargs, ): prompt = self.get_prompt( - kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True) + kwargs.get("hotwords", []), + kwargs.get("language", None), + kwargs.get("itn", True), ) data_in = [self.generate_chatml(prompt, data) for data in data_in] @@ -598,7 +622,35 @@ def inference( key = [] for _ in data_in: chars = string.ascii_letters + string.digits - key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13))) + key.append( + "rand_key_" + "".join(random.choice(chars) for _ in range(13)) + ) + + # 批量推理:LLM 自回归解码不支持跨样本 padding 批处理, + # 对每条音频独立推理后聚合结果,实现对 batch_size_s > 0 的支持。 + if len(data_in) > 1: + all_results = [] + last_meta = {} + for i, single_data in enumerate(data_in): + single_key = [key[i]] if i < len(key) else None + try: + res, meta = self.inference_llm( + [single_data], + data_lengths=None, + key=single_key, + tokenizer=tokenizer, + frontend=frontend, + **kwargs, + ) + all_results.extend(res) + last_meta = meta + except Exception as e: + logging.error( + f"batch item {i} inference failed: {str(e)}, {traceback.format_exc()}" + ) + if single_key: + all_results.append({"key": single_key[0], "text": ""}) + return all_results, last_meta return self.inference_llm( data_in, @@ -626,7 +678,9 @@ def inference_llm( if self.ctc_decoder is not None: encoder_out = meta_data["encoder_out"] encoder_out_lens = meta_data["encoder_out_lens"] - decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens) + decoder_out, decoder_out_lens = self.ctc_decoder( + encoder_out, encoder_out_lens + ) ctc_logits = self.ctc.log_softmax(decoder_out) b, n, d = encoder_out.size() @@ -665,7 +719,8 @@ def inference_llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=kwargs.get("max_length", 512), - pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id, + pad_token_id=self.llm.config.pad_token_id + or self.llm.config.eos_token_id, **llm_kwargs, ) @@ -683,7 +738,8 @@ def inference_llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids, - pad_token_id=self.llm.config.pad_token_id or self.llm.config.eos_token_id, + pad_token_id=self.llm.config.pad_token_id + or self.llm.config.eos_token_id, **llm_kwargs, ) @@ -722,8 +778,12 @@ def inference_llm( result["ctc_timestamps"] = forced_align( ctc_result["ctc_logits"], target_ids, self.blank_id ) - target_ids = torch.tensor(self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64) - result["timestamps"] = forced_align(ctc_result["ctc_logits"], target_ids, self.blank_id) + target_ids = torch.tensor( + self.ctc_tokenizer.encode(result["text"]), dtype=torch.int64 + ) + result["timestamps"] = forced_align( + ctc_result["ctc_logits"], target_ids, self.blank_id + ) for timestamps in [result["timestamps"], result["ctc_timestamps"]]: for timestamp in timestamps: timestamp["token"] = self.ctc_tokenizer.decode([timestamp["token"]]) @@ -741,6 +801,8 @@ def inference_llm( def from_pretrained(model: str = None, **kwargs): from funasr import AutoModel - model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs) + model, kwargs = AutoModel.build_model( + model=model, trust_remote_code=True, **kwargs + ) return model, kwargs