Customize trainer to control the training process¶
In this tutorial, you will learn how to create and customize your own trainer to control the training process, make predictions, and aggregate results to meet your specific needs. We will first introduce you to the interfaces of the TrainerBase class that you need to implement. Then, we will provide a toy example of the FedProx algorithm (please note that this is just a toy example and should not be used in production) to help you better understand the concept of trainer customization.
TrainerBase Class¶
Basics¶
The TrainerBase Class is the base for all Homo-NN trainer in FATE. To create a custom trainer, you need to subclass the TrainerBase class located in federatedml.homo.trainer_base. There are two required functions that you must implement:
The 'train()' function: This function takes five parameters: a training dataset instance (must be a subclass of Dataset), a validation dataset instance (also a subclass of Dataset), an optimizer instance with initialized training parameters, a loss function, and an extra data dictionary that may contain additional data for a warmstart task. In this function, you can define the process of client-side training and federation for a Homo-NN task.
The 'server_aggregate_procedure()' function: This function takes one parameter, an extra data dictionary that may contain additional data for a warmstart task. It is called by the server and is where you can define the aggregation process.
There is also an optional 'predict()' function that takes one parameter, a dataset, and allows you to define how your trainer makes predictions. If you want to use the FATE framework, you need to ensure that your return data is formatted correctly so that FATE can display it correctly (we will cover this in a later tutorial)."
In the Homo-NN client component, the 'set_model()' function is used to set the initialized model to the trainer. When developing your trainer, you can use 'set_model()' to set your model, and then access it using 'self.model' within the trainer.
Here display the source code of these interfaces:
class TrainerBase(object):
def __init__(self, **kwargs):
...
self._model = None
...
...
@property
def model(self):
if not hasattr(self, '_model'):
raise AttributeError(
'model variable is not initialized, remember to call'
' super(your_class, self).__init__()')
if self._model is None:
raise AttributeError(
'model is not set, use set_model() function to set training model')
return self._model
@model.setter
def model(self, val):
self._model = val
@abc.abstractmethod
def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):
"""
train_set : A Dataset Instance, must be a instance of subclass of Dataset (federatedml.nn.dataset.base),
for example, TableDataset() (from federatedml.nn.dataset.table)
validate_set : A Dataset Instance, but optional must be a instance of subclass of Dataset
(federatedml.nn.dataset.base), for example, TableDataset() (from federatedml.nn.dataset.table)
optimizer : A pytorch optimizer class instance, for example, t.optim.Adam(), t.optim.SGD()
loss : A pytorch Loss class, for example, nn.BECLoss(), nn.CrossEntropyLoss()
"""
pass
@abc.abstractmethod
def predict(self, dataset):
pass
@abc.abstractmethod
def server_aggregate_procedure(self, extra_data={}):
pass
Fed mode/ local mode¶
The Trainer has an attribute 'self.fed_mode' which is set to True when running a federated task. You can use this variable to determine whether your trainer is running in federated mode or in local debug mode. If you want to test the trainer locally, you can use the 'local_mode()' function to set 'self.fed_mode' to False.
Example: Develop A Toy FedProx¶
To help you understand how to implement these functions, we will provide a concrete example by demonstrating a toy implement of the FedProx algorithm from https://arxiv.org/abs/1812.06127. In FedProx, the training process is slightly different from the standard FedAVG algorithm as it requires the computation of a proximal term from the current model and the global model when calculating the loss. We will walk you through the code with comments step by step.
Toy FedProx¶
Here is the code for the trainer, which is saved in the federatedml.nn.homo.trainer module. This trainer implements two functions: train and server_aggregate_procedure. These functions enable the completion of a simple training task. The code includes comments to provide further details.
from pipeline.component.nn import save_to_fate
%%save_to_fate trainer fedprox.py
import copy
from federatedml.nn.homo.trainer.trainer_base import TrainerBase
from torch.utils.data import DataLoader
# We need to use aggregator client&server class for federation
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
# We use LOGGER to output logs
from federatedml.util import LOGGER
class ToyFedProxTrainer(TrainerBase):
def __init__(self, epochs, batch_size, u):
super(ToyFedProxTrainer, self).__init__()
# trainer parameters
self.epochs = epochs
self.batch_size = batch_size
self.u = u
# Given two model, we compute the proximal term
def _proximal_term(self, model_a, model_b):
diff_ = 0
for p1, p2 in zip(model_a.parameters(), model_b.parameters()):
diff_ += ((p1-p2.detach())**2).sum()
return diff_
# implement the train function, this function will be called by client side
# contains the local training process and the federation part
def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):
sample_num = len(train_set)
aggregator = None
if self.fed_mode:
aggregator = SecureAggregatorClient(True, aggregate_weight=sample_num,
communicate_match_suffix='fedprox') # initialize aggregator
# set dataloader
dl = DataLoader(train_set, batch_size=self.batch_size, num_workers=4)
for epoch in range(self.epochs):
# the local training process
LOGGER.debug('running epoch {}'.format(epoch))
global_model = copy.deepcopy(self.model)
loss_sum = 0
# batch training process
for batch_data, label in dl:
optimizer.zero_grad()
pred = self.model(batch_data)
loss_term_a = loss(pred, label)
loss_term_b = self._proximal_term(self.model, global_model)
loss_ = loss_term_a + (self.u/2) * loss_term_b
loss_.backward()
loss_sum += float(loss_.detach().numpy())
optimizer.step()
# print loss
LOGGER.debug('epoch loss is {}'.format(loss_sum))
# the aggregation process
if aggregator is not None:
self.model = aggregator.model_aggregation(self.model)
converge_status = aggregator.loss_aggregation(loss_sum)
# implement the aggregation function, this function will be called by the sever side
def server_aggregate_procedure(self, extra_data={}):
# initialize aggregator
if self.fed_mode:
aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')
# the aggregation process is simple: every epoch the server aggregate model and loss once
for i in range(self.epochs):
aggregator.model_aggregation()
merge_loss, _ = aggregator.loss_aggregation()
Local Test¶
We can use local_mode() to locally test our new FedProx trainer.
import torch as t
from federatedml.nn.dataset.table import TableDataset
model = t.nn.Sequential(
t.nn.Linear(30, 1),
t.nn.Sigmoid()
)
ds = TableDataset()
ds.load('../../../../examples/data/breast_homo_guest.csv')
trainer = ToyFedProxTrainer(10, 128, u=0.1)
trainer.set_model(model)
opt = t.optim.Adam(model.parameters(), lr=0.01)
loss = t.nn.BCELoss()
trainer.local_mode()
trainer.train(ds, None, opt, loss)
running epoch 0 epoch loss is 1.0665020644664764 running epoch 1 epoch loss is 0.9155551195144653 running epoch 2 epoch loss is 0.8021544218063354 running epoch 3 epoch loss is 0.7173515558242798 running epoch 4 epoch loss is 0.6532197296619415 running epoch 5 epoch loss is 0.6034933030605316 running epoch 6 epoch loss is 0.5636875331401825 running epoch 7 epoch loss is 0.5307579338550568 running epoch 8 epoch loss is 0.5026698857545853 running epoch 9 epoch loss is 0.47806812822818756
Great! It can work! Then we will submit a federated task to see if our trainer works correctly.
Submit a New Task to Test ToyFedProx¶
# torch
import torch as t
from torch import nn
from pipeline import fate_torch_hook
fate_torch_hook(t)
# pipeline
from pipeline.component.homo_nn import HomoNN, TrainerParam # HomoNN Component, TrainerParam for setting trainer parameter
from pipeline.backend.pipeline import PipeLine # pipeline class
from pipeline.component import Reader, DataTransform, Evaluation # Data I/O and Evaluation
from pipeline.interface import Data # Data Interaces for defining data flow
# create a pipeline to submitting the job
guest = 9999
host = 10000
arbiter = 10000
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)
# read uploaded dataset
train_data_0 = {"name": "breast_homo_guest", "namespace": "experiment"}
train_data_1 = {"name": "breast_homo_host", "namespace": "experiment"}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=train_data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=train_data_1)
# The transform component converts the uploaded data to the DATE standard format
data_transform_0 = DataTransform(name='data_transform_0')
data_transform_0.get_party_instance(
role='guest', party_id=guest).component_param(
with_label=True, output_format="dense")
data_transform_0.get_party_instance(
role='host', party_id=host).component_param(
with_label=True, output_format="dense")
"""
Define Pytorch model/ optimizer and loss
"""
model = nn.Sequential(
nn.Linear(30, 1),
nn.Sigmoid()
)
loss = nn.BCELoss()
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
"""
Create Homo-NN Component
"""
nn_component = HomoNN(name='nn_0',
model=model, # set model
loss=loss, # set loss
optimizer=optimizer, # set optimizer
# Here we use fedavg trainer
# TrainerParam passes parameters to fedavg_trainer, see below for details about Trainer
trainer=TrainerParam(trainer_name='fedprox', epochs=3, batch_size=128, u=0.5),
torch_seed=100 # random seed
)
# define work flow
pipeline.add_component(reader_0)
pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))
pipeline.add_component(nn_component, data=Data(train_data=data_transform_0.output.data))
pipeline.compile()
pipeline.fit()
2022-12-26 12:22:09.789 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:83 - Job id is 202212261222090031360 2022-12-26 12:22:09.821 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:98 - Job is still waiting, time elapse: 0:00:00 2022-12-26 12:22:10.837 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:98 - Job is still waiting, time elapse: 0:00:01 m2022-12-26 12:22:11.890 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-26 12:22:11.892 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:02 2022-12-26 12:22:12.916 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:03 2022-12-26 12:22:14.015 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:04 2022-12-26 12:22:15.080 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:05 2022-12-26 12:22:16.241 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:06 2022-12-26 12:22:17.336 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:07 2022-12-26 12:22:18.413 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:08 2022-12-26 12:22:19.478 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:09 2022-12-26 12:22:20.570 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:10 m2022-12-26 12:22:22.743 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-26 12:22:22.750 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:12 2022-12-26 12:22:23.844 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:14 2022-12-26 12:22:24.902 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:15 2022-12-26 12:22:26.040 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:16 2022-12-26 12:22:27.112 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:17 2022-12-26 12:22:28.163 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:18 2022-12-26 12:22:29.234 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:19 2022-12-26 12:22:30.286 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:20 2022-12-26 12:22:31.338 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:21 2022-12-26 12:22:32.421 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:22 2022-12-26 12:22:33.498 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:23 2022-12-26 12:22:34.584 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:24 2022-12-26 12:22:35.629 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:25 m2022-12-26 12:22:37.781 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-26 12:22:37.788 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:27 2022-12-26 12:22:38.862 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:29 2022-12-26 12:22:39.914 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:30 2022-12-26 12:22:40.966 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:31 2022-12-26 12:22:42.090 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:32 2022-12-26 12:22:43.141 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:33 2022-12-26 12:22:44.184 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:34 2022-12-26 12:22:45.894 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:36 2022-12-26 12:22:46.966 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:37 2022-12-26 12:22:48.009 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:38 2022-12-26 12:22:49.069 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:39 2022-12-26 12:22:50.120 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:40 2022-12-26 12:22:51.172 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:41 2022-12-26 12:22:52.225 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:42 2022-12-26 12:22:53.272 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:43 2022-12-26 12:22:54.318 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:44 2022-12-26 12:22:55.357 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:45 2022-12-26 12:22:56.444 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:46 2022-12-26 12:22:59.842 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:89 - Job is success!!! Job id is 202212261222090031360 2022-12-26 12:22:59.847 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:90 - Total time: 0:00:50
Yes! This trainer can work correctly. In the next tutorial, we will show you how to use the trainer user interfaces to improve this trainer. These interfaces allow you to return formatted prediction results, evaluate your model, save your model, and display loss curves and performance scores on the fateboard. By using these interfaces, you can enhance the functionality of the trainer and make it more user-friendly.