diff --git a/backend/python/common-env/transformers/install.sh b/backend/python/common-env/transformers/install.sh index 8502adde..30ec0de0 100644 --- a/backend/python/common-env/transformers/install.sh +++ b/backend/python/common-env/transformers/install.sh @@ -2,6 +2,7 @@ set -ex SKIP_CONDA=${SKIP_CONDA:-0} +REQUIREMENTS_FILE=$1 # Check if environment exist conda_env_exists(){ @@ -14,7 +15,7 @@ else export PATH=$PATH:/opt/conda/bin if conda_env_exists "transformers" ; then echo "Creating virtual environment..." - conda env create --name transformers --file $1 + conda env create --name transformers --file $REQUIREMENTS_FILE echo "Virtual environment created." else echo "Virtual environment already exists." @@ -28,11 +29,16 @@ if [ -d "/opt/intel" ]; then pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed optimum[openvino] fi -if [ "$PIP_CACHE_PURGE" = true ] ; then - if [ $SKIP_CONDA -eq 0 ]; then - # Activate conda environment - source activate transformers - fi +# If we didn't skip conda, activate the environment +# to install FlashAttention +if [ $SKIP_CONDA -eq 0 ]; then + source activate transformers +fi +if [[ $REQUIREMENTS_FILE =~ -nvidia.yml$ ]]; then + #TODO: FlashAttention is supported on nvidia and ROCm, but ROCm install can't be done this easily + pip install flash-attn --no-build-isolation +fi +if [ "$PIP_CACHE_PURGE" = true ] ; then pip cache purge fi \ No newline at end of file