| | import os |
| | import json |
| | import torch |
| | import numpy as np |
| | from dawo import DAWO, loss_function, Anndata_to_Tensor |
| |
|
| |
|
| | class DAWOWrapper: |
| | """ |
| | Minimal wrapper for DAWO model to use with Hugging Face Hub |
| | """ |
| | def __init__(self, repo_path): |
| | """ |
| | Initialize the DAWO model |
| | |
| | Args: |
| | repo_path: Path to repository with model files |
| | """ |
| | |
| | config_path = os.path.join(repo_path, "config.json") |
| | with open(config_path, 'r') as f: |
| | config = json.load(f) |
| | |
| | |
| | self.model = DAWO( |
| | input_dim_X=config["input_dim_X"], |
| | input_dim_Y=config["input_dim_Y"], |
| | input_dim_Z=config["input_dim_Z"], |
| | latent_dim=config["latent_dim"], |
| | Y_emb=config["Y_emb"], |
| | Z_emb=config["Z_emb"], |
| | num_classes=config["num_classes"] |
| | ) |
| | |
| | |
| | self.model.load_state_dict(torch.load(os.path.join(repo_path, "model.pth"))) |
| | self.model.eval() |
| | |
| | def predict(self, x, y, z): |
| | """ |
| | Make predictions with the DAWO model |
| | |
| | Args: |
| | x: Gene expression tensor (batch_size, input_dim_X) |
| | y: Drug feature tensor (batch_size, input_dim_Y) |
| | z: Cell line feature tensor (batch_size, input_dim_Z) |
| | |
| | Returns: |
| | Dict with model outputs |
| | """ |
| | with torch.no_grad(): |
| | x_hat, mu, logvar, y_pred = self.model(x, y, z) |
| | |
| | return { |
| | "x_hat": x_hat, |
| | "mu": mu, |
| | "logvar": logvar, |
| | "y_pred": y_pred, |
| | "probs": torch.softmax(y_pred, dim=1) |
| | } |