add njalla-tlsa-rotate.py

Signed-off-by: Martin Matous <m@matous.dev>
This commit is contained in:
Martin Matous 2025-03-18 23:25:46 +01:00
parent 63db8162e3
commit effce90992
Signed by: mmatous
GPG key ID: 8BED4CD352953224
2 changed files with 188 additions and 0 deletions

View file

@ -84,6 +84,18 @@ Usage: `kernel-update.py`
Alt.: `kernel-update.py <old-version> <new-version>`
## njalla-tlsa-rotate.py
Perform 3 1 1 + 3 1 1 TLSA key rollover for Maddy mailserver with 3h window. Since the script is stateless
and rebooting a machine would interfere, reboot is blocked via logind for the duration.
Status: active use
Dependencies (python): certbot, cryptography, dasbus
Dependencies (system): caddy, maddy, python3, systemd
Usage: Invoke periodically using systemd timer.
---
## flac-convert.py

176
njalla-tlsa-rotate.py Normal file
View file

@ -0,0 +1,176 @@
#!/usr/bin/env python3
# nonstdlib requirements: certbot, cryptography, dasbus
# non-programmatic: caddy, maddy, systemd
import datetime
import os
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Literal
import certbot.main as cbmain
import dasbus
import requests
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from dasbus.connection import SystemMessageBus
API_TOKEN = os.environ['API_TOKEN']
DOMAIN = 'matous.dev'
MX_SERVER = 'mx1'
MX_SUBDOMAIN = f'{MX_SERVER}.{DOMAIN}'
RECORD_NAME = f'_25._tcp.{MX_SERVER}'
LIVE_CERT_PATH = Path(f'/etc/letsencrypt/live/{MX_SUBDOMAIN}/')
NJALLA = 'https://njal.la/api/1/'
type Method = Literal['add-record', 'list-records', 'remove-record']
def njalla_api(method: Method, **params: dict[str, int | str]) -> dict[str, Any]:
headers = {'Authorization': f'Njalla {API_TOKEN}'}
params['domain'] = DOMAIN
response = requests.post(NJALLA, json={'method': method, 'params': params}, headers=headers, timeout=30)
response.raise_for_status()
j_response = response.json()
if 'result' not in j_response:
raise RuntimeError('API Error', j_response)
return j_response['result']
def add_tlsa(new_tlsa: str) -> None:
print('Adding new TLSA record')
njalla_api('add-record', name='_25._tcp.mx1', content=new_tlsa, ttl=3600, type='TLSA')
def remove_tlsa(entry_id: str) -> None:
print('Removing old TLSA record')
njalla_api('remove-record', id=entry_id)
def list_tlsa() -> dict[str, list[dict[str, str]]]:
return njalla_api('list-records')
def get_cert_hash(cert: x509.Certificate) -> str:
pubkey = cert.public_key()
pubkey_bytes = pubkey.public_bytes(serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo)
digest = hashes.Hash(hashes.SHA256())
digest.update(pubkey_bytes)
return digest.finalize().hex()
def renew_cert(current_cert_hash: str, old_links: dict[Path, Path]) -> None:
bus = SystemMessageBus()
logind = bus.get_proxy('org.freedesktop.login1', '/org/freedesktop/login1')
logind.Inhibit('shutdown', sys.argv[0], 'MX certs + TLSA rotation in progress', 'block')
new_cert_hash, new_links, systemd = renew_local(bus)
if new_cert_hash == current_cert_hash:
raise RuntimeError('New cert is the same as current. Perhaps the script ran prematurely?')
# restore old symlinks in /live/ for a bit
for link, target in old_links.items():
link.unlink()
link.symlink_to(target)
# this little maneuver's gonna cost us 3 hours
print('Waiting 3h for TLSA changes to propagate')
rotate_tlsa(current_cert_hash, new_cert_hash)
# use new certs in /live/
for link, target in new_links.items():
link.unlink()
link.symlink_to(target)
# alert maddy to new certs
systemd.ReloadUnit('maddy.service', 'replace')
print('Reloaded maddy')
def save_live_symlinks() -> dict[Path, Path]:
links = {}
for child in LIVE_CERT_PATH.iterdir():
if child.is_symlink():
links[child] = child.resolve(strict=True)
print(f'{child} is symlink to {links[child]}')
return links
def rotate_cert_and_tlsa(current_cert: x509.Certificate) -> None:
print('Getting current cert hash')
chash = get_cert_hash(current_cert)
print('Saving current (old) symlinks')
old_links = save_live_symlinks()
print('Renewing cert')
renew_cert(chash, old_links)
def renew_local(bus: dasbus.connection.MessageBus) -> tuple[str, dict[Path, Path], dasbus.connection.ObjectProxy]:
# serve challenges from subdomain
caddy_subconf_content = f'{MX_SUBDOMAIN}:80 {{\n root * /srv/{DOMAIN} \n file_server \n }}'
caddy_subconf_file = Path(f'/etc/caddy/Caddyfile.d/{MX_SUBDOMAIN}')
caddy_subconf_file.write_text(caddy_subconf_content)
systemd = bus.get_proxy('org.freedesktop.systemd1', '/org/freedesktop/systemd1')
systemd.ReloadUnit('caddy.service', 'replace')
print('Caddy reloaded')
# make holes in FW for plaintext traffic
firewalld = bus.get_proxy('org.fedoraproject.FirewallD1', '/org/fedoraproject/FirewallD1')
try:
zone = firewalld.addPort('', '80', 'tcp', 0)
print(f'Added port 80/tcp to {zone}')
except dasbus.error.DBusError as e:
if e == 'ALREADY_ENABLED':
print('Port 80/tcp already open')
# non-interactive (-n) renewals insert random delays for over 5 minutes
cbmain.main(['renew', '--webroot', '-w', f'/srv/{DOMAIN}', '--cert-name', MX_SUBDOMAIN, '-n'])
print('Cert renewed')
subprocess.run(
['/bin/setfacl', '--recursive', '--logical', '--modify', 'u:maddy:rX', f'/etc/letsencrypt/live/{MX_SUBDOMAIN}'],
check=True,
)
print("Cert acl'ed")
zone = firewalld.removePort('', '80', 'tcp')
print(f'Removed port 80/tcp from {zone}')
caddy_subconf_file.unlink()
systemd.ReloadUnit('caddy.service', 'replace')
print('Caddy reloaded')
# live cert changed, get new hash, save links
pem_data = (LIVE_CERT_PATH / 'cert.pem').read_bytes()
new_cert = x509.load_pem_x509_certificate(pem_data)
new_hash = get_cert_hash(new_cert)
new_links = save_live_symlinks()
return new_hash, new_links, systemd
def rotate_tlsa(current_cert_hash: str, new_cert_hash: str) -> None:
# get tlsas from njalla, get current id, delete extraneous
entries = list_tlsa()['records']
current_id = None
for entry in entries:
if entry['content'] != current_cert_hash:
remove_tlsa(entry['id'])
else:
current_id = entry['id']
add_tlsa(f'3 1 1 {new_cert_hash}')
# give DNS time to propagate
hours_3 = 60 * 60 * 3
time.sleep(hours_3)
if current_id is not None:
remove_tlsa(current_id)
else:
print('No previous entry to remove.')
pem_data = (LIVE_CERT_PATH / 'cert.pem').read_bytes()
current_cert = x509.load_pem_x509_certificate(pem_data)
to_expiry = current_cert.not_valid_after_utc - datetime.datetime.now(datetime.UTC)
if to_expiry < datetime.timedelta(days=8):
rotate_cert_and_tlsa(current_cert)
print('TLSA rotation complete')
else:
print(f'Cert has {to_expiry} days remaining. Nothing to do.')