feat: add flash-attn in nvidia and rocm envs (#1995)

Signed-off-by: Ludovic LEROUX <ludovic@inpher.io>
This commit is contained in:
Ludovic Leroux 2024-04-11 03:44:39 -04:00 committed by GitHub
parent e152b07b74
commit b4548ad72d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
set -ex set -ex
SKIP_CONDA=${SKIP_CONDA:-0} SKIP_CONDA=${SKIP_CONDA:-0}
REQUIREMENTS_FILE=$1
# Check if environment exist # Check if environment exist
conda_env_exists(){ conda_env_exists(){
@ -14,7 +15,7 @@ else
export PATH=$PATH:/opt/conda/bin export PATH=$PATH:/opt/conda/bin
if conda_env_exists "transformers" ; then if conda_env_exists "transformers" ; then
echo "Creating virtual environment..." echo "Creating virtual environment..."
conda env create --name transformers --file $1 conda env create --name transformers --file $REQUIREMENTS_FILE
echo "Virtual environment created." echo "Virtual environment created."
else else
echo "Virtual environment already exists." echo "Virtual environment already exists."
@ -28,11 +29,16 @@ if [ -d "/opt/intel" ]; then
pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed optimum[openvino] pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed optimum[openvino]
fi fi
if [ "$PIP_CACHE_PURGE" = true ] ; then # If we didn't skip conda, activate the environment
if [ $SKIP_CONDA -eq 0 ]; then # to install FlashAttention
# Activate conda environment if [ $SKIP_CONDA -eq 0 ]; then
source activate transformers source activate transformers
fi fi
if [[ $REQUIREMENTS_FILE =~ -nvidia.yml$ ]]; then
#TODO: FlashAttention is supported on nvidia and ROCm, but ROCm install can't be done this easily
pip install flash-attn --no-build-isolation
fi
if [ "$PIP_CACHE_PURGE" = true ] ; then
pip cache purge pip cache purge
fi fi