295 lines
9.5 KiB
Python
Executable File
295 lines
9.5 KiB
Python
Executable File
import base64
|
|
import os
|
|
from cryptography.fernet import Fernet
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric import dh
|
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
from cryptography.hazmat.primitives.serialization import (
|
|
Encoding,
|
|
NoEncryption,
|
|
ParameterFormat,
|
|
PublicFormat,
|
|
PrivateFormat,
|
|
)
|
|
import cryptography.hazmat.primitives.serialization as Serialization
|
|
import msgpack
|
|
from Daisy.Store import Store
|
|
|
|
# TODO: Different store directories per node
|
|
|
|
|
|
class DHEFern:
|
|
"""
|
|
|
|
Attributes
|
|
----------
|
|
cLog
|
|
Method reference to `run.Node.cLog` so we can log to the ui from here
|
|
|
|
loadedParams: dict
|
|
In memory representations of cryptography parameters
|
|
|
|
loadedKeys: dict
|
|
In memory representations of cryptography keys
|
|
|
|
nodeNickname: str
|
|
Name of node for isolating configs when running multiple nodes
|
|
|
|
cache: Components.daisy.Cache
|
|
Daisy cache for use in storing cryptography information
|
|
|
|
publicKey
|
|
Public key for node
|
|
|
|
privateKey
|
|
Private key for node
|
|
"""
|
|
|
|
def __init__(self, cache, nodeNickname, cLog):
|
|
"""
|
|
Parameters
|
|
----------
|
|
cache: Components.daisy.Cache
|
|
Reference to the node instances Daisy cache
|
|
|
|
nodeNickname: str
|
|
Node nickname for record storage
|
|
|
|
cLog
|
|
Reference to `run.Node.cLog`
|
|
|
|
"""
|
|
self.cLog = cLog
|
|
self.stores = {}
|
|
self.loadedParams = {}
|
|
self.loadedKeys = {}
|
|
self.nodeNickname = nodeNickname
|
|
self.cache = cache
|
|
if os.path.exists("daisy/cryptography/{0}/param".format(nodeNickname)) == False:
|
|
self.initStore("param")
|
|
else:
|
|
self.stores["param"] = Store("param", "cryptography", nodeNickname)
|
|
self.params = self.loadParamBytes(self.stores["param"].get()["self"])
|
|
self.cLog(20, "Param store initialized")
|
|
if os.path.exists("daisy/cryptography/{0}/key".format(nodeNickname)) == False:
|
|
self.cLog(20, "Key store DNE, initializing")
|
|
self.initStore("key")
|
|
self.genKeyPair()
|
|
else:
|
|
self.cLog(20, "Key store exists, loading")
|
|
self.stores["key"] = Store("key", "cryptography", nodeNickname)
|
|
self.cLog(20, "Store loaded")
|
|
# tks = self.stores["key"].get()
|
|
# self.publicKey = tks["self"]["publicKey"]
|
|
# self.privateKey = tks["self"]["privateKey"]
|
|
self.cLog(20, "Key store initialized")
|
|
|
|
def checkInMem(self, store: str, nodeID: str):
|
|
"""
|
|
Check if parameters or keys are loaded for node of nodeID
|
|
|
|
Parameters
|
|
----------
|
|
store: str
|
|
Whether to check loaded keys or parameters
|
|
|
|
"""
|
|
if store == "param":
|
|
return nodeID in self.loadedParams.keys()
|
|
elif store == "key":
|
|
return nodeID in self.loadedKeys.keys()
|
|
|
|
def loadRecordToMem(self, store: str, nodeID: str):
|
|
"""
|
|
Load record of nodeID from store to either keys or pameters
|
|
"""
|
|
r = self.getRecord(store, nodeID)
|
|
if r == False:
|
|
self.cLog(
|
|
30, "Tried to load nonexistent {0} for node {1}".format(store, nodeID)
|
|
)
|
|
return False
|
|
elif self.checkInMem(store, nodeID):
|
|
self.cLog(10, "{0}s already deserialized, skipping".format(store))
|
|
else:
|
|
if store == "param":
|
|
self.loadedParams[nodeID] = self.loadParamBytes(r)
|
|
elif store == "key":
|
|
self.loadedKeys[nodeID] = {
|
|
"publicKey": Serialization.load_pem_public_key(r["publicKey"]),
|
|
"privateKey": Serialization.load_pem_private_key(
|
|
r["privateKey"], None
|
|
),
|
|
}
|
|
return True
|
|
|
|
def getRecord(self, store: str, key: str):
|
|
"""
|
|
Get record from store: store with key: key
|
|
"""
|
|
r = stores[store].getRecord(key)
|
|
if r == False:
|
|
self.cLog(20, "Record does not exist")
|
|
return False
|
|
else:
|
|
return r
|
|
|
|
def initStore(self, store: str):
|
|
"""
|
|
Initialize store: store
|
|
"""
|
|
self.stores[store] = Store(store, "cryptography", self.nodeNickname)
|
|
if store == "param":
|
|
self.genParams()
|
|
self.stores[store].update("self", self.getParamsBytes(), recur=False)
|
|
elif store == "key":
|
|
self.stores[store].update("self", {}, recur=False)
|
|
else:
|
|
self.cLog(30, "Store not defined")
|
|
|
|
def genParams(self):
|
|
"""
|
|
Generate Diffie Hellman parameters
|
|
"""
|
|
params = dh.generate_parameters(generator=2, key_size=2048)
|
|
self.params = params
|
|
return params
|
|
|
|
def getParamsBytes(self):
|
|
"""
|
|
Get bytes encoded from self.parameters (TODO: Encode from store)
|
|
"""
|
|
return self.params.parameter_bytes(Encoding.PEM, ParameterFormat.PKCS3)
|
|
|
|
def loadParamBytes(self, pemBytes: bytes):
|
|
"""
|
|
Load parameters to self.params from given bytes (TODO: Load from store)
|
|
"""
|
|
self.params = Serialization.load_pem_parameters(pemBytes)
|
|
return self.params
|
|
|
|
def genKeyPair(self, paramsOverride=False, setSelf: bool = True):
|
|
"""
|
|
Generate public and private keys from self.params (TODO: Gen from passed params)
|
|
|
|
paramsOverride
|
|
False or parameters to use (TODO)
|
|
|
|
setSelf: bool
|
|
Whether to set self.privateKey and self.publicKey
|
|
"""
|
|
privateKey = self.params.generate_private_key()
|
|
if setSelf:
|
|
self.privateKey = privateKey
|
|
publicKey = privateKey.public_key()
|
|
if setSelf:
|
|
self.publicKey = publicKey
|
|
self.updateStore(
|
|
"key",
|
|
"self",
|
|
{
|
|
"publicKey": self.publicKey.public_bytes(
|
|
Encoding.PEM, PublicFormat.SubjectPublicKeyInfo
|
|
),
|
|
"privateKey": self.privateKey.private_bytes(
|
|
Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()
|
|
),
|
|
},
|
|
)
|
|
return [privateKey, publicKey]
|
|
else:
|
|
publicKey = publicKey.public_bytes(
|
|
Encoding.PEM, PublicFormat.SubjectPublicKeyInfo
|
|
)
|
|
privateKey = privateKey.private_bytes(
|
|
Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()
|
|
)
|
|
return [privateKey, publicKey]
|
|
|
|
def keyDerive(self, pubKey: bytes, salt: bytes, nodeID: str, params: bytes):
|
|
"""
|
|
Derive shared key using Diffie Hellman
|
|
|
|
pubKey: bytes
|
|
Public key
|
|
|
|
nodeID: str
|
|
PierMesh node ID
|
|
|
|
params: bytes
|
|
Encryption parameters
|
|
"""
|
|
if self.checkInMem("param", nodeID) == False:
|
|
if self.getRecord("param", nodeID) == False:
|
|
self.updateStore("param", nodeID, params, recur=False)
|
|
self.loadRecordToMem("param", nodeID)
|
|
self.cLog(20, "Precheck done for key derivation")
|
|
|
|
# TODO: Load them and if private key exists load it, otherwise generate a private key
|
|
if self.checkInMem("key", nodeID) == False:
|
|
if self.getRecord("key", nodeID) == False:
|
|
privateKey, publicKey = self.genKeyPair(setSelf=False)
|
|
self.updateStore(
|
|
"key", nodeID, {"publicKey": publicKey, "privateKey": privateKey}
|
|
)
|
|
self.loadRecordToMem("key", nodeID)
|
|
|
|
sharedKey = self.loadedKeys[nodeID]["privateKey"].exchange(
|
|
Serialization.load_pem_public_key(pubKey)
|
|
)
|
|
# Perform key derivation.
|
|
self.cLog(20, "Performing key derivation")
|
|
derivedKey = HKDF(
|
|
algorithm=hashes.SHA256(), length=32, salt=salt, info=b"handshake data"
|
|
).derive(sharedKey)
|
|
self.cLog(20, "Derived key")
|
|
ederivedKey = base64.urlsafe_b64encode(derivedKey)
|
|
tr = self.getRecord("key", nodeID)
|
|
tr["derivedKey"] = ederivedKey
|
|
self.updateStore("key", nodeID, tr)
|
|
self.cLog(20, "Done with cryptography store updates")
|
|
return ederivedKey
|
|
|
|
def getSalt(self):
|
|
"""
|
|
Get random salt
|
|
"""
|
|
return os.urandom(16)
|
|
|
|
def encrypt(self, data, nodeID: str, isDict: bool = True):
|
|
"""
|
|
Do Fernet encryption
|
|
|
|
data
|
|
Either bytes or dict to encrypt
|
|
|
|
isDict: bool
|
|
Whether data is a dictionary
|
|
"""
|
|
r = self.getRecord("key", nodeID)
|
|
if r == False:
|
|
self.cLog(20, "Node {0} not in keystore".format(nodeID))
|
|
return False
|
|
else:
|
|
derivedKey = r["derivedKey"]
|
|
fernet = Fernet(derivedKey)
|
|
if isDict:
|
|
data = msgpack.dumps(data)
|
|
token = fernet.encrypt(data)
|
|
return token
|
|
|
|
def decrypt(self, data, nodeID: str):
|
|
"""
|
|
Decrypt bytes and return either str or dict (TODO: Check whether to msgpack load)
|
|
"""
|
|
r = self.getRecord("key", nodeID)
|
|
if r == False:
|
|
self.cLog(20, "No record of node " + nodeID)
|
|
return False
|
|
elif not "derivedKey" in r.keys():
|
|
self.cLog(20, "No key derived for node " + nodeID)
|
|
return False
|
|
else:
|
|
fernet = Fernet(self.getRecord("key", nodeID)["derivedKey"])
|
|
return msgpack.loads(fernet.decrypt(data))
|