#! /usr/bin/env python
#      -*- OpenSAF  -*-
#
# (C) Copyright 2009 The OpenSAF Foundation
# (C) Copyright Ericsson AB 2017. All rights reserved.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE. This file and program are licensed
# under the GNU Lesser General Public License Version 2.1, February 1999.
# The complete license can be accessed from the following location:
# http://opensource.org/licenses/lgpl-license.php
# See the Copying file included with the OpenSAF distribution for full
# licensing terms.
#
# Author(s): Ericsson AB
#
# pylint: disable=invalid-name,unused-argument
""" immxml-validate tool """
from __future__ import print_function
import os
import sys
import getopt
import xml.dom.minidom
from baseimm import BaseOptions, BaseImmDocument, trace, retrieve_file_names, \
    print_info_stderr, abort_script, verify_input_file_read_access


class Options(BaseOptions):
    """ immxml-validate options """
    schemaFilename = None
    isXmlLintFound = True
    ignoreAttributeRefs = False
    ignoreRdnAssociationRefs = False

    @staticmethod
    def print_option_settings():
        return BaseOptions.print_option_settings() + ' schemaFilename: %s\n' \
                                                     % Options.schemaFilename


class AbortFileException(Exception):
    """ Exception to raise when aborting file """
    def __init__(self, value):
        Exception.__init__(self)
        self.value = value

    def __str__(self):
        return repr(self.value)


class ImmDocument(BaseImmDocument):
    """ This class contains methods to validate an IMM XML file """
    def __init__(self):
        """ Class constructor """
        self.imm_content_element = None
        self.classList = []
        self.classDict = {}
        self.objectList = []
        self.objectDnNameDict = {}
        self.validateFailed = False

    @staticmethod
    def initialize():
        """ Initialize the validation process """
        trace("initialize()")

    @staticmethod
    def check_class_clash_acceptable(class_element, class_name):
        """ Check if class clash is acceptable. TBD """
        pass

    def add_class(self, class_element, other_nodes):
        """ Add class to the list of classes to validate """
        class_name = class_element.getAttribute("name")
        trace("className:%s nodeType:%d nodeName:%s", class_name,
              class_element.nodeType, class_element.nodeName)

        # Check if class exists in dictionary map
        if class_name in self.classDict:
            # Check if class clash is "acceptable"
            # i.e. should we bail out or not??
            if self.check_class_clash_acceptable(class_element, class_name):
                self.validate_failed("Found duplicate class: %s", class_name)
        else:
            attr_list = []
            attribute_dict = {}
            rdn_type = None
            for attribute in class_element.childNodes:
                attr_name = None
                attr_type = None
                if attribute.nodeName == "rdn":
                    for element in attribute.childNodes:
                        if element.nodeName == "type":
                            rdn_type = element.firstChild.nodeValue
                elif attribute.nodeName == "attr":
                    for element in attribute.childNodes:
                        if element.nodeName == "name":
                            attr_name = element.firstChild.nodeValue
                            if attr_name in attr_list:
                                raise Exception("Attribute '%s' is defined "
                                                "more than once in class '%s'"
                                                % (attr_name, class_name))
                            attr_list.append(attr_name)
                        elif element.nodeName == "type":
                            attr_type = element.firstChild.nodeValue

                        if attr_name is not None and attr_type is not None:
                            attribute_dict[attr_name] = attr_type
                            break

            del attr_list

            class_def_dict = {}
            self.classDict[class_name] = class_def_dict
            self.classDict[class_name]["attributes"] = attribute_dict
            self.classDict[class_name]["rdn"] = rdn_type
            # Add to ordered list
            self.classList.append((class_name, (other_nodes, class_element)))

    def add_object(self, object_element, other_nodes):
        """ Add object to the list of objects to validate """
        class_name = object_element.getAttribute("class")
        trace("DUMPING Object with className: %s\n %s", class_name,
              object_element.toxml())

        attr_list = []
        object_dn_element = None
        # Find "dn" child node
        for child_node in object_element.childNodes:
            if child_node.nodeName == "dn" and object_dn_element is None:
                object_dn_element = child_node
            elif child_node.nodeName == "attr":
                for attr_child in child_node.childNodes:
                    if attr_child.nodeName == "name":
                        if attr_child.firstChild.nodeValue in attr_list:
                            raise Exception("Attribute '%s' is defined more "
                                            "than once in object" %
                                            attr_child.firstChild.nodeValue)
                        attr_list.append(attr_child.firstChild.nodeValue)

        del attr_list

        if object_dn_element is None or class_name is None \
                or object_dn_element.firstChild is None:
            self.validate_failed("Failed to find class name or the dn child "
                                 "node in object: %s",
                                 object_element.toxml())

        # object_dn_name = object_dn_element.nodeValue
        # NOTE: dn name should be the nodeValue of object_dn_name
        # However minidom probably gets confused by the equal sign in dn value
        # so it ends up adding a text node as a child.
        object_dn_name = object_dn_element.firstChild.nodeValue
        trace("objectDn key: %s", object_dn_name)
        trace("objectDnElement: %s", object_dn_element.toxml())

        # class_name must exist in dictionary map
        if class_name not in self.classDict:
            self.validate_failed("Failed to find class '%s' in %s",
                                 class_name, object_element.toxml())

        if object_dn_name in self.objectDnNameDict:
            self.validate_failed("Found duplicate object: %s",
                                 object_element.toxml())

        self.validate_and_store_dn(object_dn_name, object_dn_element,
                                   class_name)

        # Add the complete object to ordered list
        # (keyed by class_name for sorting)
        # self.objectList.append((class_name + object_dn_name, object_element))
        self.objectList.append((object_dn_name, (other_nodes, object_element)))

    def validate_and_store_dn(self, object_dn_name, object_dn_element,
                              class_name):
        """ Validate and store object's dn """
        comma_index = -1
        dn_len = len(object_dn_name)
        if dn_len > 256:
            self.validate_failed("Length of dn is %d (max 256): %s", dn_len,
                                 object_dn_name)

        # Search for first unescaped comma (if any)
        for i in range(0, dn_len):
            if object_dn_name[i] == ',':
                if i > 1 and object_dn_name[i - 1] != '\\':
                    comma_index = i
                    break

        if comma_index == -1:
            trace("Found root element (no unescaped commas): %s",
                  object_dn_name)
            # Looks alright, add element to "dn" dictionary
            self.objectDnNameDict[object_dn_name] = object_dn_element
            return

        object_owner_part = object_dn_name[comma_index + 1:]
        object_id_part = object_dn_name[:comma_index]
        trace("ObjectDN: %s objectOwner: %s objectIdPart: %s", object_dn_name,
              object_owner_part, object_id_part)

        # Store all dn's even if it is SA_NAME_T
        # This means even the association references are stored as objects
        # which could be parents
        # move length checking (64 bytes) to post_validate()
        self.objectDnNameDict[object_dn_name] = object_dn_element

        # NOTE for some imm.xml file (for example as result of immdump tool)
        # the object dn's are not ordered in a way that it is possible to
        # validate ownership while parsing the input files. Instead it must be
        # performed after the input files are completely parsed.
        # validate(): Do this type of IMM validation in one go afterwards
        return

    def validate_failed(self, *args):
        """ Print validation failed message to stderr and
        set validateFailed flag """
        printf_args = []
        for i in range(1, len(args)):
            printf_args.append(args[i])

        format_str = "\nValidation failed: " + args[0]
        print (format_str % tuple(printf_args), file=sys.stderr)
        # sys.exit(2)
        # no more exit, set failedFlag to True
        self.validateFailed = True

    def abort_file(self, *args):
        """ Abort file validation and raise exception """
        self.validate_failed(*args)
        raise AbortFileException("Aborting current file!")

    def post_validate_object_list(self):
        """ Post-validate object list """
        for tuple_node in self.objectList:
            object_element = tuple_node[1][1]
            class_name = object_element.getAttribute("class")
            attributes = self.classDict[class_name]["attributes"]

            object_dn_element = None
            # Find "dn" child node
            for child_node in object_element.childNodes:
                if child_node.nodeName == "dn":
                    object_dn_element = child_node
                    if object_dn_element.firstChild is None:
                        # This is really a workaround for minidom bug?:
                        # Assume dn name should be the nodeValue of
                        # object_dn_name but with minidom that is not true.
                        raise Exception("Cannot find child element of dn "
                                        "element (required by minidom)")
                    object_dn_name = object_dn_element.firstChild.nodeValue
                    trace("objectDnElement: %s", object_dn_element.toxml())
                    trace("objectDn key: %s", object_dn_name)
                    # Validate dn w.r.t ownership
                    self.post_validate_dn(object_dn_name, object_dn_element,
                                          class_name)
                    # break
                elif child_node.nodeName == "attr":
                    name_of_attribute_to_validate = None
                    for element in child_node.childNodes:
                        if element.nodeName == "name":
                            attr_name = element.firstChild.nodeValue
                            if attr_name in attributes:
                                if attributes[attr_name] == "SA_NAME_T":
                                    name_of_attribute_to_validate = attr_name
                                else:
                                    # Attribute exists in classDict but type is
                                    # not further validated
                                    break
                            else:
                                self.post_validate_missing_attribute(
                                    attr_name, class_name, object_dn_element)
                        elif element.nodeName == "value":
                            node_value = element.firstChild.nodeValue
                            self.post_validate_attribute_sa_name_t(
                                name_of_attribute_to_validate, node_value,
                                class_name, object_dn_element)
                            # Multiple values allowed....no break

        trace("post_validate_object_list() complete!")

    def post_validate_attribute_sa_name_t(self, attribute_name,
                                          attribute_value, class_name,
                                          object_element):
        """ Post-validate SaNameT attribute """
        if Options.ignoreAttributeRefs:
            return

        if attribute_value not in self.objectDnNameDict:
            self.validate_failed("NOTE: The object with rdn '%s' referred in "
                                 "attribute %s does not exist (The attribute "
                                 "is element in object with class: %s dn: %s)",
                                 attribute_value, attribute_name, class_name,
                                 object_element.toxml())

    def post_validate_missing_attribute(self, attribute_name, class_name,
                                        object_element):
        """ Post-validate missing attribute """
        self.validate_failed("NOTE: The attribute %s does not exist in class "
                             "definition (The attribute is element in object "
                             "with class: %s dn: %s)", attribute_name,
                             class_name, object_element.toxml())

    def post_validate_dn(self, object_dn_name, object_dn_element, class_name):
        """ Post-validate dn """
        comma_index = -1
        dn_len = len(object_dn_name)

        # Search for first unescaped comma (if any)
        for i in range(0, dn_len):
            if object_dn_name[i] == ',':
                if i > 1 and object_dn_name[i - 1] != '\\':
                    comma_index = i
                    break

        if comma_index == -1:
            trace("Found root element (no unescaped commas): %s",
                  object_dn_name)
            return

        object_owner_part = object_dn_name[comma_index + 1:]
        object_part = object_dn_name[:comma_index]
        trace("ObjectDN:%s objectOwner:%s objectPart:%s", object_dn_name,
              object_owner_part, object_part)

        # Owner should exist for both SA_NAME_T and SA_STRING_T
        if object_owner_part not in self.objectDnNameDict:
            print_info_stderr("Validate dn in %s", object_dn_element.toxml())
            self.validate_failed("Parent to %s is not found in %s",
                                 object_dn_name, object_dn_element.toxml())

        trace("post_validate_dn() OK parentPart %s found in objectDnNameDict",
              object_owner_part)

        # But in case dn is a SA_NAME_T also the objectIdPart should exist
        # in dictionary
        if self.classDict[class_name]["rdn"] == "SA_NAME_T":
            # Find value of association (remove association name)
            equal_sign_index = object_part.find("=")
            # object_name = object_part[:equal_sign_index]
            object_value = object_part[equal_sign_index + 1:]
            object_value_equal_sign_index = object_value.find("=")

            # x=y vs x=y=z
            if object_value_equal_sign_index != -1:
                if not Options.ignoreRdnAssociationRefs:
                    association_value = object_value
                    # Remove escape characters
                    unescaped_dn = association_value.replace('\\', '')
                    if unescaped_dn not in self.objectDnNameDict:
                        print_info_stderr("Validate dn in %s",
                                          object_dn_element.toxml())
                        self.validate_failed("The associated object %s is not "
                                             "found in %s", unescaped_dn,
                                             object_dn_element.toxml())
                    trace("post_validate_dn() OK  The associated object %s is "
                          "found in objectDnNameDict (dn has type SA_NAME_T)",
                          unescaped_dn)
            else:
                if len(object_value) > 64:
                    print_info_stderr("Validate dn in %s",
                                      object_dn_element.toxml())
                    self.validate_failed("Length of object value is %d "
                                         "(max 64): %s", len(object_part),
                                         object_value)
        return

    def process_input_file(self, file_name):
        """ Process IMM XML input file """
        trace("")
        trace("process_input_file(): %s", file_name)

        if Options.isXmlLintFound and Options.schemaFilename is not None:
            if self.validate_xml_file_with_schema(file_name,
                                                  Options.schemaFilename) != 0:
                self.abort_file("Failed to validate input file %s with xml "
                                "schema", file_name)
        else:
            self.verify_input_xml_document_file_is_parsable(file_name)

        doc = xml.dom.minidom.parse(file_name)

        # Fast forward to imm:contents element
        for element in doc.childNodes:
            other_nodes = []
            if element.nodeName == self.imm_content_element_name:
                for child_element in element.childNodes:
                    trace("imm:contents loop.....NodeName:%s NodeValue:%s",
                          child_element.nodeName, child_element.nodeValue)
                    if child_element.nodeName == "class":
                        self.add_class(child_element, other_nodes)
                        other_nodes = []
                    elif child_element.nodeName == "object":
                        self.add_object(child_element, other_nodes)
                        other_nodes = []
                    else:
                        # Probably text nodes
                        # Ignore if whitespace only
                        # child_element_str = \
                        #     child_element.nodeValue.lstrip().rstrip()
                        # if len(child_element_str) > 1:
                        #     other_nodes.append(child_element)
                        other_nodes.append(child_element)

        return 0

    def post_process_validate(self):
        """ Post validation process """
        # Iterate over all objects again to validate again when all objects are
        # parsed from input files
        self.post_validate_object_list()

# End of ImmDocument class


def print_usage():
    """ Print usage of immxml-validate tool """
    print ("usage: immxml-validate [options] filename[s]")
    print ("""
    --schema            validate input files with the supplied xsd schema file

    -t, --trace         print trace information to stderr

    --ignore-attribute-refs
                        if this option is specified, the tool skips to validate
                        that SA_NAME_T attributes references existing objects

    --ignore-rdn-association-refs
                        if this option is specified, the tool skips to validate
                        that SA_NAME_T rdn association references existing
                        objects

    -v, --version       print version information and exit

    -h, --help          display this help and exit

   See http://devel.opensaf.org/ for information and updates.
   """)


def print_version():
    """ Print version of immxml-validate tool """
    print ("immxml-validate version 0.5.1")


def main(argv):
    """ Main program """
    try:
        opts, args = getopt.getopt(argv, "thv",
                                   ["trace", "help", "version", "schema=",
                                    "ignore-attribute-refs",
                                    "ignore-rdn-association-refs"])
    except getopt.GetoptError as err:
        # Print help information and exit
        print_info_stderr("%s", str(err))
        print_usage()
        sys.exit(2)

    for opt, value in opts:
        if opt in ["-t", "--trace"]:
            BaseOptions.traceOn = True
        if opt == "--schema":
            Options.schemaFilename = value
        if opt == "--ignore-attribute-refs":
            Options.ignoreAttributeRefs = True
        if opt == "--ignore-rdn-association-refs":
            Options.ignoreRdnAssociationRefs = True
        if opt in ["-v", "--version"]:
            print_version()
            sys.exit(0)
        elif opt in ["-h", "--help"]:
            print_usage()
            sys.exit(0)

    # Cannot trace these until -t, Options.traceOn is effective (or not)
    trace("opts: %s", opts)
    trace("args: %s", args)
    trace("sys.path: %s", sys.path)

    if not args:
        print_usage()
        sys.exit(2)

    trace("Option object:\n %s", Options.print_option_settings())

    if not os.path.exists('/usr/bin/xmllint'):
        if Options.schemaFilename is not None:
            abort_script("Cannot find the required linux command "
                         "/usr/bin/xmllint. '--schema' option requires "
                         "xmllint, Exiting!")

    file_list = retrieve_file_names(args)

    trace("Starting to process files...\n")
    at_least_one_file_failed = False
    for file_name in file_list:
        try:
            doc = ImmDocument()
            verify_input_file_read_access(file_name)
            doc.initialize()
            doc.process_input_file(file_name)
            if not doc.validateFailed:
                doc.post_process_validate()
        except AbortFileException:
            doc.validateFailed = True

        if doc.validateFailed:
            at_least_one_file_failed = True
            print_info_stderr("Validation failed for file: %s", file_name)
        else:
            print_info_stderr("Validation succeeded for file: %s", file_name)

        trace("Done with file: %s", file_name)

    if at_least_one_file_failed:
        sys.exit(2)
    else:
        return 0


if __name__ == "__main__":
    main(sys.argv[1:])
