import tensorflow_hub as hub# Load model from TFHub into KerasLayermodel_url = "https://tfhub.dev/google/bit/m-r50x1/1"module = hub.KerasLayer(model_url)
# use modellogits = imagenet_module(image)
> show_preds(preds, image[0])
class MyBiTModel(tf.keras.Model): """BiT with a new head.""" def __init__(self, num_classes, module): super().__init__() self.num_classes = num_classes self.head = tf.keras.layers.Dense(num_classes, kernel_initializer='zeros') self.bit_model = module def call(self, images): # No need to cut head off since we are using feature extractor model bit_embedding = self.bit_model(images) return self.head(bit_embedding)model = MyBiTModel(num_classes=5, module=module)
# Define optimiser and loss # Decay learning rate by factor of 10 at SCHEDULE_BOUNDARIES.lr = 0.003SCHEDULE_BOUNDARIES = [200, 300, 400, 500]lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=SCHEDULE_BOUNDARIES, values=[lr, lr*0.1, lr*0.001, lr*0.0001])optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy']) # Fine-tune modelmodel.fit( pipeline_train, batch_size=512, steps_per_epoch=10, epochs=50, validation_data=pipeline_test)
# Save fine-tuned model as SavedModelexport_module_dir = '/tmp/my_saved_bit_model/'tf.saved_model.save(model, export_module_dir) # Load saved modelsaved_module = hub.KerasLayer(export_module_dir, trainable=True)