From c69c07e92ac914c8d7826be4a08320065cc3292d Mon Sep 17 00:00:00 2001 From: linmy <657894692@qq.com> Date: Tue, 10 Mar 2026 13:02:44 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20BaseDT=20=E6=B7=BB=E5=8A=A0=20NPZ=20?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=9B=86=E5=88=B6=E4=BD=9C=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=88=E8=A7=86=E9=A2=91/CSV=EF=BC=8C=E6=94=AF=E6=8C=81=20XE?= =?UTF-8?q?du=20=E4=B8=8E=20MediaPipe=EF=BC=89=E5=8F=8A=20demo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Made-with: Cursor --- BaseDT/BaseDT/__init__.py | 6 +- BaseDT/BaseDT/dataset.py | 408 ++++++++++++++++++++++++++++++++++++++ BaseDT/BaseDT/npz.py | 6 + demo/npz_demo.py | 14 ++ 4 files changed, 433 insertions(+), 1 deletion(-) create mode 100644 BaseDT/BaseDT/npz.py create mode 100644 demo/npz_demo.py diff --git a/BaseDT/BaseDT/__init__.py b/BaseDT/BaseDT/__init__.py index 7152555..ddcbbc7 100644 --- a/BaseDT/BaseDT/__init__.py +++ b/BaseDT/BaseDT/__init__.py @@ -1 +1,5 @@ -from .version import __version__ \ No newline at end of file +from .version import __version__ +from .dataset import make_npz_dataset, NPZGenerator +from . import npz + +__all__ = ['make_npz_dataset', 'NPZGenerator', 'npz'] \ No newline at end of file diff --git a/BaseDT/BaseDT/dataset.py b/BaseDT/BaseDT/dataset.py index 913172f..64cd549 100644 --- a/BaseDT/BaseDT/dataset.py +++ b/BaseDT/BaseDT/dataset.py @@ -882,6 +882,414 @@ def split_tab_dataset_class(data_path, data_column,label_column=[-1],train_val_r return train_x,train_y, val_x, val_y + +# ============= NPZ 序列数据集制作(视频/CSV,适用于 RNN/BaseNN load_npz_data) ============= + +def _detect_npz_data_type(source): + """根据 source 自动检测为 video 或 csv。""" + if source is None or (isinstance(source, str) and not os.path.exists(source)): + return None + if os.path.isdir(source): + return 'video' + if isinstance(source, str) and source.lower().endswith(('.csv', '.txt')): + return 'csv' + return None + + +def _make_npz_from_csv( + source, + output_path, + sequence_length=30, + data_column=None, + label_column=-1, + delimiter=',', + skiprows=1, + one_hot=True, +): + """从 CSV/表格文件生成 NPZ 序列数据集。按标签分组,每组内按行顺序切成长度为 sequence_length 的序列。""" + try: + import pandas as pd + except ImportError: + pd = None + if pd is None: + data = np.loadtxt(source, dtype=float, delimiter=delimiter, skiprows=skiprows) + n_cols = data.shape[1] + lcol = np.atleast_1d(np.asarray(label_column)) + if data_column is not None: + data_arr = data[:, np.atleast_1d(np.asarray(data_column))] + else: + use = np.setdiff1d(np.arange(n_cols), lcol) + data_arr = data[:, use] + labels = data[:, lcol].ravel().astype(int) + else: + df = pd.read_csv(source, delimiter=delimiter, skiprows=range(skiprows) if skiprows else None, header=0 if skiprows else None) + lcol = df.columns[label_column] if isinstance(label_column, int) else label_column + if data_column is not None: + cols = [df.columns[i] if isinstance(i, int) else i for i in np.atleast_1d(np.asarray(data_column))] + data_arr = df[cols].values.astype(np.float64) + else: + cols = [c for c in df.columns if c != lcol] + data_arr = df[cols].values.astype(np.float64) + labels = df[lcol].values.ravel().astype(int) + + unique_labels = np.unique(labels) + num_classes = len(unique_labels) + label_to_idx = {u: i for i, u in enumerate(unique_labels)} + samples_data = [] + samples_label = [] + + for lab in unique_labels: + mask = labels == lab + rows = data_arr[mask] + n_chunks = len(rows) // sequence_length + for i in range(n_chunks): + chunk = rows[i * sequence_length:(i + 1) * sequence_length] + samples_data.append(chunk) + idx = label_to_idx[lab] + if one_hot: + one_hot_label = np.zeros(num_classes, dtype=np.float32) + one_hot_label[idx] = 1.0 + samples_label.append(one_hot_label) + else: + samples_label.append(idx) + + if len(samples_data) == 0: + raise ValueError("CSV 中每个类别至少需要 sequence_length={} 行连续数据,当前不足。".format(sequence_length)) + + data_np = np.array(samples_data, dtype=np.float32) + label_np = np.array(samples_label, dtype=np.float32 if one_hot else np.int64) + d = os.path.dirname(output_path) + if d: + os.makedirs(d, exist_ok=True) + np.savez(output_path, data=data_np, label=label_np) + label_map = {i: int(k) for k, i in label_to_idx.items()} + label_list = [str(unique_labels[i]) for i in range(num_classes)] + return output_path, label_map, label_list + + +def _make_npz_from_video(source, output_path, sequence_length=30, model_path=None, pose_source='xedu'): + """从视频目录(按类别分子目录)生成 NPZ。pose_source='xedu' 用 XEdu det_body+pose_body26;'mediapipe' 需 model_path。""" + if pose_source == 'mediapipe' and (model_path is None or not os.path.isfile(model_path)): + raise ValueError("pose_source='mediapipe' 时需提供有效的 MediaPipe 模型路径 model_path(如 pose_landmarker_full.task)。") + gen = NPZGenerator(dataset_path=source, model_path=model_path, sequence_length=sequence_length, pose_source=pose_source) + gen.generate_dataset(output_path) + return output_path, gen.get_label_map(), gen.get_label_map_list() + + +def make_npz_dataset( + source, + output_path, + data_type='auto', + sequence_length=30, + pose_source='xedu', + model_path=None, + data_column=None, + label_column=-1, + delimiter=',', + skiprows=1, + one_hot=True, +): + """ + 生成适用于 RNN/序列分类的 NPZ 数据集,支持视频与 CSV 两种数据源。 + 输出格式与 BaseNN load_npz_data 兼容:'data' (n, sequence_length, n_features),'label' (n, n_classes) one-hot。 + + 参数: source 数据源路径;output_path 输出 .npz 路径;data_type 'auto'|'video'|'csv';sequence_length 序列长度,默认30。 + 视频模式 pose_source:'xedu'(默认,XEdu det_body+pose_body26,52维/帧,不需 model_path)或 'mediapipe'(132维/帧,需 model_path)。 + CSV 模式支持 data_column、label_column、delimiter、skiprows、one_hot。 + 返回: (output_path, label_map, label_list) + """ + if not output_path.endswith('.npz'): + raise ValueError("output_path 须为 .npz 文件路径。") + if data_type == 'auto': + data_type = _detect_npz_data_type(source) + if data_type is None: + raise ValueError("无法从 source 自动判断类型,请显式指定 data_type='video' 或 'csv'。") + if data_type == 'video': + return _make_npz_from_video( + source, output_path, sequence_length=sequence_length, + model_path=model_path, pose_source=pose_source, + ) + elif data_type == 'csv': + return _make_npz_from_csv( + source, output_path, sequence_length=sequence_length, + data_column=data_column, label_column=label_column, delimiter=delimiter, + skiprows=skiprows, one_hot=one_hot, + ) + else: + raise ValueError("data_type 须为 'auto'、'video' 或 'csv'。") + + +class NPZGenerator(object): + """ + 视频专用 NPZ 生成器。 + pose_source='xedu'(默认):XEdu det_body + pose_body26,52 维/帧,不需 model_path。 + pose_source='mediapipe':MediaPipe 姿态关键点,132 维/帧,需 model_path(.task 文件)。 + """ + def __init__(self, dataset_path, model_path=None, sequence_length=30, pose_source='xedu'): + self.sequence_length = sequence_length + self.dataset_path = dataset_path + self.model_path = model_path + self.pose_source = pose_source.lower() + self.label_map = {} + self.label_map_list = [] + self.video_path = [] + self.results = [] + self.file_name = '尚未保存' + self.feat_dim = 52 if self.pose_source == 'xedu' else 132 + if self.pose_source == 'xedu': + self._init_xedu() + elif self.pose_source == 'mediapipe': + self._init_mediapipe() + else: + raise ValueError("pose_source 须为 'xedu' 或 'mediapipe'。") + + def _init_xedu(self): + """使用 XEdu Workflow 的 det_body 与 pose_body26,无需额外模型文件。""" + try: + from .. import Workflow + except ImportError: + try: + from XEdu.hub import Workflow + except ImportError: + raise ImportError("pose_source='xedu' 需要 XEdu-python,请安装: pip install xedu-python") + self._det_model = Workflow(task='det_body') + self._pose_model = Workflow(task='pose_body26') + + def _get_keypoints_from_frame_xedu(self, frame): + """XEdu: 单帧提取 26*2=52 维关键点,归一化。无人则返回 False。""" + bboxs = self._det_model.inference(data=frame) + if len(bboxs) > 0: + keypoints = self._pose_model.inference(data=frame, bbox=bboxs[0]) + x = np.ravel(keypoints) + min_val, max_val = np.min(x), np.max(x) + if max_val - min_val > 1e-8: + x = (x - min_val) / (max_val - min_val) + return list(x.astype(np.float32)) + return False + + def _init_mediapipe(self): + try: + import mediapipe as mp + except ImportError: + raise ImportError("视频模式需要安装 mediapipe: pip install mediapipe") + self._mp = mp + self.BaseOptions = mp.tasks.BaseOptions + self.PoseLandmarker = mp.tasks.vision.PoseLandmarker + self.PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions + self.VisionRunningMode = mp.tasks.vision.RunningMode + self.options = self.PoseLandmarkerOptions( + base_options=self.BaseOptions(model_asset_path=self.model_path), + running_mode=self.VisionRunningMode.VIDEO, + ) + + def _read_video_path(self): + for action in os.listdir(self.dataset_path): + action_full = os.path.join(self.dataset_path, action) + if not os.path.isdir(action_full): + continue + if action not in self.label_map: + self.label_map[action] = len(self.label_map) + for name in os.listdir(action_full): + path = os.path.join(action_full, name) + if os.path.isfile(path): + self.video_path.append(path) + self._build_label_map_list() + + def _build_label_map_list(self): + self.label_map_list = [''] * len(self.label_map) + for name, idx in self.label_map.items(): + self.label_map_list[idx] = name + + def _get_landmark_list_mediapipe(self, pose): + """MediaPipe: 33 关键点 * 4 = 132 维。""" + out = [] + if len(pose.pose_landmarks) > 0: + for p in pose.pose_landmarks[0]: + out.extend([p.x, p.y, p.z, p.visibility]) + return out + + def _get_features(self): + if self.pose_source == 'xedu': + self._get_features_xedu() + else: + self._get_features_mediapipe() + + def _get_features_xedu(self): + for path in self.video_path: + cap = cv2.VideoCapture(path) + all_poses = [] + print("-----开始处理 " + path + " -----") + seq = self.sequence_length + keypoints_list = [] + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + kp = self._get_keypoints_from_frame_xedu(frame) + if kp is False: + keypoints_list = [] + continue + keypoints_list.append(kp) + if len(keypoints_list) == seq: + all_poses.append(keypoints_list[:]) + keypoints_list = [] + cap.release() + self.results.append((all_poses, path)) + print("-----结束处理 " + path + " -----") + + def _get_features_mediapipe(self): + mp = self._mp + for path in self.video_path: + cap = cv2.VideoCapture(path) + fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 + poses = [] + print("-----开始处理 " + path + " -----") + with self.PoseLandmarker.create_from_options(self.options) as landmarker: + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame_id = cap.get(cv2.CAP_PROP_POS_FRAMES) + ts_ms = int(frame_id * (1000.0 / fps)) + img = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame) + res = landmarker.detect_for_video(img, ts_ms) + poses.append(res) + self.results.append((poses, path)) + print("-----结束处理 " + path + " -----") + + def _split_and_save(self): + total = 0 + for poses, _ in self.results: + if self.pose_source == 'xedu': + total += len(poses) + else: + total += len(poses) // self.sequence_length + feat_dim = self.feat_dim + datas = np.zeros((total, self.sequence_length, feat_dim), dtype=np.float32) + labels_list = [] + cur = 0 + for poses, path in self.results: + cls_name = os.path.basename(os.path.dirname(path)) + lab = np.zeros(len(self.label_map), dtype=np.float32) + lab[self.label_map[cls_name]] = 1.0 + if self.pose_source == 'xedu': + for chunk in poses: + block = np.array(chunk, dtype=np.float32) + if block.shape == (self.sequence_length, feat_dim): + datas[cur, :, :] = block + labels_list.append(lab.copy()) + cur += 1 + else: + times = len(poses) // self.sequence_length + for t in range(times): + block = np.zeros((self.sequence_length, feat_dim), dtype=np.float32) + for k in range(self.sequence_length): + vec = self._get_landmark_list_mediapipe(poses[t * self.sequence_length + k]) + if len(vec) == feat_dim: + block[k, :] = vec + datas[cur, :, :] = block + labels_list.append(lab.copy()) + cur += 1 + self._datas = datas + self._labels = np.array(labels_list, dtype=np.float32) + + def generate_dataset(self, file_name): + if not file_name.endswith('.npz'): + print('请指定 .npz 输出路径') + return + self.file_name = file_name + self._read_video_path() + self._get_features() + self._split_and_save() + with open(self.file_name, 'wb') as f: + np.savez(f, data=self._datas, label=self._labels) + + def set_sequence_length(self, sequence_length): + self.sequence_length = sequence_length + + def get_video_path(self): + return self.video_path + + def get_label_map(self): + return self.label_map + + def get_label_map_list(self): + return self.label_map_list + + def get_results(self): + return self.results + + def get_sequence_length(self): + return self.sequence_length + + def set_label_map_list(self, label_map_list): + self.label_map_list = list(label_map_list) + + def generate_for_inference(self, data_path): + """对单个视频生成推理用数组,形状 (n_windows, sequence_length, feat_dim)。""" + feat_dim = self.feat_dim + if self.pose_source == 'xedu': + results = [] + cap = cv2.VideoCapture(data_path) + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + kp = self._get_keypoints_from_frame_xedu(frame) + if kp is not False: + results.append(kp) + cap.release() + else: + mp = self._mp + results = [] + cap = cv2.VideoCapture(data_path) + fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 + with self.PoseLandmarker.create_from_options(self.options) as landmarker: + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame_id = cap.get(cv2.CAP_PROP_POS_FRAMES) + ts_ms = int(frame_id * (1000.0 / fps)) + img = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame) + res = landmarker.detect_for_video(img, ts_ms) + if len(res.pose_landmarks) > 0: + results.append(res) + cap.release() + if len(results) < self.sequence_length: + print("视频帧数过少,无法生成至少 {} 帧的序列。".format(self.sequence_length)) + return None + n_windows = len(results) - self.sequence_length + 1 + content = np.zeros((n_windows, self.sequence_length, feat_dim), dtype=np.float32) + for start in range(n_windows): + for i in range(self.sequence_length): + if self.pose_source == 'xedu': + vec = results[start + i] + else: + vec = self._get_landmark_list_mediapipe(results[start + i]) + if len(vec) == feat_dim: + content[start, i, :] = vec + return content + + def see_result(self, result): + """对多窗口推理结果取平均并打印最大置信度类别。""" + result = np.asarray(result) + if result.ndim == 2: + avg = np.mean(result, axis=0) + else: + avg = result + probs = np.asarray(avg).ravel() + if len(self.label_map_list) and len(probs) >= len(self.label_map_list): + idx = int(np.argmax(probs[: len(self.label_map_list)])) + conf = float(probs[idx]) + name = self.label_map_list[idx] if idx < len(self.label_map_list) else str(idx) + print('推理结果为:' + str(name)) + print('置信度为:' + str(conf)) + return [list(probs), name] + print('推理结果:', probs) + return probs + + if __name__=="__main__": path = "../iris/iris.csv" train_x,train_y,val_x,val_y = split_tab_dataset(path,data_column=range(0,4),label_column=4,normalize=False) diff --git a/BaseDT/BaseDT/npz.py b/BaseDT/BaseDT/npz.py new file mode 100644 index 0000000..2a7bf0e --- /dev/null +++ b/BaseDT/BaseDT/npz.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""NPZ 数据集制作(视频/CSV)。用法: from BaseDT import npz 后使用 npz.make_npz_dataset、npz.NPZGenerator。""" + +from .dataset import make_npz_dataset, NPZGenerator + +__all__ = ['make_npz_dataset', 'NPZGenerator'] diff --git a/demo/npz_demo.py b/demo/npz_demo.py new file mode 100644 index 0000000..b6d8299 --- /dev/null +++ b/demo/npz_demo.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +"""NPZ 数据集制作示例:视频或 CSV 转成 NPZ(供 RNN/序列分类用)。""" + +from BaseDT import npz + +# 方式一:视频目录 → NPZ(目录结构:类别名/视频文件,如 video/挥手/*.mp4, video/走路/*.mp4) +npz.make_npz_dataset('./video', 'dataset_video.npz', data_type='video', sequence_length=30) + +# 方式二:CSV 文件 → NPZ(最后一列为标签) +npz.make_npz_dataset('data.csv', 'dataset_csv.npz', data_type='csv', sequence_length=10, label_column=-1) + +# 方式三:用类逐视频控制 +# gen = npz.NPZGenerator(dataset_path='./video', sequence_length=30) +# gen.generate_dataset('out.npz')