# Hetero-NN Customize Models & Loss

We can customize the top model and the bottom model in the Hetero NN.
The model_zoo module was introduced in FATE 1.10, located under federatedml.nn.model_zoo. This module allows you to customize your own PyTorch model, provided that it is developed based on torch.nn.Module and implements the forward interface. For more information, see the PyTorch documentation  [Pytorch Module](https://pytorch.org/docs/stable/notes/modules.html#a-simple-custom-module) on custom modules. To use your custom model in a federated task, simply place it in the federatedml/nn/model_zoo directory and specify the module and model class through the interface when submitting the task. Hetero-NN components will automatically search and import the model you have implemented.

You can also define your own loss class in a similar way. The you can place your loss class under the loss module, located under federatedml.nn.loss.
We recommend you read these two tutorials before reading this tutorial: [Customize loss function](Homo-NN-Customize-Loss.ipynb), [Customize Model](Homo-NN-Customize-Model.ipynb)

As an example, we consider reuse the task of MNIST handwriting recognition of last hetero-nn tutorial.

## Prepare MNIST Data

Please download the guest/host MNIST dataset from the link below and place it in the project examples/data folder:

- guest data: https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/mnist_guest.zip

- host data: https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/mnist_host.zip
  
The mnist_guest is a simplified version of the MNIST dataset, with a total of ten categories, which are classified into 0-9 10 folders according to labels. The mnist_host has the same images as the mnist_guest, but it is not labeled.

In [3]:
! ls ../../../../examples/data/mnist_guest

0  1  2  3  4  5  6  7	8  9


In [4]:
! ls ../../../../examples/data/mnist_host

not_labeled


## Dataset

In version FATE-1.10, FATE introduces a new base class for datasets called Dataset, which is based on PyTorch's Dataset class. This class allows users to create custom datasets according to their specific needs. The usage is similar to that of PyTorch's Dataset class, with the added requirement of implementing two additional interfaces when using FATE-NN for data reading and training: load() and get_sample_ids().

To create a custom dataset in Hetero-NN, users need to:

- Develop a new dataset class that inherits from the Dataset class
- Implement the \_\_len\_\_() and \_\_getitem\_\_() methods, which are consistent with PyTorch's Dataset usage. The \_\_len\_\_() method should return the length of the dataset, while the \_\_getitem\_\_() method should return the corresponding data at the specified index. **However, please notice that different \_\_getitem\_\_() methods may have different behaviors between different parties. In the guest party(party with labels), _\_getitem\_\_() method return features and labels, while in the host parties(parties without label), _\_getitem\_\_() method return features only.** 
- Implement the load(), get_sample_ids(), get_classes() methods
  
For those unfamiliar with PyTorch's Dataset class, more information can be found in the PyTorch documentation: [Pytorch Dataset Documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)

## Customize Bottom/Top Model

Name the model code bottom_net.py,  you can put it directly under federatedml/nn/model_zoo or use the shortcut interface of jupyter notebook: save_to_fate, to save it directly to federatedml/nn/model_zoo. This is the bottom model structure we define for feature extraction.

In [2]:
from pipeline.component.nn import save_to_fate

In [17]:
%%save_to_fate model bottom_net.py
import torch as t
from torch import nn
from torch.nn import Module

class BottomNet(nn.Module):

    def __init__(self):
        super(BottomNet, self).__init__()
        self.seq = t.nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5),
            nn.MaxPool2d(kernel_size=3),
            nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3),
            nn.AvgPool2d(kernel_size=3)
        )
        
        self.fc = t.nn.Sequential(   # extracted feature is a 8-dim embedding
            nn.Linear(48, 32),
            nn.ReLU(),
            nn.Linear(32, 8),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.seq(x)
        x = x.flatten(start_dim=1)
        x = self.fc(x)
        return x

And this is the top model we define for classification, we named it as top_model.py.

In [36]:
%%save_to_fate model top_net.py
import torch as t
from torch import nn
from torch.nn import Module

class TopNet(nn.Module):

    def __init__(self):
        super(TopNet, self).__init__()
        self.fc = t.nn.Sequential(   
            nn.Linear(8, 10)
        )
        self.softmax = t.nn.Softmax(dim=1)

    def forward(self, x):
        x = self.fc(x)
        return self.softmax(x)

## Use Cust Loss

Using Cust Loss is exactly the same as Homo-NN, see: [Customize loss function](Homo-NN-Customize-Loss.ipynb). Here we use a new CrossEntropyLoss.

In [37]:
%%save_to_fate loss ce.py
import torch as t
from federatedml.util import consts
from torch.nn.functional import one_hot


def cross_entropy(p2, p1, reduction='mean'):
    p2 = p2 + consts.FLOAT_ZERO  # to avoid nan
    assert p2.shape == p1.shape
    if reduction == 'sum':
        return -t.sum(p1 * t.log(p2))
    elif reduction == 'mean':
        return -t.mean(t.sum(p1 * t.log(p2), dim=1))
    elif reduction == 'none':
        return -t.sum(p1 * t.log(p2), dim=1)
    else:
        raise ValueError('unknown reduction')


class CrossEntropyLoss(t.nn.Module):

    """
    A CrossEntropy Loss that will not compute Softmax
    """

    def __init__(self, reduction='mean'):
        super(CrossEntropyLoss, self).__init__()
        self.reduction = reduction

    def forward(self, pred, label):

        one_hot_label = one_hot(label.flatten())
        loss_ = cross_entropy(pred, one_hot_label, self.reduction)

        return loss_

Then, we can use our models & loss in the Hetero-NN MNIST task! The usage is the same as Homo-NN: we specify our model and loss by nn.CustModel and nn.CustLoss interfaces.

## pipeline initialization

Here we define the pipeline to run a hetero task

In [38]:
import os
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HeteroNN
from pipeline.component.hetero_nn import DatasetParam
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model
from pipeline.component.nn import save_to_fate

fate_torch_hook(t)

# bind path to fate name&namespace
fate_project_path = os.path.abspath('../../../../')
guest = 10000
host = 9999

pipeline_img = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host)

guest_data = {"name": "mnist_guest", "namespace": "experiment"}
host_data = {"name": "mnist_host", "namespace": "experiment"}

guest_data_path = fate_project_path + '/examples/data/mnist_guest/'
host_data_path = fate_project_path + '/examples/data/mnist_host/'
pipeline_img.bind_table(name='mnist_guest', namespace='experiment', path=guest_data_path)
pipeline_img.bind_table(name='mnist_host', namespace='experiment', path=host_data_path)

{'namespace': 'experiment', 'table_name': 'mnist_host'}

In [39]:
guest_data = {"name": "mnist_guest", "namespace": "experiment"}
host_data = {"name": "mnist_host", "namespace": "experiment"}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=guest_data)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=host_data)

In [40]:
hetero_nn_0 = HeteroNN(name="hetero_nn_0", epochs=3,
                       interactive_layer_lr=0.01, batch_size=512, task_type='classification', seed=100
                       )
guest_nn_0 = hetero_nn_0.get_party_instance(role='guest', party_id=guest)
host_nn_0 = hetero_nn_0.get_party_instance(role='host', party_id=host)

# define model
# use cust model here
# our simple classification model:
guest_bottom = t.nn.CustModel(module_name='bottom_net.py', class_name='BottomNet')

# use cust model here
host_bottom = t.nn.CustModel(module_name='bottom_net.py', class_name='BottomNet')

# use new top model here
guest_top = t.nn.CustModel(module_name='top_net.py', class_name='TopNet')

# interactive layer define
interactive_layer = t.nn.InteractiveLayer(out_dim=8, guest_dim=8, host_dim=8)

# add models
guest_nn_0.add_top_model(guest_top)
guest_nn_0.add_bottom_model(guest_bottom)
host_nn_0.add_bottom_model(host_bottom)

# opt, loss
optimizer = t.optim.Adam(lr=0.01) 
loss = t.nn.CustLoss(loss_module_name='ce', class_name='CrossEntropyLoss')

# use DatasetParam to specify dataset and pass parameters
guest_nn_0.add_dataset(DatasetParam(dataset_name='image', return_label=True))
host_nn_0.add_dataset(DatasetParam(dataset_name='image', return_label=False))

hetero_nn_0.set_interactive_layer(interactive_layer)
hetero_nn_0.compile(optimizer=optimizer, loss=loss)

In [41]:
pipeline_img.add_component(reader_0)
pipeline_img.add_component(hetero_nn_0, data=Data(train_data=reader_0.output.data))
pipeline_img.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=hetero_nn_0.output.data))
pipeline_img.compile()

<pipeline.backend.pipeline.PipeLine at 0x7f3f363f5970>

In [42]:
pipeline_img.fit()

[32m2022-12-24 21:26:06.571[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m83[0m - [1mJob id is 202212242126060352040
[0m
[32m2022-12-24 21:26:06.584[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:00[0m
[32m2022-12-24 21:26:07.598[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:01[0m
[0mm2022-12-24 21:26:08.622[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m125[0m - [1m
[32m2022-12-24 21:26:08.624[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:02[0m
[32m2022-12-24 21:26:09.648[0m | [1mINFO    

In [43]:
pipeline_img.get_component('hetero_nn_0').get_output_data()  # get result

Unnamed: 0,id,label,predict_result,predict_score,predict_detail,type
0,img_1,0,1,0.12622389197349548,"{'0': 0.07662956416606903, '1': 0.126223891973...",train
1,img_3,4,1,0.12622389197349548,"{'0': 0.07662956416606903, '1': 0.126223891973...",train
2,img_4,0,1,0.12622389197349548,"{'0': 0.07662956416606903, '1': 0.126223891973...",train
3,img_5,0,1,0.12622389197349548,"{'0': 0.07662956416606903, '1': 0.126223891973...",train
4,img_6,7,1,0.12622389197349548,"{'0': 0.07662956416606903, '1': 0.126223891973...",train
...,...,...,...,...,...,...
1304,img_32537,1,1,0.12622389197349548,"{'0': 0.07662956416606903, '1': 0.126223891973...",train
1305,img_32558,1,1,0.12622389197349548,"{'0': 0.07662956416606903, '1': 0.126223891973...",train
1306,img_32563,1,1,0.12622389197349548,"{'0': 0.07662956416606903, '1': 0.126223891973...",train
1307,img_32565,1,1,0.12622389197349548,"{'0': 0.07662956416606903, '1': 0.126223891973...",train
