From d02a9b251e9411658b70ffc83de8ef0b11b757d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20H=C3=B6rstrup?= Date: Wed, 6 Nov 2024 11:38:13 +0100 Subject: [PATCH] Implement plugin system - Add PFDLBaseClasses which is an object that holds the actual classes instantiated by the PFDL Scheduler (can be overwritten by plugins) - Add plugin loader class that can load plugins and return a new PFDLBaseClasses object - Script for merging multiple grammars into a comprehensive grammar - Adaptions to the existing code base so that the new PFDLBaseClasses are used - Adaptions to the unit tests and new ones for the plugin classes --- pfdl_scheduler/model/array.py | 2 +- pfdl_scheduler/model/instance.py | 2 +- pfdl_scheduler/model/struct.py | 7 +- pfdl_scheduler/parser/pfdl_tree_visitor.py | 66 +++-- pfdl_scheduler/petri_net/generator.py | 29 ++- pfdl_scheduler/pfdl_base_classes.py | 92 +++++++ pfdl_scheduler/plugins/README.md | 47 ++++ pfdl_scheduler/plugins/__init__.py | 0 pfdl_scheduler/plugins/grammar_merge.py | 197 +++++++++++++++ pfdl_scheduler/plugins/parser/__init__.py | 0 pfdl_scheduler/plugins/plugin_loader.py | 235 ++++++++++++++++++ pfdl_scheduler/scheduler.py | 66 +++-- pfdl_scheduler/scheduling/event.py | 9 +- pfdl_scheduler/utils/parsing_utils.py | 34 +-- .../validation/semantic_error_checker.py | 58 +++-- tests/unit_test/test_pfdl_base_classes.py | 94 +++++++ tests/unit_test/test_pfdl_tree_visitor.py | 39 ++- tests/unit_test/test_plugin_loader.py | 39 +++ .../unit_test/test_semantic_error_checker.py | 93 ++++++- 19 files changed, 1010 insertions(+), 99 deletions(-) create mode 100644 pfdl_scheduler/pfdl_base_classes.py create mode 100644 pfdl_scheduler/plugins/README.md create mode 100644 pfdl_scheduler/plugins/__init__.py create mode 100644 pfdl_scheduler/plugins/grammar_merge.py create mode 100644 pfdl_scheduler/plugins/parser/__init__.py create mode 100644 pfdl_scheduler/plugins/plugin_loader.py create mode 100644 tests/unit_test/test_pfdl_base_classes.py create mode 100644 tests/unit_test/test_plugin_loader.py diff --git a/pfdl_scheduler/model/array.py b/pfdl_scheduler/model/array.py index def6538..1f21557 100644 --- a/pfdl_scheduler/model/array.py +++ b/pfdl_scheduler/model/array.py @@ -64,7 +64,7 @@ def __radd__(self, other) -> str: return other + str(self) def __eq__(self, __o: object) -> bool: - if isinstance(__o, Array): + if hasattr(__o, "values") and hasattr(__o, "length") and hasattr(__o, "type_of_elements"): return ( self.values == __o.values and self.length == __o.length diff --git a/pfdl_scheduler/model/instance.py b/pfdl_scheduler/model/instance.py index 4a06288..c8bc44b 100644 --- a/pfdl_scheduler/model/instance.py +++ b/pfdl_scheduler/model/instance.py @@ -88,7 +88,7 @@ def parse_json( error_msg = "Array definition in JSON are not supported in the PFDL." error_handler.print_error(error_msg, context=instance_context) elif isinstance(value, dict): - inner_struct = parse_json(value, error_handler, instance_context) + inner_struct = parse_json(value, error_handler, instance_context, instance_class) instance.attributes[identifier] = inner_struct return instance diff --git a/pfdl_scheduler/model/struct.py b/pfdl_scheduler/model/struct.py index 7a93e1f..656b3b6 100644 --- a/pfdl_scheduler/model/struct.py +++ b/pfdl_scheduler/model/struct.py @@ -64,7 +64,12 @@ def __init__( self.context_dict: Dict = {} def __eq__(self, __o: object) -> bool: - if isinstance(__o, Struct): + if ( + hasattr(__o, "name") + and hasattr(__o, "attributes") + and hasattr(__o, "context") + and hasattr(__o, "context_dict") + ): return ( self.name == __o.name and self.attributes == __o.attributes diff --git a/pfdl_scheduler/parser/pfdl_tree_visitor.py b/pfdl_scheduler/parser/pfdl_tree_visitor.py index 0322091..c23919a 100644 --- a/pfdl_scheduler/parser/pfdl_tree_visitor.py +++ b/pfdl_scheduler/parser/pfdl_tree_visitor.py @@ -9,9 +9,13 @@ # standard libraries from typing import Dict, List, OrderedDict, Tuple, Union from pfdl_scheduler.model.instance import Instance +from pfdl_scheduler.pfdl_base_classes import PFDLBaseClasses from pfdl_scheduler.utils import helpers from pfdl_scheduler.model.parallel import Parallel +# 3rd party +from antlr4.tree.Tree import TerminalNodeImpl + # local sources from pfdl_scheduler.validation.error_handler import ErrorHandler @@ -45,15 +49,22 @@ class PFDLTreeVisitor(PFDLParserVisitor): Attributes: error_handler: ErrorHandler instance for printing errors while visiting. current_task: Reference to the currently visited Task. Every visitor method can access it. + pfdl_base_classes: `PFDLBaseClasses` instance for creating new objects. """ - def __init__(self, error_handler: ErrorHandler) -> None: + def __init__( + self, + error_handler: ErrorHandler, + pfdl_base_classes: PFDLBaseClasses = PFDLBaseClasses(), + ) -> None: """Initialize the object. Args: - error_handler: ErrorHandler instance for printing errors while visiting. + error_handler: `ErrorHandler` instance for printing errors while visiting. + pfdl_base_classes: `PFDLBaseClasses` instance for creating new objects. """ self.error_handler: ErrorHandler = error_handler + self.pfdl_base_classes: PFDLBaseClasses = pfdl_base_classes self.current_task: Task = None def visitProgram(self, ctx) -> Process: @@ -64,7 +75,7 @@ def visitProgram(self, ctx) -> Process: for child in ctx.children: process_component = self.visit(child) - if isinstance(process_component, Struct): + if isinstance(process_component, self.pfdl_base_classes.get_class("Struct")): if process_component.name not in process.structs: process.structs[process_component.name] = process_component else: @@ -73,7 +84,7 @@ def visitProgram(self, ctx) -> Process: "is already defined" ) self.error_handler.print_error(error_msg, context=child) - elif isinstance(process_component, Task): + elif isinstance(process_component, self.pfdl_base_classes.get_class("Task")): if process_component.name not in process.tasks: process.tasks[process_component.name] = process_component else: @@ -81,7 +92,7 @@ def visitProgram(self, ctx) -> Process: f"A Task with the name '{process_component.name}' " "is already defined" ) self.error_handler.print_error(error_msg, context=child) - elif isinstance(process_component, Instance): + elif isinstance(process_component, self.pfdl_base_classes.get_class("Instance")): if process_component.name not in process.tasks: process.instances[process_component.name] = process_component else: @@ -106,8 +117,7 @@ def execute_additional_tasks(self, process: Process) -> None: self.add_inherited_attributes_to_structs(process) def add_inherited_attributes_to_structs(self, process: Process) -> None: - """ - Tries to add attributes inherited from the respective parents to all child structs. + """Tries to add attributes inherited from the respective parents to all child structs. Throws an error if one parent struct name is found to be invalid. """ @@ -129,12 +139,19 @@ def visitProgram_statement(self, ctx: PFDLParser.Program_statementContext): return self.visit(ctx.children[0]) def addInstancesToAllTasks(self, process: Process) -> None: + """Adds all instances to the variables of all tasks in the process. + + This method is necessary to use the instances in expressions. + + Args: + process: The `Process` object containing all tasks and instances. + """ for instance in process.instances.values(): for task in process.tasks.values(): task.variables[instance.name] = instance.struct_name def visitStruct(self, ctx) -> Struct: - struct = self.pfdl_base_classes.struct() + struct = self.pfdl_base_classes.get_class("Struct")() struct.name = ctx.STARTS_WITH_UPPER_C_STR().getText() struct.context = ctx @@ -160,7 +177,7 @@ def visitStruct_id(self, ctx: PFDLParser.Struct_idContext) -> str: return ctx.children[0].getText() def visitTask(self, ctx) -> Task: - task = Task() + task = self.pfdl_base_classes.get_class("Task")() task.name = ctx.STARTS_WITH_LOWER_C_STR().getText() task.context = ctx @@ -170,8 +187,8 @@ def visitTask(self, ctx) -> Task: task.input_parameters = self.visitTask_in(ctx.task_in()) task.context_dict[IN_KEY] = ctx.task_in() - for statement_ctx in ctx.statement(): - statement = self.visitStatement(statement_ctx) + for statement_ctx in ctx.taskStatement(): + statement = self.visitTaskStatement(statement_ctx) task.statements.append(statement) if ctx.task_out(): task.output_parameters = self.visitTask_out(ctx.task_out()) @@ -182,7 +199,7 @@ def visitTask(self, ctx) -> Task: def visitInstance(self, ctx: PFDLParser.InstanceContext) -> Instance: instance_name = ctx.STARTS_WITH_LOWER_C_STR().getText() struct_name = self.visitStruct_id(ctx.struct_id()) - instance = self.pfdl_base_classes.instance( + instance = self.pfdl_base_classes.get_class("Instance")( name=instance_name, struct_name=struct_name, context=ctx ) self.current_program_component = instance @@ -192,8 +209,11 @@ def visitInstance(self, ctx: PFDLParser.InstanceContext) -> Instance: ) # JSON value if isinstance(attribute_value, Dict): - attribute_value = self.pfdl_base_classes.instance.from_json( - attribute_value, self.error_handler, ctx, self.pfdl_base_classes.instance + attribute_value = self.pfdl_base_classes.get_class("Instance").from_json( + attribute_value, + self.error_handler, + ctx, + self.pfdl_base_classes.get_class("Instance"), ) instance.attributes[attribute_name] = attribute_value instance.attribute_contexts[attribute_name] = attribute_assignment_ctx @@ -251,7 +271,7 @@ def visitStatement( return statement def visitService_call(self, ctx: PFDLParser.Service_callContext) -> Service: - service = Service() + service = self.pfdl_base_classes.get_class("Service")() service.context = ctx service.name = ctx.STARTS_WITH_UPPER_C_STR().getText() @@ -303,13 +323,15 @@ def visitParameter(self, ctx: PFDLParser.ParameterContext) -> Union[str, List[st def visitStruct_initialization(self, ctx: PFDLParser.Struct_initializationContext) -> Struct: json_string = ctx.json_object().getText() - struct = Struct.from_json(json_string, self.error_handler, ctx.json_object()) + struct = self.pfdl_base_classes.get_class("Struct").from_json( + json_string, self.error_handler, ctx.json_object() + ) struct.name = ctx.STARTS_WITH_UPPER_C_STR().getText() struct.context = ctx return struct def visitTask_call(self, ctx: PFDLParser.Task_callContext) -> TaskCall: - task_call = TaskCall() + task_call = self.pfdl_base_classes.get_class("TaskCall")() task_call.name = ctx.STARTS_WITH_LOWER_C_STR().getText() task_call.context = ctx @@ -329,7 +351,7 @@ def visitTask_call(self, ctx: PFDLParser.Task_callContext) -> TaskCall: return task_call def visitParallel(self, ctx: PFDLParser.ParallelContext) -> Parallel: - parallel = Parallel() + parallel = self.pfdl_base_classes.get_class("Parallel")() parallel.context = ctx for task_call_context in ctx.task_call(): task_call = self.visitTask_call(task_call_context) @@ -337,7 +359,7 @@ def visitParallel(self, ctx: PFDLParser.ParallelContext) -> Parallel: return parallel def visitWhile_loop(self, ctx: PFDLParser.While_loopContext) -> WhileLoop: - while_loop = WhileLoop() + while_loop = self.pfdl_base_classes.get_class("WhileLoop")() while_loop.context = ctx while_loop.expression = self.visitExpression(ctx.expression()) @@ -348,7 +370,7 @@ def visitWhile_loop(self, ctx: PFDLParser.While_loopContext) -> WhileLoop: return while_loop def visitCounting_loop(self, ctx: PFDLParser.Counting_loopContext) -> CountingLoop: - counting_loop = CountingLoop() + counting_loop = self.pfdl_base_classes.get_class("CountingLoop")() counting_loop.context = ctx counting_loop.counting_variable = ctx.STARTS_WITH_LOWER_C_STR().getText() @@ -367,7 +389,7 @@ def visitCounting_loop(self, ctx: PFDLParser.Counting_loopContext) -> CountingLo return counting_loop def visitCondition(self, ctx: PFDLParser.ConditionContext) -> Condition: - condition_statement = Condition() + condition_statement = self.pfdl_base_classes.get_class("Condition")() condition_statement.context = ctx condition_statement.expression = self.visitExpression(ctx.expression()) @@ -416,7 +438,7 @@ def visitPrimitive(self, ctx: PFDLParser.PrimitiveContext): return ctx.getText() def initializeArray(self, array_ctx: PFDLParser.ArrayContext, variable_type: str) -> Array: - array = Array() + array = self.pfdl_base_classes.get_class("Array")() array.type_of_elements = variable_type array.context = array_ctx length = self.visitArray(array_ctx) diff --git a/pfdl_scheduler/petri_net/generator.py b/pfdl_scheduler/petri_net/generator.py index 995098a..0d39864 100644 --- a/pfdl_scheduler/petri_net/generator.py +++ b/pfdl_scheduler/petri_net/generator.py @@ -31,6 +31,7 @@ from pfdl_scheduler.petri_net.drawer import draw_petri_net from pfdl_scheduler.petri_net.callbacks import PetriNetCallbacks +from pfdl_scheduler.pfdl_base_classes import PFDLBaseClasses plugins.load(["labels", "gv", "clusters"], "snakes.nets", "nets") @@ -82,6 +83,7 @@ class PetriNetGenerator: callbacks: A PetriNetCallbacks instance representing functions called while execution. generate_test_ids: A boolean indicating if test ids (counting from 0) should be generated. used_in_extension: A boolean indicating if the Generator is used within the extension. + pfdl_base_classes: An instance of `PFDLBaseClasses`. """ def __init__( @@ -91,6 +93,7 @@ def __init__( generate_test_ids: bool = False, draw_net: bool = True, file_name: str = "petri_net", + pfdl_base_classes: PFDLBaseClasses = PFDLBaseClasses(), ) -> None: """Initialize the object. @@ -100,6 +103,7 @@ def __init__( generate_test_ids: A boolean indicating if test ids (counting from 0) should be generated. draw_net: A boolean indicating if the petri net should be drawn. file_name: The desired filename of the petri net image. + pfdl_base_classes: An instance of `PFDLBaseClasses`. """ if used_in_extension: @@ -117,11 +121,12 @@ def __init__( self.transition_dict: OrderedDict = OrderedDict() self.place_dict: Dict = {} self.task_started_uuid: str = "" - self.callbacks: PetriNetCallbacks = PetriNetCallbacks() + self.callbacks: PetriNetCallbacks = pfdl_base_classes.get_class("PetriNetCallbacks")() self.generate_test_ids: bool = generate_test_ids self.used_in_extension: bool = used_in_extension self.tree = None self.file_name = file_name + self.pfdl_base_classes = pfdl_base_classes def add_callback(self, transition_uuid: str, callback_function: Callable, *args: Any) -> None: """Registers the given callback function in the transition_dict. @@ -156,7 +161,7 @@ def generate_petri_net(self, process: Process) -> PetriNet: group_uuid = str(uuid.uuid4()) self.tree = Node(group_uuid, start_task.name) - task_context = TaskAPI(start_task, None) + task_context = self.pfdl_base_classes.get_class("TaskAPI")(start_task, None) if self.generate_test_ids: task_context.uuid = "0" @@ -244,17 +249,17 @@ def generate_statements( in_loop, ) - if isinstance(statement, Service): + if isinstance(statement, self.pfdl_base_classes.get_class("Service")): connection_uuids = [self.generate_service(*args)] - elif isinstance(statement, TaskCall): + elif isinstance(statement, self.pfdl_base_classes.get_class("TaskCall")): connection_uuids = self.generate_task_call(*args) - elif isinstance(statement, Parallel): + elif isinstance(statement, self.pfdl_base_classes.get_class("Parallel")): connection_uuids = [self.generate_parallel(*args)] - elif isinstance(statement, CountingLoop): + elif isinstance(statement, self.pfdl_base_classes.get_class("CountingLoop")): connection_uuids = [self.generate_counting_loop(*args)] - elif isinstance(statement, WhileLoop): + elif isinstance(statement, self.pfdl_base_classes.get_class("WhileLoop")): connection_uuids = [self.generate_while_loop(*args)] - elif isinstance(statement, Condition): + elif isinstance(statement, self.pfdl_base_classes.get_class("Condition")): connection_uuids = self.generate_condition(*args) else: connection_uuids = self.handle_other_statements(*args) @@ -280,7 +285,9 @@ def generate_service( group_uuid = str(uuid.uuid4()) service_node = Node(group_uuid, service.name, node) - service_api = ServiceAPI(service, task_context, in_loop=in_loop) + service_api = self.pfdl_base_classes.get_class("ServiceAPI")( + service, task_context, in_loop=in_loop + ) service_started_uuid = create_place(service.name + " started", self.net, service_node) service_finished_uuid = create_place(service.name + " finished", self.net, service_node) @@ -327,7 +334,9 @@ def generate_task_call( The uuids of the last transitions of the TaskCall petri net component. """ called_task = self.tasks[task_call.name] - new_task_context = TaskAPI(called_task, task_context, task_call=task_call, in_loop=in_loop) + new_task_context = self.pfdl_base_classes.get_class("TaskAPI")( + called_task, task_context, task_call=task_call, in_loop=in_loop + ) group_uuid = str(uuid.uuid4()) task_node = Node(group_uuid, task_call.name, node) diff --git a/pfdl_scheduler/pfdl_base_classes.py b/pfdl_scheduler/pfdl_base_classes.py new file mode 100644 index 0000000..3ae4852 --- /dev/null +++ b/pfdl_scheduler/pfdl_base_classes.py @@ -0,0 +1,92 @@ +# Copyright The PFDL Contributors +# +# Licensed under the MIT License. +# For details on the licensing terms, see the LICENSE file. +# SPDX-License-Identifier: MIT + +"""Contains the PFDLBaseClasses class.""" + +import os +import importlib +import inspect + + +class PFDLBaseClasses: + def __init__(self, base_dir="pfdl_scheduler"): + self._class_registry = {} + self._class_instances = {} + self._base_dir = base_dir # Base directory for project scanning + self._default_classes = self._scan_project_classes() + + def register_class(self, name, class_reference): + """Register a custom class with a specific name.""" + self._class_registry[name] = class_reference + + def get_class(self, name): + """Return the registered class if available, otherwise the default class.""" + # Check if the class is registered + if name in self._class_registry: + return self._class_registry[name] + + # Fall back to default if not registered + return self._get_default_class(name) + + def get_instance(self, name, *args, **kwargs): + """Instantiate the class dynamically if not already instantiated.""" + if name not in self._class_instances: + class_ref = self.get_class(name) + if class_ref is None: + raise ValueError(f"Class '{name}' not found.") + self._class_instances[name] = class_ref(*args, **kwargs) # Instantiate with args + return self._class_instances[name] + + def _scan_project_classes(self): + """Scan the project folder for all available classes, ignoring 'plugins' folder.""" + class_map = {} + for root, dirs, files in os.walk(self._base_dir): + # Skip the 'plugins' folder if encountered + dirs[:] = [d for d in dirs if d != "plugins"] + + for file in files: + if file.endswith(".py") and not file.startswith("__"): + # Create the module path by converting file path to importable module + module_path = os.path.join(root, file) + module_name = self._module_name_from_path(module_path) + + try: + # Dynamically import the module + module = importlib.import_module(module_name) + + # Find all classes defined in the module + for name, obj in inspect.getmembers(module, inspect.isclass): + # Map class name to its full module path + class_map[name] = f"{module_name}.{name}" + except Exception as e: + # Handle any import errors + print(f"Failed to import {module_name}: {e}") + return class_map + + def _module_name_from_path(self, path): + """Convert a file path to a valid module import path.""" + module_name = path.replace(os.sep, ".")[:-3] + if module_name.startswith("."): + module_name = module_name[1:] + return module_name + + def _get_default_class(self, name): + """Dynamically load the default class based on the component name.""" + if name not in self._default_classes: + raise ValueError(f"Default class for '{name}' not found.") + + # Extract the module path and class name + full_class_path = self._default_classes[name] + module_path, class_name = full_class_path.rsplit(".", 1) + + # Dynamically import the class from the module + module = importlib.import_module(module_path) + return getattr(module, class_name) + + def clear_registry(self): + """Clear the registry and instances.""" + self._class_registry.clear() + self._class_instances.clear() diff --git a/pfdl_scheduler/plugins/README.md b/pfdl_scheduler/plugins/README.md new file mode 100644 index 0000000..a233ead --- /dev/null +++ b/pfdl_scheduler/plugins/README.md @@ -0,0 +1,47 @@ +# PFDL Plugin System +The PFDL plugin system can be used to create plugins that extends the grammar and the underlying logic of the PFDL. +In the following, the different steps for creating your own plugin will be explained in detail. + +## The Plugin Loader - modify PFDL code +The core of the plugin system is the `PluginLoader` class. This class can be used to load the desired plugins and for returning the overwritten classes. +The PFDL code base was designed in a way such that its base classes can be changed. Thus, the plugin loader returns a `PFDLBaseClasses` object which contains all base classes of the PFDL overwritten by the plugins (See the example below). + +The plugin loader's `load_plugins` method requries a list of strings which are essentially the paths to the plugin folders in which he will search for classes with the decorator `@base_class("")`. Important here is that the name of the overwritten base class must match with the acutal class. If a class has a decorator and is inside a folder in the plugins folder it will be used for overwritting the base respective base class. Additionally, the overwritten classes needs to inherit from the base class so that the new combined class receives all methods and attributes. + +```python +@base_class("Instance") +class Instance(pfdl_scheduler.model.instance.Instance): + ... +``` + + An example of how to load plugins and receive the overwritten base classes is shown here: + +```python +plugin_loader = PluginLoader() +plugin_loader.load_plugins(["plugins/sample_plugin_folder"]) + +pfdl_base_classes = plugin_loader.get_pfdl_base_classes() +``` + +The `Scheduler` class, which is the entry point for using the PFDL, has an optional paramter for the base classes which can be used with the newly created plugin base classes. This way, the user made changes will be directly inserted into the PFDL base code. If you want to also modify the Scheduler class a complete example would look like this: + +```python +scheduler = pfdl_base_classes.scheduler_class( + ... + pfdl_base_classes=pfdl_base_classes, +) +``` + +## Merging Grammars +If you want to make changes to the base grammar of the PFDL you can do that by creating a custom Lexer and Parser file in the antlr `.g4` format. You only need to define your required rules. +If the rule names are not in the base grammar they will just be added as new rules. +If they already exists they will be added as an alternative to the old rule. + +To generate a new grammar that contains the old rules and the newly or overwritten rules of the plugins, the `grammar_merge.py` script has to be executed inside the `plugins` folder. +The script requires a list of parser and subsequently a list of lexer files so that you can define which plugins should be used to create a combined grammar. +The order of the passed grammar files can change the overall result so keep that in mind. Moreover, due to the nature of such systems, some plugins might be not working when used together! + +### The Parser folder +The `grammar_merge.py` script will call the ANTLR build script to generate Lexer, Parser, and Visitor python files that can be used by the plugin and stores them inside the `parser` folder inside the `plugins` folder. +Internally, the newly generated clases will be inserted into the base classes which are then used by the PFDL base code. +In addition, if you want to check the merged grammar or want to manually edit it, the parser folder also contains the merged `.g4` files from which the parser files are generated. \ No newline at end of file diff --git a/pfdl_scheduler/plugins/__init__.py b/pfdl_scheduler/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pfdl_scheduler/plugins/grammar_merge.py b/pfdl_scheduler/plugins/grammar_merge.py new file mode 100644 index 0000000..1c80900 --- /dev/null +++ b/pfdl_scheduler/plugins/grammar_merge.py @@ -0,0 +1,197 @@ +# Copyright The PFDL Contributors +# +# Licensed under the MIT License. +# For details on the licensing terms, see the LICENSE file. +# SPDX-License-Identifier: MIT + +"""Merges multiple grammar files into a single grammar file.""" + +import argparse +import re +import subprocess +from pathlib import Path +from typing import Dict, List + +PLUGIN_ENTRY_POINT = "// {Plugin_Move_To_Front}" +LEXER_PLUGIN_INSERTION_POINT = "// {Plugin_Insertion_Point}" + +# Regular expression to match grammar rules including the custom entry point for plugins +rule_pattern = re.compile(r"((\/\/\s*\{Plugin_Move_To_Front\}\s*\n)?\w+\s*:\s*[^;]+;)") + + +def extract_rules(grammar_content: str) -> Dict[str, str]: + """Extracts rules from a given grammar content while preserving the original formatting. + + Args: + grammar_content: The entire content of the grammar + + Returns: + A dictionary of rule names and their entire formatted content. + """ + rules = {} + for match in rule_pattern.finditer(grammar_content): + rule = match.group(0) + rule_name = rule.split(":")[0].strip() + rules[rule_name] = rule + return rules + + +def merge_parser(base_grammar: str, new_grammar: str) -> str: + """Merges two parser contents while preserving the format of the base grammar. + + Args: + base_grammar: The base grammar content. + new_grammar: The new grammar content to be merged. + + Returns: + The merged grammar content as a string. + """ + base_rules = extract_rules(base_grammar) + new_rules = extract_rules(new_grammar) + + merged_grammar = base_grammar # Start with the base grammar as-is + + for rule_name, new_rule in new_rules.items(): + move_rule_to_the_front = False + if new_rule.strip().startswith(PLUGIN_ENTRY_POINT): + move_rule_to_the_front = True + new_rule = new_rule.replace(PLUGIN_ENTRY_POINT, "").strip() + rule_name = rule_name.replace(PLUGIN_ENTRY_POINT, "").strip() + + if rule_name in base_rules: + # Add new alternative to the existing rule + base_rule = base_rules[rule_name] + # Find the position before the semicolon to insert the new alternative + if move_rule_to_the_front: + insert_pos = base_rule.find(":") + 1 + merged_grammar = merged_grammar.replace( + base_rule, + base_rule[:insert_pos] + + new_rule.split(":")[1].strip().rstrip(";") + + " | " + + base_rule[insert_pos:], + ) + else: + insert_pos = base_rule.rfind(";") + merged_grammar = merged_grammar.replace( + base_rule, + base_rule[:insert_pos] + + " | " + + new_rule.split(":")[1].strip().rstrip(";") + + base_rule[insert_pos:], + ) + else: + # Add new rule at the end of the grammar with appropriate formatting + merged_grammar += "\n\n" + new_rule + + return merged_grammar + + +def merge_lexer(base_lexer: str, new_lexer: str) -> str: + """Merges two lexer contents while preserving the format of the base lexer. + + Args: + base_lexer: The base lexer content. + new_lexer: The new lexer content to be merged. + + Returns: + The merged lexer content as a string. + """ + merged_grammar = base_lexer # Start with the base grammar as-is + + insert_position = merged_grammar.find(LEXER_PLUGIN_INSERTION_POINT) + merged_grammar = ( + merged_grammar[:insert_position] + new_lexer + "\n\n" + merged_grammar[insert_position:] + ) + + return merged_grammar + + +def merge_multiple_parsers(parser_files: List[str]) -> str: + """Merge multiple grammar files. + + Args: + parser_files: A list of file paths to the grammar files. + + Returns: + The merged grammar content as a string. + """ + with open(parser_files[0], "r") as file: + base_grammar = file.read() + + for grammar_file in parser_files[1:]: + with open(grammar_file, "r") as file: + new_grammar = file.read() + base_grammar = merge_parser(base_grammar, new_grammar) + return base_grammar + + +def merge_multiple_lexers(lexer_files: List[str]) -> str: + """Merge multiple lexer files. + + Args: + lexer_files: A list of file paths to the lexer files. + + Returns: + The merged lexer content as a string. + """ + with open(lexer_files[0], "r") as file: + base_grammar = file.read() + + for grammar_file in lexer_files[1:]: + with open(grammar_file, "r") as file: + new_grammar = file.read() + base_grammar = merge_lexer(base_grammar, new_grammar) + return base_grammar + + +if __name__ == "__main__": + + script_description = """ + This script merges multiple grammar and lexer files into single grammar and lexer files respectively. + It then generates the corresponding ANTLR parser and lexer files for Python3. + + Usage: + python grammar_merge.py + + Arguments: + parser_files: List of file paths to the grammar files to be merged. + lexer_files: List of file paths to the lexer files to be merged. + + Example: + python grammar_merge.py parser1.g4 parser2.g4 lexer1.g4 lexer2.g4 + """ + parser = argparse.ArgumentParser( + prog="PFDL grammar merge script", + description=script_description, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument("parser_files", type=str, nargs="+") + parser.add_argument("lexer_files", type=str, nargs="+") + + args = parser.parse_args() + + parser_files = ["../../pfdl_grammar/PFDLParser.g4", *args.parser_files] + merged_parser = merge_multiple_parsers(parser_files) + + lexer_files = ["../../pfdl_grammar/PFDLLexer.g4", *args.lexer_files] + merged_lexer = merge_multiple_lexers(lexer_files) + + file = Path("parser/PFDLParser.g4") + file.parent.mkdir(parents=True, exist_ok=True) + file.write_text(merged_parser) + + file = Path("parser/PFDLLexer.g4") + file.write_text(merged_lexer) + + subprocess.call( + [ + "antlr4", + "-v", + "4.9.3", + "-Dlanguage=Python3", + "-visitor", + "parser/PFDLLexer.g4", + "parser/PFDLParser.g4", + ] + ) diff --git a/pfdl_scheduler/plugins/parser/__init__.py b/pfdl_scheduler/plugins/parser/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pfdl_scheduler/plugins/plugin_loader.py b/pfdl_scheduler/plugins/plugin_loader.py new file mode 100644 index 0000000..68e5010 --- /dev/null +++ b/pfdl_scheduler/plugins/plugin_loader.py @@ -0,0 +1,235 @@ +# Copyright The PFDL Contributors +# +# Licensed under the MIT License. +# For details on the licensing terms, see the LICENSE file. +# SPDX-License-Identifier: MIT + +"""Contains the PluginLoader class for dynamically loading plugins for the PFDL.""" + +from functools import wraps +import importlib.util +import os +import sys +import inspect +from typing import List + +from pfdl_scheduler.pfdl_base_classes import PFDLBaseClasses + +base_classes_registry = {} + +PLUGIN_FOLDER_PATH = "./pfdl_scheduler/plugins" + + +def base_class(existing_class_name): + """A Decorator to mark a class that will extend an existing class. + + Registers the class in the base_classes_registry. + """ + + def decorator(cls): + if existing_class_name not in base_classes_registry: + base_classes_registry[existing_class_name] = [] + base_classes_registry[existing_class_name].append(cls) + return cls + + return decorator + + +def wrap_method(original_method, new_methods): + """Chains multiple methods together, calling them in order.""" + + @wraps(original_method) + def wrapper(*args, **kwargs): + for method in new_methods: + result = method(*args, **kwargs) + return result + + return wrapper + + +def apply_plugin_to_base(base_class, plugin_class): + """Applies the methods and attributes of the plugin_class to the base_class. + + Handles method overwrites with different argument counts, including class methods. + """ + + class CombinedClass(base_class, plugin_class): + def __init__(self, *args, **kwargs): + plugin_class.__init__(self, *args, **kwargs) + + # Method containers for chaining + method_overrides = {} + + for name, method in plugin_class.__dict__.items(): + # Handle instance methods and class methods separately + if callable(method): + if name in method_overrides: + method_overrides[name].append(method) + else: + method_overrides[name] = [getattr(base_class, name, None), method] + elif not name.startswith("__"): + # Add class attributes (non-callable) + setattr(CombinedClass, name, method) + + # Add or override instance methods in CombinedClass + for name, methods in method_overrides.items(): + original_method = methods[0] if methods[0] is not None else None + combined_methods = methods[1:] # Plugins' methods + + if original_method: + setattr(CombinedClass, name, wrap_method(original_method, combined_methods)) + else: + setattr( + CombinedClass, name, wrap_method(lambda *args, **kwargs: None, combined_methods) + ) + + # Add class-level attributes + for name, attr in plugin_class.__dict__.items(): + if not callable(attr) and not name.startswith("__"): + if not hasattr(CombinedClass, name): + setattr(CombinedClass, name, attr) + + CombinedClass.__name__ = base_class.__name__ + CombinedClass.__qualname__ = CombinedClass.__name__ + + return CombinedClass + + +class PluginLoader: + """Loads plugins and applies them to the existing classes in the main project. + + The PluginLoader class is responsible for dynamically loading plugins from the plugin folder + and applying them to the existing classes in the main project. It automatically detects all + classes in the main project and combines them with the plugin classes to create the final classes. + """ + + def __init__(self): + self.existing_classes = self.get_existing_classes() + + def get_existing_classes(self): + """Automatically detect and load all classes in the main project, excluding the plugin folder.""" + existing_classes = {} + main_project_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + plugins_path = os.path.abspath(os.path.dirname(__file__)) # Path to the plugin folder + + # Walk through the project files to find Python files excluding the plugins folder + for root, _, files in os.walk(main_project_path): + if root.startswith(plugins_path): + continue # Skip files inside the plugins folder + + for file in files: + if file.endswith(".py"): + module_name = os.path.splitext(file)[0] + module_path = os.path.join(root, file) + + if module_name == "__init__": + continue + + # Convert file path to importable module name + relative_path = os.path.relpath(module_path, main_project_path) + module_import_name = relative_path.replace(os.path.sep, ".")[ + :-3 + ] # Remove '.py' + + try: + # Dynamically import the module + spec = importlib.util.spec_from_file_location( + module_import_name, module_path + ) + module = importlib.util.module_from_spec(spec) + sys.modules[module_import_name] = module + spec.loader.exec_module(module) + + # Inspect the module for classes + for name, obj in inspect.getmembers(module, inspect.isclass): + if obj.__module__ == module_import_name: + existing_classes[name] = obj + + except Exception as e: + print(f"Error loading module {module_import_name}: {e}") + + return existing_classes + + def load_plugin_modules(self, module_name, module_path): + """Dynamically import a module given its path and register any classes that overwrite base classes.""" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + def load_plugins(self, plugins: List[str]): + """Recursively load all Python files from plugin folders.""" + for plugin_folder in plugins: + plugin_path = os.path.join(PLUGIN_FOLDER_PATH, plugin_folder) + + if os.path.isdir(plugin_path): + # Walk through all files in the plugin folder + for root, _, files in os.walk(plugin_path): + for file in files: + if file.endswith(".py"): + module_name = f"{plugin_folder}.{file[:-3]}" # Plugin folder + filename without .py + module_path = os.path.join(root, file) + self.load_plugin_modules(module_name, module_path) + + def get_final_classes(self): + """Return a dictionary of final classes after applying plugins.""" + final_classes = {} + + for class_name, base_class in self.existing_classes.items(): + if class_name in base_classes_registry: + # Combine the existing class with the plugin classes + combined_class = base_class + for plugin_class in base_classes_registry[class_name]: + combined_class = apply_plugin_to_base(combined_class, plugin_class) + + final_classes[class_name] = combined_class + else: + final_classes[class_name] = base_class + + return final_classes + + def get_pfdl_base_classes( + self, pfdl_base_classes_path: str = "pfdl_scheduler" + ) -> PFDLBaseClasses: + """Return an instance of `PFDLBaseClasses` populated with final classes after applying plugins. + + Class names are dynamically handled. The base classes are populated with the final classes + after applying plugins, and the registry is updated with any new classes that are not already + present in the base classes. + + Args: + pfdl_base_classes_path: The path to the PFDL base classes module. + + Returns: + An instance of `PFDLBaseClasses` populated with the final classes after applying plugins. + """ + final_classes = self.get_final_classes() + base_classes = PFDLBaseClasses(base_dir=pfdl_base_classes_path) + + for class_name, class_ref in final_classes.items(): + # Try to find a matching property on the base class + # Convert class_name to its lower_snake_case form to match typical property naming conventions + property_name = self._class_name_to_property_name(class_name) + + # If the property exists, set it dynamically + if hasattr(base_classes, property_name): + setattr(base_classes, property_name, class_ref) + else: + base_classes.register_class(class_name, class_ref) + + return base_classes + + def _class_name_to_property_name(self, class_name: str) -> str: + """Converts a class name to a property name by converting CamelCase to snake_case. + + Args: + class_name: The name of the class to convert. + + Returns: + The converted property name. + """ + import re + + # Convert CamelCase to snake_case, and append '_class' to the name + s1 = re.sub("([a-z])([A-Z])", r"\1_\2", class_name).lower() + return f"{s1}_class" diff --git a/pfdl_scheduler/scheduler.py b/pfdl_scheduler/scheduler.py index d1f0fc5..322d066 100644 --- a/pfdl_scheduler/scheduler.py +++ b/pfdl_scheduler/scheduler.py @@ -22,13 +22,17 @@ from pfdl_scheduler.api.task_api import TaskAPI from pfdl_scheduler.api.service_api import ServiceAPI +from pfdl_scheduler.pfdl_base_classes import PFDLBaseClasses from pfdl_scheduler.utils.parsing_utils import parse_program from pfdl_scheduler.petri_net.generator import Node, PetriNetGenerator -from pfdl_scheduler.petri_net.logic import PetriNetLogic from pfdl_scheduler.scheduling.event import Event -from pfdl_scheduler.scheduling.event import START_PRODUCTION_TASK, SET_PLACE, SERVICE_FINISHED +from pfdl_scheduler.scheduling.event import ( + START_PRODUCTION_TASK, + SET_PLACE, + SERVICE_FINISHED, +) from pfdl_scheduler.scheduling.task_callbacks import TaskCallbacks from pfdl_scheduler.api.observer_api import NotificationType, Observer, Subject @@ -50,7 +54,9 @@ class Scheduler(Subject): The scheduler comprises almost the complete execution of a production order including the parsing of the PFDL description, model creation and validation and execution of the petri net. It interacts with the execution engines and informs them about services - or tasks which started or finished. + or tasks which started or finished. The pfdl_base_classes attribute is one of the most + impoortant attributes of the scheduler. It holds the base classes for the scheduler that + can be overwritten to extend the scheduler with plugins. This class implements the Observer pattern and serves as subject. Observers can be registered in the scheduler and receive updates (e.g. log entries, info about a new petri net img,..) @@ -68,6 +74,7 @@ class Scheduler(Subject): generate_test_ids: Indicates whether test ids should be generated. test_id_counters: A List consisting of counters for the test ids of tasks and services. observers: List of `Observers` used to update them on a `notify` call. + pfdl_base_classes: A `PFDLBaseClasses` instance which holds the base classes for the scheduler. """ def __init__( @@ -77,6 +84,7 @@ def __init__( draw_petri_net: bool = True, scheduler_uuid: str = "", dashboard_host_address: str = "", + pfdl_base_classes: PFDLBaseClasses = PFDLBaseClasses("pfdl_scheduler"), ) -> None: """Initialize the object. @@ -92,24 +100,46 @@ def __init__( draw_petri_net: A boolean indicating whether the petri net should be drawn. scheduler_uuid: A unique ID to identify the Scheduer / Production Order dashboard_host_address: The address of the Dashboard (if existing) + pfdl_base_classes: A `PFDLBaseClasses` instance which holds the base classes for the scheduler. """ - self.init_scheduler(scheduler_uuid, generate_test_ids) - self.pfdl_file_valid, self.process, pfdl_string = parse_program(pfdl_program) + self.init_scheduler( + scheduler_uuid, + generate_test_ids, + pfdl_base_classes.get_instance("PetriNetGenerator"), + pfdl_base_classes.get_instance("TaskCallbacks"), + ) + self.pfdl_file_valid, self.process, pfdl_string = parse_program( + pfdl_program, pfdl_base_classes + ) if self.pfdl_file_valid: - self.petri_net_generator = PetriNetGenerator( + self.petri_net_generator = pfdl_base_classes.get_class("PetriNetGenerator")( "", generate_test_ids=self.generate_test_ids, draw_net=draw_petri_net, file_name=self.scheduler_uuid, ) - self.setup_scheduling(draw_petri_net) + self.setup_scheduling(draw_petri_net, pfdl_base_classes.get_class("PetriNetLogic")) if dashboard_host_address != "": self.attach( DashboardObserver(dashboard_host_address, self.scheduler_uuid, pfdl_string) ) - def init_scheduler(self, scheduler_uuid: str, generate_test_ids: bool): + def init_scheduler( + self, + scheduler_uuid: str, + generate_test_ids: bool, + petri_net_generator: PetriNetGenerator, + task_callbacks: TaskCallbacks, + ) -> None: + """Initialize the scheduler with the given parameters. + + Args: + scheduler_uuid: A unique ID to identify the scheduler / production order + generate_test_ids: A boolean indicating whether test ids should be generated. + petri_net_generator: A `PetriNetGenerator` instance for generating the petri net. + task_callbacks: `TaskCallbacks` instance which holds the registered callbacks. + """ if scheduler_uuid == "": self.scheduler_uuid: str = str(uuid.uuid4()) else: @@ -117,9 +147,8 @@ def init_scheduler(self, scheduler_uuid: str, generate_test_ids: bool): self.running: bool = False self.pfdl_file_valid: bool = False self.process: Process = None - self.petri_net_generator: PetriNetGenerator = None - self.petri_net_logic: PetriNetLogic = None - self.task_callbacks: TaskCallbacks = TaskCallbacks() + self.petri_net_generator: PetriNetGenerator = petri_net_generator + self.task_callbacks: TaskCallbacks = task_callbacks self.variable_access_function: Callable[[str], str] = None self.loop_counters: Dict[str, Dict[str, int]] = {} self.awaited_events: List[Event] = [] @@ -127,11 +156,16 @@ def init_scheduler(self, scheduler_uuid: str, generate_test_ids: bool): self.test_id_counters: List[int] = [0, 0] self.observers: List[Observer] = [] - def setup_scheduling(self, draw_petri_net: bool): + def setup_scheduling(self, draw_petri_net: bool, petri_net_logic_class) -> None: + """Setup the scheduling process. + + This method is called after the PFDL file was successfully parsed and the petri net + generator was created. It generates the petri net and creates the petri net logic. + """ self.register_for_petrinet_callbacks() self.petri_net_generator.generate_petri_net(self.process) - self.petri_net_logic = PetriNetLogic( + self.petri_net_logic = petri_net_logic_class( self.petri_net_generator, draw_petri_net, file_name=self.scheduler_uuid ) @@ -359,7 +393,11 @@ def on_service_finished(self, service_api: ServiceAPI) -> None: self.notify(NotificationType.LOG_EVENT, (log_entry, logging.INFO, False)) def on_condition_started( - self, condition: Condition, then_uuid: str, else_uuid: str, task_context: TaskAPI + self, + condition: Condition, + then_uuid: str, + else_uuid: str, + task_context: TaskAPI, ) -> None: """Executes Scheduling logic when a Condition statement is started.""" if self.check_expression(condition.expression, task_context): diff --git a/pfdl_scheduler/scheduling/event.py b/pfdl_scheduler/scheduling/event.py index 2fec82f..dcbf075 100644 --- a/pfdl_scheduler/scheduling/event.py +++ b/pfdl_scheduler/scheduling/event.py @@ -33,11 +33,10 @@ def __init__(self, event_type: str = "", data: Dict = None) -> None: self.event_type: str = event_type self.data: Dict = data - def __eq__(self, other: "Event"): - if not isinstance(other, Event): - # don't attempt to compare against unrelated types - return NotImplemented - return self.event_type == other.event_type and self.data == other.data + def __eq__(self, other: object) -> bool: + if hasattr(other, "event_type") and hasattr(other, "data"): + return self.event_type == other.event_type and self.data == other.data + return False @classmethod def from_json(cls, json_string: str) -> Union[None, "Event"]: diff --git a/pfdl_scheduler/utils/parsing_utils.py b/pfdl_scheduler/utils/parsing_utils.py index 56a1454..488f808 100644 --- a/pfdl_scheduler/utils/parsing_utils.py +++ b/pfdl_scheduler/utils/parsing_utils.py @@ -17,19 +17,15 @@ from antlr4.InputStream import InputStream # local sources -from pfdl_scheduler.parser.pfdl_tree_visitor import PFDLTreeVisitor -from pfdl_scheduler.parser.PFDLLexer import PFDLLexer -from pfdl_scheduler.parser.PFDLParser import PFDLParser - -from pfdl_scheduler.validation.error_handler import ErrorHandler -from pfdl_scheduler.validation.syntax_error_listener import SyntaxErrorListener -from pfdl_scheduler.validation.semantic_error_checker import SemanticErrorChecker - +from pfdl_scheduler.pfdl_base_classes import PFDLBaseClasses from pfdl_scheduler.model.process import Process def parse_string( - pfdl_string: str, file_path: str = "", used_in_extension: bool = False + pfdl_string: str, + file_path: str = "", + used_in_extension: bool = False, + pfdl_base_classes: PFDLBaseClasses = PFDLBaseClasses("pfdl_scheduler"), ) -> Tuple[bool, Union[None, Process]]: """Instantiate the ANTLR lexer and parser and parses the given PFDL string. @@ -41,25 +37,27 @@ def parse_string( Returns: A boolan indicating validity of the PFDL file and the process object if so, otherwise None. """ - lexer = PFDLLexer(InputStream(pfdl_string)) + lexer = pfdl_base_classes.get_class("PFDLLexer")(InputStream(pfdl_string)) lexer.removeErrorListeners() token_stream = CommonTokenStream(lexer) - parser = PFDLParser(token_stream) + parser = pfdl_base_classes.get_class("PFDLParser")(token_stream) parser.removeErrorListeners() - error_handler = ErrorHandler(file_path, used_in_extension) - error_listener = SyntaxErrorListener(token_stream, error_handler) + error_handler = pfdl_base_classes.get_class("ErrorHandler")(file_path, used_in_extension) + error_listener = pfdl_base_classes.get_class("SyntaxErrorListener")(token_stream, error_handler) parser.addErrorListener(error_listener) tree = parser.program() if error_handler.has_error() is False: - visitor = PFDLTreeVisitor(error_handler) + visitor = pfdl_base_classes.get_class("PFDLTreeVisitor")(error_handler, pfdl_base_classes) process = visitor.visit(tree) - semantic_error_checker = SemanticErrorChecker(error_handler, process) + semantic_error_checker = pfdl_base_classes.get_class("SemanticErrorChecker")( + error_handler, process, pfdl_base_classes + ) semantic_error_checker.validate_process() if error_handler.has_error() is False: @@ -68,7 +66,9 @@ def parse_string( return (False, None) -def parse_program(program: str) -> Tuple[bool, Union[None, Process], str]: +def parse_program( + program: str, pfdl_base_classes: PFDLBaseClasses = PFDLBaseClasses("pfdl_scheduler") +) -> Tuple[bool, Union[None, Process], str]: """Loads the content of the program from either the given path or the PFDL program directly and calls the parse_string function. Args: @@ -79,7 +79,7 @@ def parse_program(program: str) -> Tuple[bool, Union[None, Process], str]: process object if so, otherwise None. """ pfdl_string, file_path = extract_content_and_file_path(program) - return *parse_string(pfdl_string, file_path), pfdl_string + return *parse_string(pfdl_string, file_path, pfdl_base_classes=pfdl_base_classes), pfdl_string def write_tokens_to_file(token_stream: CommonTokenStream) -> None: diff --git a/pfdl_scheduler/validation/semantic_error_checker.py b/pfdl_scheduler/validation/semantic_error_checker.py index 8768406..ba32174 100644 --- a/pfdl_scheduler/validation/semantic_error_checker.py +++ b/pfdl_scheduler/validation/semantic_error_checker.py @@ -25,6 +25,7 @@ from pfdl_scheduler.model.while_loop import WhileLoop from pfdl_scheduler.model.condition import Condition +from pfdl_scheduler.pfdl_base_classes import PFDLBaseClasses from pfdl_scheduler.validation.error_handler import ErrorHandler from pfdl_scheduler.utils import helpers @@ -47,7 +48,12 @@ class SemanticErrorChecker: structs: A Dict that contains all Struct objects of the given process object. """ - def __init__(self, error_handler: ErrorHandler, process: Process) -> None: + def __init__( + self, + error_handler: ErrorHandler, + process: Process, + pfdl_base_classes: PFDLBaseClasses = PFDLBaseClasses(), + ) -> None: """Initialize the object. Args: @@ -58,6 +64,7 @@ def __init__(self, error_handler: ErrorHandler, process: Process) -> None: self.process: Process = process self.tasks: Dict[str, Task] = process.tasks self.structs: Dict[str, Struct] = process.structs + self.pfdl_base_classes: PFDLBaseClasses = pfdl_base_classes def validate_process(self) -> bool: """Starts static semantic checks. @@ -174,7 +181,7 @@ def check_if_instance_attributes_exist_in_struct( self.error_handler.print_error(error_msg, context=instance.context) valid = False - elif isinstance(attribute_value, Instance): + elif isinstance(attribute_value, self.pfdl_base_classes.get_class("Instance")): # the attribute is an instance so recursively check its attributes nested_struct = self.process.structs[struct.attributes[attribute_name]] if not self.check_if_instance_attributes_exist_in_struct( @@ -199,7 +206,7 @@ def check_if_struct_attributes_are_assigned(self, struct: Struct, instance: Inst if struct_attribute == attribute_name: attribute_found = True - if isinstance(attribute_value, Instance): + if isinstance(attribute_value, self.pfdl_base_classes.get_class("Instance")): # the attribute is an instance so recursively check its attributes nested_struct = self.process.structs[struct.attributes[attribute_name]] if not self.check_if_struct_attributes_are_assigned( @@ -231,7 +238,7 @@ def check_if_value_matches_with_defined_type(self, struct: Struct, instance: Ins # This method assumes that the attribute exists in the Struct, so no additional check struct_attr_type = struct_attributes[attribute_name] - if isinstance(attribute_value, Instance): + if isinstance(attribute_value, self.pfdl_base_classes.get_class("Instance")): # the attribute is an instance so recursively check its attributes nested_struct = self.process.structs[struct_attr_type] if not self.check_if_value_matches_with_defined_type( @@ -270,15 +277,15 @@ def check_statement( Returns: True if the given statement is valid. """ - if isinstance(statement, Service): + if isinstance(statement, self.pfdl_base_classes.get_class("Service")): return self.check_service(statement, task) - if isinstance(statement, TaskCall): + if isinstance(statement, self.pfdl_base_classes.get_class("TaskCall")): return self.check_task_call(statement, task) - if isinstance(statement, Parallel): + if isinstance(statement, self.pfdl_base_classes.get_class("Parallel")): return self.check_parallel(statement, task) - if isinstance(statement, WhileLoop): + if isinstance(statement, self.pfdl_base_classes.get_class("WhileLoop")): return self.check_while_loop(statement, task) - if isinstance(statement, CountingLoop): + if isinstance(statement, self.pfdl_base_classes.get_class("CountingLoop")): return self.check_counting_loop(statement, task) return self.check_conditional_statement(statement, task) @@ -395,7 +402,10 @@ def check_if_task_call_matches_with_called_task(self, task_call: TaskCall, task: if variable_in_called_task in called_task.variables: type_of_variable = "" - if isinstance(called_task.variables[variable_in_called_task], Instance): + if isinstance( + called_task.variables[variable_in_called_task], + self.pfdl_base_classes.get_class("Instance"), + ): type_of_variable = called_task.variables[variable_in_called_task].struct_name else: type_of_variable = called_task.variables[variable_in_called_task] @@ -442,7 +452,7 @@ def check_if_input_parameter_matches( if input_parameter in task_context.variables: type_of_variable = "" variable = task_context.variables[input_parameter] - if isinstance(variable, Instance): + if isinstance(variable, self.pfdl_base_classes.get_class("Instance")): type_of_variable = variable.struct_name else: type_of_variable = variable @@ -479,7 +489,7 @@ def check_if_input_parameter_matches( i = 1 while i < len(input_parameter) - 1: element = current_struct.attributes[input_parameter[i]] - if isinstance(element, Array): + if isinstance(element, self.pfdl_base_classes.get_class("Array")): i = i + 1 current_struct = self.structs[element.type_of_elements] else: @@ -504,7 +514,7 @@ def check_if_input_parameter_matches( off_symbol_length=len(task_call.name), ) return False - elif isinstance(input_parameter, Struct): + elif isinstance(input_parameter, self.pfdl_base_classes.get_class("Struct")): if input_parameter.name != defined_type: error_msg = ( f"Type of TaskCall parameter '{input_parameter.name}' does not match " @@ -567,7 +577,7 @@ def check_call_input_parameters( valid = True for input_parameter in called_entity.input_parameters: - if isinstance(input_parameter, Struct): + if isinstance(input_parameter, self.pfdl_base_classes.get_class("Struct")): if not self.check_instantiated_struct_attributes(input_parameter): valid = False elif isinstance(input_parameter, list): @@ -742,7 +752,7 @@ def check_for_wrong_attribute_type_in_struct( if isinstance(correct_attribute_type, str): if correct_attribute_type in self.structs: # check for structs which has structs as attribute - if isinstance(attribute, Struct): + if isinstance(attribute, self.pfdl_base_classes.get_class("Struct")): attribute.name = correct_attribute_type struct_def = self.structs[correct_attribute_type] struct_correct = True @@ -767,10 +777,10 @@ def check_for_wrong_attribute_type_in_struct( self.error_handler.print_error(error_msg, context=struct_instance.context) return False - elif isinstance(correct_attribute_type, Array): - if not isinstance(attribute, Array) or not self.check_array( - attribute, correct_attribute_type - ): + elif isinstance(correct_attribute_type, self.pfdl_base_classes.get_class("Array")): + if not isinstance( + attribute, self.pfdl_base_classes.get_class("Array") + ) or not self.check_array(attribute, correct_attribute_type): error_msg = ( f"Attribute '{identifier}' has the wrong type in the instantiated" f" Struct '{struct_instance.name}', expected 'Array'" @@ -789,7 +799,7 @@ def check_array(self, instantiated_array: Array, array_definition: Array) -> boo element_type = array_definition.type_of_elements for value in instantiated_array.values: # type of Struct not checked yet - if isinstance(value, Struct): + if isinstance(value, self.pfdl_base_classes.get_class("Struct")): if value.name == "": value.name = array_definition.type_of_elements if not self.check_instantiated_struct_attributes(value): @@ -850,10 +860,10 @@ def check_counting_loop(self, counting_loop: CountingLoop, task: Task) -> bool: """ if counting_loop.parallel: if len(counting_loop.statements) == 1 and isinstance( - counting_loop.statements[0], TaskCall + counting_loop.statements[0], self.pfdl_base_classes.get_class("TaskCall") ): return True - error_msg = "Only a single task is allowed in a parallel loop statement!" + error_msg = "Only a single task call is allowed in a parallel loop statement!" self.error_handler.print_error(error_msg, context=counting_loop.context) return False else: @@ -1020,7 +1030,7 @@ def check_if_variable_definition_is_valid( if isinstance(variable_type, str): if not self.variable_type_exists(variable_type): valid = False - elif isinstance(variable_type, Array): + elif isinstance(variable_type, self.pfdl_base_classes.get_class("Array")): element_type = variable_type.type_of_elements if not self.variable_type_exists(element_type): valid = False @@ -1077,7 +1087,7 @@ def check_type_of_value(self, value: Any, value_type: str) -> bool: return isinstance(value, bool) if value_type == "string": return isinstance(value, str) - if isinstance(value, Struct): + if isinstance(value, self.pfdl_base_classes.get_class("Struct")): return value.name == value_type # value was a string return True diff --git a/tests/unit_test/test_pfdl_base_classes.py b/tests/unit_test/test_pfdl_base_classes.py new file mode 100644 index 0000000..e66608d --- /dev/null +++ b/tests/unit_test/test_pfdl_base_classes.py @@ -0,0 +1,94 @@ +# Copyright The PFDL Contributors +# +# Licensed under the MIT License. +# For details on the licensing terms, see the LICENSE file. +# SPDX-License-Identifier: MIT + +"""Contains unit tests for the PFDLBaseClasses class.""" + +import unittest +from unittest.mock import patch, MagicMock +from pfdl_scheduler.pfdl_base_classes import PFDLBaseClasses + + +class TestPFDLBaseClasses(unittest.TestCase): + """Test the PFDLBaseClasses class.""" + + def test_register_class(self): + pfdl_base_classes = PFDLBaseClasses() + mock_class = type("MockClass", (object,), {}) + + pfdl_base_classes.register_class("MockClass", mock_class) + + # Verify that the class is registered + self.assertEqual(pfdl_base_classes.get_class("MockClass"), mock_class) + + def test_get_class_default(self): + pfdl_base_classes = PFDLBaseClasses() + mock_class = type("MockClass", (object,), {}) + + # Simulate default class registration + pfdl_base_classes._default_classes["MockClass"] = "module.MockClass" + + with patch( + "pfdl_scheduler.pfdl_base_classes.importlib.import_module" + ) as mock_import_module: + mock_import_module.return_value = MagicMock(MockClass=mock_class) + + # Verify that the default class can be retrieved + self.assertEqual(pfdl_base_classes.get_class("MockClass"), mock_class) + + def test_get_instance(self): + pfdl_base_classes = PFDLBaseClasses() + mock_class = type( + "MockClass", (object,), {"__init__": lambda self, x: setattr(self, "x", x)} + ) + + pfdl_base_classes.register_class("MockClass", mock_class) + + # Verify instance creation and reuse + instance = pfdl_base_classes.get_instance("MockClass", 5) + self.assertEqual(instance.x, 5) + instance2 = pfdl_base_classes.get_instance("MockClass") + self.assertIs(instance, instance2) + + def test_clear_registry(self): + pfdl_base_classes = PFDLBaseClasses() + mock_class = type("MockClass", (object,), {}) + + pfdl_base_classes.register_class("MockClass", mock_class) + pfdl_base_classes.get_instance("MockClass") + + # Clear registry and instances + pfdl_base_classes.clear_registry() + + # Verify that both registry and instances are cleared + self.assertNotIn("MockClass", pfdl_base_classes._class_registry) + self.assertNotIn("MockClass", pfdl_base_classes._class_instances) + + def test_module_name_from_path(self): + pfdl_base_classes = PFDLBaseClasses() + module_name = pfdl_base_classes._module_name_from_path("fake_dir/module.py") + self.assertEqual(module_name, "fake_dir.module") + + pfdl_base_classes = PFDLBaseClasses() + module_name = pfdl_base_classes._module_name_from_path("/fake_dir/module.py") + self.assertEqual(module_name, "fake_dir.module") + + @patch("pfdl_scheduler.pfdl_base_classes.importlib.import_module") + def test_get_default_class(self, mock_import_module): + pfdl_base_classes = PFDLBaseClasses() + mock_class = type("MockClass", (object,), {}) + pfdl_base_classes._default_classes["MockClass"] = "module.MockClass" + + mock_import_module.return_value = MagicMock(MockClass=mock_class) + + # Verify default class retrieval + self.assertEqual(pfdl_base_classes._get_default_class("MockClass"), mock_class) + + def test_get_default_class_nonexistent(self): + pfdl_base_classes = PFDLBaseClasses() + with self.assertRaises(ValueError) as context: + pfdl_base_classes._get_default_class("NonExistentClass") + + self.assertEqual(str(context.exception), "Default class for 'NonExistentClass' not found.") diff --git a/tests/unit_test/test_pfdl_tree_visitor.py b/tests/unit_test/test_pfdl_tree_visitor.py index 6ad61ba..3f3f830 100644 --- a/tests/unit_test/test_pfdl_tree_visitor.py +++ b/tests/unit_test/test_pfdl_tree_visitor.py @@ -14,6 +14,7 @@ import unittest from unittest.mock import MagicMock from unittest.mock import patch +from pfdl_scheduler.model.instance import Instance from pfdl_scheduler.model.parallel import Parallel @@ -145,8 +146,8 @@ def test_visit_struct(self): def test_visit_task(self): task_in_context = PFDLParser.Task_inContext(None) task_out_context = PFDLParser.Task_outContext(None) - statement_context_1 = PFDLParser.StatementContext(None) - statement_context_2 = PFDLParser.StatementContext(None) + statement_context_1 = PFDLParser.TaskStatementContext(None) + statement_context_2 = PFDLParser.TaskStatementContext(None) task_context = PFDLParser.TaskContext(None) task_context.children = [ @@ -175,7 +176,7 @@ def test_visit_task(self): ) as mock_2: with patch.object( self.visitor, - "visitStatement", + "visitTaskStatement", MagicMock(side_effect=[statement_1, statement_2]), ) as mock_3: task = self.visitor.visitTask(task_context) @@ -944,6 +945,38 @@ def test_visitUnOperation(self): un_op = self.visitor.visitUnOperation(un_op_context) self.assertEqual(un_op, "!") + def test_visitInstance(self): + instance_context = PFDLParser.InstanceContext(None) + + instance_context.children = [ + PFDLParser.Struct_idContext(None), + PFDLParser.Attribute_assignmentContext(None), + PFDLParser.Attribute_assignmentContext(None), + ] + create_and_add_token(PFDLParser.STARTS_WITH_LOWER_C_STR, "instance_id", instance_context) + with patch.object( + PFDLTreeVisitor, + "visitStruct_id", + MagicMock(side_effect=["struct_id"]), + ): + with patch.object( + PFDLTreeVisitor, + "visitAttribute_assignment", + MagicMock(side_effect=[("attr", "value"), ("attr_2", {"id": "value"})]), + ): + with patch.object( + Instance, + "from_json", + MagicMock(side_effect=[Instance(attributes={"id": "value"})]), + ): + instance = self.visitor.visitInstance(instance_context) + + self.assertIsNotNone(instance) + self.assertEqual(instance.name, "instance_id") + self.assertEqual(instance.struct_name, "struct_id") + self.assertEqual(len(instance.attributes), 2) + self.assertTrue(isinstance(instance.attributes["attr_2"], Instance)) + def create_and_add_token( token_type: int, token_text: str, antlr_context: ParserRuleContext diff --git a/tests/unit_test/test_plugin_loader.py b/tests/unit_test/test_plugin_loader.py new file mode 100644 index 0000000..4ea7ce8 --- /dev/null +++ b/tests/unit_test/test_plugin_loader.py @@ -0,0 +1,39 @@ +# Copyright The PFDL Contributors +# +# Licensed under the MIT License. +# For details on the licensing terms, see the LICENSE file. +# SPDX-License-Identifier: MIT + +"""Contains unit tests for the PluginLoader class.""" + +import unittest +import os +from unittest.mock import patch, MagicMock +from pfdl_scheduler.plugins.plugin_loader import PluginLoader, base_class, apply_plugin_to_base + + +class TestPluginLoader(unittest.TestCase): + """Test the PluginLoader class.""" + + def test_apply_plugin_to_base(self): + class BaseClass: + def method(self): + return "base" + + @base_class("BaseClass") + class PluginClass: + def method(self): + return "plugin" + + combined_class = apply_plugin_to_base(BaseClass, PluginClass) + instance = combined_class() + + # Check if method is overridden correctly + self.assertEqual(instance.method(), "plugin") + + def test_class_name_to_property_name(self): + plugin_loader = PluginLoader() + property_name = plugin_loader._class_name_to_property_name("TestClassName") + + # Check if conversion is correct + self.assertEqual(property_name, "test_class_name_class") diff --git a/tests/unit_test/test_semantic_error_checker.py b/tests/unit_test/test_semantic_error_checker.py index 607d128..9fcae49 100644 --- a/tests/unit_test/test_semantic_error_checker.py +++ b/tests/unit_test/test_semantic_error_checker.py @@ -16,7 +16,8 @@ from typing import Dict from pfdl_scheduler.model.condition import Condition import unittest -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch +from pfdl_scheduler.model.instance import Instance from pfdl_scheduler.model.parallel import Parallel # local sources @@ -233,6 +234,96 @@ def test_check_tasks(self): self.check_if_print_error_is_called(self.execute_check_tasks, False, False, False) + def test_check_instances(self): + empty_instances_valid = self.checker.check_instances() + self.assertTrue(empty_instances_valid) + + test_instance = Instance("testInstance", struct_name="TestStruct") + test_struct = Struct("TestStruct") + self.process.instances = {"testInstance": test_instance} + self.process.structs = {"TestStruct": test_struct} + + # test valid case + with patch.object( + SemanticErrorChecker, + "check_if_instance_attributes_exist_in_struct", + MagicMock(side_effect=[True]), + ): + with patch.object( + SemanticErrorChecker, + "check_if_value_matches_with_defined_type", + MagicMock(side_effect=[True]), + ): + with patch.object( + SemanticErrorChecker, + "check_if_struct_attributes_are_assigned", + MagicMock(side_effect=[True]), + ): + is_instance_valid = self.checker.check_instances() + + self.assertTrue(is_instance_valid) + + # test invalid cases + with patch.object( + SemanticErrorChecker, + "check_if_instance_attributes_exist_in_struct", + MagicMock(side_effect=[False]), + ): + with patch.object( + SemanticErrorChecker, + "check_if_value_matches_with_defined_type", + MagicMock(side_effect=[True]), + ) as value_matches_mock: + with patch.object( + SemanticErrorChecker, + "check_if_struct_attributes_are_assigned", + MagicMock(side_effect=[True]), + ) as struct_attributes_assigned_mock: + is_instance_valid = self.checker.check_instances() + + self.assertFalse(is_instance_valid) + value_matches_mock.assert_not_called() + struct_attributes_assigned_mock.assert_called() + + with patch.object( + SemanticErrorChecker, + "check_if_instance_attributes_exist_in_struct", + MagicMock(side_effect=[True]), + ): + with patch.object( + SemanticErrorChecker, + "check_if_value_matches_with_defined_type", + MagicMock(side_effect=[False]), + ) as value_matches_mock: + with patch.object( + SemanticErrorChecker, + "check_if_struct_attributes_are_assigned", + MagicMock(side_effect=[True]), + ) as struct_attributes_assigned_mock: + is_instance_valid = self.checker.check_instances() + + self.assertFalse(is_instance_valid) + struct_attributes_assigned_mock.assert_called() + + with patch.object( + SemanticErrorChecker, + "check_if_instance_attributes_exist_in_struct", + MagicMock(side_effect=[True]), + ): + with patch.object( + SemanticErrorChecker, + "check_if_value_matches_with_defined_type", + MagicMock(side_effect=[True]), + ) as value_matches_mock: + with patch.object( + SemanticErrorChecker, + "check_if_struct_attributes_are_assigned", + MagicMock(side_effect=[False]), + ) as struct_attributes_assigned_mock: + is_instance_valid = self.checker.check_instances() + + self.assertFalse(is_instance_valid) + def test_check_statements(self): dummy_task = Task() dummy_task.statements = [Service()]