| import argparse | |
| import json | |
| import os | |
| from tqdm import tqdm | |
| from utils.simple_bleu import simple_score | |
| def load_json(filename): | |
| json_data = [] | |
| with open(filename, "r", encoding="utf-8") as f: | |
| if os.path.splitext(filename)[1] != ".jsonl": | |
| json_data = json.load(f) | |
| else: | |
| for line in f: | |
| json_data.append(json.loads(line)) | |
| return json_data | |
| def save_json(json_data, filename, option="a"): | |
| directory, _ = os.path.split(filename) | |
| if not os.path.exists(directory): | |
| os.makedirs(directory) | |
| filename = filename.replace(" ", "_") | |
| with open(filename, option, encoding="utf-8") as f: | |
| if not filename.endswith(".jsonl"): | |
| json.dump(json_data, f, ensure_ascii=False, indent=4) | |
| else: | |
| for data in json_data: | |
| json.dump(data, f, ensure_ascii=False) | |
| f.write("\n") | |
| def main(): | |
| parser = argparse.ArgumentParser("argument") | |
| parser.add_argument( | |
| "--input_file", | |
| default="./data/komt-1810k-test.jsonl", | |
| type=str, | |
| help="input_file", | |
| ) | |
| parser.add_argument( | |
| "--model_path", | |
| default=None, | |
| type=str, | |
| help="model path", | |
| ) | |
| parser.add_argument("--output", default="", type=str, help="model path") | |
| parser.add_argument( | |
| "--model", default="davidkim205/iris-7b", type=str, help="model" | |
| ) | |
| args = parser.parse_args() | |
| json_data = load_json(args.input_file) | |
| if args.model == "squarelike/Gugugo-koen-7B-V1.1": | |
| from models.gugugo import translate_en2ko, translate_ko2en | |
| elif args.model == "jbochi/madlad400-10b-mt": | |
| from models.madlad400 import translate_ko2en, translate_en2ko | |
| elif args.model == "facebook/mbart-large-50-many-to-many-mmt": | |
| from models.mbart50 import translate_en2ko, translate_ko2en | |
| elif args.model == "facebook/nllb-200-distilled-1.3B": | |
| from models.nllb200 import translate_ko2en, translate_en2ko | |
| elif args.model == "Unbabel/TowerInstruct-7B-v0.1": | |
| from models.TowerInstruct import translate_ko2en, translate_en2ko | |
| elif args.model == "maywell/Synatra-7B-v0.3-Translation": | |
| from models.synatra import translate_ko2en, translate_en2ko | |
| elif args.model == "davidkim205/iris-7b": | |
| from models.iris_7b import translate_ko2en, translate_en2ko | |
| if args.model_path: | |
| from models.iris_7b import load_model | |
| load_model(args.model_path) | |
| for index, data in tqdm(enumerate(json_data)): | |
| chat = data["conversations"] | |
| src = data["src"] | |
| input = chat[0]["value"] | |
| def clean_text(text): | |
| if chat[0]["value"].find("한글로 번역하세요.") != -1: | |
| cur_lang = "en" | |
| else: | |
| cur_lang = "ko" | |
| text = text.split("번역하세요.\n", 1)[-1] | |
| return text, cur_lang | |
| input, cur_lang = clean_text(input) | |
| def do_translation(text, cur_lang): | |
| trans = "" | |
| try: | |
| if cur_lang == "en": | |
| trans = translate_en2ko(text) | |
| else: | |
| trans = translate_ko2en(text) | |
| except Exception as e: | |
| trans = "" | |
| return trans | |
| generation1 = do_translation(input, cur_lang) | |
| next_lang = "ko" if cur_lang == "en" else "en" | |
| generation2 = do_translation(generation1, next_lang) | |
| bleu = simple_score(input, generation2) | |
| bleu = round(bleu, 3) | |
| result = { | |
| "index": index, | |
| "reference": input, | |
| "generation": generation2, | |
| "generation1": generation1, | |
| "bleu": bleu, | |
| "lang": cur_lang, | |
| "model": args.model, | |
| "src": src, | |
| "conversations": chat, | |
| } | |
| print(json.dumps(result, ensure_ascii=False, indent=2)) | |
| if args.output: | |
| output = args.output | |
| else: | |
| if args.model_path: | |
| filename = args.model.split("/")[-1] | |
| model_num = args.model_path.split("/")[-1] | |
| output = f"results_self/{filename}-{model_num}-result.jsonl" | |
| else: | |
| filename = args.model.split("/")[-1] | |
| output = f"results_self/{filename}-result.jsonl" | |
| save_json([result], output) | |
| if __name__ == "__main__": | |
| main() | |