deprecate argument stream in model.chat()
Browse files- modeling_qwen.py +23 -36
modeling_qwen.py
CHANGED
|
@@ -60,6 +60,12 @@ If you are directly using the model downloaded from Huggingface, please make sur
|
|
| 60 |
如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
|
| 61 |
"""
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
apply_rotary_emb_func = None
|
| 64 |
rms_norm = None
|
| 65 |
flash_attn_unpadded_func = None
|
|
@@ -977,10 +983,11 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
| 977 |
history: Optional[HistoryType],
|
| 978 |
system: str = "You are a helpful assistant.",
|
| 979 |
append_history: bool = True,
|
| 980 |
-
stream: Optional[bool] =
|
| 981 |
stop_words_ids: Optional[List[List[int]]] = None,
|
| 982 |
**kwargs,
|
| 983 |
) -> Tuple[str, HistoryType]:
|
|
|
|
| 984 |
assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
|
| 985 |
if history is None:
|
| 986 |
history = []
|
|
@@ -1000,41 +1007,21 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
| 1000 |
self.generation_config.chat_format, tokenizer
|
| 1001 |
))
|
| 1002 |
input_ids = torch.tensor([context_tokens]).to(self.device)
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
break
|
| 1019 |
-
yield tokenizer.decode(outputs, skip_special_tokens=True)
|
| 1020 |
-
|
| 1021 |
-
return stream_generator()
|
| 1022 |
-
else:
|
| 1023 |
-
outputs = self.generate(
|
| 1024 |
-
input_ids,
|
| 1025 |
-
stop_words_ids = stop_words_ids,
|
| 1026 |
-
return_dict_in_generate = False,
|
| 1027 |
-
**kwargs,
|
| 1028 |
-
)
|
| 1029 |
-
|
| 1030 |
-
response = decode_tokens(
|
| 1031 |
-
outputs[0],
|
| 1032 |
-
tokenizer,
|
| 1033 |
-
raw_text_len=len(raw_text),
|
| 1034 |
-
context_length=len(context_tokens),
|
| 1035 |
-
chat_format=self.generation_config.chat_format,
|
| 1036 |
-
verbose=False,
|
| 1037 |
-
)
|
| 1038 |
|
| 1039 |
if append_history:
|
| 1040 |
history.append((query, response))
|
|
|
|
| 60 |
如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
|
| 61 |
"""
|
| 62 |
|
| 63 |
+
_SENTINEL = object()
|
| 64 |
+
_ERROR_STREAM_IN_CHAT = """\
|
| 65 |
+
Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
|
| 66 |
+
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
apply_rotary_emb_func = None
|
| 70 |
rms_norm = None
|
| 71 |
flash_attn_unpadded_func = None
|
|
|
|
| 983 |
history: Optional[HistoryType],
|
| 984 |
system: str = "You are a helpful assistant.",
|
| 985 |
append_history: bool = True,
|
| 986 |
+
stream: Optional[bool] = _SENTINEL,
|
| 987 |
stop_words_ids: Optional[List[List[int]]] = None,
|
| 988 |
**kwargs,
|
| 989 |
) -> Tuple[str, HistoryType]:
|
| 990 |
+
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
|
| 991 |
assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
|
| 992 |
if history is None:
|
| 993 |
history = []
|
|
|
|
| 1007 |
self.generation_config.chat_format, tokenizer
|
| 1008 |
))
|
| 1009 |
input_ids = torch.tensor([context_tokens]).to(self.device)
|
| 1010 |
+
outputs = self.generate(
|
| 1011 |
+
input_ids,
|
| 1012 |
+
stop_words_ids = stop_words_ids,
|
| 1013 |
+
return_dict_in_generate = False,
|
| 1014 |
+
**kwargs,
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
response = decode_tokens(
|
| 1018 |
+
outputs[0],
|
| 1019 |
+
tokenizer,
|
| 1020 |
+
raw_text_len=len(raw_text),
|
| 1021 |
+
context_length=len(context_tokens),
|
| 1022 |
+
chat_format=self.generation_config.chat_format,
|
| 1023 |
+
verbose=False,
|
| 1024 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1025 |
|
| 1026 |
if append_history:
|
| 1027 |
history.append((query, response))
|