diff --git a/backend/python/transformers/backend.py b/backend/python/transformers/backend.py old mode 100755 new mode 100644 index b1e0d559..10603d2e --- a/backend/python/transformers/backend.py +++ b/backend/python/transformers/backend.py @@ -22,9 +22,9 @@ import torch.cuda XPU=os.environ.get("XPU", "0") == "1" if XPU: - from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer + from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria else: - from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer + from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -246,28 +246,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): # Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) -# print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr) -# print("Embeddings:", sentence_embeddings, file=sys.stderr) return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0]) async def _predict(self, request, context, streaming=False): set_seed(request.Seed) - if request.TopP == 0: - request.TopP = 0.9 + if request.TopP < 0 or request.TopP > 1: + request.TopP = 1 - if request.TopK == 0: - request.TopK = 40 + if request.TopK <= 0: + request.TopK = 50 + + if request.Temperature > 0 : + sample=True + else: + sample=False + request.TopP == None + request.TopK == None + request.Temperature == None prompt = request.Prompt if not request.Prompt and request.UseTokenizerTemplate and request.Messages: prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True) - eos_token_id = self.tokenizer.eos_token_id - if request.StopPrompts: - eos_token_id = [] - for word in request.StopPrompts: - eos_token_id.append(self.tokenizer.convert_tokens_to_ids(word)) - inputs = self.tokenizer(prompt, return_tensors="pt") if request.Tokens > 0: @@ -281,6 +281,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): inputs = inputs.to("xpu") streaming = False + criteria=[] + if request.StopPrompts: + criteria = StoppingCriteriaList( + [ + StopStringCriteria(tokenizer=self.tokenizer, stop_strings=request.StopPrompts), + ] + ) + if streaming: streamer=TextIteratorStreamer(self.tokenizer, skip_prompt=True, @@ -290,11 +298,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): temperature=request.Temperature, top_p=request.TopP, top_k=request.TopK, - do_sample=True, + do_sample=sample, attention_mask=inputs["attention_mask"], - eos_token_id=eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.eos_token_id, - streamer=streamer) + streamer=streamer, + stopping_criteria=criteria, + use_cache=True, + ) thread=Thread(target=self.model.generate, kwargs=config) thread.start() generated_text = "" @@ -311,18 +322,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): temperature=request.Temperature, top_p=request.TopP, top_k=request.TopK, - do_sample=True, + do_sample=sample, pad_token=self.tokenizer.eos_token_id) else: - outputs = self.model.generate(inputs["input_ids"], + outputs = self.model.generate(**inputs, max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP, top_k=request.TopK, - do_sample=True, - attention_mask=inputs["attention_mask"], - eos_token_id=eos_token_id, - pad_token_id=self.tokenizer.eos_token_id) + do_sample=sample, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + stopping_criteria=criteria, + use_cache=True, + ) generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] if streaming: