# Homo-Graph Customized Model

In FATE 1.10, we integrated Torch-geometric 2.2 into the FATE framework with which you can build Graph Neural Networks (GNN) in a homo federated way. Homo-graph is an extension of the customized model, but there are some differences in terms of input data and trainer.

## Install Torch-geometric

PLEASE MAKE SURE YOUR GCC IS ABOVE 5.5, OR THE INSTALLATION MAY FAIL!

For the installation please refer to [torch-geometric web site](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html), or you may try this:


pip install -r {path/to/your/fate/base}/python/requirements-graph-nn.txt

## Cora Dataset

Cora is a graph dataset for multiple node classification. It has 2708 nodes and 10k edges. Each node has 1433 features. 

In federated homo graph modeling, each party holds their own graph dataset with the same features, i.e. horizontal federation. The nodes in the two graphs may not overlap and the parties do not exchange any information about their graph datasets duraing modeling.

For simplicity, the host and the guest in the demo have the same Cora dataset. The train/validation/test is divided in the following way:

train: [0:140]
validation: [200:500]
test: [500:1500]

The preprocessed data can be find in examples/data/cora4fate which contains "guest" and "host" directory. both "guest" and "host" have feats.csv and adj.csv to store the node feature and adjacent matrix respectively. 


## GraphSage Model

Name the model code as homegraphsage.py. You can put it directly under federatedml/nn/model_zoo or use the shortcut interface of jupyter notebook to save it directly to federatedml/nn/model_zoo

In [None]:
from pipeline.component.nn import save_to_fate

In [None]:
%%save_to_fate model graphsage_cora.py

import torch as t
from torch import nn
from torch.nn import Module
import torch_geometric.nn as pyg


class Sage(nn.Module):
    def __init__(self, in_channels, hidden_channels, class_num):
        super().__init__()
        self.model = nn.ModuleList([
            pyg.SAGEConv(in_channels=in_channels, out_channels=hidden_channels, project=True),
            pyg.SAGEConv(in_channels=hidden_channels, out_channels=class_num),
            nn.LogSoftmax()]
        )

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.model):
            if isinstance(conv, pyg.SAGEConv):
                x = conv(x, edge_index)
            else:
                x = conv(x)
        return x 

In [None]:
homosage = Sage(in_channels=1433, hidden_channels=64, class_num=7)
homosage

## Submit a Homo-NN task with Custom Model

In [None]:
import os
cwd = os.getcwd()
cwd

Initialize a pipeline and specifiy where the node feature and adjacent matrix file is.

In [None]:
import torch as t
import os
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component.nn import TrainerParam
from pipeline.backend.pipeline import PipeLine
from pipeline.component import HomoNN, Evaluation
from pipeline.component.reader import Reader
from pipeline.interface import Data
from pipeline.component.nn import DatasetParam

fate_torch_hook(t)
fate_project_path = os.path.abspath("/../../../../")
host = 10000
guest = 9999
arbiter = 10000
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,
                                                                        arbiter=arbiter)
data_0 = {"name": "cora_guest", "namespace": "experiment"}
data_1 = {"name": "cora_host", "namespace": "experiment"}

data_path_0 = fate_project_path + '/examples/data/cora4fate/guest'
data_path_1 = fate_project_path + '/examples/data/cora4fate/host'
 

Bind table and use the DatasetParam to specifiy the following parameters:
1. id_col, name of the id column in csv, default 'id'
2. label_col str, name of label column in csv, if None, will automatically take 'y' or 'label' or 'target' as label
3. feature_dtype dtype of feature, supports int, long, float, double
4. label_dtype: dtype of label, supports int, long, float, double
5. feats_name: name of the node feature csv, default 'feats.csv'
6. feats_dataset_col: name of the dataset column indicating to which dataset the node belongs, default 'dataset'
7. feats_dataset_train: flag of the train set
8. feats_dataset_vali: flag of the validation set
9. feats_dataset_test: flag of the test set
10. adj_name: name of the adjacent matrix, default 'adj.csv'
11. adj_src_col: source node in the adjacent matrix, default 'node1'
12. adj_dst_col: destination node in the adjacent matrix, default 'node2'

In [None]:
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)
pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)

reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)
dataset_param = DatasetParam(
    "graph",
    id_col='id',
    label_col='y',
    feature_dtype='float',
    label_dtype='long',
    feats_name='feats.csv',
    feats_dataset_col='dataset',
    feats_dataset_train='train',
    feats_dataset_vali='vali',
    feats_dataset_test='test',
    adj_name='adj.csv',
    adj_src_col='node1',
    adj_dst_col='node2')

setup the model, loss function and optimizer

In [None]:
model = t.nn.Sequential(
    t.nn.CustModel(module_name='graphsage_cora', class_name='Sage', in_channels=1433, hidden_channels=64, class_num=7)
)
loss = nn.NLLLoss()
optimizer = t.optim.Adam(model.parameters(), lr=0.001)

homo_graph_0 = HomoNN(
    name="homo_graph_0",
    model=model,
    loss=loss,
    optimizer=optimizer,
    dataset=dataset_param,
    trainer=TrainerParam(trainer_name='fedavg_graph_trainer', epochs=10, batch_size=10,
                            validation_freqs=1, num_neighbors=[11, 11], task_type='multi'),
    torch_seed=100
)

Integrate every component together and fit

In [None]:
pipeline.add_component(reader_0)
pipeline.add_component(homo_graph_0, data=Data(train_data=reader_0.output.data))
pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=homo_graph_0.output.data))

pipeline.compile()
pipeline.fit()