Customize loss function¶
When Pytorch's built-in Loss function cannot meet your usage needs, you can use custom Loss to train your model
A little problem with the MNIST example¶
You might notice that in the MNIST example in last tutorial Customize your Dataset, the classifier output scores are the result of the Softmax function, and we are using torch built-in CrossEntropyLoss to compute the loss. However, it shows in documentation (CrossEntropyLoss Doc) that the input is expected to contain the unnormalized logits for each class, that is to say, in that example, we compute Softmax twice. To tackle this problem, we can use a customized CrossEntropyLoss.
Develop a Custom loss¶
A Customized Loss is a class that subclass the torch.nn.Module and implements the forward function. In the FATE trainer, the loss function will be passed two parameters: the predicted scores and the label (loss_fn(pred, loss)), so when you are using FATE's trainer, your loss funcion need to take two parameters as input(predict score & label). However, if you are using your own trainer and have defined your own training process, you are not restricted in how you use the loss function.
A New CrossEntropy Loss¶
Here we realize a new CrossEntropyLoss that skips softmax computation. We can use the jupyter interface: save_to_fate, to update the code to federatedml.nn.loss, named ce.py, of course you can manually copy the code file to the directory.
import torch as t
from federatedml.util import consts
from torch.nn.functional import one_hot
def cross_entropy(p2, p1, reduction='mean'):
p2 = p2 + consts.FLOAT_ZERO # to avoid nan
assert p2.shape == p1.shape
if reduction == 'sum':
return -t.sum(p1 * t.log(p2))
elif reduction == 'mean':
return -t.mean(t.sum(p1 * t.log(p2), dim=1))
elif reduction == 'none':
return -t.sum(p1 * t.log(p2), dim=1)
else:
raise ValueError('unknown reduction')
class CrossEntropyLoss(t.nn.Module):
"""
A CrossEntropy Loss that will not compute Softmax
"""
def __init__(self, reduction='mean'):
super(CrossEntropyLoss, self).__init__()
self.reduction = reduction
def forward(self, pred, label):
one_hot_label = one_hot(label.flatten())
loss_ = cross_entropy(pred, one_hot_label, self.reduction)
return loss_
Train with New Loss¶
Import Components¶
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model
t = fate_torch_hook(t)
Bind data path to name & namespace¶
import os
# bind data path to name & namespace
fate_project_path = os.path.abspath('../../../../')
arbiter = 10000
host = 10000
guest = 9999
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,
arbiter=arbiter)
data_0 = {"name": "mnist_guest", "namespace": "experiment"}
data_1 = {"name": "mnist_host", "namespace": "experiment"}
data_path_0 = fate_project_path + '/examples/data/mnist'
data_path_1 = fate_project_path + '/examples/data/mnist'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)
pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)
{'namespace': 'experiment', 'table_name': 'mnist_host'}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)
Use CustLoss¶
After fate_torch_hook, we can use t.nn.CustLoss to specify your own loss. We will specify the module name and the class name in the parameter, and behind is the initialization parameter for your loss class. The initialization parameter must be JSON-serializable, otherwise this pipeline can not be submitted.
from pipeline.component.homo_nn import TrainerParam, DatasetParam # Interface
# your loss class
loss = t.nn.CustLoss(loss_module_name='cross_entropy', class_name='CrossEntropyLoss', reduction='mean')
# our simple classification model:
model = t.nn.Sequential(
t.nn.Linear(784, 32),
t.nn.ReLU(),
t.nn.Linear(32, 10),
t.nn.Softmax(dim=1)
)
nn_component = HomoNN(name='nn_0',
model=model, # model
loss=loss, # loss
optimizer=t.optim.Adam(model.parameters(), lr=0.01), # optimizer
dataset=DatasetParam(dataset_name='mnist_dataset', flatten_feature=True), # dataset
trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=2, batch_size=1024, validation_freqs=1),
torch_seed=100 # random seed
)
pipeline.add_component(reader_0)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))
pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))
<pipeline.backend.pipeline.PipeLine at 0x7f15d02f6b50>
pipeline.compile()
pipeline.fit()
2022-12-19 18:39:12.858 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:83 - Job id is 202212191839119838210 2022-12-19 18:39:12.890 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:98 - Job is still waiting, time elapse: 0:00:00 m2022-12-19 18:39:13.940 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-19 18:39:13.943 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:01 2022-12-19 18:39:14.977 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:02 2022-12-19 18:39:16.036 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:03 2022-12-19 18:39:17.088 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:04 2022-12-19 18:39:18.133 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:05 2022-12-19 18:39:19.184 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:06 2022-12-19 18:39:20.246 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:07 2022-12-19 18:39:21.278 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:08 2022-12-19 18:39:22.319 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:09 2022-12-19 18:39:23.343 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:10 2022-12-19 18:39:24.383 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:11 m2022-12-19 18:39:26.565 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-19 18:39:26.568 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:13 2022-12-19 18:39:27.611 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:14 2022-12-19 18:39:28.656 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:15 2022-12-19 18:39:29.713 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:16 2022-12-19 18:39:30.774 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:17 2022-12-19 18:39:31.812 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:18 2022-12-19 18:39:32.857 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:19 2022-12-19 18:39:33.981 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:21 2022-12-19 18:39:35.004 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:22 2022-12-19 18:39:36.092 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:23 2022-12-19 18:39:37.129 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:24 2022-12-19 18:39:38.166 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:25 2022-12-19 18:39:39.244 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:26 2022-12-19 18:39:40.286 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:27 2022-12-19 18:39:41.429 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:28 2022-12-19 18:39:42.479 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:29 2022-12-19 18:39:43.621 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:30 2022-12-19 18:39:44.665 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:31 2022-12-19 18:39:45.717 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:32 2022-12-19 18:39:46.758 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:33 2022-12-19 18:39:47.802 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:34 2022-12-19 18:39:48.847 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:35 2022-12-19 18:39:49.895 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:37 2022-12-19 18:39:50.946 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:38 m2022-12-19 18:39:53.243 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-19 18:39:53.246 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:40 2022-12-19 18:39:54.538 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:41 2022-12-19 18:39:55.640 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:42 2022-12-19 18:39:56.688 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:43 2022-12-19 18:39:57.779 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:44 2022-12-19 18:39:58.820 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:45 2022-12-19 18:40:00.137 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:47 2022-12-19 18:40:01.182 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:48 2022-12-19 18:40:02.214 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:49 2022-12-19 18:40:03.277 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:50 2022-12-19 18:40:04.307 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:51 2022-12-19 18:40:05.342 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:52 2022-12-19 18:40:06.416 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:53 2022-12-19 18:40:07.456 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:54 2022-12-19 18:40:10.543 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:89 - Job is success!!! Job id is 202212191839119838210 2022-12-19 18:40:10.545 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:90 - Total time: 0:00:57
pipeline.get_component('nn_0').get_output_data()
id | label | predict_result | predict_score | predict_detail | type | |
---|---|---|---|---|---|---|
0 | img_1 | 0 | 0 | 0.9070178270339966 | {'0': 0.9070178270339966, '1': 0.0023874549660... | train |
1 | img_3 | 4 | 6 | 0.19601570069789886 | {'0': 0.19484134018421173, '1': 0.044997252523... | train |
2 | img_4 | 0 | 0 | 0.9618675112724304 | {'0': 0.9618675112724304, '1': 0.0010393995326... | train |
3 | img_5 | 0 | 0 | 0.33044907450675964 | {'0': 0.33044907450675964, '1': 0.033256266266... | train |
4 | img_6 | 7 | 7 | 0.3145765960216522 | {'0': 0.05851678550243378, '1': 0.075524508953... | train |
... | ... | ... | ... | ... | ... | ... |
1304 | img_32537 | 1 | 8 | 0.20599651336669922 | {'0': 0.080563984811306, '1': 0.12380836158990... | train |
1305 | img_32558 | 1 | 8 | 0.20311488211154938 | {'0': 0.07224143296480179, '1': 0.130610913038... | train |
1306 | img_32563 | 1 | 8 | 0.2071550488471985 | {'0': 0.06843454390764236, '1': 0.129064396023... | train |
1307 | img_32565 | 1 | 5 | 0.29367145895957947 | {'0': 0.05658009275794029, '1': 0.086584843695... | train |
1308 | img_32573 | 1 | 8 | 0.199515700340271 | {'0': 0.08787216246128082, '1': 0.127247273921... | train |
1309 rows × 6 columns
pipeline.get_component('nn_0').get_summary()
{'best_epoch': 1, 'loss_history': [3.58235876026547, 3.4448592824914055], 'metrics_summary': {'train': {'accuracy': [0.25668449197860965, 0.4950343773873186], 'precision': [0.3708616690797323, 0.5928620913124757], 'recall': [0.21817632850241547, 0.4855654369784805]}}, 'need_stop': False}