Skip to content

Commit

Permalink
read write env variables without extra package, tests added
Browse files Browse the repository at this point in the history
  • Loading branch information
zain-sohail committed Jan 12, 2025
1 parent 9d22be7 commit 80ef9c1
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 42 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ dependencies = [
"joblib>=1.2.0",
"pyarrow>=14.0.1,<17.0",
"pydantic>=2.8.2",
"python-dotenv>=1.0.1",
]

[project.urls]
Expand Down
51 changes: 51 additions & 0 deletions src/sed/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,54 @@ def complete_dictionary(dictionary: dict, base_dictionary: dict) -> dict:
dictionary[k] = v

return dictionary


def read_env_var(var_name: str) -> str | None:
"""Read an environment variable from the .env file in the user config directory.
Args:
var_name (str): Name of the environment variable to read
Returns:
str | None: Value of the environment variable or None if not found
"""
env_path = USER_CONFIG_PATH / ".env"
if not env_path.exists():
logger.debug(f"Environment variable {var_name} not found in .env file")
return None

with open(env_path) as f:
for line in f:
if line.startswith(f"{var_name}="):
return line.strip().split("=", 1)[1]
logger.debug(f"Environment variable {var_name} not found in .env file")
return None


def save_env_var(var_name: str, value: str) -> None:
"""Save an environment variable to the .env file in the user config directory.
If the file exists, preserves other variables. If not, creates a new file.
Args:
var_name (str): Name of the environment variable to save
value (str): Value to save for the environment variable
"""
env_path = USER_CONFIG_PATH / ".env"
env_content = {}

# Read existing variables if file exists
if env_path.exists():
with open(env_path) as f:
for line in f:
if "=" in line:
key, val = line.strip().split("=", 1)
env_content[key] = val

# Update or add new variable
env_content[var_name] = value

# Write all variables back to file
with open(env_path, "w") as f:
for key, val in env_content.items():
f.write(f"{key}={val}\n")
logger.debug(f"Environment variable {var_name} saved to .env file")
20 changes: 6 additions & 14 deletions src/sed/loader/flash/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
"""
from __future__ import annotations

import os
from pathlib import Path

import requests
from dotenv import load_dotenv
from dotenv import set_key

from sed.core.config import read_env_var
from sed.core.config import save_env_var
from sed.core.logging import setup_logging

logger = setup_logging("flash_metadata_retriever")
Expand All @@ -33,16 +30,11 @@ def __init__(self, metadata_config: dict, token: str = None) -> None:
"""
# Token handling
if token:
# Save token to .env file in user's home directory
env_path = Path.home() / ".sed" / ".env"
env_path.parent.mkdir(parents=True, exist_ok=True)
set_key(str(env_path), "SCICAT_TOKEN", token)
self.token = token
save_env_var("SCICAT_TOKEN", self.token)
else:
# Try to load token from config or environment
self.token = metadata_config.get("token")
if not self.token:
load_dotenv(Path.home() / ".sed" / ".env")
self.token = os.getenv("SCICAT_TOKEN")
# Try to load token from config or .env file
self.token = read_env_var("SCICAT_TOKEN")

if not self.token:
raise ValueError(
Expand Down
46 changes: 19 additions & 27 deletions tests/loader/flash/test_flash_metadata.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""Tests for FlashLoader metadata functionality"""
from __future__ import annotations

from pathlib import Path
import os

import pytest

from sed.core.config import read_env_var
from sed.core.config import save_env_var
from sed.core.config import USER_CONFIG_PATH
from sed.loader.flash.metadata import MetadataRetriever

ENV_PATH = USER_CONFIG_PATH / ".env"


@pytest.fixture
def mock_requests(requests_mock) -> None:
Expand All @@ -15,14 +20,6 @@ def mock_requests(requests_mock) -> None:
requests_mock.get(dataset_url, json={"fake": "data"}, status_code=200)


@pytest.fixture
def mock_env_token(monkeypatch, tmp_path) -> None:
# Create a temporary .env file
env_path = tmp_path / ".env"
env_path.write_text("SCICAT_TOKEN=env_test_token")
monkeypatch.setattr(Path, "home", lambda: tmp_path)


def test_get_metadata_with_explicit_token(mock_requests: None) -> None: # noqa: ARG001
metadata_config = {
"archiver_url": "https://example.com",
Expand All @@ -31,27 +28,21 @@ def test_get_metadata_with_explicit_token(mock_requests: None) -> None: # noqa:
metadata = retriever.get_metadata("11013410", ["43878"])
assert isinstance(metadata, dict)
assert metadata == {"fake": "data"}
assert ENV_PATH.exists()
assert read_env_var("SCICAT_TOKEN") == "explicit_test_token"
os.remove(ENV_PATH)


def test_get_metadata_with_config_token(mock_requests: None) -> None: # noqa: ARG001
metadata_config = {
"archiver_url": "https://example.com",
"token": "config_test_token",
}
retriever = MetadataRetriever(metadata_config)
metadata = retriever.get_metadata("11013410", ["43878"])
assert isinstance(metadata, dict)
assert metadata == {"fake": "data"}


def test_get_metadata_with_env_token(mock_requests: None, mock_env_token: None) -> None: # noqa: ARG001
def test_get_metadata_with_env_token(mock_requests: None) -> None: # noqa: ARG001
save_env_var("SCICAT_TOKEN", "env_test_token")
metadata_config = {
"archiver_url": "https://example.com",
}
retriever = MetadataRetriever(metadata_config)
metadata = retriever.get_metadata("11013410", ["43878"])
assert isinstance(metadata, dict)
assert metadata == {"fake": "data"}
os.remove(ENV_PATH)


def test_get_metadata_no_token() -> None:
Expand All @@ -66,39 +57,40 @@ def test_get_metadata_no_url() -> None:
metadata_config: dict = {}
with pytest.raises(ValueError, match="No URL provided for fetching metadata"):
MetadataRetriever(metadata_config, token="test_token")
os.remove(ENV_PATH)


def test_get_metadata_with_existing_metadata(mock_requests: None) -> None: # noqa: ARG001
metadata_config = {
"archiver_url": "https://example.com",
"token": "test_token",
}
retriever = MetadataRetriever(metadata_config)
retriever = MetadataRetriever(metadata_config, token="test_token")
existing_metadata = {"existing": "metadata"}
metadata = retriever.get_metadata("11013410", ["43878"], existing_metadata)
assert isinstance(metadata, dict)
assert metadata == {"existing": "metadata", "fake": "data"}
os.remove(ENV_PATH)


def test_get_metadata_per_run(mock_requests: None) -> None: # noqa: ARG001
metadata_config = {
"archiver_url": "https://example.com",
"token": "test_token",
}
retriever = MetadataRetriever(metadata_config)
retriever = MetadataRetriever(metadata_config, token="test_token")
metadata = retriever._get_metadata_per_run("11013410/43878")
assert isinstance(metadata, dict)
assert metadata == {"fake": "data"}
os.remove(ENV_PATH)


def test_create_dataset_url_by_PID() -> None:
metadata_config = {
"archiver_url": "https://example.com",
"token": "test_token",
}
retriever = MetadataRetriever(metadata_config)
retriever = MetadataRetriever(metadata_config, token="test_token")
pid = "11013410/43878"
url = retriever._create_new_dataset_url(pid)
expected_url = "https://example.com/Datasets/11013410%2F43878"
assert isinstance(url, str)
assert url == expected_url
os.remove(ENV_PATH)
47 changes: 47 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from sed.core.config import complete_dictionary
from sed.core.config import load_config
from sed.core.config import parse_config
from sed.core.config import read_env_var
from sed.core.config import save_config
from sed.core.config import save_env_var

test_dir = os.path.dirname(__file__)
test_config_dir = Path(f"{test_dir}/data/loader/")
Expand Down Expand Up @@ -231,3 +233,48 @@ def test_invalid_config_wrong_values():
verify_config=True,
)
assert "Invalid value 9999 for gid. Group not found." in str(e.value)


def test_env_var_read_write(tmp_path, monkeypatch):
"""Test reading and writing environment variables."""
# Mock USER_CONFIG_PATH to use a temporary directory
monkeypatch.setattr("sed.core.config.USER_CONFIG_PATH", tmp_path)

# Test writing a new variable
save_env_var("TEST_VAR", "test_value")
assert read_env_var("TEST_VAR") == "test_value"

# Test writing multiple variables
save_env_var("TEST_VAR2", "test_value2")
assert read_env_var("TEST_VAR") == "test_value"
assert read_env_var("TEST_VAR2") == "test_value2"

# Test overwriting an existing variable
save_env_var("TEST_VAR", "new_value")
assert read_env_var("TEST_VAR") == "new_value"
assert read_env_var("TEST_VAR2") == "test_value2" # Other variables unchanged

# Test reading non-existent variable
assert read_env_var("NON_EXISTENT_VAR") is None


def test_env_var_read_no_file(tmp_path, monkeypatch):
"""Test reading environment variables when .env file doesn't exist."""
# Mock USER_CONFIG_PATH to use an empty temporary directory
monkeypatch.setattr("sed.core.config.USER_CONFIG_PATH", tmp_path)

# Test reading from non-existent file
assert read_env_var("TEST_VAR") is None


def test_env_var_special_characters():
"""Test reading and writing environment variables with special characters."""
test_cases = {
"TEST_URL": "http://example.com/path?query=value",
"TEST_PATH": "/path/to/something/with/spaces and special=chars",
"TEST_QUOTES": "value with 'single' and \"double\" quotes",
}

for var_name, value in test_cases.items():
save_env_var(var_name, value)
assert read_env_var(var_name) == value

0 comments on commit 80ef9c1

Please sign in to comment.