To set up the environment for training, please install the required dependencies by running the following command:
pip install -r requirements.txtWe select two image datasets and use the features extracted by pretrained models for training. To obtain the features, please follow the following instructions:
- Glint360K:
- Download the webdataset version of the Glint360K dataset from here
- Download the pretrained ResNet-50 model from here
- Extract the features using the pretrained ResNet-50 model
python extract_feat.py \ --data 'glint360k' \ --data-url 'path/to/glint360k-0{000..416}.tar.gz' \ --batch-size 512 \ --num-samples 18000000 \ --data-size 17091657 \ --out-dir 'dir/to/output' \ --pretrained-path 'path/to/pretrained/16backbone_r50.pth'
- TreeOfLife-10M:
- Follow the instructions in 'Reproduce TreeOfLife-10M' section on this page to download the webdataset version of the TreeOfLife-10M dataset
- In the same environment as the previous step, extract the features using the pretrained CLIP model
python extract_feat.py \ --data 'treeoflife10m' \ --data-url 'path/to/treeoflife10m-000{000..152}.tar' \ --batch-size 512 \ --num-samples 10000000 \ --data-size 9533174 \ --out-dir 'dir/to/output' \ --metadata-path 'path/to/metadata/catalog.csv'
- After these steps, we obtain two tensor files under
dir/to/output:features.ptandlabels.pt. Thefeatures.ptfile contains the extracted features, thelabels.ptfile contains the corresponding labels. - Then we need to split the features and labels into training, validation and testing sets, which can be done by taking different rows from the two tensor files.
- Finally, store the splitted features and labels into
dir/to/features/{split}/features.ptanddir/to/features/{split}/labels.pt, where{split}can betrain,val, ortest.
Sample script to run SCENT on Glint360K
In the command, gamma denotes the learning rate for the dual variable
python -u train.py \
--algorithm scent \
--data-dir 'dir/to/features/' \
--data-size 17091657 \
--epochs 50 \
--gamma 12.0 \
--lr 5.0 \
--name glint360k_scent \
--num-classes 360232Sample script to run BSGD on Glint360K
python -u train.py \
--algorithm bsgd \
--data-dir 'dir/to/features/' \
--data-size 17091657 \
--epochs 50 \
--lr 1.0 \
--name glint360k_bsgd \
--num-classes 360232Sample script to run ASGD on Glint360K
In the command, gamma denotes the initial learning rate for the dual variable
python -u train.py \
--algorithm asgd \
--data-dir 'dir/to/features/' \
--data-size 17091657 \
--epochs 50 \
--gamma 1.0 \
--lr 0.5 \
--name glint360k_asgd \
--num-classes 360232Sample script to run ASGD (Softplus) on Glint360K
In the command, gamma denotes the initial learning rate for the dual variable
python -u train.py \
--algorithm softplus \
--data-dir 'dir/to/features/' \
--data-size 17091657 \
--epochs 50 \
--gamma 1.0 \
--lr 0.5 \
--name glint360k_softplus \
--num-classes 360232 \
--softplus-rho 1e-3Sample script to run U-max on Glint360K
In the command, gamma denotes the initial learning rate for the dual variable
python -u train.py \
--algorithm umax \
--data-dir 'dir/to/features/' \
--data-size 17091657 \
--epochs 50 \
--gamma 1.0 \
--lr 0.5 \
--name glint360k_umax \
--num-classes 360232 \
--umax-delta 1.0Sample script to run SOX on Glint360K
python -u train.py \
--algorithm sox \
--data-dir 'dir/to/features/' \
--data-size 17091657 \
--epochs 50 \
--gamma 1e-5 \
--lr 5.0 \
--name glint360k_sox \
--num-classes 360232