{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7a982d9b",
   "metadata": {},
   "source": [
    "# Customize loss function\n",
    "\n",
    "When Pytorch's built-in Loss function cannot meet your usage needs, you can use custom Loss to train your model\n",
    "\n",
    "## A little problem with the MNIST example\n",
    "\n",
    "You might notice that in the MNIST example in last tutorial [Customize your Dataset](Homo-NN-Customize-your-Dataset.ipynb), the classifier output scores are the result of the Softmax function, and we are using torch built-in CrossEntropyLoss to compute the loss. However, it shows in documentation ([CrossEntropyLoss Doc](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss)) that the input is expected to contain the unnormalized logits for each class, that is to say, in that example, we compute Softmax twice.\n",
    "To tackle this problem, we can use a customized CrossEntropyLoss. "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "40b31519",
   "metadata": {},
   "source": [
    "## Develop a Custom loss\n",
    "\n",
    "A Customized Loss is a class that subclass the torch.nn.Module and implements the forward function. In the FATE trainer, the loss function will be passed two parameters: the predicted scores and the label (loss_fn（pred, loss)), so when you are using FATE's trainer, your loss funcion need to take two parameters as input(predict score & label). However, if you are using your own trainer and have defined your own training process, you are not restricted in how you use the loss function."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "054985e4",
   "metadata": {},
   "source": [
    "### A New CrossEntropy Loss\n",
    "\n",
    "Here we realize a new CrossEntropyLoss that skips softmax computation. We can use the jupyter interface: save_to_fate, to update the code to federatedml.nn.loss, named ce.py, of course you can manually copy the code file to the directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "808626e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as t\n",
    "from federatedml.util import consts\n",
    "from torch.nn.functional import one_hot\n",
    "\n",
    "\n",
    "def cross_entropy(p2, p1, reduction='mean'):\n",
    "    p2 = p2 + consts.FLOAT_ZERO  # to avoid nan\n",
    "    assert p2.shape == p1.shape\n",
    "    if reduction == 'sum':\n",
    "        return -t.sum(p1 * t.log(p2))\n",
    "    elif reduction == 'mean':\n",
    "        return -t.mean(t.sum(p1 * t.log(p2), dim=1))\n",
    "    elif reduction == 'none':\n",
    "        return -t.sum(p1 * t.log(p2), dim=1)\n",
    "    else:\n",
    "        raise ValueError('unknown reduction')\n",
    "\n",
    "\n",
    "class CrossEntropyLoss(t.nn.Module):\n",
    "\n",
    "    \"\"\"\n",
    "    A CrossEntropy Loss that will not compute Softmax\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, reduction='mean'):\n",
    "        super(CrossEntropyLoss, self).__init__()\n",
    "        self.reduction = reduction\n",
    "\n",
    "    def forward(self, pred, label):\n",
    "\n",
    "        one_hot_label = one_hot(label.flatten())\n",
    "        loss_ = cross_entropy(pred, one_hot_label, self.reduction)\n",
    "\n",
    "        return loss_"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "2122f107",
   "metadata": {},
   "source": [
    "## Train with New Loss"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d258b9d2",
   "metadata": {},
   "source": [
    "### Import Components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "1518af62",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as t\n",
    "from torch import nn\n",
    "from pipeline import fate_torch_hook\n",
    "from pipeline.component import HomoNN\n",
    "from pipeline.backend.pipeline import PipeLine\n",
    "from pipeline.component import Reader, Evaluation, DataTransform\n",
    "from pipeline.interface import Data, Model\n",
    "\n",
    "t = fate_torch_hook(t)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8315687c",
   "metadata": {},
   "source": [
    "### Bind data path to name & namespace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d900c35a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'namespace': 'experiment', 'table_name': 'mnist_host'}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "# bind data path to name & namespace\n",
    "fate_project_path = os.path.abspath('../../../../')\n",
    "arbiter = 10000\n",
    "host = 10000\n",
    "guest = 9999\n",
    "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,\n",
    "                                                                            arbiter=arbiter)\n",
    "\n",
    "data_0 = {\"name\": \"mnist_guest\", \"namespace\": \"experiment\"}\n",
    "data_1 = {\"name\": \"mnist_host\", \"namespace\": \"experiment\"}\n",
    "\n",
    "data_path_0 = fate_project_path + '/examples/data/mnist'\n",
    "data_path_1 = fate_project_path + '/examples/data/mnist'\n",
    "pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)\n",
    "pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "d3af79ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "reader_0 = Reader(name=\"reader_0\")\n",
    "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)\n",
    "reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "86d4085a",
   "metadata": {},
   "source": [
    "## Use CustLoss\n",
    "\n",
    "After fate_torch_hook, we can use t.nn.CustLoss to specify your own loss. We will specify the module name and the class name in the parameter, and behind is the initialization parameter for your loss class. **The initialization parameter must be JSON-serializable, otherwise this pipeline can not be submitted.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "de9917a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pipeline.component.homo_nn import TrainerParam, DatasetParam  # Interface\n",
    "\n",
    "# your loss class\n",
    "loss = t.nn.CustLoss(loss_module_name='cross_entropy', class_name='CrossEntropyLoss', reduction='mean')\n",
    "\n",
    "# our simple classification model:\n",
    "model = t.nn.Sequential(\n",
    "    t.nn.Linear(784, 32),\n",
    "    t.nn.ReLU(),\n",
    "    t.nn.Linear(32, 10),\n",
    "    t.nn.Softmax(dim=1)\n",
    ")\n",
    "\n",
    "nn_component = HomoNN(name='nn_0',\n",
    "                      model=model, # model\n",
    "                      loss=loss,  # loss\n",
    "                      optimizer=t.optim.Adam(model.parameters(), lr=0.01), # optimizer\n",
    "                      dataset=DatasetParam(dataset_name='mnist_dataset', flatten_feature=True),  # dataset\n",
    "                      trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=2, batch_size=1024, validation_freqs=1),\n",
    "                      torch_seed=100 # random seed\n",
    "                      )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "62361f34",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<pipeline.backend.pipeline.PipeLine at 0x7f15d02f6b50>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipeline.add_component(reader_0)\n",
    "pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))\n",
    "pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "1fa46219",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2022-12-19 18:39:12.858\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191839119838210\n",
      "\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:12.890\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n",
      "\u001b[0mm2022-12-19 18:39:13.940\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n",
      "\u001b[32m2022-12-19 18:39:13.943\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:01\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:14.977\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:02\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:16.036\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:03\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:17.088\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:04\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:18.133\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:05\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:19.184\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:06\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:20.246\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:07\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:21.278\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:08\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:22.319\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:09\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:23.343\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:10\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:24.383\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:11\u001b[0m\n",
      "\u001b[0mm2022-12-19 18:39:26.565\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n",
      "\u001b[32m2022-12-19 18:39:26.568\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:13\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:27.611\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:14\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:28.656\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:15\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:29.713\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:16\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:30.774\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:17\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:31.812\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:18\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:32.857\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:19\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:33.981\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:21\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:35.004\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:22\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:36.092\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:23\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:37.129\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:24\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:38.166\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:25\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:39.244\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:26\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:40.286\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:27\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:41.429\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:28\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:42.479\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:29\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:43.621\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:30\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:44.665\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:31\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:45.717\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:32\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:46.758\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:33\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:47.802\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:34\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:48.847\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:35\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:49.895\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:37\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:50.946\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:38\u001b[0m\n",
      "\u001b[0mm2022-12-19 18:39:53.243\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n",
      "\u001b[32m2022-12-19 18:39:53.246\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:40\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:54.538\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:41\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:55.640\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:42\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:56.688\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:43\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:57.779\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:44\u001b[0m\n",
      "\u001b[32m2022-12-19 18:39:58.820\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:45\u001b[0m\n",
      "\u001b[32m2022-12-19 18:40:00.137\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:47\u001b[0m\n",
      "\u001b[32m2022-12-19 18:40:01.182\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:48\u001b[0m\n",
      "\u001b[32m2022-12-19 18:40:02.214\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:49\u001b[0m\n",
      "\u001b[32m2022-12-19 18:40:03.277\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:50\u001b[0m\n",
      "\u001b[32m2022-12-19 18:40:04.307\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:51\u001b[0m\n",
      "\u001b[32m2022-12-19 18:40:05.342\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:52\u001b[0m\n",
      "\u001b[32m2022-12-19 18:40:06.416\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:53\u001b[0m\n",
      "\u001b[32m2022-12-19 18:40:07.456\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:54\u001b[0m\n",
      "\u001b[32m2022-12-19 18:40:10.543\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191839119838210\u001b[0m\n",
      "\u001b[32m2022-12-19 18:40:10.545\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:57\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "pipeline.compile()\n",
    "pipeline.fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "0edf9014",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>label</th>\n",
       "      <th>predict_result</th>\n",
       "      <th>predict_score</th>\n",
       "      <th>predict_detail</th>\n",
       "      <th>type</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>img_1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.9070178270339966</td>\n",
       "      <td>{'0': 0.9070178270339966, '1': 0.0023874549660...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>img_3</td>\n",
       "      <td>4</td>\n",
       "      <td>6</td>\n",
       "      <td>0.19601570069789886</td>\n",
       "      <td>{'0': 0.19484134018421173, '1': 0.044997252523...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>img_4</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.9618675112724304</td>\n",
       "      <td>{'0': 0.9618675112724304, '1': 0.0010393995326...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>img_5</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.33044907450675964</td>\n",
       "      <td>{'0': 0.33044907450675964, '1': 0.033256266266...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>img_6</td>\n",
       "      <td>7</td>\n",
       "      <td>7</td>\n",
       "      <td>0.3145765960216522</td>\n",
       "      <td>{'0': 0.05851678550243378, '1': 0.075524508953...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1304</th>\n",
       "      <td>img_32537</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "      <td>0.20599651336669922</td>\n",
       "      <td>{'0': 0.080563984811306, '1': 0.12380836158990...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1305</th>\n",
       "      <td>img_32558</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "      <td>0.20311488211154938</td>\n",
       "      <td>{'0': 0.07224143296480179, '1': 0.130610913038...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1306</th>\n",
       "      <td>img_32563</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "      <td>0.2071550488471985</td>\n",
       "      <td>{'0': 0.06843454390764236, '1': 0.129064396023...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1307</th>\n",
       "      <td>img_32565</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>0.29367145895957947</td>\n",
       "      <td>{'0': 0.05658009275794029, '1': 0.086584843695...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1308</th>\n",
       "      <td>img_32573</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "      <td>0.199515700340271</td>\n",
       "      <td>{'0': 0.08787216246128082, '1': 0.127247273921...</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1309 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "             id label predict_result        predict_score  \\\n",
       "0         img_1     0              0   0.9070178270339966   \n",
       "1         img_3     4              6  0.19601570069789886   \n",
       "2         img_4     0              0   0.9618675112724304   \n",
       "3         img_5     0              0  0.33044907450675964   \n",
       "4         img_6     7              7   0.3145765960216522   \n",
       "...         ...   ...            ...                  ...   \n",
       "1304  img_32537     1              8  0.20599651336669922   \n",
       "1305  img_32558     1              8  0.20311488211154938   \n",
       "1306  img_32563     1              8   0.2071550488471985   \n",
       "1307  img_32565     1              5  0.29367145895957947   \n",
       "1308  img_32573     1              8    0.199515700340271   \n",
       "\n",
       "                                         predict_detail   type  \n",
       "0     {'0': 0.9070178270339966, '1': 0.0023874549660...  train  \n",
       "1     {'0': 0.19484134018421173, '1': 0.044997252523...  train  \n",
       "2     {'0': 0.9618675112724304, '1': 0.0010393995326...  train  \n",
       "3     {'0': 0.33044907450675964, '1': 0.033256266266...  train  \n",
       "4     {'0': 0.05851678550243378, '1': 0.075524508953...  train  \n",
       "...                                                 ...    ...  \n",
       "1304  {'0': 0.080563984811306, '1': 0.12380836158990...  train  \n",
       "1305  {'0': 0.07224143296480179, '1': 0.130610913038...  train  \n",
       "1306  {'0': 0.06843454390764236, '1': 0.129064396023...  train  \n",
       "1307  {'0': 0.05658009275794029, '1': 0.086584843695...  train  \n",
       "1308  {'0': 0.08787216246128082, '1': 0.127247273921...  train  \n",
       "\n",
       "[1309 rows x 6 columns]"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipeline.get_component('nn_0').get_output_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "8592212b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'best_epoch': 1,\n",
       " 'loss_history': [3.58235876026547, 3.4448592824914055],\n",
       " 'metrics_summary': {'train': {'accuracy': [0.25668449197860965,\n",
       "    0.4950343773873186],\n",
       "   'precision': [0.3708616690797323, 0.5928620913124757],\n",
       "   'recall': [0.21817632850241547, 0.4855654369784805]}},\n",
       " 'need_stop': False}"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipeline.get_component('nn_0').get_summary()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]"
  },
  "vscode": {
   "interpreter": {
    "hash": "d29574a2ab71ec988cdcd4d29c58400bd2037cad632b9528d973466f7fb6f853"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
