import logging import os import time from urllib.parse import urljoin import requests from flask import Flask, Response, g, request app = Flask(__name__) logging.basicConfig( level=os.environ.get("LOG_LEVEL", "INFO"), format="%(asctime)s %(levelname)s %(message)s", ) logger = logging.getLogger("mtls-bridge") logging.getLogger("werkzeug").setLevel(logging.WARNING) # Config via env TARGET_URL = (os.environ.get("TARGET_URL") or "").strip() CLIENT_CERT = os.environ.get("CLIENT_CERT", "/certs/client.crt") CLIENT_KEY = os.environ.get("CLIENT_KEY", "/certs/client.key") UPSTREAM_CA_CERT = os.environ.get("UPSTREAM_CA_CERT", os.environ.get("CA_CERT", "")).strip() TIMEOUT = int(os.environ.get("TIMEOUT", "5")) HEALTH_ENDPOINT = os.environ.get("HEALTH_ENDPOINT", "/_mtls_bridge/health") ALLOWED_PATHS_FILE = (os.environ.get("ALLOWED_PATHS_FILE") or "").strip() def normalize_path(path: str) -> str: if not path or path == "/": return "/" return f"/{path.lstrip('/')}" def load_allowed_paths() -> set[str]: if not ALLOWED_PATHS_FILE: return set() if not os.path.exists(ALLOWED_PATHS_FILE): logger.warning("ALLOWED_PATHS_FILE does not exist: %s (allow-list disabled)", ALLOWED_PATHS_FILE) return set() allowed_paths = set() with open(ALLOWED_PATHS_FILE, encoding="utf-8") as f: for line in f: entry = line.strip() if not entry or entry.startswith("#"): continue allowed_paths.add(normalize_path(entry)) logger.info("loaded %s allowed path(s) from %s", len(allowed_paths), ALLOWED_PATHS_FILE) return allowed_paths def get_verify_setting(): if not UPSTREAM_CA_CERT: return True lowered = UPSTREAM_CA_CERT.lower() if lowered in {"false", "0", "no"}: logger.warning("TLS verification for upstream is disabled via UPSTREAM_CA_CERT=%s", UPSTREAM_CA_CERT) return False if not os.path.exists(UPSTREAM_CA_CERT): logger.warning( "Configured UPSTREAM_CA_CERT path does not exist: %s (falling back to system CA bundle)", UPSTREAM_CA_CERT, ) return True return UPSTREAM_CA_CERT VERIFY_SETTING = get_verify_setting() ALLOWED_PATHS = load_allowed_paths() if TARGET_URL and TARGET_URL.lower().startswith("http://"): logger.warning("TARGET_URL uses http:// (plaintext): %s", TARGET_URL) logger.info( "mtls-bridge starting target_url=%s timeout=%ss cert=%s key=%s verify=%s health_endpoint=%s allow_list_file=%s allow_list_entries=%s log_level=%s", TARGET_URL, TIMEOUT, CLIENT_CERT, CLIENT_KEY, VERIFY_SETTING, HEALTH_ENDPOINT, ALLOWED_PATHS_FILE, len(ALLOWED_PATHS), os.environ.get("LOG_LEVEL", "INFO"), ) def build_upstream_url(path: str) -> str: """Map incoming path directly onto TARGET_URL origin/base path.""" if not TARGET_URL: raise ValueError("TARGET_URL is not set") normalized_target = TARGET_URL.rstrip("/") + "/" normalized_path = path.lstrip("/") upstream_url = urljoin(normalized_target, normalized_path) if request.query_string: upstream_url = f"{upstream_url}?{request.query_string.decode('utf-8', 'ignore')}" return upstream_url def is_path_allowed(request_path: str) -> bool: if not ALLOWED_PATHS: return True return request_path in ALLOWED_PATHS @app.route(HEALTH_ENDPOINT, methods=["GET"]) def health(): logger.debug("healthcheck request from %s", request.remote_addr) return "OK", 200 @app.before_request def before_request(): g.request_start = time.time() @app.after_request def after_request(response): elapsed_ms = int((time.time() - g.request_start) * 1000) if request.path != HEALTH_ENDPOINT: logger.info( "request complete method=%s path=%s status=%s elapsed_ms=%s", request.method, request.path, response.status_code, elapsed_ms, ) return response @app.route( "/", defaults={"path": ""}, methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], provide_automatic_options=False, ) @app.route( "/", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], provide_automatic_options=False, ) def proxy(path): request_path = normalize_path(path) request_size = len(request.get_data(cache=True)) logger.info( "incoming request method=%s path=%s query=%s remote=%s bytes=%s", request.method, request_path, request.query_string.decode("utf-8", "ignore"), request.remote_addr, request_size, ) if not is_path_allowed(request_path): logger.warning("request blocked by allow-list path=%s", request_path) return Response("Endpoint not allowed", status=403) try: upstream_url = build_upstream_url(path) headers = {k: v for k, v in request.headers if k.lower() != "host"} headers["X-Forwarded-By"] = "mtls-bridge" start_time = time.time() resp = requests.request( method=request.method, url=upstream_url, headers=headers, data=request.get_data(cache=True), cookies=request.cookies, cert=(CLIENT_CERT, CLIENT_KEY), verify=VERIFY_SETTING, timeout=TIMEOUT, allow_redirects=False, ) elapsed_ms = int((time.time() - start_time) * 1000) logger.info( "upstream response status=%s url=%s elapsed_ms=%s response_bytes=%s", resp.status_code, upstream_url, elapsed_ms, len(resp.content), ) excluded_headers = {"content-encoding", "content-length", "transfer-encoding", "connection"} response_headers = [(k, v) for k, v in resp.headers.items() if k.lower() not in excluded_headers] return Response(resp.content, resp.status_code, response_headers) except ValueError as exc: logger.error("proxy request failed: %s", exc) return Response(str(exc), status=500) except Exception as exc: # noqa: BLE001 logger.exception("proxy request failed") return Response(str(exc), status=500) if __name__ == "__main__": app.run(host="0.0.0.0", port=8080)