feat(transformers): various enhancements to the transformers backend (#2468)

update transformers

*Handle Temperature = 0 as greedy search
*Handle custom works as stop words
*Implement KV cache
*Phi 3 no more requires trust_remote_code: true
This commit is contained in:
fakezeta 2024-06-03 08:52:55 +02:00 committed by GitHub
parent 5ddaa19914
commit 4a239a4bff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

59
backend/python/transformers/backend.py Executable file → Normal file
View File

@ -22,9 +22,9 @@ import torch.cuda
XPU=os.environ.get("XPU", "0") == "1"
if XPU:
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
else:
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -246,28 +246,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
# print("Embeddings:", sentence_embeddings, file=sys.stderr)
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])
async def _predict(self, request, context, streaming=False):
set_seed(request.Seed)
if request.TopP == 0:
request.TopP = 0.9
if request.TopP < 0 or request.TopP > 1:
request.TopP = 1
if request.TopK == 0:
request.TopK = 40
if request.TopK <= 0:
request.TopK = 50
if request.Temperature > 0 :
sample=True
else:
sample=False
request.TopP == None
request.TopK == None
request.Temperature == None
prompt = request.Prompt
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
eos_token_id = self.tokenizer.eos_token_id
if request.StopPrompts:
eos_token_id = []
for word in request.StopPrompts:
eos_token_id.append(self.tokenizer.convert_tokens_to_ids(word))
inputs = self.tokenizer(prompt, return_tensors="pt")
if request.Tokens > 0:
@ -281,6 +281,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
inputs = inputs.to("xpu")
streaming = False
criteria=[]
if request.StopPrompts:
criteria = StoppingCriteriaList(
[
StopStringCriteria(tokenizer=self.tokenizer, stop_strings=request.StopPrompts),
]
)
if streaming:
streamer=TextIteratorStreamer(self.tokenizer,
skip_prompt=True,
@ -290,11 +298,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
temperature=request.Temperature,
top_p=request.TopP,
top_k=request.TopK,
do_sample=True,
do_sample=sample,
attention_mask=inputs["attention_mask"],
eos_token_id=eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
streamer=streamer)
streamer=streamer,
stopping_criteria=criteria,
use_cache=True,
)
thread=Thread(target=self.model.generate, kwargs=config)
thread.start()
generated_text = ""
@ -311,18 +322,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
temperature=request.Temperature,
top_p=request.TopP,
top_k=request.TopK,
do_sample=True,
do_sample=sample,
pad_token=self.tokenizer.eos_token_id)
else:
outputs = self.model.generate(inputs["input_ids"],
outputs = self.model.generate(**inputs,
max_new_tokens=max_tokens,
temperature=request.Temperature,
top_p=request.TopP,
top_k=request.TopK,
do_sample=True,
attention_mask=inputs["attention_mask"],
eos_token_id=eos_token_id,
pad_token_id=self.tokenizer.eos_token_id)
do_sample=sample,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
stopping_criteria=criteria,
use_cache=True,
)
generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]
if streaming: