From d1d3733c73ba66b37530b58b36ab8d4db97cdcd8 Mon Sep 17 00:00:00 2001 From: linmy <657894692@qq.com> Date: Fri, 20 Feb 2026 10:02:31 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat:=20BaseDT=20=E6=B7=BB=E5=8A=A0=20NPZ?= =?UTF-8?q?=20=E6=95=B0=E6=8D=AE=E9=9B=86=E5=88=B6=E4=BD=9C=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=EF=BC=88=E8=A7=86=E9=A2=91/CSV=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=20XEdu=20=E4=B8=8E=20MediaPipe=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- XEdu/examples/basedt_demo.py | 200 +++++++++++++++++ XEdu/hub/BaseDT/__init__.py | 6 + XEdu/hub/BaseDT/dataset.py | 408 +++++++++++++++++++++++++++++++++++ 3 files changed, 614 insertions(+) create mode 100644 XEdu/examples/basedt_demo.py diff --git a/XEdu/examples/basedt_demo.py b/XEdu/examples/basedt_demo.py new file mode 100644 index 0000000..8834c4c --- /dev/null +++ b/XEdu/examples/basedt_demo.py @@ -0,0 +1,200 @@ +# -*- coding: utf-8 -*- +""" +BaseDT NPZ 数据集制作使用示例。 +演示如何从视频或 CSV 生成适用于 RNN/序列分类的 NPZ 数据集。 +""" + +import os +import sys +import shutil +import tempfile + +# 优先加载本地 XEdu-python-main,避免使用 site-packages 的旧版本 +_examples_dir = os.path.dirname(os.path.abspath(__file__)) +_project_root = os.path.dirname(os.path.dirname(_examples_dir)) # XEdu-python-main +if _project_root not in sys.path: + sys.path.insert(0, _project_root) + +from XEdu.hub.BaseDT import make_npz_dataset, NPZGenerator + + +def demo_video_to_npz(video_path, output_path='dataset_from_video.npz'): + """ + 将单个视频制作为 NPZ 数据集。 + 视频模式需要按类别分目录,单视频时自动创建临时目录结构(单类)。 + """ + video_path = os.path.abspath(video_path) + if not os.path.isfile(video_path): + print('视频不存在:', video_path) + return + tmp_dir = tempfile.mkdtemp(prefix='npz_video_') + class_dir = os.path.join(tmp_dir, 'single') + os.makedirs(class_dir) + dst = os.path.join(class_dir, os.path.basename(video_path)) + shutil.copy(video_path, dst) + print('临时目录:', tmp_dir) + make_npz_dataset( + tmp_dir, + output_path, + data_type='video', + pose_source='xedu', + sequence_length=30, + ) + shutil.rmtree(tmp_dir, ignore_errors=True) + print('视频 -> NPZ 完成,输出:', output_path) + import numpy as np + data = np.load(output_path) + print('data.shape:', data['data'].shape, 'label.shape:', data['label'].shape) + + +def demo_iris_csv_to_npz(csv_path, output_path='dataset_from_iris.npz'): + """将 iris.csv 制作为 NPZ 数据集。""" + csv_path = os.path.abspath(csv_path) + if not os.path.isfile(csv_path): + print('CSV 不存在:', csv_path) + return + make_npz_dataset( + csv_path, + output_path, + data_type='csv', + sequence_length=10, + label_column=-1, + delimiter=',', + skiprows=1, + ) + print('CSV -> NPZ 完成,输出:', output_path) + import numpy as np + data = np.load(output_path) + print('data.shape:', data['data'].shape, 'label.shape:', data['label'].shape) + + +def demo_make_npz_from_csv(): + """ + 示例1:从 CSV 表格数据生成 NPZ 数据集。 + CSV 格式:每行一个时间步,最后一列为类别标签。按类别分组后切成 sequence_length 长度的序列。 + """ + import numpy as np + + # 创建示例 CSV 数据(若 data.csv 不存在) + csv_path = 'data_for_npz_demo.csv' + if not os.path.exists(csv_path): + # 3 类,每类 100 行,5 个特征列 + 1 个标签列 + np.random.seed(42) + rows = [] + header = 'f1,f2,f3,f4,f5,label' + for label in range(3): + for _ in range(100): + feat = np.random.randn(5).astype(float) + rows.append(','.join(map(str, feat)) + ',' + str(label)) + with open(csv_path, 'w') as f: + f.write(header + '\n' + '\n'.join(rows)) + print('已创建示例 CSV:', csv_path) + + output_path = 'dataset_from_csv.npz' + make_npz_dataset( + csv_path, + output_path, + data_type='csv', + sequence_length=10, + label_column=-1, + delimiter=',', + skiprows=1, + ) + print('CSV -> NPZ 完成,输出:', output_path) + + # 验证 + data = np.load(output_path) + print('data.shape:', data['data'].shape, 'label.shape:', data['label'].shape) + + +def demo_make_npz_from_video(): + """ + 示例2:从视频目录生成 NPZ 数据集。 + 目录结构:dataset_path/类别名/视频文件 + 默认使用 XEdu det_body+pose_body26(52维/帧),不需额外模型。也可用 pose_source='mediapipe' 需 .task 模型。 + """ + video_dir = './video' # 视频目录:video/waving/*.mp4, video/walking/*.mp4 等 + + if not os.path.isdir(video_dir): + print('视频目录不存在:', video_dir, '请准备按类别分文件夹的视频目录。') + return + + # 默认 pose_source='xedu',无需 model_path + make_npz_dataset( + video_dir, + 'dataset_from_video.npz', + data_type='video', + pose_source='xedu', # 默认,XEdu det_body+pose_body26 + sequence_length=30, + ) + print('视频 -> NPZ 完成,输出: dataset_from_video.npz') + + # 若使用 MediaPipe(132维/帧),需传入 model_path + # make_npz_dataset(video_dir, 'dataset_mediapipe.npz', data_type='video', + # pose_source='mediapipe', model_path='pose_landmarker_full.task') + + +def demo_npz_generator_class(): + """ + 示例3:使用 NPZGenerator 类(兼容原 npz_generator 接口)。 + 默认 pose_source='xedu' 不需 model_path;pose_source='mediapipe' 需 model_path。 + """ + video_dir = './video' + + if not os.path.isdir(video_dir): + print('请准备视频目录。') + return + + # 默认使用 XEdu,不需 model_path + gen = NPZGenerator( + dataset_path=video_dir, + sequence_length=30, + pose_source='xedu', # 默认 + ) + gen.generate_dataset('dataset.npz') + print('标签映射:', gen.get_label_map()) + print('标签列表:', gen.get_label_map_list()) + + # 推理:对单个视频生成推理用数组 + # inf_data = gen.generate_for_inference('test.mp4') + # 设置标签名用于解析推理结果 + gen.set_label_map_list(['waving', 'walking', 'stretching']) + # gen.see_result(model_output) # 解析并打印推理结果 + + +def demo_auto_detect(): + """ + 示例4:自动检测数据类型(目录->视频,.csv->CSV)。 + """ + # CSV 示例 + csv_path = 'data_for_npz_demo.csv' + if os.path.exists(csv_path): + make_npz_dataset(csv_path, 'auto_csv.npz', data_type='auto', sequence_length=10) + print('自动检测 CSV 完成') + + # 视频目录示例(默认 xedu 不需 model_path) + if os.path.isdir('./video'): + make_npz_dataset('./video', 'auto_video.npz', data_type='auto') + print('自动检测视频完成') + + +if __name__ == '__main__': + print('=' * 50) + print('BaseDT NPZ 数据集制作示例') + print('=' * 50) + + # 检查效果:指定路径制作 NPZ + video_path = r'D:\Download\test (1).mp4' + iris_csv_path = r'D:\XEdu\datasets\baseml\iris\iris.csv' + + print('\n--- 1. 视频转 NPZ ---') + demo_video_to_npz(video_path, 'dataset_video.npz') + + print('\n--- 2. Iris CSV 转 NPZ ---') + demo_iris_csv_to_npz(iris_csv_path, 'dataset_iris.npz') + + # 以下为其他示例(可选) + # print('\n--- 示例: 从 CSV 生成 NPZ ---') + # demo_make_npz_from_csv() + # print('\n--- 示例: 从视频目录生成 NPZ ---') + # demo_make_npz_from_video() diff --git a/XEdu/hub/BaseDT/__init__.py b/XEdu/hub/BaseDT/__init__.py index e69de29..3012b37 100644 --- a/XEdu/hub/BaseDT/__init__.py +++ b/XEdu/hub/BaseDT/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""BaseDT 数据工具模块。""" + +from .dataset import make_npz_dataset, NPZGenerator + +__all__ = ['make_npz_dataset', 'NPZGenerator'] diff --git a/XEdu/hub/BaseDT/dataset.py b/XEdu/hub/BaseDT/dataset.py index 237dba3..acef9d2 100644 --- a/XEdu/hub/BaseDT/dataset.py +++ b/XEdu/hub/BaseDT/dataset.py @@ -877,6 +877,414 @@ def split_tab_dataset_class(data_path, data_column,label_column=[-1],train_val_r return train_x,train_y, val_x, val_y + +# ============= NPZ 序列数据集制作(视频/CSV,适用于 RNN/BaseNN load_npz_data) ============= + +def _detect_npz_data_type(source): + """根据 source 自动检测为 video 或 csv。""" + if source is None or (isinstance(source, str) and not os.path.exists(source)): + return None + if os.path.isdir(source): + return 'video' + if isinstance(source, str) and source.lower().endswith(('.csv', '.txt')): + return 'csv' + return None + + +def _make_npz_from_csv( + source, + output_path, + sequence_length=30, + data_column=None, + label_column=-1, + delimiter=',', + skiprows=1, + one_hot=True, +): + """从 CSV/表格文件生成 NPZ 序列数据集。按标签分组,每组内按行顺序切成长度为 sequence_length 的序列。""" + try: + import pandas as pd + except ImportError: + pd = None + if pd is None: + data = np.loadtxt(source, dtype=float, delimiter=delimiter, skiprows=skiprows) + n_cols = data.shape[1] + lcol = np.atleast_1d(np.asarray(label_column)) + if data_column is not None: + data_arr = data[:, np.atleast_1d(np.asarray(data_column))] + else: + use = np.setdiff1d(np.arange(n_cols), lcol) + data_arr = data[:, use] + labels = data[:, lcol].ravel().astype(int) + else: + df = pd.read_csv(source, delimiter=delimiter, skiprows=range(skiprows) if skiprows else None, header=0 if skiprows else None) + lcol = df.columns[label_column] if isinstance(label_column, int) else label_column + if data_column is not None: + cols = [df.columns[i] if isinstance(i, int) else i for i in np.atleast_1d(np.asarray(data_column))] + data_arr = df[cols].values.astype(np.float64) + else: + cols = [c for c in df.columns if c != lcol] + data_arr = df[cols].values.astype(np.float64) + labels = df[lcol].values.ravel().astype(int) + + unique_labels = np.unique(labels) + num_classes = len(unique_labels) + label_to_idx = {u: i for i, u in enumerate(unique_labels)} + samples_data = [] + samples_label = [] + + for lab in unique_labels: + mask = labels == lab + rows = data_arr[mask] + n_chunks = len(rows) // sequence_length + for i in range(n_chunks): + chunk = rows[i * sequence_length:(i + 1) * sequence_length] + samples_data.append(chunk) + idx = label_to_idx[lab] + if one_hot: + one_hot_label = np.zeros(num_classes, dtype=np.float32) + one_hot_label[idx] = 1.0 + samples_label.append(one_hot_label) + else: + samples_label.append(idx) + + if len(samples_data) == 0: + raise ValueError("CSV 中每个类别至少需要 sequence_length={} 行连续数据,当前不足。".format(sequence_length)) + + data_np = np.array(samples_data, dtype=np.float32) + label_np = np.array(samples_label, dtype=np.float32 if one_hot else np.int64) + d = os.path.dirname(output_path) + if d: + os.makedirs(d, exist_ok=True) + np.savez(output_path, data=data_np, label=label_np) + label_map = {i: int(k) for k, i in label_to_idx.items()} + label_list = [str(unique_labels[i]) for i in range(num_classes)] + return output_path, label_map, label_list + + +def _make_npz_from_video(source, output_path, sequence_length=30, model_path=None, pose_source='xedu'): + """从视频目录(按类别分子目录)生成 NPZ。pose_source='xedu' 用 XEdu det_body+pose_body26;'mediapipe' 需 model_path。""" + if pose_source == 'mediapipe' and (model_path is None or not os.path.isfile(model_path)): + raise ValueError("pose_source='mediapipe' 时需提供有效的 MediaPipe 模型路径 model_path(如 pose_landmarker_full.task)。") + gen = NPZGenerator(dataset_path=source, model_path=model_path, sequence_length=sequence_length, pose_source=pose_source) + gen.generate_dataset(output_path) + return output_path, gen.get_label_map(), gen.get_label_map_list() + + +def make_npz_dataset( + source, + output_path, + data_type='auto', + sequence_length=30, + pose_source='xedu', + model_path=None, + data_column=None, + label_column=-1, + delimiter=',', + skiprows=1, + one_hot=True, +): + """ + 生成适用于 RNN/序列分类的 NPZ 数据集,支持视频与 CSV 两种数据源。 + 输出格式与 BaseNN load_npz_data 兼容:'data' (n, sequence_length, n_features),'label' (n, n_classes) one-hot。 + + 参数: source 数据源路径;output_path 输出 .npz 路径;data_type 'auto'|'video'|'csv';sequence_length 序列长度,默认30。 + 视频模式 pose_source:'xedu'(默认,XEdu det_body+pose_body26,52维/帧,不需 model_path)或 'mediapipe'(132维/帧,需 model_path)。 + CSV 模式支持 data_column、label_column、delimiter、skiprows、one_hot。 + 返回: (output_path, label_map, label_list) + """ + if not output_path.endswith('.npz'): + raise ValueError("output_path 须为 .npz 文件路径。") + if data_type == 'auto': + data_type = _detect_npz_data_type(source) + if data_type is None: + raise ValueError("无法从 source 自动判断类型,请显式指定 data_type='video' 或 'csv'。") + if data_type == 'video': + return _make_npz_from_video( + source, output_path, sequence_length=sequence_length, + model_path=model_path, pose_source=pose_source, + ) + elif data_type == 'csv': + return _make_npz_from_csv( + source, output_path, sequence_length=sequence_length, + data_column=data_column, label_column=label_column, delimiter=delimiter, + skiprows=skiprows, one_hot=one_hot, + ) + else: + raise ValueError("data_type 须为 'auto'、'video' 或 'csv'。") + + +class NPZGenerator(object): + """ + 视频专用 NPZ 生成器。 + pose_source='xedu'(默认):XEdu det_body + pose_body26,52 维/帧,不需 model_path。 + pose_source='mediapipe':MediaPipe 姿态关键点,132 维/帧,需 model_path(.task 文件)。 + """ + def __init__(self, dataset_path, model_path=None, sequence_length=30, pose_source='xedu'): + self.sequence_length = sequence_length + self.dataset_path = dataset_path + self.model_path = model_path + self.pose_source = pose_source.lower() + self.label_map = {} + self.label_map_list = [] + self.video_path = [] + self.results = [] + self.file_name = '尚未保存' + self.feat_dim = 52 if self.pose_source == 'xedu' else 132 + if self.pose_source == 'xedu': + self._init_xedu() + elif self.pose_source == 'mediapipe': + self._init_mediapipe() + else: + raise ValueError("pose_source 须为 'xedu' 或 'mediapipe'。") + + def _init_xedu(self): + """使用 XEdu Workflow 的 det_body 与 pose_body26,无需额外模型文件。""" + try: + from .. import Workflow + except ImportError: + try: + from XEdu.hub import Workflow + except ImportError: + raise ImportError("pose_source='xedu' 需要 XEdu-python,请安装: pip install xedu-python") + self._det_model = Workflow(task='det_body') + self._pose_model = Workflow(task='pose_body26') + + def _get_keypoints_from_frame_xedu(self, frame): + """XEdu: 单帧提取 26*2=52 维关键点,归一化。无人则返回 False。""" + bboxs = self._det_model.inference(data=frame) + if len(bboxs) > 0: + keypoints = self._pose_model.inference(data=frame, bbox=bboxs[0]) + x = np.ravel(keypoints) + min_val, max_val = np.min(x), np.max(x) + if max_val - min_val > 1e-8: + x = (x - min_val) / (max_val - min_val) + return list(x.astype(np.float32)) + return False + + def _init_mediapipe(self): + try: + import mediapipe as mp + except ImportError: + raise ImportError("视频模式需要安装 mediapipe: pip install mediapipe") + self._mp = mp + self.BaseOptions = mp.tasks.BaseOptions + self.PoseLandmarker = mp.tasks.vision.PoseLandmarker + self.PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions + self.VisionRunningMode = mp.tasks.vision.RunningMode + self.options = self.PoseLandmarkerOptions( + base_options=self.BaseOptions(model_asset_path=self.model_path), + running_mode=self.VisionRunningMode.VIDEO, + ) + + def _read_video_path(self): + for action in os.listdir(self.dataset_path): + action_full = os.path.join(self.dataset_path, action) + if not os.path.isdir(action_full): + continue + if action not in self.label_map: + self.label_map[action] = len(self.label_map) + for name in os.listdir(action_full): + path = os.path.join(action_full, name) + if os.path.isfile(path): + self.video_path.append(path) + self._build_label_map_list() + + def _build_label_map_list(self): + self.label_map_list = [''] * len(self.label_map) + for name, idx in self.label_map.items(): + self.label_map_list[idx] = name + + def _get_landmark_list_mediapipe(self, pose): + """MediaPipe: 33 关键点 * 4 = 132 维。""" + out = [] + if len(pose.pose_landmarks) > 0: + for p in pose.pose_landmarks[0]: + out.extend([p.x, p.y, p.z, p.visibility]) + return out + + def _get_features(self): + if self.pose_source == 'xedu': + self._get_features_xedu() + else: + self._get_features_mediapipe() + + def _get_features_xedu(self): + for path in self.video_path: + cap = cv2.VideoCapture(path) + all_poses = [] + print("-----开始处理 " + path + " -----") + seq = self.sequence_length + keypoints_list = [] + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + kp = self._get_keypoints_from_frame_xedu(frame) + if kp is False: + keypoints_list = [] + continue + keypoints_list.append(kp) + if len(keypoints_list) == seq: + all_poses.append(keypoints_list[:]) + keypoints_list = [] + cap.release() + self.results.append((all_poses, path)) + print("-----结束处理 " + path + " -----") + + def _get_features_mediapipe(self): + mp = self._mp + for path in self.video_path: + cap = cv2.VideoCapture(path) + fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 + poses = [] + print("-----开始处理 " + path + " -----") + with self.PoseLandmarker.create_from_options(self.options) as landmarker: + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame_id = cap.get(cv2.CAP_PROP_POS_FRAMES) + ts_ms = int(frame_id * (1000.0 / fps)) + img = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame) + res = landmarker.detect_for_video(img, ts_ms) + poses.append(res) + self.results.append((poses, path)) + print("-----结束处理 " + path + " -----") + + def _split_and_save(self): + total = 0 + for poses, _ in self.results: + if self.pose_source == 'xedu': + total += len(poses) + else: + total += len(poses) // self.sequence_length + feat_dim = self.feat_dim + datas = np.zeros((total, self.sequence_length, feat_dim), dtype=np.float32) + labels_list = [] + cur = 0 + for poses, path in self.results: + cls_name = os.path.basename(os.path.dirname(path)) + lab = np.zeros(len(self.label_map), dtype=np.float32) + lab[self.label_map[cls_name]] = 1.0 + if self.pose_source == 'xedu': + for chunk in poses: + block = np.array(chunk, dtype=np.float32) + if block.shape == (self.sequence_length, feat_dim): + datas[cur, :, :] = block + labels_list.append(lab.copy()) + cur += 1 + else: + times = len(poses) // self.sequence_length + for t in range(times): + block = np.zeros((self.sequence_length, feat_dim), dtype=np.float32) + for k in range(self.sequence_length): + vec = self._get_landmark_list_mediapipe(poses[t * self.sequence_length + k]) + if len(vec) == feat_dim: + block[k, :] = vec + datas[cur, :, :] = block + labels_list.append(lab.copy()) + cur += 1 + self._datas = datas + self._labels = np.array(labels_list, dtype=np.float32) + + def generate_dataset(self, file_name): + if not file_name.endswith('.npz'): + print('请指定 .npz 输出路径') + return + self.file_name = file_name + self._read_video_path() + self._get_features() + self._split_and_save() + with open(self.file_name, 'wb') as f: + np.savez(f, data=self._datas, label=self._labels) + + def set_sequence_length(self, sequence_length): + self.sequence_length = sequence_length + + def get_video_path(self): + return self.video_path + + def get_label_map(self): + return self.label_map + + def get_label_map_list(self): + return self.label_map_list + + def get_results(self): + return self.results + + def get_sequence_length(self): + return self.sequence_length + + def set_label_map_list(self, label_map_list): + self.label_map_list = list(label_map_list) + + def generate_for_inference(self, data_path): + """对单个视频生成推理用数组,形状 (n_windows, sequence_length, feat_dim)。""" + feat_dim = self.feat_dim + if self.pose_source == 'xedu': + results = [] + cap = cv2.VideoCapture(data_path) + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + kp = self._get_keypoints_from_frame_xedu(frame) + if kp is not False: + results.append(kp) + cap.release() + else: + mp = self._mp + results = [] + cap = cv2.VideoCapture(data_path) + fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 + with self.PoseLandmarker.create_from_options(self.options) as landmarker: + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame_id = cap.get(cv2.CAP_PROP_POS_FRAMES) + ts_ms = int(frame_id * (1000.0 / fps)) + img = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame) + res = landmarker.detect_for_video(img, ts_ms) + if len(res.pose_landmarks) > 0: + results.append(res) + cap.release() + if len(results) < self.sequence_length: + print("视频帧数过少,无法生成至少 {} 帧的序列。".format(self.sequence_length)) + return None + n_windows = len(results) - self.sequence_length + 1 + content = np.zeros((n_windows, self.sequence_length, feat_dim), dtype=np.float32) + for start in range(n_windows): + for i in range(self.sequence_length): + if self.pose_source == 'xedu': + vec = results[start + i] + else: + vec = self._get_landmark_list_mediapipe(results[start + i]) + if len(vec) == feat_dim: + content[start, i, :] = vec + return content + + def see_result(self, result): + """对多窗口推理结果取平均并打印最大置信度类别。""" + result = np.asarray(result) + if result.ndim == 2: + avg = np.mean(result, axis=0) + else: + avg = result + probs = np.asarray(avg).ravel() + if len(self.label_map_list) and len(probs) >= len(self.label_map_list): + idx = int(np.argmax(probs[: len(self.label_map_list)])) + conf = float(probs[idx]) + name = self.label_map_list[idx] if idx < len(self.label_map_list) else str(idx) + print('推理结果为:' + str(name)) + print('置信度为:' + str(conf)) + return [list(probs), name] + print('推理结果:', probs) + return probs + + if __name__=="__main__": path = "../iris/iris.csv" train_x,train_y,val_x,val_y = split_tab_dataset(path,data_column=range(0,4),label_column=4,normalize=True) From 0db401223e02fcd9355fc3329fd9e5a6e4d81432 Mon Sep 17 00:00:00 2001 From: linmy <657894692@qq.com> Date: Thu, 5 Mar 2026 15:49:42 +0800 Subject: [PATCH 2/2] =?UTF-8?q?refactor:=20BaseDT=20=E5=A2=9E=E5=8A=A0=20n?= =?UTF-8?q?pz=20=E7=AE=80=E7=9F=AD=E5=85=A5=E5=8F=A3=EF=BC=8C=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B=E6=94=B9=E4=B8=BA=20npz=5Fdemo=20=E5=B9=B6=E7=B2=BE?= =?UTF-8?q?=E7=AE=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- XEdu/examples/basedt_demo.py | 200 ----------------------------------- XEdu/examples/npz_demo.py | 14 +++ XEdu/hub/BaseDT/__init__.py | 3 +- XEdu/hub/BaseDT/npz.py | 6 ++ 4 files changed, 22 insertions(+), 201 deletions(-) delete mode 100644 XEdu/examples/basedt_demo.py create mode 100644 XEdu/examples/npz_demo.py create mode 100644 XEdu/hub/BaseDT/npz.py diff --git a/XEdu/examples/basedt_demo.py b/XEdu/examples/basedt_demo.py deleted file mode 100644 index 8834c4c..0000000 --- a/XEdu/examples/basedt_demo.py +++ /dev/null @@ -1,200 +0,0 @@ -# -*- coding: utf-8 -*- -""" -BaseDT NPZ 数据集制作使用示例。 -演示如何从视频或 CSV 生成适用于 RNN/序列分类的 NPZ 数据集。 -""" - -import os -import sys -import shutil -import tempfile - -# 优先加载本地 XEdu-python-main,避免使用 site-packages 的旧版本 -_examples_dir = os.path.dirname(os.path.abspath(__file__)) -_project_root = os.path.dirname(os.path.dirname(_examples_dir)) # XEdu-python-main -if _project_root not in sys.path: - sys.path.insert(0, _project_root) - -from XEdu.hub.BaseDT import make_npz_dataset, NPZGenerator - - -def demo_video_to_npz(video_path, output_path='dataset_from_video.npz'): - """ - 将单个视频制作为 NPZ 数据集。 - 视频模式需要按类别分目录,单视频时自动创建临时目录结构(单类)。 - """ - video_path = os.path.abspath(video_path) - if not os.path.isfile(video_path): - print('视频不存在:', video_path) - return - tmp_dir = tempfile.mkdtemp(prefix='npz_video_') - class_dir = os.path.join(tmp_dir, 'single') - os.makedirs(class_dir) - dst = os.path.join(class_dir, os.path.basename(video_path)) - shutil.copy(video_path, dst) - print('临时目录:', tmp_dir) - make_npz_dataset( - tmp_dir, - output_path, - data_type='video', - pose_source='xedu', - sequence_length=30, - ) - shutil.rmtree(tmp_dir, ignore_errors=True) - print('视频 -> NPZ 完成,输出:', output_path) - import numpy as np - data = np.load(output_path) - print('data.shape:', data['data'].shape, 'label.shape:', data['label'].shape) - - -def demo_iris_csv_to_npz(csv_path, output_path='dataset_from_iris.npz'): - """将 iris.csv 制作为 NPZ 数据集。""" - csv_path = os.path.abspath(csv_path) - if not os.path.isfile(csv_path): - print('CSV 不存在:', csv_path) - return - make_npz_dataset( - csv_path, - output_path, - data_type='csv', - sequence_length=10, - label_column=-1, - delimiter=',', - skiprows=1, - ) - print('CSV -> NPZ 完成,输出:', output_path) - import numpy as np - data = np.load(output_path) - print('data.shape:', data['data'].shape, 'label.shape:', data['label'].shape) - - -def demo_make_npz_from_csv(): - """ - 示例1:从 CSV 表格数据生成 NPZ 数据集。 - CSV 格式:每行一个时间步,最后一列为类别标签。按类别分组后切成 sequence_length 长度的序列。 - """ - import numpy as np - - # 创建示例 CSV 数据(若 data.csv 不存在) - csv_path = 'data_for_npz_demo.csv' - if not os.path.exists(csv_path): - # 3 类,每类 100 行,5 个特征列 + 1 个标签列 - np.random.seed(42) - rows = [] - header = 'f1,f2,f3,f4,f5,label' - for label in range(3): - for _ in range(100): - feat = np.random.randn(5).astype(float) - rows.append(','.join(map(str, feat)) + ',' + str(label)) - with open(csv_path, 'w') as f: - f.write(header + '\n' + '\n'.join(rows)) - print('已创建示例 CSV:', csv_path) - - output_path = 'dataset_from_csv.npz' - make_npz_dataset( - csv_path, - output_path, - data_type='csv', - sequence_length=10, - label_column=-1, - delimiter=',', - skiprows=1, - ) - print('CSV -> NPZ 完成,输出:', output_path) - - # 验证 - data = np.load(output_path) - print('data.shape:', data['data'].shape, 'label.shape:', data['label'].shape) - - -def demo_make_npz_from_video(): - """ - 示例2:从视频目录生成 NPZ 数据集。 - 目录结构:dataset_path/类别名/视频文件 - 默认使用 XEdu det_body+pose_body26(52维/帧),不需额外模型。也可用 pose_source='mediapipe' 需 .task 模型。 - """ - video_dir = './video' # 视频目录:video/waving/*.mp4, video/walking/*.mp4 等 - - if not os.path.isdir(video_dir): - print('视频目录不存在:', video_dir, '请准备按类别分文件夹的视频目录。') - return - - # 默认 pose_source='xedu',无需 model_path - make_npz_dataset( - video_dir, - 'dataset_from_video.npz', - data_type='video', - pose_source='xedu', # 默认,XEdu det_body+pose_body26 - sequence_length=30, - ) - print('视频 -> NPZ 完成,输出: dataset_from_video.npz') - - # 若使用 MediaPipe(132维/帧),需传入 model_path - # make_npz_dataset(video_dir, 'dataset_mediapipe.npz', data_type='video', - # pose_source='mediapipe', model_path='pose_landmarker_full.task') - - -def demo_npz_generator_class(): - """ - 示例3:使用 NPZGenerator 类(兼容原 npz_generator 接口)。 - 默认 pose_source='xedu' 不需 model_path;pose_source='mediapipe' 需 model_path。 - """ - video_dir = './video' - - if not os.path.isdir(video_dir): - print('请准备视频目录。') - return - - # 默认使用 XEdu,不需 model_path - gen = NPZGenerator( - dataset_path=video_dir, - sequence_length=30, - pose_source='xedu', # 默认 - ) - gen.generate_dataset('dataset.npz') - print('标签映射:', gen.get_label_map()) - print('标签列表:', gen.get_label_map_list()) - - # 推理:对单个视频生成推理用数组 - # inf_data = gen.generate_for_inference('test.mp4') - # 设置标签名用于解析推理结果 - gen.set_label_map_list(['waving', 'walking', 'stretching']) - # gen.see_result(model_output) # 解析并打印推理结果 - - -def demo_auto_detect(): - """ - 示例4:自动检测数据类型(目录->视频,.csv->CSV)。 - """ - # CSV 示例 - csv_path = 'data_for_npz_demo.csv' - if os.path.exists(csv_path): - make_npz_dataset(csv_path, 'auto_csv.npz', data_type='auto', sequence_length=10) - print('自动检测 CSV 完成') - - # 视频目录示例(默认 xedu 不需 model_path) - if os.path.isdir('./video'): - make_npz_dataset('./video', 'auto_video.npz', data_type='auto') - print('自动检测视频完成') - - -if __name__ == '__main__': - print('=' * 50) - print('BaseDT NPZ 数据集制作示例') - print('=' * 50) - - # 检查效果:指定路径制作 NPZ - video_path = r'D:\Download\test (1).mp4' - iris_csv_path = r'D:\XEdu\datasets\baseml\iris\iris.csv' - - print('\n--- 1. 视频转 NPZ ---') - demo_video_to_npz(video_path, 'dataset_video.npz') - - print('\n--- 2. Iris CSV 转 NPZ ---') - demo_iris_csv_to_npz(iris_csv_path, 'dataset_iris.npz') - - # 以下为其他示例(可选) - # print('\n--- 示例: 从 CSV 生成 NPZ ---') - # demo_make_npz_from_csv() - # print('\n--- 示例: 从视频目录生成 NPZ ---') - # demo_make_npz_from_video() diff --git a/XEdu/examples/npz_demo.py b/XEdu/examples/npz_demo.py new file mode 100644 index 0000000..026b114 --- /dev/null +++ b/XEdu/examples/npz_demo.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +"""NPZ 数据集制作示例:视频或 CSV 转成 NPZ(供 RNN/序列分类用)。""" + +from XEdu.hub.BaseDT import npz + +# 方式一:视频目录 → NPZ(目录结构:类别名/视频文件,如 video/挥手/*.mp4, video/走路/*.mp4) +npz.make_npz_dataset('./video', 'dataset_video.npz', data_type='video', sequence_length=30) + +# 方式二:CSV 文件 → NPZ(最后一列为标签) +npz.make_npz_dataset('data.csv', 'dataset_csv.npz', data_type='csv', sequence_length=10, label_column=-1) + +# 方式三:用类逐视频控制 +# gen = npz.NPZGenerator(dataset_path='./video', sequence_length=30) +# gen.generate_dataset('out.npz') diff --git a/XEdu/hub/BaseDT/__init__.py b/XEdu/hub/BaseDT/__init__.py index 3012b37..1b57e17 100644 --- a/XEdu/hub/BaseDT/__init__.py +++ b/XEdu/hub/BaseDT/__init__.py @@ -2,5 +2,6 @@ """BaseDT 数据工具模块。""" from .dataset import make_npz_dataset, NPZGenerator +from . import npz -__all__ = ['make_npz_dataset', 'NPZGenerator'] +__all__ = ['make_npz_dataset', 'NPZGenerator', 'npz'] diff --git a/XEdu/hub/BaseDT/npz.py b/XEdu/hub/BaseDT/npz.py new file mode 100644 index 0000000..42f0f45 --- /dev/null +++ b/XEdu/hub/BaseDT/npz.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""NPZ 数据集制作(视频/CSV)。用法: from XEdu.hub.BaseDT import npz 后使用 npz.make_npz_dataset、npz.NPZGenerator。""" + +from .dataset import make_npz_dataset, NPZGenerator + +__all__ = ['make_npz_dataset', 'NPZGenerator']