{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3ee4aa0e",
   "metadata": {},
   "source": [
    "# Homo-NN Quick Start: A Binary Classification Task"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "deff6c23",
   "metadata": {},
   "source": [
    "This tutorial allows you to quickly get started using Homo-NN. By default, you can use the Homo-NN component in the same process as other FATE algorithm components: use the reader and transformer interfaces that come with FATE to input table data and convert the data format, and then input the data into the algorithm component. Then NN component will use your defined model, optimizer and loss function for training and model aggregation.\n",
    "\n",
    "In FATE-1.10, Homo-NN in the pipeline has added support for pytorch. You can follow the usage of pytorch Sequential, use the built-in layers of Pytorch to define the Sequential model and submit the model. At the same time, you can use the loss function and optimizer that comes with Pytorch.\n",
    "\n",
    "The following is a basic binary classification task Homo-NN task. There are two clients with party ids of 10000 and 9999 respectively, and 10000 is specified as the server-side aggregation model."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "43fd3f74",
   "metadata": {},
   "source": [
    "## Uploading Tabular Data\n",
    "\n",
    "At the very beginning, we upload data to FATE. We can directly upload data using the pipeline. Here we upload two files: breast_homo_guest.csv for the guest, and breast_homo_host.csv for the host. Please notice that in this tutorial we are using a standalone version, if you are using a cluster version, you need to upload corresponding data on each machine. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "611ca3bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " UPLOADING:||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||100.00%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2022-12-19 10:40:32.733\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191040322910830\n",
      "\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:32.747\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n",
      "\u001b[0mm2022-12-19 10:40:33.781\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n",
      "\u001b[32m2022-12-19 10:40:33.788\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:01\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:34.810\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:02\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:35.835\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:03\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:36.856\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:04\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:37.887\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:05\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:38.912\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:06\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:39.998\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191040322910830\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:40.001\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:07\u001b[0m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " UPLOADING:||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||100.00%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2022-12-19 10:40:40.706\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191040400256350\n",
      "\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:40.748\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:41.769\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:01\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:42.806\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:02\u001b[0m\n",
      "\u001b[0mm2022-12-19 10:40:43.830\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n",
      "\u001b[32m2022-12-19 10:40:43.832\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:03\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:44.852\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:04\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:45.872\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:05\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:46.893\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:06\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:47.925\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:07\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:48.951\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:08\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:49.969\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191040400256350\u001b[0m\n",
      "\u001b[32m2022-12-19 10:40:49.971\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:09\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "from pipeline.backend.pipeline import PipeLine  # pipeline Class\n",
    "\n",
    "# [9999(guest), 10000(host)] as client\n",
    "# [10000(arbiter)] as server\n",
    "\n",
    "guest = 9999\n",
    "host = 10000\n",
    "arbiter = 10000\n",
    "pipeline_upload = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)\n",
    "\n",
    "partition = 4\n",
    "\n",
    "# upload a dataset\n",
    "path_to_fate_project = '../../../../'\n",
    "guest_data = {\"name\": \"breast_homo_guest\", \"namespace\": \"experiment\"}\n",
    "host_data = {\"name\": \"breast_homo_host\", \"namespace\": \"experiment\"}\n",
    "pipeline_upload.add_upload_data(file=\"examples/data/breast_homo_guest.csv\", # file in the example/data\n",
    "                                table_name=guest_data[\"name\"],             # table name\n",
    "                                namespace=guest_data[\"namespace\"],         # namespace\n",
    "                                head=1, partition=partition)               # data info\n",
    "pipeline_upload.add_upload_data(file=\"examples/data/breast_homo_host.csv\", # file in the example/data\n",
    "                                table_name=host_data[\"name\"],             # table name\n",
    "                                namespace=host_data[\"namespace\"],         # namespace\n",
    "                                head=1, partition=partition)               # data info\n",
    "\n",
    "\n",
    "pipeline_upload.upload(drop=1)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "afa1afad",
   "metadata": {},
   "source": [
    "The breast dataset is a binary dataset set with 30 features:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d9580f9e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>y</th>\n",
       "      <th>x0</th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "      <th>x5</th>\n",
       "      <th>x6</th>\n",
       "      <th>x7</th>\n",
       "      <th>...</th>\n",
       "      <th>x20</th>\n",
       "      <th>x21</th>\n",
       "      <th>x22</th>\n",
       "      <th>x23</th>\n",
       "      <th>x24</th>\n",
       "      <th>x25</th>\n",
       "      <th>x26</th>\n",
       "      <th>x27</th>\n",
       "      <th>x28</th>\n",
       "      <th>x29</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>133</td>\n",
       "      <td>1</td>\n",
       "      <td>0.254879</td>\n",
       "      <td>-1.046633</td>\n",
       "      <td>0.209656</td>\n",
       "      <td>0.074214</td>\n",
       "      <td>-0.441366</td>\n",
       "      <td>-0.377645</td>\n",
       "      <td>-0.485934</td>\n",
       "      <td>0.347072</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.337360</td>\n",
       "      <td>-0.728193</td>\n",
       "      <td>-0.442587</td>\n",
       "      <td>-0.272757</td>\n",
       "      <td>-0.608018</td>\n",
       "      <td>-0.577235</td>\n",
       "      <td>-0.501126</td>\n",
       "      <td>0.143371</td>\n",
       "      <td>-0.466431</td>\n",
       "      <td>-0.554102</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>273</td>\n",
       "      <td>1</td>\n",
       "      <td>-1.142928</td>\n",
       "      <td>-0.781198</td>\n",
       "      <td>-1.166747</td>\n",
       "      <td>-0.923578</td>\n",
       "      <td>0.628230</td>\n",
       "      <td>-1.021418</td>\n",
       "      <td>-1.111867</td>\n",
       "      <td>-0.959523</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.493639</td>\n",
       "      <td>0.348620</td>\n",
       "      <td>-0.552483</td>\n",
       "      <td>-0.526877</td>\n",
       "      <td>2.253098</td>\n",
       "      <td>-0.827620</td>\n",
       "      <td>-0.780739</td>\n",
       "      <td>-0.376997</td>\n",
       "      <td>-0.310239</td>\n",
       "      <td>0.176301</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>175</td>\n",
       "      <td>1</td>\n",
       "      <td>-1.451067</td>\n",
       "      <td>-1.406518</td>\n",
       "      <td>-1.456564</td>\n",
       "      <td>-1.092337</td>\n",
       "      <td>-0.708765</td>\n",
       "      <td>-1.168557</td>\n",
       "      <td>-1.305831</td>\n",
       "      <td>-1.745063</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.666881</td>\n",
       "      <td>-0.779358</td>\n",
       "      <td>-0.708418</td>\n",
       "      <td>-0.637545</td>\n",
       "      <td>0.710369</td>\n",
       "      <td>-0.976454</td>\n",
       "      <td>-1.057501</td>\n",
       "      <td>-1.913447</td>\n",
       "      <td>0.795207</td>\n",
       "      <td>-0.149751</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>551</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.879933</td>\n",
       "      <td>0.420589</td>\n",
       "      <td>-0.877527</td>\n",
       "      <td>-0.780484</td>\n",
       "      <td>-1.037534</td>\n",
       "      <td>-0.483880</td>\n",
       "      <td>-0.555498</td>\n",
       "      <td>-0.768581</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.451772</td>\n",
       "      <td>0.453852</td>\n",
       "      <td>-0.431696</td>\n",
       "      <td>-0.494754</td>\n",
       "      <td>-1.182041</td>\n",
       "      <td>0.281228</td>\n",
       "      <td>0.084759</td>\n",
       "      <td>-0.252420</td>\n",
       "      <td>1.038575</td>\n",
       "      <td>0.351054</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>199</td>\n",
       "      <td>0</td>\n",
       "      <td>0.426758</td>\n",
       "      <td>0.723479</td>\n",
       "      <td>0.316885</td>\n",
       "      <td>0.287273</td>\n",
       "      <td>1.000835</td>\n",
       "      <td>0.962702</td>\n",
       "      <td>1.077099</td>\n",
       "      <td>1.053586</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.707304</td>\n",
       "      <td>-1.026834</td>\n",
       "      <td>-0.702973</td>\n",
       "      <td>-0.460212</td>\n",
       "      <td>-0.999033</td>\n",
       "      <td>-0.531406</td>\n",
       "      <td>-0.394360</td>\n",
       "      <td>-0.728830</td>\n",
       "      <td>-0.644416</td>\n",
       "      <td>-0.688003</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>222</th>\n",
       "      <td>105</td>\n",
       "      <td>0</td>\n",
       "      <td>0.008451</td>\n",
       "      <td>-0.533675</td>\n",
       "      <td>-0.025652</td>\n",
       "      <td>-0.093843</td>\n",
       "      <td>2.359748</td>\n",
       "      <td>0.990056</td>\n",
       "      <td>1.753070</td>\n",
       "      <td>1.278939</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.051872</td>\n",
       "      <td>-0.531700</td>\n",
       "      <td>-0.225763</td>\n",
       "      <td>-0.124905</td>\n",
       "      <td>0.040342</td>\n",
       "      <td>0.203542</td>\n",
       "      <td>0.757183</td>\n",
       "      <td>0.338023</td>\n",
       "      <td>-0.614146</td>\n",
       "      <td>1.249401</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>223</th>\n",
       "      <td>242</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.763967</td>\n",
       "      <td>0.371736</td>\n",
       "      <td>-0.598731</td>\n",
       "      <td>-0.716671</td>\n",
       "      <td>0.102199</td>\n",
       "      <td>1.466524</td>\n",
       "      <td>2.261608</td>\n",
       "      <td>0.109537</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.586035</td>\n",
       "      <td>0.771362</td>\n",
       "      <td>-0.246060</td>\n",
       "      <td>-0.526877</td>\n",
       "      <td>-0.125998</td>\n",
       "      <td>1.881346</td>\n",
       "      <td>1.886843</td>\n",
       "      <td>0.217988</td>\n",
       "      <td>-0.071715</td>\n",
       "      <td>1.845903</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>224</th>\n",
       "      <td>312</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.430564</td>\n",
       "      <td>-1.510738</td>\n",
       "      <td>-0.453377</td>\n",
       "      <td>-0.460192</td>\n",
       "      <td>-0.568490</td>\n",
       "      <td>-0.212884</td>\n",
       "      <td>-0.457149</td>\n",
       "      <td>-0.464354</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.283944</td>\n",
       "      <td>-1.011412</td>\n",
       "      <td>-0.257445</td>\n",
       "      <td>-0.333482</td>\n",
       "      <td>-0.182334</td>\n",
       "      <td>0.123061</td>\n",
       "      <td>-0.017365</td>\n",
       "      <td>-0.179426</td>\n",
       "      <td>-0.391362</td>\n",
       "      <td>0.225852</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>225</th>\n",
       "      <td>473</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.583805</td>\n",
       "      <td>2.014830</td>\n",
       "      <td>-0.660686</td>\n",
       "      <td>-0.565491</td>\n",
       "      <td>-1.672278</td>\n",
       "      <td>-1.285861</td>\n",
       "      <td>-1.305831</td>\n",
       "      <td>-1.745063</td>\n",
       "      <td>...</td>\n",
       "      <td>0.145552</td>\n",
       "      <td>4.409122</td>\n",
       "      <td>0.008881</td>\n",
       "      <td>-0.114565</td>\n",
       "      <td>0.099344</td>\n",
       "      <td>-0.963264</td>\n",
       "      <td>-1.057501</td>\n",
       "      <td>-1.913447</td>\n",
       "      <td>1.315845</td>\n",
       "      <td>-0.249231</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>226</th>\n",
       "      <td>364</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.318739</td>\n",
       "      <td>-0.647666</td>\n",
       "      <td>-0.402145</td>\n",
       "      <td>-0.381613</td>\n",
       "      <td>-0.485202</td>\n",
       "      <td>-0.551311</td>\n",
       "      <td>-0.651448</td>\n",
       "      <td>-0.681180</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.890652</td>\n",
       "      <td>-1.096686</td>\n",
       "      <td>-0.905935</td>\n",
       "      <td>-0.596622</td>\n",
       "      <td>-0.882362</td>\n",
       "      <td>-0.725342</td>\n",
       "      <td>-0.576392</td>\n",
       "      <td>-1.023890</td>\n",
       "      <td>-0.924107</td>\n",
       "      <td>-0.650934</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>227 rows × 32 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      id  y        x0        x1        x2        x3        x4        x5  \\\n",
       "0    133  1  0.254879 -1.046633  0.209656  0.074214 -0.441366 -0.377645   \n",
       "1    273  1 -1.142928 -0.781198 -1.166747 -0.923578  0.628230 -1.021418   \n",
       "2    175  1 -1.451067 -1.406518 -1.456564 -1.092337 -0.708765 -1.168557   \n",
       "3    551  1 -0.879933  0.420589 -0.877527 -0.780484 -1.037534 -0.483880   \n",
       "4    199  0  0.426758  0.723479  0.316885  0.287273  1.000835  0.962702   \n",
       "..   ... ..       ...       ...       ...       ...       ...       ...   \n",
       "222  105  0  0.008451 -0.533675 -0.025652 -0.093843  2.359748  0.990056   \n",
       "223  242  1 -0.763967  0.371736 -0.598731 -0.716671  0.102199  1.466524   \n",
       "224  312  1 -0.430564 -1.510738 -0.453377 -0.460192 -0.568490 -0.212884   \n",
       "225  473  1 -0.583805  2.014830 -0.660686 -0.565491 -1.672278 -1.285861   \n",
       "226  364  1 -0.318739 -0.647666 -0.402145 -0.381613 -0.485202 -0.551311   \n",
       "\n",
       "           x6        x7  ...       x20       x21       x22       x23  \\\n",
       "0   -0.485934  0.347072  ... -0.337360 -0.728193 -0.442587 -0.272757   \n",
       "1   -1.111867 -0.959523  ... -0.493639  0.348620 -0.552483 -0.526877   \n",
       "2   -1.305831 -1.745063  ... -0.666881 -0.779358 -0.708418 -0.637545   \n",
       "3   -0.555498 -0.768581  ... -0.451772  0.453852 -0.431696 -0.494754   \n",
       "4    1.077099  1.053586  ... -0.707304 -1.026834 -0.702973 -0.460212   \n",
       "..        ...       ...  ...       ...       ...       ...       ...   \n",
       "222  1.753070  1.278939  ... -0.051872 -0.531700 -0.225763 -0.124905   \n",
       "223  2.261608  0.109537  ... -0.586035  0.771362 -0.246060 -0.526877   \n",
       "224 -0.457149 -0.464354  ... -0.283944 -1.011412 -0.257445 -0.333482   \n",
       "225 -1.305831 -1.745063  ...  0.145552  4.409122  0.008881 -0.114565   \n",
       "226 -0.651448 -0.681180  ... -0.890652 -1.096686 -0.905935 -0.596622   \n",
       "\n",
       "          x24       x25       x26       x27       x28       x29  \n",
       "0   -0.608018 -0.577235 -0.501126  0.143371 -0.466431 -0.554102  \n",
       "1    2.253098 -0.827620 -0.780739 -0.376997 -0.310239  0.176301  \n",
       "2    0.710369 -0.976454 -1.057501 -1.913447  0.795207 -0.149751  \n",
       "3   -1.182041  0.281228  0.084759 -0.252420  1.038575  0.351054  \n",
       "4   -0.999033 -0.531406 -0.394360 -0.728830 -0.644416 -0.688003  \n",
       "..        ...       ...       ...       ...       ...       ...  \n",
       "222  0.040342  0.203542  0.757183  0.338023 -0.614146  1.249401  \n",
       "223 -0.125998  1.881346  1.886843  0.217988 -0.071715  1.845903  \n",
       "224 -0.182334  0.123061 -0.017365 -0.179426 -0.391362  0.225852  \n",
       "225  0.099344 -0.963264 -1.057501 -1.913447  1.315845 -0.249231  \n",
       "226 -0.882362 -0.725342 -0.576392 -1.023890 -0.924107 -0.650934  \n",
       "\n",
       "[227 rows x 32 columns]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_csv('../../../../examples/data/breast_homo_guest.csv')\n",
    "df"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "53e6fc84",
   "metadata": {},
   "source": [
    "## Write the Pipeline script and execute it\n",
    "\n",
    "After the upload is complete, we can start writing the pipeline script to submit a FATE task."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "9a5ccfe0",
   "metadata": {},
   "source": [
    "### Import Pipeline Components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d4eec107",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "# torch\n",
    "import torch as t\n",
    "from torch import nn\n",
    "\n",
    "# pipeline\n",
    "from pipeline.component.homo_nn import HomoNN, TrainerParam  # HomoNN Component, TrainerParam for setting trainer parameter\n",
    "from pipeline.backend.pipeline import PipeLine  # pipeline class\n",
    "from pipeline.component import Reader, DataTransform, Evaluation # Data I/O and Evaluation\n",
    "from pipeline.interface import Data  # Data Interaces for defining data flow"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b7f7d9d5",
   "metadata": {},
   "source": [
    "We can check the parameters of the Homo-NN component:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "fbdf01b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "    Parameters\n",
      "    ----------\n",
      "    name, name of this component\n",
      "    trainer, trainer param\n",
      "    dataset, dataset param\n",
      "    torch_seed, global random seed\n",
      "    loss, loss function from fate_torch\n",
      "    optimizer, optimizer from fate_torch\n",
      "    model, a fate torch sequential defining the model structure\n",
      "    \n"
     ]
    }
   ],
   "source": [
    "print(HomoNN.__doc__)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "48edc92d",
   "metadata": {},
   "source": [
    "### fate_torch_hook\n",
    "\n",
    "Please be sure to execute the following fate_torch_hook function, which can modify some classes of torch, so that the torch layers, sequential, optimizer, and loss function you define in the scripts can be parsed and submitted by the pipeline. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "955db238",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pipeline import fate_torch_hook\n",
    "t = fate_torch_hook(t)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ee5be800",
   "metadata": {},
   "source": [
    "### pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "cc9a174a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2022-12-19 10:50:38.106\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191050351408970\n",
      "\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:38.118\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:39.139\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:01\u001b[0m\n",
      "\u001b[0mm2022-12-19 10:50:40.176\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n",
      "\u001b[32m2022-12-19 10:50:40.183\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:02\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:41.217\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:03\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:42.256\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:04\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:43.287\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:05\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:44.330\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:06\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:45.357\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:07\u001b[0m\n",
      "\u001b[0mm2022-12-19 10:50:47.457\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n",
      "\u001b[32m2022-12-19 10:50:47.459\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:09\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:48.484\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:10\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:49.512\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:11\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:50.539\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:12\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:51.590\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:13\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:52.624\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:14\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:53.677\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:15\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:54.699\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:16\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:55.723\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:17\u001b[0m\n",
      "\u001b[0mm2022-12-19 10:50:56.782\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n",
      "\u001b[32m2022-12-19 10:50:56.785\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:18\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:57.814\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:19\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:58.839\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:20\u001b[0m\n",
      "\u001b[32m2022-12-19 10:50:59.865\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:21\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:00.898\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:22\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:01.925\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:23\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:02.953\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:24\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:04.020\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:25\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:05.055\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:26\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:06.080\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:27\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:07.132\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:29\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:08.164\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:30\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:09.207\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:31\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:10.237\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:32\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:11.258\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:33\u001b[0m\n",
      "\u001b[0mm2022-12-19 10:51:13.377\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n",
      "\u001b[32m2022-12-19 10:51:13.379\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:35\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:14.400\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:36\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:15.439\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:37\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:16.469\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:38\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:17.491\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:39\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:18.564\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:40\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:19.589\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:41\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:20.626\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:42\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:21.650\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:43\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:23.691\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191050351408970\u001b[0m\n",
      "\u001b[32m2022-12-19 10:51:23.693\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:45\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# create a pipeline to submitting the job\n",
    "guest = 9999\n",
    "host = 10000\n",
    "arbiter = 10000\n",
    "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)\n",
    "\n",
    "# read uploaded dataset\n",
    "train_data_0 = {\"name\": \"breast_homo_guest\", \"namespace\": \"experiment\"}\n",
    "train_data_1 = {\"name\": \"breast_homo_host\", \"namespace\": \"experiment\"}\n",
    "reader_0 = Reader(name=\"reader_0\")\n",
    "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=train_data_0)\n",
    "reader_0.get_party_instance(role='host', party_id=host).component_param(table=train_data_1)\n",
    "\n",
    "# The transform component converts the uploaded data to the DATE standard format\n",
    "data_transform_0 = DataTransform(name='data_transform_0')\n",
    "data_transform_0.get_party_instance(\n",
    "    role='guest', party_id=guest).component_param(\n",
    "    with_label=True, output_format=\"dense\")\n",
    "data_transform_0.get_party_instance(\n",
    "    role='host', party_id=host).component_param(\n",
    "    with_label=True, output_format=\"dense\")\n",
    "\n",
    "\"\"\"\n",
    "Define Pytorch model/ optimizer and loss\n",
    "\"\"\"\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(30, 1),\n",
    "    nn.Sigmoid()\n",
    ")\n",
    "loss = nn.BCELoss()\n",
    "optimizer = t.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "Create Homo-NN Component\n",
    "\"\"\"\n",
    "nn_component = HomoNN(name='nn_0',\n",
    "                      model=model, # set model\n",
    "                      loss=loss, # set loss\n",
    "                      optimizer=optimizer, # set optimizer\n",
    "                      # Here we use fedavg trainer\n",
    "                      # TrainerParam passes parameters to fedavg_trainer, see below for details about Trainer\n",
    "                      trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=3, batch_size=128, validation_freqs=1),\n",
    "                      torch_seed=100 # random seed\n",
    "                      )\n",
    "\n",
    "# define work flow\n",
    "pipeline.add_component(reader_0)\n",
    "pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))\n",
    "pipeline.add_component(nn_component, data=Data(train_data=data_transform_0.output.data))\n",
    "pipeline.add_component(Evaluation(name='eval_0'), data=Data(data=nn_component.output.data))\n",
    "\n",
    "pipeline.compile()\n",
    "pipeline.fit()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "af94b45d",
   "metadata": {},
   "source": [
    "## Get Component Output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4f7325c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>label</th>\n",
       "      <th>predict_result</th>\n",
       "      <th>predict_score</th>\n",
       "      <th>predict_detail</th>\n",
       "      <th>type</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.05885133519768715</td>\n",
       "      <td>{'0': 0.9411486648023129, '1': 0.0588513351976...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.5971069931983948</td>\n",
       "      <td>{'0': 0.4028930068016052, '1': 0.5971069931983...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>5</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.7218729257583618</td>\n",
       "      <td>{'0': 0.2781270742416382, '1': 0.7218729257583...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>7</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.6514894962310791</td>\n",
       "      <td>{'0': 0.3485105037689209, '1': 0.6514894962310...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>14</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.2351398915052414</td>\n",
       "      <td>{'0': 0.7648601084947586, '1': 0.2351398915052...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>222</th>\n",
       "      <td>551</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.38658156991004944</td>\n",
       "      <td>{'0': 0.6134184300899506, '1': 0.3865815699100...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>223</th>\n",
       "      <td>559</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.5517507195472717</td>\n",
       "      <td>{'0': 0.44824928045272827, '1': 0.551750719547...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>224</th>\n",
       "      <td>562</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.39873841404914856</td>\n",
       "      <td>{'0': 0.6012615859508514, '1': 0.3987384140491...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>225</th>\n",
       "      <td>567</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.6306618452072144</td>\n",
       "      <td>{'0': 0.36933815479278564, '1': 0.630661845207...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>226</th>\n",
       "      <td>568</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.5063760876655579</td>\n",
       "      <td>{'0': 0.49362391233444214, '1': 0.506376087665...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>227 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      id label predict_result        predict_score  \\\n",
       "0      0   0.0              0  0.05885133519768715   \n",
       "1      3   0.0              1   0.5971069931983948   \n",
       "2      5   1.0              1   0.7218729257583618   \n",
       "3      7   1.0              1   0.6514894962310791   \n",
       "4     14   0.0              0   0.2351398915052414   \n",
       "..   ...   ...            ...                  ...   \n",
       "222  551   1.0              0  0.38658156991004944   \n",
       "223  559   0.0              1   0.5517507195472717   \n",
       "224  562   0.0              0  0.39873841404914856   \n",
       "225  567   1.0              1   0.6306618452072144   \n",
       "226  568   0.0              1   0.5063760876655579   \n",
       "\n",
       "                                        predict_detail   type  \n",
       "0    {'0': 0.9411486648023129, '1': 0.0588513351976...  train  \n",
       "1    {'0': 0.4028930068016052, '1': 0.5971069931983...  train  \n",
       "2    {'0': 0.2781270742416382, '1': 0.7218729257583...  train  \n",
       "3    {'0': 0.3485105037689209, '1': 0.6514894962310...  train  \n",
       "4    {'0': 0.7648601084947586, '1': 0.2351398915052...  train  \n",
       "..                                                 ...    ...  \n",
       "222  {'0': 0.6134184300899506, '1': 0.3865815699100...  train  \n",
       "223  {'0': 0.44824928045272827, '1': 0.551750719547...  train  \n",
       "224  {'0': 0.6012615859508514, '1': 0.3987384140491...  train  \n",
       "225  {'0': 0.36933815479278564, '1': 0.630661845207...  train  \n",
       "226  {'0': 0.49362391233444214, '1': 0.506376087665...  train  \n",
       "\n",
       "[227 rows x 6 columns]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# get predict scores\n",
    "pipeline.get_component('nn_0').get_output_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ab8afcfd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'best_epoch': 2,\n",
       " 'loss_history': [0.8317702709315632, 0.683187778825802, 0.5690162255375396],\n",
       " 'metrics_summary': {'train': {'auc': [0.732987012987013,\n",
       "    0.9094372294372294,\n",
       "    0.9561904761904763],\n",
       "   'ks': [0.4153246753246753, 0.6851948051948051, 0.7908225108225109]}},\n",
       " 'need_stop': False}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# get summary\n",
    "pipeline.get_component('nn_0').get_summary()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7d1df84b",
   "metadata": {},
   "source": [
    "## TrainerParam trainer parameter and trainer"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7c4a6605",
   "metadata": {},
   "source": [
    "In this version, Homo-NN's training logic and federated aggregation logic are all implemented in the Trainer class. fedavg_trainer is the default Trainer of FATE Homo-NN, which implements the standard fedavg algorithm. And the function of TrainerParam is:\n",
    "\n",
    "- Use trainer_name='{module name}' to specify the trainer to use. The trainer is in the federatedml.nn.homo.trainer directory, so you can customize your own trainer. There will be a special chapter for the tutorial on customizing the trainer\n",
    "- The remaining parameters will be passed to the \\_\\_init\\_\\_() interface of the trainer\n",
    "\n",
    "We can check the parameters of fedavg_trainer in FATE, these available parameters can be filled in TrainerParam."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d742b424",
   "metadata": {},
   "outputs": [],
   "source": [
    "from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a0e9a681",
   "metadata": {},
   "source": [
    "Check the documentation of FedAVGTrainer to learn about the available parameters. When submitting tasks, these parameters can be passed with TrainerParam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "041f1937",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "    Parameters\n",
      "    ----------\n",
      "    epochs: int >0, epochs to train\n",
      "    batch_size: int, -1 means full batch\n",
      "    secure_aggregate: bool, default is True, whether to use secure aggregation. if enabled, will add random number\n",
      "                            mask to local models. These random number masks will eventually cancel out to get 0.\n",
      "    weighted_aggregation: bool, whether add weight to each local model when doing aggregation.\n",
      "                         if True, According to origin paper, weight of a client is: n_local / n_global, where n_local\n",
      "                         is the sample number locally and n_global is the sample number of all clients.\n",
      "                         if False, simply averaging these models.\n",
      "\n",
      "    early_stop: None, 'diff' or 'abs'. if None, disable early stop; if 'diff', use the loss difference between\n",
      "                two epochs as early stop condition, if differences < tol, stop training ; if 'abs', if loss < tol,\n",
      "                stop training\n",
      "    tol: float, tol value for early stop\n",
      "\n",
      "    aggregate_every_n_epoch: None or int. if None, aggregate model on the end of every epoch, if int, aggregate\n",
      "                             every n epochs.\n",
      "    cuda: bool, use cuda or not\n",
      "    pin_memory: bool, for pytorch DataLoader\n",
      "    shuffle: bool, for pytorch DataLoader\n",
      "    data_loader_worker: int, for pytorch DataLoader, number of workers when loading data\n",
      "    validation_freqs: None or int. if int, validate your model and send validate results to fate-board every n epoch.\n",
      "                      if is binary classification task, will use metrics 'auc', 'ks', 'gain', 'lift', 'precision'\n",
      "                      if is multi classification task, will use metrics 'precision', 'recall', 'accuracy'\n",
      "                      if is regression task, will use metrics 'mse', 'mae', 'rmse', 'explained_variance', 'r2_score'\n",
      "    checkpoint_save_freqs: save model every n epoch, if None, will not save checkpoint.\n",
      "    task_type: str, 'auto', 'binary', 'multi', 'regression'\n",
      "               this option decides the return format of this trainer, and the evaluation type when running validation.\n",
      "               if auto, will automatically infer your task type from labels and predict results.\n",
      "    \n"
     ]
    }
   ],
   "source": [
    "print(FedAVGTrainer.__doc__)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "bde0e5e8",
   "metadata": {},
   "source": [
    "So far, we have gained a basic understanding of Homo-NN and have utilized it to perform basic modeling tasks. In addition, Homo-NN offers the ability to customize models, datasets, and Trainers for more advanced use cases. For further information, refer to the additional tutorials provided"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]"
  },
  "vscode": {
   "interpreter": {
    "hash": "d29574a2ab71ec988cdcd4d29c58400bd2037cad632b9528d973466f7fb6f853"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
