Tensorflow Hub is a package for reusable machine learning modules in Tensorflow. A module consists of the model architecture along with its weights trained on very large datasets. This is also known as a pre-trained model. The pre-trained model can be fine-tuned across multiple tasks is known as transfer learning.
The shortage of training data often becomes the biggest problem in real-life scenarios. To achieve good performance, sufficient amount of the training samples are essential.
TF-hub provides various machine learning pre-trained models. Using Pre-trained model we can create the customized model for a particular task without having the high computing power or the data used to train the model. we do not need to train the model from scratch.
Training of the pre-trained models are very expensive, they took thousands of GPU-hours to train on a huge dataset. By using the pre-trained model, a developer can reuse and build a model with a smaller dataset.
This tutorial has illustrated an example to use a pre-trained model of TF-hub.
Data –
We will use Kaggle’s data of Quora Insincere Questions Classification task for the demonstration.
In [1]: # Let's load the required packages import tensorflow as tf import tensorflow_hub as hub import numpy as np import os import pandas as pd from sklearn.model_selection import train_test_split
Prepare Training Data
In [2]: train_df = pd.read_csv('input/train.csv') test_df = pd.read_csv('input/test.csv') train_df['target']=train_df['target'].astype('float') train_df.head()
In [3]: train, val = train_test_split(train_df, test_size = 0.1, random_state=42)
Prepare Model –
Specify the Input functions that wrap Pandas dataframes.
In [4]: train_input_fn = tf.estimator.inputs.pandas_input_fn(train, train["target"], num_epochs=None, shuffle=True, target_column='') val_input_fn = tf.estimator.inputs.pandas_input_fn(val, val["target"], shuffle=False,target_column='') test_input_fn = tf.estimator.inputs.pandas_input_fn(test_df, shuffle=False,target_column='')
Define the feature column that applies a module on the given text feature. The module takes a batch of sentences in a 1-D tensor of strings as input. The module performs some text pre-processing on text data such as removal of special symbol, splitting a text, etc.
In this tutorial, we will use nnlm-en-dim128 module as a pre-trained text embedding module.
In [5]: embedded_text_feature_column = hub.text_embedding_column( key="question_text", module_spec="https://tfhub.dev/google/nnlm-en-dim128/1")
Define Classifier –
Here, we will use the DNN Classifier for text classification.
In [6]: estimator = tf.estimator.DNNClassifier( hidden_units=[500, 100], feature_columns=[embedded_text_feature_column], n_classes=2, optimizer=tf.train.AdagradOptimizer(learning_rate=0.003))
Model Training
In [7]: estimator.train(input_fn=train_input_fn, steps=1000);
Model evaluation
In [8]: eval_val_result = estimator.evaluate(input_fn=val_input_fn) print("Validation set accuracy: {accuracy}".format(**eval_val_result)) Out[8]: Validation set accuracy: 0.9473023414611816
Make Prediction on Test Data
In [9]: test_prediction = estimator.predict(input_fn=test_input_fn, predict_keys="probabilities")
In [10]: prob = [] for e in test_prediction: y_prob = e['probabilities'] y_classes = y_prob.argmax(axis=-1) prob.append(y_classes) test_df['target'] = prob
In [11]: test_df.head()
Training the classifier together with the module –
In [12]: embedded_text_feature_column = hub.text_embedding_column( key="question_text", module_spec="https://tfhub.dev/google/nnlm-en-dim128/1", trainable=True # Whether or not the Module is trainable.(Default-False) ) estimator = tf.estimator.DNNClassifier( hidden_units=[500, 100], feature_columns=[embedded_text_feature_column], n_classes=2, optimizer=tf.train.AdagradOptimizer(learning_rate=0.003)) # Train the model estimator.train(input_fn=train_input_fn, steps=1000) # Evaluate the model eval_val_result = estimator.evaluate(input_fn=val_input_fn) print("Validation set accuracy: {}".format(eval_val_result["accuracy"])) Out[12]: Validation set accuracy: 0.9511151313781738
. . .