piermesh/src/Cryptography/WhaleSong.py

295 lines
9.5 KiB
Python
Executable File

import base64
import os
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.serialization import (
Encoding,
NoEncryption,
ParameterFormat,
PublicFormat,
PrivateFormat,
)
import cryptography.hazmat.primitives.serialization as Serialization
import msgpack
from Daisy.Store import Store
# TODO: Different store directories per node
class DHEFern:
"""
Attributes
----------
cLog
Method reference to `run.Node.cLog` so we can log to the ui from here
loadedParams: dict
In memory representations of cryptography parameters
loadedKeys: dict
In memory representations of cryptography keys
nodeNickname: str
Name of node for isolating configs when running multiple nodes
cache: Components.daisy.Cache
Daisy cache for use in storing cryptography information
publicKey
Public key for node
privateKey
Private key for node
"""
def __init__(self, cache, nodeNickname, cLog):
"""
Parameters
----------
cache: Components.daisy.Cache
Reference to the node instances Daisy cache
nodeNickname: str
Node nickname for record storage
cLog
Reference to `run.Node.cLog`
"""
self.cLog = cLog
self.stores = {}
self.loadedParams = {}
self.loadedKeys = {}
self.nodeNickname = nodeNickname
self.cache = cache
if os.path.exists("daisy/cryptography/{0}/param".format(nodeNickname)) == False:
self.initStore("param")
else:
self.stores["param"] = Store("param", "cryptography", nodeNickname)
self.params = self.loadParamBytes(self.stores["param"].get()["self"])
self.cLog(20, "Param store initialized")
if os.path.exists("daisy/cryptography/{0}/key".format(nodeNickname)) == False:
self.cLog(20, "Key store DNE, initializing")
self.initStore("key")
self.genKeyPair()
else:
self.cLog(20, "Key store exists, loading")
self.stores["key"] = Store("key", "cryptography", nodeNickname)
self.cLog(20, "Store loaded")
# tks = self.stores["key"].get()
# self.publicKey = tks["self"]["publicKey"]
# self.privateKey = tks["self"]["privateKey"]
self.cLog(20, "Key store initialized")
def checkInMem(self, store: str, nodeID: str):
"""
Check if parameters or keys are loaded for node of nodeID
Parameters
----------
store: str
Whether to check loaded keys or parameters
"""
if store == "param":
return nodeID in self.loadedParams.keys()
elif store == "key":
return nodeID in self.loadedKeys.keys()
def loadRecordToMem(self, store: str, nodeID: str):
"""
Load record of nodeID from store to either keys or pameters
"""
r = self.getRecord(store, nodeID)
if r == False:
self.cLog(
30, "Tried to load nonexistent {0} for node {1}".format(store, nodeID)
)
return False
elif self.checkInMem(store, nodeID):
self.cLog(10, "{0}s already deserialized, skipping".format(store))
else:
if store == "param":
self.loadedParams[nodeID] = self.loadParamBytes(r)
elif store == "key":
self.loadedKeys[nodeID] = {
"publicKey": Serialization.load_pem_public_key(r["publicKey"]),
"privateKey": Serialization.load_pem_private_key(
r["privateKey"], None
),
}
return True
def getRecord(self, store: str, key: str):
"""
Get record from store: store with key: key
"""
r = stores[store].getRecord(key)
if r == False:
self.cLog(20, "Record does not exist")
return False
else:
return r
def initStore(self, store: str):
"""
Initialize store: store
"""
self.stores[store] = Store(store, "cryptography", self.nodeNickname)
if store == "param":
self.genParams()
self.stores[store].update("self", self.getParamsBytes(), recur=False)
elif store == "key":
self.stores[store].update("self", {}, recur=False)
else:
self.cLog(30, "Store not defined")
def genParams(self):
"""
Generate Diffie Hellman parameters
"""
params = dh.generate_parameters(generator=2, key_size=2048)
self.params = params
return params
def getParamsBytes(self):
"""
Get bytes encoded from self.parameters (TODO: Encode from store)
"""
return self.params.parameter_bytes(Encoding.PEM, ParameterFormat.PKCS3)
def loadParamBytes(self, pemBytes: bytes):
"""
Load parameters to self.params from given bytes (TODO: Load from store)
"""
self.params = Serialization.load_pem_parameters(pemBytes)
return self.params
def genKeyPair(self, paramsOverride=False, setSelf: bool = True):
"""
Generate public and private keys from self.params (TODO: Gen from passed params)
paramsOverride
False or parameters to use (TODO)
setSelf: bool
Whether to set self.privateKey and self.publicKey
"""
privateKey = self.params.generate_private_key()
if setSelf:
self.privateKey = privateKey
publicKey = privateKey.public_key()
if setSelf:
self.publicKey = publicKey
self.updateStore(
"key",
"self",
{
"publicKey": self.publicKey.public_bytes(
Encoding.PEM, PublicFormat.SubjectPublicKeyInfo
),
"privateKey": self.privateKey.private_bytes(
Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()
),
},
)
return [privateKey, publicKey]
else:
publicKey = publicKey.public_bytes(
Encoding.PEM, PublicFormat.SubjectPublicKeyInfo
)
privateKey = privateKey.private_bytes(
Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()
)
return [privateKey, publicKey]
def keyDerive(self, pubKey: bytes, salt: bytes, nodeID: str, params: bytes):
"""
Derive shared key using Diffie Hellman
pubKey: bytes
Public key
nodeID: str
PierMesh node ID
params: bytes
Encryption parameters
"""
if self.checkInMem("param", nodeID) == False:
if self.getRecord("param", nodeID) == False:
self.updateStore("param", nodeID, params, recur=False)
self.loadRecordToMem("param", nodeID)
self.cLog(20, "Precheck done for key derivation")
# TODO: Load them and if private key exists load it, otherwise generate a private key
if self.checkInMem("key", nodeID) == False:
if self.getRecord("key", nodeID) == False:
privateKey, publicKey = self.genKeyPair(setSelf=False)
self.updateStore(
"key", nodeID, {"publicKey": publicKey, "privateKey": privateKey}
)
self.loadRecordToMem("key", nodeID)
sharedKey = self.loadedKeys[nodeID]["privateKey"].exchange(
Serialization.load_pem_public_key(pubKey)
)
# Perform key derivation.
self.cLog(20, "Performing key derivation")
derivedKey = HKDF(
algorithm=hashes.SHA256(), length=32, salt=salt, info=b"handshake data"
).derive(sharedKey)
self.cLog(20, "Derived key")
ederivedKey = base64.urlsafe_b64encode(derivedKey)
tr = self.getRecord("key", nodeID)
tr["derivedKey"] = ederivedKey
self.updateStore("key", nodeID, tr)
self.cLog(20, "Done with cryptography store updates")
return ederivedKey
def getSalt(self):
"""
Get random salt
"""
return os.urandom(16)
def encrypt(self, data, nodeID: str, isDict: bool = True):
"""
Do Fernet encryption
data
Either bytes or dict to encrypt
isDict: bool
Whether data is a dictionary
"""
r = self.getRecord("key", nodeID)
if r == False:
self.cLog(20, "Node {0} not in keystore".format(nodeID))
return False
else:
derivedKey = r["derivedKey"]
fernet = Fernet(derivedKey)
if isDict:
data = msgpack.dumps(data)
token = fernet.encrypt(data)
return token
def decrypt(self, data, nodeID: str):
"""
Decrypt bytes and return either str or dict (TODO: Check whether to msgpack load)
"""
r = self.getRecord("key", nodeID)
if r == False:
self.cLog(20, "No record of node " + nodeID)
return False
elif not "derivedKey" in r.keys():
self.cLog(20, "No key derived for node " + nodeID)
return False
else:
fernet = Fernet(self.getRecord("key", nodeID)["derivedKey"])
return msgpack.loads(fernet.decrypt(data))