diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..9ac0a81 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,35 @@ +[run] +# Coverage configuration +source = . +omit = + # Exclude test files + test_*.py + # Exclude preload script + preload_models.py + # Exclude virtual environments + venv/* + .venv/* + env/* + # Exclude system/package files + */site-packages/* + */dist-packages/* + +[report] +# Reporting options +precision = 2 +show_missing = True +skip_covered = False + +# Exclude lines from coverage +exclude_lines = + # Default excludes + pragma: no cover + def __repr__ + raise AssertionError + raise NotImplementedError + if __name__ == .__main__.: + if TYPE_CHECKING: + @abstract + +[html] +directory = htmlcov diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml new file mode 100644 index 0000000..0ecb4f8 --- /dev/null +++ b/.github/workflows/docker-publish.yml @@ -0,0 +1,94 @@ +name: Build and Publish Docker Image + +on: + push: + branches: [ main ] + tags: [ 'v*.*.*' ] + pull_request: + branches: [ main ] + workflow_dispatch: + inputs: + embedding_model: + description: 'Embedding model to use' + required: false + default: 'multi-qa-MiniLM-L6-cos-v1' + tokenizer_model: + description: 'Tokenizer model to use' + required: false + default: 'sentence-transformers/multi-qa-MiniLM-L6-cos-v1' + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +jobs: + build-and-push: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + id-token: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log into registry ${{ env.REGISTRY }} + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract Docker metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=pr + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=semver,pattern={{major}} + type=sha,prefix=sha- + type=raw,value=latest,enable={{is_default_branch}} + + - name: Set build args + id: build-args + run: | + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + echo "EMBEDDING_MODEL=${{ github.event.inputs.embedding_model }}" >> $GITHUB_ENV + echo "TOKENIZER_MODEL=${{ github.event.inputs.tokenizer_model }}" >> $GITHUB_ENV + else + echo "EMBEDDING_MODEL=multi-qa-MiniLM-L6-cos-v1" >> $GITHUB_ENV + echo "TOKENIZER_MODEL=sentence-transformers/multi-qa-MiniLM-L6-cos-v1" >> $GITHUB_ENV + fi + + - name: Build and push Docker image + id: build-and-push + uses: docker/build-push-action@v5 + with: + context: . + push: ${{ github.event_name != 'pull_request' }} + load: ${{ github.event_name == 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + build-args: | + EMBEDDING_MODEL=${{ env.EMBEDDING_MODEL }} + TOKENIZER_MODEL=${{ env.TOKENIZER_MODEL }} + cache-from: type=gha + cache-to: type=gha,mode=max + platforms: linux/amd64 + + - name: Generate artifact attestation + if: github.event_name != 'pull_request' + uses: actions/attest-build-provenance@v1 + with: + subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + subject-digest: ${{ steps.build-and-push.outputs.digest }} + push-to-registry: true diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..4396313 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,62 @@ +name: Tests + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + workflow_dispatch: + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-dev.txt + + - name: Lint with flake8 (optional) + run: | + # Install flake8 for basic linting + pip install flake8 + # Stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=venv,env,.venv,.git,__pycache__ + # Exit-zero treats all errors as warnings + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude=venv,env,.venv,.git,__pycache__ + continue-on-error: true + + - name: Run tests with pytest + run: | + pytest + + - name: Upload coverage reports + uses: codecov/codecov-action@v4 + if: matrix.python-version == '3.10' + with: + file: ./coverage.xml + flags: unittests + name: codecov-umbrella + fail_ci_if_error: false + continue-on-error: true + + - name: Archive coverage report + uses: actions/upload-artifact@v4 + if: matrix.python-version == '3.10' + with: + name: coverage-report + path: htmlcov/ + retention-days: 30 diff --git a/.gitignore b/.gitignore index c323cf4..b98579b 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,36 @@ +# Virtual environments .venv +venv/ +env/ +ENV/ + +# Python artifacts +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python + +# Testing and coverage +.pytest_cache/ +.coverage +.coverage.* +htmlcov/ +coverage.xml +*.cover +.hypothesis/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS +.DS_Store +Thumbs.db + +# Environment variables +.env +.env.local diff --git a/Dockerfile b/Dockerfile index 857c9c9..100a217 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,12 +1,26 @@ # Use the official Python image as a base image FROM python:3.10-slim +# Build arguments for model configuration +ARG EMBEDDING_MODEL=multi-qa-MiniLM-L6-cos-v1 +ARG TOKENIZER_MODEL=sentence-transformers/multi-qa-MiniLM-L6-cos-v1 + # Set environment variables -ENV PYTHONUNBUFFERED 1 +ENV PYTHONUNBUFFERED=1 +# Set HuggingFace cache directory to bundle models in the image +ENV HF_HOME=/app/.cache/huggingface +ENV TRANSFORMERS_CACHE=/app/.cache/huggingface +ENV SENTENCE_TRANSFORMERS_HOME=/app/.cache/huggingface +# Set model to use at runtime (from build arg) +ENV EMBEDDING_MODEL=${EMBEDDING_MODEL} +ENV TOKENIZER_MODEL=${TOKENIZER_MODEL} # Set the working directory in the container WORKDIR /app +# Create cache directory with proper permissions +RUN mkdir -p /app/.cache/huggingface + # Copy the requirements.txt file into the container COPY requirements.txt . @@ -16,13 +30,24 @@ RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https:// # Install dependencies RUN pip install --no-cache-dir -r requirements.txt +# Pre-download models for offline availability +# This must happen BEFORE copying application code to ensure models are cached +COPY preload_models.py . +RUN python preload_models.py "${EMBEDDING_MODEL}" "${TOKENIZER_MODEL}" && rm preload_models.py + # Copy the Python script into the container COPY embeddings.py . COPY main.py . # Run the web service on container startup. Here we use the gunicorn -# webserver, with one worker process and 8 threads. -# For environments with multiple CPU cores, increase the number of workers -# to be equal to the cores available. -# Timeout is set to 0 to disable the timeouts of the workers to allow Cloud Run to handle instance scaling. -CMD exec gunicorn --bind :$PORT --workers 1 --threads 8 --timeout 0 main:app +# webserver with optimized configuration for medium concurrency (10-50 requests). +# +# Configuration: +# - 2 workers: Utilizes multiple CPU cores (each worker loads model separately) +# - 8 threads per worker: Handles concurrent requests (total 16 concurrent capacity) +# - Timeout 0: Allows Cloud Run to handle instance scaling +# +# NOTE: Each worker loads the model independently (~200MB RAM per worker). +# For Cloud Run, ensure you allocate at least 1GB RAM and 2 vCPUs. +# Adjust workers based on your CPU allocation: workers = (2 x $num_cores) +CMD exec gunicorn --bind :$PORT --workers 2 --threads 8 --timeout 0 --worker-class gthread main:app diff --git a/README.md b/README.md index 93fdb69..873fc05 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,330 @@ -# To Build and Publish +# Vector Embedder Microservice +A Flask-based microservice for generating text embeddings using SentenceTransformers models. Optimized for offline operation with models bundled in the Docker image. + +## Features + +- **Offline Operation**: Models are pre-downloaded during build, no internet required at runtime +- **Fast Cold Starts**: Models bundled in image eliminate download time +- **Configurable Models**: Use any SentenceTransformers model via build arguments +- **Optimized Concurrency**: Configured for 10-50 concurrent requests +- **Google Cloud Run Ready**: Optimized for serverless deployment + +## Build Arguments + +The service supports customizing the embedding model at build time: + +### Available Build Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `EMBEDDING_MODEL` | `multi-qa-MiniLM-L6-cos-v1` | SentenceTransformers model for generating embeddings | +| `TOKENIZER_MODEL` | `sentence-transformers/multi-qa-MiniLM-L6-cos-v1` | HuggingFace tokenizer model (should match embedding model) | + +### Popular Model Options + +| Model | Size | Use Case | +|-------|------|----------| +| `multi-qa-MiniLM-L6-cos-v1` (default) | ~90MB | Question answering, semantic search | +| `all-MiniLM-L6-v2` | ~80MB | General purpose, fast inference | +| `all-mpnet-base-v2` | ~420MB | High quality, slower inference | +| `paraphrase-multilingual-MiniLM-L12-v2` | ~470MB | Multilingual support (50+ languages) | + +See [SentenceTransformers documentation](https://www.sbert.net/docs/pretrained_models.html) for more models. + +## Building the Image + +### Default Build (multi-qa-MiniLM-L6-cos-v1) + +```bash +docker build -t vector-embedder-microservice . +``` + +### Custom Model Build + +```bash +# Using all-MiniLM-L6-v2 (general purpose) +docker build \ + --build-arg EMBEDDING_MODEL=all-MiniLM-L6-v2 \ + --build-arg TOKENIZER_MODEL=sentence-transformers/all-MiniLM-L6-v2 \ + -t vector-embedder-microservice . + +# Using all-mpnet-base-v2 (higher quality) +docker build \ + --build-arg EMBEDDING_MODEL=all-mpnet-base-v2 \ + --build-arg TOKENIZER_MODEL=sentence-transformers/all-mpnet-base-v2 \ + -t vector-embedder-microservice . + +# Using multilingual model +docker build \ + --build-arg EMBEDDING_MODEL=paraphrase-multilingual-MiniLM-L12-v2 \ + --build-arg TOKENIZER_MODEL=sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 \ + -t vector-embedder-microservice . +``` + +## Using Pre-built Public Images + +Public Docker images are automatically built and published to GitHub Container Registry (ghcr.io) whenever changes are pushed to the main branch. + +### Pull and Run from ghcr.io + +```bash +# Pull the latest image (no authentication needed for public images) +docker pull ghcr.io/OWNER/REPO:latest + +# Run the container +docker run -d \ + -e PORT=5001 \ + -e VECTOR_EMBEDDER_API_KEY=your-api-key \ + -p 5001:5001 \ + ghcr.io/OWNER/REPO:latest +``` + +Replace `OWNER/REPO` with your GitHub username and repository name (e.g., `jman/vectorembeddermicroservice`). + +### Available Tags + +- `latest` - Latest build from main branch +- `main` - Same as latest +- `v1.0.0` - Specific version tags +- `v1.0` - Minor version tags +- `v1` - Major version tags +- `main-sha-abc1234` - Specific commit SHA + +### Making the Image Public + +After the first build, you need to make the package public: + +1. Go to your GitHub repository +2. Click **Packages** in the right sidebar +3. Click on your package name +4. Click **Package settings** (bottom of right sidebar) +5. Scroll to **Danger Zone** +6. Click **Change visibility** → **Public** +7. Type the package name to confirm + +Once public, anyone can pull the image without authentication. + +### Automated Builds + +The Docker image is automatically built and published by GitHub Actions: + +**Automatic triggers:** +- **Push to main** → Builds `latest` and `main` tags +- **Git tags** (e.g., `v1.0.0`) → Builds versioned tags +- **Pull requests** → Builds image but doesn't push + +**Manual builds:** +You can trigger a manual build with custom model selection: +1. Go to **Actions** tab in GitHub +2. Click **Build and Publish Docker Image** workflow +3. Click **Run workflow** +4. Optionally specify custom embedding and tokenizer models +5. Click **Run workflow** + +**Creating versioned releases:** +```bash +# Tag a version +git tag v1.0.0 +git push origin v1.0.0 + +# This automatically builds and publishes: +# - ghcr.io/OWNER/REPO:v1.0.0 +# - ghcr.io/OWNER/REPO:v1.0 +# - ghcr.io/OWNER/REPO:v1 +# - ghcr.io/OWNER/REPO:latest (if on main branch) +``` + +### Building Custom Model Variants + +To build images with different embedding models: + +**Via GitHub Actions (recommended):** +1. Go to **Actions** → **Build and Publish Docker Image** +2. Click **Run workflow** +3. Set custom model parameters: + - Embedding model: `all-mpnet-base-v2` + - Tokenizer model: `sentence-transformers/all-mpnet-base-v2` +4. Run workflow + +This creates a tagged image with your custom model that you can reference by commit SHA. + +## Deploying to Google Cloud Run + +You have two options for deploying to Google Cloud Run: + +### Option 1: Deploy from GitHub Container Registry (Easiest) + +Deploy directly from the public ghcr.io image: + +```bash +gcloud run deploy vector-embedder-microservice \ + --image ghcr.io/OWNER/REPO:latest \ + --region us-central1 \ + --memory 1Gi \ + --cpu 2 \ + --allow-unauthenticated \ + --set-env-vars VECTOR_EMBEDDER_API_KEY=your-api-key +``` + +This pulls the pre-built image from GitHub Container Registry, no build required! + +### Option 2: Build and Push to Google Artifact Registry + +If you prefer to use Google's registry: + +```bash +# Build locally docker build -t vector-embedder-microservice . -docker tag vector-embedder-microservice us-central1-docker.pkg.dev/lasso-409319/models/vector-embedder-microservice -docker push us-central1-docker.pkg.dev/lasso-409319/models/vector-embedder-microservice \ No newline at end of file + +# Tag for Google Artifact Registry +docker tag vector-embedder-microservice \ + us-central1-docker.pkg.dev/YOUR-PROJECT-ID/models/vector-embedder-microservice + +# Push to registry +docker push us-central1-docker.pkg.dev/YOUR-PROJECT-ID/models/vector-embedder-microservice + +# Deploy to Cloud Run +gcloud run deploy vector-embedder-microservice \ + --image us-central1-docker.pkg.dev/YOUR-PROJECT-ID/models/vector-embedder-microservice \ + --region us-central1 \ + --memory 1Gi \ + --cpu 2 \ + --set-env-vars VECTOR_EMBEDDER_API_KEY=your-api-key +``` + +Replace `YOUR-PROJECT-ID` with your Google Cloud project ID. + +## Resource Requirements + +### Minimum (default model) +- **Memory**: 512MB +- **CPU**: 1 vCPU +- **Disk**: 600MB +- **Concurrency**: 8 requests + +### Recommended (production) +- **Memory**: 1GB +- **CPU**: 2 vCPU +- **Disk**: 600MB-1GB (depending on model) +- **Concurrency**: 16 requests (2 workers × 8 threads) + +### High Performance (larger models) +- **Memory**: 2GB+ +- **CPU**: 4 vCPU +- **Disk**: 1GB+ +- **Concurrency**: 32 requests (4 workers × 8 threads) + +## Testing Offline Capability + +```bash +# Run container without network access +docker run --network none \ + -e PORT=5001 \ + -e VECTOR_EMBEDDER_API_KEY=test123 \ + -p 5001:5001 \ + vector-embedder-microservice + +# Test the endpoint +curl -X POST http://localhost:5001/embeddings \ + -H "Content-Type: application/json" \ + -H "X-API-Key: test123" \ + -d '{"text": "This is a test sentence"}' +``` + +## API Usage + +### Generate Embeddings + +**Endpoint**: `POST /embeddings` + +**Headers**: +- `Content-Type: application/json` +- `X-API-Key: ` + +**Request Body**: +```json +{ + "text": "Your text to embed" +} +``` + +**Response**: +```json +{ + "embeddings": [0.123, -0.456, 0.789, ...] +} +``` + +## Configuration + +### Environment Variables + +| Variable | Required | Default | Description | +|----------|----------|---------|-------------| +| `VECTOR_EMBEDDER_API_KEY` | No | `abc123` | API key for authentication | +| `PORT` | No | `5001` | Port to run the service on | +| `EMBEDDING_MODEL` | No | From build arg | Override model at runtime (not recommended) | +| `TOKENIZER_MODEL` | No | From build arg | Override tokenizer at runtime (not recommended) | + +**Note**: `EMBEDDING_MODEL` and `TOKENIZER_MODEL` are set during build. Only override at runtime if you have the desired models already cached in the image. + +## Development + +### Running Tests Locally + +The project includes comprehensive unit tests for both the embedding logic and API endpoints. + +**Install development dependencies:** +```bash +pip install -r requirements-dev.txt +``` + +**Run all tests:** +```bash +pytest +``` + +**Run tests with coverage report:** +```bash +pytest --cov=. --cov-report=html +``` + +**Run specific test file:** +```bash +pytest test_embeddings.py +pytest test_main.py +``` + +**View coverage report:** +After running tests with coverage, open `htmlcov/index.html` in your browser to see a detailed coverage report. + +### Test Structure + +- **test_embeddings.py**: Tests for embedding generation and text chunking logic + - Model loading configuration + - Text chunking with various lengths + - Embedding generation and averaging + - Edge cases and error handling + +- **test_main.py**: Tests for Flask API endpoints + - Authentication and API key validation + - Request/response format validation + - Error handling (missing fields, invalid data) + - HTTP method validation + +### Continuous Integration + +Tests run automatically on: +- **Push to main/develop branches** +- **Pull requests to main/develop** +- **Manual workflow dispatch** + +The CI pipeline: +1. Tests against Python 3.9, 3.10, and 3.11 +2. Runs linting checks with flake8 +3. Executes full test suite with pytest +4. Generates and uploads coverage reports +5. Archives coverage HTML report as artifact + +View test results in the **Actions** tab of the GitHub repository. \ No newline at end of file diff --git a/embeddings.py b/embeddings.py index b362598..c0386f0 100644 --- a/embeddings.py +++ b/embeddings.py @@ -1,24 +1,33 @@ +import os import numpy as np from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer +# Load model from environment variable with fallback to default +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "multi-qa-MiniLM-L6-cos-v1") +TOKENIZER_MODEL = os.getenv("TOKENIZER_MODEL", "sentence-transformers/multi-qa-MiniLM-L6-cos-v1") + +print(f"Loading embedding model: {EMBEDDING_MODEL}") # Need for speed... -model = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1") +model = SentenceTransformer(EMBEDDING_MODEL) def chunk_by_transformers_tokens(content: str, max_token_length: int = 512, - transformer_id: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1") -> list[str]: + transformer_id: str = None) -> list[str]: """ Tokenizes a long string and splits it into substrings based on a specified maximum token length. Parameters: - content (str): The input text to be tokenized. - transformer_id (str): The model name for the SentenceTransformer to use (could expand to other HF models). + If None, uses the TOKENIZER_MODEL environment variable. - max_token_length (int): The maximum length for each chunk of tokens. Returns: - List[str]: List of substrings, each specified max tokens or fewer tokens. """ + if transformer_id is None: + transformer_id = TOKENIZER_MODEL tokenizer = AutoTokenizer.from_pretrained(transformer_id) tokens = tokenizer.tokenize(content) diff --git a/preload_models.py b/preload_models.py new file mode 100644 index 0000000..673099b --- /dev/null +++ b/preload_models.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +""" +Pre-download all required models for offline use. +This script runs during Docker image build to bundle models into the image. + +Usage: + python preload_models.py [embedding_model] [tokenizer_model] + +Examples: + python preload_models.py multi-qa-MiniLM-L6-cos-v1 sentence-transformers/multi-qa-MiniLM-L6-cos-v1 + python preload_models.py all-MiniLM-L6-v2 sentence-transformers/all-MiniLM-L6-v2 +""" +import sys +from sentence_transformers import SentenceTransformer +from transformers import AutoTokenizer + +# Default model IDs +DEFAULT_EMBEDDING_MODEL = "multi-qa-MiniLM-L6-cos-v1" +DEFAULT_TOKENIZER_MODEL = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" + +# Get model names from command line arguments or use defaults +EMBEDDING_MODEL = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_EMBEDDING_MODEL +TOKENIZER_MODEL = sys.argv[2] if len(sys.argv) > 2 else DEFAULT_TOKENIZER_MODEL + +print("=" * 60) +print("Pre-downloading models for offline availability...") +print("=" * 60) + +# Download SentenceTransformer model +print(f"\n1. Downloading embedding model: {EMBEDDING_MODEL}") +model = SentenceTransformer(EMBEDDING_MODEL) +print(f" ✓ Model downloaded and cached successfully") + +# Download tokenizer +print(f"\n2. Downloading tokenizer: {TOKENIZER_MODEL}") +tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL) +print(f" ✓ Tokenizer downloaded and cached successfully") + +# Test that models work +print("\n3. Testing model functionality...") +test_embedding = model.encode(["Test sentence"]) +print(f" ✓ Model test successful (embedding shape: {test_embedding.shape})") + +print("\n" + "=" * 60) +print("All models pre-downloaded successfully!") +print("Service is now ready for offline operation.") +print("=" * 60) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..2c22dce --- /dev/null +++ b/pytest.ini @@ -0,0 +1,29 @@ +[pytest] +# Pytest configuration file + +# Test discovery patterns +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Output options +addopts = + # Verbose output + -v + # Show summary of all test outcomes + -ra + # Show local variables in tracebacks + --showlocals + # Coverage options + --cov=. + --cov-report=term-missing + --cov-report=html + --cov-report=xml + # Exclude files from coverage + --cov-config=.coveragerc + +# Test paths +testpaths = . + +# Ignore patterns +norecursedirs = .git .github __pycache__ *.egg-info .cache .coverage htmlcov diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..f6b5889 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,12 @@ +# Development and testing dependencies +# Install with: pip install -r requirements-dev.txt + +# Include production dependencies +-r requirements.txt + +# Testing framework +pytest>=7.4.0 +pytest-cov>=4.1.0 + +# Test mocking +pytest-mock>=3.11.0 diff --git a/test_embeddings.py b/test_embeddings.py new file mode 100644 index 0000000..b83006f --- /dev/null +++ b/test_embeddings.py @@ -0,0 +1,193 @@ +""" +Unit tests for embeddings.py module +""" +import pytest +import numpy as np +from unittest.mock import Mock, patch, MagicMock + + +class TestChunkByTransformersTokens: + """Tests for chunk_by_transformers_tokens function""" + + @patch('embeddings.AutoTokenizer') + def test_chunk_short_text(self, mock_tokenizer_class): + """Test chunking with text shorter than max length""" + from embeddings import chunk_by_transformers_tokens + + # Mock tokenizer + mock_tokenizer = Mock() + mock_tokenizer.tokenize.return_value = ['test'] * 100 # 100 tokens + mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer + + result = chunk_by_transformers_tokens("Short text", max_token_length=512) + + assert len(result) == 1 + assert result[0] == "Short text" + mock_tokenizer_class.from_pretrained.assert_called_once() + + @patch('embeddings.AutoTokenizer') + def test_chunk_long_text(self, mock_tokenizer_class): + """Test chunking with text longer than max length""" + from embeddings import chunk_by_transformers_tokens + + # Mock tokenizer with many tokens + mock_tokenizer = Mock() + mock_tokenizer.tokenize.return_value = ['token'] * 1000 # 1000 tokens + mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer + + long_text = "word " * 1000 # Long text + result = chunk_by_transformers_tokens(long_text, max_token_length=512) + + # Should split into multiple chunks + assert len(result) >= 2 + # All chunks should be non-empty + assert all(len(chunk) > 0 for chunk in result) + + @patch('embeddings.AutoTokenizer') + def test_chunk_uses_custom_transformer_id(self, mock_tokenizer_class): + """Test that custom transformer_id is used""" + from embeddings import chunk_by_transformers_tokens + + mock_tokenizer = Mock() + mock_tokenizer.tokenize.return_value = ['test'] * 50 + mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer + + custom_model = "custom-model-id" + chunk_by_transformers_tokens("Text", transformer_id=custom_model) + + mock_tokenizer_class.from_pretrained.assert_called_once_with(custom_model) + + @patch('embeddings.AutoTokenizer') + def test_chunk_uses_default_tokenizer_model(self, mock_tokenizer_class): + """Test that TOKENIZER_MODEL env var is used when transformer_id is None""" + from embeddings import chunk_by_transformers_tokens, TOKENIZER_MODEL + + mock_tokenizer = Mock() + mock_tokenizer.tokenize.return_value = ['test'] * 50 + mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer + + chunk_by_transformers_tokens("Text", transformer_id=None) + + mock_tokenizer_class.from_pretrained.assert_called_once_with(TOKENIZER_MODEL) + + @patch('embeddings.AutoTokenizer') + def test_chunk_handles_empty_chunks(self, mock_tokenizer_class): + """Test that empty chunks are skipped""" + from embeddings import chunk_by_transformers_tokens + + mock_tokenizer = Mock() + mock_tokenizer.tokenize.return_value = ['test'] * 10 + mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer + + result = chunk_by_transformers_tokens("Test", max_token_length=512) + + # Should not contain any empty strings + assert all(chunk for chunk in result) + + +class TestEmbedText: + """Tests for embed_text function""" + + @patch('embeddings.model') + @patch('embeddings.chunk_by_transformers_tokens') + def test_embed_single_chunk(self, mock_chunk, mock_model): + """Test embedding generation for single chunk""" + from embeddings import embed_text + + # Mock chunking to return single chunk + mock_chunk.return_value = ["Single chunk"] + + # Mock model to return embeddings + mock_embedding = np.array([[0.1, 0.2, 0.3]]) + mock_model.encode.return_value = mock_embedding + + result = embed_text("Test text") + + assert isinstance(result, np.ndarray) + assert result.shape == (1, 3) + mock_model.encode.assert_called_once_with(["Single chunk"]) + + @patch('embeddings.model') + @patch('embeddings.chunk_by_transformers_tokens') + def test_embed_multiple_chunks(self, mock_chunk, mock_model): + """Test embedding generation for multiple chunks""" + from embeddings import embed_text + + # Mock chunking to return multiple chunks + mock_chunk.return_value = ["Chunk 1", "Chunk 2", "Chunk 3"] + + # Mock model to return embeddings for each chunk + mock_embeddings = np.array([ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9] + ]) + mock_model.encode.return_value = mock_embeddings + + result = embed_text("Long test text") + + # Should return averaged embedding + expected_avg = np.mean(mock_embeddings, axis=0, keepdims=True) + np.testing.assert_array_almost_equal(result, expected_avg) + + @patch('embeddings.model') + @patch('embeddings.chunk_by_transformers_tokens') + def test_embed_limits_to_three_chunks(self, mock_chunk, mock_model): + """Test that only first 3 chunks are used""" + from embeddings import embed_text + + # Mock chunking to return more than 3 chunks + mock_chunk.return_value = ["Chunk 1", "Chunk 2", "Chunk 3", "Chunk 4", "Chunk 5"] + + # Mock model + mock_embeddings = np.array([ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9] + ]) + mock_model.encode.return_value = mock_embeddings + + result = embed_text("Very long text") + + # Should only encode first 3 chunks + mock_model.encode.assert_called_once() + chunks_encoded = mock_model.encode.call_args[0][0] + assert len(chunks_encoded) == 3 + + @patch('embeddings.model') + @patch('embeddings.chunk_by_transformers_tokens') + def test_embed_uses_correct_window_size(self, mock_chunk, mock_model): + """Test that chunking uses correct window size (80% of 512)""" + from embeddings import embed_text + + mock_chunk.return_value = ["Chunk"] + mock_model.encode.return_value = np.array([[0.1, 0.2, 0.3]]) + + embed_text("Test") + + # Should use window size of 80% of 512 = 409.6 -> 409 + mock_chunk.assert_called_once() + call_kwargs = mock_chunk.call_args[1] + assert call_kwargs['max_token_length'] == int(512 * 0.8) + + +class TestModelLoading: + """Tests for model loading configuration""" + + def test_embedding_model_env_var(self): + """Test that EMBEDDING_MODEL is loaded from environment""" + import embeddings + assert hasattr(embeddings, 'EMBEDDING_MODEL') + assert isinstance(embeddings.EMBEDDING_MODEL, str) + + def test_tokenizer_model_env_var(self): + """Test that TOKENIZER_MODEL is loaded from environment""" + import embeddings + assert hasattr(embeddings, 'TOKENIZER_MODEL') + assert isinstance(embeddings.TOKENIZER_MODEL, str) + + def test_model_exists(self): + """Test that model object exists""" + import embeddings + assert hasattr(embeddings, 'model') + assert embeddings.model is not None diff --git a/test_main.py b/test_main.py new file mode 100644 index 0000000..6628eb9 --- /dev/null +++ b/test_main.py @@ -0,0 +1,222 @@ +""" +Unit tests for main.py Flask application +""" +import pytest +import json +import numpy as np +from unittest.mock import patch, Mock + + +@pytest.fixture +def client(): + """Create a test client for the Flask app""" + from main import app + app.config['TESTING'] = True + with app.test_client() as client: + yield client + + +@pytest.fixture +def mock_embed_text(): + """Mock the embed_text function to avoid loading models in tests""" + with patch('main.embed_text') as mock: + # Return a mock embedding + mock.return_value = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]]) + yield mock + + +class TestEmbeddingsEndpoint: + """Tests for /embeddings endpoint""" + + def test_embeddings_success(self, client, mock_embed_text): + """Test successful embedding generation""" + response = client.post( + '/embeddings', + headers={'X-API-Key': 'abc123'}, + json={'text': 'Test text'} + ) + + assert response.status_code == 200 + data = json.loads(response.data) + assert 'embeddings' in data + assert isinstance(data['embeddings'], list) + # The mock returns shape (1, 5) which becomes [[0.1, 0.2, 0.3, 0.4, 0.5]] + assert len(data['embeddings']) == 1 + assert len(data['embeddings'][0]) == 5 + mock_embed_text.assert_called_once_with('Test text') + + def test_embeddings_missing_api_key(self, client): + """Test request without API key""" + response = client.post( + '/embeddings', + json={'text': 'Test text'} + ) + + assert response.status_code == 401 + data = json.loads(response.data) + assert 'error' in data + assert data['error'] == 'Invalid API key' + + def test_embeddings_invalid_api_key(self, client): + """Test request with invalid API key""" + response = client.post( + '/embeddings', + headers={'X-API-Key': 'wrong-key'}, + json={'text': 'Test text'} + ) + + assert response.status_code == 401 + data = json.loads(response.data) + assert 'error' in data + assert data['error'] == 'Invalid API key' + + def test_embeddings_missing_text(self, client): + """Test request without text field""" + response = client.post( + '/embeddings', + headers={'X-API-Key': 'abc123'}, + json={} + ) + + assert response.status_code == 400 + data = json.loads(response.data) + assert 'error' in data + assert data['error'] == 'Text is required' + + def test_embeddings_empty_text(self, client): + """Test request with empty text""" + response = client.post( + '/embeddings', + headers={'X-API-Key': 'abc123'}, + json={'text': ''} + ) + + assert response.status_code == 400 + data = json.loads(response.data) + assert 'error' in data + assert data['error'] == 'Text is required' + + def test_embeddings_none_text(self, client): + """Test request with null text""" + response = client.post( + '/embeddings', + headers={'X-API-Key': 'abc123'}, + json={'text': None} + ) + + assert response.status_code == 400 + data = json.loads(response.data) + assert 'error' in data + + def test_embeddings_invalid_json(self, client): + """Test request with invalid JSON""" + response = client.post( + '/embeddings', + headers={ + 'X-API-Key': 'abc123', + 'Content-Type': 'application/json' + }, + data='invalid json' + ) + + # Flask returns 400 for invalid JSON + assert response.status_code in [400, 415] + + def test_embeddings_long_text(self, client, mock_embed_text): + """Test embedding generation for long text""" + long_text = "word " * 1000 # Very long text + + response = client.post( + '/embeddings', + headers={'X-API-Key': 'abc123'}, + json={'text': long_text} + ) + + assert response.status_code == 200 + data = json.loads(response.data) + assert 'embeddings' in data + mock_embed_text.assert_called_once_with(long_text) + + def test_embeddings_with_default_api_key(self, client, mock_embed_text): + """Test that default API key 'abc123' works""" + response = client.post( + '/embeddings', + headers={'X-API-Key': 'abc123'}, + json={'text': 'Test'} + ) + + assert response.status_code == 200 + data = json.loads(response.data) + assert 'embeddings' in data + + def test_embeddings_special_characters(self, client, mock_embed_text): + """Test embedding generation with special characters""" + special_text = "Hello! @#$%^&*() 你好 مرحبا" + + response = client.post( + '/embeddings', + headers={'X-API-Key': 'abc123'}, + json={'text': special_text} + ) + + assert response.status_code == 200 + data = json.loads(response.data) + assert 'embeddings' in data + mock_embed_text.assert_called_once_with(special_text) + + def test_embeddings_response_format(self, client, mock_embed_text): + """Test that response format matches expected structure""" + response = client.post( + '/embeddings', + headers={'X-API-Key': 'abc123'}, + json={'text': 'Test'} + ) + + assert response.status_code == 200 + assert response.content_type == 'application/json' + + data = json.loads(response.data) + assert isinstance(data, dict) + assert 'embeddings' in data + assert isinstance(data['embeddings'], list) + # embeddings is a list of lists (2D array converted to JSON) + assert len(data['embeddings']) > 0 + # Check that inner embeddings are numeric + assert all(isinstance(x, (int, float)) for x in data['embeddings'][0]) + + +class TestMethodNotAllowed: + """Tests for unsupported HTTP methods""" + + def test_get_not_allowed(self, client): + """Test that GET is not allowed on /embeddings""" + response = client.get('/embeddings') + assert response.status_code == 405 + + def test_put_not_allowed(self, client): + """Test that PUT is not allowed on /embeddings""" + response = client.put( + '/embeddings', + headers={'X-API-Key': 'abc123'}, + json={'text': 'Test'} + ) + assert response.status_code == 405 + + def test_delete_not_allowed(self, client): + """Test that DELETE is not allowed on /embeddings""" + response = client.delete('/embeddings') + assert response.status_code == 405 + + +class TestHealthCheck: + """Tests for basic health/connectivity""" + + def test_404_on_root(self, client): + """Test that root path returns 404""" + response = client.get('/') + assert response.status_code == 404 + + def test_404_on_unknown_path(self, client): + """Test that unknown paths return 404""" + response = client.get('/unknown') + assert response.status_code == 404