Spaces:
Sleeping
Sleeping
| #PSUEDOCODE UNTIL WE GET DATA | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.metrics import accuracy_score | |
| from sklearn.preprocessing import LabelEncoder | |
| from datasets import load_dataset | |
| import pickle | |
| def infer_dna(args): | |
| ecoDf = pd.read_csv(args['input_path'], sep='\t') | |
| dnaEmbeds = load_dataset("LofiAmazon/BOLD-Embeddings", split='train') | |
| # load model to calculate new embeddings | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(model, force_download=True) | |
| tokenizer.add_special_tokens({"pad_token": "<UNK>"}) | |
| bert_model = BertForMaskedLM.from_pretrained(model, force_download=True) | |
| modelDNA = load_checkpoint() | |
| modelDNAEnv = load_checkpoint() | |
| ecoDF = ecoDf[ecoDf['marker_code' == 'COI-5P']] | |
| ecoDf = ecoDf[['processid','nucraw','coord','country','depth', | |
| 'WorldClim2_BIO_Temperature_Seasonality', | |
| 'WorldClim2_BIO_Precipitation_Seasonality','WorldClim2_BIO_Annual_Precipitation', 'EarthEnvTopoMed_Elevation', | |
| 'EsaWorldCover_TreeCover', 'CHELSA_exBIO_GrowingSeasonLength', | |
| 'WCS_Human_Footprint_2009', 'GHS_Population_Density', | |
| 'CHELSA_BIO_Annual_Mean_Temperature']] | |
| # grab DNA embeddings and merge them onto ecoDf by processid | |
| X_eco = pd.merge(ecoDf, dnaEmbeds, on='processid', how='left') | |
| # split data into X and y | |
| # X = df.drop(columns=['genus']) | |
| Y_eco = ecoDf['genus'] | |
| # do inference with the model trained on DNA and Env data | |
| y_eco_probs = modelDNA.predict_proba(X_eco) | |
| # topProb = np.argsort(y_probs, axis=1)[:,-3:] | |
| # topClass = dnamodel.classes_[topProb] | |
| DNAGenuses = {} | |
| for i in range(len(X_eco)): | |
| topProbs = np.argsort(y_probs[i], axis=1)[:,-3:] | |
| topClasses = modelDNA.classes_[topProbs] | |
| sampleStr = X_eco['nucraw'][i] | |
| DNAGenuses[sampleStr] = (topClasses, topProbs) | |
| X_dna = dnaEmbeds.drop(columns='genus') | |
| Y_dna = dnaEmbeds['genus'] | |
| # do inferences with the model only trained on DNA | |
| y_dna_probs = modelDNAEnv.predict_proba(X_dna) | |
| DNAEnvGenuses = {} | |
| for i in range(len()): | |
| topProbs = np.argsort(y_dna_probs[i], axis=1)[:,-3:] | |
| topClasses = modelDNA.classes_[topProbs] | |
| sampleStr = X_eco['nucraw'][i] | |
| DNAGenuses[sampleStr] = (topClasses, topProbs) | |
| return DNAGenuses, DNAEnvGenuses | |
| # if __name__ == '__main__': | |
| # parser = argparse.ArgumentParser() | |
| # parser.add_argument('--input_path', action='store', type=str) | |
| # # parser.add_argument('--checkpt', action='store', type=bool, default=False) | |
| # args = vars(parser.parse_args()) |