diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index c72e2d9..c80b519 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -40,7 +40,9 @@ func wireApp(configConfig *config.Config) (*app.App, func(), error) { sysHandler := handler.NewSysHandler(sysService) cacheJobService := service.NewCacheJobService(fileDao, metaDao, downloaderDao, schedulerDao) cacheJobHandler := handler.NewCacheJobHandler(cacheJobService) - httpRouter := router.NewHttpRouter(echo, fileHandler, metaHandler, sysHandler, cacheJobHandler) + modelscopeService := service.NewModelscopeService() + modelscopeHandler := handler.NewModelscopeHandler(modelscopeService) + httpRouter := router.NewHttpRouter(echo, fileHandler, metaHandler, sysHandler, cacheJobHandler, modelscopeHandler) httpServer := server.NewServer(configConfig, echo, httpRouter) schedulerService := service.NewSchedulerService(schedulerDao) schedulerServer := server.NewSchedulerServer(schedulerService, sysService, localOperationService) diff --git a/config/config.yaml b/config/config.yaml index 0814d64..4595dc2 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -75,3 +75,9 @@ dynamicProxy: timePeriod: 60 #定期检测代理是否可用时间周期,单位秒(S) maxContinuousFails: 5 #连续失败次数超过该值,则认为代理不可用 webhook: https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=73662ac1-1055-48a7-8c89-37964b5f4fdc111 # 企业微信机器人Webhook地址 + +modelscope: + officialBaseURL: https://www.modelscope.cn # ModelScope官方基础地址 + chunkSize: 8388608 # 8MB分块,16*1024*1024的数值结果 + maxRetry: 5 # 超时重试次数 + retryDelay: 3 # 重试间隔,单位秒(S)(原配置为5*time.Second,YAML中简化为数值+注释) \ No newline at end of file diff --git a/go.mod b/go.mod index ebf41d5..55d7ae0 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/bytedance/sonic v1.13.2 github.com/go-playground/validator/v10 v10.27.0 github.com/go-redsync/redsync/v4 v4.13.0 + github.com/gofrs/flock v0.8.1 github.com/google/uuid v1.6.0 github.com/google/wire v0.6.0 github.com/klauspost/compress v1.18.0 diff --git a/go.sum b/go.sum index 1efaa85..d7383ad 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,10 @@ github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-redsync/redsync/v4 v4.13.0 h1:49X6GJfnbLGaIpBBREM/zA4uIMDXKAh1NDkvQ1EkZKA= github.com/go-redsync/redsync/v4 v4.13.0/go.mod h1:HMW4Q224GZQz6x1Xc7040Yfgacukdzu7ifTDAKiyErQ= +github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= +github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= +github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= +github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 4dafbe9..7726302 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -18,4 +18,4 @@ import ( "github.com/google/wire" ) -var HandlerProvider = wire.NewSet(NewFileHandler, NewMetaHandler, NewSysHandler, NewCacheJobHandler) +var HandlerProvider = wire.NewSet(NewFileHandler, NewMetaHandler, NewSysHandler, NewCacheJobHandler, NewModelscopeHandler) diff --git a/internal/handler/modelscope_handler.go b/internal/handler/modelscope_handler.go new file mode 100644 index 0000000..3f3facb --- /dev/null +++ b/internal/handler/modelscope_handler.go @@ -0,0 +1,119 @@ +package handler + +import ( + "fmt" + "strings" + + "dingospeed/internal/service" + "dingospeed/pkg/util" + + "github.com/labstack/echo/v4" +) + +// ModelscopeHandler 模型代理请求处理器 +type ModelscopeHandler struct { + modelscopeService *service.ModelscopeService +} + +// NewModelscopeHandler 创建模型代理处理器实例 +func NewModelscopeHandler(modelscopeService *service.ModelscopeService) *ModelscopeHandler { + return &ModelscopeHandler{ + modelscopeService: modelscopeService, + } +} + +// ModelInfoHandler 处理模型信息查询请求 +func (m *ModelscopeHandler) ModelInfoHandler(c echo.Context) error { + parts := strings.Split(strings.Trim(c.Request().URL.Path, "/"), "/") + + if len(parts) < 5 { + err := fmt.Errorf("请求路径格式非法") + return util.ResponseError(c, err) + } + + org, repo, repoType := parts[3], parts[4], parts[2] + if err := m.modelscopeService.ForwardModelInfo(c, org, repo, repoType); err != nil { + return util.ResponseError(c, err) + } + return nil +} + +// RevisionsHandler 处理模型版本查询请求 +func (m *ModelscopeHandler) RevisionsHandler(c echo.Context) error { + parts := strings.Split(strings.Trim(c.Request().URL.Path, "/"), "/") + + if len(parts) < 5 { + err := fmt.Errorf("请求路径格式非法") + return util.ResponseError(c, err) + } + + org, repo, repoType := parts[3], parts[4], parts[2] + if err := m.modelscopeService.ForwardRevisions(c, org, repo, repoType); err != nil { + return util.ResponseError(c, err) + } + return nil +} + +// FileListHandler 处理模型文件列表请求 +func (m *ModelscopeHandler) FileListHandler(c echo.Context) error { + parts := strings.Split(strings.Trim(c.Request().URL.Path, "/"), "/") + + if len(parts) < 5 { + err := fmt.Errorf("请求路径格式非法") + return util.ResponseError(c, err) + } + + org, repo, repoType := parts[3], parts[4], parts[2] + if err := m.modelscopeService.ForwardFileList(c, org, repo, repoType); err != nil { + return util.ResponseError(c, err) + } + return nil +} + +// FileDownloadHandler 处理模型文件下载请求(支持续传) +func (m *ModelscopeHandler) FileDownloadHandler(c echo.Context) error { + parts := strings.Split(strings.Trim(c.Request().URL.Path, "/"), "/") + + if len(parts) < 5 { + err := fmt.Errorf("请求路径格式非法") + return util.ResponseError(c, err) + } + + org, repo, repoType := parts[3], parts[4], parts[2] + if err := m.modelscopeService.HandleFileDownload(c, org, repo, repoType); err != nil { + return util.ResponseError(c, err) + } + return nil +} + +// FileTreeHandler 处理数据集文件列表请求 +func (m *ModelscopeHandler) FileTreeHandler(c echo.Context) error { + parts := strings.Split(strings.Trim(c.Request().URL.Path, "/"), "/") + + if len(parts) < 5 { + err := fmt.Errorf("请求路径格式非法") + return util.ResponseError(c, err) + } + + org, repo, repoType := parts[3], parts[4], parts[2] + if err := m.modelscopeService.ForwardRepoTree(c, org, repo, repoType); err != nil { + return util.ResponseError(c, err) + } + return nil +} + +// DatasetFileTreeHandler 处理数据集文件列表请求 +func (m *ModelscopeHandler) DatasetFileTreeHandler(c echo.Context) error { + parts := strings.Split(strings.Trim(c.Request().URL.Path, "/"), "/") + + if len(parts) < 4 { + err := fmt.Errorf("请求路径格式非法") + return util.ResponseError(c, err) + } + + datasetId := parts[3] + if err := m.modelscopeService.ForwardRepoTreeByDatasetId(c, datasetId); err != nil { + return util.ResponseError(c, err) + } + return nil +} diff --git a/internal/router/http_router.go b/internal/router/http_router.go index 360fd45..eb2b178 100644 --- a/internal/router/http_router.go +++ b/internal/router/http_router.go @@ -23,21 +23,23 @@ import ( ) type HttpRouter struct { - echo *echo.Echo - fileHandler *handler.FileHandler - metaHandler *handler.MetaHandler - sysHandler *handler.SysHandler - cacheJobHandler *handler.CacheJobHandler + echo *echo.Echo + fileHandler *handler.FileHandler + metaHandler *handler.MetaHandler + sysHandler *handler.SysHandler + cacheJobHandler *handler.CacheJobHandler + modelscopeHandler *handler.ModelscopeHandler } func NewHttpRouter(echo *echo.Echo, fileHandler *handler.FileHandler, metaHandler *handler.MetaHandler, - sysHandler *handler.SysHandler, cacheJobHandler *handler.CacheJobHandler) *HttpRouter { + sysHandler *handler.SysHandler, cacheJobHandler *handler.CacheJobHandler, modelscopeHandler *handler.ModelscopeHandler) *HttpRouter { r := &HttpRouter{ - echo: echo, - fileHandler: fileHandler, - metaHandler: metaHandler, - sysHandler: sysHandler, - cacheJobHandler: cacheJobHandler, + echo: echo, + fileHandler: fileHandler, + metaHandler: metaHandler, + sysHandler: sysHandler, + cacheJobHandler: cacheJobHandler, + modelscopeHandler: modelscopeHandler, } r.initRouter() return r @@ -54,6 +56,7 @@ func (r *HttpRouter) initRouter() { r.routerForCacheJob() r.routerForSpeed() + r.routerForModelscope() } func (r *HttpRouter) routerForSpeed() { // alayanew @@ -91,3 +94,12 @@ func (r *HttpRouter) routerForCacheJob() { // alayanew r.echo.POST("/api/cacheJob/resume", r.cacheJobHandler.ResumeCacheJobHandler) r.echo.POST("/api/cacheJob/realtime", r.cacheJobHandler.RealtimeCacheJobHandler) } + +func (r *HttpRouter) routerForModelscope() { // modelscope + r.echo.GET("/api/v1/:repoType/:org/:repo", r.modelscopeHandler.ModelInfoHandler) + r.echo.GET("/api/v1/:repoType/:org/:repo/revisions", r.modelscopeHandler.RevisionsHandler) + r.echo.GET("/api/v1/:repoType/:org/:repo/repo/files", r.modelscopeHandler.FileListHandler) + r.echo.GET("/api/v1/:repoType/:org/:repo/repo", r.modelscopeHandler.FileDownloadHandler) + r.echo.GET("/api/v1/:repoType/:org/:repo/repo/tree", r.modelscopeHandler.FileTreeHandler) + r.echo.GET("/api/v1/datasets/:datasetId/repo/tree", r.modelscopeHandler.DatasetFileTreeHandler) +} diff --git a/internal/server/http.go b/internal/server/http.go index 4b8f290..9f72025 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -80,6 +80,7 @@ func NewEngine() *echo.Echo { r := echo.New() middleware.InitMiddlewareConfig() r.Use(middleware.QueueLimitMiddleware) + r.Use(middleware.CORSMiddleware()) t := &Template{ templates: template.Must(template.ParseFS(templatesFS, "templates/*.html")), diff --git a/internal/service/modelscope_service.go b/internal/service/modelscope_service.go new file mode 100644 index 0000000..e81eb39 --- /dev/null +++ b/internal/service/modelscope_service.go @@ -0,0 +1,783 @@ +package service + +import ( + "context" + "dingospeed/pkg/config" + "dingospeed/pkg/util" + "fmt" + "github.com/gofrs/flock" + "github.com/labstack/echo/v4" + "go.uber.org/zap" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" +) + +type ModelscopeService struct { + downloadingMap sync.Map + mapMu sync.Mutex +} + +func NewModelscopeService() *ModelscopeService { + return &ModelscopeService{ + downloadingMap: sync.Map{}, + mapMu: sync.Mutex{}, + } +} + +// CalculateBlockInfo 计算块信息:块编号、块起始偏移、块结束偏移 +func (m *ModelscopeService) CalculateBlockInfo(offset int64, chunkSize int64) (blockNum int64, blockStart int64, blockEnd int64) { + blockNum = offset / chunkSize + blockStart = blockNum * chunkSize + blockEnd = blockStart + chunkSize - 1 + return +} + +// IsBlockComplete 检查指定块是否完整 +func (m *ModelscopeService) IsBlockComplete(cachePath string, blockNum int64, chunkSize int64) (bool, error) { + fileInfo, err := os.Stat(cachePath) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, fmt.Errorf("stat cache file failed: %w", err) + } + + blockStart := blockNum * chunkSize + blockEnd := blockStart + chunkSize - 1 + fileSize := fileInfo.Size() + + if fileSize <= blockStart { + return false, nil // 块为空 + } + if fileSize > blockEnd { + return true, nil // 块完整 + } + return false, nil +} + +// getCacheLock 获取缓存文件对应的锁文件 +func (m *ModelscopeService) getCacheLock(cachePath string) *flock.Flock { + lockPath := cachePath + ".lock" + return flock.New(lockPath) +} + +func (m *ModelscopeService) openCacheFile(cachePath string, c echo.Context) (*os.File, *flock.Flock, error) { + lockPath := cachePath + ".lock" + fileLock := flock.New(lockPath) + + ctx, cancel := context.WithTimeout(c.Request().Context(), 3*time.Second) + defer cancel() + locked, lockErr := fileLock.TryLockContext(ctx, time.Duration(0)) + if lockErr != nil { + zap.S().Errorf("获取缓存文件锁失败: %s, err: %v", lockPath, lockErr) + return nil, nil, c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "get cache file lock failed", + "msg": lockErr.Error(), + }) + } + if !locked { + zap.S().Warnf("缓存文件锁超时,无法独占写入: %s", cachePath) + return nil, nil, c.JSON(http.StatusTooManyRequests, map[string]string{ + "code": "429", + "error": "cache file is being written by another request", + "msg": "please try again later", + }) + } + + var cacheFile *os.File + var err error + if _, statErr := os.Stat(cachePath); os.IsNotExist(statErr) { + cacheFile, err = os.OpenFile(cachePath, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0664) + } else { + cacheFile, err = os.OpenFile(cachePath, os.O_RDWR, 0664) + } + + if err != nil { + zap.S().Errorf("打开缓存文件失败: %s, err: %v", cachePath, err) + _ = fileLock.Unlock() + return nil, nil, c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "open cache file failed", + "msg": err.Error(), + }) + } + + return cacheFile, fileLock, nil +} + +func (m *ModelscopeService) copyResponseToCacheAndClient(c echo.Context, resp *http.Response, cacheFile *os.File, cachePath string, totalFileSize int64) error { + chunkSize := config.SysConfig.Modelscope.ChunkSize + if chunkSize <= 0 { + chunkSize = 1024 * 1024 * 8 + } + + currentOffset := int64(0) + contentRange := resp.Header.Get("Content-Range") + if contentRange != "" { + parts := strings.Split(contentRange, " ") + if len(parts) >= 2 { + rangeParts := strings.Split(parts[1], "-") + if len(rangeParts) >= 1 { + parsedOffset, err := strconv.ParseInt(rangeParts[0], 10, 64) + if err == nil { + currentOffset = parsedOffset + } + } + } + } + + buf := make([]byte, chunkSize) + written := int64(0) + + for { + if c.Request().Context().Err() != nil { + zap.S().Warnf("客户端断开连接,停止续传: %s", cachePath) + if cacheFile != nil { + _ = cacheFile.Sync() + closeErr := cacheFile.Close() + if closeErr != nil { + zap.S().Errorf("关闭缓存文件失败: %s, err: %v", cachePath, closeErr) + } + } + return nil + } + + currentWriteOffset := currentOffset + written + blockNum, _, blockEnd := m.CalculateBlockInfo(currentWriteOffset, chunkSize) + + isComplete, err := m.IsBlockComplete(cachePath, blockNum, chunkSize) + if err != nil { + zap.S().Errorf("检查块%d完整性失败: %v", blockNum, err) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "check block complete failed", + "msg": err.Error(), + }) + } + + if isComplete { + skipBytes := blockEnd - currentWriteOffset + 1 + written += skipBytes + continue + } + + readSize := blockEnd - currentWriteOffset + 1 + if readSize > int64(len(buf)) { + readSize = int64(len(buf)) + } + + n, err := resp.Body.Read(buf[:readSize]) + if n > 0 { + if _, seekErr := cacheFile.Seek(currentWriteOffset, io.SeekStart); seekErr != nil { + zap.S().Errorf("定位缓存文件到偏移%d失败: %v", currentWriteOffset, seekErr) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "seek cache file failed", + "msg": seekErr.Error(), + }) + } + + if _, writeErr := cacheFile.Write(buf[:n]); writeErr != nil { + zap.S().Errorf("写入缓存块%d失败: %s, err: %v", blockNum, cachePath, writeErr) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "write cache block failed", + "msg": writeErr.Error(), + }) + } + if syncErr := cacheFile.Sync(); syncErr != nil { + zap.S().Warnf("缓存块%d刷盘失败: %s, err: %v", blockNum, cachePath, syncErr) + } + + if _, writeErr := c.Response().Write(buf[:n]); writeErr != nil { + if strings.Contains(writeErr.Error(), "http2: stream closed") || + strings.Contains(writeErr.Error(), "broken pipe") || + strings.Contains(writeErr.Error(), "connection reset by peer") { + zap.S().Warnf("客户端断开连接,停止返回续传数据: %s, err: %v", cachePath, writeErr) + return nil + } + zap.S().Errorf("返回续传数据失败: %s, err: %v", cachePath, writeErr) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "write response failed", + "msg": writeErr.Error(), + }) + } + + written += int64(n) + if f, ok := c.Response().Writer.(http.Flusher); ok { + f.Flush() + } + + if written%(100*1024*1024) == 0 { + zap.S().Infof("续传进度: %dMB, 文件: %s", written/(1024*1024), cachePath) + } + } + + if err == io.EOF { + if syncErr := cacheFile.Sync(); syncErr != nil { + zap.S().Warnf("缓存文件最终刷盘失败: %s, err: %v", cachePath, syncErr) + } + zap.S().Infof("续传完成: %s, 共下载%d字节,完整文件%d字节", cachePath, written, totalFileSize) + break + } + if err != nil { + zap.S().Errorf("续传中断: %s, err: %v", cachePath, err) + return c.JSON(http.StatusBadGateway, map[string]string{ + "code": "502", + "error": "download interrupted", + "msg": err.Error(), + }) + } + } + + return nil +} + +func (m *ModelscopeService) ForwardModelInfo(c echo.Context, owner, repo string, repoType string) error { + apiPrefix := util.GetAPIPathPrefix(repoType) + officialURL := fmt.Sprintf("%s/api/v1/%s/%s/%s?%s", + config.SysConfig.Modelscope.OfficialBaseURL, + apiPrefix, + url.PathEscape(owner), + url.PathEscape(repo), + c.Request().URL.RawQuery) + zap.S().Infof("转发%s信息请求到官方: %s", apiPrefix, officialURL) + return m.forwardRequest(c, officialURL) +} + +func (m *ModelscopeService) ForwardRevisions(c echo.Context, owner, repo string, repoType string) error { + apiPrefix := util.GetAPIPathPrefix(repoType) + officialURL := fmt.Sprintf("%s/api/v1/%s/%s/%s/revisions?%s", + config.SysConfig.Modelscope.OfficialBaseURL, + apiPrefix, + url.PathEscape(owner), + url.PathEscape(repo), + c.Request().URL.RawQuery) + zap.S().Infof("转发%s版本请求到官方: %s", apiPrefix, officialURL) + return m.forwardRequest(c, officialURL) +} + +func (m *ModelscopeService) ForwardFileList(c echo.Context, owner, repo string, repoType string) error { + apiPrefix := util.GetAPIPathPrefix(repoType) + officialURL := fmt.Sprintf("%s/api/v1/%s/%s/%s/repo/files?%s", + config.SysConfig.Modelscope.OfficialBaseURL, + apiPrefix, + url.PathEscape(owner), + url.PathEscape(repo), + c.Request().URL.RawQuery) + zap.S().Infof("转发%s文件列表请求到官方: %s", apiPrefix, officialURL) + return m.forwardRequest(c, officialURL) +} + +func (m *ModelscopeService) ForwardRepoTree(c echo.Context, owner, repo string, repoType string) error { + apiPrefix := util.GetAPIPathPrefix(repoType) + officialURL := fmt.Sprintf("%s/api/v1/%s/%s/%s/repo/tree?%s", + config.SysConfig.Modelscope.OfficialBaseURL, + apiPrefix, + url.PathEscape(owner), + url.PathEscape(repo), + c.Request().URL.RawQuery) + zap.S().Infof("转发%s文件树请求到官方: %s", apiPrefix, officialURL) + return m.forwardRequest(c, officialURL) +} + +func (m *ModelscopeService) ForwardRepoTreeByDatasetId(c echo.Context, datasetId string) error { + officialURL := fmt.Sprintf("%s/api/v1/datasets/%s/repo/tree?%s", + config.SysConfig.Modelscope.OfficialBaseURL, + url.PathEscape(datasetId), + c.Request().URL.RawQuery) + zap.S().Infof("转发文件树请求到官方: %s", officialURL) + return m.forwardRequest(c, officialURL) +} + +// forwardRequest 通用请求转发逻辑 +func (m *ModelscopeService) forwardRequest(c echo.Context, officialURL string) error { + req, err := http.NewRequest(http.MethodGet, officialURL, nil) + if err != nil { + zap.S().Errorf("构建请求失败: %v", err) + return err + } + + util.AddCLIHeaders(req.Header, c.Request().Header.Get("User-Agent")) + + for k, v := range c.Request().Header { + req.Header[k] = v + } + + resp, err := util.DoRequestWithRetry(req) + if err != nil { + zap.S().Errorf("转发请求失败: %v", err) + return err + } + defer resp.Body.Close() + + for k, v := range resp.Header { + c.Response().Header()[k] = v + } + c.Response().WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Response(), resp.Body) + if err != nil { + zap.S().Errorf("复制响应体失败: %v", err) + return err + } + return nil +} + +// HandleFileDownload 处理ModelScope文件下载请求 +func (m *ModelscopeService) HandleFileDownload(c echo.Context, owner, repo, repoType string) error { + repoId := fmt.Sprintf("%s/%s", owner, repo) + revision := c.Request().URL.Query().Get("Revision") + filePath := c.Request().URL.Query().Get("FilePath") + + if revision == "" { + revision = "master" + } + if filePath == "" { + zap.S().Error("请求参数缺失: FilePath为空") + return c.JSON(http.StatusBadRequest, map[string]string{ + "code": "400", + "error": "missing FilePath parameter", + }) + } + + if err := util.EnsureDir(filepath.Join(config.SysConfig.GetModelCacheRoot(), "dummy")); err != nil { + zap.S().Errorf("初始化模型缓存根目录失败: %v", err) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "init model cache root dir failed", + "msg": err.Error(), + }) + } + + cachePath, cacheExists := util.GetCachePath(repoType, repoId, revision, filePath) + if cachePath == "" { + zap.S().Errorf("生成缓存路径失败: 无效的repoId格式 %s", repoId) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "get cache path failed", + "msg": "invalid repoId format, require org/repo", + }) + } + zap.S().Infof("生成缓存路径: %s (缓存文件是否存在: %t)", cachePath, cacheExists) + + if err := util.EnsureDir(filepath.Dir(cachePath)); err != nil { + zap.S().Errorf("初始化缓存文件上级目录失败: %s, err: %v", filepath.Dir(cachePath), err) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "init cache file parent dir failed", + "msg": err.Error(), + }) + } + + cachedSize := util.GetFileSize(cachePath) + zap.S().Infof("缓存文件状态: %s (已下载: %d字节)", cachePath, cachedSize) + if cacheExists && cachedSize == 0 { + zap.S().Warnf("缓存文件存在但大小为0,视为无效缓存: %s", cachePath) + if err := os.Remove(cachePath); err != nil { + zap.S().Errorf("删除空缓存文件失败: %s, err: %v", cachePath, err) + } + cachedSize = 0 + cacheExists = false + } + + clientStart, clientEnd, err := util.ParseRangeHeader(c.Request()) + if err != nil { + zap.S().Errorf("解析Range失败: %v", err) + return c.JSON(http.StatusBadRequest, map[string]string{ + "code": "400", + "error": "parse Range header failed", + "msg": err.Error(), + }) + } + + actualStart := clientStart + if cachedSize > 0 && actualStart < cachedSize { + actualStart = cachedSize + } + zap.S().Infof("续传起始位置: 客户端请求=%d, 缓存末尾=%d, 实际起始=%d", clientStart, cachedSize, actualStart) + + headerWritten := false + c.Response().Header().Set("Transfer-Encoding", "chunked") + c.Response().Header().Set("Content-Type", "application/octet-stream") + c.Response().Header().Set("Access-Control-Expose-Headers", "Content-Range, Content-Type") + + var cacheWritten int64 = 0 + var needRemoteDownload bool + if cachedSize > 0 && clientStart < cachedSize { + cacheWritten, headerWritten, needRemoteDownload, err = m.writeCacheData(c, cachePath, clientStart, clientEnd, cachedSize, headerWritten) + if err != nil { + zap.S().Errorf("读取缓存数据失败: %v", err) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "write cache data failed", + "msg": err.Error(), + }) + } + + if clientEnd != -1 && (clientStart+cacheWritten-1) >= clientEnd { + zap.S().Infof("缓存数据已满足客户端Range请求,无需续传") + return nil + } + + if needRemoteDownload { + m.mapMu.Lock() + defer m.mapMu.Unlock() + + isDownloading := false + loaded := m.downloadingMap.CompareAndSwap(cachePath, false, true) + if !loaded { + val, exists := m.downloadingMap.Load(cachePath) + if exists && val.(bool) { + isDownloading = true + } else { + m.downloadingMap.Store(cachePath, true) + } + } + + if !isDownloading { + zap.S().Infof("触发远程下载: 连续10次空轮询,当前偏移=%d, 文件=%s (标记为已发起远程下载)", clientStart+cacheWritten, cachePath) + actualStart = clientStart + cacheWritten + } else { + zap.S().Infof("文件%s已发起远程下载,不重复触发,继续轮询缓存", cachePath) + cacheWritten, headerWritten, needRemoteDownload, err = m.writeCacheData(c, cachePath, clientStart+cacheWritten, clientEnd, cachedSize, headerWritten) + if err != nil { + zap.S().Errorf("再次轮询缓存数据失败: %v", err) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "retry write cache data failed", + "msg": err.Error(), + }) + } + actualStart = clientStart + cacheWritten + if needRemoteDownload { + zap.S().Warnf("文件%s已标记为下载中,但仍触发远程下载请求,终止流程", cachePath) + return nil + } + } + } + } + + if err := m.downloadAndWriteRemaining(c, owner, repo, actualStart, clientEnd, cachePath, headerWritten, repoType); err != nil { + zap.S().Errorf("远程下载并写入剩余数据失败: %v", err) + m.downloadingMap.Delete(cachePath) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "download and write remaining data failed", + "msg": err.Error(), + }) + } + + // 下载完成后清除标识 + m.downloadingMap.Delete(cachePath) + zap.S().Infof("文件%s远程下载完成,清除下载标识", cachePath) + + return nil +} + +func (m *ModelscopeService) writeCacheData(c echo.Context, cachePath string, clientStart, clientEnd, cachedSize int64, headerWritten bool) (int64, bool, bool, error) { + pollInterval := 1 * time.Second // 每秒轮询一次缓存 + maxPollCount := 30 // 最大连续空轮询次数 + + totalWritten := int64(0) + currentCacheOffset := clientStart // 当前读取到的缓存偏移量 + emptyPollCount := 0 // 连续空轮询计数器 + + for { + // 客户端断开连接则直接返回 + if c.Request().Context().Err() != nil { + zap.S().Warnf("客户端断开连接,停止返回缓存数据: %s", cachePath) + return totalWritten, headerWritten, false, nil + } + + cacheFile, err := os.Open(cachePath) + if err != nil { + zap.S().Errorf("打开缓存文件失败: %s, err: %v", cachePath, err) + return totalWritten, headerWritten, false, fmt.Errorf("open cache file failed: %w", err) + } + + fileInfo, err := cacheFile.Stat() + if err != nil { + _ = cacheFile.Close() + zap.S().Errorf("获取缓存文件信息失败: %s, err: %v", cachePath, err) + return totalWritten, headerWritten, false, fmt.Errorf("stat cache file failed: %w", err) + } + latestCacheSize := fileInfo.Size() + + cacheEnd := latestCacheSize - 1 + if clientEnd != -1 && clientEnd < cacheEnd { + cacheEnd = clientEnd + } + + if currentCacheOffset > cacheEnd { + emptyPollCount++ + _ = cacheFile.Close() + + isDownloading := false + if downloadingVal, exists := m.downloadingMap.Load(cachePath); exists { + isDownloading = downloadingVal.(bool) + } + + zap.S().Debugf("缓存暂未更新,第%d次空轮询(最大%d次): 当前偏移=%d, 缓存末尾=%d, 文件=%s, 有客户端下载中: %t", + emptyPollCount, maxPollCount, currentCacheOffset, cacheEnd, cachePath, isDownloading) + + if emptyPollCount >= maxPollCount && !isDownloading { + zap.S().Warnf("连续%d次空轮询且无客户端下载中,触发当前客户端发起远程下载: %s", maxPollCount, cachePath) + return totalWritten, headerWritten, true, nil + } + + zap.S().Debugf("继续轮询缓存: 文件=%s, 原因: %s", cachePath, + func() string { + if isDownloading { + return "有客户端正在远程下载" + } + return fmt.Sprintf("空轮询次数未到上限(当前%d次,最大%d次)", emptyPollCount, maxPollCount) + }()) + time.Sleep(pollInterval) + continue + } + + emptyPollCount = 0 + if !headerWritten { + contentRange := fmt.Sprintf("bytes %d-%d/%d", clientStart, cacheEnd, latestCacheSize) + c.Response().Header().Set("Content-Range", contentRange) + c.Response().WriteHeader(http.StatusPartialContent) + headerWritten = true + zap.S().Infof("设置缓存响应头: Content-Range=%s", contentRange) + } + + if _, err := cacheFile.Seek(currentCacheOffset, io.SeekStart); err != nil { + _ = cacheFile.Close() + zap.S().Errorf("定位缓存文件失败: %s, err: %v", cachePath, err) + return totalWritten, headerWritten, false, fmt.Errorf("seek cache file failed: %w", err) + } + + buf := make([]byte, config.SysConfig.Modelscope.ChunkSize) + readableSize := cacheEnd - currentCacheOffset + 1 + writtenInRound := int64(0) + + for writtenInRound < readableSize { + if c.Request().Context().Err() != nil { + _ = cacheFile.Close() + zap.S().Warnf("客户端断开连接,停止返回缓存数据: %s", cachePath) + return totalWritten, headerWritten, false, nil + } + + readSize := readableSize - writtenInRound + if readSize > int64(len(buf)) { + readSize = int64(len(buf)) + } + + n, err := cacheFile.Read(buf[:readSize]) + if n > 0 { + if _, writeErr := c.Response().Write(buf[:n]); writeErr != nil { + _ = cacheFile.Close() + zap.S().Errorf("返回缓存数据失败: %s, err: %v", cachePath, writeErr) + return totalWritten, headerWritten, false, fmt.Errorf("write cache data to response failed: %w", writeErr) + } + + writtenInRound += int64(n) + totalWritten += int64(n) + currentCacheOffset += int64(n) + + if f, ok := c.Response().Writer.(http.Flusher); ok { + f.Flush() + } + } + + if err == io.EOF { + break + } + if err != nil { + _ = cacheFile.Close() + zap.S().Errorf("读取缓存数据失败: %s, err: %v", cachePath, err) + return totalWritten, headerWritten, false, fmt.Errorf("read cache file failed: %w", err) + } + } + _ = cacheFile.Close() + zap.S().Debugf("本轮读取缓存完成: 文件=%s, 读取字节=%d, 累计读取=%d, 当前偏移=%d", + cachePath, writtenInRound, totalWritten, currentCacheOffset) + + if clientEnd == -1 || currentCacheOffset > clientEnd { + zap.S().Infof("缓存数据已满足客户端Range请求: 文件=%s, 累计读取=%d字节", cachePath, totalWritten) + return totalWritten, headerWritten, false, nil + } + + time.Sleep(pollInterval) + } +} + +// downloadAndWriteRemaining 下载剩余部分并写入响应+缓存 +func (m *ModelscopeService) downloadAndWriteRemaining(c echo.Context, owner, repo string, actualStart, clientEnd int64, cachePath string, headerWritten bool, repoType string) error { + req, err := m.buildDownloadRequest(c, owner, repo, actualStart, clientEnd, repoType) + if err != nil { + return err + } + + resp, err := m.sendDownloadRequest(req, c, owner, repo, repoType) + if err != nil { + return err + } + defer resp.Body.Close() + + totalFileSize := m.parseTotalFileSize(resp) + + cacheFile, fileLock, err := m.openCacheFile(cachePath, c) + if err != nil { + return err + } + defer func() { + if cacheFile != nil { + if closeErr := cacheFile.Close(); closeErr != nil { + zap.S().Errorf("关闭缓存文件失败: %s, err: %v", cachePath, closeErr) + } + } + if fileLock != nil { + if unlockErr := fileLock.Unlock(); unlockErr != nil { + zap.S().Errorf("解锁缓存文件锁失败: %s, err: %v", cachePath+".lock", unlockErr) + } + } + }() + + headerWritten = m.writeResponseHeader(c, resp, headerWritten) + + err = m.copyResponseToCacheAndClient(c, resp, cacheFile, cachePath, totalFileSize) + if err != nil { + return err + } + + return nil +} + +// buildDownloadRequest 构建下载剩余部分的HTTP请求 +func (m *ModelscopeService) buildDownloadRequest(c echo.Context, owner, repo string, actualStart, clientEnd int64, repoType string) (*http.Request, error) { + apiPrefix := util.GetAPIPathPrefix(repoType) + query := c.Request().URL.RawQuery + officialURL := fmt.Sprintf("%s/api/v1/%s/%s/%s/repo?%s", + config.SysConfig.Modelscope.OfficialBaseURL, + apiPrefix, + url.PathEscape(owner), + url.PathEscape(repo), + query, + ) + zap.S().Infof("请求ModelScope官方地址: %s", officialURL) + + req, err := http.NewRequest(http.MethodGet, officialURL, nil) + if err != nil { + zap.S().Errorf("构建请求失败: %v", err) + return nil, c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "build request failed", + "msg": err.Error(), + }) + } + + skipHeaders := map[string]bool{ + "Range": true, + "User-Agent": true, + "Host": true, + } + for k, v := range c.Request().Header { + key := strings.ToLower(k) + if !skipHeaders[key] { + req.Header[k] = v + } + } + + util.AddCLIHeaders(req.Header, c.Request().Header.Get("User-Agent")) + + rangeHeader := fmt.Sprintf("bytes=%d-", actualStart) + if clientEnd != -1 { + rangeHeader = fmt.Sprintf("bytes=%d-%d", actualStart, clientEnd) + } + req.Header.Set("Range", rangeHeader) + zap.S().Infof("向官方请求剩余部分: %s", rangeHeader) + + return req, nil +} + +// sendDownloadRequest 发送下载请求并校验响应状态码 +func (m *ModelscopeService) sendDownloadRequest(req *http.Request, c echo.Context, owner, repo, repoType string) (*http.Response, error) { + resp, err := util.DoRequestWithRetry(req) + if err != nil { + zap.S().Errorf("下载剩余部分失败: %v", err) + return nil, c.JSON(http.StatusBadGateway, map[string]string{ + "code": "502", + "error": "download remaining failed", + "msg": err.Error(), + }) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + officialURL := req.URL.String() + zap.S().Errorf("ModelScope返回错误状态码: %d, URL: %s", resp.StatusCode, officialURL) + errorMsg := fmt.Sprintf("modelscope server return status code: %d", resp.StatusCode) + + var respJSON error + switch resp.StatusCode { + case http.StatusNotFound: + respJSON = c.JSON(http.StatusNotFound, map[string]string{ + "code": "404", + "error": "resource not found", + "msg": "model or file does not exist on ModelScope", + }) + case http.StatusForbidden: + respJSON = c.JSON(http.StatusForbidden, map[string]string{ + "code": "403", + "error": "forbidden", + "msg": "no permission to access the resource", + }) + default: + respJSON = c.JSON(http.StatusBadGateway, map[string]string{ + "code": "502", + "error": "modelscope server error", + "msg": errorMsg, + }) + } + return nil, respJSON + } + + return resp, nil +} + +// parseTotalFileSize 从响应头解析Content-Range获取总文件大小 +func (m *ModelscopeService) parseTotalFileSize(resp *http.Response) int64 { + totalFileSize := int64(-1) + contentRange := resp.Header.Get("Content-Range") + if contentRange != "" { + parts := strings.Split(contentRange, "/") + if len(parts) == 2 { + parsedSize, err := strconv.ParseInt(parts[1], 10, 64) + if err == nil { + totalFileSize = parsedSize + } else { + zap.S().Warnf("解析Content-Range失败: %s, err: %v", contentRange, err) + } + } + } + return totalFileSize +} + +// writeResponseHeader 写入续传响应头 +func (m *ModelscopeService) writeResponseHeader(c echo.Context, resp *http.Response, headerWritten bool) bool { + if !headerWritten { + if resp.StatusCode == http.StatusPartialContent { + c.Response().Header().Set("Content-Range", resp.Header.Get("Content-Range")) + c.Response().WriteHeader(http.StatusPartialContent) + } else { + c.Response().WriteHeader(http.StatusOK) + } + headerWritten = true + zap.S().Infof("设置续传响应头,状态码: %d", resp.StatusCode) + } + return headerWritten +} diff --git a/internal/service/service.go b/internal/service/service.go index d26f4bd..bbb41a7 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -16,4 +16,4 @@ package service import "github.com/google/wire" -var ServiceProvider = wire.NewSet(NewFileService, NewMetaService, NewSysService, NewSchedulerService, NewCacheJobService, NewLocalOperationService) +var ServiceProvider = wire.NewSet(NewFileService, NewMetaService, NewSysService, NewSchedulerService, NewCacheJobService, NewLocalOperationService, NewModelscopeService) diff --git a/pkg/config/config.go b/pkg/config/config.go index 8fdb763..0a9f137 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "os" + "path/filepath" "sync" "time" @@ -46,6 +47,7 @@ type Config struct { DynamicProxy DynamicProxy `json:"dynamicProxy" yaml:"dynamicProxy"` Scheduler Scheduler `json:"scheduler" yaml:"scheduler"` mu sync.RWMutex + Modelscope Modelscope `yaml:"modelscope"` } type ServerConfig struct { @@ -154,6 +156,13 @@ type DynamicProxy struct { Webhook string `json:"webhook " yaml:"webhook"` } +type Modelscope struct { + OfficialBaseURL string `yaml:"officialBaseURL"` + ChunkSize int64 `yaml:"chunkSize"` + MaxRetry int `yaml:"maxRetry"` + RetryDelay int `yaml:"retryDelay"` +} + func (c *Config) GetHFURLBase() string { return fmt.Sprintf("%s://%s", c.GetHfScheme(), c.GetHfNetLoc()) } @@ -323,6 +332,14 @@ func (c *Config) GetOriginSchedulerModel() string { return c.Scheduler.OriginMode } +func (c *Config) GetModelCacheRoot() string { + return filepath.Join(c.Server.Repos, consts.ModelCacheRoot) +} + +func (c *Config) GetDatasetCacheRoot() string { + return filepath.Join(c.Server.Repos, consts.DatasetCacheRoot) +} + func (c *Config) SetDefaults() { if c.Server.Port == 0 { c.Server.Port = 8090 diff --git a/pkg/consts/const.go b/pkg/consts/const.go index 38dd991..0abafc2 100644 --- a/pkg/consts/const.go +++ b/pkg/consts/const.go @@ -104,3 +104,8 @@ const ( const ( TaskMoreErrMsg = "当前缓存任务较多导致启动失败,请稍后再启动。" ) + +const ( + ModelCacheRoot = "modelscope/models" + DatasetCacheRoot = "modelscope/datasets" +) diff --git a/pkg/middleware/queue_limit.go b/pkg/middleware/queue_limit.go index d6a001a..2aa7894 100644 --- a/pkg/middleware/queue_limit.go +++ b/pkg/middleware/queue_limit.go @@ -2,6 +2,7 @@ package middleware import ( "net" + "net/http" "strings" "dingospeed/pkg/config" @@ -75,3 +76,23 @@ func nextRequest(c echo.Context, next echo.HandlerFunc) error { return util.ErrorTooManyRequest(c) } } + +// CORSMiddleware 跨域中间件(适配Echo框架) +func CORSMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + // 设置跨域头 + c.Response().Header().Set("Access-Control-Allow-Origin", "*") + c.Response().Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS, HEAD") + c.Response().Header().Set("Access-Control-Allow-Headers", "*") + c.Response().Header().Set("Access-Control-Expose-Headers", "*") + + // 处理OPTIONS预检请求 + if c.Request().Method == http.MethodOptions { + return c.NoContent(http.StatusOK) + } + + return next(c) + } + } +} diff --git a/pkg/util/modelscope_util.go b/pkg/util/modelscope_util.go new file mode 100644 index 0000000..7c96fbd --- /dev/null +++ b/pkg/util/modelscope_util.go @@ -0,0 +1,211 @@ +package util + +import ( + "crypto/tls" + "fmt" + "net/http" + "os" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "dingospeed/pkg/config" + + "go.uber.org/zap" +) + +// 提取ModelScope版本和Python版本的正则表达式 +var ( + msVersionRegex = regexp.MustCompile(`modelscope/(\d+\.\d+\.\d+)`) + pyVersionRegex = regexp.MustCompile(`python/(\d+\.\d+\.\d+)`) +) + +// ParseClientEnv 增强:从客户端UA提取版本+返回真实系统架构 +func ParseClientEnv(clientUA string) (msVersion, system, arch, pythonVer string) { + if msMatch := msVersionRegex.FindStringSubmatch(clientUA); len(msMatch) > 1 { + msVersion = msMatch[1] + } else { + msVersion = "1.33.0" + } + + if pyMatch := pyVersionRegex.FindStringSubmatch(clientUA); len(pyMatch) > 1 { + pythonVer = pyMatch[1] + } else { + pythonVer = "3.13.2" + } + + system = runtime.GOOS + switch system { + case "darwin": + system = "macOS" + case "windows": + system = "Windows" + case "linux": + system = "Linux" + } + + arch = runtime.GOARCH + switch arch { + case "amd64": + arch = "x86_64" + case "arm64": + arch = "aarch64" + } + + return +} + +func AddCLIHeaders(header http.Header, clientUA string) { + msVersion, system, arch, pythonVer := ParseClientEnv(clientUA) + + userAgent := fmt.Sprintf("modelscope/%s (%s; %s) Python/%s", msVersion, system, arch, pythonVer) + header.Set("User-Agent", userAgent) + zap.S().Infof("构建兼容的 User-Agent: %s (客户端原始 UA: %s)", userAgent, clientUA) + + header.Set("Accept-Encoding", "identity") +} + +func EnsureDir(path string) error { + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + zap.S().Errorf("创建目录失败: %s, 错误: %v", filepath.Dir(path), err) + return err + } + return nil +} + +func GetCachePath(repoType, repoId, revision, filePath string) (string, bool) { + parts := strings.Split(repoId, "/") + if len(parts) != 2 { + zap.S().Errorf("无效的repoId格式: %s,需为 org/repo 格式", repoId) + return "", false + } + + var cacheRoot string + switch repoType { + case "datasets": + cacheRoot = config.SysConfig.GetDatasetCacheRoot() + case "models", "": + cacheRoot = config.SysConfig.GetModelCacheRoot() + default: + zap.S().Warnf("未知的repoType: %s,默认使用models缓存目录", repoType) + cacheRoot = config.SysConfig.GetModelCacheRoot() + } + + targetCachePath := filepath.Join(cacheRoot, parts[0], parts[1], revision, filepath.Clean(filePath)) + fileInfo, err := os.Stat(targetCachePath) + if err == nil { + zap.S().Debugf("缓存文件存在: %s, 大小: %d字节", targetCachePath, fileInfo.Size()) + return targetCachePath, true + } + + _ = EnsureDir(targetCachePath) + return targetCachePath, false +} + +var ( + httpClientOnce sync.Once + globalHTTPClient *http.Client +) + +// CreateHTTPClient 单例创建HTTP客户端 +func CreateHTTPClient() *http.Client { + httpClientOnce.Do(func() { + transport := &http.Transport{ + MaxConnsPerHost: 0, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 50, + IdleConnTimeout: 5 * time.Minute, + DisableCompression: true, + DisableKeepAlives: false, + TLSHandshakeTimeout: 2 * time.Minute, + ResponseHeaderTimeout: 5 * time.Minute, + ExpectContinueTimeout: 1 * time.Minute, + ForceAttemptHTTP2: true, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS13, + InsecureSkipVerify: true, // 跳过证书校验,适配内网/代理场景,保留 + Renegotiation: tls.RenegotiateFreelyAsClient, + }, + } + + globalHTTPClient = &http.Client{ + Timeout: 30 * time.Minute, // 30分钟超时,完全适配7-8GB大文件下载,保留 + Transport: transport, + } + }) + + return globalHTTPClient +} + +// DoRequestWithRetry 带重试的HTTP请求 +func DoRequestWithRetry(req *http.Request) (*http.Response, error) { + client := CreateHTTPClient() + var resp *http.Response + var err error + + for i := 0; i < config.SysConfig.Modelscope.MaxRetry; i++ { + resp, err = client.Do(req) + if err == nil { + return resp, nil + } + + if strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline exceeded") { + zap.S().Warnf("⚠️ Retry %d/%d: request timeout - %v", i+1, config.SysConfig.Modelscope.MaxRetry, err) + time.Sleep(time.Duration(config.SysConfig.Modelscope.RetryDelay) * time.Duration(i+1)) + continue + } + + return nil, err + } + + return nil, fmt.Errorf("failed after %d retries: %v", config.SysConfig.Modelscope.MaxRetry, err) +} + +// ParseRangeHeader 解析Range请求头,返回起始字节和结束字节(-1表示到末尾) +func ParseRangeHeader(r *http.Request) (start int64, end int64, err error) { + rangeHeader := r.Header.Get("Range") + if rangeHeader == "" { + return 0, -1, nil + } + + // 解析Range头格式:bytes=start-end + parts := strings.SplitN(rangeHeader, "=", 2) + if len(parts) != 2 || parts[0] != "bytes" { + return 0, -1, fmt.Errorf("invalid Range header: %s", rangeHeader) + } + + rangeParts := strings.SplitN(parts[1], "-", 2) + start, err = strconv.ParseInt(rangeParts[0], 10, 64) + if err != nil { + return 0, -1, fmt.Errorf("invalid start byte: %s, err: %v", rangeParts[0], err) + } + + if len(rangeParts) == 2 && rangeParts[1] != "" { + end, err = strconv.ParseInt(rangeParts[1], 10, 64) + if err != nil { + return 0, -1, fmt.Errorf("invalid end byte: %s, err: %v", rangeParts[1], err) + } + } else { + end = -1 + } + + return start, end, nil +} + +func GetAPIPathPrefix(repoType string) string { + repoType = strings.TrimSpace(strings.ToLower(repoType)) + switch repoType { + case "dataset", "datasets": + return "datasets" + case "model", "models": + return "models" + default: + zap.S().Warnf("无效的repoType: %s,默认使用models", repoType) + return "models" + } +} diff --git a/pkg/util/repo_util.go b/pkg/util/repo_util.go index 3c65585..cc18f02 100644 --- a/pkg/util/repo_util.go +++ b/pkg/util/repo_util.go @@ -30,6 +30,7 @@ import ( "dingospeed/pkg/common" "github.com/bytedance/sonic" + "go.uber.org/zap" "golang.org/x/sys/unix" ) @@ -141,11 +142,16 @@ func IsFile(path string) bool { // GetFileSize 获取文件大小 func GetFileSize(path string) int64 { - fh, err := os.Stat(path) + fileInfo, err := os.Stat(path) if err != nil { - fmt.Printf("读取文件%s失败, err: %s\n", path, err) + if os.IsNotExist(err) { + zap.S().Infof("文件不存在: %s", path) + return 0 + } + zap.S().Errorf("读取文件大小失败: %s, err: %v", path, err) + return 0 } - return fh.Size() + return fileInfo.Size() } func ReadDir(dir string) ([]string, error) {