{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Using pre-trained DreaMS in a custom model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This simple tutorial demonstrates how to use the pre-trained DreaMS weights in a custom PyTorch model." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([32, 1])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "from dreams.api import PreTrainedModel\n", "from dreams.models.dreams.dreams import DreaMS as DreaMSModel\n", "\n", "# Example model\n", "class CustomModel(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.spec_encoder = PreTrainedModel.from_ckpt(\n", " # ckpt_path should be replaced with the path to the ssl_model.ckpt model downloaded from https://zenodo.org/records/10997887\n", " ckpt_path=\"\", ckpt_cls=DreaMSModel, n_highest_peaks=60\n", " ).model.train()\n", "\n", " # Example head for a downstream task (e.g., for binary classification)\n", " self.lin_out = nn.Linear(1024, 1)\n", "\n", " def forward(self, x):\n", " x = self.spec_encoder(x)[:, 0, :] # [:, 0, :] to get the precursor peak token embedding\n", " x = F.sigmoid(self.lin_out(x)) # Example forward pass through the head\n", " return x\n", "\n", "model = CustomModel()\n", "example_in = torch.rand(32, 100, 2) # Example input (32 = batch size, 100 = num. peaks, 2 = m/z and intensity)\n", "example_out = model(example_in)\n", "example_out.shape" ] } ], "metadata": { "kernelspec": { "display_name": "dreams", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 2 }