Skip to content

Commit

Permalink
Implement plugin system
Browse files Browse the repository at this point in the history
- Add PFDLBaseClasses which is an object that holds the actual classes instantiated
by the PFDL Scheduler (can be overwritten by plugins)
- Add plugin loader class that can load plugins and return a new PFDLBaseClasses object
- Script for merging multiple grammars into a comprehensive grammar
- Adaptions to the existing code base so that the new PFDLBaseClasses are used
- Adaptions to the unit tests and new ones for the plugin classes
  • Loading branch information
maxhoerstr committed Nov 6, 2024
1 parent 1cd7bea commit d02a9b2
Show file tree
Hide file tree
Showing 19 changed files with 1,010 additions and 99 deletions.
2 changes: 1 addition & 1 deletion pfdl_scheduler/model/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __radd__(self, other) -> str:
return other + str(self)

def __eq__(self, __o: object) -> bool:
if isinstance(__o, Array):
if hasattr(__o, "values") and hasattr(__o, "length") and hasattr(__o, "type_of_elements"):
return (
self.values == __o.values
and self.length == __o.length
Expand Down
2 changes: 1 addition & 1 deletion pfdl_scheduler/model/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def parse_json(
error_msg = "Array definition in JSON are not supported in the PFDL."
error_handler.print_error(error_msg, context=instance_context)
elif isinstance(value, dict):
inner_struct = parse_json(value, error_handler, instance_context)
inner_struct = parse_json(value, error_handler, instance_context, instance_class)
instance.attributes[identifier] = inner_struct

return instance
7 changes: 6 additions & 1 deletion pfdl_scheduler/model/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ def __init__(
self.context_dict: Dict = {}

def __eq__(self, __o: object) -> bool:
if isinstance(__o, Struct):
if (
hasattr(__o, "name")
and hasattr(__o, "attributes")
and hasattr(__o, "context")
and hasattr(__o, "context_dict")
):
return (
self.name == __o.name
and self.attributes == __o.attributes
Expand Down
66 changes: 44 additions & 22 deletions pfdl_scheduler/parser/pfdl_tree_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
# standard libraries
from typing import Dict, List, OrderedDict, Tuple, Union
from pfdl_scheduler.model.instance import Instance
from pfdl_scheduler.pfdl_base_classes import PFDLBaseClasses
from pfdl_scheduler.utils import helpers
from pfdl_scheduler.model.parallel import Parallel

# 3rd party
from antlr4.tree.Tree import TerminalNodeImpl

# local sources
from pfdl_scheduler.validation.error_handler import ErrorHandler

Expand Down Expand Up @@ -45,15 +49,22 @@ class PFDLTreeVisitor(PFDLParserVisitor):
Attributes:
error_handler: ErrorHandler instance for printing errors while visiting.
current_task: Reference to the currently visited Task. Every visitor method can access it.
pfdl_base_classes: `PFDLBaseClasses` instance for creating new objects.
"""

def __init__(self, error_handler: ErrorHandler) -> None:
def __init__(
self,
error_handler: ErrorHandler,
pfdl_base_classes: PFDLBaseClasses = PFDLBaseClasses(),
) -> None:
"""Initialize the object.
Args:
error_handler: ErrorHandler instance for printing errors while visiting.
error_handler: `ErrorHandler` instance for printing errors while visiting.
pfdl_base_classes: `PFDLBaseClasses` instance for creating new objects.
"""
self.error_handler: ErrorHandler = error_handler
self.pfdl_base_classes: PFDLBaseClasses = pfdl_base_classes
self.current_task: Task = None

def visitProgram(self, ctx) -> Process:
Expand All @@ -64,7 +75,7 @@ def visitProgram(self, ctx) -> Process:
for child in ctx.children:
process_component = self.visit(child)

if isinstance(process_component, Struct):
if isinstance(process_component, self.pfdl_base_classes.get_class("Struct")):
if process_component.name not in process.structs:
process.structs[process_component.name] = process_component
else:
Expand All @@ -73,15 +84,15 @@ def visitProgram(self, ctx) -> Process:
"is already defined"
)
self.error_handler.print_error(error_msg, context=child)
elif isinstance(process_component, Task):
elif isinstance(process_component, self.pfdl_base_classes.get_class("Task")):
if process_component.name not in process.tasks:
process.tasks[process_component.name] = process_component
else:
error_msg = (
f"A Task with the name '{process_component.name}' " "is already defined"
)
self.error_handler.print_error(error_msg, context=child)
elif isinstance(process_component, Instance):
elif isinstance(process_component, self.pfdl_base_classes.get_class("Instance")):
if process_component.name not in process.tasks:
process.instances[process_component.name] = process_component
else:
Expand All @@ -106,8 +117,7 @@ def execute_additional_tasks(self, process: Process) -> None:
self.add_inherited_attributes_to_structs(process)

def add_inherited_attributes_to_structs(self, process: Process) -> None:
"""
Tries to add attributes inherited from the respective parents to all child structs.
"""Tries to add attributes inherited from the respective parents to all child structs.
Throws an error if one parent struct name is found to be invalid.
"""
Expand All @@ -129,12 +139,19 @@ def visitProgram_statement(self, ctx: PFDLParser.Program_statementContext):
return self.visit(ctx.children[0])

def addInstancesToAllTasks(self, process: Process) -> None:
"""Adds all instances to the variables of all tasks in the process.
This method is necessary to use the instances in expressions.
Args:
process: The `Process` object containing all tasks and instances.
"""
for instance in process.instances.values():
for task in process.tasks.values():
task.variables[instance.name] = instance.struct_name

def visitStruct(self, ctx) -> Struct:
struct = self.pfdl_base_classes.struct()
struct = self.pfdl_base_classes.get_class("Struct")()
struct.name = ctx.STARTS_WITH_UPPER_C_STR().getText()
struct.context = ctx

Expand All @@ -160,7 +177,7 @@ def visitStruct_id(self, ctx: PFDLParser.Struct_idContext) -> str:
return ctx.children[0].getText()

def visitTask(self, ctx) -> Task:
task = Task()
task = self.pfdl_base_classes.get_class("Task")()
task.name = ctx.STARTS_WITH_LOWER_C_STR().getText()
task.context = ctx

Expand All @@ -170,8 +187,8 @@ def visitTask(self, ctx) -> Task:
task.input_parameters = self.visitTask_in(ctx.task_in())
task.context_dict[IN_KEY] = ctx.task_in()

for statement_ctx in ctx.statement():
statement = self.visitStatement(statement_ctx)
for statement_ctx in ctx.taskStatement():
statement = self.visitTaskStatement(statement_ctx)
task.statements.append(statement)
if ctx.task_out():
task.output_parameters = self.visitTask_out(ctx.task_out())
Expand All @@ -182,7 +199,7 @@ def visitTask(self, ctx) -> Task:
def visitInstance(self, ctx: PFDLParser.InstanceContext) -> Instance:
instance_name = ctx.STARTS_WITH_LOWER_C_STR().getText()
struct_name = self.visitStruct_id(ctx.struct_id())
instance = self.pfdl_base_classes.instance(
instance = self.pfdl_base_classes.get_class("Instance")(
name=instance_name, struct_name=struct_name, context=ctx
)
self.current_program_component = instance
Expand All @@ -192,8 +209,11 @@ def visitInstance(self, ctx: PFDLParser.InstanceContext) -> Instance:
)
# JSON value
if isinstance(attribute_value, Dict):
attribute_value = self.pfdl_base_classes.instance.from_json(
attribute_value, self.error_handler, ctx, self.pfdl_base_classes.instance
attribute_value = self.pfdl_base_classes.get_class("Instance").from_json(
attribute_value,
self.error_handler,
ctx,
self.pfdl_base_classes.get_class("Instance"),
)
instance.attributes[attribute_name] = attribute_value
instance.attribute_contexts[attribute_name] = attribute_assignment_ctx
Expand Down Expand Up @@ -251,7 +271,7 @@ def visitStatement(
return statement

def visitService_call(self, ctx: PFDLParser.Service_callContext) -> Service:
service = Service()
service = self.pfdl_base_classes.get_class("Service")()
service.context = ctx

service.name = ctx.STARTS_WITH_UPPER_C_STR().getText()
Expand Down Expand Up @@ -303,13 +323,15 @@ def visitParameter(self, ctx: PFDLParser.ParameterContext) -> Union[str, List[st
def visitStruct_initialization(self, ctx: PFDLParser.Struct_initializationContext) -> Struct:
json_string = ctx.json_object().getText()

struct = Struct.from_json(json_string, self.error_handler, ctx.json_object())
struct = self.pfdl_base_classes.get_class("Struct").from_json(
json_string, self.error_handler, ctx.json_object()
)
struct.name = ctx.STARTS_WITH_UPPER_C_STR().getText()
struct.context = ctx
return struct

def visitTask_call(self, ctx: PFDLParser.Task_callContext) -> TaskCall:
task_call = TaskCall()
task_call = self.pfdl_base_classes.get_class("TaskCall")()
task_call.name = ctx.STARTS_WITH_LOWER_C_STR().getText()
task_call.context = ctx

Expand All @@ -329,15 +351,15 @@ def visitTask_call(self, ctx: PFDLParser.Task_callContext) -> TaskCall:
return task_call

def visitParallel(self, ctx: PFDLParser.ParallelContext) -> Parallel:
parallel = Parallel()
parallel = self.pfdl_base_classes.get_class("Parallel")()
parallel.context = ctx
for task_call_context in ctx.task_call():
task_call = self.visitTask_call(task_call_context)
parallel.task_calls.append(task_call)
return parallel

def visitWhile_loop(self, ctx: PFDLParser.While_loopContext) -> WhileLoop:
while_loop = WhileLoop()
while_loop = self.pfdl_base_classes.get_class("WhileLoop")()
while_loop.context = ctx

while_loop.expression = self.visitExpression(ctx.expression())
Expand All @@ -348,7 +370,7 @@ def visitWhile_loop(self, ctx: PFDLParser.While_loopContext) -> WhileLoop:
return while_loop

def visitCounting_loop(self, ctx: PFDLParser.Counting_loopContext) -> CountingLoop:
counting_loop = CountingLoop()
counting_loop = self.pfdl_base_classes.get_class("CountingLoop")()
counting_loop.context = ctx

counting_loop.counting_variable = ctx.STARTS_WITH_LOWER_C_STR().getText()
Expand All @@ -367,7 +389,7 @@ def visitCounting_loop(self, ctx: PFDLParser.Counting_loopContext) -> CountingLo
return counting_loop

def visitCondition(self, ctx: PFDLParser.ConditionContext) -> Condition:
condition_statement = Condition()
condition_statement = self.pfdl_base_classes.get_class("Condition")()
condition_statement.context = ctx

condition_statement.expression = self.visitExpression(ctx.expression())
Expand Down Expand Up @@ -416,7 +438,7 @@ def visitPrimitive(self, ctx: PFDLParser.PrimitiveContext):
return ctx.getText()

def initializeArray(self, array_ctx: PFDLParser.ArrayContext, variable_type: str) -> Array:
array = Array()
array = self.pfdl_base_classes.get_class("Array")()
array.type_of_elements = variable_type
array.context = array_ctx
length = self.visitArray(array_ctx)
Expand Down
29 changes: 19 additions & 10 deletions pfdl_scheduler/petri_net/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from pfdl_scheduler.petri_net.drawer import draw_petri_net
from pfdl_scheduler.petri_net.callbacks import PetriNetCallbacks
from pfdl_scheduler.pfdl_base_classes import PFDLBaseClasses

plugins.load(["labels", "gv", "clusters"], "snakes.nets", "nets")

Expand Down Expand Up @@ -82,6 +83,7 @@ class PetriNetGenerator:
callbacks: A PetriNetCallbacks instance representing functions called while execution.
generate_test_ids: A boolean indicating if test ids (counting from 0) should be generated.
used_in_extension: A boolean indicating if the Generator is used within the extension.
pfdl_base_classes: An instance of `PFDLBaseClasses`.
"""

def __init__(
Expand All @@ -91,6 +93,7 @@ def __init__(
generate_test_ids: bool = False,
draw_net: bool = True,
file_name: str = "petri_net",
pfdl_base_classes: PFDLBaseClasses = PFDLBaseClasses(),
) -> None:
"""Initialize the object.
Expand All @@ -100,6 +103,7 @@ def __init__(
generate_test_ids: A boolean indicating if test ids (counting from 0) should be generated.
draw_net: A boolean indicating if the petri net should be drawn.
file_name: The desired filename of the petri net image.
pfdl_base_classes: An instance of `PFDLBaseClasses`.
"""

if used_in_extension:
Expand All @@ -117,11 +121,12 @@ def __init__(
self.transition_dict: OrderedDict = OrderedDict()
self.place_dict: Dict = {}
self.task_started_uuid: str = ""
self.callbacks: PetriNetCallbacks = PetriNetCallbacks()
self.callbacks: PetriNetCallbacks = pfdl_base_classes.get_class("PetriNetCallbacks")()
self.generate_test_ids: bool = generate_test_ids
self.used_in_extension: bool = used_in_extension
self.tree = None
self.file_name = file_name
self.pfdl_base_classes = pfdl_base_classes

def add_callback(self, transition_uuid: str, callback_function: Callable, *args: Any) -> None:
"""Registers the given callback function in the transition_dict.
Expand Down Expand Up @@ -156,7 +161,7 @@ def generate_petri_net(self, process: Process) -> PetriNet:
group_uuid = str(uuid.uuid4())
self.tree = Node(group_uuid, start_task.name)

task_context = TaskAPI(start_task, None)
task_context = self.pfdl_base_classes.get_class("TaskAPI")(start_task, None)
if self.generate_test_ids:
task_context.uuid = "0"

Expand Down Expand Up @@ -244,17 +249,17 @@ def generate_statements(
in_loop,
)

if isinstance(statement, Service):
if isinstance(statement, self.pfdl_base_classes.get_class("Service")):
connection_uuids = [self.generate_service(*args)]
elif isinstance(statement, TaskCall):
elif isinstance(statement, self.pfdl_base_classes.get_class("TaskCall")):
connection_uuids = self.generate_task_call(*args)
elif isinstance(statement, Parallel):
elif isinstance(statement, self.pfdl_base_classes.get_class("Parallel")):
connection_uuids = [self.generate_parallel(*args)]
elif isinstance(statement, CountingLoop):
elif isinstance(statement, self.pfdl_base_classes.get_class("CountingLoop")):
connection_uuids = [self.generate_counting_loop(*args)]
elif isinstance(statement, WhileLoop):
elif isinstance(statement, self.pfdl_base_classes.get_class("WhileLoop")):
connection_uuids = [self.generate_while_loop(*args)]
elif isinstance(statement, Condition):
elif isinstance(statement, self.pfdl_base_classes.get_class("Condition")):
connection_uuids = self.generate_condition(*args)
else:
connection_uuids = self.handle_other_statements(*args)
Expand All @@ -280,7 +285,9 @@ def generate_service(
group_uuid = str(uuid.uuid4())
service_node = Node(group_uuid, service.name, node)

service_api = ServiceAPI(service, task_context, in_loop=in_loop)
service_api = self.pfdl_base_classes.get_class("ServiceAPI")(
service, task_context, in_loop=in_loop
)

service_started_uuid = create_place(service.name + " started", self.net, service_node)
service_finished_uuid = create_place(service.name + " finished", self.net, service_node)
Expand Down Expand Up @@ -327,7 +334,9 @@ def generate_task_call(
The uuids of the last transitions of the TaskCall petri net component.
"""
called_task = self.tasks[task_call.name]
new_task_context = TaskAPI(called_task, task_context, task_call=task_call, in_loop=in_loop)
new_task_context = self.pfdl_base_classes.get_class("TaskAPI")(
called_task, task_context, task_call=task_call, in_loop=in_loop
)

group_uuid = str(uuid.uuid4())
task_node = Node(group_uuid, task_call.name, node)
Expand Down
Loading

0 comments on commit d02a9b2

Please sign in to comment.