diff --git a/.github/workflows/test-extra.yml b/.github/workflows/test-extra.yml index 7f2445c8..fcf99f83 100644 --- a/.github/workflows/test-extra.yml +++ b/.github/workflows/test-extra.yml @@ -78,6 +78,26 @@ jobs: make --jobs=5 --output-sync=target -C backend/python/diffusers make --jobs=5 --output-sync=target -C backend/python/diffusers test + tests-vllm: + runs-on: ubuntu-latest + steps: + - name: Clone + uses: actions/checkout@v4 + with: + submodules: true + - name: Dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential ffmpeg + sudo apt-get install -y ca-certificates cmake curl patch python3-pip + sudo apt-get install -y libopencv-dev + # Install UV + curl -LsSf https://astral.sh/uv/install.sh | sh + pip install --user --no-cache-dir grpcio-tools==1.64.1 + - name: Test vllm backend + run: | + make --jobs=5 --output-sync=target -C backend/python/vllm + make --jobs=5 --output-sync=target -C backend/python/vllm test # tests-transformers-musicgen: # runs-on: ubuntu-latest # steps: diff --git a/Makefile b/Makefile index 7ed99a9b..835fcc0e 100644 --- a/Makefile +++ b/Makefile @@ -598,10 +598,12 @@ prepare-extra-conda-environments: protogen-python prepare-test-extra: protogen-python $(MAKE) -C backend/python/transformers $(MAKE) -C backend/python/diffusers + $(MAKE) -C backend/python/vllm test-extra: prepare-test-extra $(MAKE) -C backend/python/transformers test $(MAKE) -C backend/python/diffusers test + $(MAKE) -C backend/python/vllm test backend-assets: mkdir -p backend-assets diff --git a/backend/python/vllm/backend.py b/backend/python/vllm/backend.py index 1ccf6d2a..56698a54 100644 --- a/backend/python/vllm/backend.py +++ b/backend/python/vllm/backend.py @@ -194,27 +194,40 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): await iterations.aclose() async def _predict(self, request, context, streaming=False): + # Build the sampling parameters + # NOTE: this must stay in sync with the vllm backend + request_to_sampling_params = { + "N": "n", + "PresencePenalty": "presence_penalty", + "FrequencyPenalty": "frequency_penalty", + "RepetitionPenalty": "repetition_penalty", + "Temperature": "temperature", + "TopP": "top_p", + "TopK": "top_k", + "MinP": "min_p", + "Seed": "seed", + "StopPrompts": "stop", + "StopTokenIds": "stop_token_ids", + "BadWords": "bad_words", + "IncludeStopStrInOutput": "include_stop_str_in_output", + "IgnoreEOS": "ignore_eos", + "Tokens": "max_tokens", + "MinTokens": "min_tokens", + "Logprobs": "logprobs", + "PromptLogprobs": "prompt_logprobs", + "SkipSpecialTokens": "skip_special_tokens", + "SpacesBetweenSpecialTokens": "spaces_between_special_tokens", + "TruncatePromptTokens": "truncate_prompt_tokens", + "GuidedDecoding": "guided_decoding", + } - # Build sampling parameters sampling_params = SamplingParams(top_p=0.9, max_tokens=200) - if request.TopP != 0: - sampling_params.top_p = request.TopP - if request.Tokens > 0: - sampling_params.max_tokens = request.Tokens - if request.Temperature != 0: - sampling_params.temperature = request.Temperature - if request.TopK != 0: - sampling_params.top_k = request.TopK - if request.PresencePenalty != 0: - sampling_params.presence_penalty = request.PresencePenalty - if request.FrequencyPenalty != 0: - sampling_params.frequency_penalty = request.FrequencyPenalty - if request.StopPrompts: - sampling_params.stop = request.StopPrompts - if request.IgnoreEOS: - sampling_params.ignore_eos = request.IgnoreEOS - if request.Seed != 0: - sampling_params.seed = request.Seed + + for request_field, param_field in request_to_sampling_params.items(): + if hasattr(request, request_field): + value = getattr(request, request_field) + if value not in (None, 0, [], False): + setattr(sampling_params, param_field, value) # Extract image paths and process images prompt = request.Prompt diff --git a/backend/python/vllm/test.py b/backend/python/vllm/test.py index 9f325b10..827aa71a 100644 --- a/backend/python/vllm/test.py +++ b/backend/python/vllm/test.py @@ -75,6 +75,53 @@ class TestBackendServicer(unittest.TestCase): finally: self.tearDown() + def test_sampling_params(self): + """ + This method tests if all sampling parameters are correctly processed + NOTE: this does NOT test for correctness, just that we received a compatible response + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + self.assertTrue(response.success) + + req = backend_pb2.PredictOptions( + Prompt="The capital of France is", + TopP=0.8, + Tokens=50, + Temperature=0.7, + TopK=40, + PresencePenalty=0.1, + FrequencyPenalty=0.2, + RepetitionPenalty=1.1, + MinP=0.05, + Seed=42, + StopPrompts=["\n"], + StopTokenIds=[50256], + BadWords=["badword"], + IncludeStopStrInOutput=True, + IgnoreEOS=True, + MinTokens=5, + Logprobs=5, + PromptLogprobs=5, + SkipSpecialTokens=True, + SpacesBetweenSpecialTokens=True, + TruncatePromptTokens=10, + GuidedDecoding=True, + N=2, + ) + resp = stub.Predict(req) + self.assertIsNotNone(resp.message) + self.assertIsNotNone(resp.logprobs) + except Exception as err: + print(err) + self.fail("sampling params service failed") + finally: + self.tearDown() + + def test_embedding(self): """ This method tests if the embeddings are generated successfully