Using FATE-interfaces¶
In this tutorial, we will demonstrate how to use the trainer user interfaces to return formatted prediction results, evaluate the performance of your model, save your model, and display loss curves and performance scores on the dashboard. These interfaces allow your trainer to integrate with the FATE framework and make it easier to work with.
In this tutorial, we will continue to develop our toy FedProx trainer.
The Toy Implementation of FedProx¶
In last tutorial, we provide a concrete example by demonstrating a toy implement of the FedProx algorithm from https://arxiv.org/abs/1812.06127. In FedProx, the training process is slightly different from the standard FedAVG algorithm as it requires the computation of a proximal term from the current model and the global model when calculating the loss. The codes is here:
from pipeline.component.nn import save_to_fate
%%save_to_fate trainer fedprox.py
import copy
import torch as t
from federatedml.nn.homo.trainer.trainer_base import TrainerBase
from torch.utils.data import DataLoader
# We need to use aggregator client&server class for federation
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
# We use LOGGER to output logs
from federatedml.util import LOGGER
class ToyFedProxTrainer(TrainerBase):
def __init__(self, epochs, batch_size, u):
super(ToyFedProxTrainer, self).__init__()
# trainer parameters
self.epochs = epochs
self.batch_size = batch_size
self.u = u
# Given two model, we compute the proximal term
def _proximal_term(self, model_a, model_b):
diff_ = 0
for p1, p2 in zip(model_a.parameters(), model_b.parameters()):
diff_ += (p1-p2.detach()).sum()
return diff_
# implement the train function, this function will be called by client side
# contains the local training process and the federation part
def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):
sample_num = len(train_set)
aggregator = None
if self.fed_mode:
aggregator = SecureAggregatorClient(True, aggregate_weight=sample_num,
communicate_match_suffix='fedprox') # initialize aggregator
# set dataloader
dl = DataLoader(train_set, batch_size=self.batch_size, num_workers=4)
for epoch in range(self.epochs):
# the local training process
LOGGER.debug('running epoch {}'.format(epoch))
global_model = copy.deepcopy(self.model)
loss_sum = 0
# batch training process
for batch_data, label in dl:
optimizer.zero_grad()
pred = self.model(batch_data)
loss_term_a = loss(pred, label)
loss_term_b = self._proximal_term(self.model, global_model)
loss_ = loss_term_a + (self.u/2) * loss_term_b
loss_.backward()
loss_sum += float(loss_.detach().numpy())
optimizer.step()
# pring loss
LOGGER.debug('epoch loss is {}'.format(loss_sum))
# the aggregation process
if aggregator is not None:
self.model = aggregator.model_aggregation(self.model)
converge_status = aggregator.loss_aggregation(loss_sum)
# implement the aggregation function, this function will be called by the sever side
def server_aggregate_procedure(self, extra_data={}):
# initialize aggregator
if self.fed_mode:
aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')
# the aggregation process is simple: every epoch the server aggregate model and loss once
for i in range(self.epochs):
aggregator.model_aggregation()
merge_loss, _ = aggregator.loss_aggregation()
User interfaces¶
Now we introduce you the user-interfaces offered by the TrainerBase class, we will use these function to improve our trainer.
format_predict_result¶
This function will organize your prediction results and return a StdReturnFormat object, which wraps the results. You can use this function at the end of your prediction function to return a standardized format that the FATE framework can parse and display them on the fateboard. This standardized format also allows downstream components, such as the evaluation component, to use the prediction results.
This function takes four arguments:
- sample_ids: a list of sample IDs
- predict_result: a tensor of prediction scores
- true_label: a tensor of true labels
- task_type: the type of task being performed. The default is 'auto', which will automatically infer the task type. Other options include 'binary', 'multi', and 'regression'. Currently, the FATE dashboard only supports the display of binary/multi classification and regression tasks. If 'auto' is chosen, the task type will be inferred automatically.
We will implement a prediction in FedProx trainer later.
import torch as t
from typing import List
def format_predict_result(self, sample_ids: List, predict_result: t.Tensor,
true_label: t.Tensor, task_type: str = None):
...
callback_metric & callback_loss¶
As the names suggest, these two functions enable you to save data points and display custom evaluation metrics and loss curves on the fateboard.
When using the callback metric function, you need to provide the metric name, a float value, and specify the metric type ('train' or 'validate') and the epoch index. When using the callback loss function, you need to provide a float loss value and the epoch index. Your data will be displayed on the fateboard.
def callback_metric(self, metric_name: str, value: float, metric_type='train', epoch_idx=0):
...
def callback_loss(self, loss: float, epoch_idx: int):
...
summary¶
This function allows you to save a summary of the training process, such as the loss history and the best epoch, in a dictionary. You can retrieve this summary from the pipeline once the task is completed.
def summary(self, summary_dict: dict):
...
save & checkpoint¶
You can save your models using the 'save' and set model checkpoint using 'checkpoint' function. It's important to note that:
- 'save' only stores the model in memory, so the model you save will be the one that was last saved using the 'save' function.
- 'checkpoint' directly saves the model to disk.
- 'save' should only be called on the client side (in the 'train' function), while 'checkpoint' should be called on both the client and server side (in the 'train' and 'server_aggregate_procedure' functions) to ensure that the checkpoint mechanism works correctly.
The 'extra_data' parameter in the function allows you to save additional data in a dictionary. This can be useful when warm-starting a model, as you can retrieve the saved data using the 'extra_data' parameter in the 'train' and 'server_aggregate_procedure' functions.
def save(
self,
model=None,
epoch_idx=-1,
optimizer=None,
converge_status=False,
loss_history=None,
best_epoch=-1,
extra_data={}): ...
def checkpoint(
self,
epoch_idx,
model=None,
optimizer=None,
converge_status=False,
loss_history=None,
best_epoch=-1,
extra_data={}): ...
local_save and local_checkpoint¶
In the FATE framework, the standard format for saving models is protobuf. However, these local_save and local_checkpoint functions are designed for situations where you have a large model that cannot be accommodated within a protobuf. In such cases, the model is saved to a local path, and the path is recorded in the model protobuf. The 'local_save' and 'local_checkpoint' functions are used to save the model and other related information to the local file system. The path of the file is structured in the following way:
'fateflow/jobs/${jobid}/${party}/${party_id}/${your_nn_component}'
where 'jobid' is the unique identifier for the training job, 'party' is the role of the party (e.g., host or guest), 'party_id' is the ID of the party, and 'your_nn_component' is the name of your neural network component.
This function It is important to note that the 'local_save' function is intended for use only within the client-side code, and should not be used in the 'server_aggregate_procedure' function. The 'local_checkpoint' function can be used within both the client and server-side code.
def local_save(self,
model=None,
epoch_idx=-1,
optimizer=None,
converge_status=False,
loss_history=None,
best_epoch=-1,
extra_data={}):
...
def local_checkpoint(self,
model=None,
epoch_idx=-1,
optimizer=None,
converge_status=False,
loss_history=None,
best_epoch=-1,
extra_data={}):
...
evaluation¶
This interface allows you to evaluate your model by automatically computing various performance metrics The metrics that are computed depend on the type of your dataset and task
- Binary classification: 'AUC' and 'ks'
- Multi-class classification: 'accuracy', 'precision', and 'recall'
- Regression: 'rmse' and 'mae'
You can specify the type of your dataset ('train' or 'validate') and the task type ('binary', 'multi', or 'regression') in the parameters. If no task type is specified, it will be automatically inferred from your scores and labels.
def evaluation(self, sample_ids: list, pred_scores: t.Tensor, label: t.Tensor, dataset_type='train',
epoch_idx=0, task_type=None):
...
Improved FedProx Trainer¶
In this section, we will use the interfaces introduced earlier to improve our FedProx Trainer and make it a more comprehensive training tool. We:
- we implement the predict function, and it return a formatted result
- add evaluation function
- save model at the end of the training
- callback loss to save loss curves
- we compute accuracy scores and display then using callback metrics.
from pipeline.component.nn import save_to_fate
%%save_to_fate trainer fedprox_v2.py
import copy
import torch as t
from federatedml.nn.homo.trainer.trainer_base import TrainerBase
from federatedml.nn.dataset.base import Dataset
from torch.utils.data import DataLoader
# We need to use aggregator client&server class for federation
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
# We use LOGGER to output logs
from federatedml.util import LOGGER
class ToyFedProxTrainer(TrainerBase):
def __init__(self, epochs, batch_size, u):
super(ToyFedProxTrainer, self).__init__()
# trainer parameters
self.epochs = epochs
self.batch_size = batch_size
self.u = u
# Given two model, we compute the proximal term
def _proximal_term(self, model_a, model_b):
diff_ = 0
for p1, p2 in zip(model_a.parameters(), model_b.parameters()):
diff_ += t.sqrt((p1-p2.detach())**2).sum()
return diff_
# implement the train function, this function will be called by client side
# contains the local training process and the federation part
def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):
sample_num = len(train_set)
aggregator = None
if self.fed_mode:
aggregator = SecureAggregatorClient(True, aggregate_weight=sample_num,
communicate_match_suffix='fedprox') # initialize aggregator
# set dataloader
dl = DataLoader(train_set, batch_size=self.batch_size, num_workers=4)
loss_history = []
for epoch in range(self.epochs):
# the local training process
LOGGER.debug('running epoch {}'.format(epoch))
global_model = copy.deepcopy(self.model)
loss_sum = 0
# batch training process
for batch_data, label in dl:
optimizer.zero_grad()
pred = self.model(batch_data)
loss_term_a = loss(pred, label)
loss_term_b = self._proximal_term(self.model, global_model)
loss_ = loss_term_a + (self.u/2) * loss_term_b
LOGGER.debug('loss is {} loss a is {} loss b is {}'.format(loss_, loss_term_a, loss_term_b))
loss_.backward()
loss_sum += float(loss_.detach().numpy())
optimizer.step()
# print loss
LOGGER.debug('epoch loss is {}'.format(loss_sum))
loss_history.append(loss_sum)
# we callback loss here
self.callback_loss(loss_sum, epoch)
# we evaluate out model here
sample_ids, preds, labels = self._predict(train_set)
self.evaluation(sample_ids, preds, labels, 'train', task_type='binary', epoch_idx=epoch)
# we manually compute accuracy:
acc = ((preds > 0.5 + 0) == labels).sum() / len(labels)
acc = float(acc.detach().numpy())
self.callback_metric('my_accuracy', acc, epoch_idx=epoch)
# the aggregation process
if aggregator is not None:
self.model = aggregator.model_aggregation(self.model)
converge_status = aggregator.loss_aggregation(loss_sum)
# We will save model at the end of the training
self.save(self.model, epoch, optimizer)
# We will save model summary
self.summary({'loss_history': loss_history})
# implement the aggregation function, this function will be called by the sever side
def server_aggregate_procedure(self, extra_data={}):
# initialize aggregator
if self.fed_mode:
aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')
# the aggregation process is simple: every epoch the server aggregate model and loss once
for i in range(self.epochs):
aggregator.model_aggregation()
merge_loss, _ = aggregator.loss_aggregation()
def _predict(self, dataset: Dataset):
len_ = len(dataset)
dl = DataLoader(dataset, batch_size=len_)
preds, labels = None, None
for data, l in dl:
preds = self.model(data)
labels = l
sample_ids = dataset.get_sample_ids()
return sample_ids, preds, labels
# We implement the predict function here
def predict(self, dataset):
sample_ids, preds, labels = self._predict(dataset)
return self.format_predict_result(sample_ids, preds, labels, 'binary')
Submit a Pipeline¶
Here we submit a new pipeline to test our new trainer
# torch
import torch as t
from torch import nn
from pipeline import fate_torch_hook
fate_torch_hook(t)
# pipeline
from pipeline.component.homo_nn import HomoNN, TrainerParam # HomoNN Component, TrainerParam for setting trainer parameter
from pipeline.backend.pipeline import PipeLine # pipeline class
from pipeline.component import Reader, DataTransform, Evaluation # Data I/O and Evaluation
from pipeline.interface import Data # Data Interaces for defining data flow
# create a pipeline to submitting the job
guest = 9999
host = 10000
arbiter = 10000
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)
# read uploaded dataset
train_data_0 = {"name": "breast_homo_guest", "namespace": "experiment"}
train_data_1 = {"name": "breast_homo_host", "namespace": "experiment"}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=train_data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=train_data_1)
# The transform component converts the uploaded data to the DATE standard format
data_transform_0 = DataTransform(name='data_transform_0')
data_transform_0.get_party_instance(
role='guest', party_id=guest).component_param(
with_label=True, output_format="dense")
data_transform_0.get_party_instance(
role='host', party_id=host).component_param(
with_label=True, output_format="dense")
"""
Define Pytorch model/ optimizer and loss
"""
model = nn.Sequential(
nn.Linear(30, 1),
nn.Sigmoid()
)
loss = nn.BCELoss()
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
"""
Create Homo-NN Component
"""
nn_component = HomoNN(name='nn_0',
model=model, # set model
loss=loss, # set loss
optimizer=optimizer, # set optimizer
# Here we use fedavg trainer
# TrainerParam passes parameters to fedavg_trainer, see below for details about Trainer
trainer=TrainerParam(trainer_name='fedprox_v2', epochs=3, batch_size=128, u=0.5),
torch_seed=100 # random seed
)
# define work flow
pipeline.add_component(reader_0)
pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))
pipeline.add_component(nn_component, data=Data(train_data=data_transform_0.output.data))
pipeline.compile()
pipeline.fit()
2022-12-26 16:13:37.650 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:83 - Job id is 202212261613368298170 2022-12-26 16:13:37.665 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:98 - Job is still waiting, time elapse: 0:00:00 m2022-12-26 16:13:38.698 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-26 16:13:38.700 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:01 2022-12-26 16:13:39.722 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:02 2022-12-26 16:13:40.754 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:03 2022-12-26 16:13:41.783 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:04 2022-12-26 16:13:42.810 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:05 2022-12-26 16:13:43.934 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:06 2022-12-26 16:13:45.045 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:07 2022-12-26 16:13:46.081 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:08 2022-12-26 16:13:47.107 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:09 m2022-12-26 16:13:48.155 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-26 16:13:48.160 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:10 2022-12-26 16:13:49.233 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:11 2022-12-26 16:13:50.274 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:12 2022-12-26 16:13:51.299 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:13 2022-12-26 16:13:52.342 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:14 2022-12-26 16:13:53.377 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:15 2022-12-26 16:13:54.448 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:16 2022-12-26 16:13:55.483 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:17 2022-12-26 16:13:56.546 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:18 m2022-12-26 16:13:58.785 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-26 16:13:58.791 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:21 2022-12-26 16:13:59.883 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:22 2022-12-26 16:14:00.921 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:23 2022-12-26 16:14:01.997 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:24 2022-12-26 16:14:03.025 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:25 2022-12-26 16:14:04.053 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:26 2022-12-26 16:14:05.094 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:27 2022-12-26 16:14:06.134 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:28 2022-12-26 16:14:07.175 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:29 2022-12-26 16:14:08.221 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:30 2022-12-26 16:14:09.268 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:31 2022-12-26 16:14:10.292 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:32 2022-12-26 16:14:11.476 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:33 2022-12-26 16:14:12.497 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:34 2022-12-26 16:14:13.542 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:35 2022-12-26 16:14:14.577 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:36 2022-12-26 16:14:15.608 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:37 2022-12-26 16:14:17.720 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:89 - Job is success!!! Job id is 202212261613368298170 2022-12-26 16:14:17.721 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:90 - Total time: 0:00:40