What should graphics based deep learning do?

Time:2021-10-25

What should graphics based deep learning do?

Because it can extract complex patterns from all kinds of complex data (such as free text, image and video), deep learning technology has been widely used. However, in practice, it will be found that many data sets are easier to express in the form of graph, such as the relationship between people on social networks.

For this kind of data, the traditional neural network architecture such as convolutional neural network and recursive neural network used in deep learning technology is somewhat difficult to deal with. Maybe it is time to introduce a new method.

Graphical neural network (GNN)

Graphical neural network (GNN) is one of the most flourishing directions in the field of machine learning. This technology can usually be used to train prediction models on graphical data sets, such as:

  • A social network dataset in which connections between acquaintances are graphically displayed;
  • Recommend a system data set in which the interaction between customers and items is graphically displayed;
  • A chemical analysis data set in which compounds are modeled as graphs composed of atoms and chemical bonds;
  • Network security dataset, which graphically describes the connection between the source IP address and the target IP address

In most cases, these data sets are very large and have only some tags. Take a typical fraud detection scenario as an example. In this scenario, we will try to predict the possibility that a person is a fraud actor by analyzing the relationship between a person and a known fraudster. This problem can be defined as a semi supervised learning task, in which only a small number of graph nodes will be marked (“fraudster” or “legitimate”). This should be a better solution than trying to build a large manually labeled data set and “linearize” it to apply the traditional machine learning algorithm.

For further introduction to GNN, please refer to thesereference

Amazon sagemaker now supports the open source deep graph library

In the practical application of GNN, we usually need to have knowledge in specific fields (retail, finance, chemistry, etc.), computer science knowledge (Python, in-depth learning, open source tools) and infrastructure related knowledge (training, deployment and expansion model). The requirements are many and complex, and few people can master all these skills.
And Amazon sagemakerDeep Graph LibraryOur support solves these problems.
Deep Graph Library(DGL) for the first time in December 2018GithubIt is an open source Python library that can help researchers and scientists quickly build, train and evaluate GNN using its data sets.

What should graphics based deep learning do?

DGL is built on popular deep learning frameworks, such asPyTorchandApache MXNet。 If you are familiar with one of them, you will find it easy to use. No matter which framework we use, we can use these friendly for beginnersExampleEasy to get started. In addition, GTC 2019workshopThe slides and code provided by can also help us get started quickly.
Once you have completed the toy example, you can begin to explore the various implemented in DGLfrontier modelYes. For example, we can use the following commandGraph convolution network(GCN) andCORAData sets to train the document classification model:

`$ python3 train.py --dataset cora --gpu 0 --self-loop`

The code of all models can be viewed and adjusted. These implementations have been carefully validated by the AWS team, who verified the performance statement and ensured that the results can be reproduced.

DGL also containsGraphic datasetWe can easily download and use it for experiments.
Of course, we can also install and run DGL locally, but AWS has added it to pytorch and Apache mxnet for convenienceDeep learning containerYes. So you can easilyAmazon SageMakerUse DGL on to train and deploy models on any scale without having to manage servers.

Using DGL on Amazon sagemaker

AWS in GitHubmemory poolZhongweiSageMakerAdded a full example: in one of the examples, we usedTox21The dataset trained a simple GNN for molecular toxicity prediction.

The problem we try to solve is to calculate the potential toxicity of the new compound to 12 different targets (receptors in biological cells, etc.). It is conceivable that this analysis is very important in the design of new drugs, and the results can be quickly predicted without in vitro experiments, which helps researchers focus on the most promising drug candidates.

The dataset contains more than 8000 compounds: each compound is modeled as a graph (atoms are vertices and atomic bonds are edges) and marked 12 times (one mark for each target). We will use GNN to establish a multi label binary classification model, so that we can predict the potential toxicity of the molecules investigated.
In the training script, we can easily download the required data set from the DGL set.

from dgl.data.chem import Tox21
dataset = Tox21()

Similarly, we can also use DGLModel zooEasily build a GNN classifier:

from dgl import model_zoo
model = model_zoo.chem.GCNClassifier(
    in_feats=args['n_input'],
    gcn_hidden_feats=[args['n_hidden'] for _ in range(args['n_layers'])],
    n_tasks=dataset.n_tasks,
    classifier_hidden_feats=args['n_hidden']).to(args['device'])

Most of the rest of the code is the original pytorch. If you are familiar with this library, you should be able to use it easily.

To be inAmazon SageMakerTo run this code on, all we have to do is useSageMakerSimulator, passing the full name of the DGL container and taking the name of the training script as a super parameter.

estimator = sagemaker.estimator.Estimator(container,
    role,
    train_instance_count=1,
    train_instance_type='ml.p3.2xlarge',
    hyperparameters={'entrypoint': 'main.py'},
    sagemaker_session=sess)
code_location = sess.upload_data(CODE_PATH,
bucket=bucket,
key_prefix=custom_code_upload_location)
estimator.fit({'training-code': code_location})
<output removed>
epoch 23/100, batch 48/49, loss 0.4684
epoch 23/100, batch 49/49, loss 0.5389
epoch 23/100, training roc-auc 0.9451
EarlyStopping counter: 10 out of 10
epoch 23/100, validation roc-auc 0.8375, best validation roc-auc 0.8495
Best validation score 0.8495
Test score 0.8273
2019-11-21 14:11:03 Uploading - Uploading generated training model
2019-11-21 14:11:03 Completed - Training job completed
Training seconds: 209
Billable seconds: 209

Now, we can getS3The trained model is used to predict the toxicity of a large number of compounds without actual experiments.

Try now

You can nowAmazon SageMakerUse DGL on.
While experiencing by yourself, you might as well pass DGLforumAmazon SageMakerofAWS platformOr your usual AWS support contact information to send us feedback.

What should graphics based deep learning do?

Recommended Today

Swift advanced (XV) extension

The extension in swift is somewhat similar to the category in OC Extension can beenumeration、structural morphology、class、agreementAdd new features□ you can add methods, calculation attributes, subscripts, (convenient) initializers, nested types, protocols, etc What extensions can’t do:□ original functions cannot be overwritten□ you cannot add storage attributes or add attribute observers to existing attributes□ cannot add parent […]