# Customize loss function

When Pytorch's built-in Loss function cannot meet your usage needs, you can use custom Loss to train your model

## A little problem with the MNIST example

You might notice that in the MNIST example in last tutorial [Customize your Dataset](Homo-NN-Customize-your-Dataset.ipynb), the classifier output scores are the result of the Softmax function, and we are using torch built-in CrossEntropyLoss to compute the loss. However, it shows in documentation ([CrossEntropyLoss Doc](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss)) that the input is expected to contain the unnormalized logits for each class, that is to say, in that example, we compute Softmax twice.
To tackle this problem, we can use a customized CrossEntropyLoss. 

## Develop a Custom loss

A Customized Loss is a class that subclass the torch.nn.Module and implements the forward function. In the FATE trainer, the loss function will be passed two parameters: the predicted scores and the label (loss_fn（pred, loss)), so when you are using FATE's trainer, your loss funcion need to take two parameters as input(predict score & label). However, if you are using your own trainer and have defined your own training process, you are not restricted in how you use the loss function.

### A New CrossEntropy Loss

Here we realize a new CrossEntropyLoss that skips softmax computation. We can use the jupyter interface: save_to_fate, to update the code to federatedml.nn.loss, named ce.py, of course you can manually copy the code file to the directory.

In [20]:
import torch as t
from federatedml.util import consts
from torch.nn.functional import one_hot


def cross_entropy(p2, p1, reduction='mean'):
    p2 = p2 + consts.FLOAT_ZERO  # to avoid nan
    assert p2.shape == p1.shape
    if reduction == 'sum':
        return -t.sum(p1 * t.log(p2))
    elif reduction == 'mean':
        return -t.mean(t.sum(p1 * t.log(p2), dim=1))
    elif reduction == 'none':
        return -t.sum(p1 * t.log(p2), dim=1)
    else:
        raise ValueError('unknown reduction')


class CrossEntropyLoss(t.nn.Module):

    """
    A CrossEntropy Loss that will not compute Softmax
    """

    def __init__(self, reduction='mean'):
        super(CrossEntropyLoss, self).__init__()
        self.reduction = reduction

    def forward(self, pred, label):

        one_hot_label = one_hot(label.flatten())
        loss_ = cross_entropy(pred, one_hot_label, self.reduction)

        return loss_

## Train with New Loss

### Import Components

In [21]:
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model

t = fate_torch_hook(t)


### Bind data path to name & namespace

In [22]:
import os
# bind data path to name & namespace
fate_project_path = os.path.abspath('../../../../')
arbiter = 10000
host = 10000
guest = 9999
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,
                                                                            arbiter=arbiter)

data_0 = {"name": "mnist_guest", "namespace": "experiment"}
data_1 = {"name": "mnist_host", "namespace": "experiment"}

data_path_0 = fate_project_path + '/examples/data/mnist'
data_path_1 = fate_project_path + '/examples/data/mnist'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)
pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)

{'namespace': 'experiment', 'table_name': 'mnist_host'}

In [23]:
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)

## Use CustLoss

After fate_torch_hook, we can use t.nn.CustLoss to specify your own loss. We will specify the module name and the class name in the parameter, and behind is the initialization parameter for your loss class. **The initialization parameter must be JSON-serializable, otherwise this pipeline can not be submitted.**

In [24]:
from pipeline.component.homo_nn import TrainerParam, DatasetParam  # Interface

# your loss class
loss = t.nn.CustLoss(loss_module_name='cross_entropy', class_name='CrossEntropyLoss', reduction='mean')

# our simple classification model:
model = t.nn.Sequential(
    t.nn.Linear(784, 32),
    t.nn.ReLU(),
    t.nn.Linear(32, 10),
    t.nn.Softmax(dim=1)
)

nn_component = HomoNN(name='nn_0',
                      model=model, # model
                      loss=loss,  # loss
                      optimizer=t.optim.Adam(model.parameters(), lr=0.01), # optimizer
                      dataset=DatasetParam(dataset_name='mnist_dataset', flatten_feature=True),  # dataset
                      trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=2, batch_size=1024, validation_freqs=1),
                      torch_seed=100 # random seed
                      )

In [25]:
pipeline.add_component(reader_0)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))
pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))

<pipeline.backend.pipeline.PipeLine at 0x7f15d02f6b50>

In [26]:
pipeline.compile()
pipeline.fit()

[32m2022-12-19 18:39:12.858[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m83[0m - [1mJob id is 202212191839119838210
[0m
[32m2022-12-19 18:39:12.890[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:00[0m
[0mm2022-12-19 18:39:13.940[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m125[0m - [1m
[32m2022-12-19 18:39:13.943[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:01[0m
[32m2022-12-19 18:39:14.977[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:02[0m
[32m2022-12-19 18:39:16.036[0m | [1mI

In [40]:
pipeline.get_component('nn_0').get_output_data()

Unnamed: 0,id,label,predict_result,predict_score,predict_detail,type
0,img_1,0,0,0.9070178270339966,"{'0': 0.9070178270339966, '1': 0.0023874549660...",train
1,img_3,4,6,0.19601570069789886,"{'0': 0.19484134018421173, '1': 0.044997252523...",train
2,img_4,0,0,0.9618675112724304,"{'0': 0.9618675112724304, '1': 0.0010393995326...",train
3,img_5,0,0,0.33044907450675964,"{'0': 0.33044907450675964, '1': 0.033256266266...",train
4,img_6,7,7,0.3145765960216522,"{'0': 0.05851678550243378, '1': 0.075524508953...",train
...,...,...,...,...,...,...
1304,img_32537,1,8,0.20599651336669922,"{'0': 0.080563984811306, '1': 0.12380836158990...",train
1305,img_32558,1,8,0.20311488211154938,"{'0': 0.07224143296480179, '1': 0.130610913038...",train
1306,img_32563,1,8,0.2071550488471985,"{'0': 0.06843454390764236, '1': 0.129064396023...",train
1307,img_32565,1,5,0.29367145895957947,"{'0': 0.05658009275794029, '1': 0.086584843695...",train


In [41]:
pipeline.get_component('nn_0').get_summary()

{'best_epoch': 1,
 'loss_history': [3.58235876026547, 3.4448592824914055],
 'metrics_summary': {'train': {'accuracy': [0.25668449197860965,
    0.4950343773873186],
   'precision': [0.3708616690797323, 0.5928620913124757],
   'recall': [0.21817632850241547, 0.4855654369784805]}},
 'need_stop': False}