Homo-NN Customize Model¶
Build a model¶
In FATE 1.10, you can use the pipeline to submit a PyTorch Sequential model. However, the Sequential model, in combination with PyTorch's built-in layers, may not be sufficient for representing more complex models. For instance, when constructing a residual module similar to those found in ResNet, the output of some modules needs to be reused, which may not be possible using the Sequential model.
To address this issue, the model_zoo module was introduced in FATE 1.10, located under federatedml.nn.model_zoo. This module allows you to customize your own PyTorch model, provided that it is developed based on torch.nn.Module and implements the forward interface. For more information, see the PyTorch documentation Pytorch Module on custom modules. To use your custom model in a federated task, simply place it in the federatedml/nn/model_zoo directory and specify the module and model class through the interface when submitting the task. Homo-NN will automatically search and import the model you have implemented.
As an example, consider the task of MNIST handwriting recognition. We can first write a simple neural network module with residual connections locally, and then use it in a federated task.
A Customized Model¶
Name the model code image_net.py, you can put it directly under federatedml/nn/model_zoo or use the shortcut interface of jupyter notebook to save it directly to federatedml/nn/model_zoo
from pipeline.component.nn import save_to_fate
%%save_to_fate model image_net.py
import torch as t
from torch import nn
from torch.nn import Module
# the residual component
class Residual(Module):
def __init__(self, ch, kernel_size=3, padding=1):
super(Residual, self).__init__()
self.convs = t.nn.ModuleList([nn.Conv2d(ch, ch, kernel_size=kernel_size, padding=padding) for i in range(2)])
self.act = nn.ReLU()
def forward(self, x):
x = self.act(self.convs[0](x))
x_ = self.convs[1](x)
return self.act(x + x_)
# we call it image net
class ImgNet(nn.Module):
def __init__(self, class_num=10):
super(ImgNet, self).__init__()
self.seq = t.nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5),
Residual(12),
nn.MaxPool2d(kernel_size=3),
nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3),
Residual(12),
nn.AvgPool2d(kernel_size=3)
)
self.fc = t.nn.Sequential(
nn.Linear(48, 32),
nn.ReLU(),
nn.Linear(32, class_num)
)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = self.seq(x)
x = x.flatten(start_dim=1)
x = self.fc(x)
if self.training:
return x
else:
return self.softmax(x)
img_model = ImgNet(10)
img_model
ImgNet( (seq): Sequential( (0): Conv2d(3, 12, kernel_size=(5, 5), stride=(1, 1)) (1): Residual( (convs): ModuleList( (0): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (act): ReLU() ) (2): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False) (3): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1)) (4): Residual( (convs): ModuleList( (0): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (act): ReLU() ) (5): AvgPool2d(kernel_size=3, stride=3, padding=0) ) (fc): Sequential( (0): Linear(in_features=48, out_features=32, bias=True) (1): ReLU() (2): Linear(in_features=32, out_features=10, bias=True) ) (softmax): Softmax(dim=1) )
from federatedml.nn.dataset.image import ImageDataset
ds = ImageDataset()
ds.load('../../../../examples/data/mnist/')
img_model(i[0])
tensor([[ 1.3241e-01, 1.3432e-01, 3.6705e-02, 3.9092e-02, -1.2944e-01, 5.4261e-02, -1.8467e-01, 1.0478e-01, 1.0396e-03, 4.6396e-02], [ 1.3575e-01, 1.3287e-01, 3.7010e-02, 3.5438e-02, -1.3169e-01, 4.9747e-02, -1.8520e-01, 1.0215e-01, 3.3909e-03, 4.6577e-02], [ 1.3680e-01, 1.3542e-01, 3.6674e-02, 3.4830e-02, -1.3046e-01, 4.8866e-02, -1.8568e-01, 1.0199e-01, 4.7719e-03, 4.7090e-02], [ 1.3564e-01, 1.3297e-01, 3.6487e-02, 3.5213e-02, -1.3040e-01, 5.0300e-02, -1.8406e-01, 1.0286e-01, 3.6997e-03, 4.4414e-02], [ 1.3091e-01, 1.3101e-01, 3.5820e-02, 3.9637e-02, -1.3302e-01, 5.2289e-02, -1.8336e-01, 1.0439e-01, 2.8879e-03, 4.4465e-02], [ 1.3206e-01, 1.3344e-01, 3.7300e-02, 3.8817e-02, -1.3155e-01, 5.3004e-02, -1.8556e-01, 1.0341e-01, 7.9196e-05, 4.6511e-02], [ 1.3058e-01, 1.3162e-01, 3.5691e-02, 4.0402e-02, -1.3395e-01, 5.1268e-02, -1.8198e-01, 1.0670e-01, 3.6078e-03, 4.4348e-02], [ 1.3416e-01, 1.3208e-01, 3.6845e-02, 3.6941e-02, -1.3210e-01, 5.2559e-02, -1.8635e-01, 1.0151e-01, 1.1148e-03, 4.7174e-02]], grad_fn=<AddmmBackward0>)
run a local test¶
We can use our dataset, custom model, and Trainer for local debugging to test whether the program can run through. In the case of local testing, all federation processes will be skipped, and the model will not perform fed averaging
import torch as t
from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer
trainer = FedAVGTrainer(epochs=3, batch_size=256, shuffle=True, data_loader_worker=8, pin_memory=False)
trainer.set_model(img_model) # set model
trainer.local_mode() # !! use local mode to skip federation process !!
optimizer = t.optim.Adam(img_model.parameters(), lr=0.01)
loss = t.nn.CrossEntropyLoss()
trainer.train(train_set=ds, optimizer=optimizer, loss=loss)
epoch is 0 100%|██████████| 6/6 [00:00<00:00, 7.00it/s] epoch loss is 1.732767325125185 epoch is 1 100%|██████████| 6/6 [00:01<00:00, 4.28it/s] epoch loss is 0.9436628721978848 epoch is 2 100%|██████████| 6/6 [00:00<00:00, 6.72it/s] epoch loss is 0.6457311573421982
It works! Now we can submit a federated task.
Submit a Homo-NN Task with Custom Model¶
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model
t = fate_torch_hook(t)
import os
# bind data path to name & namespace
fate_project_path = os.path.abspath('../../../../')
host = 10000
guest = 9999
arbiter = 10000
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,
arbiter=arbiter)
data_0 = {"name": "mnist_guest", "namespace": "experiment"}
data_1 = {"name": "mnist_host", "namespace": "experiment"}
data_path_0 = fate_project_path + '/examples/data/mnist'
data_path_1 = fate_project_path + '/examples/data/mnist'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)
pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)
{'namespace': 'experiment', 'table_name': 'mnist_host'}
# 定义reader
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)
nn.CustModel¶
After fate_torch_hook, we can use t.nn.CustModel to specify models. You should specify module name and class name here. Model initialization parameters can also be set here. The initialization parameter must be JSON-serializable, otherwise, this pipeline can not be submitted.
from pipeline.component.homo_nn import DatasetParam, TrainerParam
model = t.nn.Sequential(
# the class_num=10 is the initialzation parameter for your model
t.nn.CustModel(module_name='image_net', class_name='ImgNet', class_num=10)
)
nn_component = HomoNN(name='nn_0',
model=model, # your cust model
loss=t.nn.CrossEntropyLoss(),
optimizer=t.optim.Adam(model.parameters(), lr=0.01),
dataset=DatasetParam(dataset_name='image'), # use image dataset
trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=3, batch_size=1024, validation_freqs=1),
torch_seed=100 # global random seed
)
pipeline.add_component(reader_0)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))
pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))
<pipeline.backend.pipeline.PipeLine at 0x7fafb4d99370>
pipeline.compile()
pipeline.fit()
2022-12-19 22:07:14.965 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:83 - Job id is 202212192207123770130 2022-12-19 22:07:14.974 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:98 - Job is still waiting, time elapse: 0:00:00 2022-12-19 22:07:15.990 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:98 - Job is still waiting, time elapse: 0:00:01 m2022-12-19 22:07:17.016 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-19 22:07:17.019 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:02 2022-12-19 22:07:18.047 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:03 2022-12-19 22:07:19.066 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:04 2022-12-19 22:07:20.085 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:05 2022-12-19 22:07:21.117 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:06 2022-12-19 22:07:22.145 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:07 m2022-12-19 22:07:24.547 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-19 22:07:24.552 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:09 2022-12-19 22:07:25.598 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:10 2022-12-19 22:07:26.654 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:11 2022-12-19 22:07:27.679 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:12 2022-12-19 22:07:28.736 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:13 2022-12-19 22:07:29.762 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:14 2022-12-19 22:07:30.795 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:15 2022-12-19 22:07:31.877 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:16 2022-12-19 22:07:33.008 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:18 2022-12-19 22:07:34.045 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:19 2022-12-19 22:07:35.074 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:20 2022-12-19 22:07:36.170 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:21 2022-12-19 22:07:37.332 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:22 2022-12-19 22:07:38.620 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:23 2022-12-19 22:07:39.759 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:24 2022-12-19 22:07:40.844 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:25 2022-12-19 22:07:41.969 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:27 2022-12-19 22:07:42.992 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:28 2022-12-19 22:07:44.132 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:29 2022-12-19 22:07:45.206 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:30 2022-12-19 22:07:46.239 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:31 2022-12-19 22:07:47.350 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:32 2022-12-19 22:07:48.424 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:33 2022-12-19 22:07:49.509 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:34 2022-12-19 22:07:50.618 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:35 2022-12-19 22:07:51.685 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:36 2022-12-19 22:07:52.744 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:37 2022-12-19 22:07:53.842 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:38 2022-12-19 22:07:54.920 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:39 2022-12-19 22:07:56.194 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:41 2022-12-19 22:07:57.318 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:42 2022-12-19 22:07:58.388 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:43 2022-12-19 22:07:59.449 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:44 2022-12-19 22:08:00.494 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:45 2022-12-19 22:08:01.567 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:46 2022-12-19 22:08:02.670 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:47 2022-12-19 22:08:03.754 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:48 2022-12-19 22:08:04.836 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:49 2022-12-19 22:08:05.866 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:50 2022-12-19 22:08:06.887 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:51 m2022-12-19 22:08:07.954 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-19 22:08:07.956 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:52 2022-12-19 22:08:09.001 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:54 2022-12-19 22:08:10.025 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:55 2022-12-19 22:08:11.050 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:56 2022-12-19 22:08:12.074 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:57 2022-12-19 22:08:13.124 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:58 2022-12-19 22:08:14.149 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:59 2022-12-19 22:08:15.190 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:01:00 2022-12-19 22:08:16.211 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:01:01 2022-12-19 22:08:18.281 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:89 - Job is success!!! Job id is 202212192207123770130 2022-12-19 22:08:18.282 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:90 - Total time: 0:01:03