Для каждого проезда (пролета) требуется логировать, время, управление и состояние робота. Это можно длеать любым удобным способом, например ROS нодой для логирования - logger. Каждый проезд (это один датасет) может быть представлен в виде 3 файлов:
- time.csv - файл с временными отметками (timestamps)
- state.csv - файл с состоянием робота
- control.csv - файл с текущим управлением
Для собранных данных нет строгого формата, главное чтобы было удобно загрузить их в класс датасета
Класс для представления датасета, и по необходимости, инструменты для взаимодействия с ним. Обязательно должны быть атрибуты:
- data_t (torch.tensor of shape [num_samples, 1]): тензор с последовательность временных отметок
- data_x (torch.tensor of shape [num_samples, robot_state]): тензор с последовательностью состояния робота
- data_u (torch.tensor of shape [num_samples, control]): тензор с последовательностью управляющих воздействий
Пример RosbotDataset
Класс для представления модели робота
Должен наследоваться от torch.nn.Module
Обязательные атрибуты:
model
- (torch.nn.Module
или класс отнаследованный отtorch.nn.Module
): нейосеть, которую будут обучать
Обязательные методы:
get_optimizer
- возвращаетoptimizer
(torch.optim
)get_loss_fn
- возвращат loss функцию. Loss функция может быть определена отдельным классом (см. ниже)update_state
- возвращает следующее состояние робота с использованием нейросетевой модели (атрибутmodel
). Аргументы:state (torch.tensor of shape [num_smaples, robot_state])
- текущее состояния,control (torch.tensor of shape [num_smaples, control])
- текущее управление,dt (torch.tensor of shape [num_smaples, robot_state] or float)
- временной шаг
- forward - вычисление, выполняемое при каждом вызове модели
- calc_metrics - расчет вспомогателных метрик (возвращает словарь, где ключ - нвазание метрики и поле это значение метрики )
- plot_trajectories - Построение графиков. Аргументы: predicted_traj (torch.tensor of shape [num_smaples, robot_state]) - предсказанная траектория ground_truth_traj (torch.tensor of shape [num_smaples, robot_state]) - истинная траектория Возвращает: matplotlib.pyplot.fig
Пример RosbotModel Пример RosbotLinearModel
Класс для расчета ошибки (loss). Данный класс нужен, так как для разных роботов loss может отличаться (например при расчете loss не обязательно учитывать каждый элемент вектора состояний). Должен наследоваться от torch.nn.Module Обязательные методы:
- forward - вычисление, выполняемое при каждом вызове (расчет loss)
Пример RosbotModelLoss
Обучение в нейросети происходит в Trainer.fit Аргументы:
- model - модель робота
- train_data - лист датасетов для обучения
- val_data - лист датасетов для валидирования
- epochs_num - количество эпох обучения
- batch_size - размер батча
- rollout_size - длина проезда для обучения
- main_metric - название ключевой метрики, по которой выбирается лучшая модель
- device - дейвайс на котором производятся вычисления(cuda или cpu)
- use_wandb - флаг, если true, графики и метрики будут логироваться в wandb
Пример rosbot_train
Для анализа результатов рекомендуется использовать wandb
Для быстрой смены гиперпараметров возможно использование конфиг параметров. Пример конфига gz-rosbot_1.yaml Пример использования rosbot_train