Using pre-trained DreaMS in a custom model
This simple tutorial demonstrates how to use the pre-trained DreaMS weights in a custom PyTorch model.
[2]:
import torch
from torch import nn
import torch.nn.functional as F
from dreams.api import PreTrainedModel
from dreams.models.dreams.dreams import DreaMS as DreaMSModel
# Example model
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.spec_encoder = PreTrainedModel.from_ckpt(
# ckpt_path should be replaced with the path to the ssl_model.ckpt model downloaded from https://zenodo.org/records/10997887
ckpt_path="<path/to/ssl_model.ckpt>", ckpt_cls=DreaMSModel, n_highest_peaks=60
).model.train()
# Example head for a downstream task (e.g., for binary classification)
self.lin_out = nn.Linear(1024, 1)
def forward(self, x):
x = self.spec_encoder(x)[:, 0, :] # [:, 0, :] to get the precursor peak token embedding
x = F.sigmoid(self.lin_out(x)) # Example forward pass through the head
return x
model = CustomModel()
example_in = torch.rand(32, 100, 2) # Example input (32 = batch size, 100 = num. peaks, 2 = m/z and intensity)
example_out = model(example_in)
example_out.shape
[2]:
torch.Size([32, 1])