From 2c43000b89b2ec8a471cae70c132e1666154f2e7 Mon Sep 17 00:00:00 2001 From: Julian Buechel Date: Mon, 6 Jan 2025 08:58:23 +0100 Subject: [PATCH] added FP preset Signed-off-by: Julian Buechel --- src/aihwkit/simulator/presets/__init__.py | 3 +-- src/aihwkit/simulator/presets/inference.py | 24 +++++++++++++++++++++- tests/test_presets.py | 21 ++++++++++++++++++- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/aihwkit/simulator/presets/__init__.py b/src/aihwkit/simulator/presets/__init__.py index 39f50907..62730742 100644 --- a/src/aihwkit/simulator/presets/__init__.py +++ b/src/aihwkit/simulator/presets/__init__.py @@ -68,8 +68,7 @@ MixedPrecisionGokmenVlasovPreset, MixedPrecisionPCMPreset, ) -from .inference import StandardHWATrainingPreset - +from .inference import StandardHWATrainingPreset, FloatingPointPreset from .devices import ( ReRamESPresetDevice, ReRamSBPresetDevice, diff --git a/src/aihwkit/simulator/presets/inference.py b/src/aihwkit/simulator/presets/inference.py index 317385e2..057b74c1 100644 --- a/src/aihwkit/simulator/presets/inference.py +++ b/src/aihwkit/simulator/presets/inference.py @@ -10,7 +10,7 @@ from typing import Optional from dataclasses import dataclass, field -from aihwkit.simulator.configs.configs import InferenceRPUConfig +from aihwkit.simulator.configs.configs import InferenceRPUConfig, TorchInferenceRPUConfig from aihwkit.simulator.parameters import ( MappingParameter, IOParameters, @@ -34,6 +34,28 @@ # Inference +@dataclass +class FloatingPointPreset(TorchInferenceRPUConfig): + """Preset configuration for FP-like AIMC (Analog In-Mememory Compute) + accuracy evaluation/training. + + This preset configuration does not inject any noise in any form (weight noise + quantization etc.) and is equivalent to the FP model. + """ + + mapping: MappingParameter = field( + default_factory=lambda: MappingParameter(max_input_size=0, max_output_size=0) + ) + + forward: IOParameters = field(default_factory=lambda: IOParameters(is_perfect=True)) + + pre_post: PrePostProcessingParameter = field( + default_factory=lambda: PrePostProcessingParameter( + input_range=InputRangeParameter(enable=False) + ) + ) + + @dataclass class StandardHWATrainingPreset(InferenceRPUConfig): """Preset configuration for AIMC (Analog In-Mememory Compute) diff --git a/tests/test_presets.py b/tests/test_presets.py index 78eee455..d78ea4da 100644 --- a/tests/test_presets.py +++ b/tests/test_presets.py @@ -6,7 +6,7 @@ """Tests for analog presets.""" -from torch import Tensor +from torch import Tensor, randn from aihwkit.simulator.tiles.analog import AnalogTile from aihwkit.simulator.presets import ( @@ -50,6 +50,7 @@ TTv2EcRamPreset, TTv2EcRamMOPreset, TTv2IdealizedPreset, + FloatingPointPreset, ) from .helpers.decorators import parametrize_over_presets from .helpers.testcases import AihwkitTestCase @@ -131,3 +132,21 @@ def test_tile_preset(self): self.assertEqual(tile_biases, None) # TODO: disabled as the comparison needs to take into account noise # self.assertTensorAlmostEqual(tile_weights, weights) + + +class PresetTestFP(AihwkitTestCase): + """Test for FP preset.""" + + def test_tile_preset(self): + """Test fwd behavior of FP preset.""" + out_size = 2 + in_size = 3 + weights = randn(out_size, in_size) + inp = randn(in_size) + fp_out = inp @ weights.T + + rpu_config = FloatingPointPreset() + analog_tile = AnalogTile(out_size, in_size, rpu_config, bias=False) + analog_tile.set_weights(weights) + tile_out = analog_tile(inp) + self.assertTensorAlmostEqual(fp_out, tile_out)