"""Tests for scripts/bootstrap.py — focus on secret safety and idempotency.""" import sys from pathlib import Path import pytest # bootstrap.py lives in scripts/, not in a package — add it to path sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "scripts")) from bootstrap import ( build_uris, copy_example_if_missing, extract_python_version, generate_secrets_file, load_env_file, ) # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- SECRETS_SPEC = { "JWT_SECRET": (32, "hex"), "MONGO_PASSWORD": (16, "hex"), "POSTGRES_PASSWORD": (16, "hex"), } ALFRED_ENV = """\ MONGO_HOST=mongodb MONGO_PORT=27017 MONGO_USER=alfred MONGO_DB_NAME=mydb POSTGRES_HOST=vectordb POSTGRES_PORT=5432 POSTGRES_USER=alfred POSTGRES_DB_NAME=alfred """ SECRETS_ENV = """\ # Auto-generated secrets — DO NOT COMMIT JWT_SECRET=deadbeef MONGO_PASSWORD=cafebabe POSTGRES_PASSWORD=f00dface """ @pytest.fixture def secrets_file(tmp_path): """An existing .env.secrets with pre-generated values.""" p = tmp_path / ".env.secrets" p.write_text(SECRETS_ENV) return p @pytest.fixture def alfred_file(tmp_path): p = tmp_path / ".env.alfred" p.write_text(ALFRED_ENV) return p # --------------------------------------------------------------------------- # load_env_file # --------------------------------------------------------------------------- class TestLoadEnvFile: def test_parses_key_value_pairs(self, tmp_path): f = tmp_path / ".env" f.write_text("FOO=bar\nBAZ=qux\n") assert load_env_file(f) == {"FOO": "bar", "BAZ": "qux"} def test_ignores_comments_and_blanks(self, tmp_path): f = tmp_path / ".env" f.write_text("# comment\n\nFOO=bar\n") assert load_env_file(f) == {"FOO": "bar"} def test_missing_file_returns_empty(self, tmp_path): assert load_env_file(tmp_path / "nonexistent") == {} def test_value_with_equals_sign(self, tmp_path): """Values containing '=' must be preserved intact (e.g. base64).""" f = tmp_path / ".env" f.write_text("KEY=abc=def==\n") assert load_env_file(f)["KEY"] == "abc=def==" # --------------------------------------------------------------------------- # generate_secrets_file — the critical ones # --------------------------------------------------------------------------- class TestGenerateSecretsFile: def test_generates_all_secrets_on_first_run(self, tmp_path): path = tmp_path / ".env.secrets" generate_secrets_file(path, SECRETS_SPEC) result = load_env_file(path) assert set(SECRETS_SPEC.keys()) <= result.keys() assert all(result[k] for k in SECRETS_SPEC) # non-empty def test_never_overwrites_existing_secrets(self, secrets_file): """Core safety property: running bootstrap again must not change existing values.""" before = load_env_file(secrets_file) generate_secrets_file(secrets_file, SECRETS_SPEC) after = load_env_file(secrets_file) for key in before: assert after[key] == before[key], f"{key} was overwritten!" def test_adds_missing_secrets_without_touching_existing(self, secrets_file): """Only keys absent from the file should be added.""" before = load_env_file(secrets_file) # POSTGRES_PASSWORD already exists; JWT_SECRET already exists # Add a new key to the spec that is not yet in the file spec = {**SECRETS_SPEC, "NEW_SECRET": (16, "hex")} generate_secrets_file(secrets_file, spec) after = load_env_file(secrets_file) # Existing values untouched for key in before: assert after[key] == before[key] # New key added assert "NEW_SECRET" in after assert len(after["NEW_SECRET"]) == 32 # 16 bytes → 32 hex chars def test_idempotent_across_multiple_runs(self, tmp_path): """Calling bootstrap N times must produce stable secrets.""" path = tmp_path / ".env.secrets" generate_secrets_file(path, SECRETS_SPEC) after_first = load_env_file(path) generate_secrets_file(path, SECRETS_SPEC) after_second = load_env_file(path) assert after_first == after_second def test_hex_secret_has_correct_length(self, tmp_path): path = tmp_path / ".env.secrets" generate_secrets_file(path, {"MY_KEY": (32, "hex")}) value = load_env_file(path)["MY_KEY"] assert len(value) == 64 # 32 bytes → 64 hex chars assert all(c in "0123456789abcdef" for c in value) def test_preserves_comments_in_existing_file(self, secrets_file): """Comments in .env.secrets must survive a bootstrap run.""" generate_secrets_file(secrets_file, SECRETS_SPEC) content = secrets_file.read_text() assert "# Auto-generated secrets" in content # --------------------------------------------------------------------------- # build_uris # --------------------------------------------------------------------------- class TestBuildUris: def test_writes_uris_to_secrets_file(self, alfred_file, secrets_file): build_uris(alfred_file, secrets_file) result = load_env_file(secrets_file) assert "MONGO_URI" in result assert "POSTGRES_URI" in result def test_mongo_uri_contains_all_components(self, alfred_file, secrets_file): build_uris(alfred_file, secrets_file) uri = load_env_file(secrets_file)["MONGO_URI"] assert "alfred" in uri # user assert "cafebabe" in uri # password from secrets assert "mongodb" in uri # host assert "27017" in uri # port assert "mydb" in uri # dbname assert "authSource=admin" in uri def test_postgres_uri_contains_all_components(self, alfred_file, secrets_file): build_uris(alfred_file, secrets_file) uri = load_env_file(secrets_file)["POSTGRES_URI"] assert "alfred" in uri assert "f00dface" in uri # password from secrets assert "vectordb" in uri assert "5432" in uri assert uri.startswith("postgresql://") def test_uri_is_updated_when_host_changes(self, tmp_path, secrets_file): """If MONGO_HOST changes in .env.alfred, the URI must reflect it.""" alfred = tmp_path / ".env.alfred" alfred.write_text(ALFRED_ENV.replace("MONGO_HOST=mongodb", "MONGO_HOST=newhost")) build_uris(alfred, secrets_file) uri = load_env_file(secrets_file)["MONGO_URI"] assert "@newhost:" in uri assert "@mongodb:" not in uri def test_uri_update_does_not_alter_other_secrets(self, alfred_file, secrets_file): """Recomputing URIs must not touch JWT_SECRET or passwords.""" before = load_env_file(secrets_file) build_uris(alfred_file, secrets_file) after = load_env_file(secrets_file) for key in before: assert after[key] == before[key], f"{key} was altered by build_uris!" def test_uri_recomputed_on_repeated_calls(self, tmp_path, secrets_file): """Calling build_uris twice with different configs produces the latest URI.""" alfred_v1 = tmp_path / "alfred_v1" alfred_v1.write_text(ALFRED_ENV) build_uris(alfred_v1, secrets_file) uri_v1 = load_env_file(secrets_file)["MONGO_URI"] alfred_v2 = tmp_path / "alfred_v2" alfred_v2.write_text(ALFRED_ENV.replace("MONGO_DB_NAME=mydb", "MONGO_DB_NAME=otherdb")) build_uris(alfred_v2, secrets_file) uri_v2 = load_env_file(secrets_file)["MONGO_URI"] assert "mydb" not in uri_v2 assert "otherdb" in uri_v2 # Password unchanged across both calls assert load_env_file(secrets_file)["MONGO_PASSWORD"] == "cafebabe" # --------------------------------------------------------------------------- # copy_example_if_missing # --------------------------------------------------------------------------- class TestCopyExampleIfMissing: def test_copies_when_dst_missing(self, tmp_path): src = tmp_path / "src.env" src.write_text("FOO=bar\n") dst = tmp_path / "dst.env" copy_example_if_missing(src, dst, "test") assert dst.read_text() == "FOO=bar\n" def test_never_overwrites_existing_dst(self, tmp_path): src = tmp_path / "src.env" src.write_text("FOO=new\n") dst = tmp_path / "dst.env" dst.write_text("FOO=original\n") copy_example_if_missing(src, dst, "test") assert dst.read_text() == "FOO=original\n" def test_silent_skip_when_src_missing(self, tmp_path): """Should not raise if the example file doesn't exist yet.""" dst = tmp_path / "dst.env" copy_example_if_missing(tmp_path / "nonexistent.env", dst, "test") assert not dst.exists() # --------------------------------------------------------------------------- # extract_python_version # --------------------------------------------------------------------------- class TestExtractPythonVersion: @pytest.mark.parametrize("spec,expected_full,expected_short", [ ("==3.14.3", "3.14.3", "3.14"), ("^3.12.0", "3.12.0", "3.12"), ("~3.11.1", "3.11.1", "3.11"), ("3.10.5", "3.10.5", "3.10"), ]) def test_parses_version_specifiers(self, spec, expected_full, expected_short): full, short = extract_python_version(spec) assert full == expected_full assert short == expected_short def test_raises_on_invalid_version(self): with pytest.raises(ValueError): extract_python_version("3")