{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "0d19012c",
   "metadata": {},
   "source": [
    "# Hetero NN: A federated task with guest using image data and host using text data\n",
    "\n",
    "In this task, we will show you how to build a federated task under Hetero-NN, in which the participating parties use different structured data: the guest party has image data and labels, and the host party has text, and together they complete a binary classification task. The tutorial dataset is built by flickr 8k, and labels 0 and 1 indicate whether the image is in the wilderness or in the city. You can download the processed dataset from here and put it under examples/data. The complete dataset can be downloaded from here. (Please note that the original dataset is different from the data in this example, and this dataset is annotated with a small portion of the complete dataset for demonstration purposes.)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "08010311",
   "metadata": {},
   "source": [
    "## Get the example dataset:\n",
    "\n",
    "Please down load the dataset from:\n",
    "- https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/flicker_toy_data.zip\n",
    "and put it under /examples/data folder\n",
    "\n",
    "The origin of this dataset is the flickr-8k dataset from:\n",
    "- https://www.kaggle.com/datasets/adityajn105/flickr8k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "15b81e88",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pipeline.component.nn import save_to_fate"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7bdde3a2",
   "metadata": {},
   "source": [
    "### Guest Bottom Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "98e9a771",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%save_to_fate model guest_bottom_image.py\n",
    "from torch import nn\n",
    "import torch as t\n",
    "from torch.nn import functional as F\n",
    "\n",
    "class ImgBottomNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ImgBottomNet, self).__init__()\n",
    "        self.seq = t.nn.Sequential(\n",
    "            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),\n",
    "            nn.MaxPool2d(kernel_size=3),\n",
    "            nn.Conv2d(in_channels=6, out_channels=6, kernel_size=3),\n",
    "            nn.AvgPool2d(kernel_size=5)\n",
    "        )\n",
    "        \n",
    "        self.fc = t.nn.Sequential(\n",
    "            nn.Linear(1176, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(32, 8)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.seq(x)\n",
    "        x = x.flatten(start_dim=1)\n",
    "        x = self.fc(x)\n",
    "        return x\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b55fce67",
   "metadata": {},
   "source": [
    "## Guest Top Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f2a81854",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%save_to_fate model guest_top_image.py\n",
    "from torch import nn\n",
    "import torch as t\n",
    "from torch.nn import functional as F\n",
    "\n",
    "class ImgTopNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ImgTopNet, self).__init__()\n",
    "        \n",
    "        self.fc = t.nn.Sequential(\n",
    "            nn.Linear(4, 1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc(x)\n",
    "        return x.flatten()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "abea2f8e",
   "metadata": {},
   "source": [
    "### Host Bottom Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "84d92fa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%save_to_fate model host_bottom_lstm.py\n",
    "from torch import nn\n",
    "import torch as t\n",
    "from torch.nn import functional as F\n",
    "\n",
    "class LSTMBottom(nn.Module):\n",
    "    \n",
    "    def __init__(self, vocab_size):\n",
    "        super(LSTMBottom, self).__init__()\n",
    "        self.word_embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=16, padding_idx=0)\n",
    "        self.lstm = t.nn.Sequential(\n",
    "            nn.LSTM(input_size=16, hidden_size=16, num_layers=2, batch_first=True)\n",
    "        )\n",
    "        self.act = nn.ReLU()\n",
    "        self.linear = nn.Linear(16, 8)\n",
    "\n",
    "    def forward(self, x):\n",
    "        embeddings = self.word_embed(x)\n",
    "        lstm_fw, _ = self.lstm(embeddings)\n",
    "        \n",
    "        return self.act(self.linear(lstm_fw.sum(dim=1)))    "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b67515a5",
   "metadata": {},
   "source": [
    "### Locally test dataset and bottom model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e6bca16b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-12-26 20:45:42.535744: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
      "2022-12-26 20:45:42.535777: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
     ]
    }
   ],
   "source": [
    "from federatedml.nn.dataset.image import ImageDataset\n",
    "from federatedml.nn.dataset.nlp_tokenizer import TokenizerDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f3495837",
   "metadata": {},
   "outputs": [],
   "source": [
    "# flicke image\n",
    "img_ds = ImageDataset(center_crop=True, center_crop_shape=(224, 224), return_label=True) # return label = True\n",
    "img_ds.load('../../../../examples/data/flicker_toy_data/flicker/images/')\n",
    "# text\n",
    "txt_ds = TokenizerDataset(return_label=False) \n",
    "txt_ds.load('../../../../examples/data/flicker_toy_data/text.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7542d6d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "215\n",
      "(tensor([[[0.5059, 0.5176, 0.5137,  ..., 0.4941, 0.5020, 0.5059],\n",
      "         [0.4980, 0.5020, 0.4980,  ..., 0.4824, 0.5020, 0.5059],\n",
      "         [0.5059, 0.4863, 0.4902,  ..., 0.4980, 0.4980, 0.5137],\n",
      "         ...,\n",
      "         [0.7843, 0.7922, 0.7529,  ..., 0.1412, 0.2078, 0.2196],\n",
      "         [0.9922, 0.9922, 0.9647,  ..., 0.1176, 0.0941, 0.1333],\n",
      "         [0.9961, 0.9922, 1.0000,  ..., 0.1647, 0.1294, 0.1373]],\n",
      "\n",
      "        [[0.5765, 0.5882, 0.5843,  ..., 0.5490, 0.5569, 0.5608],\n",
      "         [0.5686, 0.5804, 0.5765,  ..., 0.5490, 0.5529, 0.5529],\n",
      "         [0.5608, 0.5569, 0.5647,  ..., 0.5569, 0.5490, 0.5529],\n",
      "         ...,\n",
      "         [0.7961, 0.8039, 0.7490,  ..., 0.1373, 0.1882, 0.2000],\n",
      "         [0.9961, 0.9961, 0.9608,  ..., 0.1137, 0.1137, 0.1529],\n",
      "         [0.9922, 0.9922, 1.0000,  ..., 0.1608, 0.1059, 0.1216]],\n",
      "\n",
      "        [[0.6235, 0.6353, 0.6314,  ..., 0.5922, 0.6000, 0.6118],\n",
      "         [0.6078, 0.6235, 0.6196,  ..., 0.5804, 0.5882, 0.6000],\n",
      "         [0.6039, 0.6118, 0.6196,  ..., 0.5843, 0.5843, 0.6000],\n",
      "         ...,\n",
      "         [0.5882, 0.5961, 0.5686,  ..., 0.1216, 0.1765, 0.1882],\n",
      "         [0.7294, 0.7373, 0.7373,  ..., 0.0980, 0.0980, 0.1294],\n",
      "         [0.8745, 0.8431, 0.8627,  ..., 0.1451, 0.1059, 0.1176]]]), tensor(0))\n",
      "[0, 1]\n",
      "['1022454428_b6b660a67b', '103195344_5d2dc613a3', '1055753357_4fa3d8d693', '1124448967_2221af8dc5', '1131804997_177c3c0640', '1138784872_69ade3f2ab', '1142847777_2a0c1c2551', '1143373711_2e90b7b799', '1143882946_1898d2eeb9', '1144288288_e5c9558b6a']\n"
     ]
    }
   ],
   "source": [
    "print(len(img_ds))\n",
    "print(img_ds[0])\n",
    "print(img_ds.get_classes())\n",
    "print(img_ds.get_sample_ids()[0: 10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6fae43a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "215\n",
      "tensor([  101,  1037,  2158,  1998,  2450,  2729,  2005,  2019, 10527,  2247,\n",
      "         1996,  2217,  1997,  1037,  2303,  1997,  2300,  1012,   102,     0,\n",
      "            0,     0,     0,     0,     0,     0])\n",
      "30522\n"
     ]
    }
   ],
   "source": [
    "print(len(txt_ds))\n",
    "print(txt_ds[0]) # word idx\n",
    "print(txt_ds.get_vocab_size()) # vocab size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c857db34",
   "metadata": {},
   "outputs": [],
   "source": [
    "img_bottom = ImgBottomNet()\n",
    "lstm_bottom = LSTMBottom(vocab_size=txt_ds.get_vocab_size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c4fd9b82",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 1.8284, 0.0000, 0.0000, 2.3009, 0.0626, 0.0678, 0.0000],\n",
       "        [0.0369, 1.8046, 0.0000, 0.0000, 2.4555, 0.0000, 0.0000, 0.0000]],\n",
       "       grad_fn=<ReluBackward0>)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstm_bottom(t.vstack([txt_ds[0], txt_ds[1]]))  # test forward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9add8cc6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0987,  0.0808, -0.0140, -0.0718, -0.1381,  0.2642, -0.1874, -0.0494],\n",
       "        [ 0.0856,  0.0948, -0.0362, -0.0702, -0.0695,  0.2293, -0.1768, -0.0638]],\n",
       "       grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img_bottom(t.vstack([img_ds[0][0].unsqueeze(dim=0), img_ds[1][0].unsqueeze(dim=0)])) "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "9950c6bd",
   "metadata": {},
   "source": [
    "### Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d4a51493",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'namespace': 'experiment', 'table_name': 'flicker_host'}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import torch as t\n",
    "from torch import nn\n",
    "from pipeline import fate_torch_hook\n",
    "from pipeline.component import HeteroNN\n",
    "from pipeline.component.hetero_nn import DatasetParam\n",
    "from pipeline.backend.pipeline import PipeLine\n",
    "from pipeline.component import Reader, Evaluation, DataTransform\n",
    "from pipeline.interface import Data, Model\n",
    "from pipeline.component.nn import save_to_fate\n",
    "\n",
    "fate_torch_hook(t)\n",
    "\n",
    "fate_project_path = os.path.abspath('../../../../')\n",
    "guest = 10000\n",
    "host = 9999\n",
    "\n",
    "pipeline_mix = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host)\n",
    "\n",
    "guest_data = {\"name\": \"flicker_guest\", \"namespace\": \"experiment\"}\n",
    "host_data = {\"name\": \"flicker_host\", \"namespace\": \"experiment\"}\n",
    "\n",
    "guest_data_path = fate_project_path + '/examples/data/flicker_toy_data/flicker/images'\n",
    "host_data_path = fate_project_path + '/examples/data/flicker_toy_data/text.csv'\n",
    "\n",
    "pipeline_mix.bind_table(name='flicker_guest', namespace='experiment', path=guest_data_path)\n",
    "pipeline_mix.bind_table(name='flicker_host', namespace='experiment', path=host_data_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "50efe200",
   "metadata": {},
   "outputs": [],
   "source": [
    "reader_0 = Reader(name=\"reader_0\")\n",
    "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=guest_data)\n",
    "reader_0.get_party_instance(role='host', party_id=host).component_param(table=host_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4475a968",
   "metadata": {},
   "outputs": [],
   "source": [
    "hetero_nn_0 = HeteroNN(name=\"hetero_nn_0\", epochs=5,\n",
    "                       interactive_layer_lr=0.001, batch_size=64, validation_freqs=1, task_type='classification')\n",
    "guest_nn_0 = hetero_nn_0.get_party_instance(role='guest', party_id=guest)\n",
    "host_nn_0 = hetero_nn_0.get_party_instance(role='host', party_id=host)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a2a591e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "guest_bottom = t.nn.Sequential(\n",
    "    nn.CustModel(module_name='guest_bottom_image', class_name='ImgBottomNet')\n",
    ")\n",
    "\n",
    "guest_top = t.nn.Sequential(\n",
    "    nn.CustModel(module_name='guest_top_image', class_name='ImgTopNet')\n",
    ")\n",
    "# bottom model\n",
    "host_bottom = nn.CustModel(module_name='host_bottom_lstm', class_name='LSTMBottom', vocab_size=txt_ds.get_vocab_size())\n",
    "\n",
    "interactive_layer = t.nn.InteractiveLayer(out_dim=4, guest_dim=8, host_dim=8, host_num=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b8799751",
   "metadata": {},
   "outputs": [],
   "source": [
    "guest_nn_0.add_top_model(guest_top)\n",
    "guest_nn_0.add_bottom_model(guest_bottom)\n",
    "host_nn_0.add_bottom_model(host_bottom)\n",
    "optimizer = t.optim.Adam(lr=0.001)\n",
    "loss = t.nn.BCELoss()\n",
    "\n",
    "hetero_nn_0.set_interactive_layer(interactive_layer)\n",
    "hetero_nn_0.compile(optimizer=optimizer, loss=loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ea5dae1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 添加dataset\n",
    "guest_nn_0.add_dataset(DatasetParam(dataset_name='image', return_label=True, center_crop=True, center_crop_shape=(224, 224), label_dtype='float'))\n",
    "host_nn_0.add_dataset(DatasetParam(dataset_name='nlp_tokenizer', return_label=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "aa06e1bf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<pipeline.backend.pipeline.PipeLine at 0x7fb269222be0>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipeline_mix.add_component(reader_0)\n",
    "pipeline_mix.add_component(hetero_nn_0, data=Data(train_data=reader_0.output.data))\n",
    "pipeline_mix.compile()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "bdc60285",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline_mix.fit()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]"
  },
  "vscode": {
   "interpreter": {
    "hash": "d29574a2ab71ec988cdcd4d29c58400bd2037cad632b9528d973466f7fb6f853"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
