from build_tools import *
import osconfig
Import('OS_ROOT')
import os
from string import Template

# 1. Specific certificate file template
cert_template = """/**
 ***********************************************************************************************************************
 * Copyright (c) 2020, China Mobile Communications Group Co.,Ltd.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with 
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on 
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the 
 * specific language governing permissions and limitations under the License.
 *
 * @file        tls_certificate.c
 *
 * @brief      a ca file auto generate by SConscript
 *
 * @details     
 *
 * @revision
 * Date               Author             Notes
 * 2020-06-10   OneOS Team    first version
 ***********************************************************************************************************************
 */

#include "mbedtls/certs.h"

const char mbedtls_root_certificate[] = "" \\
${CERT_CONTENT}
;

const size_t mbedtls_root_certificate_len = sizeof(mbedtls_root_certificate);

"""

# 2. Create substitute from template
cert_subs = Template(cert_template)

# 3. Get the current absolute path
pwd = PresentDir()

# 4. PEM certificate file path (*.pem or *.cer)
certs_user_dir = pwd + os.sep + 'certs'
certs_default_dir = pwd + os.sep + (os.sep).join(['certs', 'default'])

ROOT_CA_FILE = []

# 5. File that stores the contents of the certificate file
output_cert_file = pwd + os.sep + (os.sep).join(['ports', 'src', 'tls_certificate.c'])

if IsDefined(['SECURITY_USING_MBEDTLS_USE_ALL_CERTS']):
    file_list = os.listdir(certs_default_dir)
    if len(file_list):
        for i in range(0, len(file_list)):
            path = os.path.join(certs_default_dir, file_list[i])
            if os.path.isfile(path):
                ROOT_CA_FILE += [path]

if IsDefined(['SECURITY_USING_MBEDTLS_USER_CERTS']):
    file_list = os.listdir(certs_user_dir)
    if len(file_list):
        for i in range(0, len(file_list)):
            path = os.path.join(certs_user_dir, file_list[i])
            if os.path.isfile(path):
                ROOT_CA_FILE += [path]


KCONFIG_ROOT_CA_DICT = {'SECURITY_USING_MBEDTLS_THAWTE_ROOT_CA': 'THAWTE_ROOT_CA.cer', \
                        'SECURITY_USING_MBEDTLS_VERSIGN_PBULIC_ROOT_CA': 'VERSIGN_PUBLIC_ROOT_CA.cer', \
                        'SECURITY_USING_MBEDTLS_VERSIGN_UNIVERSAL_ROOT_CA': 'VERSIGN_UNIVERSAL_ROOT_CA.cer', \
                        'SECURITY_USING_MBEDTLS_GEOTRUST_ROOT_CA': 'GEOTRUST_ROOT_CA.cer', \
                        'SECURITY_USING_MBEDTLS_DIGICERT_ROOT_CA': 'DIGICERT_ROOT_CA.cer', \
                        'SECURITY_USING_MBEDTLS_GODADDY_ROOT_CA': 'GODADDY_ROOT_CA.cer', 
                        'SECURITY_USING_MBEDTLS_COMODOR_ROOT_CA': 'COMODOR_ROOT_CA.cer', \
                        'SECURITY_USING_MBEDTLS_DST_ROOT_CA': 'DIGITAL_SIGNATURE_TRUST_ROOT_CA.cer', \
                        'SECURITY_USING_MBEDTLS_CLOBALSIGN_ROOT_CA': 'CLOBALSIGN_ROOT_CA.cer', \
                        'SECURITY_USING_MBEDTLS_ENTRUST_ROOT_CA': 'ENTRUST_ROOT_CA.cer', \
                        'SECURITY_USING_MBEDTLS_AMAZON_ROOT_CA': 'AMAZON_ROOT_CA.cer'}

for key, value in KCONFIG_ROOT_CA_DICT.items():
    if IsDefined([key]):
        path = os.path.join(certs_default_dir, value)
        if os.path.exists(path) and os.path.isfile(path):
            ROOT_CA_FILE += [path]

ROOT_CA_FILE = list(set(ROOT_CA_FILE))

file_content = ""

# 6. Traverse the specified certificate file
if len(ROOT_CA_FILE) > 0:
    for i in range(0, len(ROOT_CA_FILE)):
        if os.path.isfile(ROOT_CA_FILE[i]):
            # READ CER FILE, copy to tls_certificate.c
            with open(ROOT_CA_FILE[i], 'r') as ca:
                # Pre-read, check first line
                if not ca.readline().startswith("-----BEGIN CERTIFICATE"):
                    print("[mbedtls] Warning: ", ROOT_CA_FILE[i], "is not CA file! Skipped!")
                    continue
                ca.seek(0)
                for line in ca.readlines():
                    file_content += '"' + line.strip() + '\\r\\n" \\\n'

# 7. Populate certificate template content
cert_content = cert_subs.substitute(CERT_CONTENT = file_content)

# 8. Write certificate template content to tls_certificate.c
with open(output_cert_file, 'w') as f:
    f.write(cert_content)


src = Glob('mbedtls/library/*.c')
DeleteSrcFile(src, 'mbedtls/library/net_sockets.c')

src += Glob('ports/src/*.c')

if IsDefined(['SECURITY_USING_MBEDTLS_EXAMPLE']):
    src += Glob('samples/*.c')

CPPPATH = [
pwd + '/mbedtls/include',
pwd + '/mbedtls/include/mbedtls',
pwd + '/ports/inc',
]

if osconfig.CROSS_TOOL == 'gcc' or osconfig.CROSS_TOOL == 'keil' or osconfig.CROSS_TOOL == 'iar':
    import shutil
    import platform
    cp_src = (pwd + '/ports/inc/tls_config.h').replace(r'\/'.replace(os.sep, ''), os.sep)
    cp_dst = (pwd + '/mbedtls/include/mbedtls/config.h').replace(r'\/'.replace(os.sep, ''), os.sep)
    try:
        shutil.copyfile(cp_src, cp_dst)
    except Exception as e:
        if platform.system() == "Windows":
            if os.path.exists(cp_dst):    
                os.system("del "+cp_dst)
            if os.path.exists(cp_dst):    
                print("remove mbedtls/config.h first")
        shutil.copyfile(cp_src, cp_dst)
    CPPDEFINES = []
else:
    CPPDEFINES = []

group = AddCodeGroup('mbedtls', src, depend = ['SECURITY_USING_MBEDTLS'], CPPPATH = CPPPATH, CPPDEFINES = CPPDEFINES)

Return('group')
