-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhfd.sh
More file actions
executable file
·193 lines (168 loc) · 7.63 KB
/
hfd.sh
File metadata and controls
executable file
·193 lines (168 loc) · 7.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#!/usr/bin/env bash
# Color definitions
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
trap 'printf "${YELLOW}\nDownload interrupted. If you re-run the command, you can resume the download from the breakpoint.\n${NC}"; exit 1' INT
display_help() {
cat << EOF
Usage:
hfd <repo_id> [--include include_pattern1 include_pattern2 ...] [--exclude exclude_pattern1 exclude_pattern2 ...] [--hf_username username] [--hf_token token] [--tool aria2c|wget] [-x threads] [--dataset] [--local-dir path]
Description:
Downloads a model or dataset from Hugging Face using the provided repo ID.
Parameters:
repo_id The Hugging Face repo ID in the format 'org/repo_name'.
--include (Optional) Flag to specify string patterns to include files for downloading. Supports multiple patterns.
--exclude (Optional) Flag to specify string patterns to exclude files from downloading. Supports multiple patterns.
include/exclude_pattern The patterns to match against filenames, supports wildcard characters. e.g., '--exclude *.safetensor *.txt', '--include vae/*'.
--hf_username (Optional) Hugging Face username for authentication. **NOT EMAIL**.
--hf_token (Optional) Hugging Face token for authentication.
--tool (Optional) Download tool to use. Can be aria2c (default) or wget.
-x (Optional) Number of download threads for aria2c. Defaults to 4.
--dataset (Optional) Flag to indicate downloading a dataset.
--local-dir (Optional) Local directory path where the model or dataset will be stored.
Example:
hfd bigscience/bloom-560m --exclude *.safetensors
hfd meta-llama/Llama-2-7b --hf_username myuser --hf_token mytoken -x 4
hfd lavita/medical-qa-shared-task-v1-toy --dataset
EOF
exit 1
}
MODEL_ID=$1
shift
# Default values
TOOL="aria2c"
THREADS=4
HF_ENDPOINT=${HF_ENDPOINT:-"https://huggingface.co"}
INCLUDE_PATTERNS=()
EXCLUDE_PATTERNS=()
while [[ $# -gt 0 ]]; do
case $1 in
--include)
shift
while [[ $# -gt 0 && ! $1 =~ ^-- ]]; do
INCLUDE_PATTERNS+=("$1")
shift
done
;;
--exclude)
shift
while [[ $# -gt 0 && ! $1 =~ ^-- ]]; do
EXCLUDE_PATTERNS+=("$1")
shift
done
;;
--hf_username) HF_USERNAME="$2"; shift 2 ;;
--hf_token) HF_TOKEN="$2"; shift 2 ;;
--tool) TOOL="$2"; shift 2 ;;
-x) THREADS="$2"; shift 2 ;;
--dataset) DATASET=1; shift ;;
--local-dir) LOCAL_DIR="$2"; shift 2 ;;
*) shift ;;
esac
done
# Check if aria2, wget, curl, git, and git-lfs are installed
check_command() {
if ! command -v $1 &>/dev/null; then
echo -e "${RED}$1 is not installed. Please install it first.${NC}"
exit 1
fi
}
# Mark current repo safe when using shared file system like samba or nfs
ensure_ownership() {
if git status 2>&1 | grep "fatal: detected dubious ownership in repository at" > /dev/null; then
git config --global --add safe.directory "${PWD}"
printf "${YELLOW}Detected dubious ownership in repository, mark ${PWD} safe using git, edit ~/.gitconfig if you want to reverse this.\n${NC}"
fi
}
[[ "$TOOL" == "aria2c" ]] && check_command aria2c
[[ "$TOOL" == "wget" ]] && check_command wget
check_command curl; check_command git; check_command git-lfs
[[ -z "$MODEL_ID" || "$MODEL_ID" =~ ^-h ]] && display_help
if [[ -z "$LOCAL_DIR" ]]; then
LOCAL_DIR="${MODEL_ID#*/}"
fi
if [[ "$DATASET" == 1 ]]; then
MODEL_ID="datasets/$MODEL_ID"
fi
echo "Downloading to $LOCAL_DIR"
if [ -d "$LOCAL_DIR/.git" ]; then
printf "${YELLOW}%s exists, Skip Clone.\n${NC}" "$LOCAL_DIR"
cd "$LOCAL_DIR" && ensure_ownership && GIT_LFS_SKIP_SMUDGE=1 git pull || { printf "${RED}Git pull failed.${NC}\n"; exit 1; }
else
REPO_URL="$HF_ENDPOINT/$MODEL_ID"
GIT_REFS_URL="${REPO_URL}/info/refs?service=git-upload-pack"
echo "Testing GIT_REFS_URL: $GIT_REFS_URL"
response=$(curl -s -o /dev/null -w "%{http_code}" "$GIT_REFS_URL")
if [ "$response" == "401" ] || [ "$response" == "403" ]; then
if [[ -z "$HF_USERNAME" || -z "$HF_TOKEN" ]]; then
printf "${RED}HTTP Status Code: $response.\nThe repository requires authentication, but --hf_username and --hf_token is not passed. Please get token from https://huggingface.co/settings/tokens.\nExiting.\n${NC}"
exit 1
fi
REPO_URL="https://$HF_USERNAME:$HF_TOKEN@${HF_ENDPOINT#https://}/$MODEL_ID"
elif [ "$response" != "200" ]; then
printf "${RED}Unexpected HTTP Status Code: $response\n${NC}"
printf "${YELLOW}Executing debug command: curl -v %s\nOutput:${NC}\n" "$GIT_REFS_URL"
curl -v "$GIT_REFS_URL"; printf "\n${RED}Git clone failed.\n${NC}"; exit 1
fi
echo "GIT_LFS_SKIP_SMUDGE=1 git clone $REPO_URL $LOCAL_DIR"
GIT_LFS_SKIP_SMUDGE=1 git clone $REPO_URL $LOCAL_DIR && cd "$LOCAL_DIR" || { printf "${RED}Git clone failed.\n${NC}"; exit 1; }
ensure_ownership
while IFS= read -r file; do
truncate -s 0 "$file"
done <<< $(git lfs ls-files | cut -d ' ' -f 3-)
fi
printf "\nStart Downloading lfs files, bash script:\ncd $LOCAL_DIR\n"
files=$(git lfs ls-files | cut -d ' ' -f 3-)
declare -a urls
file_matches_include_patterns() {
local file="$1"
for pattern in "${INCLUDE_PATTERNS[@]}"; do
if [[ "$file" == $pattern ]]; then
return 0
fi
done
return 1
}
file_matches_exclude_patterns() {
local file="$1"
for pattern in "${EXCLUDE_PATTERNS[@]}"; do
if [[ "$file" == $pattern ]]; then
return 0
fi
done
return 1
}
while IFS= read -r file; do
url="$HF_ENDPOINT/$MODEL_ID/resolve/main/$file"
file_dir=$(dirname "$file")
mkdir -p "$file_dir"
if [[ "$TOOL" == "wget" ]]; then
download_cmd="wget -c \"$url\" -O \"$file\""
[[ -n "$HF_TOKEN" ]] && download_cmd="wget --header=\"Authorization: Bearer ${HF_TOKEN}\" -c \"$url\" -O \"$file\""
else
download_cmd="aria2c --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c \"$url\" -d \"$file_dir\" -o \"$(basename "$file")\""
[[ -n "$HF_TOKEN" ]] && download_cmd="aria2c --header=\"Authorization: Bearer ${HF_TOKEN}\" --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c \"$url\" -d \"$file_dir\" -o \"$(basename "$file")\""
fi
if [[ ${#INCLUDE_PATTERNS[@]} -gt 0 ]]; then
file_matches_include_patterns "$file" || { printf "# %s\n" "$download_cmd"; continue; }
fi
if [[ ${#EXCLUDE_PATTERNS[@]} -gt 0 ]]; then
file_matches_exclude_patterns "$file" && { printf "# %s\n" "$download_cmd"; continue; }
fi
printf "%s\n" "$download_cmd"
urls+=("$url|$file")
done <<< "$files"
for url_file in "${urls[@]}"; do
IFS='|' read -r url file <<< "$url_file"
printf "${YELLOW}Start downloading ${file}.\n${NC}"
file_dir=$(dirname "$file")
if [[ "$TOOL" == "wget" ]]; then
[[ -n "$HF_TOKEN" ]] && wget --header="Authorization: Bearer ${HF_TOKEN}" -c "$url" -O "$file" || wget -c "$url" -O "$file"
else
[[ -n "$HF_TOKEN" ]] && aria2c --header="Authorization: Bearer ${HF_TOKEN}" --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c "$url" -d "$file_dir" -o "$(basename "$file")" || aria2c --console-log-level=error --file-allocation=none -x $THREADS -s $THREADS -k 1M -c "$url" -d "$file_dir" -o "$(basename "$file")"
fi
[[ $? -eq 0 ]] && printf "Downloaded %s successfully.\n" "$url" || { printf "${RED}Failed to download %s.\n${NC}" "$url"; exit 1; }
done
printf "${GREEN}Download completed successfully.\n${NC}"