Keras fit, fit_generator, train_on_batch

The Keras deep learning library provides three different methods to train deep learning models. All these model training methods have their own specialized property to train the deep neural network model.

  • .fit
  • .fit_generator
  • .train_on_batch

 

All these three model training methods used to achieve the same work to train the deep learning model, but they work in a very different way.

In this tutorial, we will learn about all these model training functions with example. And also get to know about the differences between them and when to use a particular training method.

Let’s explore the different behaviour of all these model training methods with an example and see when they use in particular scenarios.

Keras.fit() method

fit(x=train_x, y=train_y, batch_size=64, epochs=20)

This method train the model on input data train_x as features and train_y as the target variable. Once you hit this method, the entire input train dataset load into RAM and start to train the model.

This is one of the easiest solution to train the neural network model using .fit method. However, it is not a feasible solution when you are working with the huge dataset. For large training data, it requires lots of resources to fit the entire training data into RAM. Hence, it’s not quite good to use in real-life problems where the amount of data is large enough.

Keras’ .fit method works like a static function. Once you feed the data, you are not allowed to manipulating the training data during training using data augmentation. Please refer to this tutorial to learn about data augmentation.

The main two primary premises of Keras.fit is:

  • While training the model, our entire training data will fit into RAM
  • not allows performing real-time data augmentation on images

 

The Keras.fit() has various parameters:

fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False)

Keras.fit_generator()

It is perfectly fine to use Keras.fit() function when you are train model on a small and simplest dataset.  But, real-life problems have a huge amount of data that are unable to load into memory. It also required to perform data augmentation to avoid overfitting to make the model more generalized. In these situations, we should use Keras’ fit_generator() function to train the model. The .fit_generator doesn’t accept the X and Y directly, need to pass through the generator.

The Keras.fit_generator() train the model on data generated batch-by-batch by a Python generator. Keras’ fit_generator method is a dynamic method that takes the input training data from Python generator function.

This generator function applies random data augmentation such as rotation, flipping, resizing, etc on training data during the training phase. Data augmentation make the model stronger and generalized. The generator function returns the batches of the data as defined by the batch size parameter.

Let’s define the generator function using Keras’ ImageDataGenerator class. Please refer to this tutorial to understand the various methods of Keras’ ImageDataGenerator class.

train_datagen = ImageDataGenerator(rotation_range=30, zoom_range=0.15, 
                                   width_shift_range=0.2, height_shift_range=0.2, 
                                   shear_range=0.15, horizontal_flip=True, 
                                   fill_mode="nearest")

In the next step, we need to feed the generator method to Keras.fit_generator method. On each iteration .fit_generator function get the batch of data from the generator and perform forward and back-propagation operation on it. It is called training the model.

To feed the Keras’ ImageDataGenerator class to .fit_generator method, three methods exist such as

batch_size = 32

train_generator = train_datagen.flow(trainX, trainY, batch_size=batch_size)

Here, the generator function runs forever. we forcefully need to terminate it. So, to control over it steps_per_epoch parameter is used.

The steps_per_epoch parameter equal to the ceil(num_samples / Batch_size). An epoch is considered to finish when steps_per_epoch batches have been seen by the model.

steps_per_epoch = len(trainX) // batch_size
fit_generator(train_generator, steps_per_epoch=steps_per_epoch, epochs=10, verbose=1)

Keras.train_on_batch

Keras’ train_on_batch  function accepts a single batch of data, perform backpropagation on it and then update the model parameters. The batch of data can be any size- doesn’t require to define explicitly.

When you want to train the model with your own custom rule and want to take the entire control over model training, you can use Keras.train_on_batch() function.

train_on_batch(x, y, sample_weight=None, class_weight=None, reset_metrics=True)

If you want to perform some custom changes after each batch training, Keras.train_on_batch() is best choice to use. However, it’s quite a complex method than traditional model training.

Summary

In this tutorial, you have discover the Keras’s different model training function such as .fit, .fit_generator and .train_on_batch. You have also got to know about the difference between all these functions and when to use the particular training method.

.     .     .

Leave a Reply

Your email address will not be published. Required fields are marked *

Computer Vision Tutorials

Prepare COCO dataset of a specific subset of classes for semantic image segmentation

YOLOV4: Train a yolov4-tiny on the custom dataset using google colab.

Video classification techniques with Deep Learning

Keras ImageDataGenerator with flow_from_dataframe()

Keras ImageDataGenerator with flow_from_directory()

Keras ImageDataGenerator with flow()

Keras ImageDataGenerator

Keras Modeling | Sequential vs Functional API

Save and Load Keras Model

Convolutional Neural Networks (CNN) with Keras in Python

Transfer Learning for Image Recognition Using Pre-Trained Models

An introduction to Transfer Learning

Keras ImageDataGenerator and Data Augmentation

Introduction to Computer Vision