{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 4: Protein DMS modeling using a biophysical G-P map" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2021-12-30T00:37:35.454595Z", "start_time": "2021-12-30T00:37:33.643066Z" }, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Standard imports\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "# Special imports\n", "import mavenn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we show how to train and visualize a thermodynamic model describing the folding and IgG-binding of protein GB1 variants. This model was first proposed by Otwinowski (2018), who trained it on the DMS data of Olson et al. (2014). Here we repeat this exercise within the MAVE-NN framework, thus obtaining a model similar to the one featured in Figs. 6a and 6b of Tareen et al. (2021). The mathematical form of this G-P map is explianed in the supplemental material of Tareen et al. (2021); see in particular Fig. S4a." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining a custom G-P map\n", "\n", "First we define a custom G-P map that represents our biophysical model. We do this by subclassing `CustomGPMapLayer` to get a custom G-P map class called `OtwinowskiGPMapLayer`. This subclassing procedure requires that we fill in the bodies of two specific methods.\n", "- `__init__()`: This constructor must first call the superclass constructor, which sets the attributes `L`, `C`, and `regularizer`. The the derived class constructor then defines all of the trainable parameters of the G-P map: `theta_f_0`, `theta_b_0`, `theta_f_lc`, and `theta_b_lc` in this case. \n", "- `call()`: This is the meat of the custom G-P map. The input `x_lc` is a one-hot encoding of all sequences in a minimatch. It has size `[-1, L, C]`, where the first index runs over minibatch examples. The G-P map parameters are then used to compute and return a vector `phi` of latent phenotype values, one for each input sequence in the minibatch." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2021-12-30T00:37:35.459816Z", "start_time": "2021-12-30T00:37:35.455682Z" } }, "outputs": [], "source": [ "# Standard TensorFlow imports\n", "import tensorflow as tf\n", "import tensorflow.keras.backend as K\n", "from tensorflow.keras.initializers import Constant\n", "\n", "# Import base class\n", "from mavenn.src.layers.gpmap import GPMapLayer\n", "\n", "# Define custom G-P map layer\n", "class OtwinowskiGPMapLayer(GPMapLayer):\n", " \"\"\"\n", " A G-P map representing the thermodynamic model described by\n", " Otwinowski (2018).\n", " \"\"\"\n", "\n", " def __init__(self, *args, **kwargs):\n", " \"\"\"Construct layer instance.\"\"\"\n", "\n", " # Call superclass constructor\n", " # Sets self.L, self.C, and self.regularizer\n", " super().__init__(*args, **kwargs)\n", " \n", " # Initialize constant parameter for folding energy\n", " self.theta_f_0 = self.add_weight(name='theta_f_0',\n", " shape=(1,),\n", " trainable=True,\n", " regularizer=self.regularizer)\n", "\n", " # Initialize constant parameter for binding energy\n", " self.theta_b_0 = self.add_weight(name='theta_b_0',\n", " shape=(1,),\n", " trainable=True,\n", " regularizer=self.regularizer)\n", "\n", " # Initialize additive parameter for folding energy\n", " self.theta_f_lc = self.add_weight(name='theta_f_lc',\n", " shape=(1, self.L, self.C),\n", " trainable=True,\n", " regularizer=self.regularizer)\n", "\n", " # Initialize additive parameter for binding energy\n", " self.theta_b_lc = self.add_weight(name='theta_b_lc',\n", " shape=(1, self.L, self.C),\n", " trainable=True,\n", " regularizer=self.regularizer)\n", "\n", " def call(self, x_lc):\n", " \"\"\"Compute phi given x.\"\"\"\n", "\n", " # 1kT = 0.582 kcal/mol at room temperature\n", " kT = 0.582\n", "\n", " # Reshape input to samples x length x characters\n", " x_lc = tf.reshape(x_lc, [-1, self.L, self.C])\n", " \n", " # Compute Delta G for binding\n", " Delta_G_b = self.theta_b_0 + \\\n", " tf.reshape(K.sum(self.theta_b_lc * x_lc, axis=[1, 2]),\n", " shape=[-1, 1])\n", " \n", " # Compute Delta G for folding\n", " Delta_G_f = self.theta_f_0 + \\\n", " tf.reshape(K.sum(self.theta_f_lc * x_lc, axis=[1, 2]),\n", " shape=[-1, 1])\n", " \n", " # Compute and return fraction folded and bound\n", " Z = 1+K.exp(-Delta_G_f/kT)+K.exp(-(Delta_G_f+Delta_G_b)/kT)\n", " p_bf = (K.exp(-(Delta_G_f+Delta_G_b)/kT))/Z\n", " phi = p_bf #K.log(p_bf)/np.log(2)\n", " return phi" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Training a model with a custom G-P map\n", "\n", "Next we load the `'gb1'` dataset, compute sequence length, and split the data into a test set and a training+validation set." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2021-12-30T00:37:36.116339Z", "start_time": "2021-12-30T00:37:35.460796Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading dataset 'gb1' \n", "Sequence length: 55 amino acids\n", "Training set : 477,854 observations ( 90.04%)\n", "Validation set : 26,519 observations ( 5.00%)\n", "Test set : 26,364 observations ( 4.97%)\n", "-------------------------------------------------\n", "Total dataset : 530,737 observations ( 100.00%)\n", "\n", "trainval_df:\n" ] }, { "data": { "text/html": [ "
| \n", " | validation | \n", "dist | \n", "input_ct | \n", "selected_ct | \n", "y | \n", "x | \n", "
|---|---|---|---|---|---|---|
| 0 | \n", "False | \n", "2 | \n", "173 | \n", "33 | \n", "-3.145154 | \n", "AAKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDD... | \n", "
| 1 | \n", "False | \n", "2 | \n", "18 | \n", "8 | \n", "-1.867676 | \n", "ACKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDD... | \n", "
| 2 | \n", "False | \n", "2 | \n", "66 | \n", "2 | \n", "-5.270800 | \n", "ADKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDD... | \n", "
| 3 | \n", "False | \n", "2 | \n", "72 | \n", "1 | \n", "-5.979498 | \n", "AEKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDD... | \n", "
| 4 | \n", "False | \n", "2 | \n", "69 | \n", "168 | \n", "0.481923 | \n", "AFKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDD... | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 504368 | \n", "False | \n", "2 | \n", "462 | \n", "139 | \n", "-2.515259 | \n", "QYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDD... | \n", "
| 504369 | \n", "False | \n", "2 | \n", "317 | \n", "84 | \n", "-2.693165 | \n", "QYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDD... | \n", "
| 504370 | \n", "False | \n", "2 | \n", "335 | \n", "77 | \n", "-2.896589 | \n", "QYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDD... | \n", "
| 504371 | \n", "False | \n", "2 | \n", "148 | \n", "28 | \n", "-3.150861 | \n", "QYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDD... | \n", "
| 504372 | \n", "False | \n", "2 | \n", "95 | \n", "16 | \n", "-3.287173 | \n", "QYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDD... | \n", "
504373 rows × 6 columns
\n", "