-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain_engine.py
61 lines (45 loc) · 1.58 KB
/
train_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
'''
@author: Bappy Ahmed
Email: [email protected]
Date: 06-sep-2021
'''
from utils import model
from utils import data_manager as dm
from utils.config import configureModel
from utils import callbacks
import tensorflow as tf
config_model = configureModel()
def train():
"""The logic for one training step.
This method should contain the mathematical logic for one step of training.
This typically includes the forward pass, loss calculation, backpropagation,
and metric updates.
Args:
data: A nested structure of `Tensor`s.
Returns:
A `dict` containing values.Typically, the
values of the `Model`'s metrics are returned. Example:
`{'loss': 0.2, 'accuracy': 0.7}`.
"""
model_obj = model.load_pretrain_model()
my_model = model_obj
train_data, valid_data = dm.train_valid_generator()
#callbacks
log_dir = callbacks.get_log_path()
tb_cb = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
ckp = callbacks.checkpoint()
call = [tb_cb, ckp]
#Calculating steps_per_epoch & validation_steps
steps_per_epoch = train_data.samples // train_data.batch_size
validation_steps = valid_data.samples // valid_data.batch_size
my_model.fit(
train_data,
validation_data=valid_data,
epochs=config_model['EPOCHS'],
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
callbacks=call
)
new_path = f"New_trained_model/{'new'+config_model['MODEL_NAME']+'.h5'}"
my_model.save(new_path)
print(f"Model saved at the following location : {new_path}")