Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from huggingface_hub import HfApi, get_collection, list_collections | |
| from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset | |
| import pandas as pd | |
| import os | |
| def get_models(): | |
| # this is the collection id for the molecular property prediction models | |
| collection = get_collection("ChemFM/molecular-property-prediction-6710141ffc31f31a47d6fc0c", token = os.environ.get("TOKEN")) | |
| models = dict() | |
| for item in collection.items: | |
| if item.item_type == "model": | |
| item_name = item.item_id.split("/")[-1] | |
| models[item_name] = item.item_id | |
| assert item_name in dataset_task_types, f"{item_name} is not in the task_types" | |
| assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions" | |
| return models | |
| candidate_models = get_models() | |
| properties = [dataset_property_names[item] for item in candidate_models.keys()] | |
| property_names = list(candidate_models.keys()) | |
| model = MolecularPropertyPredictionModel(candidate_models) | |
| def get_description(property_name): | |
| property_id = dataset_property_names_to_dataset[property_name] | |
| return dataset_descriptions[property_id] | |
| def predict_single_label(smiles, property_name): | |
| property_id = dataset_property_names_to_dataset[property_name] | |
| try: | |
| adapter_id = candidate_models[property_id] | |
| info = model.swith_adapter(property_id, adapter_id) | |
| running_status = None | |
| if info == "keep": | |
| running_status = "Adapter is the same as the current one" | |
| #print("Adapter is the same as the current one") | |
| elif info == "switched": | |
| running_status = "Adapter is switched successfully" | |
| #print("Adapter is switched successfully") | |
| elif info == "error": | |
| running_status = "Adapter is not found" | |
| #print("Adapter is not found") | |
| return "NA", running_status | |
| else: | |
| running_status = "Unknown error" | |
| return "NA", running_status | |
| #prediction = model.predict(smiles, property_name, adapter_id) | |
| prediction = model.predict_single_smiles(smiles, dataset_task_types[property_id]) | |
| if prediction is None: | |
| return "NA", "Invalid SMILES string" | |
| # if the prediction is a float, round it to 3 decimal places | |
| if isinstance(prediction, float): | |
| prediction = round(prediction, 3) | |
| except Exception as e: | |
| # no matter what the error is, we should return | |
| print(e) | |
| return "NA", "Prediction failed" | |
| return prediction, "Prediction is done" | |
| def predict_file(file, property_name): | |
| property_id = dataset_property_names_to_dataset[property_name] | |
| try: | |
| adapter_id = candidate_models[property_id] | |
| info = model.swith_adapter(property_id, adapter_id) | |
| running_status = None | |
| if info == "keep": | |
| running_status = "Adapter is the same as the current one" | |
| #print("Adapter is the same as the current one") | |
| elif info == "switched": | |
| running_status = "Adapter is switched successfully" | |
| #print("Adapter is switched successfully") | |
| elif info == "error": | |
| running_status = "Adapter is not found" | |
| #print("Adapter is not found") | |
| return None, None, file, running_status | |
| else: | |
| running_status = "Unknown error" | |
| return None, None, file, running_status | |
| df = pd.read_csv(file) | |
| # we have already checked the file contains the "smiles" column | |
| df = model.predict_file(df, dataset_task_types[property_id]) | |
| # we should save this file to the disk to be downloaded | |
| # rename the file to have "_prediction" suffix | |
| prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv") | |
| print(file, prediction_file) | |
| # save the file to the disk | |
| df.to_csv(prediction_file, index=False) | |
| except Exception as e: | |
| # no matter what the error is, we should return | |
| print(e) | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), file, "Prediction failed" | |
| return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done" | |
| def validate_file(file): | |
| try: | |
| if file.endswith(".csv"): | |
| df = pd.read_csv(file) | |
| if "smiles" not in df.columns: | |
| # we should clear the file input | |
| return "Invalid file content. The csv file must contain column named 'smiles'", \ | |
| None, gr.update(visible=False), gr.update(visible=False) | |
| # check the length of the smiles | |
| length = len(df["smiles"]) | |
| elif file.endswith(".smi"): | |
| return "Invalid file extension", \ | |
| None, gr.update(visible=False), gr.update(visible=False) | |
| else: | |
| return "Invalid file extension", \ | |
| None, gr.update(visible=False), gr.update(visible=False) | |
| except Exception as e: | |
| return "Invalid file content.", \ | |
| None, gr.update(visible=False), gr.update(visible=False) | |
| if length > 100: | |
| return "The space does not support the file containing more than 100 SMILES", \ | |
| None, gr.update(visible=False), gr.update(visible=False) | |
| return "Valid file", file, gr.update(visible=True), gr.update(visible=False) | |
| def raise_error(status): | |
| if status != "Valid file": | |
| raise gr.Error(status) | |
| return None | |
| def clear_file(download_button): | |
| # we might need to delete the prediction file and uploaded file | |
| prediction_path = download_button | |
| print(prediction_path) | |
| if prediction_path and os.path.exists(prediction_path): | |
| os.remove(prediction_path) | |
| original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv") | |
| original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi") | |
| if os.path.exists(original_data_file_0): | |
| os.remove(original_data_file_0) | |
| if os.path.exists(original_data_file_1): | |
| os.remove(original_data_file_1) | |
| #if os.path.exists(file): | |
| # os.remove(file) | |
| #prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv") | |
| #if os.path.exists(prediction_file): | |
| # os.remove(prediction_file) | |
| return gr.update(visible=False), gr.update(visible=False), None | |
| def build_inference(): | |
| with gr.Blocks() as demo: | |
| # first row - Dropdown input | |
| #with gr.Row(): | |
| print(property_names[0].lower()) | |
| print(properties) | |
| dropdown = gr.Dropdown(properties, label="Property", value=dataset_property_names[property_names[0].lower()]) | |
| description_box = gr.Textbox(label="Property description", lines=5, | |
| interactive=False, | |
| value=dataset_descriptions[property_names[0].lower()]) | |
| # third row - Textbox input and prediction label | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| textbox = gr.Textbox(label="Molecule SMILES", type="text", placeholder="Provide a SMILES string here", | |
| lines=1) | |
| predict_single_smiles_button = gr.Button("Predict", size='sm') | |
| prediction = gr.Label("Prediction will appear here") | |
| running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False) | |
| input_file = gr.File(label="Molecule file", | |
| file_count='single', | |
| file_types=[".smi", ".csv"], height=300) | |
| predict_file_button = gr.Button("Predict", size='sm', visible=False) | |
| download_button = gr.DownloadButton("Download", size='sm', visible=False) | |
| stop_button = gr.Button("Stop", size='sm', visible=False) | |
| # dropdown change event | |
| dropdown.change(get_description, inputs=dropdown, outputs=description_box) | |
| # predict single button click event | |
| predict_single_smiles_button.click(lambda:(gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\ | |
| .then(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])\ | |
| .then(lambda:(gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label]) | |
| # input file upload event | |
| file_status = gr.State() | |
| input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status) | |
| # input file clear event | |
| input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file]) | |
| # predict file button click event | |
| predict_file_event = predict_file_button.click(lambda:(gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False, visible=True), | |
| gr.update(interactive=False), | |
| gr.update(interactive=True, visible=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\ | |
| .then(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, stop_button, input_file, running_terminal_label])\ | |
| .then(lambda:(gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label]) | |
| # stop button click event | |
| #stop_button.click(fn=None, inputs=None, outputs=None, cancels=[predict_file_event]) | |
| return demo | |
| demo = build_inference() | |
| if __name__ == '__main__': | |
| demo.launch() |