dreams.models.heads package

Submodules

dreams.models.heads.heads module

class dreams.models.heads.heads.BinClassificationHead(backbone_pth: Path, lr, weight_decay, head_depth=1, head_phi_depth=0, dropout=0, focal_loss_alpha=None, focal_loss_gamma=0)

Bases: FineTuningHead

Binary Classification head for fine-tuning tasks.

This class implements a binary classification head that can be used for various classification tasks on top of the DreaMS backbone.

head

The classification head module.

Type:

nn.Module

metrics

Dictionary to store various classification metrics.

Type:

dict

loss

The loss function (Focal Loss).

Type:

nn.Module

Initialize the BinClassificationHead.

Parameters:
  • backbone_pth (Path) – Path to the pre-trained backbone.

  • lr (float) – Learning rate for the optimizer.

  • weight_decay (float) – Weight decay for the optimizer.

  • head_depth (int) – Depth of the classification head (default: 1).

  • head_phi_depth (int) – Depth of the phi network in DeepSets (default: 0).

  • dropout (float) – Dropout rate (default: 0).

  • focal_loss_alpha (float, optional) – Alpha parameter for Focal Loss.

  • focal_loss_gamma (float) – Gamma parameter for Focal Loss (default: 0).

on_validation_epoch_end()

Perform operations at the end of each validation epoch. This method computes and logs ROC and PR curves.

step(data, batch_idx)

Perform a single step of computation.

Parameters:
  • data (dict) – A dictionary containing the input data.

  • batch_idx (int) – The index of the current batch.

Returns:

A tuple containing the predicted labels and the loss.

Return type:

Tuple[torch.Tensor, torch.Tensor]

training_step(data, batch_idx)

Perform a single training step.

Parameters:
  • data (dict) – A dictionary containing the input data.

  • batch_idx (int) – The index of the current batch.

Returns:

The computed loss for this step.

Return type:

torch.Tensor

validation_step(data, batch_idx, dataloader_idx=0)

Perform a single validation step.

Parameters:
  • data (dict) – A dictionary containing the input data.

  • batch_idx (int) – The index of the current batch.

  • dataloader_idx (int, optional) – The index of the dataloader. Defaults to 0.

Returns:

The computed loss for this step.

Return type:

torch.Tensor

class dreams.models.heads.heads.ContrastiveHead(backbone_pth: Path, lr, weight_decay, triplet_loss_margin: float)

Bases: FineTuningHead

Initialize the ContrastiveHead.

Parameters:
  • backbone_pth (Path) – Path to the pre-trained backbone.

  • lr (float) – Learning rate.

  • weight_decay (float) – Weight decay for optimization.

  • triplet_loss_margin (float) – Margin for the triplet loss.

step(data, batch_idx)

Perform a single step of computation.

Parameters:
  • data (dict) – A dictionary containing the input data.

  • batch_idx (int) – The index of the current batch.

Returns:

A tuple containing None and the computed loss.

Return type:

Tuple[None, torch.Tensor]

training_step(data, batch_idx)

Perform a single training step.

Parameters:
  • data (dict) – Input data.

  • batch_idx (int) – Index of the current batch.

Returns:

The computed loss.

Return type:

torch.Tensor

validation_step(data, batch_idx, dataloader_idx=0)

Perform a single validation step.

Parameters:
  • data (dict) – Input data.

  • batch_idx (int) – Index of the current batch.

  • dataloader_idx (int) – Index of the dataloader.

Returns:

The computed loss.

Return type:

torch.Tensor

class dreams.models.heads.heads.FineTuningHead(backbone: ~pathlib.Path | ~dreams.models.dreams.dreams.DreaMS, lr, weight_decay, backbone_cls=<class 'dreams.models.dreams.dreams.DreaMS'>, unfreeze_backbone_at_epoch=0, precursor_emb=True)

Bases: LightningModule

Base class for fine-tuning heads in the DreaMS model.

This class provides a foundation for creating various fine-tuning heads that can be attached to a pre-trained DreaMS backbone for specific tasks.

backbone

The pre-trained backbone model.

Type:

Union[DreaMS, pl.LightningModule]

lr

Learning rate for the optimizer.

Type:

float

weight_decay

Weight decay for the optimizer.

Type:

float

precursor_emb

Whether to use precursor embeddings.

Type:

bool

unfreeze_backbone_at_epoch

Epoch at which to unfreeze the backbone.

Type:

int

head

The fine-tuning head (to be implemented in subclasses).

Type:

nn.Module

Initialize the FineTuningHead.

Parameters:
  • backbone (Union[Path, DreaMS]) – Path to the pre-trained backbone or the backbone itself.

  • lr (float) – Learning rate for the optimizer.

  • weight_decay (float) – Weight decay for the optimizer.

  • backbone_cls (type) – Class of the backbone model (default: DreaMS).

  • unfreeze_backbone_at_epoch (int) – Epoch at which to unfreeze the backbone (default: 0).

  • precursor_emb (bool) – Whether to use only precursor embeddings (default: True).

configure_optimizers()

Configure the optimizer for the model.

Returns:

The configured optimizer.

Return type:

torch.optim.Optimizer

forward(spec, charge=None, no_head=False)

Forward pass through the model.

Parameters:
  • spec (torch.Tensor) – Input spectrum.

  • charge (torch.Tensor, optional) – Charge information.

  • no_head (bool) – If True, return embeddings without passing through the head.

Returns:

Output of the model.

Return type:

torch.Tensor

on_train_epoch_start()

Callback method called at the start of each training epoch. Handles freezing/unfreezing of the backbone.

abstract step(data, batch_idx)

Perform a single step (to be implemented in subclasses).

Parameters:
  • data (dict) – Input data.

  • batch_idx (int) – Index of the current batch.

Returns:

Predicted labels and loss.

Return type:

tuple

training_step(data, batch_idx)

Perform a single training step.

Parameters:
  • data (dict) – Input data.

  • batch_idx (int) – Index of the current batch.

Returns:

The computed loss.

Return type:

torch.Tensor

validation_step(data, batch_idx, dataloader_idx=0)

Perform a single validation step.

Parameters:
  • data (dict) – Input data.

  • batch_idx (int) – Index of the current batch.

  • dataloader_idx (int) – Index of the dataloader.

Returns:

The computed loss.

Return type:

torch.Tensor

class dreams.models.heads.heads.FingerprintHead(backbone: Path, fp_str: str, lr, batch_size, weight_decay, dropout=0, loss='cos', retrieval_val_pth=None, retrieval_epoch_freq=10, unfreeze_backbone_at_epoch=0, head_depth=1, store_val_out_dir: Path = None, head_phi_depth: int = 0)

Bases: FineTuningHead

Initialize the FingerprintHead.

Parameters:
  • backbone (Path) – Path to the pre-trained backbone.

  • fp_str (str) – String representation of the fingerprint.

  • lr (float) – Learning rate.

  • batch_size (int) – Batch size.

  • weight_decay (float) – Weight decay for optimization.

  • dropout (float, optional) – Dropout rate. Defaults to 0.

  • loss (str, optional) – Loss function to use. Defaults to ‘cos’.

  • retrieval_val_pth (Path, optional) – Path for validation retrieval. Defaults to None.

  • retrieval_epoch_freq (int, optional) – Frequency of retrieval validation. Defaults to 10.

  • unfreeze_backbone_at_epoch (int, optional) – Epoch to unfreeze the backbone. Defaults to 0.

  • head_depth (int, optional) – Depth of the head. Defaults to 1.

  • store_val_out_dir (Path, optional) – Directory to store validation outputs. Defaults to None.

  • head_phi_depth (int, optional) – Depth of the phi network in DeepSets. Defaults to 0.

on_validation_epoch_end()

Perform operations at the end of each validation epoch. This method computes and logs retrieval metrics if it’s a retrieval epoch.

step(data, batch_idx)

Perform a single step of computation.

Parameters:
  • data (dict) – A dictionary containing the input data.

  • batch_idx (int) – The index of the current batch.

Returns:

A tuple containing the predictions and the loss.

Return type:

Tuple[torch.Tensor, torch.Tensor]

validate(data, batch_idx, dataloader_idx)

Perform validation computations.

Parameters:
  • data (dict) – A dictionary containing the input data.

  • batch_idx (int) – The index of the current batch.

  • dataloader_idx (int) – The index of the dataloader.

Returns:

The computed loss for this validation step.

Return type:

torch.Tensor

validation_step(data, batch_idx, dataloader_idx=0)

Perform a single validation step.

Parameters:
  • data (dict) – A dictionary containing the input data.

  • batch_idx (int) – The index of the current batch.

  • dataloader_idx (int, optional) – The index of the dataloader. Defaults to 0.

Returns:

The computed loss for this validation step.

Return type:

torch.Tensor

class dreams.models.heads.heads.IntRegressionHead(backbone_pth: Path, lr, weight_decay, out_dim=1)

Bases: RegressionHead

Integer Regression head for fine-tuning tasks.

This class implements a regression head specifically for integer-valued outputs.

Inherits all attributes from RegressionHead.

Initialize the IntRegressionHead.

Parameters:
  • backbone_pth (Path) – Path to the pre-trained backbone.

  • lr (float) – Learning rate for the optimizer.

  • weight_decay (float) – Weight decay for the optimizer.

  • out_dim (int) – Output dimension of the regression (default: 1).

validation_step(data, batch_idx)

Perform a single validation step.

Parameters:
  • data (dict) – Input data.

  • batch_idx (int) – Index of the current batch.

Returns:

The computed loss.

Return type:

torch.Tensor

class dreams.models.heads.heads.RegressionHead(backbone: Path | DreaMS, lr, weight_decay, sigmoid=True, out_dim=1, mol_props_calc: MolPropertyCalculator = None, head_depth=1, dropout=0)

Bases: FineTuningHead

Regression head for fine-tuning tasks.

This class implements a regression head that can be used for various regression tasks on top of the DreaMS backbone.

head

The regression head module.

Type:

nn.Module

out_dim

Output dimension of the regression.

Type:

int

sigmoid

Sigmoid activation (if used).

Type:

nn.Module

mol_props_calc

Calculator for molecular properties.

Type:

mu.MolPropertyCalculator

Initialize the RegressionHead.

Parameters:
  • backbone (Union[Path, DreaMS]) – Path to the pre-trained backbone or the backbone itself.

  • lr (float) – Learning rate for the optimizer.

  • weight_decay (float) – Weight decay for the optimizer.

  • sigmoid (bool) – Whether to use sigmoid activation (default: True).

  • out_dim (int) – Output dimension of the regression (default: 1).

  • mol_props_calc (mu.MolPropertyCalculator, optional) – Calculator for molecular properties.

  • head_depth (int) – Depth of the regression head (default: 1).

  • dropout (float) – Dropout rate (default: 0).

step(data, batch_idx)

Perform a single step of regression.

Parameters:
  • data (dict) – Input data.

  • batch_idx (int) – Index of the current batch.

Returns:

Predicted labels and loss.

Return type:

tuple

validation_step(data, batch_idx, dataloader_idx=0)

Perform a single validation step.

Parameters:
  • data (dict) – Input data.

  • batch_idx (int) – Index of the current batch.

  • dataloader_idx (int) – Index of the dataloader.

Returns:

The computed loss.

Return type:

torch.Tensor

Module contents