Add optional allow-list support for mtls-bridge paths
This commit is contained in:
+74
-106
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from flask import Flask, Response, g, request
|
||||
@@ -15,23 +16,39 @@ logger = logging.getLogger("mtls-bridge")
|
||||
logging.getLogger("werkzeug").setLevel(logging.WARNING)
|
||||
|
||||
# Config via env
|
||||
TARGET_URL = os.environ.get("TARGET_URL")
|
||||
TARGET_URL = (os.environ.get("TARGET_URL") or "").strip()
|
||||
CLIENT_CERT = os.environ.get("CLIENT_CERT", "/certs/client.crt")
|
||||
CLIENT_KEY = os.environ.get("CLIENT_KEY", "/certs/client.key")
|
||||
UPSTREAM_CA_CERT = os.environ.get("UPSTREAM_CA_CERT", os.environ.get("CA_CERT", "")).strip()
|
||||
# Backward-compat alias: keep CA_CERT defined so legacy code paths/log statements don't crash.
|
||||
CA_CERT = UPSTREAM_CA_CERT
|
||||
TIMEOUT = int(os.environ.get("TIMEOUT", "5"))
|
||||
HEALTH_ENDPOINT = os.environ.get("HEALTH_ENDPOINT", "/_mtls_bridge/health")
|
||||
ALLOWED_PATHS_FILE = (os.environ.get("ALLOWED_PATHS_FILE") or "").strip()
|
||||
|
||||
logger.info(
|
||||
"mtls-bridge starting target_url=%s timeout=%ss cert=%s key=%s ca=%s log_level=%s",
|
||||
TARGET_URL,
|
||||
TIMEOUT,
|
||||
CLIENT_CERT,
|
||||
CLIENT_KEY,
|
||||
CA_CERT,
|
||||
os.environ.get("LOG_LEVEL", "INFO"),
|
||||
)
|
||||
|
||||
def normalize_path(path: str) -> str:
|
||||
if not path or path == "/":
|
||||
return "/"
|
||||
return f"/{path.lstrip('/')}"
|
||||
|
||||
|
||||
def load_allowed_paths() -> set[str]:
|
||||
if not ALLOWED_PATHS_FILE:
|
||||
return set()
|
||||
|
||||
if not os.path.exists(ALLOWED_PATHS_FILE):
|
||||
logger.warning("ALLOWED_PATHS_FILE does not exist: %s (allow-list disabled)", ALLOWED_PATHS_FILE)
|
||||
return set()
|
||||
|
||||
allowed_paths = set()
|
||||
with open(ALLOWED_PATHS_FILE, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
entry = line.strip()
|
||||
if not entry or entry.startswith("#"):
|
||||
continue
|
||||
allowed_paths.add(normalize_path(entry))
|
||||
|
||||
logger.info("loaded %s allowed path(s) from %s", len(allowed_paths), ALLOWED_PATHS_FILE)
|
||||
return allowed_paths
|
||||
|
||||
|
||||
def get_verify_setting():
|
||||
@@ -54,95 +71,47 @@ def get_verify_setting():
|
||||
|
||||
|
||||
VERIFY_SETTING = get_verify_setting()
|
||||
|
||||
logger.info(
|
||||
"mtls-bridge starting target_url=%s timeout=%ss cert=%s key=%s verify=%s log_level=%s",
|
||||
TARGET_URL,
|
||||
TIMEOUT,
|
||||
CLIENT_CERT,
|
||||
CLIENT_KEY,
|
||||
VERIFY_SETTING,
|
||||
os.environ.get("LOG_LEVEL", "INFO"),
|
||||
)
|
||||
|
||||
|
||||
def get_verify_setting():
|
||||
if not UPSTREAM_CA_CERT:
|
||||
return True
|
||||
|
||||
lowered = UPSTREAM_CA_CERT.lower()
|
||||
if lowered in {"false", "0", "no"}:
|
||||
logger.warning("TLS verification for upstream is disabled via UPSTREAM_CA_CERT=%s", UPSTREAM_CA_CERT)
|
||||
return False
|
||||
|
||||
if not os.path.exists(UPSTREAM_CA_CERT):
|
||||
logger.warning(
|
||||
"Configured UPSTREAM_CA_CERT path does not exist: %s (falling back to system CA bundle)",
|
||||
UPSTREAM_CA_CERT,
|
||||
)
|
||||
return True
|
||||
|
||||
return UPSTREAM_CA_CERT
|
||||
|
||||
|
||||
VERIFY_SETTING = get_verify_setting()
|
||||
|
||||
logger.info(
|
||||
"mtls-bridge starting target_url=%s timeout=%ss cert=%s key=%s verify=%s log_level=%s",
|
||||
TARGET_URL,
|
||||
TIMEOUT,
|
||||
CLIENT_CERT,
|
||||
CLIENT_KEY,
|
||||
VERIFY_SETTING,
|
||||
os.environ.get("LOG_LEVEL", "INFO"),
|
||||
)
|
||||
ALLOWED_PATHS = load_allowed_paths()
|
||||
|
||||
if TARGET_URL and TARGET_URL.lower().startswith("http://"):
|
||||
logger.warning(
|
||||
"TARGET_URL uses http://; upstream may redirect to https:// and change request behavior: %s",
|
||||
TARGET_URL,
|
||||
)
|
||||
|
||||
|
||||
def get_verify_setting():
|
||||
if not UPSTREAM_CA_CERT:
|
||||
return True
|
||||
|
||||
lowered = UPSTREAM_CA_CERT.lower()
|
||||
if lowered in {"false", "0", "no"}:
|
||||
logger.warning("TLS verification for upstream is disabled via UPSTREAM_CA_CERT=%s", UPSTREAM_CA_CERT)
|
||||
return False
|
||||
|
||||
if not os.path.exists(UPSTREAM_CA_CERT):
|
||||
logger.warning(
|
||||
"Configured UPSTREAM_CA_CERT path does not exist: %s (falling back to system CA bundle)",
|
||||
UPSTREAM_CA_CERT,
|
||||
)
|
||||
return True
|
||||
|
||||
return UPSTREAM_CA_CERT
|
||||
|
||||
|
||||
VERIFY_SETTING = get_verify_setting()
|
||||
logger.warning("TARGET_URL uses http:// (plaintext): %s", TARGET_URL)
|
||||
|
||||
logger.info(
|
||||
"mtls-bridge starting target_url=%s timeout=%ss cert=%s key=%s verify=%s log_level=%s",
|
||||
"mtls-bridge starting target_url=%s timeout=%ss cert=%s key=%s verify=%s health_endpoint=%s allow_list_file=%s allow_list_entries=%s log_level=%s",
|
||||
TARGET_URL,
|
||||
TIMEOUT,
|
||||
CLIENT_CERT,
|
||||
CLIENT_KEY,
|
||||
VERIFY_SETTING,
|
||||
HEALTH_ENDPOINT,
|
||||
ALLOWED_PATHS_FILE,
|
||||
len(ALLOWED_PATHS),
|
||||
os.environ.get("LOG_LEVEL", "INFO"),
|
||||
)
|
||||
|
||||
if TARGET_URL and TARGET_URL.lower().startswith("http://"):
|
||||
logger.warning(
|
||||
"TARGET_URL uses http://; upstream may redirect to https:// and change request behavior: %s",
|
||||
TARGET_URL,
|
||||
)
|
||||
|
||||
def build_upstream_url(path: str) -> str:
|
||||
"""Map incoming path directly onto TARGET_URL origin/base path."""
|
||||
if not TARGET_URL:
|
||||
raise ValueError("TARGET_URL is not set")
|
||||
|
||||
normalized_target = TARGET_URL.rstrip("/") + "/"
|
||||
normalized_path = path.lstrip("/")
|
||||
upstream_url = urljoin(normalized_target, normalized_path)
|
||||
|
||||
if request.query_string:
|
||||
upstream_url = f"{upstream_url}?{request.query_string.decode('utf-8', 'ignore')}"
|
||||
|
||||
return upstream_url
|
||||
|
||||
|
||||
@app.route("/health", methods=["GET"])
|
||||
def is_path_allowed(request_path: str) -> bool:
|
||||
if not ALLOWED_PATHS:
|
||||
return True
|
||||
return request_path in ALLOWED_PATHS
|
||||
|
||||
|
||||
@app.route(HEALTH_ENDPOINT, methods=["GET"])
|
||||
def health():
|
||||
logger.debug("healthcheck request from %s", request.remote_addr)
|
||||
return "OK", 200
|
||||
@@ -156,7 +125,7 @@ def before_request():
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
elapsed_ms = int((time.time() - g.request_start) * 1000)
|
||||
if request.path != "/health":
|
||||
if request.path != HEALTH_ENDPOINT:
|
||||
logger.info(
|
||||
"request complete method=%s path=%s status=%s elapsed_ms=%s",
|
||||
request.method,
|
||||
@@ -179,7 +148,7 @@ def after_request(response):
|
||||
provide_automatic_options=False,
|
||||
)
|
||||
def proxy(path):
|
||||
request_path = f"/{path}" if path else "/"
|
||||
request_path = normalize_path(path)
|
||||
request_size = len(request.get_data(cache=True))
|
||||
logger.info(
|
||||
"incoming request method=%s path=%s query=%s remote=%s bytes=%s",
|
||||
@@ -190,50 +159,49 @@ def proxy(path):
|
||||
request_size,
|
||||
)
|
||||
|
||||
if not TARGET_URL:
|
||||
logger.error("TARGET_URL is not set; cannot proxy request")
|
||||
return Response("TARGET_URL is not set", status=500)
|
||||
if not is_path_allowed(request_path):
|
||||
logger.warning("request blocked by allow-list path=%s", request_path)
|
||||
return Response("Endpoint not allowed", status=403)
|
||||
|
||||
try:
|
||||
url = f"{TARGET_URL.rstrip('/')}/{path}".rstrip("/")
|
||||
start_time = time.time()
|
||||
upstream_url = build_upstream_url(path)
|
||||
|
||||
headers = {k: v for k, v in request.headers if k.lower() != "host"}
|
||||
headers["X-Forwarded-By"] = "mtls-bridge"
|
||||
|
||||
logger.debug("forwarding request to upstream url=%s headers=%s", url, headers)
|
||||
|
||||
start_time = time.time()
|
||||
resp = requests.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
url=upstream_url,
|
||||
headers=headers,
|
||||
data=request.get_data(cache=True),
|
||||
cookies=request.cookies,
|
||||
cert=(CLIENT_CERT, CLIENT_KEY),
|
||||
verify=VERIFY_SETTING,
|
||||
timeout=TIMEOUT,
|
||||
allow_redirects=False,
|
||||
)
|
||||
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
logger.info(
|
||||
"upstream response status=%s url=%s elapsed_ms=%s response_bytes=%s",
|
||||
resp.status_code,
|
||||
url,
|
||||
upstream_url,
|
||||
elapsed_ms,
|
||||
len(resp.content),
|
||||
)
|
||||
|
||||
excluded_headers = ["content-encoding", "content-length", "transfer-encoding", "connection"]
|
||||
response_headers = [
|
||||
(k, v)
|
||||
for k, v in resp.raw.headers.items()
|
||||
if k.lower() not in excluded_headers
|
||||
]
|
||||
excluded_headers = {"content-encoding", "content-length", "transfer-encoding", "connection"}
|
||||
response_headers = [(k, v) for k, v in resp.headers.items() if k.lower() not in excluded_headers]
|
||||
|
||||
return Response(resp.content, resp.status_code, response_headers)
|
||||
|
||||
except Exception as e:
|
||||
except ValueError as exc:
|
||||
logger.error("proxy request failed: %s", exc)
|
||||
return Response(str(exc), status=500)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("proxy request failed")
|
||||
return Response(str(e), status=500)
|
||||
return Response(str(exc), status=500)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user