dreams.models.heads package
Submodules
dreams.models.heads.heads module
- class dreams.models.heads.heads.BinClassificationHead(backbone_pth: Path, lr, weight_decay, head_depth=1, head_phi_depth=0, dropout=0, focal_loss_alpha=None, focal_loss_gamma=0)
Bases:
FineTuningHeadBinary Classification head for fine-tuning tasks.
This class implements a binary classification head that can be used for various classification tasks on top of the DreaMS backbone.
- head
The classification head module.
- Type:
nn.Module
- metrics
Dictionary to store various classification metrics.
- Type:
dict
- loss
The loss function (Focal Loss).
- Type:
nn.Module
Initialize the BinClassificationHead.
- Parameters:
backbone_pth (Path) – Path to the pre-trained backbone.
lr (float) – Learning rate for the optimizer.
weight_decay (float) – Weight decay for the optimizer.
head_depth (int) – Depth of the classification head (default: 1).
head_phi_depth (int) – Depth of the phi network in DeepSets (default: 0).
dropout (float) – Dropout rate (default: 0).
focal_loss_alpha (float, optional) – Alpha parameter for Focal Loss.
focal_loss_gamma (float) – Gamma parameter for Focal Loss (default: 0).
- on_validation_epoch_end()
Perform operations at the end of each validation epoch. This method computes and logs ROC and PR curves.
- step(data, batch_idx)
Perform a single step of computation.
- Parameters:
data (dict) – A dictionary containing the input data.
batch_idx (int) – The index of the current batch.
- Returns:
A tuple containing the predicted labels and the loss.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- training_step(data, batch_idx)
Perform a single training step.
- Parameters:
data (dict) – A dictionary containing the input data.
batch_idx (int) – The index of the current batch.
- Returns:
The computed loss for this step.
- Return type:
torch.Tensor
- validation_step(data, batch_idx, dataloader_idx=0)
Perform a single validation step.
- Parameters:
data (dict) – A dictionary containing the input data.
batch_idx (int) – The index of the current batch.
dataloader_idx (int, optional) – The index of the dataloader. Defaults to 0.
- Returns:
The computed loss for this step.
- Return type:
torch.Tensor
- class dreams.models.heads.heads.ContrastiveHead(backbone_pth: Path, lr, weight_decay, triplet_loss_margin: float)
Bases:
FineTuningHeadInitialize the ContrastiveHead.
- Parameters:
backbone_pth (Path) – Path to the pre-trained backbone.
lr (float) – Learning rate.
weight_decay (float) – Weight decay for optimization.
triplet_loss_margin (float) – Margin for the triplet loss.
- step(data, batch_idx)
Perform a single step of computation.
- Parameters:
data (dict) – A dictionary containing the input data.
batch_idx (int) – The index of the current batch.
- Returns:
A tuple containing None and the computed loss.
- Return type:
Tuple[None, torch.Tensor]
- training_step(data, batch_idx)
Perform a single training step.
- Parameters:
data (dict) – Input data.
batch_idx (int) – Index of the current batch.
- Returns:
The computed loss.
- Return type:
torch.Tensor
- validation_step(data, batch_idx, dataloader_idx=0)
Perform a single validation step.
- Parameters:
data (dict) – Input data.
batch_idx (int) – Index of the current batch.
dataloader_idx (int) – Index of the dataloader.
- Returns:
The computed loss.
- Return type:
torch.Tensor
- class dreams.models.heads.heads.FineTuningHead(backbone: ~pathlib.Path | ~dreams.models.dreams.dreams.DreaMS, lr, weight_decay, backbone_cls=<class 'dreams.models.dreams.dreams.DreaMS'>, unfreeze_backbone_at_epoch=0, precursor_emb=True)
Bases:
LightningModuleBase class for fine-tuning heads in the DreaMS model.
This class provides a foundation for creating various fine-tuning heads that can be attached to a pre-trained DreaMS backbone for specific tasks.
- lr
Learning rate for the optimizer.
- Type:
float
- weight_decay
Weight decay for the optimizer.
- Type:
float
- precursor_emb
Whether to use precursor embeddings.
- Type:
bool
- unfreeze_backbone_at_epoch
Epoch at which to unfreeze the backbone.
- Type:
int
- head
The fine-tuning head (to be implemented in subclasses).
- Type:
nn.Module
Initialize the FineTuningHead.
- Parameters:
backbone (Union[Path, DreaMS]) – Path to the pre-trained backbone or the backbone itself.
lr (float) – Learning rate for the optimizer.
weight_decay (float) – Weight decay for the optimizer.
backbone_cls (type) – Class of the backbone model (default: DreaMS).
unfreeze_backbone_at_epoch (int) – Epoch at which to unfreeze the backbone (default: 0).
precursor_emb (bool) – Whether to use only precursor embeddings (default: True).
- configure_optimizers()
Configure the optimizer for the model.
- Returns:
The configured optimizer.
- Return type:
torch.optim.Optimizer
- forward(spec, charge=None, no_head=False)
Forward pass through the model.
- Parameters:
spec (torch.Tensor) – Input spectrum.
charge (torch.Tensor, optional) – Charge information.
no_head (bool) – If True, return embeddings without passing through the head.
- Returns:
Output of the model.
- Return type:
torch.Tensor
- on_train_epoch_start()
Callback method called at the start of each training epoch. Handles freezing/unfreezing of the backbone.
- abstract step(data, batch_idx)
Perform a single step (to be implemented in subclasses).
- Parameters:
data (dict) – Input data.
batch_idx (int) – Index of the current batch.
- Returns:
Predicted labels and loss.
- Return type:
tuple
- training_step(data, batch_idx)
Perform a single training step.
- Parameters:
data (dict) – Input data.
batch_idx (int) – Index of the current batch.
- Returns:
The computed loss.
- Return type:
torch.Tensor
- validation_step(data, batch_idx, dataloader_idx=0)
Perform a single validation step.
- Parameters:
data (dict) – Input data.
batch_idx (int) – Index of the current batch.
dataloader_idx (int) – Index of the dataloader.
- Returns:
The computed loss.
- Return type:
torch.Tensor
- class dreams.models.heads.heads.FingerprintHead(backbone: Path, fp_str: str, lr, batch_size, weight_decay, dropout=0, loss='cos', retrieval_val_pth=None, retrieval_epoch_freq=10, unfreeze_backbone_at_epoch=0, head_depth=1, store_val_out_dir: Path = None, head_phi_depth: int = 0)
Bases:
FineTuningHeadInitialize the FingerprintHead.
- Parameters:
backbone (Path) – Path to the pre-trained backbone.
fp_str (str) – String representation of the fingerprint.
lr (float) – Learning rate.
batch_size (int) – Batch size.
weight_decay (float) – Weight decay for optimization.
dropout (float, optional) – Dropout rate. Defaults to 0.
loss (str, optional) – Loss function to use. Defaults to ‘cos’.
retrieval_val_pth (Path, optional) – Path for validation retrieval. Defaults to None.
retrieval_epoch_freq (int, optional) – Frequency of retrieval validation. Defaults to 10.
unfreeze_backbone_at_epoch (int, optional) – Epoch to unfreeze the backbone. Defaults to 0.
head_depth (int, optional) – Depth of the head. Defaults to 1.
store_val_out_dir (Path, optional) – Directory to store validation outputs. Defaults to None.
head_phi_depth (int, optional) – Depth of the phi network in DeepSets. Defaults to 0.
- on_validation_epoch_end()
Perform operations at the end of each validation epoch. This method computes and logs retrieval metrics if it’s a retrieval epoch.
- step(data, batch_idx)
Perform a single step of computation.
- Parameters:
data (dict) – A dictionary containing the input data.
batch_idx (int) – The index of the current batch.
- Returns:
A tuple containing the predictions and the loss.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- validate(data, batch_idx, dataloader_idx)
Perform validation computations.
- Parameters:
data (dict) – A dictionary containing the input data.
batch_idx (int) – The index of the current batch.
dataloader_idx (int) – The index of the dataloader.
- Returns:
The computed loss for this validation step.
- Return type:
torch.Tensor
- validation_step(data, batch_idx, dataloader_idx=0)
Perform a single validation step.
- Parameters:
data (dict) – A dictionary containing the input data.
batch_idx (int) – The index of the current batch.
dataloader_idx (int, optional) – The index of the dataloader. Defaults to 0.
- Returns:
The computed loss for this validation step.
- Return type:
torch.Tensor
- class dreams.models.heads.heads.IntRegressionHead(backbone_pth: Path, lr, weight_decay, out_dim=1)
Bases:
RegressionHeadInteger Regression head for fine-tuning tasks.
This class implements a regression head specifically for integer-valued outputs.
- Inherits all attributes from RegressionHead.
Initialize the IntRegressionHead.
- Parameters:
backbone_pth (Path) – Path to the pre-trained backbone.
lr (float) – Learning rate for the optimizer.
weight_decay (float) – Weight decay for the optimizer.
out_dim (int) – Output dimension of the regression (default: 1).
- validation_step(data, batch_idx)
Perform a single validation step.
- Parameters:
data (dict) – Input data.
batch_idx (int) – Index of the current batch.
- Returns:
The computed loss.
- Return type:
torch.Tensor
- class dreams.models.heads.heads.RegressionHead(backbone: Path | DreaMS, lr, weight_decay, sigmoid=True, out_dim=1, mol_props_calc: MolPropertyCalculator = None, head_depth=1, dropout=0)
Bases:
FineTuningHeadRegression head for fine-tuning tasks.
This class implements a regression head that can be used for various regression tasks on top of the DreaMS backbone.
- head
The regression head module.
- Type:
nn.Module
- out_dim
Output dimension of the regression.
- Type:
int
- sigmoid
Sigmoid activation (if used).
- Type:
nn.Module
- mol_props_calc
Calculator for molecular properties.
- Type:
mu.MolPropertyCalculator
Initialize the RegressionHead.
- Parameters:
backbone (Union[Path, DreaMS]) – Path to the pre-trained backbone or the backbone itself.
lr (float) – Learning rate for the optimizer.
weight_decay (float) – Weight decay for the optimizer.
sigmoid (bool) – Whether to use sigmoid activation (default: True).
out_dim (int) – Output dimension of the regression (default: 1).
mol_props_calc (mu.MolPropertyCalculator, optional) – Calculator for molecular properties.
head_depth (int) – Depth of the regression head (default: 1).
dropout (float) – Dropout rate (default: 0).
- step(data, batch_idx)
Perform a single step of regression.
- Parameters:
data (dict) – Input data.
batch_idx (int) – Index of the current batch.
- Returns:
Predicted labels and loss.
- Return type:
tuple
- validation_step(data, batch_idx, dataloader_idx=0)
Perform a single validation step.
- Parameters:
data (dict) – Input data.
batch_idx (int) – Index of the current batch.
dataloader_idx (int) – Index of the dataloader.
- Returns:
The computed loss.
- Return type:
torch.Tensor