mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-18 02:40:01 +00:00
feat(transformers): various enhancements to the transformers backend (#2468)
update transformers *Handle Temperature = 0 as greedy search *Handle custom works as stop words *Implement KV cache *Phi 3 no more requires trust_remote_code: true
This commit is contained in:
parent
5ddaa19914
commit
4a239a4bff
59
backend/python/transformers/backend.py
Executable file → Normal file
59
backend/python/transformers/backend.py
Executable file → Normal file
@ -22,9 +22,9 @@ import torch.cuda
|
||||
|
||||
XPU=os.environ.get("XPU", "0") == "1"
|
||||
if XPU:
|
||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer
|
||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||
else:
|
||||
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer
|
||||
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@ -246,28 +246,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||
# print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
||||
# print("Embeddings:", sentence_embeddings, file=sys.stderr)
|
||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])
|
||||
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
set_seed(request.Seed)
|
||||
if request.TopP == 0:
|
||||
request.TopP = 0.9
|
||||
if request.TopP < 0 or request.TopP > 1:
|
||||
request.TopP = 1
|
||||
|
||||
if request.TopK == 0:
|
||||
request.TopK = 40
|
||||
if request.TopK <= 0:
|
||||
request.TopK = 50
|
||||
|
||||
if request.Temperature > 0 :
|
||||
sample=True
|
||||
else:
|
||||
sample=False
|
||||
request.TopP == None
|
||||
request.TopK == None
|
||||
request.Temperature == None
|
||||
|
||||
prompt = request.Prompt
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
if request.StopPrompts:
|
||||
eos_token_id = []
|
||||
for word in request.StopPrompts:
|
||||
eos_token_id.append(self.tokenizer.convert_tokens_to_ids(word))
|
||||
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
if request.Tokens > 0:
|
||||
@ -281,6 +281,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
inputs = inputs.to("xpu")
|
||||
streaming = False
|
||||
|
||||
criteria=[]
|
||||
if request.StopPrompts:
|
||||
criteria = StoppingCriteriaList(
|
||||
[
|
||||
StopStringCriteria(tokenizer=self.tokenizer, stop_strings=request.StopPrompts),
|
||||
]
|
||||
)
|
||||
|
||||
if streaming:
|
||||
streamer=TextIteratorStreamer(self.tokenizer,
|
||||
skip_prompt=True,
|
||||
@ -290,11 +298,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
temperature=request.Temperature,
|
||||
top_p=request.TopP,
|
||||
top_k=request.TopK,
|
||||
do_sample=True,
|
||||
do_sample=sample,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
eos_token_id=eos_token_id,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
streamer=streamer)
|
||||
streamer=streamer,
|
||||
stopping_criteria=criteria,
|
||||
use_cache=True,
|
||||
)
|
||||
thread=Thread(target=self.model.generate, kwargs=config)
|
||||
thread.start()
|
||||
generated_text = ""
|
||||
@ -311,18 +322,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
temperature=request.Temperature,
|
||||
top_p=request.TopP,
|
||||
top_k=request.TopK,
|
||||
do_sample=True,
|
||||
do_sample=sample,
|
||||
pad_token=self.tokenizer.eos_token_id)
|
||||
else:
|
||||
outputs = self.model.generate(inputs["input_ids"],
|
||||
outputs = self.model.generate(**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=request.Temperature,
|
||||
top_p=request.TopP,
|
||||
top_k=request.TopK,
|
||||
do_sample=True,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=self.tokenizer.eos_token_id)
|
||||
do_sample=sample,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
stopping_criteria=criteria,
|
||||
use_cache=True,
|
||||
)
|
||||
generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]
|
||||
|
||||
if streaming:
|
||||
|
Loading…
Reference in New Issue
Block a user