-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathCTRL_CCNET_ANY_ACQ_TXT.sh
103 lines (97 loc) · 4.17 KB
/
CTRL_CCNET_ANY_ACQ_TXT.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
echo "Run start_script.sh -> src/bash/start_script.sh"
sh src/bash/start_script.sh
echo "Create datalist -> src/python/preprocessing/create_datalist.py"
output_path="project/outputs/ids/"
#root_path="/home/richard/Desktop/software/data/ldm_data_all_phases_cropped_split1602"
root_path="../data/ldm_data_all_phases_cropped_split1602"
python3 src/python/preprocessing/create_datalist.py \
--output_path ${output_path} \
--root_path ${root_path}
echo "Run train_controlnet.py -> src/python/training/train_controlnet.py"
seed=42
run_dir="CTRL_CCNET_ANY_ACQ_TXT_v1"
# Setting of artifact urls
#stage1_uri="project/outputs/runs/aekl_v1/checkpoint_best_val_loss_e10.pth"
#stage1_uri= # TODO: Set AE_KL artifact url "/project/mlruns/837816334068618022/39336906f86c4cdc96fb6464b88c8c06/artifacts/final_model"
#ddpm_uri= # TODO: Set ldm artifact url "/project/mlruns/102676348294480761/a53f700f40184ff49f5f7e27fafece97/artifacts/final_model"
#stage1_uri="mlruns/809882896951469465/b32e98be004e4fdc92b11c546a9059e7/artifacts/final_model"
#ddpm_uri="mlruns/.../.../artifacts/final_model"
ddpm_uri="project/outputs/runs/LDM_CCNET_ANY_ACQ_TXT_v17/ldm_final_model_val_loss_06331433683314406_epoch100.pth"
stage1_uri=""
# TODO Do we init_from_unet ?
training_ids="project/outputs/ids/train.tsv"
validation_ids="project/outputs/ids/validation.tsv"
config_file="configs/controlnet/controlnet_v1.yaml" # TODO LR 5e-5
scale_factor=0.1 #0.18215 #0.1 #0.01 #=0.3
batch_size=4 #4 #2 #8 #16 #384 #8 #16 #64 #256 #512
n_epochs=100 #150
eval_freq=1
num_workers=4 #16 #64
experiment=${run_dir}
is_resumed=true #false #true
use_pretrained=1 # loading only the VAE but not the LDM as pretrained models (from source_model)
source_model="stabilityai/stable-diffusion-2-1-base"
torch_detect_anomaly=0 # whether to use torch.autograd.detect_anomaly() or not (o not, 1 yes)
#early_stopping_after_num_epochs=20
#--early_stopping_after_num_epochs ${early_stopping_after_num_epochs} \
img_width=224 #512
img_height=224 #512
#clip_grad_norm_by=15.0
#clip_grad_norm_or_value='value'
controlnet_conditioning_scale=1.0
if $is_resumed ; then
python3 src/python/training/train_controlnet.py \
--seed ${seed} \
--run_dir ${run_dir} \
--training_ids ${training_ids} \
--validation_ids ${validation_ids} \
--stage1_uri=${stage1_uri} \
--ddpm_uri ${ddpm_uri} \
--config_file ${config_file} \
--batch_size ${batch_size} \
--n_epochs ${n_epochs} \
--eval_freq ${eval_freq} \
--num_workers ${num_workers} \
--experiment ${experiment} \
--use_pretrained ${use_pretrained} \
--source_model ${source_model} \
--torch_detect_anomaly ${torch_detect_anomaly} \
--img_width ${img_width} \
--img_height ${img_height} \
--scale_factor=${scale_factor} \
--controlnet_conditioning_scale=${controlnet_conditioning_scale} \
--"cond_on_acq_times" \
--"is_resumed"
#--"is_ldm_fine_tuned"
#--"init_from_unet"
#--"use_default_report_text" \
#--clip_grad_norm_by ${clip_grad_norm_by} \
#--clip_grad_norm_or_value ${clip_grad_norm_or_value} \
else
python3 src/python/training/train_controlnet.py \
--seed ${seed} \
--run_dir ${run_dir} \
--training_ids ${training_ids} \
--validation_ids ${validation_ids} \
--stage1_uri=${stage1_uri} \
--ddpm_uri ${ddpm_uri} \
--config_file ${config_file} \
--batch_size ${batch_size} \
--n_epochs ${n_epochs} \
--eval_freq ${eval_freq} \
--num_workers ${num_workers} \
--experiment ${experiment} \
--use_pretrained ${use_pretrained} \
--source_model ${source_model} \
--torch_detect_anomaly ${torch_detect_anomaly} \
--img_width ${img_width} \
--img_height ${img_height} \
--scale_factor=${scale_factor} \
--controlnet_conditioning_scale=${controlnet_conditioning_scale} \
--"cond_on_acq_times"
#--"is_ldm_fine_tuned"
#--"init_from_unet"
#--"use_default_report_text" \
#--clip_grad_norm_by ${clip_grad_norm_by} \
#--clip_grad_norm_or_value ${clip_grad_norm_or_value} \
fi