You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
2.6 KiB
92 lines
2.6 KiB
"""
|
|
src/scenario_loader.py
|
|
----------------------
|
|
Dynamic scenario discovery and loading.
|
|
|
|
Usage
|
|
-----
|
|
from scenario_loader import load_scenario, list_scenarios
|
|
|
|
scenario = load_scenario("braking") # returns BrakingScenario()
|
|
names = list_scenarios() # ['braking', 'cutin', 'obstacle']
|
|
"""
|
|
|
|
import importlib
|
|
import inspect
|
|
import pkgutil
|
|
import sys
|
|
import os
|
|
|
|
# Ensure project root is on path so 'scenarios' package is importable
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
if PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, PROJECT_ROOT)
|
|
|
|
from scenarios.base import ScenarioBase
|
|
|
|
|
|
def list_scenarios() -> list[str]:
|
|
"""
|
|
Discover all scenario modules inside the 'scenarios' package.
|
|
Returns a sorted list of module names (e.g. ['braking', 'cutin', 'obstacle']).
|
|
Skips 'base' and '__init__'.
|
|
"""
|
|
import scenarios as _scenarios_pkg
|
|
|
|
names = []
|
|
for finder, module_name, is_pkg in pkgutil.iter_modules(_scenarios_pkg.__path__):
|
|
if module_name in ("base",):
|
|
continue
|
|
names.append(module_name)
|
|
|
|
return sorted(names)
|
|
|
|
|
|
def load_scenario(name: str) -> ScenarioBase:
|
|
"""
|
|
Dynamically import scenario module 'scenarios.<name>' and return
|
|
an instance of the concrete ScenarioBase subclass found within it.
|
|
|
|
Parameters
|
|
----------
|
|
name : str
|
|
Scenario identifier, e.g. 'braking', 'cutin', 'obstacle'.
|
|
|
|
Returns
|
|
-------
|
|
ScenarioBase
|
|
An instantiated, ready-to-use scenario object.
|
|
|
|
Raises
|
|
------
|
|
ModuleNotFoundError
|
|
If no module 'scenarios.<name>' exists.
|
|
ValueError
|
|
If the module contains no concrete ScenarioBase subclass.
|
|
"""
|
|
module_path = f"scenarios.{name}"
|
|
|
|
try:
|
|
module = importlib.import_module(module_path)
|
|
except ModuleNotFoundError as exc:
|
|
available = list_scenarios()
|
|
raise ModuleNotFoundError(
|
|
f"Scenario '{name}' not found. "
|
|
f"Available scenarios: {available}"
|
|
) from exc
|
|
|
|
# Find the first concrete subclass of ScenarioBase in the module
|
|
for _, obj in inspect.getmembers(module, inspect.isclass):
|
|
if (
|
|
issubclass(obj, ScenarioBase)
|
|
and obj is not ScenarioBase
|
|
and not inspect.isabstract(obj)
|
|
):
|
|
instance = obj()
|
|
print(f"[ScenarioLoader] Loaded scenario: '{instance.name}' ({obj.__name__})")
|
|
return instance
|
|
|
|
raise ValueError(
|
|
f"No concrete ScenarioBase subclass found in module '{module_path}'. "
|
|
"Ensure your scenario class subclasses ScenarioBase and implements all abstract methods."
|
|
)
|