Homo-NN Quick Start: A Binary Classification Task¶
This tutorial allows you to quickly get started using Homo-NN. By default, you can use the Homo-NN component in the same process as other FATE algorithm components: use the reader and transformer interfaces that come with FATE to input table data and convert the data format, and then input the data into the algorithm component. Then NN component will use your defined model, optimizer and loss function for training and model aggregation.
In FATE-1.10, Homo-NN in the pipeline has added support for pytorch. You can follow the usage of pytorch Sequential, use the built-in layers of Pytorch to define the Sequential model and submit the model. At the same time, you can use the loss function and optimizer that comes with Pytorch.
The following is a basic binary classification task Homo-NN task. There are two clients with party ids of 10000 and 9999 respectively, and 10000 is specified as the server-side aggregation model.
Uploading Tabular Data¶
At the very beginning, we upload data to FATE. We can directly upload data using the pipeline. Here we upload two files: breast_homo_guest.csv for the guest, and breast_homo_host.csv for the host. Please notice that in this tutorial we are using a standalone version, if you are using a cluster version, you need to upload corresponding data on each machine.
from pipeline.backend.pipeline import PipeLine # pipeline Class
# [9999(guest), 10000(host)] as client
# [10000(arbiter)] as server
guest = 9999
host = 10000
arbiter = 10000
pipeline_upload = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)
partition = 4
# upload a dataset
path_to_fate_project = '../../../../'
guest_data = {"name": "breast_homo_guest", "namespace": "experiment"}
host_data = {"name": "breast_homo_host", "namespace": "experiment"}
pipeline_upload.add_upload_data(file="examples/data/breast_homo_guest.csv", # file in the example/data
table_name=guest_data["name"], # table name
namespace=guest_data["namespace"], # namespace
head=1, partition=partition) # data info
pipeline_upload.add_upload_data(file="examples/data/breast_homo_host.csv", # file in the example/data
table_name=host_data["name"], # table name
namespace=host_data["namespace"], # namespace
head=1, partition=partition) # data info
pipeline_upload.upload(drop=1)
UPLOADING:||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||100.00%
2022-12-19 10:40:32.733 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:83 - Job id is 202212191040322910830 2022-12-19 10:40:32.747 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:98 - Job is still waiting, time elapse: 0:00:00 m2022-12-19 10:40:33.781 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-19 10:40:33.788 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component upload_0, time elapse: 0:00:01 2022-12-19 10:40:34.810 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component upload_0, time elapse: 0:00:02 2022-12-19 10:40:35.835 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component upload_0, time elapse: 0:00:03 2022-12-19 10:40:36.856 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component upload_0, time elapse: 0:00:04 2022-12-19 10:40:37.887 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component upload_0, time elapse: 0:00:05 2022-12-19 10:40:38.912 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component upload_0, time elapse: 0:00:06 2022-12-19 10:40:39.998 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:89 - Job is success!!! Job id is 202212191040322910830 2022-12-19 10:40:40.001 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:90 - Total time: 0:00:07
UPLOADING:||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||100.00%
2022-12-19 10:40:40.706 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:83 - Job id is 202212191040400256350 2022-12-19 10:40:40.748 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:98 - Job is still waiting, time elapse: 0:00:00 2022-12-19 10:40:41.769 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:98 - Job is still waiting, time elapse: 0:00:01 2022-12-19 10:40:42.806 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:98 - Job is still waiting, time elapse: 0:00:02 m2022-12-19 10:40:43.830 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-19 10:40:43.832 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component upload_0, time elapse: 0:00:03 2022-12-19 10:40:44.852 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component upload_0, time elapse: 0:00:04 2022-12-19 10:40:45.872 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component upload_0, time elapse: 0:00:05 2022-12-19 10:40:46.893 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component upload_0, time elapse: 0:00:06 2022-12-19 10:40:47.925 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component upload_0, time elapse: 0:00:07 2022-12-19 10:40:48.951 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component upload_0, time elapse: 0:00:08 2022-12-19 10:40:49.969 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:89 - Job is success!!! Job id is 202212191040400256350 2022-12-19 10:40:49.971 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:90 - Total time: 0:00:09
The breast dataset is a binary dataset set with 30 features:
import pandas as pd
df = pd.read_csv('../../../../examples/data/breast_homo_guest.csv')
df
id | y | x0 | x1 | x2 | x3 | x4 | x5 | x6 | x7 | ... | x20 | x21 | x22 | x23 | x24 | x25 | x26 | x27 | x28 | x29 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 133 | 1 | 0.254879 | -1.046633 | 0.209656 | 0.074214 | -0.441366 | -0.377645 | -0.485934 | 0.347072 | ... | -0.337360 | -0.728193 | -0.442587 | -0.272757 | -0.608018 | -0.577235 | -0.501126 | 0.143371 | -0.466431 | -0.554102 |
1 | 273 | 1 | -1.142928 | -0.781198 | -1.166747 | -0.923578 | 0.628230 | -1.021418 | -1.111867 | -0.959523 | ... | -0.493639 | 0.348620 | -0.552483 | -0.526877 | 2.253098 | -0.827620 | -0.780739 | -0.376997 | -0.310239 | 0.176301 |
2 | 175 | 1 | -1.451067 | -1.406518 | -1.456564 | -1.092337 | -0.708765 | -1.168557 | -1.305831 | -1.745063 | ... | -0.666881 | -0.779358 | -0.708418 | -0.637545 | 0.710369 | -0.976454 | -1.057501 | -1.913447 | 0.795207 | -0.149751 |
3 | 551 | 1 | -0.879933 | 0.420589 | -0.877527 | -0.780484 | -1.037534 | -0.483880 | -0.555498 | -0.768581 | ... | -0.451772 | 0.453852 | -0.431696 | -0.494754 | -1.182041 | 0.281228 | 0.084759 | -0.252420 | 1.038575 | 0.351054 |
4 | 199 | 0 | 0.426758 | 0.723479 | 0.316885 | 0.287273 | 1.000835 | 0.962702 | 1.077099 | 1.053586 | ... | -0.707304 | -1.026834 | -0.702973 | -0.460212 | -0.999033 | -0.531406 | -0.394360 | -0.728830 | -0.644416 | -0.688003 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
222 | 105 | 0 | 0.008451 | -0.533675 | -0.025652 | -0.093843 | 2.359748 | 0.990056 | 1.753070 | 1.278939 | ... | -0.051872 | -0.531700 | -0.225763 | -0.124905 | 0.040342 | 0.203542 | 0.757183 | 0.338023 | -0.614146 | 1.249401 |
223 | 242 | 1 | -0.763967 | 0.371736 | -0.598731 | -0.716671 | 0.102199 | 1.466524 | 2.261608 | 0.109537 | ... | -0.586035 | 0.771362 | -0.246060 | -0.526877 | -0.125998 | 1.881346 | 1.886843 | 0.217988 | -0.071715 | 1.845903 |
224 | 312 | 1 | -0.430564 | -1.510738 | -0.453377 | -0.460192 | -0.568490 | -0.212884 | -0.457149 | -0.464354 | ... | -0.283944 | -1.011412 | -0.257445 | -0.333482 | -0.182334 | 0.123061 | -0.017365 | -0.179426 | -0.391362 | 0.225852 |
225 | 473 | 1 | -0.583805 | 2.014830 | -0.660686 | -0.565491 | -1.672278 | -1.285861 | -1.305831 | -1.745063 | ... | 0.145552 | 4.409122 | 0.008881 | -0.114565 | 0.099344 | -0.963264 | -1.057501 | -1.913447 | 1.315845 | -0.249231 |
226 | 364 | 1 | -0.318739 | -0.647666 | -0.402145 | -0.381613 | -0.485202 | -0.551311 | -0.651448 | -0.681180 | ... | -0.890652 | -1.096686 | -0.905935 | -0.596622 | -0.882362 | -0.725342 | -0.576392 | -1.023890 | -0.924107 | -0.650934 |
227 rows × 32 columns
Write the Pipeline script and execute it¶
After the upload is complete, we can start writing the pipeline script to submit a FATE task.
Import Pipeline Components¶
# torch
import torch as t
from torch import nn
# pipeline
from pipeline.component.homo_nn import HomoNN, TrainerParam # HomoNN Component, TrainerParam for setting trainer parameter
from pipeline.backend.pipeline import PipeLine # pipeline class
from pipeline.component import Reader, DataTransform, Evaluation # Data I/O and Evaluation
from pipeline.interface import Data # Data Interaces for defining data flow
We can check the parameters of the Homo-NN component:
print(HomoNN.__doc__)
Parameters ---------- name, name of this component trainer, trainer param dataset, dataset param torch_seed, global random seed loss, loss function from fate_torch optimizer, optimizer from fate_torch model, a fate torch sequential defining the model structure
fate_torch_hook¶
Please be sure to execute the following fate_torch_hook function, which can modify some classes of torch, so that the torch layers, sequential, optimizer, and loss function you define in the scripts can be parsed and submitted by the pipeline.
from pipeline import fate_torch_hook
t = fate_torch_hook(t)
pipeline¶
# create a pipeline to submitting the job
guest = 9999
host = 10000
arbiter = 10000
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)
# read uploaded dataset
train_data_0 = {"name": "breast_homo_guest", "namespace": "experiment"}
train_data_1 = {"name": "breast_homo_host", "namespace": "experiment"}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=train_data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=train_data_1)
# The transform component converts the uploaded data to the DATE standard format
data_transform_0 = DataTransform(name='data_transform_0')
data_transform_0.get_party_instance(
role='guest', party_id=guest).component_param(
with_label=True, output_format="dense")
data_transform_0.get_party_instance(
role='host', party_id=host).component_param(
with_label=True, output_format="dense")
"""
Define Pytorch model/ optimizer and loss
"""
model = nn.Sequential(
nn.Linear(30, 1),
nn.Sigmoid()
)
loss = nn.BCELoss()
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
"""
Create Homo-NN Component
"""
nn_component = HomoNN(name='nn_0',
model=model, # set model
loss=loss, # set loss
optimizer=optimizer, # set optimizer
# Here we use fedavg trainer
# TrainerParam passes parameters to fedavg_trainer, see below for details about Trainer
trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=3, batch_size=128, validation_freqs=1),
torch_seed=100 # random seed
)
# define work flow
pipeline.add_component(reader_0)
pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))
pipeline.add_component(nn_component, data=Data(train_data=data_transform_0.output.data))
pipeline.add_component(Evaluation(name='eval_0'), data=Data(data=nn_component.output.data))
pipeline.compile()
pipeline.fit()
2022-12-19 10:50:38.106 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:83 - Job id is 202212191050351408970 2022-12-19 10:50:38.118 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:98 - Job is still waiting, time elapse: 0:00:00 2022-12-19 10:50:39.139 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:98 - Job is still waiting, time elapse: 0:00:01 m2022-12-19 10:50:40.176 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-19 10:50:40.183 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:02 2022-12-19 10:50:41.217 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:03 2022-12-19 10:50:42.256 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:04 2022-12-19 10:50:43.287 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:05 2022-12-19 10:50:44.330 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:06 2022-12-19 10:50:45.357 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component reader_0, time elapse: 0:00:07 m2022-12-19 10:50:47.457 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-19 10:50:47.459 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:09 2022-12-19 10:50:48.484 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:10 2022-12-19 10:50:49.512 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:11 2022-12-19 10:50:50.539 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:12 2022-12-19 10:50:51.590 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:13 2022-12-19 10:50:52.624 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:14 2022-12-19 10:50:53.677 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:15 2022-12-19 10:50:54.699 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:16 2022-12-19 10:50:55.723 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component data_transform_0, time elapse: 0:00:17 m2022-12-19 10:50:56.782 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-19 10:50:56.785 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:18 2022-12-19 10:50:57.814 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:19 2022-12-19 10:50:58.839 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:20 2022-12-19 10:50:59.865 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:21 2022-12-19 10:51:00.898 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:22 2022-12-19 10:51:01.925 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:23 2022-12-19 10:51:02.953 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:24 2022-12-19 10:51:04.020 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:25 2022-12-19 10:51:05.055 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:26 2022-12-19 10:51:06.080 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:27 2022-12-19 10:51:07.132 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:29 2022-12-19 10:51:08.164 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:30 2022-12-19 10:51:09.207 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:31 2022-12-19 10:51:10.237 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:32 2022-12-19 10:51:11.258 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component nn_0, time elapse: 0:00:33 m2022-12-19 10:51:13.377 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:125 - 2022-12-19 10:51:13.379 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:35 2022-12-19 10:51:14.400 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:36 2022-12-19 10:51:15.439 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:37 2022-12-19 10:51:16.469 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:38 2022-12-19 10:51:17.491 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:39 2022-12-19 10:51:18.564 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:40 2022-12-19 10:51:19.589 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:41 2022-12-19 10:51:20.626 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:42 2022-12-19 10:51:21.650 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:127 - Running component eval_0, time elapse: 0:00:43 2022-12-19 10:51:23.691 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:89 - Job is success!!! Job id is 202212191050351408970 2022-12-19 10:51:23.693 | INFO | pipeline.utils.invoker.job_submitter:monitor_job_status:90 - Total time: 0:00:45
Get Component Output¶
# get predict scores
pipeline.get_component('nn_0').get_output_data()
id | label | predict_result | predict_score | predict_detail | type | |
---|---|---|---|---|---|---|
0 | 0 | 0.0 | 0 | 0.05885133519768715 | {'0': 0.9411486648023129, '1': 0.0588513351976... | train |
1 | 3 | 0.0 | 1 | 0.5971069931983948 | {'0': 0.4028930068016052, '1': 0.5971069931983... | train |
2 | 5 | 1.0 | 1 | 0.7218729257583618 | {'0': 0.2781270742416382, '1': 0.7218729257583... | train |
3 | 7 | 1.0 | 1 | 0.6514894962310791 | {'0': 0.3485105037689209, '1': 0.6514894962310... | train |
4 | 14 | 0.0 | 0 | 0.2351398915052414 | {'0': 0.7648601084947586, '1': 0.2351398915052... | train |
... | ... | ... | ... | ... | ... | ... |
222 | 551 | 1.0 | 0 | 0.38658156991004944 | {'0': 0.6134184300899506, '1': 0.3865815699100... | train |
223 | 559 | 0.0 | 1 | 0.5517507195472717 | {'0': 0.44824928045272827, '1': 0.551750719547... | train |
224 | 562 | 0.0 | 0 | 0.39873841404914856 | {'0': 0.6012615859508514, '1': 0.3987384140491... | train |
225 | 567 | 1.0 | 1 | 0.6306618452072144 | {'0': 0.36933815479278564, '1': 0.630661845207... | train |
226 | 568 | 0.0 | 1 | 0.5063760876655579 | {'0': 0.49362391233444214, '1': 0.506376087665... | train |
227 rows × 6 columns
# get summary
pipeline.get_component('nn_0').get_summary()
{'best_epoch': 2, 'loss_history': [0.8317702709315632, 0.683187778825802, 0.5690162255375396], 'metrics_summary': {'train': {'auc': [0.732987012987013, 0.9094372294372294, 0.9561904761904763], 'ks': [0.4153246753246753, 0.6851948051948051, 0.7908225108225109]}}, 'need_stop': False}
TrainerParam trainer parameter and trainer¶
In this version, Homo-NN's training logic and federated aggregation logic are all implemented in the Trainer class. fedavg_trainer is the default Trainer of FATE Homo-NN, which implements the standard fedavg algorithm. And the function of TrainerParam is:
- Use trainer_name='{module name}' to specify the trainer to use. The trainer is in the federatedml.nn.homo.trainer directory, so you can customize your own trainer. There will be a special chapter for the tutorial on customizing the trainer
- The remaining parameters will be passed to the __init__() interface of the trainer
We can check the parameters of fedavg_trainer in FATE, these available parameters can be filled in TrainerParam.
from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer
Check the documentation of FedAVGTrainer to learn about the available parameters. When submitting tasks, these parameters can be passed with TrainerParam
print(FedAVGTrainer.__doc__)
Parameters ---------- epochs: int >0, epochs to train batch_size: int, -1 means full batch secure_aggregate: bool, default is True, whether to use secure aggregation. if enabled, will add random number mask to local models. These random number masks will eventually cancel out to get 0. weighted_aggregation: bool, whether add weight to each local model when doing aggregation. if True, According to origin paper, weight of a client is: n_local / n_global, where n_local is the sample number locally and n_global is the sample number of all clients. if False, simply averaging these models. early_stop: None, 'diff' or 'abs'. if None, disable early stop; if 'diff', use the loss difference between two epochs as early stop condition, if differences < tol, stop training ; if 'abs', if loss < tol, stop training tol: float, tol value for early stop aggregate_every_n_epoch: None or int. if None, aggregate model on the end of every epoch, if int, aggregate every n epochs. cuda: bool, use cuda or not pin_memory: bool, for pytorch DataLoader shuffle: bool, for pytorch DataLoader data_loader_worker: int, for pytorch DataLoader, number of workers when loading data validation_freqs: None or int. if int, validate your model and send validate results to fate-board every n epoch. if is binary classification task, will use metrics 'auc', 'ks', 'gain', 'lift', 'precision' if is multi classification task, will use metrics 'precision', 'recall', 'accuracy' if is regression task, will use metrics 'mse', 'mae', 'rmse', 'explained_variance', 'r2_score' checkpoint_save_freqs: save model every n epoch, if None, will not save checkpoint. task_type: str, 'auto', 'binary', 'multi', 'regression' this option decides the return format of this trainer, and the evaluation type when running validation. if auto, will automatically infer your task type from labels and predict results.
So far, we have gained a basic understanding of Homo-NN and have utilized it to perform basic modeling tasks. In addition, Homo-NN offers the ability to customize models, datasets, and Trainers for more advanced use cases. For further information, refer to the additional tutorials provided