diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6c32926f..06bd9852 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -66,6 +66,10 @@ jobs: - name: Run tests run: | make test + - name: Typechecking + if: ${{ startsWith(runner.os, 'macOS') }} + run: | + make mypy report-coverage: # Report coverage from python 3.10 and mac-os. May change later runs-on: ${{ matrix.os }} strategy: diff --git a/.gitignore b/.gitignore index 62305a52..da20c6d3 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ __pycache__/ # C extensions *.so +*.swp + # Distribution / packaging .Python build/ diff --git a/Makefile b/Makefile index 674c121b..89b4b179 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ #* Variables PYTHON := python3 PYTHONPATH := `pwd` -AUTOFLAKE8_ARGS := -r --exclude '__init__.py' --keep-pass-after-docstring +AUTOFLAKE_ARGS := -r #* Poetry .PHONY: poetry-download poetry-download: @@ -47,19 +47,23 @@ flake8: poetry run flake8 --version poetry run flake8 elastica tests -.PHONY: autoflake8-check -autoflake8-check: - poetry run autoflake8 --version - poetry run autoflake8 $(AUTOFLAKE8_ARGS) elastica tests examples - poetry run autoflake8 --check $(AUTOFLAKE8_ARGS) elastica tests examples +.PHONY: autoflake-check +autoflake-check: + poetry run autoflake --version + poetry run autoflake $(AUTOFLAKE_ARGS) elastica tests examples + poetry run autoflake --check $(AUTOFLAKE_ARGS) elastica tests examples -.PHONY: autoflake8-format -autoflake8-format: - poetry run autoflake8 --version - poetry run autoflake8 --in-place $(AUTOFLAKE8_ARGS) elastica tests examples +.PHONY: autoflake-format +autoflake-format: + poetry run autoflake --version + poetry run autoflake --in-place $(AUTOFLAKE_ARGS) elastica tests examples .PHONY: format-codestyle -format-codestyle: black flake8 +format-codestyle: black autoflake-format + +.PHONY: mypy +mypy: + poetry run mypy --config-file pyproject.toml elastica .PHONY: test test: @@ -74,14 +78,14 @@ test_coverage_xml: NUMBA_DISABLE_JIT=1 poetry run pytest --cov=elastica --cov-report=xml .PHONY: check-codestyle -check-codestyle: black-check flake8 autoflake8-check +check-codestyle: black-check flake8 autoflake-check .PHONY: formatting formatting: format-codestyle .PHONY: update-dev-deps update-dev-deps: - poetry add -D pytest@latest coverage@latest pytest-html@latest pytest-cov@latest black@latest + poetry add -D mypy@latest pytest@latest coverage@latest pytest-html@latest pytest-cov@latest black@latest #* Cleaning .PHONY: pycache-remove @@ -92,6 +96,10 @@ pycache-remove: dsstore-remove: find . | grep -E ".DS_Store" | xargs rm -rf +.PHONY: mypycache-remove +mypycache-remove: + find . | grep -E ".mypy_cache" | xargs rm -rf + .PHONY: ipynbcheckpoints-remove ipynbcheckpoints-remove: find . | grep -E ".ipynb_checkpoints" | xargs rm -rf @@ -105,7 +113,7 @@ build-remove: rm -rf build/ .PHONY: cleanup -cleanup: pycache-remove dsstore-remove ipynbcheckpoints-remove pytestcache-remove +cleanup: pycache-remove dsstore-remove ipynbcheckpoints-remove pytestcache-remove mypycache-remove all: format-codestyle cleanup test diff --git a/docs/advanced/PackageDesign.md b/docs/advanced/PackageDesign.md index 857a3c04..1ef46121 100644 --- a/docs/advanced/PackageDesign.md +++ b/docs/advanced/PackageDesign.md @@ -1,8 +1,114 @@ -# Code Design: Mixin and Composition +# Code Design + +## Mixin and Composition Elastica package follows Mixin and composition design patterns that may be unfamiliar to users. Here is a collection of references that introduce the package design. -## References +### References - [stackoverflow discussion on Mixin](https://stackoverflow.com/questions/533631/what-is-a-mixin-and-why-are-they-useful) - [example of Mixin: python collections](https://docs.python.org/dev/library/collections.abc.html) + +## Duck Typing + +Elastica package uses duck typing to allow users to define their own classes and functions. Here is a `typing.Protocol` structure that is used in the package. + +### Systems + +``` {mermaid} + flowchart LR + direction RL + subgraph Systems Protocol + direction RL + SLBD(SlenderBodyGeometryProtool) + SymST["SymplecticSystem:\n• KinematicStates/Rates\n• DynamicStates/Rates"] + style SymST text-align:left + ExpST["ExplicitSystem:\n• States (Unused)"] + style ExpST text-align:left + P((position\nvelocity\nacceleration\n..)) --> SLBD + subgraph StaticSystemType + Surface + Mesh + end + subgraph SystemType + direction TB + Rod + RigidBody + end + SLBD --> SymST + SystemType --> SymST + SLBD --> ExpST + SystemType --> ExpST + end + subgraph Timestepper Protocol + direction TB + StP["StepperProtocol\n• step(SystemCollection, time, dt)"] + style StP text-align:left + SymplecticStepperProtocol["SymplecticStepperProtocol\n• PositionVerlet"] + style SymplecticStepperProtocol text-align:left + ExpplicitStepperProtocol["ExpplicitStepperProtocol\n(Unused)"] + end + + subgraph SystemCollection + + end + SymST --> SystemCollection --> SymplecticStepperProtocol + ExpST --> SystemCollection --> ExpplicitStepperProtocol + StaticSystemType --> SystemCollection + +``` + +### System Collection (Build memory block) + +``` {mermaid} + flowchart LR + Sys((Systems)) + St((Stepper)) + subgraph SystemCollectionType + direction LR + StSys["StaticSystem:\n• Surface\n• Mesh"] + style StSys text-align:left + DynSys["DynamicSystem:\n• Rod\n  • CosseratRod\n• RigidBody\n  • Sphere\n  • Cylinder"] + style DynSys text-align:left + + BlDynSys["BlockSystemType:\n• BlockCosseratRod\n• BlockRigidBody"] + style BlDynSys text-align:left + + F{{"Feature Group (OperatorGroup):\n• Synchronize\n• Constrain values\n• Constrain rates\n• Callback"}} + style F text-align:left + end + Sys --> StSys --> F + Sys --> DynSys -->|Finalize| BlDynSys --> St + DynSys --> F <--> St + +``` + +### System Collection (Features) + +``` {mermaid} + flowchart LR + Sys((Systems)) + St((Stepper)) + subgraph SystemCollectionType + direction LR + StSys["StaticSystem:\n• Surface\n• Mesh"] + style StSys text-align:left + DynSys["DynamicSystem:\n• Rod\n  • CosseratRod\n• RigidBody\n  • Sphere\n  • Cylinder"] + style DynSys text-align:left + + subgraph Feature + direction LR + Forcing -->|add_forcing_to| Synchronize + Constraints -->|constrain| ConstrainValues + Constraints -->|constrain| ConstrainRates + Contact -->|detect_contact_between| Synchronize + Connection -->|connect| Synchronize + Damping -->|dampen| ConstrainRates + Callback -->|collect_diagnosis| CallbackGroup + end + end + Sys --> StSys --> Feature + Sys --> DynSys + DynSys --> Feature <--> St + +``` diff --git a/docs/conf.py b/docs/conf.py index 6d204796..f2ca1f39 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -41,6 +41,7 @@ #'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx.ext.mathjax', + "sphinxcontrib.mermaid", 'numpydoc', 'myst_parser', ] @@ -98,3 +99,6 @@ # -- Options for numpydoc --------------------------------------------------- numpydoc_show_class_members = False + +# -- Mermaid configuration --------------------------------------------------- +mermaid_params = ['--theme', 'neutral'] diff --git a/elastica/__init__.py b/elastica/__init__.py index a55a6fb0..21dfddb7 100644 --- a/elastica/__init__.py +++ b/elastica/__init__.py @@ -1,7 +1,5 @@ from collections import defaultdict from elastica.rod.knot_theory import ( - KnotTheory, - KnotTheoryCompatibleProtocol, compute_link, compute_twist, compute_writhe, @@ -19,8 +17,6 @@ GeneralConstraint, FixedConstraint, HelicalBucklingBC, - FreeRod, - OneEndFixedRod, ) from elastica.external_forces import ( NoForces, @@ -38,10 +34,8 @@ ) from elastica.joint import ( FreeJoint, - ExternalContact, FixedJoint, HingeJoint, - SelfContact, ) from elastica.contact_forces import ( NoContact, @@ -79,7 +73,6 @@ ) from elastica._linalg import levi_civita_tensor from elastica.utils import isqrt -from elastica.typing import RodType, SystemType, AllowedContactType from elastica.timestepper import ( integrate, PositionVerlet, diff --git a/elastica/_calculus.py b/elastica/_calculus.py index eca9829b..3cdda50d 100644 --- a/elastica/_calculus.py +++ b/elastica/_calculus.py @@ -1,23 +1,16 @@ __doc__ = """ Quadrature and difference kernels """ import numpy as np from numpy import zeros, empty +from numpy.typing import NDArray +import numba from numba import njit from elastica.reset_functions_for_block_structure._reset_ghost_vector_or_scalar import ( _reset_vector_ghost, ) -import functools -@functools.lru_cache(maxsize=2) -def _get_zero_array(dim, ndim): - if ndim == 1: - return 0.0 - if ndim == 2: - return np.zeros((dim, 1)) - - -@njit(cache=True) -def _trapezoidal(array_collection): +@njit(cache=True) # type: ignore +def _trapezoidal(array_collection: NDArray[np.float64]) -> NDArray[np.float64]: """ Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way @@ -62,8 +55,10 @@ def _trapezoidal(array_collection): return temp_collection -@njit(cache=True) -def _trapezoidal_for_block_structure(array_collection, ghost_idx): +@njit(cache=True) # type: ignore +def _trapezoidal_for_block_structure( + array_collection: NDArray[np.float64], ghost_idx: NDArray[np.int32] +) -> NDArray[np.float64]: """ Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way. This form specifically for the block structure implementation and there is a reset function call, to reset @@ -114,8 +109,10 @@ def _trapezoidal_for_block_structure(array_collection, ghost_idx): return temp_collection -@njit(cache=True) -def _two_point_difference(array_collection): +@njit(cache=True) # type: ignore +def _two_point_difference( + array_collection: NDArray[np.float64], +) -> NDArray[np.float64]: """ This function does differentiation. @@ -155,8 +152,10 @@ def _two_point_difference(array_collection): return temp_collection -@njit(cache=True) -def _two_point_difference_for_block_structure(array_collection, ghost_idx): +@njit(cache=True) # type: ignore +def _two_point_difference_for_block_structure( + array_collection: NDArray[np.float64], ghost_idx: NDArray[np.int32] +) -> NDArray[np.float64]: """ This function does the differentiation, for Cosserat rod model equations. This form specifically for the block structure implementation and there is a reset function call, to @@ -206,8 +205,8 @@ def _two_point_difference_for_block_structure(array_collection, ghost_idx): return temp_collection -@njit(cache=True) -def _difference(vector): +@njit(cache=True) # type: ignore +def _difference(vector: NDArray[np.float64]) -> NDArray[np.float64]: """ This function computes difference between elements of a batch vector. @@ -237,8 +236,8 @@ def _difference(vector): return output_vector -@njit(cache=True) -def _average(vector): +@njit(cache=True) # type: ignore +def _average(vector: NDArray[np.float64]) -> NDArray[np.float64]: """ This function computes the average between elements of a vector. @@ -267,8 +266,10 @@ def _average(vector): return output_vector -@njit(cache=True) -def _clip_array(input_array, vmin, vmax): +@njit(cache=True) # type: ignore +def _clip_array( + input_array: NDArray[np.float64], vmin: np.float64, vmax: np.float64 +) -> NDArray[np.float64]: """ This function clips an array values between user defined minimum and maximum @@ -303,8 +304,8 @@ def _clip_array(input_array, vmin, vmax): return input_array -@njit(cache=True) -def _isnan_check(array): +@njit(cache=True) # type: ignore +def _isnan_check(array: NDArray) -> bool: """ This function checks if there is any nan inside the array. If there is nan, it returns True boolean. @@ -324,7 +325,7 @@ def _isnan_check(array): Python version: 2.24 µs ± 96.1 ns per loop This version: 479 ns ± 6.49 ns per loop """ - return np.isnan(array).any() + return bool(np.isnan(array).any()) position_difference_kernel = _difference diff --git a/elastica/_contact_functions.py b/elastica/_contact_functions.py index 245d9231..d1456092 100644 --- a/elastica/_contact_functions.py +++ b/elastica/_contact_functions.py @@ -22,30 +22,32 @@ _batch_matrix_transpose, _batch_vec_oneD_vec_cross, ) -import numba import numpy as np +from numpy.typing import NDArray +from numba import njit -@numba.njit(cache=True) + +@njit(cache=True) # type: ignore def _calculate_contact_forces_rod_cylinder( - x_collection_rod, - edge_collection_rod, - x_cylinder_center, - x_cylinder_tip, - edge_cylinder, - radii_sum, - length_sum, - internal_forces_rod, - external_forces_rod, - external_forces_cylinder, - external_torques_cylinder, - cylinder_director_collection, - velocity_rod, - velocity_cylinder, - contact_k, - contact_nu, - velocity_damping_coefficient, - friction_coefficient, + x_collection_rod: NDArray[np.float64], + edge_collection_rod: NDArray[np.float64], + x_cylinder_center: NDArray[np.float64], + x_cylinder_tip: NDArray[np.float64], + edge_cylinder: NDArray[np.float64], + radii_sum: NDArray[np.float64], + length_sum: NDArray[np.float64], + internal_forces_rod: NDArray[np.float64], + external_forces_rod: NDArray[np.float64], + external_forces_cylinder: NDArray[np.float64], + external_torques_cylinder: NDArray[np.float64], + cylinder_director_collection: NDArray[np.float64], + velocity_rod: NDArray[np.float64], + velocity_cylinder: NDArray[np.float64], + contact_k: np.float64, + contact_nu: np.float64, + velocity_damping_coefficient: np.float64, + friction_coefficient: np.float64, ) -> None: # We already pass in only the first n_elem x n_points = x_collection_rod.shape[1] @@ -153,24 +155,24 @@ def _calculate_contact_forces_rod_cylinder( ) -@numba.njit(cache=True) +@njit(cache=True) # type: ignore def _calculate_contact_forces_rod_rod( - x_collection_rod_one, - radius_rod_one, - length_rod_one, - tangent_rod_one, - velocity_rod_one, - internal_forces_rod_one, - external_forces_rod_one, - x_collection_rod_two, - radius_rod_two, - length_rod_two, - tangent_rod_two, - velocity_rod_two, - internal_forces_rod_two, - external_forces_rod_two, - contact_k, - contact_nu, + x_collection_rod_one: NDArray[np.float64], + radius_rod_one: NDArray[np.float64], + length_rod_one: NDArray[np.float64], + tangent_rod_one: NDArray[np.float64], + velocity_rod_one: NDArray[np.float64], + internal_forces_rod_one: NDArray[np.float64], + external_forces_rod_one: NDArray[np.float64], + x_collection_rod_two: NDArray[np.float64], + radius_rod_two: NDArray[np.float64], + length_rod_two: NDArray[np.float64], + tangent_rod_two: NDArray[np.float64], + velocity_rod_two: NDArray[np.float64], + internal_forces_rod_two: NDArray[np.float64], + external_forces_rod_two: NDArray[np.float64], + contact_k: np.float64, + contact_nu: np.float64, ) -> None: # We already pass in only the first n_elem x n_points_rod_one = x_collection_rod_one.shape[1] @@ -270,16 +272,16 @@ def _calculate_contact_forces_rod_rod( external_forces_rod_two[..., j + 1] += net_contact_force -@numba.njit(cache=True) +@njit(cache=True) # type: ignore def _calculate_contact_forces_self_rod( - x_collection_rod, - radius_rod, - length_rod, - tangent_rod, - velocity_rod, - external_forces_rod, - contact_k, - contact_nu, + x_collection_rod: NDArray[np.float64], + radius_rod: NDArray[np.float64], + length_rod: NDArray[np.float64], + tangent_rod: NDArray[np.float64], + velocity_rod: NDArray[np.float64], + external_forces_rod: NDArray[np.float64], + contact_k: np.float64, + contact_nu: np.float64, ) -> None: # We already pass in only the first n_elem x n_points_rod = x_collection_rod.shape[1] @@ -358,26 +360,26 @@ def _calculate_contact_forces_self_rod( external_forces_rod[..., j + 1] += net_contact_force -@numba.njit(cache=True) +@njit(cache=True) # type: ignore def _calculate_contact_forces_rod_sphere( - x_collection_rod, - edge_collection_rod, - x_sphere_center, - x_sphere_tip, - edge_sphere, - radii_sum, - length_sum, - internal_forces_rod, - external_forces_rod, - external_forces_sphere, - external_torques_sphere, - sphere_director_collection, - velocity_rod, - velocity_sphere, - contact_k, - contact_nu, - velocity_damping_coefficient, - friction_coefficient, + x_collection_rod: NDArray[np.float64], + edge_collection_rod: NDArray[np.float64], + x_sphere_center: NDArray[np.float64], + x_sphere_tip: NDArray[np.float64], + edge_sphere: NDArray[np.float64], + radii_sum: NDArray[np.float64], + length_sum: NDArray[np.float64], + internal_forces_rod: NDArray[np.float64], + external_forces_rod: NDArray[np.float64], + external_forces_sphere: NDArray[np.float64], + external_torques_sphere: NDArray[np.float64], + sphere_director_collection: NDArray[np.float64], + velocity_rod: NDArray[np.float64], + velocity_sphere: NDArray[np.float64], + contact_k: np.float64, + contact_nu: np.float64, + velocity_damping_coefficient: np.float64, + friction_coefficient: np.float64, ) -> None: # We already pass in only the first n_elem x n_points = x_collection_rod.shape[1] @@ -484,20 +486,20 @@ def _calculate_contact_forces_rod_sphere( ) -@numba.njit(cache=True) +@njit(cache=True) # type: ignore def _calculate_contact_forces_rod_plane( - plane_origin, - plane_normal, - surface_tol, - k, - nu, - radius, - mass, - position_collection, - velocity_collection, - internal_forces, - external_forces, -): + plane_origin: NDArray[np.float64], + plane_normal: NDArray[np.float64], + surface_tol: np.float64, + k: np.float64, + nu: np.float64, + radius: NDArray[np.float64], + mass: NDArray[np.float64], + position_collection: NDArray[np.float64], + velocity_collection: NDArray[np.float64], + internal_forces: NDArray[np.float64], + external_forces: NDArray[np.float64], +) -> tuple[NDArray[np.float64], NDArray[np.intp]]: """ This function computes the plane force response on the element, in the case of contact. Contact model given in Eqn 4.8 Gazzola et. al. RSoS 2018 paper @@ -569,32 +571,32 @@ def _calculate_contact_forces_rod_plane( return (_batch_norm(plane_response_force), no_contact_point_idx) -@numba.njit(cache=True) +@njit(cache=True) # type: ignore def _calculate_contact_forces_rod_plane_with_anisotropic_friction( - plane_origin, - plane_normal, - surface_tol, - slip_velocity_tol, - k, - nu, - kinetic_mu_forward, - kinetic_mu_backward, - kinetic_mu_sideways, - static_mu_forward, - static_mu_backward, - static_mu_sideways, - radius, - mass, - tangents, - position_collection, - director_collection, - velocity_collection, - omega_collection, - internal_forces, - external_forces, - internal_torques, - external_torques, -): + plane_origin: NDArray[np.float64], + plane_normal: NDArray[np.float64], + surface_tol: np.float64, + slip_velocity_tol: np.float64, + k: np.float64, + nu: np.float64, + kinetic_mu_forward: np.float64, + kinetic_mu_backward: np.float64, + kinetic_mu_sideways: np.float64, + static_mu_forward: np.float64, + static_mu_backward: np.float64, + static_mu_sideways: np.float64, + radius: NDArray[np.float64], + mass: NDArray[np.float64], + tangents: NDArray[np.float64], + position_collection: NDArray[np.float64], + director_collection: NDArray[np.float64], + velocity_collection: NDArray[np.float64], + omega_collection: NDArray[np.float64], + internal_forces: NDArray[np.float64], + external_forces: NDArray[np.float64], + internal_torques: NDArray[np.float64], + external_torques: NDArray[np.float64], +) -> None: ( plane_response_force_mag, no_contact_point_idx, @@ -782,18 +784,18 @@ def _calculate_contact_forces_rod_plane_with_anisotropic_friction( ) -@numba.njit(cache=True) +@njit(cache=True) # type: ignore def _calculate_contact_forces_cylinder_plane( - plane_origin, - plane_normal, - surface_tol, - k, - nu, - length, - position_collection, - velocity_collection, - external_forces, -): + plane_origin: NDArray[np.float64], + plane_normal: NDArray[np.float64], + surface_tol: np.float64, + k: np.float64, + nu: np.float64, + length: NDArray[np.float64], + position_collection: NDArray[np.float64], + velocity_collection: NDArray[np.float64], + external_forces: NDArray[np.float64], +) -> tuple[NDArray[np.float64], NDArray[np.intp]]: # Compute plane response force # total_forces = system.internal_forces + system.external_forces diff --git a/elastica/_linalg.py b/elastica/_linalg.py index a4995ab2..c4ce37d3 100644 --- a/elastica/_linalg.py +++ b/elastica/_linalg.py @@ -1,5 +1,6 @@ __doc__ = """ Convenient linear algebra kernels """ import numpy as np +from numpy.typing import NDArray from numba import njit from numpy import sqrt import functools @@ -8,7 +9,7 @@ @functools.lru_cache(maxsize=1) -def levi_civita_tensor(dim): +def levi_civita_tensor(dim: int) -> NDArray[np.float64]: """ Parameters @@ -27,8 +28,10 @@ def levi_civita_tensor(dim): return epsilon -@njit(cache=True) -def _batch_matvec(matrix_collection, vector_collection): +@njit(cache=True) # type: ignore +def _batch_matvec( + matrix_collection: NDArray[np.float64], vector_collection: NDArray[np.float64] +) -> NDArray[np.float64]: """ This function does batch matrix and batch vector product @@ -58,8 +61,11 @@ def _batch_matvec(matrix_collection, vector_collection): return output_vector -@njit(cache=True) -def _batch_matmul(first_matrix_collection, second_matrix_collection): +@njit(cache=True) # type: ignore +def _batch_matmul( + first_matrix_collection: NDArray[np.float64], + second_matrix_collection: NDArray[np.float64], +) -> NDArray[np.float64]: """ This is batch matrix matrix multiplication function. Only batch of 3x3 matrices can be multiplied. @@ -92,8 +98,11 @@ def _batch_matmul(first_matrix_collection, second_matrix_collection): return output_matrix -@njit(cache=True) -def _batch_cross(first_vector_collection, second_vector_collection): +@njit(cache=True) # type: ignore +def _batch_cross( + first_vector_collection: NDArray[np.float64], + second_vector_collection: NDArray[np.float64], +) -> NDArray[np.float64]: """ This function does cross product between two batch vectors. @@ -132,8 +141,10 @@ def _batch_cross(first_vector_collection, second_vector_collection): return output_vector -@njit(cache=True) -def _batch_vec_oneD_vec_cross(first_vector_collection, second_vector): +@njit(cache=True) # type: ignore +def _batch_vec_oneD_vec_cross( + first_vector_collection: NDArray[np.float64], second_vector: NDArray[np.float64] +) -> NDArray[np.float64]: """ This function does cross product between batch vector and a 1D vector. Idea of having this function is that, for friction calculations, we dont @@ -176,8 +187,10 @@ def _batch_vec_oneD_vec_cross(first_vector_collection, second_vector): return output_vector -@njit(cache=True) -def _batch_dot(first_vector, second_vector): +@njit(cache=True) # type: ignore +def _batch_dot( + first_vector: NDArray[np.float64], second_vector: NDArray[np.float64] +) -> NDArray[np.float64]: """ This function does batch vec and batch vec dot product. Parameters @@ -203,8 +216,8 @@ def _batch_dot(first_vector, second_vector): return output_vector -@njit(cache=True) -def _batch_norm(vector): +@njit(cache=True) # type: ignore +def _batch_norm(vector: NDArray[np.float64]) -> NDArray[np.float64]: """ This function computes norm of a batch vector Parameters @@ -232,8 +245,10 @@ def _batch_norm(vector): return output_vector -@njit(cache=True) -def _batch_product_i_k_to_ik(vector1, vector2): +@njit(cache=True) # type: ignore +def _batch_product_i_k_to_ik( + vector1: NDArray[np.float64], vector2: NDArray[np.float64] +) -> NDArray[np.float64]: """ This function does outer product following 'i,k->ik'. vector1 has shape of 3 and vector 2 has shape of blocksize @@ -261,8 +276,10 @@ def _batch_product_i_k_to_ik(vector1, vector2): return output_vector -@njit(cache=True) -def _batch_product_i_ik_to_k(vector1, vector2): +@njit(cache=True) # type: ignore +def _batch_product_i_ik_to_k( + vector1: NDArray[np.float64], vector2: NDArray[np.float64] +) -> NDArray[np.float64]: """ This function does the following product 'i,ik->k' This function do dot product between a vector of 3 elements @@ -292,8 +309,10 @@ def _batch_product_i_ik_to_k(vector1, vector2): return output_vector -@njit(cache=True) -def _batch_product_k_ik_to_ik(vector1, vector2): +@njit(cache=True) # type: ignore +def _batch_product_k_ik_to_ik( + vector1: NDArray[np.float64], vector2: NDArray[np.float64] +) -> NDArray[np.float64]: """ This function does the following product 'k, ik->ik' Parameters @@ -321,8 +340,10 @@ def _batch_product_k_ik_to_ik(vector1, vector2): return output_vector -@njit(cache=True) -def _batch_vector_sum(vector1, vector2): +@njit(cache=True) # type: ignore +def _batch_vector_sum( + vector1: NDArray[np.float64], vector2: NDArray[np.float64] +) -> NDArray[np.float64]: """ This function is for summing up two vectors. Although this function is not faster than pure python implementation @@ -351,10 +372,12 @@ def _batch_vector_sum(vector1, vector2): return output_vector -@njit(cache=True) -def _batch_matrix_transpose(input_matrix): +@njit(cache=True) # type: ignore +def _batch_matrix_transpose(input_matrix: NDArray[np.float64]) -> NDArray[np.float64]: """ This function takes an batch input matrix and transpose it. + [i,j,k] -> [j,i,k] + Parameters ---------- input_matrix diff --git a/elastica/_rotations.py b/elastica/_rotations.py index 25ec1421..11d279ee 100644 --- a/elastica/_rotations.py +++ b/elastica/_rotations.py @@ -8,14 +8,17 @@ from numpy import cos from numpy import sqrt from numpy import arccos +from numpy.typing import NDArray from numba import njit from elastica._linalg import _batch_matmul -@njit(cache=True) -def _get_rotation_matrix(scale: float, axis_collection): +@njit(cache=True) # type: ignore +def _get_rotation_matrix( + scale: np.float64, axis_collection: NDArray[np.float64] +) -> NDArray[np.float64]: blocksize = axis_collection.shape[1] rot_mat = np.empty((3, 3, blocksize)) @@ -48,8 +51,12 @@ def _get_rotation_matrix(scale: float, axis_collection): return rot_mat -@njit(cache=True) -def _rotate(director_collection, scale: float, axis_collection): +@njit(cache=True) # type: ignore +def _rotate( + director_collection: NDArray[np.float64], + scale: np.float64, + axis_collection: NDArray[np.float64], +) -> NDArray[np.float64]: """ Does alibi rotations https://en.wikipedia.org/wiki/Rotation_matrix#Ambiguities @@ -73,8 +80,8 @@ def _rotate(director_collection, scale: float, axis_collection): ) -@njit(cache=True) -def _inv_rotate(director_collection): +@njit(cache=True) # type: ignore +def _inv_rotate(director_collection: NDArray[np.float64]) -> NDArray[np.float64]: """ Calculated rate of change using Rodrigues' formula @@ -156,12 +163,15 @@ def _inv_rotate(director_collection): return vector_collection +_generate_skew_map_sentinel = (0, 0, 0) + + # TODO: Below contains numpy-only implementations @functools.lru_cache(maxsize=1) -def _generate_skew_map(dim: int): +def _generate_skew_map(dim: int) -> list[tuple[int, int, int]]: # TODO Documentation # Preallocate - mapping_list = [None] * ((dim**2 - dim) // 2) + mapping_list = [_generate_skew_map_sentinel] * ((dim**2 - dim) // 2) # Indexing (i,j), j is the fastest changing # r = 2, r here is rank, we deal with only matrices for index, (i, j) in enumerate(combinations(range(dim), r=2)): @@ -185,7 +195,7 @@ def _generate_skew_map(dim: int): @functools.lru_cache(maxsize=1) -def _get_skew_map(dim): +def _get_skew_map(dim: int) -> tuple[tuple[int, int, int], ...]: """Generates mapping from src to target skew-symmetric operator For input vector V and output Matrix M (represented in lexicographical index), @@ -208,7 +218,7 @@ def _get_skew_map(dim): @functools.lru_cache(maxsize=1) -def _get_inv_skew_map(dim): +def _get_inv_skew_map(dim: int) -> tuple[tuple[int, int, int], ...]: # TODO Documentation # (vec_src, mat_i, mat_j, sign) mapping_list = _generate_skew_map(dim) @@ -219,7 +229,7 @@ def _get_inv_skew_map(dim): @functools.lru_cache(maxsize=1) -def _get_diag_map(dim): +def _get_diag_map(dim: int) -> tuple[int, ...]: """Generates lexicographic mapping to diagonal in a serialized matrix-type For input dimension dim we calculate mapping to * in Matrix M below @@ -231,17 +241,10 @@ def _get_diag_map(dim): in a dimension agnostic way. """ - # Preallocate - mapping_list = [None] * dim - - # Store linear indices - for dim_iter in range(dim): - mapping_list[dim_iter] = dim_iter * (dim + 1) - - return tuple(mapping_list) + return tuple([dim_iter * (dim + 1) for dim_iter in range(dim)]) -def _skew_symmetrize(vector): +def _skew_symmetrize(vector: NDArray[np.float64]) -> NDArray[np.float64]: """ Parameters @@ -276,7 +279,7 @@ def _skew_symmetrize(vector): # This is purely for testing and optimization sake # While calculating u^2, use u with einsum instead, as it is tad bit faster -def _skew_symmetrize_sq(vector): +def _skew_symmetrize_sq(vector: NDArray[np.float64]) -> NDArray[np.float64]: """ Generate the square of an orthogonal matrix from vector elements @@ -298,12 +301,11 @@ def _skew_symmetrize_sq(vector): hardcoded : 23.1 µs ± 481 ns per loop this version: 14.1 µs ± 96.9 ns per loop """ - dim, _ = vector.shape # First generate array of [x^2, xy, xz, yx, y^2, yz, zx, zy, z^2] # across blocksize # This is slightly faster than doing v[np.newaxis,:,:] * v[:,np.newaxis,:] - products_xy = np.einsum("ik,jk->ijk", vector, vector) + products_xy: NDArray[np.float64] = np.einsum("ik,jk->ijk", vector, vector) # No copy made here, as we do not change memory layout # products_xy = products_xy.reshape((dim * dim, -1)) @@ -335,7 +337,9 @@ def _skew_symmetrize_sq(vector): return products_xy -def _get_skew_symmetric_pair(vector_collection): +def _get_skew_symmetric_pair( + vector_collection: NDArray[np.float64], +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: """ Parameters @@ -351,7 +355,7 @@ def _get_skew_symmetric_pair(vector_collection): return u, u_sq -def _inv_skew_symmetrize(matrix): +def _inv_skew_symmetrize(matrix: NDArray[np.float64]) -> NDArray[np.float64]: """ Return the vector elements from a skew-symmetric matrix M diff --git a/elastica/_synchronize_periodic_boundary.py b/elastica/_synchronize_periodic_boundary.py index b4fe87b4..5af2b51e 100644 --- a/elastica/_synchronize_periodic_boundary.py +++ b/elastica/_synchronize_periodic_boundary.py @@ -2,17 +2,23 @@ """These functions are used to synchronize periodic boundaries for ring rods. """ ) +from typing import Any from numba import njit +import numpy as np +from numpy.typing import NDArray from elastica.boundary_conditions import ConstraintBase +from elastica.typing import RodType -@njit(cache=True) -def _synchronize_periodic_boundary_of_vector_collection(input, periodic_idx): +@njit(cache=True) # type: ignore +def _synchronize_periodic_boundary_of_vector_collection( + input_array: NDArray[np.float64], periodic_idx: NDArray[np.float64] +) -> None: """ This function synchronizes the periodic boundaries of a vector collection. Parameters ---------- - input : numpy.ndarray + input_array : numpy.ndarray 2D (dim, blocksize) array containing data with 'float' type. Vector that is going to be synched. periodic_idx : numpy.ndarray 2D (2, n_periodic_boundary) array containing data with 'float' type. Vector containing periodic boundary @@ -24,16 +30,18 @@ def _synchronize_periodic_boundary_of_vector_collection(input, periodic_idx): """ for i in range(3): for k in range(periodic_idx.shape[1]): - input[i, periodic_idx[0, k]] = input[i, periodic_idx[1, k]] + input_array[i, periodic_idx[0, k]] = input_array[i, periodic_idx[1, k]] -@njit(cache=True) -def _synchronize_periodic_boundary_of_matrix_collection(input, periodic_idx): +@njit(cache=True) # type: ignore +def _synchronize_periodic_boundary_of_matrix_collection( + input_array: NDArray[np.float64], periodic_idx: NDArray[np.float64] +) -> None: """ This function synchronizes the periodic boundaries of a matrix collection. Parameters ---------- - input : numpy.ndarray + input_array : numpy.ndarray 2D (dim, dim, blocksize) array containing data with 'float' type. Matrix collection that is going to be synched. periodic_idx : numpy.ndarray 2D (2, n_periodic_boundary) array containing data with 'float' type. Vector containing periodic boundary @@ -46,17 +54,21 @@ def _synchronize_periodic_boundary_of_matrix_collection(input, periodic_idx): for i in range(3): for j in range(3): for k in range(periodic_idx.shape[1]): - input[i, j, periodic_idx[0, k]] = input[i, j, periodic_idx[1, k]] + input_array[i, j, periodic_idx[0, k]] = input_array[ + i, j, periodic_idx[1, k] + ] -@njit(cache=True) -def _synchronize_periodic_boundary_of_scalar_collection(input, periodic_idx): +@njit(cache=True) # type: ignore +def _synchronize_periodic_boundary_of_scalar_collection( + input_array: NDArray[np.float64], periodic_idx: NDArray[np.float64] +) -> None: """ This function synchronizes the periodic boundaries of a scalar collection. Parameters ---------- - input : numpy.ndarray + input_array : numpy.ndarray 2D (dim, dim, blocksize) array containing data with 'float' type. Scalar collection that is going to be synched. periodic_idx : numpy.ndarray 2D (2, n_periodic_boundary) array containing data with 'float' type. Vector containing periodic boundary @@ -67,7 +79,7 @@ def _synchronize_periodic_boundary_of_scalar_collection(input, periodic_idx): """ for k in range(periodic_idx.shape[1]): - input[periodic_idx[0, k]] = input[periodic_idx[1, k]] + input_array[periodic_idx[0, k]] = input_array[periodic_idx[1, k]] class _ConstrainPeriodicBoundaries(ConstraintBase): @@ -76,10 +88,11 @@ class _ConstrainPeriodicBoundaries(ConstraintBase): is to synchronize periodic boundaries of ring rod. """ - def __init__(self, **kwargs): + # TODO: improve typing + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def constrain_values(self, rod, time): + def constrain_values(self, rod: RodType, time: np.float64) -> None: _synchronize_periodic_boundary_of_vector_collection( rod.position_collection, rod.periodic_boundary_nodes_idx ) @@ -87,7 +100,7 @@ def constrain_values(self, rod, time): rod.director_collection, rod.periodic_boundary_elems_idx ) - def constrain_rates(self, rod, time): + def constrain_rates(self, rod: RodType, time: np.float64) -> None: _synchronize_periodic_boundary_of_vector_collection( rod.velocity_collection, rod.periodic_boundary_nodes_idx ) diff --git a/elastica/boundary_conditions.py b/elastica/boundary_conditions.py index adca8b6e..1075821c 100644 --- a/elastica/boundary_conditions.py +++ b/elastica/boundary_conditions.py @@ -1,9 +1,9 @@ __doc__ = """ Built-in boundary condition implementationss """ -import warnings -from typing import Optional +from typing import Any, Optional, TypeVar, Generic import numpy as np +from numpy.typing import NDArray from abc import ABC, abstractmethod @@ -11,10 +11,13 @@ from elastica._linalg import _batch_matvec, _batch_matrix_transpose from elastica._rotations import _get_rotation_matrix -from elastica.typing import SystemType, RodType +from elastica.typing import SystemType, RodType, RigidBodyType, ConstrainingIndex -class ConstraintBase(ABC): +S = TypeVar("S") + + +class ConstraintBase(ABC, Generic[S]): """Base class for constraint and displacement boundary condition implementation. Notes @@ -30,19 +33,25 @@ class ConstraintBase(ABC): """ - _system: SystemType - _constrained_position_idx: np.ndarray - _constrained_director_idx: np.ndarray + _system: S + _constrained_position_idx: NDArray[np.int32] + _constrained_director_idx: NDArray[np.int32] - def __init__(self, *args, **kwargs): + def __init__( + self, + *args: Any, + constrained_position_idx: ConstrainingIndex = (), + constrained_director_idx: ConstrainingIndex = (), + **kwargs: Any, + ) -> None: """Initialize boundary condition""" try: self._system = kwargs["_system"] self._constrained_position_idx = np.array( - kwargs.get("constrained_position_idx", []), dtype=int + constrained_position_idx, dtype=np.int32 ) self._constrained_director_idx = np.array( - kwargs.get("constrained_director_idx", []), dtype=int + constrained_director_idx, dtype=np.int32 ) except KeyError: raise KeyError( @@ -50,25 +59,22 @@ def __init__(self, *args, **kwargs): ) @property - def system(self) -> SystemType: + def system(self) -> S: """get system (rod or rigid body) reference""" return self._system @property - def constrained_position_idx(self) -> Optional[np.ndarray]: + def constrained_position_idx(self) -> NDArray[np.int32]: """get position-indices passed to "using" """ - # TODO: This should be immutable somehow return self._constrained_position_idx @property - def constrained_director_idx(self) -> Optional[np.ndarray]: + def constrained_director_idx(self) -> NDArray[np.int32]: """get director-indices passed to "using" """ - # TODO: This should be immutable somehow return self._constrained_director_idx @abstractmethod - def constrain_values(self, system: SystemType, time: float) -> None: - # TODO: In the future, we can remove rod and use self.system + def constrain_values(self, system: S, time: np.float64) -> None: """ Constrain values (position and/or directors) of a rod object. @@ -82,8 +88,7 @@ def constrain_values(self, system: SystemType, time: float) -> None: pass @abstractmethod - def constrain_rates(self, system: SystemType, time: float) -> None: - # TODO: In the future, we can remove rod and use self.system + def constrain_rates(self, system: S, time: np.float64) -> None: """ Constrain rates (velocity and/or omega) of a rod object. @@ -103,27 +108,22 @@ class FreeBC(ConstraintBase): Boundary condition template. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values( + self, system: "RodType | RigidBodyType", time: np.float64 + ) -> None: """In FreeBC, this routine simply passes.""" pass - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates( + self, system: "RodType | RigidBodyType", time: np.float64 + ) -> None: """In FreeBC, this routine simply passes.""" pass -class FreeRod(FreeBC): - # Please clear this part beyond version 0.3.0 - """Deprecated 0.2.1: Same implementation as FreeBC""" - warnings.warn( - "FreeRod is deprecated and renamed to FreeBC. The deprecated name will be removed in the future.", - DeprecationWarning, - ) - - class OneEndFixedBC(ConstraintBase): """ This boundary condition class fixes one end of the rod. Currently, @@ -143,7 +143,12 @@ class OneEndFixedBC(ConstraintBase): ... ) """ - def __init__(self, fixed_position, fixed_directors, **kwargs): + def __init__( + self, + fixed_position: tuple[int, ...], + fixed_directors: tuple[int, ...], + **kwargs: Any, + ) -> None: """ Initialization of the constraint. Any parameter passed to 'using' will be available in kwargs. @@ -159,7 +164,9 @@ def __init__(self, fixed_position, fixed_directors, **kwargs): self.fixed_position_collection = np.array(fixed_position) self.fixed_directors_collection = np.array(fixed_directors) - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values( + self, system: "RodType | RigidBodyType", time: np.float64 + ) -> None: # system.position_collection[..., 0] = self.fixed_position # system.director_collection[..., 0] = self.fixed_directors self.compute_constrain_values( @@ -169,7 +176,9 @@ def constrain_values(self, system: SystemType, time: float) -> None: self.fixed_directors_collection, ) - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates( + self, system: "RodType | RigidBodyType", time: np.float64 + ) -> None: # system.velocity_collection[..., 0] = 0.0 # system.omega_collection[..., 0] = 0.0 self.compute_constrain_rates( @@ -178,13 +187,13 @@ def constrain_rates(self, system: SystemType, time: float) -> None: ) @staticmethod - @njit(cache=True) + @njit(cache=True) # type: ignore def compute_constrain_values( - position_collection, - fixed_position_collection, - director_collection, - fixed_directors_collection, - ): + position_collection: NDArray[np.float64], + fixed_position_collection: NDArray[np.float64], + director_collection: NDArray[np.float64], + fixed_directors_collection: NDArray[np.float64], + ) -> None: """ Computes constrain values in numba njit decorator @@ -207,8 +216,11 @@ def compute_constrain_values( director_collection[..., 0] = fixed_directors_collection @staticmethod - @njit(cache=True) - def compute_constrain_rates(velocity_collection, omega_collection): + @njit(cache=True) # type: ignore + def compute_constrain_rates( + velocity_collection: NDArray[np.float64], + omega_collection: NDArray[np.float64], + ) -> None: """ Compute contrain rates in numba njit decorator @@ -227,15 +239,6 @@ def compute_constrain_rates(velocity_collection, omega_collection): omega_collection[..., 0] = 0.0 -class OneEndFixedRod(OneEndFixedBC): - # Please clear this part beyond version 0.3.0 - """Deprecated 0.2.1: Same implementation as OneEndFixedBC""" - warnings.warn( - "OneEndFixedRod is deprecated and renamed to OneEndFixedBC. The deprecated name will be removed in the future.", - DeprecationWarning, - ) - - class GeneralConstraint(ConstraintBase): """ This boundary condition class allows the specified node/link to have a configurable constraint. @@ -266,11 +269,11 @@ class GeneralConstraint(ConstraintBase): def __init__( self, - *fixed_data, - translational_constraint_selector: Optional[np.ndarray] = None, - rotational_constraint_selector: Optional[np.array] = None, - **kwargs, - ): + *fixed_data: Any, + translational_constraint_selector: Optional[NDArray[np.bool_]] = None, + rotational_constraint_selector: Optional[NDArray[np.bool_]] = None, + **kwargs: Any, + ) -> None: """ Initialization of the constraint. Any parameter passed to 'using' will be available in kwargs. @@ -316,12 +319,12 @@ def __init__( rotational_constraint_selector = np.array([True, True, True]) # properly validate the user-provided constraint selectors assert ( - type(translational_constraint_selector) == np.ndarray + isinstance(translational_constraint_selector, np.ndarray) and translational_constraint_selector.dtype == bool and translational_constraint_selector.shape == (3,) ), "Translational constraint selector must be a 1D boolean array of length 3." assert ( - type(rotational_constraint_selector) == np.ndarray + isinstance(rotational_constraint_selector, np.ndarray) and rotational_constraint_selector.dtype == bool and rotational_constraint_selector.shape == (3,) ), "Rotational constraint selector must be a 1D boolean array of length 3." @@ -331,7 +334,9 @@ def __init__( ) self.rotational_constraint_selector = rotational_constraint_selector.astype(int) - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values( + self, system: "RodType | RigidBodyType", time: np.float64 + ) -> None: if self.constrained_position_idx.size: self.nb_constrain_translational_values( system.position_collection, @@ -340,7 +345,9 @@ def constrain_values(self, system: SystemType, time: float) -> None: self.translational_constraint_selector, ) - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates( + self, system: "RodType | RigidBodyType", time: np.float64 + ) -> None: if self.constrained_position_idx.size: self.nb_constrain_translational_rates( system.velocity_collection, @@ -356,9 +363,12 @@ def constrain_rates(self, system: SystemType, time: float) -> None: ) @staticmethod - @njit(cache=True) + @njit(cache=True) # type: ignore def nb_constrain_translational_values( - position_collection, fixed_position_collection, indices, constraint_selector + position_collection: NDArray[np.float64], + fixed_position_collection: NDArray[np.float64], + indices: NDArray[np.int32], + constraint_selector: NDArray[np.int32], ) -> None: """ Computes constrain values in numba njit decorator @@ -391,9 +401,11 @@ def nb_constrain_translational_values( ] @staticmethod - @njit(cache=True) + @njit(cache=True) # type: ignore def nb_constrain_translational_rates( - velocity_collection, indices, constraint_selector + velocity_collection: NDArray[np.float64], + indices: NDArray[np.int32], + constraint_selector: NDArray[np.int32], ) -> None: """ Compute constrain rates in numba njit decorator @@ -420,9 +432,12 @@ def nb_constrain_translational_rates( ) * velocity_collection[..., k] @staticmethod - @njit(cache=True) + @njit(cache=True) # type: ignore def nb_constrain_rotational_rates( - director_collection, omega_collection, indices, constraint_selector + director_collection: NDArray[np.float64], + omega_collection: NDArray[np.float64], + indices: NDArray[np.int32], + constraint_selector: NDArray[np.int32], ) -> None: """ Compute constrain rates in numba njit decorator @@ -489,7 +504,7 @@ class FixedConstraint(GeneralConstraint): GeneralConstraint: Generalized constraint with configurable DOF. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """ Initialization of the constraint. Any parameter passed to 'using' will be available in kwargs. @@ -508,7 +523,9 @@ def __init__(self, *args, **kwargs): **kwargs, ) - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values( + self, system: "RodType | RigidBodyType", time: np.float64 + ) -> None: if self.constrained_position_idx.size: self.nb_constrain_translational_values( system.position_collection, @@ -522,7 +539,9 @@ def constrain_values(self, system: SystemType, time: float) -> None: self.constrained_director_idx, ) - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates( + self, system: "RodType | RigidBodyType", time: np.float64 + ) -> None: if self.constrained_position_idx.size: self.nb_constrain_translational_rates( system.velocity_collection, @@ -535,9 +554,11 @@ def constrain_rates(self, system: SystemType, time: float) -> None: ) @staticmethod - @njit(cache=True) + @njit(cache=True) # type: ignore def nb_constraint_rotational_values( - director_collection, fixed_director_collection, indices + director_collection: NDArray[np.float64], + fixed_director_collection: NDArray[np.float64], + indices: NDArray[np.int32], ) -> None: """ Computes constrain values in numba njit decorator @@ -556,9 +577,11 @@ def nb_constraint_rotational_values( director_collection[..., k] = fixed_director_collection[..., i] @staticmethod - @njit(cache=True) + @njit(cache=True) # type: ignore def nb_constrain_translational_values( - position_collection, fixed_position_collection, indices + position_collection: NDArray[np.float64], + fixed_position_collection: NDArray[np.float64], + indices: NDArray[np.int32], ) -> None: """ Computes constrain values in numba njit decorator @@ -577,8 +600,10 @@ def nb_constrain_translational_values( position_collection[..., k] = fixed_position_collection[..., i] @staticmethod - @njit(cache=True) - def nb_constrain_translational_rates(velocity_collection, indices) -> None: + @njit(cache=True) # type: ignore + def nb_constrain_translational_rates( + velocity_collection: NDArray[np.float64], indices: NDArray[np.int32] + ) -> None: """ Compute constrain rates in numba njit decorator Parameters @@ -597,8 +622,10 @@ def nb_constrain_translational_rates(velocity_collection, indices) -> None: velocity_collection[2, k] = 0.0 @staticmethod - @njit(cache=True) - def nb_constrain_rotational_rates(omega_collection, indices) -> None: + @njit(cache=True) # type: ignore + def nb_constrain_rotational_rates( + omega_collection: NDArray[np.float64], indices: NDArray[np.int32] + ) -> None: """ Compute constrain rates in numba njit decorator Parameters @@ -654,15 +681,15 @@ class HelicalBucklingBC(ConstraintBase): def __init__( self, - position_start: np.ndarray, - position_end: np.ndarray, - director_start: np.ndarray, - director_end: np.ndarray, + position_start: NDArray[np.float64], + position_end: NDArray[np.float64], + director_start: NDArray[np.float64], + director_end: NDArray[np.float64], twisting_time: float, slack: float, number_of_rotations: float, - **kwargs, - ): + **kwargs: Any, + ) -> None: """ Helical Buckling initializer @@ -690,12 +717,12 @@ def __init__( Number of rotations applied to rod. """ super().__init__(**kwargs) - self.twisting_time = twisting_time + self.twisting_time = np.float64(twisting_time) - angel_vel_scalar = ( - 2.0 * number_of_rotations * np.pi / self.twisting_time - ) / 2.0 - shrink_vel_scalar = slack / (self.twisting_time * 2.0) + angel_vel_scalar = np.float64( + (2.0 * number_of_rotations * np.pi / self.twisting_time) / 2.0 + ) + shrink_vel_scalar = np.float64(slack / (self.twisting_time * 2.0)) direction = (position_end - position_start) / np.linalg.norm( position_end - position_start @@ -707,7 +734,7 @@ def __init__( self.ang_vel = angel_vel_scalar * direction self.shrink_vel = shrink_vel_scalar * direction - theta = number_of_rotations * np.pi + theta = np.float64(number_of_rotations * np.pi) self.final_start_directors = ( _get_rotation_matrix(theta, direction.reshape(3, 1)).reshape(3, 3) @@ -718,25 +745,29 @@ def __init__( @ director_end ) # rotation_matrix wants vectors 3,1 - def constrain_values(self, rod: RodType, time: float) -> None: + def constrain_values( + self, system: "RodType | RigidBodyType", time: np.float64 + ) -> None: if time > self.twisting_time: - rod.position_collection[..., 0] = self.final_start_position - rod.position_collection[..., -1] = self.final_end_position + system.position_collection[..., 0] = self.final_start_position + system.position_collection[..., -1] = self.final_end_position - rod.director_collection[..., 0] = self.final_start_directors - rod.director_collection[..., -1] = self.final_end_directors + system.director_collection[..., 0] = self.final_start_directors + system.director_collection[..., -1] = self.final_end_directors - def constrain_rates(self, rod: RodType, time: float) -> None: + def constrain_rates( + self, system: "RodType | RigidBodyType", time: np.float64 + ) -> None: if time > self.twisting_time: - rod.velocity_collection[..., 0] = 0.0 - rod.omega_collection[..., 0] = 0.0 + system.velocity_collection[..., 0] = 0.0 + system.omega_collection[..., 0] = 0.0 - rod.velocity_collection[..., -1] = 0.0 - rod.omega_collection[..., -1] = 0.0 + system.velocity_collection[..., -1] = 0.0 + system.omega_collection[..., -1] = 0.0 else: - rod.velocity_collection[..., 0] = self.shrink_vel - rod.omega_collection[..., 0] = self.ang_vel + system.velocity_collection[..., 0] = self.shrink_vel + system.omega_collection[..., 0] = self.ang_vel - rod.velocity_collection[..., -1] = -self.shrink_vel - rod.omega_collection[..., -1] = -self.ang_vel + system.velocity_collection[..., -1] = -self.shrink_vel + system.omega_collection[..., -1] = -self.ang_vel diff --git a/elastica/callback_functions.py b/elastica/callback_functions.py index ab865dfd..0030b8b3 100644 --- a/elastica/callback_functions.py +++ b/elastica/callback_functions.py @@ -1,14 +1,21 @@ __doc__ = """ Module contains callback classes to save simulation data for rod-like objects """ +from typing import Any, Optional, TypeVar, Generic +from elastica.typing import RodType, RigidBodyType, SystemType import os import sys import numpy as np +from numpy.typing import NDArray import logging + from collections import defaultdict -class CallBackBaseClass: +T = TypeVar("T") + + +class CallBackBaseClass(Generic[T]): """ This is the base class for callbacks for rod-like objects. @@ -19,13 +26,13 @@ class CallBackBaseClass: """ - def __init__(self): + def __init__(self) -> None: """ CallBackBaseClass does not need any input parameters. """ pass - def make_callback(self, system, time, current_step: int): + def make_callback(self, system: T, time: np.float64, current_step: int) -> None: """ This method is called every time step. Users can define which parameters are called back and recorded. Also users @@ -59,7 +66,7 @@ class MyCallBack(CallBackBaseClass): Collected callback data is saved in this dictionary. """ - def __init__(self, step_skip: int, callback_params): + def __init__(self, step_skip: int, callback_params: dict) -> None: """ Parameters @@ -73,7 +80,9 @@ def __init__(self, step_skip: int, callback_params): self.sample_every = step_skip self.callback_params = callback_params - def make_callback(self, system, time, current_step: int): + def make_callback( + self, system: "RodType | RigidBodyType", time: np.float64, current_step: int + ) -> None: if current_step % self.sample_every == 0: @@ -116,8 +125,8 @@ def __init__( directory: str, method: str, initial_file_count: int = 0, - file_save_interval: int = 1e8, - ): + file_save_interval: int = 100_000_000, + ) -> None: """ Parameters ---------- @@ -167,7 +176,9 @@ def __init__( self.file_save_interval = file_save_interval # Data collector - self.buffer = defaultdict(list) + self.buffer: dict[str, list[NDArray[np.float64] | np.float64 | int]] = ( + defaultdict(list) + ) self.buffer_size = 0 # Module @@ -189,7 +200,9 @@ def __init__( self._pickle = pickle self._ext = "pkl" - def make_callback(self, system, time, current_step: int): + def make_callback( + self, system: "RodType | RigidBodyType", time: np.float64, current_step: int + ) -> None: """ Parameters @@ -224,7 +237,7 @@ def make_callback(self, system, time, current_step: int): ): self._dump() - def _dump(self, **kwargs): + def _dump(self, **kwargs: Any) -> None: """ Dump dictionary buffer (self.buffer) to a file and clear the buffer. @@ -247,7 +260,7 @@ def _dump(self, **kwargs): self.buffer_size = 0 self.buffer.clear() - def get_last_saved_path(self) -> str: + def get_last_saved_path(self) -> Optional[str]: """ Return last saved file path. If no file has been saved, return None @@ -257,14 +270,14 @@ def get_last_saved_path(self) -> str: else: return self.save_path.format(self.file_count - 1, self._ext) - def close(self): + def close(self) -> None: """ Save residual buffer """ if self.buffer_size: self._dump() - def clear(self): + def clear(self) -> None: """ Alias to `close` """ diff --git a/elastica/collision/AABBCollection.py b/elastica/collision/AABBCollection.py index 22b71606..c91b83b7 100644 --- a/elastica/collision/AABBCollection.py +++ b/elastica/collision/AABBCollection.py @@ -1,17 +1,20 @@ """ Axis Aligned Bounding Boxes for coarse collision detection """ +from typing_extensions import Self + import numpy as np +from numpy.typing import NDArray from elastica.utils import MaxDimension class AABBCollection: def __init__( self, - elemental_position_collection, - dimension_collection, + elemental_position_collection: NDArray[np.float64], + dimension_collection: NDArray[np.float64], elements_per_aabb: int, - ): + ) -> None: """ Doesn't differentiate tangent direction from the rest : potentially harmful as maybe you don't need to expand to radius amount in tangential direction @@ -36,7 +39,9 @@ def __init__( self.update(elemental_position_collection, dimension_collection) @classmethod - def make_from_aabb(cls, aabb_collection, scale_factor=4): + def make_from_aabb( + cls, aabb_collection: list["AABBCollection"], scale_factor: int = 4 + ) -> Self: # Make position collection and dimension collection arrays from aabb_collection # Wasted effort, but only once during construction n_aabb_from_lower_level = len(aabb_collection) @@ -59,7 +64,7 @@ def make_from_aabb(cls, aabb_collection, scale_factor=4): return cls(elemental_position_collection, dimension_collection, scale_factor) - def _update(self, aabb_collection): + def _update(self, aabb_collection: list["AABBCollection"]) -> None: # Updates internal state from another aabb """ @@ -78,7 +83,11 @@ def _update(self, aabb_collection): temp = np.array([aabb.aabb[..., 1, 0] for aabb in aabb_collection]) self.aabb[..., 1, 0] = np.amax(temp, axis=0) - def update(self, elemental_position_collection, dimension_collection): + def update( + self, + elemental_position_collection: NDArray[np.float64], + dimension_collection: NDArray[np.float64], + ) -> None: # Initialize the boxes for i in range(self.n_aabb): start = i * self.elements_per_aabb @@ -91,7 +100,7 @@ def update(self, elemental_position_collection, dimension_collection): ) + np.amax(dimension_collection[..., start:stop], axis=1) -def find_nearest_integer_square_root(x: int): +def find_nearest_integer_square_root(x: int) -> int: from math import sqrt return round(sqrt(x)) @@ -101,8 +110,11 @@ class AABBHierarchy: """Simple hierarchy for handling cylinder collisions alone, meant for a rod""" def __init__( - self, position_collection, dimension_collection, avg_n_dofs_in_final_level - ): + self, + position_collection: NDArray[np.float64], + dimension_collection: NDArray[np.float64], + avg_n_dofs_in_final_level: int, + ) -> None: """ scaling is always set to 4, so that theres' 1 major AABBCollection, then scaling_factor smaller AABBs, then scaling factor even smaller AABBs (which cover the elements @@ -121,10 +133,10 @@ def __init__( ) # nearest power of 4 that is less than the number - n_levels_bound_below = np.int( + n_levels_bound_below = int( np.floor(0.5 * np.log2(potential_n_aabbs_in_final_level)) ) - n_levels_bound_above = np.int( + n_levels_bound_above = int( np.ceil(0.5 * np.log2(potential_n_aabbs_in_final_level)) ) # Check which is the closest and use that as the number of levels @@ -214,11 +226,15 @@ def __init__( # Add one for the middle level # self.aabb.append(AABBCollection(position_collection, dimension_collection, self.n_aabbs_in_first_level)) - def n_aabbs_at_level(self, i: int): + def n_aabbs_at_level(self, i: int) -> int: assert i < self.n_levels return 4 ** (i) - def update(self, position_collection, dimension_collection): + def update( + self, + position_collection: NDArray[np.float64], + dimension_collection: NDArray[np.float64], + ) -> None: # Update bottom level first, the first level entries n_aabbs_in_final_level = self.n_aabbs_at_level(self.n_levels - 1) stop = 0 @@ -261,5 +277,8 @@ def update(self, position_collection, dimension_collection): count_elapsed_n_aabbs += n_aabbs_in_next_level -def are_aabb_intersecting(first_aabb_collection, second_aabb_collection): +def are_aabb_intersecting( + first_aabb_collection: NDArray[np.float64], + second_aabb_collection: NDArray[np.float64], +) -> bool: return True diff --git a/elastica/contact_forces.py b/elastica/contact_forces.py index 8f9b0ab5..70c7caf2 100644 --- a/elastica/contact_forces.py +++ b/elastica/contact_forces.py @@ -1,9 +1,13 @@ __doc__ = """ Numba implementation module containing contact between rods and rigid bodies and other rods rigid bodies or surfaces.""" -from elastica.typing import RodType, SystemType, AllowedContactType -from elastica.rod import RodBase -from elastica.rigidbody import Cylinder, Sphere -from elastica.surface import Plane +from typing import TypeVar, Generic, Type +from elastica.typing import RodType, SystemType, SurfaceType + +from elastica.rod.rod_base import RodBase +from elastica.rigidbody.cylinder import Cylinder +from elastica.rigidbody.sphere import Sphere +from elastica.surface.plane import Plane +from elastica.surface.surface_base import SurfaceBase from elastica.contact_utils import ( _prune_using_aabbs_rod_cylinder, _prune_using_aabbs_rod_rod, @@ -19,9 +23,14 @@ _calculate_contact_forces_cylinder_plane, ) import numpy as np +from numpy.typing import NDArray -class NoContact: +S1 = TypeVar("S1") # TODO: Find bound +S2 = TypeVar("S2") + + +class NoContact(Generic[S1, S2]): """ This is the base class for contact applied between rod-like objects and allowed contact objects. @@ -32,55 +41,52 @@ class NoContact: """ - def __init__(self): + def __init__(self) -> None: """ NoContact class does not need any input parameters. """ + pass + + @property + def _allowed_system_one(self) -> list[Type]: + # Modify this list to include the allowed system types for contact + return [RodBase] + + @property + def _allowed_system_two(self) -> list[Type]: + # Modify this list to include the allowed system types for contact + return [RodBase] def _check_systems_validity( self, - system_one: SystemType, - system_two: AllowedContactType, + system_one: S1, + system_two: S2, ) -> None: """ - This checks the contact order between a SystemType object and an AllowedContactType object, the order should follow: Rod, Rigid body, Surface. - In NoContact class, this just checks if system_two is a rod then system_one must be a rod. + Here, we check the allowed system types for contact. + For derived classes, this method can be overridden to enforce specific system types + for contact model. + """ + common_check_systems_validity(system_one, self._allowed_system_one) + common_check_systems_validity(system_two, self._allowed_system_two) - Parameters - ---------- - system_one - SystemType - system_two - AllowedContactType - """ - if issubclass(system_two.__class__, RodBase): - if not issubclass(system_one.__class__, RodBase): - raise TypeError( - "Systems provided to the contact class have incorrect order. \n" - " First system is {0} and second system is {1}. \n" - " If the first system is a rod, the second system can be a rod, rigid body or surface. \n" - " If the first system is a rigid body, the second system can be a rigid body or surface.".format( - system_one.__class__, system_two.__class__ - ) - ) + common_check_systems_identity(system_one, system_two) def apply_contact( self, - system_one: SystemType, - system_two: AllowedContactType, + system_one: S1, + system_two: S2, ) -> None: """ - Apply contact forces and torques between SystemType object and AllowedContactType object. + Apply contact forces and torques between two system object.. In NoContact class, this routine simply passes. Parameters ---------- - system_one : SystemType - Rod or rigid-body object - system_two : AllowedContactType - Rod, rigid-body, or surface object + system_one + system_two """ pass @@ -101,7 +107,7 @@ class RodRodContact(NoContact): """ - def __init__(self, k: float, nu: float): + def __init__(self, k: np.float64, nu: np.float64) -> None: """ Parameters ---------- @@ -114,49 +120,14 @@ def __init__(self, k: float, nu: float): self.k = k self.nu = nu - def _check_systems_validity( - self, - system_one: SystemType, - system_two: AllowedContactType, - ) -> None: - """ - This checks the contact order and type of a SystemType object and an AllowedContactType object. - For the RodRodContact class both systems must be distinct rods. - - Parameters - ---------- - system_one - SystemType - system_two - AllowedContactType - """ - if not issubclass(system_one.__class__, RodBase) or not issubclass( - system_two.__class__, RodBase - ): - raise TypeError( - "Systems provided to the contact class have incorrect order. \n" - " First system is {0} and second system is {1}. \n" - " Both systems must be distinct rods".format( - system_one.__class__, system_two.__class__ - ) - ) - if system_one == system_two: - raise TypeError( - "First rod is identical to second rod. \n" - "Rods must be distinct for RodRodConact. \n" - "If you want self contact, use RodSelfContact instead" - ) - def apply_contact(self, system_one: RodType, system_two: RodType) -> None: """ Apply contact forces and torques between RodType object and RodType object. Parameters ---------- - system_one: object - Rod object. - system_two: object - Rod object. + system_one: RodType + system_two: RodType """ # First, check for a global AABB bounding box, and see whether that @@ -227,9 +198,9 @@ def __init__( self, k: float, nu: float, - velocity_damping_coefficient=0.0, - friction_coefficient=0.0, - ): + velocity_damping_coefficient: float = 0.0, + friction_coefficient: float = 0.0, + ) -> None: """ Parameters @@ -245,39 +216,17 @@ def __init__( For Coulombic friction coefficient for rigid-body and rod contact. """ super(RodCylinderContact, self).__init__() - self.k = k - self.nu = nu - self.velocity_damping_coefficient = velocity_damping_coefficient - self.friction_coefficient = friction_coefficient + self.k = np.float64(k) + self.nu = np.float64(nu) + self.velocity_damping_coefficient = np.float64(velocity_damping_coefficient) + self.friction_coefficient = np.float64(friction_coefficient) - def _check_systems_validity( - self, - system_one: SystemType, - system_two: AllowedContactType, - ) -> None: - """ - This checks the contact order and type of a SystemType object and an AllowedContactType object. - For the RodCylinderContact class first_system should be a rod and second_system should be a cylinder. + @property + def _allowed_system_two(self) -> list[Type]: + # Modify this list to include the allowed system types for contact + return [Cylinder] - Parameters - ---------- - system_one - SystemType - system_two - AllowedContactType - """ - if not issubclass(system_one.__class__, RodBase) or not issubclass( - system_two.__class__, Cylinder - ): - raise TypeError( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a cylinder".format( - system_one.__class__, system_two.__class__ - ) - ) - - def apply_contact(self, system_one: RodType, system_two: SystemType) -> None: + def apply_contact(self, system_one: RodType, system_two: Cylinder) -> None: # First, check for a global AABB bounding box, and see whether that # intersects if _prune_using_aabbs_rod_cylinder( @@ -286,8 +235,8 @@ def apply_contact(self, system_one: RodType, system_two: SystemType) -> None: system_one.lengths, system_two.position_collection, system_two.director_collection, - system_two.radius[0], - system_two.length[0], + system_two.radius, + system_two.length, ): return @@ -338,7 +287,7 @@ class RodSelfContact(NoContact): """ - def __init__(self, k: float, nu: float): + def __init__(self, k: float, nu: float) -> None: """ Parameters @@ -349,38 +298,20 @@ def __init__(self, k: float, nu: float): Contact damping constant. """ super(RodSelfContact, self).__init__() - self.k = k - self.nu = nu + self.k = np.float64(k) + self.nu = np.float64(nu) def _check_systems_validity( self, - system_one: SystemType, - system_two: AllowedContactType, + system_one: RodType, + system_two: RodType, ) -> None: """ - This checks the contact order and type of a SystemType object and an AllowedContactType object. - For the RodSelfContact class first_system and second_system should be the same rod. - - Parameters - ---------- - system_one - SystemType - system_two - AllowedContactType + Overriding the base class method to check if the two systems are identical. """ - if ( - not issubclass(system_one.__class__, RodBase) - or not issubclass(system_two.__class__, RodBase) - or system_one != system_two - ): - raise TypeError( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system and second system should be the same rod \n" - " If you want rod rod contact, use RodRodContact instead".format( - system_one.__class__, system_two.__class__ - ) - ) + common_check_systems_validity(system_one, self._allowed_system_one) + common_check_systems_validity(system_two, self._allowed_system_two) + common_check_systems_different(system_one, system_two) def apply_contact(self, system_one: RodType, system_two: RodType) -> None: """ @@ -388,10 +319,8 @@ def apply_contact(self, system_one: RodType, system_two: RodType) -> None: Parameters ---------- - system_one: object - Rod object. - system_two: object - Rod object. + system_one: RodType + system_two: RodType """ _calculate_contact_forces_self_rod( @@ -437,9 +366,9 @@ def __init__( self, k: float, nu: float, - velocity_damping_coefficient=0.0, - friction_coefficient=0.0, - ): + velocity_damping_coefficient: float = 0.0, + friction_coefficient: float = 0.0, + ) -> None: """ Parameters ---------- @@ -454,47 +383,23 @@ def __init__( For Coulombic friction coefficient for rigid-body and rod contact. """ super(RodSphereContact, self).__init__() - self.k = k - self.nu = nu - self.velocity_damping_coefficient = velocity_damping_coefficient - self.friction_coefficient = friction_coefficient + self.k = np.float64(k) + self.nu = np.float64(nu) + self.velocity_damping_coefficient = np.float64(velocity_damping_coefficient) + self.friction_coefficient = np.float64(friction_coefficient) - def _check_systems_validity( - self, - system_one: SystemType, - system_two: AllowedContactType, - ) -> None: - """ - This checks the contact order and type of a SystemType object and an AllowedContactType object. - For the RodSphereContact class first_system should be a rod and second_system should be a sphere. - Parameters - ---------- - system_one - SystemType - system_two - AllowedContactType - """ - if not issubclass(system_one.__class__, RodBase) or not issubclass( - system_two.__class__, Sphere - ): - raise TypeError( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a sphere".format( - system_one.__class__, system_two.__class__ - ) - ) + @property + def _allowed_system_two(self) -> list[Type]: + return [Sphere] - def apply_contact(self, system_one: RodType, system_two: SystemType) -> None: + def apply_contact(self, system_one: RodType, system_two: Sphere) -> None: """ Apply contact forces and torques between RodType object and Sphere object. Parameters ---------- - system_one: object - Rod object. - system_two: object - Sphere object. + system_one: RodType + system_two: Sphere """ # First, check for a global AABB bounding box, and see whether that @@ -505,7 +410,7 @@ def apply_contact(self, system_one: RodType, system_two: SystemType) -> None: system_one.lengths, system_two.position_collection, system_two.director_collection, - system_two.radius[0], + system_two.radius, ): return @@ -562,7 +467,7 @@ def __init__( self, k: float, nu: float, - ): + ) -> None: """ Parameters ---------- @@ -572,37 +477,15 @@ def __init__( Contact damping constant. """ super(RodPlaneContact, self).__init__() - self.k = k - self.nu = nu - self.surface_tol = 1e-4 + self.k = np.float64(k) + self.nu = np.float64(nu) + self.surface_tol = np.float64(1.0e-4) - def _check_systems_validity( - self, - system_one: SystemType, - system_two: AllowedContactType, - ) -> None: - """ - This checks the contact order and type of a SystemType object and an AllowedContactType object. - For the RodPlaneContact class first_system should be a rod and second_system should be a plane. - Parameters - ---------- - system_one - SystemType - system_two - AllowedContactType - """ - if not issubclass(system_one.__class__, RodBase) or not issubclass( - system_two.__class__, Plane - ): - raise TypeError( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a plane".format( - system_one.__class__, system_two.__class__ - ) - ) + @property + def _allowed_system_two(self) -> list[Type]: + return [SurfaceBase] - def apply_contact(self, system_one: RodType, system_two: SystemType) -> None: + def apply_contact(self, system_one: RodType, system_two: SurfaceType) -> None: """ Apply contact forces and torques between RodType object and Plane object. @@ -655,9 +538,9 @@ def __init__( k: float, nu: float, slip_velocity_tol: float, - static_mu_array: np.ndarray, - kinetic_mu_array: np.ndarray, - ): + static_mu_array: NDArray[np.float64], + kinetic_mu_array: NDArray[np.float64], + ) -> None: """ Parameters ---------- @@ -675,9 +558,9 @@ def __init__( [forward, backward, sideways] kinetic friction coefficients. """ super(RodPlaneContactWithAnisotropicFriction, self).__init__() - self.k = k - self.nu = nu - self.surface_tol = 1e-4 + self.k = np.float64(k) + self.nu = np.float64(nu) + self.surface_tol = np.float64(1.0e-4) self.slip_velocity_tol = slip_velocity_tol ( self.static_mu_forward, @@ -690,42 +573,18 @@ def __init__( self.kinetic_mu_sideways, ) = kinetic_mu_array - def _check_systems_validity( - self, - system_one: SystemType, - system_two: AllowedContactType, - ) -> None: - """ - This checks the contact order and type of a SystemType object and an AllowedContactType object. - For the RodSphereContact class first_system should be a rod and second_system should be a plane. - Parameters - ---------- - system_one - SystemType - system_two - AllowedContactType - """ - if not issubclass(system_one.__class__, RodBase) or not issubclass( - system_two.__class__, Plane - ): - raise TypeError( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a plane".format( - system_one.__class__, system_two.__class__ - ) - ) + @property + def _allowed_system_two(self) -> list[Type]: + return [SurfaceBase] - def apply_contact(self, system_one: RodType, system_two: SystemType) -> None: + def apply_contact(self, system_one: RodType, system_two: SurfaceType) -> None: """ Apply contact forces and torques between RodType object and Plane object with anisotropic friction. Parameters ---------- - system_one: object - Rod object. - system_two: object - Plane object. + system_one: RodType + system_two: SurfaceType """ @@ -778,7 +637,7 @@ def __init__( self, k: float, nu: float, - ): + ) -> None: """ Parameters ---------- @@ -788,37 +647,19 @@ def __init__( Contact damping constant. """ super(CylinderPlaneContact, self).__init__() - self.k = k - self.nu = nu - self.surface_tol = 1e-4 + self.k = np.float64(k) + self.nu = np.float64(nu) + self.surface_tol = np.float64(1.0e-4) - def _check_systems_validity( - self, - system_one: SystemType, - system_two: AllowedContactType, - ) -> None: - """ - This checks the contact order and type of a SystemType object and an AllowedContactType object. - For the RodPlaneContact class first_system should be a cylinder and second_system should be a plane. - Parameters - ---------- - system_one - SystemType - system_two - AllowedContactType - """ - if not issubclass(system_one.__class__, Cylinder) or not issubclass( - system_two.__class__, Plane - ): - raise TypeError( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a cylinder, second should be a plane".format( - system_one.__class__, system_two.__class__ - ) - ) + @property + def _allowed_system_one(self) -> list[Type]: + return [Cylinder] - def apply_contact(self, system_one: Cylinder, system_two: SystemType): + @property + def _allowed_system_two(self) -> list[Type]: + return [SurfaceBase] + + def apply_contact(self, system_one: Cylinder, system_two: SurfaceType) -> None: """ This function computes the plane force response on the cylinder, in the case of contact. Contact model given in Eqn 4.8 Gazzola et. al. RSoS 2018 paper @@ -826,13 +667,11 @@ def apply_contact(self, system_one: Cylinder, system_two: SystemType): Parameters ---------- - system_one: object - Cylinder object. - system_two: object - Plane object. + system_one: Cylinder + system_two: SurfaceBase """ - return _calculate_contact_forces_cylinder_plane( + _calculate_contact_forces_cylinder_plane( system_two.origin, system_two.normal, self.surface_tol, @@ -843,3 +682,49 @@ def apply_contact(self, system_one: Cylinder, system_two: SystemType): system_one.velocity_collection, system_one.external_forces, ) + + +def common_check_systems_identity( + system_one: S1, + system_two: S2, +) -> None: + """ + This checks if two objects are identical. + + Raises + ------ + TypeError + If two objects are identical. + """ + if system_one == system_two: + raise TypeError( + "First system is identical to second system. Systems must be distinct for contact." + ) + + +def common_check_systems_different( + system_one: S1, + system_two: S2, +) -> None: + """ + This checks if two objects are identical. + + Raises + ------ + TypeError + If two objects are not identical. + """ + if system_one != system_two: + raise TypeError("First system must be identical to the second system.") + + +def common_check_systems_validity( + system: S1 | S2, allowed_system: list[Type[S1] | Type[S2]] +) -> None: + # Check validity + if not isinstance(system, tuple(allowed_system)): + system_name = system.__class__.__name__ + allowed_system_names = [candidate.__name__ for candidate in allowed_system] + raise TypeError( + f"System provided ({system_name}) must be derived from {allowed_system_names}." + ) diff --git a/elastica/contact_utils.py b/elastica/contact_utils.py index 71584d2b..37887685 100644 --- a/elastica/contact_utils.py +++ b/elastica/contact_utils.py @@ -3,37 +3,44 @@ from math import sqrt import numba import numpy as np +from numpy.typing import NDArray from elastica._linalg import ( _batch_norm, ) +from typing import Literal, Sequence -@numba.njit(cache=True) -def _dot_product(a, b): - sum = 0.0 +@numba.njit(cache=True) # type: ignore +def _dot_product(a: Sequence[np.float64], b: Sequence[np.float64]) -> np.float64: + total: np.float64 = np.float64(0.0) for i in range(3): - sum += a[i] * b[i] - return sum + total += a[i] * b[i] + return total -@numba.njit(cache=True) -def _norm(a): +@numba.njit(cache=True) # type: ignore +def _norm(a: Sequence[np.float64]) -> float: return sqrt(_dot_product(a, a)) -@numba.njit(cache=True) -def _clip(x, low, high): +@numba.njit(cache=True) # type: ignore +def _clip(x: np.float64, low: np.float64, high: np.float64) -> np.float64: return max(low, min(x, high)) # Can this be made more efficient than 2 comp, 1 or? -@numba.njit(cache=True) -def _out_of_bounds(x, low, high): - return (x < low) or (x > high) - - -@numba.njit(cache=True) -def _find_min_dist(x1, e1, x2, e2): +@numba.njit(cache=True) # type: ignore +def _out_of_bounds(x: np.float64, low: np.float64, high: np.float64) -> bool: + return bool((x < low) or (x > high)) + + +@numba.njit(cache=True) # type: ignore +def _find_min_dist( + x1: NDArray[np.float64], + e1: NDArray[np.float64], + x2: NDArray[np.float64], + e2: NDArray[np.float64], +) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]: e1e1 = _dot_product(e1, e1) e1e2 = _dot_product(e1, e2) e2e2 = _dot_product(e2, e2) @@ -98,8 +105,10 @@ def _find_min_dist(x1, e1, x2, e2): return x2 + s * e2 - x1 - t * e1, x2 + s * e2, x1 - t * e1 -@numba.njit(cache=True) -def _aabbs_not_intersecting(aabb_one, aabb_two): +@numba.njit(cache=True) # type: ignore +def _aabbs_not_intersecting( + aabb_one: NDArray[np.float64], aabb_two: NDArray[np.float64] +) -> Literal[1, 0]: """Returns true if not intersecting else false""" if (aabb_one[0, 1] < aabb_two[0, 0]) | (aabb_one[0, 0] > aabb_two[0, 1]): return 1 @@ -111,16 +120,16 @@ def _aabbs_not_intersecting(aabb_one, aabb_two): return 0 -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _prune_using_aabbs_rod_cylinder( - rod_one_position_collection, - rod_one_radius_collection, - rod_one_length_collection, - cylinder_position, - cylinder_director, - cylinder_radius, - cylinder_length, -): + rod_one_position_collection: NDArray[np.float64], + rod_one_radius_collection: NDArray[np.float64], + rod_one_length_collection: NDArray[np.float64], + cylinder_position: NDArray[np.float64], + cylinder_director: NDArray[np.float64], + cylinder_radius: NDArray[np.float64], + cylinder_length: NDArray[np.float64], +) -> Literal[1, 0]: max_possible_dimension = np.zeros((3,)) aabb_rod = np.empty((3, 2)) aabb_cylinder = np.empty((3, 2)) @@ -153,15 +162,15 @@ def _prune_using_aabbs_rod_cylinder( return _aabbs_not_intersecting(aabb_cylinder, aabb_rod) -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _prune_using_aabbs_rod_rod( - rod_one_position_collection, - rod_one_radius_collection, - rod_one_length_collection, - rod_two_position_collection, - rod_two_radius_collection, - rod_two_length_collection, -): + rod_one_position_collection: NDArray[np.float64], + rod_one_radius_collection: NDArray[np.float64], + rod_one_length_collection: NDArray[np.float64], + rod_two_position_collection: NDArray[np.float64], + rod_two_radius_collection: NDArray[np.float64], + rod_two_length_collection: NDArray[np.float64], +) -> Literal[1, 0]: max_possible_dimension = np.zeros((3,)) aabb_rod_one = np.empty((3, 2)) aabb_rod_two = np.empty((3, 2)) @@ -191,15 +200,15 @@ def _prune_using_aabbs_rod_rod( return _aabbs_not_intersecting(aabb_rod_two, aabb_rod_one) -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _prune_using_aabbs_rod_sphere( - rod_one_position_collection, - rod_one_radius_collection, - rod_one_length_collection, - sphere_position, - sphere_director, - sphere_radius, -): + rod_one_position_collection: NDArray[np.float64], + rod_one_radius_collection: NDArray[np.float64], + rod_one_length_collection: NDArray[np.float64], + sphere_position: NDArray[np.float64], + sphere_director: NDArray[np.float64], + sphere_radius: NDArray[np.float64], +) -> Literal[1, 0]: max_possible_dimension = np.zeros((3,)) aabb_rod = np.empty((3, 2)) aabb_sphere = np.empty((3, 2)) @@ -230,8 +239,10 @@ def _prune_using_aabbs_rod_sphere( return _aabbs_not_intersecting(aabb_sphere, aabb_rod) -@numba.njit(cache=True) -def _find_slipping_elements(velocity_slip, velocity_threshold): +@numba.njit(cache=True) # type: ignore +def _find_slipping_elements( + velocity_slip: NDArray[np.float64], velocity_threshold: np.float64 +) -> NDArray[np.float64]: """ This function takes the velocity of elements and checks if they are larger than the threshold velocity. If the velocity of elements is larger than threshold velocity, that means those elements are slipping. @@ -271,8 +282,8 @@ def _find_slipping_elements(velocity_slip, velocity_threshold): return slip_function -@numba.njit(cache=True) -def _node_to_element_mass_or_force(input): +@numba.njit(cache=True) # type: ignore +def _node_to_element_mass_or_force(input: NDArray[np.float64]) -> NDArray[np.float64]: """ This function converts the mass/forces on rod nodes to elements, where special treatment is necessary at the ends. @@ -309,8 +320,11 @@ def _node_to_element_mass_or_force(input): return output -@numba.njit(cache=True) -def _elements_to_nodes_inplace(vector_in_element_frame, vector_in_node_frame): +@numba.njit(cache=True) # type: ignore +def _elements_to_nodes_inplace( + vector_in_element_frame: NDArray[np.float64], + vector_in_node_frame: NDArray[np.float64], +) -> None: """ Updating nodal forces using the forces computed on elements Parameters @@ -332,8 +346,10 @@ def _elements_to_nodes_inplace(vector_in_element_frame, vector_in_node_frame): vector_in_node_frame[i, k + 1] += 0.5 * vector_in_element_frame[i, k] -@numba.njit(cache=True) -def _node_to_element_position(node_position_collection): +@numba.njit(cache=True) # type: ignore +def _node_to_element_position( + node_position_collection: NDArray[np.float64], +) -> NDArray[np.float64]: """ This function computes the position of the elements from the nodal values. @@ -378,8 +394,10 @@ def _node_to_element_position(node_position_collection): return element_position_collection -@numba.njit(cache=True) -def _node_to_element_velocity(mass, node_velocity_collection): +@numba.njit(cache=True) # type: ignore +def _node_to_element_velocity( + mass: NDArray[np.float64], node_velocity_collection: NDArray[np.float64] +) -> NDArray[np.float64]: """ This function computes the velocity of the elements from the nodal values. Uses the velocity of center of mass diff --git a/elastica/dissipation.py b/elastica/dissipation.py index f3629fe7..32ff4866 100644 --- a/elastica/dissipation.py +++ b/elastica/dissipation.py @@ -5,15 +5,20 @@ """ from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar from elastica.typing import RodType, SystemType from numba import njit import numpy as np +from numpy.typing import NDArray -class DamperBase(ABC): +T = TypeVar("T") + + +class DamperBase(Generic[T], ABC): """Base class for damping module implementations. Notes @@ -23,13 +28,14 @@ class DamperBase(ABC): Attributes ---------- - system : SystemType (RodBase or RigidBodyBase) + system : RodBase """ - _system: SystemType + _system: T - def __init__(self, *args, **kwargs): + # TODO typing can be made better + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize damping module""" try: self._system = kwargs["_system"] @@ -40,7 +46,7 @@ def __init__(self, *args, **kwargs): ) @property - def system(self): # -> SystemType: (Return type is not parsed with sphinx book.) + def system(self) -> T: """ get system (rod or rigid body) reference @@ -52,7 +58,7 @@ def system(self): # -> SystemType: (Return type is not parsed with sphinx book. return self._system @abstractmethod - def dampen_rates(self, system: SystemType, time: float): + def dampen_rates(self, system: T, time: np.float64) -> None: # TODO: In the future, we can remove rod and use self.system """ Dampen rates (velocity and/or omega) of a rod object. @@ -113,7 +119,9 @@ class AnalyticalLinearDamper(DamperBase): Damping coefficient acting on rotational velocity. """ - def __init__(self, damping_constant, time_step, **kwargs): + def __init__( + self, damping_constant: float, time_step: float, **kwargs: Any + ) -> None: """ Analytical linear damper initializer @@ -143,13 +151,13 @@ def __init__(self, damping_constant, time_step, **kwargs): * np.diagonal(self._system.inv_mass_second_moment_of_inertia).T ) - def dampen_rates(self, rod: RodType, time: float): - rod.velocity_collection[:] = ( - rod.velocity_collection * self.translational_damping_coefficient + def dampen_rates(self, system: RodType, time: np.float64) -> None: + system.velocity_collection[:] = ( + system.velocity_collection * self.translational_damping_coefficient ) - rod.omega_collection[:] = rod.omega_collection * np.power( - self.rotational_damping_coefficient, rod.dilatation + system.omega_collection[:] = system.omega_collection * np.power( + self.rotational_damping_coefficient, system.dilatation ) @@ -202,7 +210,7 @@ class LaplaceDissipationFilter(DamperBase): Filter term that modifies rod rotational velocity. """ - def __init__(self, filter_order: int, **kwargs): + def __init__(self, filter_order: int, **kwargs: Any) -> None: """ Filter damper initializer @@ -232,25 +240,25 @@ def __init__(self, filter_order: int, **kwargs): self.omega_filter_term = np.zeros_like(self._system.omega_collection) self.filter_function = _filter_function_periodic_condition - def dampen_rates(self, rod: RodType, time: float) -> None: + def dampen_rates(self, system: RodType, time: np.float64) -> None: self.filter_function( - rod.velocity_collection, + system.velocity_collection, self.velocity_filter_term, - rod.omega_collection, + system.omega_collection, self.omega_filter_term, self.filter_order, ) -@njit(cache=True) +@njit(cache=True) # type: ignore def _filter_function_periodic_condition_ring_rod( - velocity_collection, - velocity_filter_term, - omega_collection, - omega_filter_term, - filter_order, -): + velocity_collection: NDArray[np.float64], + velocity_filter_term: NDArray[np.float64], + omega_collection: NDArray[np.float64], + omega_filter_term: NDArray[np.float64], + filter_order: int, +) -> None: blocksize = velocity_filter_term.shape[1] # Transfer velocity to an array which has periodic boundaries and synchornize boundaries @@ -281,14 +289,14 @@ def _filter_function_periodic_condition_ring_rod( omega_collection[:] = omega_collection_with_periodic_bc[:, 1:-1] -@njit(cache=True) +@njit(cache=True) # type: ignore def _filter_function_periodic_condition( - velocity_collection, - velocity_filter_term, - omega_collection, - omega_filter_term, - filter_order, -): + velocity_collection: NDArray[np.float64], + velocity_filter_term: NDArray[np.float64], + omega_collection: NDArray[np.float64], + omega_filter_term: NDArray[np.float64], + filter_order: int, +) -> None: nb_filter_rate( rate_collection=velocity_collection, filter_term=velocity_filter_term, @@ -301,9 +309,11 @@ def _filter_function_periodic_condition( ) -@njit(cache=True) +@njit(cache=True) # type: ignore def nb_filter_rate( - rate_collection: np.ndarray, filter_term: np.ndarray, filter_order: int + rate_collection: NDArray[np.float64], + filter_term: NDArray[np.float64], + filter_order: int, ) -> None: """ Filters the rod rates (velocities) in numba njit decorator diff --git a/elastica/experimental/connection_contact_joint/parallel_connection.py b/elastica/experimental/connection_contact_joint/parallel_connection.py index 6044aff0..060af8fc 100644 --- a/elastica/experimental/connection_contact_joint/parallel_connection.py +++ b/elastica/experimental/connection_contact_joint/parallel_connection.py @@ -90,10 +90,10 @@ def __init__( self.rod_one_direction_vec_in_material_frame = np.array( rod_one_direction_vec_in_material_frame - ).T + ) self.rod_two_direction_vec_in_material_frame = np.array( rod_two_direction_vec_in_material_frame - ).T + ) # Apply force is same as free joint def apply_forces(self, system_one, index_one, system_two, index_two): @@ -127,7 +127,7 @@ def apply_forces(self, system_one, index_one, system_two, index_two): ) @staticmethod - @njit(cache=True) + @njit(cache=True) # type: ignore def _apply_forces( k, nu, @@ -150,7 +150,6 @@ def _apply_forces( rod_one_external_forces, rod_two_external_forces, ): - rod_one_to_rod_two_connection_vec = _batch_matvec( _batch_matrix_transpose(rod_one_director_collection[:, :, index_one]), rod_one_direction_vec_in_material_frame, @@ -273,7 +272,7 @@ def apply_torques(self, system_one, index_one, system_two, index_two): ) @staticmethod - @njit(cache=True) + @njit(cache=True) # type: ignore def _apply_torques( spring_force, rod_one_rd2, @@ -286,8 +285,8 @@ def _apply_torques( rod_two_external_torques, ): # Compute torques due to the connection forces - torque_on_rod_one = np.cross(rod_one_rd2, spring_force) - torque_on_rod_two = np.cross(rod_two_rd2, -spring_force) + torque_on_rod_one = _batch_cross(rod_one_rd2, spring_force) + torque_on_rod_two = _batch_cross(rod_two_rd2, -spring_force) torque_on_rod_one_material_frame = _batch_matvec( rod_one_director_collection[:, :, index_one], torque_on_rod_one diff --git a/elastica/experimental/interaction.py b/elastica/experimental/interaction.py index cd2c8626..293349ee 100644 --- a/elastica/experimental/interaction.py +++ b/elastica/experimental/interaction.py @@ -3,10 +3,8 @@ import numpy as np from elastica.external_forces import NoForces -from elastica.contact_utils import ( - _find_slipping_elements, -) -from elastica.contact_functions import _calculate_contact_forces_cylinder_plane +from elastica.contact_utils import _find_slipping_elements +from elastica._contact_functions import _calculate_contact_forces_cylinder_plane from elastica.interaction import InteractionPlaneRigidBody from numba import njit diff --git a/elastica/external_forces.py b/elastica/external_forces.py index cb9c61e6..ad15810e 100644 --- a/elastica/external_forces.py +++ b/elastica/external_forces.py @@ -1,17 +1,23 @@ __doc__ = """ Numba implementation module for boundary condition implementations that apply external forces to the system.""" +from typing import TypeVar, Generic import numpy as np +from numpy.typing import NDArray + from elastica._linalg import _batch_matvec -from elastica.typing import SystemType, RodType +from elastica.typing import SystemType, RodType, RigidBodyType from elastica.utils import _bspline from numba import njit from elastica._linalg import _batch_product_i_k_to_ik -class NoForces: +S = TypeVar("S") + + +class NoForces(Generic[S]): """ This is the base class for external forcing boundary conditions applied to rod-like objects. @@ -22,13 +28,13 @@ class NoForces: """ - def __init__(self): + def __init__(self) -> None: """ NoForces class does not need any input parameters. """ pass - def apply_forces(self, system: SystemType, time: np.float64 = 0.0): + def apply_forces(self, system: S, time: np.float64 = np.float64(0.0)) -> None: """Apply forces to a rod-like object. In NoForces class, this routine simply passes. @@ -43,7 +49,7 @@ def apply_forces(self, system: SystemType, time: np.float64 = 0.0): """ pass - def apply_torques(self, system: SystemType, time: np.float64 = 0.0): + def apply_torques(self, system: S, time: np.float64 = np.float64(0.0)) -> None: """Apply torques to a rod-like object. In NoForces class, this routine simply passes. @@ -70,7 +76,12 @@ class GravityForces(NoForces): """ - def __init__(self, acc_gravity=np.array([0.0, -9.80665, 0.0])): + def __init__( + self, + acc_gravity: NDArray[np.float64] = np.array( + [0.0, -9.80665, 0.0] + ), # FIXME: avoid mutable default + ) -> None: """ Parameters @@ -82,14 +93,20 @@ def __init__(self, acc_gravity=np.array([0.0, -9.80665, 0.0])): super(GravityForces, self).__init__() self.acc_gravity = acc_gravity - def apply_forces(self, system: SystemType, time=0.0): + def apply_forces( + self, system: "RodType | RigidBodyType", time: np.float64 = np.float64(0.0) + ) -> None: self.compute_gravity_forces( self.acc_gravity, system.mass, system.external_forces ) @staticmethod - @njit(cache=True) - def compute_gravity_forces(acc_gravity, mass, external_forces): + @njit(cache=True) # type: ignore + def compute_gravity_forces( + acc_gravity: NDArray[np.float64], + mass: NDArray[np.float64], + external_forces: NDArray[np.float64], + ) -> None: """ This function add gravitational forces on the nodes. We are using njit decorated function to increase the speed. @@ -122,7 +139,12 @@ class EndpointForces(NoForces): """ - def __init__(self, start_force, end_force, ramp_up_time): + def __init__( + self, + start_force: NDArray[np.float64], + end_force: NDArray[np.float64], + ramp_up_time: float, + ) -> None: """ Parameters @@ -141,9 +163,11 @@ def __init__(self, start_force, end_force, ramp_up_time): self.start_force = start_force self.end_force = end_force assert ramp_up_time > 0.0 - self.ramp_up_time = ramp_up_time + self.ramp_up_time = np.float64(ramp_up_time) - def apply_forces(self, system: SystemType, time=0.0): + def apply_forces( + self, system: "RodType | RigidBodyType", time: np.float64 = np.float64(0.0) + ) -> None: self.compute_end_point_forces( system.external_forces, self.start_force, @@ -153,10 +177,14 @@ def apply_forces(self, system: SystemType, time=0.0): ) @staticmethod - @njit(cache=True) + @njit(cache=True) # type: ignore def compute_end_point_forces( - external_forces, start_force, end_force, time, ramp_up_time - ): + external_forces: NDArray[np.float64], + start_force: NDArray[np.float64], + end_force: NDArray[np.float64], + time: np.float64, + ramp_up_time: np.float64, + ) -> None: """ Compute end point forces that are applied on the rod using numba njit decorator. @@ -174,7 +202,7 @@ def compute_end_point_forces( Applied forces are ramped up until ramp up time. """ - factor = min(1.0, time / ramp_up_time) + factor = min(np.float64(1.0), time / ramp_up_time) external_forces[..., 0] += start_force * factor external_forces[..., -1] += end_force * factor @@ -190,7 +218,13 @@ class UniformTorques(NoForces): """ - def __init__(self, torque, direction=np.array([0.0, 0.0, 0.0])): + def __init__( + self, + torque: np.float64, + direction: NDArray[np.float64] = np.array( + [0.0, 0.0, 0.0] + ), # FIXME: avoid mutable default + ) -> None: """ Parameters @@ -204,7 +238,9 @@ def __init__(self, torque, direction=np.array([0.0, 0.0, 0.0])): super(UniformTorques, self).__init__() self.torque = torque * direction - def apply_torques(self, system: SystemType, time: np.float64 = 0.0): + def apply_torques( + self, system: "RodType | RigidBodyType", time: np.float64 = np.float64(0.0) + ) -> None: n_elems = system.n_elems torque_on_one_element = ( _batch_product_i_k_to_ik(self.torque, np.ones((n_elems))) / n_elems @@ -224,7 +260,13 @@ class UniformForces(NoForces): 2D (dim, 1) array containing data with 'float' type. Total force applied to a rod-like object. """ - def __init__(self, force, direction=np.array([0.0, 0.0, 0.0])): + def __init__( + self, + force: np.float64, + direction: NDArray[np.float64] = np.array( + [0.0, 0.0, 0.0] + ), # FIXME: avoid mutable default + ) -> None: """ Parameters @@ -238,14 +280,16 @@ def __init__(self, force, direction=np.array([0.0, 0.0, 0.0])): super(UniformForces, self).__init__() self.force = (force * direction).reshape(3, 1) - def apply_forces(self, rod: RodType, time: np.float64 = 0.0): - force_on_one_element = self.force / rod.n_elems + def apply_forces( + self, system: "RodType | RigidBodyType", time: np.float64 = np.float64(0.0) + ) -> None: + force_on_one_element = self.force / system.n_elems - rod.external_forces += force_on_one_element + system.external_forces += force_on_one_element # Because mass of first and last node is half - rod.external_forces[..., 0] -= 0.5 * force_on_one_element[:, 0] - rod.external_forces[..., -1] -= 0.5 * force_on_one_element[:, 0] + system.external_forces[..., 0] -= 0.5 * force_on_one_element[:, 0] + system.external_forces[..., -1] -= 0.5 * force_on_one_element[:, 0] class MuscleTorques(NoForces): @@ -275,16 +319,16 @@ class MuscleTorques(NoForces): def __init__( self, - base_length, - b_coeff, - period, - wave_number, - phase_shift, - direction, - rest_lengths, - ramp_up_time, - with_spline=False, - ): + base_length: float, # TODO: Is this necessary? + b_coeff: NDArray[np.float64], + period: float, + wave_number: float, + phase_shift: float, + direction: NDArray[np.float64], + rest_lengths: NDArray[np.float64], + ramp_up_time: float, + with_spline: bool = False, + ) -> None: """ Parameters @@ -302,7 +346,7 @@ def __init__( Phase shift of traveling wave. direction: numpy.ndarray 1D (dim) array containing data with 'float' type. Muscle torque direction. - ramp_up_time: float + ramp_up_time: np.float64 Applied muscle torques are ramped up until ramp up time. with_spline: boolean Option to use beta-spline. @@ -311,12 +355,12 @@ def __init__( super(MuscleTorques, self).__init__() self.direction = direction # Direction torque applied - self.angular_frequency = 2.0 * np.pi / period - self.wave_number = wave_number - self.phase_shift = phase_shift + self.angular_frequency = np.float64(2.0 * np.pi / period) + self.wave_number = np.float64(wave_number) + self.phase_shift = np.float64(phase_shift) assert ramp_up_time > 0.0 - self.ramp_up_time = ramp_up_time + self.ramp_up_time = np.float64(ramp_up_time) # s is the position of nodes on the rod, we go from node=1 to node=nelem-1, because there is no # torques applied by first and last node on elements. Reason is that we cannot apply torque in an @@ -335,7 +379,9 @@ def __init__( else: self.my_spline = np.full_like(self.s, fill_value=1.0) - def apply_torques(self, rod: RodType, time: np.float64 = 0.0): + def apply_torques( + self, system: "RodType | RigidBodyType", time: np.float64 = np.float64(0.0) + ) -> None: self.compute_muscle_torques( time, self.my_spline, @@ -345,26 +391,26 @@ def apply_torques(self, rod: RodType, time: np.float64 = 0.0): self.phase_shift, self.ramp_up_time, self.direction, - rod.director_collection, - rod.external_torques, + system.director_collection, + system.external_torques, ) @staticmethod - @njit(cache=True) + @njit(cache=True) # type: ignore def compute_muscle_torques( - time, - my_spline, - s, - angular_frequency, - wave_number, - phase_shift, - ramp_up_time, - direction, - director_collection, - external_torques, - ): + time: float, + my_spline: NDArray[np.float64], + s: np.float64, + angular_frequency: np.float64, + wave_number: np.float64, + phase_shift: np.float64, + ramp_up_time: np.float64, + direction: NDArray[np.float64], + director_collection: NDArray[np.float64], + external_torques: NDArray[np.float64], + ) -> None: # Ramp up the muscle torque - factor = min(1.0, time / ramp_up_time) + factor = np.float64(min(np.float64(1.0), time / ramp_up_time)) # From the node 1 to node nelem-1 # Magnitude of the torque. Am = beta(s) * sin(2pi*t/T + 2pi*s/lambda + phi) # There is an inconsistency with paper and Elastica cpp implementation. In paper sign in @@ -387,8 +433,11 @@ def compute_muscle_torques( ) -@njit(cache=True) -def inplace_addition(external_force_or_torque, force_or_torque): +@njit(cache=True) # type: ignore +def inplace_addition( + external_force_or_torque: NDArray[np.float64], + force_or_torque: NDArray[np.float64], +) -> None: """ This function does inplace addition. First argument `external_force_or_torque` is the system.external_forces @@ -410,8 +459,11 @@ def inplace_addition(external_force_or_torque, force_or_torque): external_force_or_torque[i, k] += force_or_torque[i, k] -@njit(cache=True) -def inplace_substraction(external_force_or_torque, force_or_torque): +@njit(cache=True) # type: ignore +def inplace_substraction( + external_force_or_torque: NDArray[np.float64], + force_or_torque: NDArray[np.float64], +) -> None: """ This function does inplace substraction. First argument `external_force_or_torque` is the system.external_forces @@ -460,12 +512,16 @@ class EndpointForcesSinusoidal(NoForces): def __init__( self, - start_force_mag, - end_force_mag, - ramp_up_time=0.0, - tangent_direction=np.array([0, 0, 1]), - normal_direction=np.array([0, 1, 0]), - ): + start_force_mag: float, + end_force_mag: float, + ramp_up_time: float = 0.0, + tangent_direction: NDArray[np.floating] = np.array( + [0.0, 0.0, 1.0] + ), # FIXME: avoid mutable default + normal_direction: NDArray[np.floating] = np.array( + [0.0, 1.0, 0.0] + ), # FIXME: avoid mutable default + ) -> None: """ Parameters @@ -485,17 +541,19 @@ def __init__( """ super(EndpointForcesSinusoidal, self).__init__() # Start force - self.start_force_mag = start_force_mag - self.end_force_mag = end_force_mag + self.start_force_mag = np.float64(start_force_mag) + self.end_force_mag = np.float64(end_force_mag) # Applied force directions self.normal_direction = normal_direction self.roll_direction = np.cross(normal_direction, tangent_direction) assert ramp_up_time >= 0.0 - self.ramp_up_time = ramp_up_time + self.ramp_up_time = np.float64(ramp_up_time) - def apply_forces(self, system: SystemType, time=0.0): + def apply_forces( + self, system: "RodType | RigidBodyType", time: np.float64 = np.float64(0.0) + ) -> None: if time < self.ramp_up_time: # When time smaller than ramp up time apply the force in normal direction diff --git a/elastica/interaction.py b/elastica/interaction.py index 95d4602e..1a72650d 100644 --- a/elastica/interaction.py +++ b/elastica/interaction.py @@ -14,44 +14,14 @@ _calculate_contact_forces_cylinder_plane, ) +from numpy.typing import NDArray -def find_slipping_elements(velocity_slip, velocity_threshold): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._find_slipping_elements()\n" - "instead for finding slipping elements." - ) - - -def node_to_element_mass_or_force(input): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._node_to_element_mass_or_force()\n" - "instead for converting the mass/forces on rod nodes to elements." - ) - - -def nodes_to_elements(input): - # Remove the function beyond v0.4.0 - raise NotImplementedError( - "This function is removed in v0.3.1. Please use\n" - "elastica.interaction.node_to_element_mass_or_force()\n" - "instead for node-to-element interpolation of mass/forces." - ) - - -@njit(cache=True) -def elements_to_nodes_inplace(vector_in_element_frame, vector_in_node_frame): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._elements_to_nodes_inplace()\n" - "instead for updating nodal forces using the forces computed on elements." - ) +from elastica.typing import SystemType, RodType, RigidBodyType # base class for interaction # only applies normal force no friction -class InteractionPlane: +class InteractionPlane(NoForces): """ The interaction plane class computes the plane reaction force on a rod-like object. For more details regarding the contact module refer to @@ -74,7 +44,13 @@ class InteractionPlane: """ - def __init__(self, k, nu, plane_origin, plane_normal): + def __init__( + self, + k: float, + nu: float, + plane_origin: NDArray[np.float64], + plane_normal: NDArray[np.float64], + ) -> None: """ Parameters @@ -90,13 +66,13 @@ def __init__(self, k, nu, plane_origin, plane_normal): 2D (dim, 1) array containing data with 'float' type. The normal vector of the plane. """ - self.k = k - self.nu = nu + self.k = np.float64(k) + self.nu = np.float64(nu) + self.surface_tol = np.float64(1e-4) self.plane_origin = plane_origin.reshape(3, 1) self.plane_normal = plane_normal.reshape(3) - self.surface_tol = 1e-4 - def apply_normal_force(self, system): + def apply_forces(self, system: RodType, time: np.float64 = np.float64(0.0)) -> None: """ In the case of contact with the plane, this function computes the plane reaction force on the element. @@ -129,33 +105,13 @@ def apply_normal_force(self, system): ) -def apply_normal_force_numba( - plane_origin, - plane_normal, - surface_tol, - k, - nu, - radius, - mass, - position_collection, - velocity_collection, - internal_forces, - external_forces, -): - raise NotImplementedError( - "This function is removed in v0.3.2. For rod plane contact please use: \n" - "elastica._contact_functions._calculate_contact_forces_rod_plane() \n" - "For detail, refer to issue #113." - ) - - # class for anisotropic frictional plane # NOTE: friction coefficients are passed as arrays in the order # mu_forward : mu_backward : mu_sideways # head is at x[0] and forward means head to tail # same convention for kinetic and static # mu named as to which direction it opposes -class AnisotropicFrictionalPlane(NoForces, InteractionPlane): +class AnisotropicFrictionalPlane(InteractionPlane): """ This anisotropic friction plane class is for computing anisotropic friction forces on rods. @@ -186,14 +142,14 @@ class AnisotropicFrictionalPlane(NoForces, InteractionPlane): def __init__( self, - k, - nu, - plane_origin, - plane_normal, - slip_velocity_tol, - static_mu_array, - kinetic_mu_array, - ): + k: float, + nu: float, + plane_origin: NDArray[np.float64], + plane_normal: NDArray[np.float64], + slip_velocity_tol: float, + static_mu_array: NDArray[np.float64], + kinetic_mu_array: NDArray[np.float64], + ) -> None: """ Parameters @@ -218,7 +174,7 @@ def __init__( [forward, backward, sideways] kinetic friction coefficients. """ InteractionPlane.__init__(self, k, nu, plane_origin, plane_normal) - self.slip_velocity_tol = slip_velocity_tol + self.slip_velocity_tol = np.float64(slip_velocity_tol) ( self.static_mu_forward, self.static_mu_backward, @@ -232,12 +188,14 @@ def __init__( # kinetic and static friction should separate functions # for now putting them together to figure out common variables - def apply_forces(self, system, time=0.0): + def apply_forces( + self, system: "RodType | RigidBodyType", time: np.float64 = np.float64(0.0) + ) -> None: """ Call numba implementation to apply friction forces Parameters ---------- - system + system : RodType | RigidBodyType time """ @@ -268,41 +226,9 @@ def apply_forces(self, system, time=0.0): ) -def anisotropic_friction( - plane_origin, - plane_normal, - surface_tol, - slip_velocity_tol, - k, - nu, - kinetic_mu_forward, - kinetic_mu_backward, - kinetic_mu_sideways, - static_mu_forward, - static_mu_backward, - static_mu_sideways, - radius, - mass, - tangents, - position_collection, - director_collection, - velocity_collection, - omega_collection, - internal_forces, - external_forces, - internal_torques, - external_torques, -): - raise NotImplementedError( - "This function is removed in v0.3.2. For anisotropic_friction please use: \n" - "elastica._contact_functions._calculate_contact_forces_rod_plane_with_anisotropic_friction() \n" - "For detail, refer to issue #113." - ) - - # Slender body module -@njit(cache=True) -def sum_over_elements(input): +@njit(cache=True) # type: ignore +def sum_over_elements(input: NDArray[np.float64]) -> np.float64: """ This function sums all elements of the input array. Using a Numba njit decorator shows better performance @@ -334,43 +260,22 @@ def sum_over_elements(input): This version: 513 ns ± 24.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) """ - output = 0.0 + output: np.float64 = np.float64(0.0) for i in range(input.shape[0]): output += input[i] return output -def node_to_element_position(node_position_collection): - raise NotImplementedError( - "This function is removed in v0.3.2. For node-to-element_position() interpolation please use: \n" - "elastica.contact_utils._node_to_element_position() for rod position \n" - "For detail, refer to issue #113." - ) - - -def node_to_element_velocity(mass, node_velocity_collection): - raise NotImplementedError( - "This function is removed in v0.3.2. For node-to-element_velocity() interpolation please use: \n" - "elastica.contact_utils._node_to_element_velocity() for rod velocity. \n" - "For detail, refer to issue #113." - ) - - -def node_to_element_pos_or_vel(vector_in_node_frame): - # Remove the function beyond v0.4.0 - raise NotImplementedError( - "This function is removed in v0.3.0. For node-to-element interpolation please use: \n" - "elastica.contact_utils._node_to_element_position() for rod position \n" - "elastica.contact_utils._node_to_element_velocity() for rod velocity. \n" - "For detail, refer to issue #80." - ) - - -@njit(cache=True) +@njit(cache=True) # type: ignore def slender_body_forces( - tangents, velocity_collection, dynamic_viscosity, lengths, radius, mass -): + tangents: NDArray[np.float64], + velocity_collection: NDArray[np.float64], + dynamic_viscosity: np.float64, + lengths: NDArray[np.float64], + radius: NDArray[np.float64], + mass: NDArray[np.float64], +) -> NDArray[np.float64]: r""" This function computes hydrodynamic forces on a body using slender body theory. The below implementation is from Eq. 4.13 in Gazzola et al. RSoS. (2018). @@ -481,7 +386,7 @@ class SlenderBodyTheory(NoForces): """ - def __init__(self, dynamic_viscosity): + def __init__(self, dynamic_viscosity: float) -> None: """ Parameters @@ -490,9 +395,9 @@ def __init__(self, dynamic_viscosity): Dynamic viscosity of the fluid. """ super(SlenderBodyTheory, self).__init__() - self.dynamic_viscosity = dynamic_viscosity + self.dynamic_viscosity = np.float64(dynamic_viscosity) - def apply_forces(self, system, time=0.0): + def apply_forces(self, system: RodType, time: np.float64 = np.float64(0.0)) -> None: """ This function applies hydrodynamic forces on body using the slender body theory given in @@ -517,15 +422,23 @@ def apply_forces(self, system, time=0.0): # base class for interaction # only applies normal force no friction -class InteractionPlaneRigidBody: - def __init__(self, k, nu, plane_origin, plane_normal): - self.k = k - self.nu = nu +class InteractionPlaneRigidBody(NoForces): + def __init__( + self, + k: float, + nu: float, + plane_origin: NDArray[np.float64], + plane_normal: NDArray[np.float64], + ) -> None: + self.k = np.float64(k) + self.nu = np.float64(nu) + self.surface_tol = np.float64(1e-4) self.plane_origin = plane_origin.reshape(3, 1) self.plane_normal = plane_normal.reshape(3) - self.surface_tol = 1e-4 - def apply_normal_force(self, system): + def apply_forces( + self, system: RigidBodyType, time: np.float64 = np.float64(0.0) + ) -> None: """ This function computes the plane force response on the rigid body, in the case of contact. Contact model given in Eqn 4.8 Gazzola et. al. RSoS 2018 paper @@ -538,7 +451,7 @@ def apply_normal_force(self, system): ------- magnitude of the plane response """ - return _calculate_contact_forces_cylinder_plane( + _calculate_contact_forces_cylinder_plane( self.plane_origin, self.plane_normal, self.surface_tol, @@ -549,23 +462,3 @@ def apply_normal_force(self, system): system.velocity_collection, system.external_forces, ) - - -@njit(cache=True) -def apply_normal_force_numba_rigid_body( - plane_origin, - plane_normal, - surface_tol, - k, - nu, - length, - position_collection, - velocity_collection, - external_forces, -): - - raise NotImplementedError( - "This function is removed in v0.3.2. For cylinder plane contact please use: \n" - "elastica._contact_functions._calculate_contact_forces_cylinder_plane() \n" - "For detail, refer to issue #113." - ) diff --git a/elastica/joint.py b/elastica/joint.py index e37d6f05..1cf2dd3e 100644 --- a/elastica/joint.py +++ b/elastica/joint.py @@ -1,8 +1,11 @@ __doc__ = """ Module containing joint classes to connect multiple rods together. """ +__all__ = ["FreeJoint", "HingeJoint", "FixedJoint", "get_relative_rotation_two_systems"] + from elastica._rotations import _inv_rotate -from elastica.typing import SystemType, RodType +from elastica.typing import SystemType, RodType, ConnectionIndex, RigidBodyType + import numpy as np -import logging +from numpy.typing import NDArray class FreeJoint: @@ -27,7 +30,7 @@ class FreeJoint: # pass the k and nu for the forces # also the necessary rods for the joint # indices should be 0 or -1, we will provide wrappers for users later - def __init__(self, k, nu): + def __init__(self, k: float, nu: float) -> None: """ Parameters @@ -38,24 +41,28 @@ def __init__(self, k, nu): Damping coefficient of the joint. """ - self.k = k - self.nu = nu + self.k = np.float64(k) + self.nu = np.float64(nu) def apply_forces( - self, system_one: SystemType, index_one, system_two: SystemType, index_two - ): + self, + system_one: "RodType | RigidBodyType", + index_one: ConnectionIndex, + system_two: "RodType | RigidBodyType", + index_two: ConnectionIndex, + ) -> None: """ Apply joint force to the connected rod objects. Parameters ---------- - system_one : object + system_one : RodType | RigidBodyType Rod or rigid-body object - index_one : int + index_one : ConnectionIndex Index of first rod for joint. - system_two : object + system_two : RodType | RigidBodyType Rod or rigid-body object - index_two : int + index_two : ConnectionIndex Index of second rod for joint. Returns @@ -81,8 +88,12 @@ def apply_forces( return def apply_torques( - self, system_one: SystemType, index_one, system_two: SystemType, index_two - ): + self, + system_one: "RodType | RigidBodyType", + index_one: ConnectionIndex, + system_two: "RodType | RigidBodyType", + index_two: ConnectionIndex, + ) -> None: """ Apply restoring joint torques to the connected rod objects. @@ -90,13 +101,13 @@ def apply_torques( Parameters ---------- - system_one : object + system_one : RodType | RigidBodyType Rod or rigid-body object - index_one : int + index_one : ConnectionIndex Index of first rod for joint. - system_two : object + system_two : RodType | RigidBodyType Rod or rigid-body object - index_two : int + index_two : ConnectionIndex Index of second rod for joint. Returns @@ -127,7 +138,13 @@ class HingeJoint(FreeJoint): """ # TODO: IN WRAPPER COMPUTE THE NORMAL DIRECTION OR ASK USER TO GIVE INPUT, IF NOT THROW ERROR - def __init__(self, k, nu, kt, normal_direction): + def __init__( + self, + k: float, + nu: float, + kt: float, + normal_direction: NDArray[np.float64], + ) -> None: """ Parameters @@ -148,25 +165,25 @@ def __init__(self, k, nu, kt, normal_direction): self.normal_direction = normal_direction / np.linalg.norm(normal_direction) # additional in-plane constraint through restoring torque # stiffness of the restoring constraint -- tuned empirically - self.kt = kt + self.kt = np.float64(kt) # Apply force is same as free joint def apply_forces( self, - system_one: SystemType, - index_one, - system_two: SystemType, - index_two, - ): + system_one: "RodType | RigidBodyType", + index_one: ConnectionIndex, + system_two: "RodType | RigidBodyType", + index_two: ConnectionIndex, + ) -> None: return super().apply_forces(system_one, index_one, system_two, index_two) def apply_torques( self, - system_one: SystemType, - index_one, - system_two: SystemType, - index_two, - ): + system_one: "RodType | RigidBodyType", + index_one: ConnectionIndex, + system_two: "RodType | RigidBodyType", + index_two: ConnectionIndex, + ) -> None: # current tangent direction of the `index_two` element of system two system_two_tangent = system_two.director_collection[2, :, index_two] @@ -215,7 +232,14 @@ class FixedJoint(FreeJoint): is enforced. """ - def __init__(self, k, nu, kt, nut=0.0, rest_rotation_matrix=None): + def __init__( + self, + k: float, + nu: float, + kt: float, + nut: float = 0.0, + rest_rotation_matrix: NDArray[np.float64] | None = None, + ) -> None: """ Parameters @@ -228,7 +252,7 @@ def __init__(self, k, nu, kt, nut=0.0, rest_rotation_matrix=None): Rotational stiffness coefficient of the joint. nut: float = 0. Rotational damping coefficient of the joint. - rest_rotation_matrix: np.array + rest_rotation_matrix: np.array | None 2D (3,3) array containing data with 'float' type. Rest 3x3 rotation matrix from system one to system two at the connected elements. If provided, the rest rotation matrix is enforced between the two systems throughout the simulation. @@ -239,8 +263,8 @@ def __init__(self, k, nu, kt, nut=0.0, rest_rotation_matrix=None): super().__init__(k, nu) # additional in-plane constraint through restoring torque # stiffness of the restoring constraint -- tuned empirically - self.kt = kt - self.nut = nut + self.kt = np.float64(kt) + self.nut = np.float64(nut) # TODO: compute the rest rotation matrix directly during initialization # as soon as systems (e.g. `rod_one` and `rod_two`) and indices (e.g. `index_one` and `index_two`) @@ -253,20 +277,20 @@ def __init__(self, k, nu, kt, nut=0.0, rest_rotation_matrix=None): # Apply force is same as free joint def apply_forces( self, - system_one: SystemType, - index_one, - system_two: SystemType, - index_two, - ): + system_one: "RodType | RigidBodyType", + index_one: ConnectionIndex, + system_two: "RodType | RigidBodyType", + index_two: ConnectionIndex, + ) -> None: return super().apply_forces(system_one, index_one, system_two, index_two) def apply_torques( self, - system_one: SystemType, - index_one, - system_two: SystemType, - index_two, - ): + system_one: "RodType | RigidBodyType", + index_one: ConnectionIndex, + system_two: "RodType | RigidBodyType", + index_two: ConnectionIndex, + ) -> None: # collect directors of systems one and two # note that systems can be either rods or rigid bodies system_one_director = system_one.director_collection[..., index_one] @@ -310,11 +334,11 @@ def apply_torques( def get_relative_rotation_two_systems( - system_one: SystemType, - index_one, - system_two: SystemType, - index_two, -): + system_one: "RodType | RigidBodyType", + index_one: ConnectionIndex, + system_two: "RodType | RigidBodyType", + index_two: ConnectionIndex, +) -> NDArray[np.float64]: """ Compute the relative rotation matrix C_12 between system one and system two at the specified elements. @@ -341,13 +365,13 @@ def get_relative_rotation_two_systems( Parameters ---------- - system_one : SystemType + system_one : RodType | RigidBodyType Rod or rigid-body object - index_one : int + index_one : ConnectionIndex Index of first rod for joint. - system_two : SystemType + system_two : RodType | RigidBodyType Rod or rigid-body object - index_two : int + index_two : ConnectionIndex Index of second rod for joint. Returns @@ -359,367 +383,3 @@ def get_relative_rotation_two_systems( system_one.director_collection[..., index_one] @ system_two.director_collection[..., index_two].T ) - - -# everything below this comment should be removed beyond v0.4.0 -def _dot_product(a, b): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._dot_product()\n" - "instead for find the dot product between a and b." - ) - - -def _norm(a): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._norm()\n" - "instead for finding the norm of a." - ) - - -def _clip(x, low, high): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._clip()\n" - "instead for clipping x." - ) - - -def _out_of_bounds(x, low, high): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._out_of_bounds()\n" - "instead for checking if x is out of bounds." - ) - - -def _find_min_dist(x1, e1, x2, e2): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._find_min_dist()\n" - "instead for finding minimum distance between contact points." - ) - - -def _calculate_contact_forces_rod_rigid_body( - x_collection_rod, - edge_collection_rod, - x_cylinder_center, - x_cylinder_tip, - edge_cylinder, - radii_sum, - length_sum, - internal_forces_rod, - external_forces_rod, - external_forces_cylinder, - external_torques_cylinder, - cylinder_director_collection, - velocity_rod, - velocity_cylinder, - contact_k, - contact_nu, - velocity_damping_coefficient, - friction_coefficient, -): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica._contact_functions._calculate_contact_forces_rod_cylinder()\n" - "instead for calculating rod cylinder contact forces." - ) - - -def _calculate_contact_forces_rod_rod( - x_collection_rod_one, - radius_rod_one, - length_rod_one, - tangent_rod_one, - velocity_rod_one, - internal_forces_rod_one, - external_forces_rod_one, - x_collection_rod_two, - radius_rod_two, - length_rod_two, - tangent_rod_two, - velocity_rod_two, - internal_forces_rod_two, - external_forces_rod_two, - contact_k, - contact_nu, -): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica._contact_functions._calculate_contact_forces_rod_rod()\n" - "instead for calculating rod rod contact forces." - ) - - -def _calculate_contact_forces_self_rod( - x_collection_rod, - radius_rod, - length_rod, - tangent_rod, - velocity_rod, - external_forces_rod, - contact_k, - contact_nu, -): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica._contact_functions._calculate_contact_forces_self_rod()\n" - "instead for calculating rod self-contact forces." - ) - - -def _aabbs_not_intersecting(aabb_one, aabb_two): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._aabbs_not_intersecting()\n" - "instead for checking aabbs intersection." - ) - - -def _prune_using_aabbs_rod_rigid_body( - rod_one_position_collection, - rod_one_radius_collection, - rod_one_length_collection, - cylinder_position, - cylinder_director, - cylinder_radius, - cylinder_length, -): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._prune_using_aabbs_rod_cylinder()\n" - "instead for checking rod cylinder intersection." - ) - - -def _prune_using_aabbs_rod_rod( - rod_one_position_collection, - rod_one_radius_collection, - rod_one_length_collection, - rod_two_position_collection, - rod_two_radius_collection, - rod_two_length_collection, -): - raise NotImplementedError( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._prune_using_aabbs_rod_rod()\n" - "instead for checking rod rod intersection." - ) - - -class ExternalContact(FreeJoint): - """ - This class is for applying contact forces between rod-cylinder and rod-rod. - If you are want to apply contact forces between rod and cylinder, first system is always rod and second system - is always cylinder. - In addition to the contact forces, user can define apply friction forces between rod and cylinder that - are in contact. For details on friction model refer to this [1]_. - TODO: Currently friction force is between rod-cylinder, in future implement friction forces between rod-rod. - - Notes - ----- - The `velocity_damping_coefficient` is set to a high value (e.g. 1e4) to minimize slip and simulate stiction - (static friction), while friction_coefficient corresponds to the Coulombic friction coefficient. - - Examples - -------- - How to define contact between rod and cylinder. - - >>> simulator.connect(rod, cylinder).using( - ... ExternalContact, - ... k=1e4, - ... nu=10, - ... velocity_damping_coefficient=10, - ... kinetic_friction_coefficient=10, - ... ) - - How to define contact between rod and rod. - - >>> simulator.connect(rod, rod).using( - ... ExternalContact, - ... k=1e4, - ... nu=10, - ... ) - - .. [1] Preclik T., Popa Constantin., Rude U., Regularizing a Time-Stepping Method for Rigid Multibody Dynamics, Multibody Dynamics 2011, ECCOMAS. URL: https://www10.cs.fau.de/publications/papers/2011/Preclik_Multibody_Ext_Abstr.pdf - """ - - # Dev note: - # Most of the cylinder-cylinder contact SHOULD be implemented - # as given in this `paper `, - # but the elastica-cpp kernels are implemented. - # This is maybe to speed-up the kernel, but it's - # potentially dangerous as it does not deal with "end" conditions - # correctly. - - def __init__(self, k, nu, velocity_damping_coefficient=0, friction_coefficient=0): - """ - - Parameters - ---------- - k : float - Contact spring constant. - nu : float - Contact damping constant. - velocity_damping_coefficient : float - Velocity damping coefficient between rigid-body and rod contact is used to apply friction force in the - slip direction. - friction_coefficient : float - For Coulombic friction coefficient for rigid-body and rod contact. - """ - super().__init__(k, nu) - self.velocity_damping_coefficient = velocity_damping_coefficient - self.friction_coefficient = friction_coefficient - log = logging.getLogger(self.__class__.__name__) - log.warning( - # Remove warning and add error if ExternalContact is used in v0.3.3 - # Remove the option to use ExternalContact, beyond v0.3.3 - "The option to use the ExternalContact joint for the rod-rod and rod-cylinder contact is now deprecated.\n" - "Instead, for rod-rod contact or rod-cylinder contact,use RodRodContact or RodCylinderContact from the add-on Contact mixin class.\n" - "For reference see the classes elastica.contact_forces.RodRodContact() and elastica.contact_forces.RodCylinderContact().\n" - "For usage check examples/RigidbodyCases/RodRigidBodyContact/rod_cylinder_contact.py and examples/RodContactCase/RodRodContact/rod_rod_contact_parallel_validation.py.\n" - " The option to use the ExternalContact joint for the rod-rod and rod-cylinder will be removed in the future (v0.3.3).\n" - ) - - def apply_forces( - self, - rod_one: RodType, - index_one, - rod_two: SystemType, - index_two, - ): - # del index_one, index_two - from elastica.contact_utils import ( - _prune_using_aabbs_rod_cylinder, - _prune_using_aabbs_rod_rod, - ) - from elastica._contact_functions import ( - _calculate_contact_forces_rod_cylinder, - _calculate_contact_forces_rod_rod, - ) - - # TODO: raise error during the initialization if rod one is rigid body. - - # If rod two has one element, then it is rigid body. - if rod_two.n_elems == 1: - cylinder_two = rod_two - # First, check for a global AABB bounding box, and see whether that - # intersects - if _prune_using_aabbs_rod_cylinder( - rod_one.position_collection, - rod_one.radius, - rod_one.lengths, - cylinder_two.position_collection, - cylinder_two.director_collection, - cylinder_two.radius[0], - cylinder_two.length[0], - ): - return - - x_cyl = ( - cylinder_two.position_collection[..., 0] - - 0.5 * cylinder_two.length * cylinder_two.director_collection[2, :, 0] - ) - - rod_element_position = 0.5 * ( - rod_one.position_collection[..., 1:] - + rod_one.position_collection[..., :-1] - ) - _calculate_contact_forces_rod_cylinder( - rod_element_position, - rod_one.lengths * rod_one.tangents, - cylinder_two.position_collection[..., 0], - x_cyl, - cylinder_two.length * cylinder_two.director_collection[2, :, 0], - rod_one.radius + cylinder_two.radius, - rod_one.lengths + cylinder_two.length, - rod_one.internal_forces, - rod_one.external_forces, - cylinder_two.external_forces, - cylinder_two.external_torques, - cylinder_two.director_collection[:, :, 0], - rod_one.velocity_collection, - cylinder_two.velocity_collection, - self.k, - self.nu, - self.velocity_damping_coefficient, - self.friction_coefficient, - ) - - else: - # First, check for a global AABB bounding box, and see whether that - # intersects - - if _prune_using_aabbs_rod_rod( - rod_one.position_collection, - rod_one.radius, - rod_one.lengths, - rod_two.position_collection, - rod_two.radius, - rod_two.lengths, - ): - return - - _calculate_contact_forces_rod_rod( - rod_one.position_collection[ - ..., :-1 - ], # Discount last node, we want element start position - rod_one.radius, - rod_one.lengths, - rod_one.tangents, - rod_one.velocity_collection, - rod_one.internal_forces, - rod_one.external_forces, - rod_two.position_collection[ - ..., :-1 - ], # Discount last node, we want element start position - rod_two.radius, - rod_two.lengths, - rod_two.tangents, - rod_two.velocity_collection, - rod_two.internal_forces, - rod_two.external_forces, - self.k, - self.nu, - ) - - -class SelfContact(FreeJoint): - """ - This class is modeling self contact of rod. - - """ - - def __init__(self, k, nu): - super().__init__(k, nu) - log = logging.getLogger(self.__class__.__name__) - log.warning( - # Remove warning and add error if SelfContact is used in v0.3.3 - # Remove the option to use SelfContact, beyond v0.3.3 - "The option to use the SelfContact joint for the rod self contact is now deprecated.\n" - "Instead, for rod self contact use RodSelfContact from the add-on Contact mixin class.\n" - "For reference see the class elastica.contact_forces.RodSelfContact(), and for usage check examples/RodContactCase/RodSelfContact/solenoids.py.\n" - "The option to use the SelfContact joint for the rod self contact will be removed in the future (v0.3.3).\n" - ) - - def apply_forces(self, rod_one: RodType, index_one, rod_two: SystemType, index_two): - # del index_one, index_two - from elastica._contact_functions import ( - _calculate_contact_forces_self_rod, - ) - - _calculate_contact_forces_self_rod( - rod_one.position_collection[ - ..., :-1 - ], # Discount last node, we want element start position - rod_one.radius, - rod_one.lengths, - rod_one.tangents, - rod_one.velocity_collection, - rod_one.external_forces, - self.k, - self.nu, - ) diff --git a/elastica/memory_block/memory_block_rigid_body.py b/elastica/memory_block/memory_block_rigid_body.py index 4dd7fac4..12c1caea 100644 --- a/elastica/memory_block/memory_block_rigid_body.py +++ b/elastica/memory_block/memory_block_rigid_body.py @@ -1,18 +1,21 @@ __doc__ = """Create block-structure class for collection of rigid body systems.""" +from typing import Literal import numpy as np -from typing import Sequence, Literal +from elastica.typing import SystemIdxType, RigidBodyType from elastica.rigidbody import RigidBodyBase from elastica.rigidbody.data_structures import _RigidRodSymplecticStepperMixin class MemoryBlockRigidBody(RigidBodyBase, _RigidRodSymplecticStepperMixin): - def __init__(self, systems: Sequence, system_idx_list: Sequence[np.int64]): + def __init__( + self, systems: list[RigidBodyType], system_idx_list: list[SystemIdxType] + ) -> None: - self.n_bodies = len(systems) - self.n_elems = self.n_bodies + self.n_systems = len(systems) + self.n_elems = self.n_systems self.n_nodes = self.n_elems - self.system_idx_list = np.array(system_idx_list, dtype=np.int64) + self.system_idx_list = np.array(system_idx_list, dtype=np.int32) # Allocate block structure using system collection. self._allocate_block_variables_scalars(systems) @@ -23,7 +26,7 @@ def __init__(self, systems: Sequence, system_idx_list: Sequence[np.int64]): # Initialize the mixin class for symplectic time-stepper. _RigidRodSymplecticStepperMixin.__init__(self) - def _allocate_block_variables_scalars(self, systems: Sequence): + def _allocate_block_variables_scalars(self, systems: list[RigidBodyType]) -> None: """ This function takes system collection and allocates the variables for block-structure and references allocated variables back to the systems. @@ -58,7 +61,7 @@ def _allocate_block_variables_scalars(self, systems: Sequence): value_type="scalar", ) - def _allocate_block_variables_vectors(self, systems: Sequence): + def _allocate_block_variables_vectors(self, systems: list[RigidBodyType]) -> None: """ This function takes system collection and allocates the vector variables for block-structure and references allocated vector variables back to the systems. @@ -94,7 +97,7 @@ def _allocate_block_variables_vectors(self, systems: Sequence): value_type="vector", ) - def _allocate_block_variables_matrix(self, systems: Sequence): + def _allocate_block_variables_matrix(self, systems: list[RigidBodyType]) -> None: """ This function takes system collection and allocates the matrix variables for block-structure and references allocated matrix variables back to the systems. @@ -130,7 +133,9 @@ def _allocate_block_variables_matrix(self, systems: Sequence): value_type="tensor", ) - def _allocate_block_variables_for_symplectic_stepper(self, systems: Sequence): + def _allocate_block_variables_for_symplectic_stepper( + self, systems: list[RigidBodyType] + ) -> None: """ This function takes system collection and allocates the variables used by symplectic stepper for block-structure and references allocated variables back to the systems. @@ -176,7 +181,7 @@ def _allocate_block_variables_for_symplectic_stepper(self, systems: Sequence): def _map_system_properties_to_block_memory( self, mapping_dict: dict, - systems: Sequence, + systems: list[RigidBodyType], block_memory: np.ndarray, value_type: Literal["scalar", "vector", "tensor"], ) -> None: @@ -186,8 +191,8 @@ def _map_system_properties_to_block_memory( ---------- mapping_dict: dict Dictionary with attribute names as keys and block row index as values. - systems: Sequence - A sequence containing Cosserat rod objects to map from. + systems: list[RigidBodyType] + A sequence containing rigid body objects to map from. block_memory: ndarray Memory block that, at the end of the method execution, contains all designated attributes of all systems. @@ -197,7 +202,7 @@ def _map_system_properties_to_block_memory( """ if value_type == "scalar": - view_shape = (self.n_elems,) + view_shape: tuple[int, ...] = (self.n_elems,) elif value_type == "vector": view_shape = (3, self.n_elems) diff --git a/elastica/memory_block/memory_block_rod.py b/elastica/memory_block/memory_block_rod.py index b6a6b73f..a47998ba 100644 --- a/elastica/memory_block/memory_block_rod.py +++ b/elastica/memory_block/memory_block_rod.py @@ -1,11 +1,7 @@ __doc__ = """Create block-structure class for collection of Cosserat rod systems.""" import numpy as np -from typing import Sequence, Literal, Callable -from elastica.memory_block.memory_block_rod_base import ( - MemoryBlockRodBase, - make_block_memory_metadata, - make_block_memory_periodic_boundary_metadata, -) +from typing import Literal, Callable +from elastica.typing import SystemIdxType, RodType from elastica.rod.data_structures import _RodSymplecticStepperMixin from elastica.reset_functions_for_block_structure import _reset_scalar_ghost from elastica.rod.cosserat_rod import ( @@ -18,10 +14,13 @@ _synchronize_periodic_boundary_of_matrix_collection, ) +from .utils import ( + make_block_memory_metadata, + make_block_memory_periodic_boundary_metadata, +) + -class MemoryBlockCosseratRod( - MemoryBlockRodBase, CosseratRod, _RodSymplecticStepperMixin -): +class MemoryBlockCosseratRod(CosseratRod, _RodSymplecticStepperMixin): """ Memory block class for Cosserat rod equations. This class is derived from Cosserat Rod class in order to inherit the methods of Cosserat rod class. This class takes the cosserat rod object (systems) and creates big @@ -31,7 +30,10 @@ class MemoryBlockCosseratRod( TODO: need more documentation! """ - def __init__(self, systems: Sequence, system_idx_list): + def __init__( + self, systems: list[RodType], system_idx_list: list[SystemIdxType] + ) -> None: + self.n_systems = len(systems) # separate straight and ring rods system_straight_rod = [] @@ -51,20 +53,20 @@ def __init__(self, systems: Sequence, system_idx_list): # Sorted systems systems = system_straight_rod + system_ring_rod self.system_idx_list = np.array( - system_idx_list_straight_rod + system_idx_list_ring_rod, dtype=np.int64 + system_idx_list_straight_rod + system_idx_list_ring_rod, dtype=np.int32 ) n_elems_straight_rods = np.array( - [x.n_elems for x in system_straight_rod], dtype=np.int64 + [x.n_elems for x in system_straight_rod], dtype=np.int32 ) n_elems_ring_rods = np.array( - [x.n_elems for x in system_ring_rod], dtype=np.int64 + [x.n_elems for x in system_ring_rod], dtype=np.int32 ) - n_straight_rods = len(system_straight_rod) - n_ring_rods = len(system_ring_rod) + n_straight_rods: int = len(system_straight_rod) + n_ring_rods: int = len(system_ring_rod) - # self.n_elems_in_rods = np.array([x.n_elems for x in systems], dtype=np.int) + # self.n_elems_in_rods = np.array([x.n_elems for x in systems], dtype=np.int32) self.n_elems_in_rods = np.hstack((n_elems_straight_rods, n_elems_ring_rods + 2)) self.n_rods = len(systems) ( @@ -200,7 +202,7 @@ def __init__(self, systems: Sequence, system_idx_list): # Initialize the mixin class for symplectic time-stepper. _RodSymplecticStepperMixin.__init__(self) - def _allocate_block_variables_in_nodes(self, systems: Sequence): + def _allocate_block_variables_in_nodes(self, systems: list[RodType]) -> None: """ This function takes system collection and allocates the variables on node for block-structure and references allocated variables back to the @@ -250,7 +252,7 @@ def _allocate_block_variables_in_nodes(self, systems: Sequence): value_type="vector", ) - def _allocate_block_variables_in_elements(self, systems: Sequence): + def _allocate_block_variables_in_elements(self, systems: list[RodType]) -> None: """ This function takes system collection and allocates the variables on elements for block-structure and references allocated variables back to the @@ -341,7 +343,7 @@ def _allocate_block_variables_in_elements(self, systems: Sequence): value_type="tensor", ) - def _allocate_blocks_variables_in_voronoi(self, systems: Sequence): + def _allocate_blocks_variables_in_voronoi(self, systems: list[RodType]) -> None: """ This function takes system collection and allocates the variables on voronoi for block-structure and references allocated variables back to the @@ -408,18 +410,12 @@ def _allocate_blocks_variables_in_voronoi(self, systems: Sequence): value_type="tensor", ) - def _allocate_blocks_variables_for_symplectic_stepper(self, systems: Sequence): + def _allocate_blocks_variables_for_symplectic_stepper( + self, systems: list[RodType] + ) -> None: """ This function takes system collection and allocates the variables used by symplectic stepper for block-structure and references allocated variables back to the systems. - - Parameters - ---------- - systems - - Returns - ------- - """ # These vectors are on nodes or on elements, but we stack them together for # better memory access. Because we use them together in time-steppers. @@ -475,7 +471,7 @@ def _allocate_blocks_variables_for_symplectic_stepper(self, systems: Sequence): def _map_system_properties_to_block_memory( self, mapping_dict: dict, - systems: Sequence, + systems: list[RodType], block_memory: np.ndarray, domain_type: Literal["node", "element", "voronoi"], value_type: Literal["scalar", "vector", "tensor"], @@ -490,7 +486,7 @@ def _map_system_properties_to_block_memory( ---------- mapping_dict: dict Dictionary with attribute names as keys and block row index as values. - systems: Sequence + systems: list[RodType] A sequence containing Cosserat rod objects to map from. block_memory: ndarray Memory block that, at the end of the method execution, contains all designated @@ -509,8 +505,8 @@ def _map_system_properties_to_block_memory( end_idx_list: np.ndarray periodic_boundary_idx: np.ndarray synchronize_periodic_boundary: Callable - domain_num: np.int64 - view_shape: tuple + domain_num: int + view_shape: tuple[int, ...] if domain_type == "node": start_idx_list = self.start_idx_in_rod_nodes.view() diff --git a/elastica/memory_block/memory_block_rod_base.py b/elastica/memory_block/memory_block_rod_base.py index c9bcc77e..87fcdd6d 100644 --- a/elastica/memory_block/memory_block_rod_base.py +++ b/elastica/memory_block/memory_block_rod_base.py @@ -1,164 +1,8 @@ -__doc__ = """Create block-structure class for collection of Cosserat rod systems.""" +__doc__ = """Deprecated module. Use memory_blocks.utils instead.""" import numpy as np -from typing import Iterable +import numpy.typing as npt - -def make_block_memory_metadata(n_elems_in_rods: np.ndarray) -> Iterable: - """ - This function, takes number of elements of each rod as a numpy array and computes, - ghost nodes, elements and voronoi element indexes and numbers and returns it. - - Parameters - ---------- - n_elems_in_rods: ndarray - An integer array containing the number of elements in each of the n rod. - - Returns - ------- - n_elems_with_ghosts: int64 - Total number of elements with ghost elements included. There are two ghost elements - between each pair of two rods adjacent in memory block. - ghost_nodes_idx: ndarray - An integer array of length n - 1 containing the indices of ghost nodes in memory block. - ghost_elements_idx: ndarray - An integer array of length 2 * (n - 1) containing the indices of ghost elements in memory block. - ghost_voronoi_idx: ndarray - An integer array of length 2 * (n - 1) containing the indices of ghost Voronoi nodes in memory block. - """ - - n_nodes_in_rods = n_elems_in_rods + 1 - n_rods = n_elems_in_rods.shape[0] - - # Gap between two rods have one ghost node - # n_nodes_with_ghosts = np.sum(n_nodes_in_rods) + (n_rods - 1) - # Gap between two rods have two ghost elements : comes out to n_nodes_with_ghosts - 1 - n_elems_with_ghosts = np.sum(n_elems_in_rods) + 2 * (n_rods - 1) - # Gap between two rods have three ghost voronois : comes out to n_nodes_with_ghosts - 2 - # n_voronoi_with_ghosts = np.sum(n_voronois_in_rods) + 3 * (n_rods - 1) - - ghost_nodes_idx = np.cumsum(n_nodes_in_rods[:-1], dtype=np.int64) - # Add [0, 1, 2, ... n_rods-2] to the ghost_nodes idx to accommodate miscounting - ghost_nodes_idx += np.arange(0, n_rods - 1, dtype=np.int64) - - ghost_elems_idx = np.zeros((2 * (n_rods - 1),), dtype=np.int64) - ghost_elems_idx[::2] = ghost_nodes_idx - 1 - ghost_elems_idx[1::2] = ghost_nodes_idx.copy() - - ghost_voronoi_idx = np.zeros((3 * (n_rods - 1),), dtype=np.int64) - ghost_voronoi_idx[::3] = ghost_nodes_idx - 2 - ghost_voronoi_idx[1::3] = ghost_nodes_idx - 1 - ghost_voronoi_idx[2::3] = ghost_nodes_idx.copy() - - return n_elems_with_ghosts, ghost_nodes_idx, ghost_elems_idx, ghost_voronoi_idx - - -def make_block_memory_periodic_boundary_metadata(n_elems_in_rods): - """ - This function, takes the number of elements of ring rods and computes the periodic boundary node, - element and voronoi index. - - Parameters - ---------- - n_elems_in_rods : numpy.ndarray - 1D (n_ring_rods,) array containing data with 'float' type. Elements of this array contains total number of - elements of one rod, including periodic boundary elements. - - Returns - ------- - n_elems - - periodic_boundary_node : numpy.ndarray - 2D (2, n_periodic_boundary_nodes) array containing data with 'float' type. Vector containing periodic boundary - elements index. First dimension is the periodic boundary index, second dimension is the referenced cell index. - - periodic_boundary_elems_idx : numpy.ndarray - 2D (2, n_periodic_boundary_elems) array containing data with 'float' type. Vector containing periodic boundary - nodes index. First dimension is the periodic boundary index, second dimension is the referenced cell index. - - periodic_boundary_voronoi_idx : numpy.ndarray - 2D (2, n_periodic_boundary_voronoi) array containing data with 'float' type. Vector containing periodic boundary - voronoi index. First dimension is the periodic boundary index, second dimension is the referenced cell index. - - """ - - n_elem = n_elems_in_rods.copy() - n_rods = n_elems_in_rods.shape[0] - - periodic_boundary_node_idx = np.zeros((2, 3 * n_rods), dtype=np.int64) - # count ghost nodes, first rod does not have a ghost node at the start, so exclude first rod. - periodic_boundary_node_idx[0, 0::3][1:] = 1 - # This is for the first periodic node at the end - periodic_boundary_node_idx[0, 1::3] = 1 + n_elem - # This is for the second periodic node at the end - periodic_boundary_node_idx[0, 2::3] = 1 - periodic_boundary_node_idx[0, :] = np.cumsum(periodic_boundary_node_idx[0, :]) - # Add [0, 1, 2, ..., n_rods] to the periodic boundary nodes to accommodate miscounting - periodic_boundary_node_idx[0, :] += np.repeat( - np.arange(0, n_rods, dtype=np.int64), 3 - ) - # Now fill the reference node idx, to copy and correct periodic boundary nodes - # First fill with the reference node idx of the first periodic node. This is the last node of the actual rod - # (without ghost and periodic nodes). - periodic_boundary_node_idx[1, 0::3] = periodic_boundary_node_idx[0, 1::3] - 1 - # Second fill with the reference node idx of the second periodic node. This is the first node of the actual rod - # (without ghost and periodic nodes). - periodic_boundary_node_idx[1, 1::3] = periodic_boundary_node_idx[0, 0::3] + 1 - # Third fill with the reference node idx of the third periodic node. This is the second node of the actual rod - # (without ghost and periodic nodes). - periodic_boundary_node_idx[1, 2::3] = periodic_boundary_node_idx[0, 0::3] + 2 - - periodic_boundary_elems_idx = np.zeros((2, 2 * n_rods), dtype=np.int64) - # count ghost elems, first rod does not have a ghost elem at the start, so exclude first rod. - periodic_boundary_elems_idx[0, 0::2][1:] = 2 - # This is for the first periodic elem at the end - periodic_boundary_elems_idx[0, 1::2] = 1 + n_elem - periodic_boundary_elems_idx[0, :] = np.cumsum(periodic_boundary_elems_idx[0, :]) - # Add [0, 1, 2, ..., n_rods] to the periodic boundary elems to accommodate miscounting - periodic_boundary_elems_idx[0, :] += np.repeat( - np.arange(0, n_rods, dtype=np.int64), 2 - ) - # Now fill the reference element idx, to copy and correct periodic boundary elements - # First fill with the reference element idx of the first periodic element. This is the last element of the actual - # rod - # (without ghost and periodic elements). - periodic_boundary_elems_idx[1, 0::2] = periodic_boundary_elems_idx[0, 1::2] - 1 - # Second fill with the reference element idx of the second periodic element. This is the first element of the actual - # rod - # (without ghost and periodic elements). - periodic_boundary_elems_idx[1, 1::2] = periodic_boundary_elems_idx[0, 0::2] + 1 - - periodic_boundary_voronoi_idx = np.zeros((2, n_rods), dtype=np.int64) - # count ghost voronoi, first rod does not have a ghost voronoi at the start, so exclude first rod. - periodic_boundary_voronoi_idx[0, 0::1][1:] = 3 - # This is for the first periodic voronoi at the end - periodic_boundary_voronoi_idx[0, 1:] += n_elem[:-1] - periodic_boundary_voronoi_idx[0, :] = np.cumsum(periodic_boundary_voronoi_idx[0, :]) - # Add [0, 1, 2, ..., n_rods] to the periodic boundary voronoi to accommodate miscounting - periodic_boundary_voronoi_idx[0, :] += np.repeat( - np.arange(0, n_rods, dtype=np.int64), 1 - ) - # Now fill the reference voronoi idx, to copy and correct periodic boundary voronoi - # Fill with the reference voronoi idx of the periodic voronoi. This is the last voronoi of the actual rod - # (without ghost and periodic voronoi). - periodic_boundary_voronoi_idx[1, :] = ( - periodic_boundary_voronoi_idx[0, :] + n_elem[:] - ) - - # Increase the n_elem in rods by 2 because we are adding two periodic boundary elements - n_elem += 2 - - return ( - n_elem, - periodic_boundary_node_idx, - periodic_boundary_elems_idx, - periodic_boundary_voronoi_idx, - ) - - -class MemoryBlockRodBase: - """ - This is the base class for memory blocks for rods. - """ - - def __init__(self): - pass +from .utils import ( + make_block_memory_metadata, + make_block_memory_periodic_boundary_metadata, +) diff --git a/elastica/memory_block/protocol.py b/elastica/memory_block/protocol.py new file mode 100644 index 00000000..5da6331b --- /dev/null +++ b/elastica/memory_block/protocol.py @@ -0,0 +1,14 @@ +from typing import Protocol +from elastica.systems.protocol import SystemProtocol + +import numpy as np + +from elastica.rod.protocol import CosseratRodProtocol +from elastica.rigidbody.protocol import RigidBodyProtocol +from elastica.systems.protocol import SymplecticSystemProtocol + + +class BlockSystemProtocol(SystemProtocol, Protocol): + @property + def n_systems(self) -> int: + """Number of systems in the block.""" diff --git a/elastica/memory_block/utils.py b/elastica/memory_block/utils.py new file mode 100644 index 00000000..47c40489 --- /dev/null +++ b/elastica/memory_block/utils.py @@ -0,0 +1,163 @@ +__doc__ = """Create block-structure class for collection of Cosserat rod systems.""" +import numpy as np +import numpy.typing as npt +from numpy.typing import NDArray + + +def make_block_memory_metadata( + n_elems_in_rods: NDArray[np.int32], +) -> tuple[ + int, + NDArray[np.int32], + NDArray[np.int32], + NDArray[np.int32], +]: + """ + This function, takes number of elements of each rod as a numpy array and computes, + ghost nodes, elements and voronoi element indexes and numbers and returns it. + + Parameters + ---------- + n_elems_in_rods: NDArray + An integer array containing the number of elements in each of the n rod. + + Returns + ------- + n_elems_with_ghosts: np.int32 + Total number of elements with ghost elements included. There are two ghost elements + between each pair of two rods adjacent in memory block. + ghost_nodes_idx: NDArray[np.int32] + An integer array of length n - 1 containing the indices of ghost nodes in memory block. + ghost_elements_idx: NDArray[np.int32] + An integer array of length 2 * (n - 1) containing the indices of ghost elements in memory block. + ghost_voronoi_idx: NDArray[np.int32] + An integer array of length 2 * (n - 1) containing the indices of ghost Voronoi nodes in memory block. + """ + n_nodes_in_rods = n_elems_in_rods + 1 + n_rods = n_elems_in_rods.shape[0] + + # Gap between two rods have one ghost node + # n_nodes_with_ghosts = np.sum(n_nodes_in_rods) + (n_rods - 1) + # Gap between two rods have two ghost elements : comes out to n_nodes_with_ghosts - 1 + n_elems_with_ghosts = np.sum(n_elems_in_rods) + 2 * (n_rods - 1) + # Gap between two rods have three ghost voronois : comes out to n_nodes_with_ghosts - 2 + # n_voronoi_with_ghosts = np.sum(n_voronois_in_rods) + 3 * (n_rods - 1) + + ghost_nodes_idx = np.cumsum(n_nodes_in_rods[:-1], dtype=np.int32) + # Add [0, 1, 2, ... n_rods-2] to the ghost_nodes idx to accommodate miscounting + ghost_nodes_idx += np.arange(n_rods - 1) + + ghost_elems_idx = np.zeros((2 * (n_rods - 1),), dtype=np.int32) + ghost_elems_idx[::2] = ghost_nodes_idx - 1 + ghost_elems_idx[1::2] = ghost_nodes_idx.copy() + + ghost_voronoi_idx = np.zeros((3 * (n_rods - 1),), dtype=np.int32) + ghost_voronoi_idx[::3] = ghost_nodes_idx - 2 + ghost_voronoi_idx[1::3] = ghost_nodes_idx - 1 + ghost_voronoi_idx[2::3] = ghost_nodes_idx.copy() + + return int(n_elems_with_ghosts), ghost_nodes_idx, ghost_elems_idx, ghost_voronoi_idx + + +def make_block_memory_periodic_boundary_metadata( + n_elems_in_rods: NDArray[np.int32], +) -> tuple[ + NDArray[np.int32], + NDArray[np.int32], + NDArray[np.int32], + NDArray[np.int32], +]: + """ + This function, takes the number of elements of ring rods and computes the periodic boundary node, + element and voronoi index. + + Parameters + ---------- + n_elems_in_rods : NDArray + 1D (n_ring_rods,) array containing data with 'int' type. Elements of this array contains total number of + elements of one rod, including periodic boundary elements. + + Returns + ------- + n_elems + + periodic_boundary_node : NDArray + 2D (2, n_periodic_boundary_nodes) array containing data with 'int' type. Vector containing periodic boundary + elements index. First dimension is the periodic boundary index, second dimension is the referenced cell index. + + periodic_boundary_elems_idx : NDArray + 2D (2, n_periodic_boundary_elems) array containing data with 'int' type. Vector containing periodic boundary + nodes index. First dimension is the periodic boundary index, second dimension is the referenced cell index. + + periodic_boundary_voronoi_idx : NDArray + 2D (2, n_periodic_boundary_voronoi) array containing data with 'int' type. Vector containing periodic boundary + voronoi index. First dimension is the periodic boundary index, second dimension is the referenced cell index. + + """ + + n_elem: NDArray[np.int32] = n_elems_in_rods.copy() + n_rods = n_elems_in_rods.shape[0] + + periodic_boundary_node_idx = np.zeros((2, 3 * n_rods), dtype=np.int32) + # count ghost nodes, first rod does not have a ghost node at the start, so exclude first rod. + periodic_boundary_node_idx[0, 0::3][1:] = 1 + # This is for the first periodic node at the end + periodic_boundary_node_idx[0, 1::3] = 1 + n_elem + # This is for the second periodic node at the end + periodic_boundary_node_idx[0, 2::3] = 1 + periodic_boundary_node_idx[0, :] = np.cumsum(periodic_boundary_node_idx[0, :]) + # Add [0, 1, 2, ..., n_rods] to the periodic boundary nodes to accommodate miscounting + periodic_boundary_node_idx[0, :] += np.repeat(np.arange(n_rods), 3) + # Now fill the reference node idx, to copy and correct periodic boundary nodes + # First fill with the reference node idx of the first periodic node. This is the last node of the actual rod + # (without ghost and periodic nodes). + periodic_boundary_node_idx[1, 0::3] = periodic_boundary_node_idx[0, 1::3] - 1 + # Second fill with the reference node idx of the second periodic node. This is the first node of the actual rod + # (without ghost and periodic nodes). + periodic_boundary_node_idx[1, 1::3] = periodic_boundary_node_idx[0, 0::3] + 1 + # Third fill with the reference node idx of the third periodic node. This is the second node of the actual rod + # (without ghost and periodic nodes). + periodic_boundary_node_idx[1, 2::3] = periodic_boundary_node_idx[0, 0::3] + 2 + + periodic_boundary_elems_idx = np.zeros((2, 2 * n_rods), dtype=np.int32) + # count ghost elems, first rod does not have a ghost elem at the start, so exclude first rod. + periodic_boundary_elems_idx[0, 0::2][1:] = 2 + # This is for the first periodic elem at the end + periodic_boundary_elems_idx[0, 1::2] = 1 + n_elem + periodic_boundary_elems_idx[0, :] = np.cumsum(periodic_boundary_elems_idx[0, :]) + # Add [0, 1, 2, ..., n_rods] to the periodic boundary elems to accommodate miscounting + periodic_boundary_elems_idx[0, :] += np.repeat(np.arange(n_rods), 2) + # Now fill the reference element idx, to copy and correct periodic boundary elements + # First fill with the reference element idx of the first periodic element. This is the last element of the actual + # rod + # (without ghost and periodic elements). + periodic_boundary_elems_idx[1, 0::2] = periodic_boundary_elems_idx[0, 1::2] - 1 + # Second fill with the reference element idx of the second periodic element. This is the first element of the actual + # rod + # (without ghost and periodic elements). + periodic_boundary_elems_idx[1, 1::2] = periodic_boundary_elems_idx[0, 0::2] + 1 + + periodic_boundary_voronoi_idx = np.zeros((2, n_rods), dtype=np.int32) + # count ghost voronoi, first rod does not have a ghost voronoi at the start, so exclude first rod. + periodic_boundary_voronoi_idx[0, 0::1][1:] = 3 + # This is for the first periodic voronoi at the end + periodic_boundary_voronoi_idx[0, 1:] += n_elem[:-1] + periodic_boundary_voronoi_idx[0, :] = np.cumsum(periodic_boundary_voronoi_idx[0, :]) + # Add [0, 1, 2, ..., n_rods] to the periodic boundary voronoi to accommodate miscounting + periodic_boundary_voronoi_idx[0, :] += np.repeat(np.arange(0, n_rods), 1) + # Now fill the reference voronoi idx, to copy and correct periodic boundary voronoi + # Fill with the reference voronoi idx of the periodic voronoi. This is the last voronoi of the actual rod + # (without ghost and periodic voronoi). + periodic_boundary_voronoi_idx[1, :] = ( + periodic_boundary_voronoi_idx[0, :] + n_elem[:] + ) + + # Increase the n_elem in rods by 2 because we are adding two periodic boundary elements + n_elem = n_elem + 2 + + return ( + n_elem, + periodic_boundary_node_idx, + periodic_boundary_elems_idx, + periodic_boundary_voronoi_idx, + ) diff --git a/elastica/mesh/mesh_initializer.py b/elastica/mesh/mesh_initializer.py index 2320c185..b0cf32aa 100644 --- a/elastica/mesh/mesh_initializer.py +++ b/elastica/mesh/mesh_initializer.py @@ -1,7 +1,8 @@ __doc__ = """ Mesh Initializer using Pyvista """ -import pyvista as pv +import pyvista as pv # type: ignore import numpy as np +from numpy.typing import NDArray class Mesh: @@ -100,8 +101,8 @@ def mesh_update(self) -> None: ) def face_calculation( - self, pvfaces: np.ndarray, meshpoints: np.ndarray, n_faces: int - ) -> np.ndarray: + self, pvfaces: NDArray, meshpoints: NDArray, n_faces: int + ) -> NDArray: """ This function converts the faces from pyvista to pyelastica geometry @@ -142,7 +143,7 @@ def face_calculation( return faces - def face_normal_calculation(self, pyvista_face_normals: np.ndarray) -> np.ndarray: + def face_normal_calculation(self, pyvista_face_normals: NDArray) -> NDArray: """ This function converts the face normals from pyvista to pyelastica geometry, in pyelastica the face are stored in the format of (n_faces, 3 spatial coordinates), @@ -152,7 +153,7 @@ def face_normal_calculation(self, pyvista_face_normals: np.ndarray) -> np.ndarra return face_normals - def face_center_calculation(self, faces: np.ndarray, n_faces: int) -> np.ndarray: + def face_center_calculation(self, faces: NDArray, n_faces: int) -> NDArray: """ This function calculates the position vector of each face of the mesh simply by averaging all the vertices of every face/cell. @@ -167,7 +168,7 @@ def face_center_calculation(self, faces: np.ndarray, n_faces: int) -> np.ndarray return face_centers - def mesh_scale_calculation(self, bounds: np.ndarray) -> np.ndarray: + def mesh_scale_calculation(self, bounds: NDArray) -> NDArray: """ This function calculates scale of the mesh, for that it calculates the maximum distance between mesh's farthest verticies in each axis. @@ -180,7 +181,7 @@ def mesh_scale_calculation(self, bounds: np.ndarray) -> np.ndarray: return scale - def orientation_calculation(self, cube_face_normals: np.ndarray) -> np.ndarray: + def orientation_calculation(self, cube_face_normals: NDArray) -> NDArray: """ This function calculates the orientation of the mesh by using a dummy cube utility from pyvista. @@ -209,7 +210,7 @@ def visualize(self) -> None: pyvista_plotter.add_mesh(self.mesh) pyvista_plotter.show() - def translate(self, target_center: np.ndarray) -> None: + def translate(self, target_center: NDArray) -> None: """ Parameters: {numpy.ndarray-(3 spatial coordinates)} ex : mesh.translate(np.array([1,1,1])) @@ -220,7 +221,7 @@ def translate(self, target_center: np.ndarray) -> None: self.mesh = self.mesh.translate(target_center) self.mesh_update() - def scale(self, factor: np.ndarray) -> None: + def scale(self, factor: NDArray) -> None: """ Parameters: {numpy.ndarray-(3 spatial constants)} ex : mesh.scale(np.array([1,1,1])) @@ -230,7 +231,7 @@ def scale(self, factor: np.ndarray) -> None: self.mesh = self.mesh.scale(factor) self.mesh_update() - def rotate(self, axis: np.ndarray, angle: float) -> None: + def rotate(self, axis: NDArray, angle: float) -> None: """ Parameters: {rotation_axis: unit vector[numpy.ndarray-(3 spatial coordinates)], angle: in degrees[float]} ex : mesh.rotate(np.array([1,0,0]), 90) diff --git a/elastica/mesh/protocol.py b/elastica/mesh/protocol.py new file mode 100644 index 00000000..f4dd2fd3 --- /dev/null +++ b/elastica/mesh/protocol.py @@ -0,0 +1,10 @@ +from typing import Protocol + +import numpy as np +from numpy.typing import NDArray + + +class MeshProtocol(Protocol): + faces: NDArray[np.float64] + face_centers: NDArray[np.float64] + face_normals: NDArray[np.float64] diff --git a/elastica/modules/base_system.py b/elastica/modules/base_system.py index 7e783c1b..3ac2eaba 100644 --- a/elastica/modules/base_system.py +++ b/elastica/modules/base_system.py @@ -5,14 +5,25 @@ Basic coordinating for multiple, smaller systems that have an independently integrable interface (i.e. works with symplectic or explicit routines `timestepper.py`.) """ -from typing import AnyStr, Iterable -from elastica.typing import OperatorType, OperatorCallbackType, OperatorFinalizeType +from typing import Type, Generator, Iterable, Any, overload +from typing import final +from elastica.typing import ( + SystemType, + StaticSystemType, + BlockSystemType, + SystemIdxType, + OperatorType, + OperatorCallbackType, + OperatorFinalizeType, +) + +import numpy as np from collections.abc import MutableSequence -from elastica.rod import RodBase -from elastica.rigidbody import RigidBodyBase -from elastica.surface import SurfaceBase +from elastica.rod.rod_base import RodBase +from elastica.rigidbody.rigid_body import RigidBodyBase +from elastica.surface.surface_base import SurfaceBase from .memory_block import construct_memory_block_structures from .operator_group import OperatorGroupFIFO @@ -26,13 +37,12 @@ class BaseSystemCollection(MutableSequence): Attributes ---------- - allowed_sys_types: tuple + allowed_sys_types: tuple[Type] Tuple of allowed type rod-like objects. Here use a base class for objects, i.e. RodBase. - _systems: list - List of rod-like objects. - - Developer Note - ----- + systems: Callable + Returns all system objects. Once finalize, block objects are also included. + blocks: Callable + Returns block objects. Should be called after finalize. Note ---- @@ -41,28 +51,36 @@ class BaseSystemCollection(MutableSequence): https://stackoverflow.com/q/3945940 """ - def __init__(self): + def __init__(self) -> None: # Collection of functions. Each group is executed as a collection at the different steps. # Each component (Forcing, Connection, etc.) registers the executable (callable) function # in the group that that needs to be executed. These should be initialized before mixin. self._feature_group_synchronize: Iterable[OperatorType] = OperatorGroupFIFO() - self._feature_group_constrain_values: Iterable[OperatorType] = [] - self._feature_group_constrain_rates: Iterable[OperatorType] = [] - self._feature_group_callback: Iterable[OperatorCallbackType] = [] - self._feature_group_finalize: Iterable[OperatorFinalizeType] = [] + self._feature_group_constrain_values: list[OperatorType] = [] + self._feature_group_constrain_rates: list[OperatorType] = [] + self._feature_group_callback: list[OperatorCallbackType] = [] + self._feature_group_finalize: list[OperatorFinalizeType] = [] # We need to initialize our mixin classes - super(BaseSystemCollection, self).__init__() + super().__init__() + # List of system types/bases that are allowed - self.allowed_sys_types = (RodBase, RigidBodyBase, SurfaceBase) + self.allowed_sys_types: tuple[Type, ...] = ( + RodBase, + RigidBodyBase, + SurfaceBase, + ) + # List of systems to be integrated - self._systems = [] - self._memory_blocks = [] + self.__systems: list[StaticSystemType] = [] + self.__final_blocks: list[BlockSystemType] = [] + # Flag Finalize: Finalizing twice will cause an error, # but the error message is very misleading - self._finalize_flag = False + self._finalize_flag: bool = False - def _check_type(self, sys_to_be_added: AnyStr): - if not issubclass(sys_to_be_added.__class__, self.allowed_sys_types): + @final + def _check_type(self, sys_to_be_added: Any) -> bool: + if not isinstance(sys_to_be_added, self.allowed_sys_types): raise TypeError( "{0}\n" "is not a system passing validity\n" @@ -83,60 +101,112 @@ def _check_type(self, sys_to_be_added: AnyStr): ) return True - def __len__(self): - return len(self._systems) + def __len__(self) -> int: + return len(self.__systems) - def __getitem__(self, idx): - return self._systems[idx] + @overload + def __getitem__(self, idx: int, /) -> SystemType: ... - def __delitem__(self, idx): - del self._systems[idx] + @overload + def __getitem__(self, idx: slice, /) -> list[SystemType]: ... - def __setitem__(self, idx, system): + def __getitem__(self, idx, /): # type: ignore + return self.__systems[idx] + + def __delitem__(self, idx, /): # type: ignore + del self.__systems[idx] + + def __setitem__(self, idx, system, /): # type: ignore self._check_type(system) - self._systems[idx] = system + self.__systems[idx] = system - def insert(self, idx, system): + def insert(self, idx, system) -> None: # type: ignore self._check_type(system) - self._systems.insert(idx, system) + self.__systems.insert(idx, system) - def __str__(self): - return str(self._systems) + def __str__(self) -> str: + """To be readable""" + return str(self.__systems) - def extend_allowed_types(self, additional_types): + @final + def extend_allowed_types( + self, additional_types: tuple[Type[SystemType], ...] + ) -> None: self.allowed_sys_types += additional_types - def override_allowed_types(self, allowed_types): + @final + def override_allowed_types( + self, allowed_types: tuple[Type[SystemType], ...] + ) -> None: self.allowed_sys_types = allowed_types - def _get_sys_idx_if_valid(self, sys_to_be_added): - from numpy import int_ as npint - - n_systems = len(self._systems) # Total number of systems from mixed-in class + @final + def get_system_index( + self, system: "SystemType | StaticSystemType" + ) -> SystemIdxType: + """ + Get the index of the system object in the system list. + System list is private, so this is the only way to get the index of the system object. + + Example + ------- + >>> system_collection: SystemCollectionProtocol + >>> system: SystemType + ... + >>> system_idx = system_collection.get_system_index(system) # save idx + ... + >>> system = system_collection[system_idx] # just need idx to retrieve + + Parameters + ---------- + system: SystemType + System object to be found in the system list. + """ + n_systems = len(self) # Total number of systems from mixed-in class - if isinstance(sys_to_be_added, (int, npint)): + sys_idx: SystemIdxType + if isinstance( + system, (int, np.integer) + ): # np.integer includes both int32 and int64 # 1. If they are indices themselves, check range + # This is only used for testing purposes assert ( - -n_systems <= sys_to_be_added < n_systems - ), "Rod index {} exceeds number of registered rodtems".format( - sys_to_be_added - ) - sys_idx = sys_to_be_added - elif self._check_type(sys_to_be_added): - # 2. If they are rod objects (most likely), lookup indices + -n_systems <= system < n_systems + ), "System index {} exceeds number of registered rodtems".format(system) + sys_idx = int(system) + elif self._check_type(system): + # 2. If they are system object (most likely), lookup indices # index might have some problems : https://stackoverflow.com/a/176921 try: - sys_idx = self._systems.index(sys_to_be_added) + sys_idx = self.__systems.index(system) except ValueError: raise ValueError( - "Rod {} was not found, did you append it to the system?".format( - sys_to_be_added + "System {} was not found, did you append it to the system?".format( + system ) ) return sys_idx - def finalize(self): + @final + def systems(self) -> Generator[StaticSystemType, None, None]: + """ + Iterate over all systems in the system collection. + If the system collection is finalized, block objects are also included. + """ + for system in self.__systems: + yield system + + @final + def block_systems(self) -> Generator[BlockSystemType, None, None]: + """ + Iterate over all block systems in the system collection. + """ + for block in self.__final_blocks: + yield block + + @final + def finalize(self) -> None: """ This method finalizes the simulator class. When it is called, it is assumed that the user has appended all rod-like objects to the simulator as well as all boundary conditions, callbacks, etc., @@ -144,14 +214,14 @@ def finalize(self): the user cannot add new features to the simulator class. """ - # This generates more straight-forward error. - assert self._finalize_flag is not True, "The finalize cannot be called twice." + assert not self._finalize_flag, "The finalize cannot be called twice." + self._finalize_flag = True - # construct memory block - self._memory_blocks = construct_memory_block_structures(self._systems) - for block in self._memory_blocks: - # append the memory block to the simulation as a system. Memory block is the final system in the simulation. - self.append(block) + # Construct memory block + self.__final_blocks = construct_memory_block_structures(self.__systems) + # FIXME: We need this to make ring-rod working. + # But probably need to be refactored + self.__systems.extend(self.__final_blocks) # Recurrent call finalize functions for all components. for finalize in self._feature_group_finalize: @@ -159,27 +229,40 @@ def finalize(self): # Clear the finalize feature group, just for the safety. self._feature_group_finalize.clear() - self._feature_group_finalize = None + del self._feature_group_finalize - # Toggle the finalize_flag - self._finalize_flag = True - - def synchronize(self, time: float): - # Collection call _feature_group_synchronize + @final + def synchronize(self, time: np.float64) -> None: + """ + Call synchronize functions for all features. + Features are registered in _feature_group_synchronize. + """ for func in self._feature_group_synchronize: func(time=time) - def constrain_values(self, time: float): - # Collection call _feature_group_constrain_values + @final + def constrain_values(self, time: np.float64) -> None: + """ + Call constrain values functions for all features. + Features are registered in _feature_group_constrain_values. + """ for func in self._feature_group_constrain_values: func(time=time) - def constrain_rates(self, time: float): - # Collection call _feature_group_constrain_rates + @final + def constrain_rates(self, time: np.float64) -> None: + """ + Call constrain rates functions for all features. + Features are registered in _feature_group_constrain_rates. + """ for func in self._feature_group_constrain_rates: func(time=time) - def apply_callbacks(self, time: float, current_step: int): - # Collection call _feature_group_callback + @final + def apply_callbacks(self, time: np.float64, current_step: int) -> None: + """ + Call callback functions for all features. + Features are registered in _feature_group_callback. + """ for func in self._feature_group_callback: func(time=time, current_step=current_step) diff --git a/elastica/modules/callbacks.py b/elastica/modules/callbacks.py index af1482ad..de1e5091 100644 --- a/elastica/modules/callbacks.py +++ b/elastica/modules/callbacks.py @@ -4,8 +4,15 @@ Provides the callBack interface to collect data over time (see `callback_functions.py`). """ +from typing import Type, Any +from typing_extensions import Self # 3.11: from typing import Self +from elastica.typing import SystemType, SystemIdxType, OperatorFinalizeType +from .protocol import ModuleProtocol + +import numpy as np from elastica.callback_functions import CallBackBaseClass +from .protocol import SystemCollectionProtocol class CallBacks: @@ -20,13 +27,16 @@ class CallBacks: List of call back classes defined for rod-like objects. """ - def __init__(self): - self._callback_list = [] + def __init__(self: SystemCollectionProtocol) -> None: + self._callback_list: list[ModuleProtocol] = [] + self._callback_operators: list[tuple[int, CallBackBaseClass]] = [] super(CallBacks, self).__init__() self._feature_group_callback.append(self._callback_execution) self._feature_group_finalize.append(self._finalize_callback) - def collect_diagnostics(self, system): + def collect_diagnostics( + self: SystemCollectionProtocol, system: SystemType + ) -> ModuleProtocol: """ This method calls user-defined call-back classes for a user-defined system or rod-like object. You need to input the @@ -41,41 +51,33 @@ def collect_diagnostics(self, system): ------- """ - sys_idx = self._get_sys_idx_if_valid(system) + sys_idx: SystemIdxType = self.get_system_index(system) # Create _Constraint object, cache it and return to user - _callbacks = _CallBack(sys_idx) + _callbacks: ModuleProtocol = _CallBack(sys_idx) self._callback_list.append(_callbacks) return _callbacks - def _finalize_callback(self): - # From stored _CallBack objects, instantiate the boundary conditions - # inplace : https://stackoverflow.com/a/1208792 - + def _finalize_callback(self: SystemCollectionProtocol) -> None: # dev : the first index stores the rod index to collect data. - # Technically we can use another array but it its one more book-keeping - # step. Being lazy, I put them both in the same array - self._callback_list[:] = [ - (callback.id(), callback(self._systems[callback.id()])) - for callback in self._callback_list + self._callback_operators = [ + (callback.id(), callback.instantiate()) for callback in self._callback_list ] + self._callback_list.clear() + del self._callback_list - # Sort from lowest id to highest id for potentially better memory access - # _callbacks contains list of tuples. First element of tuple is rod number and - # following elements are the type of boundary condition such as - # [(0, MyCallBack), (1, MyVelocityCallBack), ... ] - # Thus using lambda we iterate over the list of tuples and use rod number (x[0]) - # to sort callbacks. - self._callback_list.sort(key=lambda x: x[0]) + # First callback execution + time = np.float64(0.0) + self._callback_execution(time=time, current_step=0) - self._callback_execution(time=0.0, current_step=0) - - def _callback_execution(self, time, current_step: int, *args, **kwargs): - for sys_id, callback in self._callback_list: - callback.make_callback( - self._systems[sys_id], time, current_step, *args, **kwargs - ) + def _callback_execution( + self: SystemCollectionProtocol, + time: np.float64, + current_step: int, + ) -> None: + for sys_id, callback in self._callback_operators: + callback.make_callback(self[sys_id], time, current_step) class _CallBack: @@ -92,7 +94,7 @@ class _CallBack: Arbitrary keyword arguments. """ - def __init__(self, sys_idx: int): + def __init__(self, sys_idx: SystemIdxType): """ Parameters @@ -100,19 +102,24 @@ def __init__(self, sys_idx: int): sys_idx: int rod object index """ - self._sys_idx = sys_idx - self._callback_cls = None - self._args = () - self._kwargs = {} - - def using(self, callback_cls, *args, **kwargs): + self._sys_idx: SystemIdxType = sys_idx + self._callback_cls: Type[CallBackBaseClass] + self._args: Any + self._kwargs: Any + + def using( + self, + cls: Type[CallBackBaseClass], + *args: Any, + **kwargs: Any, + ) -> Self: """ This method is a module to set which callback class is used to collect data from user defined rod-like object. Parameters ---------- - callback_cls: object + cls: object User defined callback class. Returns @@ -120,21 +127,21 @@ def using(self, callback_cls, *args, **kwargs): """ assert issubclass( - callback_cls, CallBackBaseClass + cls, CallBackBaseClass ), "{} is not a valid call back. Did you forget to derive from CallBackClass?".format( - callback_cls + cls ) - self._callback_cls = callback_cls + self._callback_cls = cls self._args = args self._kwargs = kwargs return self - def id(self): + def id(self) -> SystemIdxType: return self._sys_idx - def __call__(self, *args, **kwargs) -> CallBackBaseClass: + def instantiate(self) -> CallBackBaseClass: """Constructs a callback functions after checks""" - if not self._callback_cls: + if not hasattr(self, "_callback_cls"): raise RuntimeError( "No callback provided to act on rod id {0}" "but a callback was registered. Did you forget to call" diff --git a/elastica/modules/connections.py b/elastica/modules/connections.py index 07e9e85b..44a59073 100644 --- a/elastica/modules/connections.py +++ b/elastica/modules/connections.py @@ -5,10 +5,21 @@ Provides the connections interface to connect entities (rods, rigid bodies) using joints (see `joints.py`). """ +from typing import Type, cast, Any +from typing_extensions import Self +from elastica.typing import ( + SystemIdxType, + OperatorFinalizeType, + ConnectionIndex, + RodType, + RigidBodyType, +) import numpy as np import functools from elastica.joint import FreeJoint +from .protocol import SystemCollectionProtocol, ModuleProtocol + class Connections: """ @@ -22,14 +33,18 @@ class Connections: List of joint classes defined for rod-like objects. """ - def __init__(self): - self._connections = [] + def __init__(self: SystemCollectionProtocol) -> None: + self._connections: list[ModuleProtocol] = [] super(Connections, self).__init__() self._feature_group_finalize.append(self._finalize_connections) def connect( - self, first_rod, second_rod, first_connect_idx=None, second_connect_idx=None - ): + self: SystemCollectionProtocol, + first_rod: "RodType | RigidBodyType", + second_rod: "RodType | RigidBodyType", + first_connect_idx: ConnectionIndex = (), + second_connect_idx: ConnectionIndex = (), + ) -> ModuleProtocol: """ This method connects two rod-like objects using the selected joint class. You need to input the two rod-like objects that are to be connected as well @@ -37,50 +52,49 @@ def connect( Parameters ---------- - first_rod : object + first_rod : RodType | RigidBodyType Rod-like object - second_rod : object + second_rod : RodType | RigidBodyType Rod-like object - first_connect_idx : int + first_connect_idx : ConnectionIndex Index of first rod for joint. - second_connect_idx : int + second_connect_idx : ConnectionIndex Index of second rod for joint. Returns ------- """ - sys_idx = [None] * 2 - for i_sys, sys in enumerate((first_rod, second_rod)): - sys_idx[i_sys] = self._get_sys_idx_if_valid(sys) - # For each system identified, get max dofs - # FIXME: Revert back to len, it should be able to take, systems without elements! - # sys_dofs = [len(self._systems[idx]) for idx in sys_idx] - sys_dofs = [self._systems[idx].n_elems for idx in sys_idx] + sys_idx_first = self.get_system_index(first_rod) + sys_idx_second = self.get_system_index(second_rod) + sys_dofs_first = first_rod.n_elems + sys_dofs_second = second_rod.n_elems # Create _Connect object, cache it and return to user - _connect = _Connect(*sys_idx, *sys_dofs) - _connect.set_index(first_connect_idx, second_connect_idx) + _connect: ModuleProtocol = _Connect( + sys_idx_first, sys_idx_second, sys_dofs_first, sys_dofs_second + ) + _connect.set_index(first_connect_idx, second_connect_idx) # type: ignore[attr-defined] self._connections.append(_connect) self._feature_group_synchronize.append_id(_connect) return _connect - def _finalize_connections(self): + def _finalize_connections(self: SystemCollectionProtocol) -> None: # From stored _Connect objects, instantiate the joints and store it # dev : the first indices stores the # (first rod index, second_rod_idx, connection_idx_on_first_rod, connection_idx_on_second_rod) # to apply the connections to. def apply_forces_and_torques( - time, - connect_instance, - system_one, - first_connect_idx, - system_two, - second_connect_idx, - ): + time: np.float64, + connect_instance: FreeJoint, + system_one: "RodType | RigidBodyType", + first_connect_idx: ConnectionIndex, + system_two: "RodType | RigidBodyType", + second_connect_idx: ConnectionIndex, + ) -> None: connect_instance.apply_forces( system_one=system_one, index_one=first_connect_idx, @@ -98,22 +112,20 @@ def apply_forces_and_torques( first_sys_idx, second_sys_idx, first_connect_idx, second_connect_idx = ( connection.id() ) - connect_instance = connection.instantiate() + connect_instance: FreeJoint = connection.instantiate() # FIXME: lambda t is included because OperatorType takes time as an argument func = functools.partial( apply_forces_and_torques, connect_instance=connect_instance, - system_one=self._systems[first_sys_idx], + system_one=self[first_sys_idx], first_connect_idx=first_connect_idx, - system_two=self._systems[second_sys_idx], + system_two=self[second_sys_idx], second_connect_idx=second_connect_idx, ) self._feature_group_synchronize.add_operators(connection, [func]) - self.warnings(connection) - self._connections = [] del self._connections @@ -121,9 +133,6 @@ def apply_forces_and_torques( # This is to optimize the call tree for better memory accesses # https://brooksandrew.github.io/simpleblog/articles/intro-to-graph-optimization-solving-cpp/ - def warnings(self, connection): - pass - class _Connect: """ @@ -131,13 +140,13 @@ class _Connect: Attributes ---------- - _first_sys_idx: int - _second_sys_idx: int + _first_sys_idx: SystemIdxType + _second_sys_idx: SystemIdxType _first_sys_n_lim: int _second_sys_n_lim: int _connect_class: list - first_sys_connection_idx: int - second_sys_connection_idx: int + first_sys_connection_idx: ConnectionIndex + second_sys_connection_idx: ConnectionIndex *args Variable length argument list. **kwargs @@ -146,8 +155,8 @@ class _Connect: def __init__( self, - first_sys_idx: int, - second_sys_idx: int, + first_sys_idx: SystemIdxType, + second_sys_idx: SystemIdxType, first_sys_nlim: int, second_sys_nlim: int, ): @@ -160,105 +169,107 @@ def __init__( first_sys_nlim: int second_sys_nlim: int """ - self._first_sys_idx = first_sys_idx - self._second_sys_idx = second_sys_idx - self._first_sys_n_lim = first_sys_nlim - self._second_sys_n_lim = second_sys_nlim - self._connect_cls = None - self._args = () - self._kwargs = {} - self.first_sys_connection_idx = None - self.second_sys_connection_idx = None - - def set_index(self, first_idx, second_idx): - # TODO assert range - # First check if the types of first rod idx and second rod idx variable are same. - assert type(first_idx) is type( - second_idx - ), "Type of first_connect_idx :{}".format( - type(first_idx) - ) + " is different than second_connect_idx :{}".format( - type(second_idx) - ) + self._first_sys_idx: SystemIdxType = first_sys_idx + self._second_sys_idx: SystemIdxType = second_sys_idx + self._first_sys_n_lim: int = first_sys_nlim + self._second_sys_n_lim: int = second_sys_nlim + self.first_sys_connection_idx: ConnectionIndex = () + self.second_sys_connection_idx: ConnectionIndex = () + self._connect_cls: Type[FreeJoint] + + def set_index( + self, first_idx: ConnectionIndex, second_idx: ConnectionIndex + ) -> None: + first_type = type(first_idx) + second_type = type(second_idx) + # Check if the types of first rod idx and second rod idx variable are same. + assert ( + first_type == second_type + ), f"Type of first_connect_idx :{first_type} is different than second_connect_idx :{second_type}" # Check if the type of idx variables are correct. + allow_types = ( + int, + np.integer, + list, + tuple, + np.ndarray, + ) # np.integer is for both int32 and int64 assert isinstance( - first_idx, (int, np.int_, list, tuple, np.ndarray, type(None)) - ), "Connection index type is not supported :{}".format( - type(first_idx) - ) + ", please try one of the following :{}".format( - (int, np.int_, list, tuple, np.ndarray) - ) + first_idx, allow_types + ), f"Connection index type is not supported :{first_type}, please try one of the following :{allow_types}" # If type of idx variables are tuple or list or np.ndarray, check validity of each entry. - if ( - isinstance(first_idx, tuple) - or isinstance(first_idx, list) - or isinstance(first_idx, np.ndarray) - ): - - for i in range(len(first_idx)): - assert isinstance(first_idx[i], (int, np.int_)), ( + if isinstance(first_idx, (tuple, list, np.ndarray)): + first_idx_ = cast(list[int], first_idx) + second_idx_ = cast(list[int], second_idx) + for i in range(len(first_idx_)): + assert isinstance(first_idx_[i], (int, np.integer)), ( "Connection index of first rod is not integer :{}".format( - first_idx[i] + first_idx_[i] ) - + " It should be :{}".format((int, np.int_)) - + " Check your input!" + + " It should be : integer. Check your input!" ) - assert isinstance(second_idx[i], (int, np.int_)), ( + assert isinstance(second_idx_[i], (int, np.integer)), ( "Connection index of second rod is not integer :{}".format( - second_idx[i] + second_idx_[i] ) - + " It should be :{}".format((int, np.int_)) - + " Check your input!" + + " It should be : integer. Check your input!" ) # The addition of +1 and and <= check on the RHS is because # connections can be made to the node indices as well assert ( -(self._first_sys_n_lim + 1) - <= first_idx[i] + <= first_idx_[i] <= self._first_sys_n_lim ), "Connection index of first rod exceeds its dof : {}".format( self._first_sys_n_lim ) assert ( -(self._second_sys_n_lim + 1) - <= second_idx[i] + <= second_idx_[i] <= self._second_sys_n_lim ), "Connection index of second rod exceeds its dof : {}".format( self._second_sys_n_lim ) - elif first_idx is None: - # Do nothing if idx are None - pass - else: - + elif isinstance(first_idx, (int, np.integer)): # The addition of +1 and and <= check on the RHS is because # connections can be made to the node indices as well + first_idx__ = cast(int, first_idx) + second_idx__ = cast(int, second_idx) assert ( - -(self._first_sys_n_lim + 1) <= first_idx <= self._first_sys_n_lim + -(self._first_sys_n_lim + 1) <= first_idx__ <= self._first_sys_n_lim ), "Connection index of first rod exceeds its dof : {}".format( self._first_sys_n_lim ) assert ( - -(self._second_sys_n_lim + 1) <= second_idx <= self._second_sys_n_lim + -(self._second_sys_n_lim + 1) <= second_idx__ <= self._second_sys_n_lim ), "Connection index of second rod exceeds its dof : {}".format( self._second_sys_n_lim ) + else: + raise TypeError( + "Connection index type is not supported :{}".format(first_type) + ) self.first_sys_connection_idx = first_idx self.second_sys_connection_idx = second_idx - def using(self, connect_cls, *args, **kwargs): + def using( + self, + cls: Type[FreeJoint], + *args: Any, + **kwargs: Any, + ) -> Self: """ This method is a module to set which joint class is used to connect user defined rod-like objects. Parameters ---------- - connect_cls: object - User defined callback class. + cls: object + User defined connection class. *args Variable length argument list **kwargs @@ -269,16 +280,18 @@ def using(self, connect_cls, *args, **kwargs): """ assert issubclass( - connect_cls, FreeJoint + cls, FreeJoint ), "{} is not a valid joint class. Did you forget to derive from FreeJoint?".format( - connect_cls + cls ) - self._connect_cls = connect_cls + self._connect_cls = cls self._args = args self._kwargs = kwargs return self - def id(self): + def id( + self, + ) -> tuple[SystemIdxType, SystemIdxType, ConnectionIndex, ConnectionIndex]: return ( self._first_sys_idx, self._second_sys_idx, @@ -286,8 +299,8 @@ def id(self): self.second_sys_connection_idx, ) - def instantiate(self): - if not self._connect_cls: + def instantiate(self) -> FreeJoint: + if not hasattr(self, "_connect_cls"): raise RuntimeError( "No connections provided to link rod id {0}" "(at {2}) and {1} (at {3}), but a Connection" diff --git a/elastica/modules/constraints.py b/elastica/modules/constraints.py index 62189870..029ed961 100644 --- a/elastica/modules/constraints.py +++ b/elastica/modules/constraints.py @@ -4,9 +4,22 @@ Provides the constraints interface to enforce displacement boundary conditions (see `boundary_conditions.py`). """ +from typing import Any, Type, cast +from typing_extensions import Self + +import numpy as np from elastica.boundary_conditions import ConstraintBase +from elastica.typing import ( + SystemIdxType, + ConstrainingIndex, + RigidBodyType, + RodType, + BlockSystemType, +) +from .protocol import SystemCollectionProtocol, ModuleProtocol + class Constraints: """ @@ -20,14 +33,16 @@ class Constraints: List of boundary condition classes defined for rod-like objects. """ - def __init__(self): - self._constraints = [] + def __init__(self: SystemCollectionProtocol) -> None: + self._constraints_list: list[ModuleProtocol] = [] super(Constraints, self).__init__() self._feature_group_constrain_values.append(self._constrain_values) self._feature_group_constrain_rates.append(self._constrain_rates) self._feature_group_finalize.append(self._finalize_constraints) - def constrain(self, system): + def constrain( + self: SystemCollectionProtocol, system: "RodType | RigidBodyType" + ) -> ModuleProtocol: """ This method enforces a displacement boundary conditions to the relevant user-defined system or rod-like object. You must input the system or rod-like @@ -42,15 +57,15 @@ def constrain(self, system): ------- """ - sys_idx = self._get_sys_idx_if_valid(system) + sys_idx = self.get_system_index(system) # Create _Constraint object, cache it and return to user - _constraint = _Constraint(sys_idx) - self._constraints.append(_constraint) + _constraint: ModuleProtocol = _Constraint(sys_idx) + self._constraints_list.append(_constraint) return _constraint - def _finalize_constraints(self): + def _finalize_constraints(self: SystemCollectionProtocol) -> None: """ In case memory block have ring rod, then periodic boundaries have to be synched. In order to synchronize periodic boundaries, a new constrain for memory block rod added called as _ConstrainPeriodicBoundaries. This @@ -58,13 +73,14 @@ def _finalize_constraints(self): """ from elastica._synchronize_periodic_boundary import _ConstrainPeriodicBoundaries - for block in self._memory_blocks: + for block in self.block_systems(): # append the memory block to the simulation as a system. Memory block is the final system in the simulation. if hasattr(block, "ring_rod_flag"): # Apply the constrain to synchronize the periodic boundaries of the memory rod. Find the memory block # sys idx among other systems added and then apply boundary conditions. - memory_block_idx = self._get_sys_idx_if_valid(block) - self.constrain(self._systems[memory_block_idx]).using( + memory_block_idx = self.get_system_index(block) + block_system = cast(BlockSystemType, self[memory_block_idx]) + self.constrain(block_system).using( _ConstrainPeriodicBoundaries, ) @@ -72,11 +88,10 @@ def _finalize_constraints(self): # inplace : https://stackoverflow.com/a/1208792 # dev : the first index stores the rod index to apply the boundary condition - # to. Technically we can use another array but it its one more book-keeping - # step. Being lazy, I put them both in the same array - self._constraints[:] = [ - (constraint.id(), constraint(self._systems[constraint.id()])) - for constraint in self._constraints + # to. + self._constraints_operators = [ + (constraint.id(), constraint.instantiate(self[constraint.id()])) + for constraint in self._constraints_list ] # Sort from lowest id to highest id for potentially better memory access @@ -85,20 +100,20 @@ def _finalize_constraints(self): # [(0, ConstraintBase, OneEndFixedBC), (1, HelicalBucklingBC), ... ] # Thus using lambda we iterate over the list of tuples and use rod number (x[0]) # to sort constraints. - self._constraints.sort(key=lambda x: x[0]) + self._constraints_operators.sort(key=lambda x: x[0]) # At t=0.0, constrain all the boundary conditions (for compatability with # initial conditions) - self._constrain_values(time=0.0) - self._constrain_rates(time=0.0) + self._constrain_values(time=np.float64(0.0)) + self._constrain_rates(time=np.float64(0.0)) - def _constrain_values(self, time, *args, **kwargs): - for sys_id, constraint in self._constraints: - constraint.constrain_values(self._systems[sys_id], time, *args, **kwargs) + def _constrain_values(self: SystemCollectionProtocol, time: np.float64) -> None: + for sys_id, constraint in self._constraints_operators: + constraint.constrain_values(self[sys_id], time) - def _constrain_rates(self, time, *args, **kwargs): - for sys_id, constraint in self._constraints: - constraint.constrain_rates(self._systems[sys_id], time, *args, **kwargs) + def _constrain_rates(self: SystemCollectionProtocol, time: np.float64) -> None: + for sys_id, constraint in self._constraints_operators: + constraint.constrain_rates(self[sys_id], time) class _Constraint: @@ -108,14 +123,16 @@ class _Constraint: Attributes ---------- _sys_idx: int - _bc_cls: list + _bc_cls: Type[ConstraintBase] + constrained_position_idx: ConstrainingIndex + constrained_director_idx: ConstrainingIndex *args Variable length argument list. **kwargs Arbitrary keyword arguments. """ - def __init__(self, sys_idx: int): + def __init__(self, sys_idx: SystemIdxType) -> None: """ Parameters @@ -124,18 +141,27 @@ def __init__(self, sys_idx: int): """ self._sys_idx = sys_idx - self._bc_cls = None - self._args = () - self._kwargs = {} - - def using(self, bc_cls, *args, **kwargs): + self._bc_cls: Type[ConstraintBase] + self._args: Any + self._kwargs: Any + self.constrained_position_idx: ConstrainingIndex + self.constrained_director_idx: ConstrainingIndex + + def using( + self, + cls: Type[ConstraintBase], + *args: Any, + constrained_position_idx: ConstrainingIndex = (), + constrained_director_idx: ConstrainingIndex = (), + **kwargs: Any, + ) -> Self: """ This method is a module to set which boundary condition class is used to enforce boundary condition from user defined rod-like objects. Parameters ---------- - bc_cls : object + cls : Type[ConstraintBase] User defined boundary condition class. *args Variable length argument list @@ -147,61 +173,55 @@ def using(self, bc_cls, *args, **kwargs): """ assert issubclass( - bc_cls, ConstraintBase + cls, ConstraintBase ), "{} is not a valid constraint. Constraint must be driven from ConstraintBase.".format( - bc_cls + cls ) - self._bc_cls = bc_cls + self._bc_cls = cls + self.constrained_position_idx = constrained_position_idx + self.constrained_director_idx = constrained_director_idx self._args = args self._kwargs = kwargs return self - def id(self): + def id(self) -> SystemIdxType: return self._sys_idx - def __call__(self, rod, *args, **kwargs): - """Constructs a constraint after checks - - Parameters - ---------- - args - kwargs - - Returns - ------- - - """ - if not self._bc_cls: + def instantiate(self, system: "RodType | RigidBodyType") -> ConstraintBase: + """Constructs a constraint after checks""" + if not hasattr(self, "_bc_cls"): raise RuntimeError( "No boundary condition provided to constrain rod" "id {0} at {1}, but a BC was intended. Did you" - "forget to call the `using` method?".format(self.id(), rod) + "forget to call the `using` method?".format(self.id(), system) ) - # If there is position, director in kwargs, deal with it first - # Returns None if not found - pos_indices = self._kwargs.get( - "constrained_position_idx", None - ) # calculate position indices as a tuple - director_indices = self._kwargs.get( - "constrained_director_idx", None - ) # calculate director indices as a tuple - - # If pos_indices is not None, construct list else empty list # IMPORTANT : do copy for memory-safe operations positions = ( - [rod.position_collection[..., idx].copy() for idx in pos_indices] - if pos_indices + [ + system.position_collection[..., idx].copy() + for idx in self.constrained_position_idx + ] + if self.constrained_position_idx else [] ) directors = ( - [rod.director_collection[..., idx].copy() for idx in director_indices] - if director_indices + [ + system.director_collection[..., idx].copy() + for idx in self.constrained_director_idx + ] + if self.constrained_director_idx else [] ) try: bc = self._bc_cls( - *positions, *directors, *self._args, _system=rod, **self._kwargs + *positions, + *directors, + *self._args, + _system=system, + constrained_position_idx=self.constrained_position_idx, + constrained_director_idx=self.constrained_director_idx, + **self._kwargs, ) return bc except (TypeError, IndexError): diff --git a/elastica/modules/contact.py b/elastica/modules/contact.py index 8882e92f..76491664 100644 --- a/elastica/modules/contact.py +++ b/elastica/modules/contact.py @@ -5,13 +5,30 @@ Provides the contact interface to apply contact forces between objects (rods, rigid bodies, surfaces). """ +from typing import Type, Any +from typing_extensions import Self +from elastica.typing import ( + SystemIdxType, + OperatorFinalizeType, + StaticSystemType, + SystemType, +) +from .protocol import SystemCollectionProtocol, ModuleProtocol + import logging import functools -from elastica.typing import SystemType, AllowedContactType + +import numpy as np + +from elastica.contact_forces import NoContact logger = logging.getLogger(__name__) +def warnings() -> None: + logger.warning("Contact features should be instantiated lastly.") + + class Contact: """ The Contact class is a module for applying contact between rod-like objects . To apply contact between rod-like objects, @@ -23,14 +40,16 @@ class Contact: List of contact classes defined for rod-like objects. """ - def __init__(self): - self._contacts = [] + def __init__(self: SystemCollectionProtocol) -> None: + self._contacts: list[ModuleProtocol] = [] super(Contact, self).__init__() self._feature_group_finalize.append(self._finalize_contact) def detect_contact_between( - self, first_system: SystemType, second_system: AllowedContactType - ): + self: SystemCollectionProtocol, + first_system: SystemType, + second_system: "SystemType | StaticSystemType", + ) -> ModuleProtocol: """ This method adds contact detection between two objects using the selected contact class. You need to input the two objects that are to be connected. @@ -38,39 +57,37 @@ def detect_contact_between( Parameters ---------- first_system : SystemType - Rod or rigid body object - second_system : AllowedContactType - Rod, rigid body or surface object + second_system : SystemType | StaticSystemType Returns ------- """ - sys_idx = [None] * 2 - for i_sys, sys in enumerate((first_system, second_system)): - sys_idx[i_sys] = self._get_sys_idx_if_valid(sys) + sys_idx_first = self.get_system_index(first_system) + sys_idx_second = self.get_system_index(second_system) # Create _Contact object, cache it and return to user - _contact = _Contact(*sys_idx) + _contact = _Contact(sys_idx_first, sys_idx_second) self._contacts.append(_contact) self._feature_group_synchronize.append_id(_contact) return _contact - def _finalize_contact(self) -> None: + def _finalize_contact(self: SystemCollectionProtocol) -> None: # dev : the first indices stores the # (first_rod_idx, second_rod_idx) # to apply the contacts to - # Technically we can use another array but it its one more book-keeping - # step. Being lazy, I put them both in the same array def apply_contact( - time, contact_instance, system, first_sys_idx, second_sys_idx - ): + time: np.float64, + contact_instance: NoContact, + first_sys_idx: SystemIdxType, + second_sys_idx: SystemIdxType, + ) -> None: contact_instance.apply_contact( - system_one=system[first_sys_idx], - system_two=system[second_sys_idx], + system_one=self[first_sys_idx], + system_two=self[second_sys_idx], ) for contact in self._contacts: @@ -78,31 +95,23 @@ def apply_contact( contact_instance = contact.instantiate() contact_instance._check_systems_validity( - self._systems[first_sys_idx], - self._systems[second_sys_idx], + self[first_sys_idx], + self[second_sys_idx], ) func = functools.partial( apply_contact, contact_instance=contact_instance, - system=self._systems, first_sys_idx=first_sys_idx, second_sys_idx=second_sys_idx, ) self._feature_group_synchronize.add_operators(contact, [func]) - self.warnings(contact) + if not self._feature_group_synchronize.is_last(contact): + warnings() self._contacts = [] del self._contacts - def warnings(self, contact): - from elastica.contact_forces import NoContact - - # Classes that should be used last - if not self._feature_group_synchronize.is_last(contact): - if isinstance(contact._contact_cls, NoContact): - logger.warning("Contact features should be instantiated lastly.") - class _Contact: """ @@ -110,9 +119,9 @@ class _Contact: Attributes ---------- - _first_sys_idx: int - _second_sys_idx: int - _contact_cls: list + _first_sys_idx: SystemIdxType + _second_sys_idx: SystemIdxType + _contact_cls: Type[NoContact] *args Variable length argument list. **kwargs @@ -121,8 +130,8 @@ class _Contact: def __init__( self, - first_sys_idx: int, - second_sys_idx: int, + first_sys_idx: SystemIdxType, + second_sys_idx: SystemIdxType, ) -> None: """ @@ -133,16 +142,18 @@ def __init__( """ self.first_sys_idx = first_sys_idx self.second_sys_idx = second_sys_idx - self._contact_cls = None + self._contact_cls: Type[NoContact] + self._args: Any + self._kwargs: Any - def using(self, contact_cls: object, *args, **kwargs): + def using(self, cls: Type[NoContact], *args: Any, **kwargs: Any) -> Self: """ This method is a module to set which contact class is used to apply contact between user defined rod-like objects. Parameters ---------- - contact_cls: object + cls: Type[NoContact] User defined contact class. *args Variable length argument list @@ -153,26 +164,24 @@ def using(self, contact_cls: object, *args, **kwargs): ------- """ - from elastica.contact_forces import NoContact - assert issubclass( - contact_cls, NoContact + cls, NoContact ), "{} is not a valid contact class. Did you forget to derive from NoContact?".format( - contact_cls + cls ) - self._contact_cls = contact_cls + self._contact_cls = cls self._args = args self._kwargs = kwargs return self - def id(self): + def id(self) -> Any: return ( self.first_sys_idx, self.second_sys_idx, ) - def instantiate(self, *args, **kwargs): - if not self._contact_cls: + def instantiate(self) -> NoContact: + if not hasattr(self, "_contact_cls"): raise RuntimeError( "No contacts provided to to establish contact between rod-like object id {0}" " and {1}, but a Contact" diff --git a/elastica/modules/damping.py b/elastica/modules/damping.py index 027eebda..ea9ea5ea 100644 --- a/elastica/modules/damping.py +++ b/elastica/modules/damping.py @@ -9,7 +9,14 @@ """ +from typing import Any, Type, List +from typing_extensions import Self + +import numpy as np + from elastica.dissipation import DamperBase +from elastica.typing import RodType, SystemType, SystemIdxType +from .protocol import SystemCollectionProtocol, ModuleProtocol class Damping: @@ -24,13 +31,13 @@ class Damping: List of damper classes defined for rod-like objects. """ - def __init__(self): - self._dampers = [] - super(Damping, self).__init__() + def __init__(self: SystemCollectionProtocol) -> None: + self._damping_list: List[ModuleProtocol] = [] + super().__init__() self._feature_group_constrain_rates.append(self._dampen_rates) self._feature_group_finalize.append(self._finalize_dampers) - def dampen(self, system): + def dampen(self: SystemCollectionProtocol, system: RodType) -> ModuleProtocol: """ This method applies damping on relevant user-defined system or rod-like object. You must input the system or rod-like @@ -45,21 +52,21 @@ def dampen(self, system): ------- """ - sys_idx = self._get_sys_idx_if_valid(system) + sys_idx = self.get_system_index(system) # Create _Damper object, cache it and return to user - _damper = _Damper(sys_idx) - self._dampers.append(_damper) + _damper: ModuleProtocol = _Damper(sys_idx) + self._damping_list.append(_damper) return _damper - def _finalize_dampers(self): + def _finalize_dampers(self: SystemCollectionProtocol) -> None: # From stored _Damping objects, instantiate the dissipation/damping # inplace : https://stackoverflow.com/a/1208792 - self._dampers[:] = [ - (damper.id(), damper(self._systems[damper.id()])) - for damper in self._dampers + self._damping_operators = [ + (damper.id(), damper.instantiate(self[damper.id()])) + for damper in self._damping_list ] # Sort from lowest id to highest id for potentially better memory access @@ -67,11 +74,11 @@ def _finalize_dampers(self): # following elements are the type of damping. # Thus using lambda we iterate over the list of tuples and use rod number (x[0]) # to sort dampers. - self._dampers.sort(key=lambda x: x[0]) + self._damping_operators.sort(key=lambda x: x[0]) - def _dampen_rates(self, time, *args, **kwargs): - for sys_id, damper in self._dampers: - damper.dampen_rates(self._systems[sys_id], time, *args, **kwargs) + def _dampen_rates(self: SystemCollectionProtocol, time: np.float64) -> None: + for sys_id, damper in self._damping_operators: + damper.dampen_rates(self[sys_id], time) class _Damper: @@ -88,7 +95,7 @@ class _Damper: Arbitrary keyword arguments. """ - def __init__(self, sys_idx: int): + def __init__(self, sys_idx: SystemIdxType) -> None: """ Parameters @@ -97,22 +104,22 @@ def __init__(self, sys_idx: int): """ self._sys_idx = sys_idx - self._damper_cls = None - self._args = () - self._kwargs = {} + self._damper_cls: Type[DamperBase] + self._args: Any + self._kwargs: Any - def using(self, damper_cls, *args, **kwargs): + def using(self, cls: Type[DamperBase], *args: Any, **kwargs: Any) -> Self: """ This method is a module to set which damper class is used to enforce damping from user defined rod-like objects. Parameters ---------- - damper_cls : object + cls : Type[DamperBase] User defined damper class. - *args - Variable length argument list - **kwargs + *args: Any + Variable length argument list. + **kwargs: Any Arbitrary keyword arguments. Returns @@ -120,31 +127,21 @@ def using(self, damper_cls, *args, **kwargs): """ assert issubclass( - damper_cls, DamperBase + cls, DamperBase ), "{} is not a valid damper. Damper must be driven from DamperBase.".format( - damper_cls + cls ) - self._damper_cls = damper_cls + self._damper_cls = cls self._args = args self._kwargs = kwargs return self - def id(self): + def id(self) -> SystemIdxType: return self._sys_idx - def __call__(self, rod, *args, **kwargs): - """Constructs a Damper class object after checks - - Parameters - ---------- - args - kwargs - - Returns - ------- - - """ - if not self._damper_cls: + def instantiate(self, rod: SystemType) -> DamperBase: + """Constructs a Damper class object after checks""" + if not hasattr(self, "_damper_cls"): raise RuntimeError( "No damper provided to dampen rod id {0} at {1}," "but damping was intended. Did you" diff --git a/elastica/modules/forcing.py b/elastica/modules/forcing.py index 4d1c0d2d..e36b0899 100644 --- a/elastica/modules/forcing.py +++ b/elastica/modules/forcing.py @@ -7,6 +7,14 @@ """ import logging import functools +from typing import Any, Type, List +from typing_extensions import Self + +import numpy as np + +from elastica.external_forces import NoForces +from elastica.typing import SystemType, SystemIdxType +from .protocol import SystemCollectionProtocol, ModuleProtocol logger = logging.getLogger(__name__) @@ -23,12 +31,14 @@ class Forcing: List of forcing class defined for rod-like objects. """ - def __init__(self): - self._ext_forces_torques = [] - super(Forcing, self).__init__() + def __init__(self: SystemCollectionProtocol) -> None: + self._ext_forces_torques: List[ModuleProtocol] = [] + super().__init__() self._feature_group_finalize.append(self._finalize_forcing) - def add_forcing_to(self, system): + def add_forcing_to( + self: SystemCollectionProtocol, system: SystemType + ) -> ModuleProtocol: """ This method applies external forces and torques on the relevant user-defined system or rod-like object. You must input the system @@ -43,7 +53,7 @@ def add_forcing_to(self, system): ------- """ - sys_idx = self._get_sys_idx_if_valid(system) + sys_idx = self.get_system_index(system) # Create _Constraint object, cache it and return to user _ext_force_torque = _ExtForceTorque(sys_idx) @@ -52,7 +62,7 @@ def add_forcing_to(self, system): return _ext_force_torque - def _finalize_forcing(self): + def _finalize_forcing(self: SystemCollectionProtocol) -> None: # From stored _ExtForceTorque objects, and instantiate a Force # inplace : https://stackoverflow.com/a/1208792 @@ -63,24 +73,19 @@ def _finalize_forcing(self): forcing_instance = external_force_and_torque.instantiate() apply_forces = functools.partial( - forcing_instance.apply_forces, system=self._systems[sys_id] + forcing_instance.apply_forces, system=self[sys_id] ) apply_torques = functools.partial( - forcing_instance.apply_torques, system=self._systems[sys_id] + forcing_instance.apply_torques, system=self[sys_id] ) self._feature_group_synchronize.add_operators( external_force_and_torque, [apply_forces, apply_torques] ) - self.warnings(external_force_and_torque) - self._ext_forces_torques = [] del self._ext_forces_torques - def warnings(self, external_force_and_torque): - pass - class _ExtForceTorque: """ @@ -89,73 +94,58 @@ class _ExtForceTorque: Attributes ---------- _sys_idx: int - _forcing_cls: list - *args + _forcing_cls: Type[NoForces] + *args: Any Variable length argument list. - **kwargs + **kwargs: Any Arbitrary keyword arguments. """ - def __init__(self, sys_idx: int): + def __init__(self, sys_idx: SystemIdxType) -> None: """ - Parameters ---------- sys_idx: int """ self._sys_idx = sys_idx - self._forcing_cls = None - self._args = () - self._kwargs = {} + self._forcing_cls: Type[NoForces] + self._args: Any + self._kwargs: Any - def using(self, forcing_cls, *args, **kwargs): + def using(self, cls: Type[NoForces], *args: Any, **kwargs: Any) -> Self: """ - This method is a module to set which forcing class is used to apply forcing + This method sets which forcing class is used to apply forcing to user defined rod-like objects. Parameters ---------- - forcing_cls: object + cls: Type[Any] User defined forcing class. - *args - Variable length argument list - **kwargs + *args: Any + Variable length argument list. + **kwargs: Any Arbitrary keyword arguments. Returns ------- """ - from elastica.external_forces import NoForces - assert issubclass( - forcing_cls, NoForces + cls, NoForces ), "{} is not a valid forcing. Did you forget to derive from NoForces?".format( - forcing_cls + cls ) - self._forcing_cls = forcing_cls + self._forcing_cls = cls self._args = args self._kwargs = kwargs return self - def id(self): + def id(self) -> SystemIdxType: return self._sys_idx - def instantiate(self): - """Constructs a constraint after checks - - Parameters - ---------- - *args - Variable length argument list. - **kwargs - Arbitrary keyword arguments. - - Returns - ------- - - """ - if not self._forcing_cls: + def instantiate(self) -> NoForces: + """Constructs a constraint after checks""" + if not hasattr(self, "_forcing_cls"): raise RuntimeError( "No forcing provided to act on rod id {0}" "but a force was registered. Did you forget to call" diff --git a/elastica/modules/memory_block.py b/elastica/modules/memory_block.py index 21cb9b50..fa75f744 100644 --- a/elastica/modules/memory_block.py +++ b/elastica/modules/memory_block.py @@ -2,14 +2,26 @@ This function is a module to construct memory blocks for different types of systems, such as Cosserat Rods, Rigid Body etc. """ +from typing import cast +from elastica.typing import ( + RodType, + RigidBodyType, + SurfaceType, + StaticSystemType, + SystemIdxType, + BlockSystemType, +) -from elastica.rod import RodBase -from elastica.rigidbody import RigidBodyBase -from elastica.surface import SurfaceBase -from elastica.memory_block import MemoryBlockCosseratRod, MemoryBlockRigidBody +from elastica.rod.rod_base import RodBase +from elastica.rigidbody.rigid_body import RigidBodyBase +from elastica.surface.surface_base import SurfaceBase +from elastica.memory_block.memory_block_rod import MemoryBlockCosseratRod +from elastica.memory_block.memory_block_rigid_body import MemoryBlockRigidBody -def construct_memory_block_structures(systems): +def construct_memory_block_structures( + systems: list[StaticSystemType], +) -> list[BlockSystemType]: """ This function takes the systems (rod or rigid body) appended to the simulator class and separates them into lists depending on if system is Cosserat rod or rigid body. Then using @@ -19,24 +31,30 @@ def construct_memory_block_structures(systems): ------- """ - _memory_blocks = [] - temp_list_for_cosserat_rod_systems = [] - temp_list_for_rigid_body_systems = [] - temp_list_for_cosserat_rod_systems_idx = [] - temp_list_for_rigid_body_systems_idx = [] + _memory_blocks: list[BlockSystemType] = [] + temp_list_for_cosserat_rod_systems: list[RodType] = [] + temp_list_for_rigid_body_systems: list[RigidBodyType] = [] + temp_list_for_cosserat_rod_systems_idx: list[SystemIdxType] = [] + temp_list_for_rigid_body_systems_idx: list[SystemIdxType] = [] for system_idx, sys_to_be_added in enumerate(systems): - if issubclass(sys_to_be_added.__class__, RodBase): - temp_list_for_cosserat_rod_systems.append(sys_to_be_added) + if isinstance(sys_to_be_added, RodBase): + rod_system = cast(RodType, sys_to_be_added) + temp_list_for_cosserat_rod_systems.append(rod_system) temp_list_for_cosserat_rod_systems_idx.append(system_idx) - elif issubclass(sys_to_be_added.__class__, RigidBodyBase): - temp_list_for_rigid_body_systems.append(sys_to_be_added) + elif isinstance(sys_to_be_added, RigidBodyBase): + rigid_body_system = cast(RigidBodyType, sys_to_be_added) + temp_list_for_rigid_body_systems.append(rigid_body_system) temp_list_for_rigid_body_systems_idx.append(system_idx) - elif issubclass(sys_to_be_added.__class__, SurfaceBase): + elif isinstance(sys_to_be_added, SurfaceBase): pass + # surface_system = cast(SurfaceType, sys_to_be_added) + # raise NotImplementedError( + # "Surfaces are not yet implemented in memory block construction." + # ) else: raise TypeError( @@ -67,4 +85,4 @@ def construct_memory_block_structures(systems): ) ) - return _memory_blocks + return list(_memory_blocks) diff --git a/elastica/modules/operator_group.py b/elastica/modules/operator_group.py index 9e552484..3b01225d 100644 --- a/elastica/modules/operator_group.py +++ b/elastica/modules/operator_group.py @@ -1,27 +1,45 @@ -from elastica.typing import OperatorType +from typing import TypeVar, Generic, Iterator from collections.abc import Iterable import itertools +T = TypeVar("T") +F = TypeVar("F") -class OperatorGroupFIFO(Iterable): + +class OperatorGroupFIFO(Iterable, Generic[T, F]): """ A class to store the features and their corresponding operators in a FIFO manner. + Feature can be any user-defined object to label the operators, and operators are + callable functions. + + This data structure is used to organize elastica operators, such as forcing, + constraints, boundary condition, etc. Examples -------- + >>> class FeatureObj: + ... ADD: Callable + ... SUBTRACT: Callable + ... MULTIPLY: Callable + + >>> obj_1 = FeatureObj() + >>> obj_2 = FeatureObj() + >>> >>> operator_group = OperatorGroupFIFO() >>> operator_group.append_id(obj_1) >>> operator_group.append_id(obj_2) - >>> operator_group.add_operators(obj_1, [ADD, SUBTRACT]) - >>> operator_group.add_operators(obj_2, [SUBTRACT, MULTIPLY]) - >>> list(operator_group) - [OperatorType.ADD, OperatorType.SUBTRACT, OperatorType.SUBTRACT, OperatorType.MULTIPLY] + >>> operator_group.add_operators(obj_1, [obj_1.ADD, obj_1.SUBTRACT]) + >>> operator_group.add_operators(obj_2, [obj_2.SUBTRACT, obj_2.MULTIPLY]) + >>> list(iter(operator_group)) + [obj_1.ADD, obj_1.SUBTRACT, obj_2.SUBTRACT, obj_2.MULTIPLY] + >>> for operator in operator_group: # call operator in the group + ... operator() Attributes ---------- - _operator_collection : list[list[OperatorType]] + _operator_collection : list[list[T]] A list of lists of operators. Each list of operators corresponds to a feature. _operator_ids : list[int] A list of ids of the features. @@ -37,26 +55,26 @@ class OperatorGroupFIFO(Iterable): Used to check if the specific feature is the last feature in the FIFO. """ - def __init__(self): - self._operator_collection: list[list[OperatorType]] = [] + def __init__(self) -> None: + self._operator_collection: list[list[T]] = [] self._operator_ids: list[int] = [] - def __iter__(self) -> OperatorType: + def __iter__(self) -> Iterator[T]: """Returns an operator iterator to satisfy the Iterable protocol.""" operator_chain = itertools.chain.from_iterable(self._operator_collection) for operator in operator_chain: yield operator - def append_id(self, feature): + def append_id(self, feature: F) -> None: """Appends the id of the feature to the list of ids.""" self._operator_ids.append(id(feature)) self._operator_collection.append([]) - def add_operators(self, feature, operators: list[OperatorType]): + def add_operators(self, feature: F, operators: list[T]) -> None: """Adds the operators to the list of operators corresponding to the feature.""" idx = self._operator_ids.index(id(feature)) self._operator_collection[idx].extend(operators) - def is_last(self, feature) -> bool: + def is_last(self, feature: F) -> bool: """Checks if the feature is the last feature in the FIFO.""" return id(feature) == self._operator_ids[-1] diff --git a/elastica/modules/protocol.py b/elastica/modules/protocol.py new file mode 100644 index 00000000..eb1b5d84 --- /dev/null +++ b/elastica/modules/protocol.py @@ -0,0 +1,152 @@ +from typing import Protocol, Generator, TypeVar, Any, Type, overload +from typing_extensions import Self # python 3.11: from typing import Self + +from abc import abstractmethod + +from elastica.typing import ( + SystemIdxType, + OperatorType, + OperatorCallbackType, + OperatorFinalizeType, + StaticSystemType, + SystemType, + BlockSystemType, + ConnectionIndex, +) +from elastica.joint import FreeJoint +from elastica.callback_functions import CallBackBaseClass +from elastica.boundary_conditions import ConstraintBase +from elastica.dissipation import DamperBase + +import numpy as np + +from .operator_group import OperatorGroupFIFO + + +M = TypeVar("M", bound="ModuleProtocol") + + +class ModuleProtocol(Protocol[M]): + def using(self, cls: Type[M], *args: Any, **kwargs: Any) -> Self: ... + + def instantiate(self, *args: Any, **kwargs: Any) -> M: ... + + def id(self) -> Any: ... + + +class SystemCollectionProtocol(Protocol): + def __len__(self) -> int: ... + + def systems(self) -> Generator[StaticSystemType, None, None]: ... + + def block_systems(self) -> Generator[BlockSystemType, None, None]: ... + + @overload + def __getitem__(self, i: slice) -> list[SystemType]: ... + @overload + def __getitem__(self, i: int) -> SystemType: ... + def __getitem__(self, i: slice | int) -> "list[SystemType] | SystemType": ... + + @property + def _feature_group_synchronize(self) -> OperatorGroupFIFO: ... + + def synchronize(self, time: np.float64) -> None: ... + + @property + def _feature_group_constrain_values(self) -> list[OperatorType]: ... + + def constrain_values(self, time: np.float64) -> None: ... + + @property + def _feature_group_constrain_rates(self) -> list[OperatorType]: ... + + def constrain_rates(self, time: np.float64) -> None: ... + + @property + def _feature_group_callback(self) -> list[OperatorCallbackType]: ... + + def apply_callbacks(self, time: np.float64, current_step: int) -> None: ... + + @property + def _feature_group_finalize(self) -> list[OperatorFinalizeType]: ... + + def get_system_index( + self, sys_to_be_added: "SystemType | StaticSystemType" + ) -> SystemIdxType: ... + + # Connection API + _finalize_connections: OperatorFinalizeType + _connections: list[ModuleProtocol] + + @abstractmethod + def connect( + self, + first_rod: SystemType, + second_rod: SystemType, + first_connect_idx: ConnectionIndex, + second_connect_idx: ConnectionIndex, + ) -> ModuleProtocol: + raise NotImplementedError + + # CallBack API + _finalize_callback: OperatorFinalizeType + _callback_list: list[ModuleProtocol] + _callback_operators: list[tuple[int, CallBackBaseClass]] + + @abstractmethod + def collect_diagnostics(self, system: SystemType) -> ModuleProtocol: + raise NotImplementedError + + @abstractmethod + def _callback_execution( + self, time: np.float64, current_step: int, *args: Any, **kwargs: Any + ) -> None: + raise NotImplementedError + + # Constraints API + _constraints_list: list[ModuleProtocol] + _constraints_operators: list[tuple[int, ConstraintBase]] + _finalize_constraints: OperatorFinalizeType + + @abstractmethod + def constrain(self, system: SystemType) -> ModuleProtocol: + raise NotImplementedError + + @abstractmethod + def _constrain_values(self, time: np.float64) -> None: + raise NotImplementedError + + @abstractmethod + def _constrain_rates(self, time: np.float64) -> None: + raise NotImplementedError + + # Forcing API + _ext_forces_torques: list[ModuleProtocol] + _finalize_forcing: OperatorFinalizeType + + @abstractmethod + def add_forcing_to(self, system: SystemType) -> ModuleProtocol: + raise NotImplementedError + + # Contact API + _contacts: list[ModuleProtocol] + _finalize_contact: OperatorFinalizeType + + @abstractmethod + def detect_contact_between( + self, first_system: SystemType, second_system: SystemType + ) -> ModuleProtocol: + raise NotImplementedError + + # Damping API + _damping_list: list[ModuleProtocol] + _damping_operators: list[tuple[int, DamperBase]] + _finalize_dampers: OperatorFinalizeType + + @abstractmethod + def dampen(self, system: SystemType) -> ModuleProtocol: + raise NotImplementedError + + @abstractmethod + def _dampen_rates(self, time: np.float64) -> None: + raise NotImplementedError diff --git a/elastica/py.typed b/elastica/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/elastica/reset_functions_for_block_structure/_reset_ghost_vector_or_scalar.py b/elastica/reset_functions_for_block_structure/_reset_ghost_vector_or_scalar.py index 7ed2a3bf..1759a5ba 100644 --- a/elastica/reset_functions_for_block_structure/_reset_ghost_vector_or_scalar.py +++ b/elastica/reset_functions_for_block_structure/_reset_ghost_vector_or_scalar.py @@ -1,18 +1,24 @@ __doc__ = """Reset the ghost vectors or scalar variables using functions implemented in Numba""" +import numpy as np +from numpy.typing import NDArray from numba import njit -@njit(cache=True) -def _reset_vector_ghost(input, ghost_idx, reset_value=0.0): +@njit(cache=True) # type: ignore +def _reset_vector_ghost( + input: NDArray[np.float64], + ghost_idx: NDArray[np.int32], + reset_value: np.float64 = np.float64(0.0), +): """ This function resets the ghost of an input vector collection. Default reset value is 0.0. Parameters ---------- - input - ghost_idx - reset_value + input : NDArray[np.float64] + ghost_idx : NDArray[np.int32] + reset_value : np.float64 Returns ------- @@ -39,16 +45,20 @@ def _reset_vector_ghost(input, ghost_idx, reset_value=0.0): input[i, k] = reset_value -@njit(cache=True) -def _reset_scalar_ghost(input, ghost_idx, reset_value=0.0): +@njit(cache=True) # type: ignore +def _reset_scalar_ghost( + input: NDArray[np.float64], + ghost_idx: NDArray[np.int32], + reset_value: np.float64 = np.float64(0.0), +): """ This function resets the ghost of a scalar collection. Default reset value is 0.0. Parameters ---------- - input - ghost_idx - reset_value + input : NDArray[np.float64] + ghost_idx : NDArray[np.int32] + reset_value : np.float64 Returns ------- diff --git a/elastica/restart.py b/elastica/restart.py index 1b5fab26..d6f9475b 100644 --- a/elastica/restart.py +++ b/elastica/restart.py @@ -5,8 +5,10 @@ from itertools import groupby from .memory_block import MemoryBlockCosseratRod, MemoryBlockRigidBody +from typing import Iterable, Iterator, Any -def all_equal(iterable): + +def all_equal(iterable: Iterable[Any]) -> bool: """ Checks if all elements of list are equal. Parameters @@ -20,11 +22,17 @@ def all_equal(iterable): ---------- https://stackoverflow.com/questions/3844801/check-if-all-elements-in-a-list-are-identical """ - g = groupby(iterable) + g: Iterator[Any] = groupby(iterable) return next(g, True) and not next(g, False) -def save_state(simulator, directory: str = "", time=0.0, verbose: bool = False): +# TODO: simulator should have better typing +def save_state( + simulator: Iterable, + directory: str = "", + time: np.float64 = np.float64(0.0), + verbose: bool = False, +) -> None: """ Save state parameters of each rod. TODO : environment list variable is not uniform at the current stage of development. @@ -53,7 +61,10 @@ def save_state(simulator, directory: str = "", time=0.0, verbose: bool = False): print("Save complete: {}".format(directory)) -def load_state(simulator, directory: str = "", verbose: bool = False): +# TODO: simulator should have better typing +def load_state( + simulator: Iterable, directory: str = "", verbose: bool = False +) -> float: """ Load the rod-state. Compatibale with 'save_state' method. If the save-file does not exist, it returns error. @@ -72,7 +83,7 @@ def load_state(simulator, directory: str = "", verbose: bool = False): time : float Simulation time of systems when they are saved. """ - time_list = [] # Simulation time of rods when they are saved. + time_list: list[float] = [] # Simulation time of rods when they are saved. for idx, rod in enumerate(simulator): if isinstance(rod, MemoryBlockCosseratRod) or isinstance( rod, MemoryBlockRigidBody diff --git a/elastica/rigidbody/cylinder.py b/elastica/rigidbody/cylinder.py index 58bdaa9e..f1e8e83d 100644 --- a/elastica/rigidbody/cylinder.py +++ b/elastica/rigidbody/cylinder.py @@ -1,6 +1,10 @@ -__doc__ = """""" +__doc__ = """ +Implementation of a rigid body cylinder. +""" +from typing import TYPE_CHECKING import numpy as np +from numpy.typing import NDArray from elastica._linalg import _batch_cross from elastica.utils import MaxDimension @@ -8,75 +12,116 @@ class Cylinder(RigidBodyBase): - def __init__(self, start, direction, normal, base_length, base_radius, density): + def __init__( + self, + start: NDArray[np.float64], + direction: NDArray[np.float64], + normal: NDArray[np.float64], + base_length: float, + base_radius: float, + density: float, + ) -> None: """ Rigid body cylinder initializer. Parameters ---------- - start - direction - normal - base_length - base_radius - density + start : NDArray[np.float64] + direction : NDArray[np.float64] + normal : NDArray[np.float64] + base_length : float + base_radius : float + density : float """ - # rigid body does not have elements it only have one node. We are setting n_elems to - # zero for only make code to work. _bootstrap_from_data requires n_elems to be defined - self.n_elems = 1 - normal = normal.reshape(3, 1) - tangents = direction.reshape(3, 1) + # FIXME: Refactor + def assert_check_array_size( + to_check: NDArray[np.float64], name: str, expected: int = 3 + ) -> None: + array_size = to_check.size + assert array_size == expected, ( + f"Invalid size of '{name}'. " + f"Expected: {expected}, but got: {array_size}" + ) + + # FIXME: Refactor + def assert_check_lower_bound( + to_check: float, name: str, lower_bound: float = 0.0 + ) -> None: + assert ( + to_check > lower_bound + ), f"Value for '{name}' ({to_check}) must be at lease {lower_bound}. " + + assert_check_array_size(start, "start") + assert_check_array_size(direction, "direction") + assert_check_array_size(normal, "normal") + + assert_check_lower_bound(base_length, "base_length") + assert_check_lower_bound(base_radius, "base_radius") + assert_check_lower_bound(density, "density") + + super().__init__() + + normal = normal.reshape((3, 1)) + tangents = direction.reshape((3, 1)) binormal = _batch_cross(tangents, normal) - self.radius = base_radius - self.length = base_length - self.density = density + self.radius = np.float64(base_radius) + self.length = np.float64(base_length) + self.density = np.float64(density) + + dim: int = MaxDimension.value() + # This is for a rigid body cylinder - self.volume = np.pi * base_radius * base_radius * base_length - self.mass = np.array([self.volume * self.density]) + self.volume = np.float64(np.pi * base_radius * base_radius * base_length) + self.mass = np.float64(self.volume * self.density) # Second moment of inertia - A0 = np.pi * base_radius * base_radius - I0_1 = A0 * A0 / (4.0 * np.pi) - I0_2 = I0_1 - I0_3 = 2.0 * I0_2 - I0 = np.array([I0_1, I0_2, I0_3]) + area = np.pi * base_radius * base_radius + smoa_span_1 = area * area / (4.0 * np.pi) + smoa_span_2 = smoa_span_1 + smoa_axial = 2.0 * smoa_span_1 + smoa = np.array([smoa_span_1, smoa_span_2, smoa_axial]) + + # Allocate properties + self.position_collection = np.zeros((dim, 1), dtype=np.float64) + self.velocity_collection = np.zeros((dim, 1), dtype=np.float64) + self.acceleration_collection = np.zeros((dim, 1), dtype=np.float64) + self.omega_collection = np.zeros((dim, 1), dtype=np.float64) + self.alpha_collection = np.zeros((dim, 1), dtype=np.float64) + self.director_collection = np.zeros((dim, dim, 1), dtype=np.float64) + + self.external_forces = np.zeros((dim, 1), dtype=np.float64) + self.external_torques = np.zeros((dim, 1), dtype=np.float64) # Mass second moment of inertia for disk cross-section - mass_second_moment_of_inertia = np.zeros( - (MaxDimension.value(), MaxDimension.value()), np.float64 - ) - np.fill_diagonal(mass_second_moment_of_inertia, I0 * density * base_length) + mass_second_moment_of_inertia = np.diag(smoa * density * base_length) self.mass_second_moment_of_inertia = mass_second_moment_of_inertia.reshape( - MaxDimension.value(), MaxDimension.value(), 1 + (dim, dim, 1) ) self.inv_mass_second_moment_of_inertia = np.linalg.inv( mass_second_moment_of_inertia - ).reshape(MaxDimension.value(), MaxDimension.value(), 1) + ).reshape((dim, dim, 1)) # position is at the center - self.position_collection = np.zeros((MaxDimension.value(), 1)) self.position_collection[:] = ( start.reshape(3, 1) + direction.reshape(3, 1) * base_length / 2 ) - self.velocity_collection = np.zeros((MaxDimension.value(), 1)) - self.omega_collection = np.zeros((MaxDimension.value(), 1)) - self.acceleration_collection = np.zeros((MaxDimension.value(), 1)) - self.alpha_collection = np.zeros((MaxDimension.value(), 1)) - - self.director_collection = np.zeros( - (MaxDimension.value(), MaxDimension.value(), 1) - ) self.director_collection[0, ...] = normal self.director_collection[1, ...] = binormal self.director_collection[2, ...] = tangents - self.external_forces = np.zeros((MaxDimension.value())).reshape( - MaxDimension.value(), 1 - ) - self.external_torques = np.zeros((MaxDimension.value())).reshape( - MaxDimension.value(), 1 - ) + +if TYPE_CHECKING: + from .protocol import RigidBodyProtocol + + _: RigidBodyProtocol = Cylinder( + start=np.zeros(3), + direction=np.ones(3), + normal=np.ones(3), + base_length=1.0, + base_radius=1.0, + density=1.0, + ) diff --git a/elastica/rigidbody/data_structures.py b/elastica/rigidbody/data_structures.py index e246c1fc..95944f75 100644 --- a/elastica/rigidbody/data_structures.py +++ b/elastica/rigidbody/data_structures.py @@ -1,10 +1,9 @@ __doc__ = "Data structure wrapper for rod components" -import numpy as np - -from elastica._rotations import _get_rotation_matrix from elastica.rod.data_structures import _RodSymplecticStepperMixin +pass + """ # FIXME : Explicit Stepper doesn't work as States lose the # views they initially had when working with a timestepper. @@ -42,472 +41,5 @@ def __call__(self, time, *args, **kwargs): return self.__deriv_state """ - -class _RigidRodSymplecticStepperMixin(_RodSymplecticStepperMixin): - def __init__(self): - super(_RigidRodSymplecticStepperMixin, self).__init__() - # Expose rate returning functions in the interface - # to be used by the time-stepping algorithm - # dynamic rates needs to call update_accelerations and henc - # is another function - - def update_internal_forces_and_torques(self, *args, **kwargs): - pass - - -def _bootstrap_from_data(stepper_type: str, n_elems: int, vector_states, matrix_states): - """Returns states wrapping numpy arrays based on the time-stepping algorithm - - Convenience method that takes in rod internal (raw np.ndarray) data, create views - (references) from it, and outputs State classes that are used in the time-stepping - algorithm. This means that modifying the state modifies the internal data! - - Parameters - ---------- - stepper_type : str (likely to change in future), representing stepper type - Allowed parameters are ['explicit', 'symplectic'] - n_elems : int, number of rod elements - vector_states : np.ndarray of shape (dim, *) with the following structure - `vector_states` = [`position`,`velocity`,`omega`,`acceleration`,`angular acceleration`] - `n_nodes = n_elems = 1` - `position = 0 -> n_nodes , size = n_nodes` - `velocity = n_nodes -> 2 * n_nodes, size = n_nodes` - `omega = 2 * n_nodes -> 2 * n_nodes + nelem, size = nelem` - `acceleration = 2 * n_nodes + nelem -> 3 * n_nodes + nelem, size = n_nodes` - `angular acceleration = 3 * n_nodes + nelem -> 3 * n_nodes + 2 * nelem, size = n_elems` - matrix_states : np.ndarray of shape (dim, dim, n_elems) containing the directors - - Returns - ------- - output : tuple of len 8 containing - (state, derivative_state, position, directors, velocity, omega, acceleration, alpha) - derivative_state carries rate information - - """ - n_nodes = n_elems - position = np.ndarray.view(vector_states[..., :n_nodes]) - directors = np.ndarray.view(matrix_states) - v_w_dvdt_dwdt = np.ndarray.view(vector_states[..., n_nodes:]) - output = () - # TODO: - # 11/01/2020: Extend Rigid body data structures for Explicit steppers - # 12/20/2021: Future work! If explicit stepper is gonna be dropped, remove. - # if stepper_type == "explicit": - # v_w_states = np.ndarray.view(vector_states[..., n_nodes : n_nodes + 1]) - # output += ( - # _State(n_elems, position, directors, v_w_states), - # _DerivativeState(n_elems, v_w_dvdt_dwdt), - # ) - # elif stepper_type == "symplectic": - # output += ( - # _KinematicState(n_elems, position, directors), - # _DynamicState(n_elems, v_w_dvdt_dwdt), - # ) - # else: - # return - output += ( - _KinematicState(n_elems, position, directors), - _DynamicState(n_elems, v_w_dvdt_dwdt), - ) - - n_velocity_end = n_nodes + n_nodes - velocity = np.ndarray.view(vector_states[..., n_nodes:n_velocity_end]) - - n_omega_end = n_velocity_end + n_elems - omega = np.ndarray.view(vector_states[..., n_velocity_end:n_omega_end]) - - n_acceleration_end = n_omega_end + n_nodes - acceleration = np.ndarray.view(vector_states[..., n_omega_end:n_acceleration_end]) - - n_alpha_end = n_acceleration_end + n_elems - alpha = np.ndarray.view(vector_states[..., n_acceleration_end:n_alpha_end]) - - return output + (position, directors, velocity, omega, acceleration, alpha) - - -""" -Explicit stepper interface -""" - -# TODO -# 12/20/2021: If explicit stepper is gonna be dropped, remove. -# class _State: -# """State for explicit steppers. -# -# Wraps data as state, with overloaded methods for explicit steppers -# (steppers that integrate all states in one-step/stage). -# Allows for separating implementation of stepper from actual -# addition/multiplication/other formulae used. -# """ -# -# # TODO : args, kwargs instead of hardcoding types -# def __init__( -# self, -# n_elems: int, -# position_collection_view, -# director_collection_view, -# kinematic_rate_collection_view, -# ): -# """ -# Parameters -# ---------- -# n_elems : int, number of rod elements -# position_collection_view : view of positions (or) x -# director_collection_view : view of directors (or) Q -# kinematic_rate_collection_view : view of velocity and omega (or) (v,ω) -# """ -# super(_State, self).__init__() -# self.n_nodes = n_elems -# self.n_kinematic_rates = self.n_nodes + n_elems # start of (v,ω) in (x,Q,v,ω) -# self.position_collection = position_collection_view -# self.director_collection = director_collection_view -# self.kinematic_rate_collection = kinematic_rate_collection_view -# -# def __iadd__(self, scaled_deriv_array): -# """overloaded += operator -# -# The add for directors is customized to reflect Rodrigues' rotation -# formula. -# -# Parameters -# ---------- -# scaled_deriv_array : np.ndarray containing dt * (v, ω, dv/dt, dω/dt) -# ,as returned from _DerivativeState's __mul__ method -# -# Returns -# ------- -# self : _State with inplace modified data -# -# """ -# # x += v*dt -# self.position_collection += scaled_deriv_array[..., : self.n_nodes] -# # TODO : Verify the math in this note -# r""" -# Developer Note -# -------------- -# Here the overloaded `+=` operator is exploited to perform -# matrix multiplication for the directors, which is counter- -# intutive at first. While this provides a stable interface -# to interact the rod states with the timesteppers and the -# rest of the world, the reasons behind including it here also has -# a depper mathematical significance. -# -# Firstly, position lies in the vector space corresponding to R^{3} -# and update is done this space (with the + and * operators defined -# as usual), hence the `+=` operator (or `__iadd__`) is reflected -# as `+=` operator in the position update (line 163 above). -# -# For directors rather, which lie in a restricteed R^{3} \otimes -# R^{3} tensorial space, the space with Q^T.Q = Q.Q^T = I, the + -# operator can be thought of as an equivalent `*=` update for a -# 'exponential' multiplication with a rotation matrix (e^{At}). -# . This does not correspond to the position update. However, if -# we view this in a logarithmic space the `*=` becomse the '+=' -# operator once again! After performing this `+=` operation, we -# bring it back into its original space using the exponential -# operator. So we are still indirectly doing the '+=' -# update. -# -# To avoid all this hassle with the operators and spaces, we simply define -# '+=' or '__iadd__' in the case of directors as an equivalent -# '*=' (matrix multiply) with the RHS below. -# """ -# # TODO Q *= exp(w*dt) , whats' the formua again? -# # TODO the scale factor 1.0 does not seem to be necessary, although -# # we perform more work in the present framework (muliply dt to entire vector, then take -# # norm) rather than vector norm then multiple by dt (1/3 operation costs) -# # TODO optimize (somehow) extra copy away : if we don't make a copy -# # its even more slower, maybe due to aliasing effects -# np.einsum( -# "ijk,jlk->ilk", -# _get_rotation_matrix( -# 1.0, scaled_deriv_array[..., self.n_nodes : self.n_kinematic_rates] -# ), -# self.director_collection.copy(), -# out=self.director_collection, -# ) -# # (v,ω) += (dv/dt, dω/dt)*dt -# self.kinematic_rate_collection += scaled_deriv_array[ -# ..., self.n_kinematic_rates : -# ] -# return self -# -# def __add__(self, scaled_derivative_state): -# """overloaded + operator, useful in state.k1 = state + dt * deriv_state -# -# The add for directors is customized to reflect Rodrigues' rotation -# formula. -# -# Parameters -# ---------- -# scaled_derivative_state : np.ndarray with dt * (v, ω, dv/dt, dω/dt) -# ,as returned from _DerivativeState's __mul__ method -# -# Returns -# ------- -# state : new _State object with modified data (copied) -# -# Caveats -# ------- -# Note that the argument is not a `other` _State object but is rather -# assumed to be a `np.ndarray` from calling _DerivativeState's __mul__ -# method. This reflects the most common use-case in time-steppers -# -# """ -# # x += v*dt -# position_collection = ( -# self.position_collection + scaled_derivative_state[..., : self.n_nodes] -# ) -# # Devs : see `_State.__iadd__` for reasons why we do matmul here -# director_collection = _rotate( -# self.director_collection, -# 1.0, -# scaled_derivative_state[..., self.n_nodes : self.n_kinematic_rates], -# ) -# # (v,ω) += (dv/dt, dω/dt)*dt -# kinematic_rate_collection = ( -# self.kinematic_rate_collection -# + scaled_derivative_state[..., self.n_kinematic_rates :] -# ) -# return _State( -# self.n_nodes, -# position_collection, -# director_collection, -# kinematic_rate_collection, -# ) -# -# -# class _DerivativeState: -# """TimeDerivative of States for explicit steppers. -# -# Wraps time-derivative data as state, with overloaded methods for -# explicit steppers (steppers that integrate all states in one-step/stage). -# Allows for separating implementation of stepper from actual addition -# /multiplication used. -# """ -# -# def __init__(self, _unused_n_elems: int, rate_collection_view): -# """ -# Parameters -# ---------- -# _unused_n_elems : int, number of elements (unused, kept for -# compatibility with `_bootstrap_from_data`) -# rate_collection_view : np.ndarray containing (v, ω, dv/dt, dω/dt) -# """ -# super(_DerivativeState, self).__init__() -# self.rate_collection = rate_collection_view -# -# def __rmul__(self, scalar): -# """overloaded scalar * self, -# -# Parameters -# ---------- -# scalar : float, typically dt (the time-step) -# -# Returns -# ------- -# output : np.ndarray containing (v*dt, ω*dt, dv/dt*dt, dω/dt*dt) -# -# Caveats -# ------- -# Returns a np.ndarray and not a State object (as one expects). -# Returning a State here with (v*dt, ω*dt, dv/dt*dt, dω/dt*dt) as members -# is possible but it's less efficient, especially because this is hot -# piece of code -# """ -# """ -# Developer Note -# -------------- -# -# Q : Why do we need to overload operators here? -# -# The Derivative class naturally doesn't have a `mul` overloaded -# operator. That means if this method is not present, -# doing something like -# ``` -# ds = _DerivativeState(...) -# new_state = 2 * ds -# ``` -# will throw an error. Note that you can do something like -# ``` -# ds = _DerivativeState(...) -# new_state = 2 * ds.rate_collection -# ``` -# but this is hacky, as we are exposing the members outside, -# in the calling scope (defeats encapsulation and hiding). -# The point of having this class is that it works -# well with the time-stepper (where we only use `+` and `*` -# operations on the State/DerivativeState like above, -# i.e. `state = dt * derivative_state` and not something like -# `state = dt * derivative_state.rate_collection`). -# It also provides an interface for anything outside -# the `Rod` system as a whole. -# """ -# return scalar * self.rate_collection -# -# def __mul__(self, scalar): -# """overloaded self * scalar -# -# TODO Check if this pattern (forwarding to __mul__) has -# any disdvantages apart from extra function call penalty -# -# Parameters -# ---------- -# scalar : float, typically dt (the time-step) -# -# Returns -# ------- -# output : np.ndarray containing (v*dt, ω*dt, dv/dt*dt, dω/dt*dt) -# -# """ -# return self.__rmul__(scalar) - -""" -Symplectic stepper interface -""" - - -# TODO: Maybe considerg removing. We no longer use bootstrap to initialize. -# RigidBodySymplecticStepperMixin is now derived from RodSymplecticStepperMixin -class _KinematicState: - """State storing (x,Q) for symplectic steppers. - - Wraps data as state, with overloaded methods for symplectic steppers. - Allows for separating implementation of stepper from actual - addition/multiplication/other formulae used. - - Symplectic steppers rely only on in-place modifications to state and so - only these methods are provided. - """ - - def __init__(self, position_collection_view, director_collection_view): - """ - Parameters - ---------- - n_elems : int, number of rod elements - position_collection_view : view of positions (or) x - director_collection_view : view of directors (or) Q - """ - self.position_collection = position_collection_view - self.director_collection = director_collection_view - - def __iadd__(self, scaled_deriv_array): - """overloaded += operator - - The add for directors is customized to reflect Rodrigues' rotation - formula. - - Parameters - ---------- - scaled_deriv_array : np.ndarray containing dt * (v, ω), - as retured from _DynamicState's `kinematic_rates` method - - Returns - ------- - self : _KinematicState instance with inplace modified data - - Caveats - ------- - Takes a np.ndarray and not a _KinematicState object (as one expects). - This is done for efficiency reasons, see _DynamicState's `kinematic_rates` - method - """ - velocity_collection = scaled_deriv_array[0] - omega_collection = scaled_deriv_array[1] - # x += v*dt - self.position_collection += velocity_collection - # Devs : see `_State.__iadd__` for reasons why we do matmul here - np.einsum( - "ijk,jlk->ilk", - _get_rotation_matrix(1.0, omega_collection), - self.director_collection.copy(), - out=self.director_collection, - ) - return self - - -class _DynamicState: - """State storing (v,ω, dv/dt, dω/dt) for symplectic steppers. - - Wraps data as state, with overloaded methods for symplectic steppers. - Allows for separating implementation of stepper from actual - addition/multiplication/other formulae used. - - Symplectic steppers rely only on in-place modifications to state and so - only these methods are provided. - """ - - def __init__( - self, - v_w_collection, - dvdt_dwdt_collection, - velocity_collection, - omega_collection, - ): - """ - - Parameters - ---------- - n_elems : int, number of rod elements - rate_collection_view : np.ndarray containing (v, ω, dv/dt, dω/dt) - """ - super(_DynamicState, self).__init__() - # Limit at which (v, w) end - self.rate_collection = v_w_collection - self.dvdt_dwdt_collection = dvdt_dwdt_collection - self.velocity_collection = velocity_collection - self.omega_collection = omega_collection - - def __iadd__(self, scaled_second_deriv_array): - """overloaded += operator, updating dynamic_rates - - Parameters - ---------- - scaled_second_deriv_array : np.ndarray containing dt * (dvdt, dωdt), - as retured from _DynamicState's `dynamic_rates` method - - Returns - ------- - self : _DynamicState instance with inplace modified data - - Caveats - ------- - Takes a np.ndarray and not a _DynamicState object (as one expects). - This is done for efficiency reasons, see `dynamic_rates`. - """ - # Always goes in LHS : that means the update is on the rates alone - # (v,ω) += dt * (dv/dt, dω/dt) -> self.dynamic_rates - self.rate_collection += scaled_second_deriv_array - return self - - def kinematic_rates(self, time, prefac, *args, **kwargs): - """Yields kinematic rates to interact with _KinematicState - - Returns - ------- - v_and_omega : np.ndarray consisting of (v,ω) - - Caveats - ------- - Doesn't return a _KinematicState with (dt*v, dt*w) as members, - as one expects the _Kinematic __add__ operator to interact - with another _KinematicState. This is done for efficiency purposes. - """ - # RHS functino call, gives v,w so that - # Comes from kin_state -> (x,Q) += dt * (v,w) <- First part of dyn_state - return prefac * self.velocity_collection, prefac * self.omega_collection - - def dynamic_rates(self, time, prefac, *args, **kwargs): - """Yields dynamic rates to add to with _DynamicState - - Returns - ------- - acc_and_alpha : np.ndarray consisting of (dv/dt,dω/dt) - - Caveats - ------- - Doesn't return a _DynamicState with (dt*v, dt*w) as members, - as one expects the _Dynamic __add__ operator to interact - with another _DynamicState. This is done for efficiency purposes. - """ - return prefac * self.dvdt_dwdt_collection +# TODO: Temporary solution as the structure for RigidBody is similar to Rod +_RigidRodSymplecticStepperMixin = _RodSymplecticStepperMixin diff --git a/elastica/rigidbody/mesh_rigid_body.py b/elastica/rigidbody/mesh_rigid_body.py index 4a026f9a..e5cf3f50 100644 --- a/elastica/rigidbody/mesh_rigid_body.py +++ b/elastica/rigidbody/mesh_rigid_body.py @@ -1,5 +1,8 @@ __doc__ = """rigid body class based on mesh""" +from numpy.typing import NDArray +from elastica.typing import MeshType + import numpy as np import numba from elastica._linalg import _batch_cross, _batch_norm @@ -10,12 +13,12 @@ class MeshRigidBody(RigidBodyBase): def __init__( self, - mesh, - center_of_mass, - mass_second_moment_of_inertia, - density, - volume, - ): + mesh: MeshType, + center_of_mass: NDArray[np.float64], + mass_second_moment_of_inertia: NDArray[np.float64], + density: np.float64, + volume: np.float64, + ) -> None: """ Mesh rigid body initializer. @@ -37,11 +40,11 @@ def __init__( """ # rigid body does not have elements it only have one node. We are setting n_elems to # zero for only make code to work. _bootstrap_from_data requires n_elems to be defined - self.n_elems = 1 # center_mass + self.n_elems: int = 1 # center_mass self.density = density self.volume = volume - self.mass = np.array([self.volume * self.density]) + self.mass = np.float64(self.volume * self.density) self.mass_second_moment_of_inertia = mass_second_moment_of_inertia.reshape( MaxDimension.value(), MaxDimension.value(), 1 ) @@ -112,7 +115,7 @@ def __init__( MaxDimension.value(), 1 ) - def update_faces(self): + def update_faces(self) -> None: _update_faces( self.director_collection, self.face_centers, @@ -128,7 +131,7 @@ def update_faces(self): ) -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _update_faces( director_collection, face_centers, diff --git a/elastica/rigidbody/protocol.py b/elastica/rigidbody/protocol.py new file mode 100644 index 00000000..7e6f0664 --- /dev/null +++ b/elastica/rigidbody/protocol.py @@ -0,0 +1,18 @@ +from typing import Protocol + +import numpy as np +from numpy.typing import NDArray + +from elastica.systems.protocol import SystemProtocol, SlenderBodyGeometryProtocol + + +class RigidBodyProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol): + + mass: np.float64 + volume: np.float64 + length: np.float64 + tangents: NDArray[np.float64] + radius: np.float64 + + mass_second_moment_of_inertia: NDArray[np.float64] + inv_mass_second_moment_of_inertia: NDArray[np.float64] diff --git a/elastica/rigidbody/rigid_body.py b/elastica/rigidbody/rigid_body.py index e8739274..99920854 100644 --- a/elastica/rigidbody/rigid_body.py +++ b/elastica/rigidbody/rigid_body.py @@ -1,7 +1,11 @@ __doc__ = """""" -import numpy as np +from typing import Type + from abc import ABC + +import numpy as np +from numpy.typing import NDArray from elastica._linalg import _batch_matvec, _batch_cross @@ -15,56 +19,49 @@ class RigidBodyBase(ABC): """ - REQUISITE_MODULES = [] + REQUISITE_MODULES: list[Type] = [] - def __init__(self): + def __init__(self) -> None: + # rigid body does not have elements it only has one node. We are setting n_elems to + # make code to work. _bootstrap_from_data requires n_elems to be define + self.n_elems: int = 1 + self.n_nodes: int = 1 - self.position_collection = NotImplementedError - self.velocity_collection = NotImplementedError - self.acceleration_collection = NotImplementedError - self.omega_collection = NotImplementedError - self.alpha_collection = NotImplementedError - self.director_collection = NotImplementedError + self.position_collection: NDArray[np.float64] + self.velocity_collection: NDArray[np.float64] + self.acceleration_collection: NDArray[np.float64] + self.omega_collection: NDArray[np.float64] + self.alpha_collection: NDArray[np.float64] + self.director_collection: NDArray[np.float64] - self.external_forces = NotImplementedError - self.external_torques = NotImplementedError + self.external_forces: NDArray[np.float64] + self.external_torques: NDArray[np.float64] - self.mass = NotImplementedError + self.internal_forces: NDArray[np.float64] + self.internal_torques: NDArray[np.float64] - self.mass_second_moment_of_inertia = NotImplementedError - self.inv_mass_second_moment_of_inertia = NotImplementedError + self.mass: np.float64 + self.volume: np.float64 + self.length: np.float64 + self.tangents: NDArray[np.float64] + self.radius: np.float64 - # @abstractmethod - # # def update_accelerations(self): - # # pass + self.mass_second_moment_of_inertia: NDArray[np.float64] + self.inv_mass_second_moment_of_inertia: NDArray[np.float64] - # def _compute_internal_forces_and_torques(self): - # """ - # This function here is only for integrator to work properly. We do not need - # internal forces and torques at all. - # Parameters - # ---------- - # time - # - # Returns - # ------- - # - # """ - # pass - - def update_accelerations(self, time): + def update_accelerations(self, time: np.float64) -> None: np.copyto( self.acceleration_collection, (self.external_forces) / self.mass, ) # I apply common sub expression elimination here, as J w - J_omega = _batch_matvec( + j_omega = _batch_matvec( self.mass_second_moment_of_inertia, self.omega_collection ) # (J \omega_L ) x \omega_L - lagrangian_transport = _batch_cross(J_omega, self.omega_collection) + lagrangian_transport = _batch_cross(j_omega, self.omega_collection) np.copyto( self.alpha_collection, @@ -74,18 +71,24 @@ def update_accelerations(self, time): ), ) - def zeroed_out_external_forces_and_torques(self, time): + def zeroed_out_external_forces_and_torques(self, time: np.float64) -> None: # Reset forces and torques self.external_forces *= 0.0 self.external_torques *= 0.0 - def compute_position_center_of_mass(self): + def compute_internal_forces_and_torques(self, time: np.float64) -> None: + """ + For rigid body, there is no internal forces and torques + """ + pass + + def compute_position_center_of_mass(self) -> NDArray[np.float64]: """ Return positional center of mass """ return self.position_collection[..., 0].copy() - def compute_translational_energy(self): + def compute_translational_energy(self) -> NDArray[np.float64]: """ Return translational energy """ @@ -97,11 +100,11 @@ def compute_translational_energy(self): ) ) - def compute_rotational_energy(self): + def compute_rotational_energy(self) -> NDArray[np.float64]: """ Return rotational energy """ - J_omega = np.einsum( + j_omega = np.einsum( "ijk,jk->ik", self.mass_second_moment_of_inertia, self.omega_collection ) - return 0.5 * np.einsum("ik,ik->k", self.omega_collection, J_omega).sum() + return 0.5 * np.einsum("ik,ik->k", self.omega_collection, j_omega).sum() diff --git a/elastica/rigidbody/sphere.py b/elastica/rigidbody/sphere.py index 181dd18a..5f87a42a 100644 --- a/elastica/rigidbody/sphere.py +++ b/elastica/rigidbody/sphere.py @@ -1,6 +1,10 @@ -__doc__ = """""" +__doc__ = """ +Implementation of a sphere rigid body. +""" +from typing import TYPE_CHECKING import numpy as np +from numpy.typing import NDArray from elastica._linalg import _batch_cross from elastica.utils import MaxDimension @@ -8,65 +12,80 @@ class Sphere(RigidBodyBase): - def __init__(self, center, base_radius, density): + def __init__( + self, + center: NDArray[np.float64], + base_radius: float, + density: float, + ) -> None: """ Rigid body sphere initializer. Parameters ---------- - center - base_radius - density + center : NDArray[np.float64] + base_radius : float + density : float """ - # rigid body does not have elements it only have one node. We are setting n_elems to - # zero for only make code to work. _bootstrap_from_data requires n_elems to be defined - self.n_elems = 1 - self.radius = base_radius - self.density = density - self.length = 2 * base_radius + super().__init__() + + dim: int = MaxDimension.value() + + assert ( + center.size == dim + ), f"center must be of size {dim}, but was {center.size}" + assert base_radius > 0.0, "base_radius must be positive" + assert density > 0.0, "density must be positive" + + self.radius = np.float64(base_radius) + self.density = np.float64(density) + self.length = np.float64(2 * base_radius) # This is for a rigid body cylinder - self.volume = 4.0 / 3.0 * np.pi * base_radius**3 - self.mass = np.array([self.volume * self.density]) - normal = np.array([1.0, 0.0, 0.0]).reshape(3, 1) - tangents = np.array([0.0, 0.0, 1.0]).reshape(3, 1) + self.volume = np.float64(4.0 / 3.0 * np.pi * base_radius**3) + self.mass = np.float64(self.volume * self.density) + normal = np.array([1.0, 0.0, 0.0], dtype=np.float64).reshape(dim, 1) + tangents = np.array([0.0, 0.0, 1.0], dtype=np.float64).reshape(dim, 1) binormal = _batch_cross(tangents, normal) # Mass second moment of inertia for disk cross-section - mass_second_moment_of_inertia = np.zeros( - (MaxDimension.value(), MaxDimension.value()), np.float64 - ) + mass_second_moment_of_inertia = np.zeros((dim, dim), dtype=np.float64) np.fill_diagonal( mass_second_moment_of_inertia, 2.0 / 5.0 * self.mass * self.radius**2 ) self.mass_second_moment_of_inertia = mass_second_moment_of_inertia.reshape( - MaxDimension.value(), MaxDimension.value(), 1 + (dim, dim, 1) ) self.inv_mass_second_moment_of_inertia = np.linalg.inv( mass_second_moment_of_inertia - ).reshape(MaxDimension.value(), MaxDimension.value(), 1) + ).reshape((dim, dim, 1)) + + # Allocate properties + self.position_collection = np.zeros((dim, 1), dtype=np.float64) + self.velocity_collection = np.zeros((dim, 1), dtype=np.float64) + self.acceleration_collection = np.zeros((dim, 1), dtype=np.float64) + self.omega_collection = np.zeros((dim, 1), dtype=np.float64) + self.alpha_collection = np.zeros((dim, 1), dtype=np.float64) + self.director_collection = np.zeros((dim, dim, 1), dtype=np.float64) + + self.external_forces = np.zeros((dim, 1), dtype=np.float64) + self.external_torques = np.zeros((dim, 1), dtype=np.float64) # position is at the center - self.position_collection = np.zeros((MaxDimension.value(), 1)) self.position_collection[:, 0] = center - self.velocity_collection = np.zeros((MaxDimension.value(), 1)) - self.omega_collection = np.zeros((MaxDimension.value(), 1)) - self.acceleration_collection = np.zeros((MaxDimension.value(), 1)) - self.alpha_collection = np.zeros((MaxDimension.value(), 1)) - - self.director_collection = np.zeros( - (MaxDimension.value(), MaxDimension.value(), 1) - ) self.director_collection[0, ...] = normal self.director_collection[1, ...] = binormal self.director_collection[2, ...] = tangents - self.external_forces = np.zeros((MaxDimension.value())).reshape( - MaxDimension.value(), 1 - ) - self.external_torques = np.zeros((MaxDimension.value())).reshape( - MaxDimension.value(), 1 - ) + +if TYPE_CHECKING: + from .protocol import RigidBodyProtocol + + _: RigidBodyProtocol = Sphere( + center=np.zeros(3), + base_radius=1.0, + density=1.0, + ) diff --git a/elastica/rod/__init__.py b/elastica/rod/__init__.py index b9285b3c..093ab0eb 100644 --- a/elastica/rod/__init__.py +++ b/elastica/rod/__init__.py @@ -1,12 +1,4 @@ __doc__ = """Rod classes and its data structures """ -from elastica.rod.knot_theory import KnotTheory from elastica.rod.rod_base import RodBase -from elastica.rod.data_structures import ( - _RodSymplecticStepperMixin, - _State, - _DerivativeState, - _KinematicState, - _DynamicState, -) diff --git a/elastica/rod/cosserat_rod.py b/elastica/rod/cosserat_rod.py index 051b3751..03309aef 100644 --- a/elastica/rod/cosserat_rod.py +++ b/elastica/rod/cosserat_rod.py @@ -1,5 +1,11 @@ __doc__ = """ Rod classes and implementation details """ +from typing import TYPE_CHECKING, Any, Optional, Type +from typing_extensions import Self +from elastica.typing import RodType +from .protocol import CosseratRodProtocol + +from numpy.typing import NDArray import numpy as np import functools @@ -12,26 +18,25 @@ _batch_matvec, ) from elastica._rotations import _inv_rotate -from elastica.rod.factory_function import allocate -from elastica.rod.knot_theory import KnotTheory from elastica._calculus import ( quadrature_kernel_for_block_structure, difference_kernel_for_block_structure, _difference, _average, ) -from typing import Optional +from .factory_function import allocate +from .knot_theory import KnotTheory position_difference_kernel = _difference position_average = _average @functools.lru_cache(maxsize=1) -def _get_z_vector(): +def _get_z_vector() -> NDArray[np.float64]: return np.array([0.0, 0.0, 1.0]).reshape(3, -1) -def _compute_sigma_kappa_for_blockstructure(memory_block): +def _compute_sigma_kappa_for_blockstructure(memory_block: RodType) -> None: """ This function is a wrapper to call functions which computes shear stretch, strain and bending twist and strain. @@ -74,113 +79,116 @@ class CosseratRod(RodBase, KnotTheory): ---------- n_elems: int The number of elements of the rod. - position_collection: numpy.ndarray + position_collection: NDArray[np.float64] 2D (dim, n_nodes) array containing data with 'float' type. Array containing node position vectors. - velocity_collection: numpy.ndarray + velocity_collection: NDArray[np.float64] 2D (dim, n_nodes) array containing data with 'float' type. Array containing node velocity vectors. - acceleration_collection: numpy.ndarray + acceleration_collection: NDArray[np.float64] 2D (dim, n_nodes) array containing data with 'float' type. Array containing node acceleration vectors. - omega_collection: numpy.ndarray + omega_collection: NDArray[np.float64] 2D (dim, n_elems) array containing data with 'float' type. Array containing element angular velocity vectors. - alpha_collection: numpy.ndarray + alpha_collection: NDArray[np.float64] 2D (dim, n_elems) array containing data with 'float' type. Array contining element angular acceleration vectors. - director_collection: numpy.ndarray + director_collection: NDArray[np.float64] 3D (dim, dim, n_elems) array containing data with 'float' type. Array containing element director matrices. - rest_lengths: numpy.ndarray + rest_lengths: NDArray[np.float64] 1D (n_elems) array containing data with 'float' type. Rod element lengths at rest configuration. - density: numpy.ndarray + density: NDArray[np.float64] 1D (n_elems) array containing data with 'float' type. Rod elements densities. - volume: numpy.ndarray + volume: NDArray[np.float64] 1D (n_elems) array containing data with 'float' type. Rod element volumes. - mass: numpy.ndarray + mass: NDArray[np.float64] 1D (n_nodes) array containing data with 'float' type. Rod node masses. Note that masses are stored on the nodes, not on elements. - mass_second_moment_of_inertia: numpy.ndarray + mass_second_moment_of_inertia: NDArray[np.float64] 3D (dim, dim, n_elems) array containing data with 'float' type. Rod element mass second moment of interia. - inv_mass_second_moment_of_inertia: numpy.ndarray + inv_mass_second_moment_of_inertia: NDArray[np.float64] 3D (dim, dim, n_elems) array containing data with 'float' type. Rod element inverse mass moment of inertia. - rest_voronoi_lengths: numpy.ndarray + rest_voronoi_lengths: NDArray[np.float64] 1D (n_voronoi) array containing data with 'float' type. Rod lengths on the voronoi domain at the rest configuration. - internal_forces: numpy.ndarray + internal_forces: NDArray[np.float64] 2D (dim, n_nodes) array containing data with 'float' type. Rod node internal forces. Note that internal forces are stored on the node, not on elements. - internal_torques: numpy.ndarray + internal_torques: NDArray[np.float64] 2D (dim, n_elems) array containing data with 'float' type. Rod element internal torques. - external_forces: numpy.ndarray + external_forces: NDArray[np.float64] 2D (dim, n_nodes) array containing data with 'float' type. External forces acting on rod nodes. - external_torques: numpy.ndarray + external_torques: NDArray[np.float64] 2D (dim, n_elems) array containing data with 'float' type. External torques acting on rod elements. - lengths: numpy.ndarray + lengths: NDArray[np.float64] 1D (n_elems) array containing data with 'float' type. Rod element lengths. - tangents: numpy.ndarray + tangents: NDArray[np.float64] 2D (dim, n_elems) array containing data with 'float' type. Rod element tangent vectors. - radius: numpy.ndarray + radius: NDArray[np.float64] 1D (n_elems) array containing data with 'float' type. Rod element radius. - dilatation: numpy.ndarray + dilatation: NDArray[np.float64] 1D (n_elems) array containing data with 'float' type. Rod element dilatation. - voronoi_dilatation: numpy.ndarray + voronoi_dilatation: NDArray[np.float64] 1D (n_voronoi) array containing data with 'float' type. Rod dilatation on voronoi domain. - dilatation_rate: numpy.ndarray + dilatation_rate: NDArray[np.float64] 1D (n_elems) array containing data with 'float' type. Rod element dilatation rates. """ + REQUISITE_MODULES: list[Type] = [] + def __init__( - self, - n_elements, - position, - velocity, - omega, - acceleration, - angular_acceleration, - directors, - radius, - mass_second_moment_of_inertia, - inv_mass_second_moment_of_inertia, - shear_matrix, - bend_matrix, - density, - volume, - mass, - internal_forces, - internal_torques, - external_forces, - external_torques, - lengths, - rest_lengths, - tangents, - dilatation, - dilatation_rate, - voronoi_dilatation, - rest_voronoi_lengths, - sigma, - kappa, - rest_sigma, - rest_kappa, - internal_stress, - internal_couple, - ring_rod_flag, - ): + self: CosseratRodProtocol, + n_elements: int, + position: NDArray[np.float64], + velocity: NDArray[np.float64], + omega: NDArray[np.float64], + acceleration: NDArray[np.float64], + angular_acceleration: NDArray[np.float64], + directors: NDArray[np.float64], + radius: NDArray[np.float64], + mass_second_moment_of_inertia: NDArray[np.float64], + inv_mass_second_moment_of_inertia: NDArray[np.float64], + shear_matrix: NDArray[np.float64], + bend_matrix: NDArray[np.float64], + density_array: NDArray[np.float64], + volume: NDArray[np.float64], + mass: NDArray[np.float64], + internal_forces: NDArray[np.float64], + internal_torques: NDArray[np.float64], + external_forces: NDArray[np.float64], + external_torques: NDArray[np.float64], + lengths: NDArray[np.float64], + rest_lengths: NDArray[np.float64], + tangents: NDArray[np.float64], + dilatation: NDArray[np.float64], + dilatation_rate: NDArray[np.float64], + voronoi_dilatation: NDArray[np.float64], + rest_voronoi_lengths: NDArray[np.float64], + sigma: NDArray[np.float64], + kappa: NDArray[np.float64], + rest_sigma: NDArray[np.float64], + rest_kappa: NDArray[np.float64], + internal_stress: NDArray[np.float64], + internal_couple: NDArray[np.float64], + ring_rod_flag: bool, + ) -> None: + self.n_nodes = n_elements + 1 if not ring_rod_flag else n_elements self.n_elems = n_elements self.position_collection = position self.velocity_collection = velocity @@ -193,7 +201,7 @@ def __init__( self.inv_mass_second_moment_of_inertia = inv_mass_second_moment_of_inertia self.shear_matrix = shear_matrix self.bend_matrix = bend_matrix - self.density = density + self.density = density_array self.volume = volume self.mass = mass self.internal_forces = internal_forces @@ -242,17 +250,17 @@ def __init__( def straight_rod( cls, n_elements: int, - start: np.ndarray, - direction: np.ndarray, - normal: np.ndarray, + start: NDArray[np.float64], + direction: NDArray[np.float64], + normal: NDArray[np.float64], base_length: float, base_radius: float, density: float, *, - nu: Optional[float] = None, + nu: Optional[np.float64] = None, youngs_modulus: float, - **kwargs, - ): + **kwargs: Any, + ) -> Self: """ Cosserat rod constructor for straight-rod geometry. @@ -268,11 +276,11 @@ def straight_rod( n_elements : int Number of element. Must be greater than 3. Generally recommended to start with 40-50, and adjust the resolution. - start : NDArray[3, float] + start : NDArray[np.float64] Starting coordinate in 3D - direction : NDArray[3, float] + direction : NDArray[np.float64] Direction of the rod in 3D - normal : NDArray[3, float] + normal : NDArray[np.float64] Normal vector of the rod in 3D base_length : float Total length of the rod @@ -317,7 +325,7 @@ def straight_rod( inv_mass_second_moment_of_inertia, shear_matrix, bend_matrix, - density, + density_array, volume, mass, internal_forces, @@ -341,10 +349,10 @@ def straight_rod( n_elements, direction, normal, - base_length, - base_radius, - density, - youngs_modulus, + np.float64(base_length), + np.float64(base_radius), + np.float64(density), + np.float64(youngs_modulus), rod_origin_position=start, ring_rod_flag=ring_rod_flag, **kwargs, @@ -363,7 +371,7 @@ def straight_rod( inv_mass_second_moment_of_inertia, shear_matrix, bend_matrix, - density, + density_array, volume, mass, internal_forces, @@ -390,17 +398,17 @@ def straight_rod( def ring_rod( cls, n_elements: int, - ring_center_position: np.ndarray, - direction: np.ndarray, - normal: np.ndarray, + ring_center_position: NDArray[np.float64], + direction: NDArray[np.float64], + normal: NDArray[np.float64], base_length: float, base_radius: float, density: float, *, nu: Optional[float] = None, youngs_modulus: float, - **kwargs, - ): + **kwargs: Any, + ) -> Self: """ Cosserat rod constructor for straight-rod geometry. @@ -415,11 +423,11 @@ def ring_rod( ---------- n_elements : int Number of element. Must be greater than 3. Generarally recommended to start with 40-50, and adjust the resolution. - ring_center_position : NDArray[3, float] + ring_center_position : NDArray[np.float64] Center coordinate for ring rod in 3D - direction : NDArray[3, float] + direction : NDArray[np.float64] Direction of the rod in 3D - normal : NDArray[3, float] + normal : NDArray[np.float64] Normal vector of the rod in 3D base_length : float Total length of the rod @@ -427,7 +435,7 @@ def ring_rod( Uniform radius of the rod density : float Density of the rod - nu : float + nu : float | None Damping coefficient for Rayleigh damping youngs_modulus : float Young's modulus @@ -489,10 +497,10 @@ def ring_rod( n_elements, direction, normal, - base_length, - base_radius, - density, - youngs_modulus, + np.float64(base_length), + np.float64(base_radius), + np.float64(density), + np.float64(youngs_modulus), rod_origin_position=ring_center_position, ring_rod_flag=ring_rod_flag, **kwargs, @@ -533,10 +541,12 @@ def ring_rod( internal_couple, ring_rod_flag, ) - rod.REQUISITE_MODULE.append(Constraints) + rod.REQUISITE_MODULES.append(Constraints) return rod - def compute_internal_forces_and_torques(self, time): + def compute_internal_forces_and_torques( + self: CosseratRodProtocol, time: np.float64 + ) -> None: """ Compute internal forces and torques. We need to compute internal forces and torques before the acceleration because they are used in interaction. Thus in order to speed up simulation, we will compute internal forces and torques @@ -545,7 +555,7 @@ def compute_internal_forces_and_torques(self, time): Parameters ---------- - time: float + time: np.float64 current time """ @@ -591,13 +601,13 @@ def compute_internal_forces_and_torques(self, time): ) # Interface to time-stepper mixins (Symplectic, Explicit), which calls this method - def update_accelerations(self, time): + def update_accelerations(self: CosseratRodProtocol, time: np.float64) -> None: """ Updates the acceleration variables Parameters ---------- - time: float + time: np.float64 current time """ @@ -613,12 +623,14 @@ def update_accelerations(self, time): self.dilatation, ) - def zeroed_out_external_forces_and_torques(self, time): + def zeroed_out_external_forces_and_torques( + self: CosseratRodProtocol, time: np.float64 + ) -> None: _zeroed_out_external_forces_and_torques( self.external_forces, self.external_torques ) - def compute_translational_energy(self): + def compute_translational_energy(self: CosseratRodProtocol) -> NDArray[np.float64]: """ Compute total translational energy of the rod at the instance. """ @@ -632,7 +644,7 @@ def compute_translational_energy(self): ).sum() ) - def compute_rotational_energy(self): + def compute_rotational_energy(self: CosseratRodProtocol) -> NDArray[np.float64]: """ Compute total rotational energy of the rod at the instance. """ @@ -642,7 +654,9 @@ def compute_rotational_energy(self): ) return 0.5 * np.einsum("ik,ik->k", self.omega_collection, J_omega_upon_e).sum() - def compute_velocity_center_of_mass(self): + def compute_velocity_center_of_mass( + self: CosseratRodProtocol, + ) -> NDArray[np.float64]: """ Compute velocity center of mass of the rod at the instance. """ @@ -651,7 +665,9 @@ def compute_velocity_center_of_mass(self): return sum_mass_times_velocity / self.mass.sum() - def compute_position_center_of_mass(self): + def compute_position_center_of_mass( + self: CosseratRodProtocol, + ) -> NDArray[np.float64]: """ Compute position center of mass of the rod at the instance. """ @@ -660,7 +676,7 @@ def compute_position_center_of_mass(self): return sum_mass_times_position / self.mass.sum() - def compute_bending_energy(self): + def compute_bending_energy(self: CosseratRodProtocol) -> NDArray[np.float64]: """ Compute total bending energy of the rod at the instance. """ @@ -676,7 +692,7 @@ def compute_bending_energy(self): ).sum() ) - def compute_shear_energy(self): + def compute_shear_energy(self: CosseratRodProtocol) -> NDArray[np.float64]: """ Compute total shear energy of the rod at the instance. """ @@ -693,10 +709,14 @@ def compute_shear_energy(self): # Below is the numba-implementation of Cosserat Rod equations. They don't need to be visible by users. -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _compute_geometry_from_state( - position_collection, volume, lengths, tangents, radius -): + position_collection: NDArray[np.float64], + volume: NDArray[np.float64], + lengths: NDArray[np.float64], + tangents: NDArray[np.float64], + radius: NDArray[np.float64], +) -> None: """ Update given . """ @@ -717,18 +737,18 @@ def _compute_geometry_from_state( radius[k] = np.sqrt(volume[k] / lengths[k] / np.pi) -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _compute_all_dilatations( - position_collection, - volume, - lengths, - tangents, - radius, - dilatation, - rest_lengths, - rest_voronoi_lengths, - voronoi_dilatation, -): + position_collection: NDArray[np.float64], + volume: NDArray[np.float64], + lengths: NDArray[np.float64], + tangents: NDArray[np.float64], + radius: NDArray[np.float64], + dilatation: NDArray[np.float64], + rest_lengths: NDArray[np.float64], + rest_voronoi_lengths: NDArray[np.float64], + voronoi_dilatation: NDArray[np.float64], +) -> None: """ Update """ @@ -747,10 +767,14 @@ def _compute_all_dilatations( voronoi_dilatation[k] = voronoi_lengths[k] / rest_voronoi_lengths[k] -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _compute_dilatation_rate( - position_collection, velocity_collection, lengths, rest_lengths, dilatation_rate -): + position_collection: NDArray[np.float64], + velocity_collection: NDArray[np.float64], + lengths: NDArray[np.float64], + rest_lengths: NDArray[np.float64], + dilatation_rate: NDArray[np.float64], +) -> None: """ Update dilatation_rate given position, velocity, length, and rest_length """ @@ -774,20 +798,20 @@ def _compute_dilatation_rate( ) -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _compute_shear_stretch_strains( - position_collection, - volume, - lengths, - tangents, - radius, - rest_lengths, - rest_voronoi_lengths, - dilatation, - voronoi_dilatation, - director_collection, - sigma, -): + position_collection: NDArray[np.float64], + volume: NDArray[np.float64], + lengths: NDArray[np.float64], + tangents: NDArray[np.float64], + radius: NDArray[np.float64], + rest_lengths: NDArray[np.float64], + rest_voronoi_lengths: NDArray[np.float64], + dilatation: NDArray[np.float64], + voronoi_dilatation: NDArray[np.float64], + director_collection: NDArray[np.float64], + sigma: NDArray[np.float64], +) -> None: """ Update given . """ @@ -809,23 +833,23 @@ def _compute_shear_stretch_strains( sigma[:] = dilatation * _batch_matvec(director_collection, tangents) - z_vector -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _compute_internal_shear_stretch_stresses_from_model( - position_collection, - volume, - lengths, - tangents, - radius, - rest_lengths, - rest_voronoi_lengths, - dilatation, - voronoi_dilatation, - director_collection, - sigma, - rest_sigma, - shear_matrix, - internal_stress, -): + position_collection: NDArray[np.float64], + volume: NDArray[np.float64], + lengths: NDArray[np.float64], + tangents: NDArray[np.float64], + radius: NDArray[np.float64], + rest_lengths: NDArray[np.float64], + rest_voronoi_lengths: NDArray[np.float64], + dilatation: NDArray[np.float64], + voronoi_dilatation: NDArray[np.float64], + director_collection: NDArray[np.float64], + sigma: NDArray[np.float64], + rest_sigma: NDArray[np.float64], + shear_matrix: NDArray[np.float64], + internal_stress: NDArray[np.float64], +) -> None: """ Update given . @@ -849,8 +873,12 @@ def _compute_internal_shear_stretch_stresses_from_model( internal_stress[:] = _batch_matvec(shear_matrix, sigma - rest_sigma) -@numba.njit(cache=True) -def _compute_bending_twist_strains(director_collection, rest_voronoi_lengths, kappa): +@numba.njit(cache=True) # type: ignore +def _compute_bending_twist_strains( + director_collection: NDArray[np.float64], + rest_voronoi_lengths: NDArray[np.float64], + kappa: NDArray[np.float64], +) -> None: """ Update given . """ @@ -862,15 +890,15 @@ def _compute_bending_twist_strains(director_collection, rest_voronoi_lengths, ka kappa[2, k] = temp[2, k] / rest_voronoi_lengths[k] -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _compute_internal_bending_twist_stresses_from_model( - director_collection, - rest_voronoi_lengths, - internal_couple, - bend_matrix, - kappa, - rest_kappa, -): + director_collection: NDArray[np.float64], + rest_voronoi_lengths: NDArray[np.float64], + internal_couple: NDArray[np.float64], + bend_matrix: NDArray[np.float64], + kappa: NDArray[np.float64], + rest_kappa: NDArray[np.float64], +) -> None: """ Upate given . @@ -891,25 +919,25 @@ def _compute_internal_bending_twist_stresses_from_model( internal_couple[:] = _batch_matvec(bend_matrix, temp) -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _compute_internal_forces( - position_collection, - volume, - lengths, - tangents, - radius, - rest_lengths, - rest_voronoi_lengths, - dilatation, - voronoi_dilatation, - director_collection, - sigma, - rest_sigma, - shear_matrix, - internal_stress, - internal_forces, - ghost_elems_idx, -): + position_collection: NDArray[np.float64], + volume: NDArray[np.float64], + lengths: NDArray[np.float64], + tangents: NDArray[np.float64], + radius: NDArray[np.float64], + rest_lengths: NDArray[np.float64], + rest_voronoi_lengths: NDArray[np.float64], + dilatation: NDArray[np.float64], + voronoi_dilatation: NDArray[np.float64], + director_collection: NDArray[np.float64], + sigma: NDArray[np.float64], + rest_sigma: NDArray[np.float64], + shear_matrix: NDArray[np.float64], + internal_stress: NDArray[np.float64], + internal_forces: NDArray[np.float64], + ghost_elems_idx: NDArray[np.float64], +) -> None: """ Update given . """ @@ -952,28 +980,28 @@ def _compute_internal_forces( ) -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _compute_internal_torques( - position_collection, - velocity_collection, - tangents, - lengths, - rest_lengths, - director_collection, - rest_voronoi_lengths, - bend_matrix, - rest_kappa, - kappa, - voronoi_dilatation, - mass_second_moment_of_inertia, - omega_collection, - internal_stress, - internal_couple, - dilatation, - dilatation_rate, - internal_torques, - ghost_voronoi_idx, -): + position_collection: NDArray[np.float64], + velocity_collection: NDArray[np.float64], + tangents: NDArray[np.float64], + lengths: NDArray[np.float64], + rest_lengths: NDArray[np.float64], + director_collection: NDArray[np.float64], + rest_voronoi_lengths: NDArray[np.float64], + bend_matrix: NDArray[np.float64], + rest_kappa: NDArray[np.float64], + kappa: NDArray[np.float64], + voronoi_dilatation: NDArray[np.float64], + mass_second_moment_of_inertia: NDArray[np.float64], + omega_collection: NDArray[np.float64], + internal_stress: NDArray[np.float64], + internal_couple: NDArray[np.float64], + dilatation: NDArray[np.float64], + dilatation_rate: NDArray[np.float64], + internal_torques: NDArray[np.float64], + ghost_voronoi_idx: NDArray[np.int32], +) -> None: """ Update . """ @@ -1041,18 +1069,18 @@ def _compute_internal_torques( ) -@numba.njit(cache=True) +@numba.njit(cache=True) # type: ignore def _update_accelerations( - acceleration_collection, - internal_forces, - external_forces, - mass, - alpha_collection, - inv_mass_second_moment_of_inertia, - internal_torques, - external_torques, - dilatation, -): + acceleration_collection: NDArray[np.float64], + internal_forces: NDArray[np.float64], + external_forces: NDArray[np.float64], + mass: NDArray[np.float64], + alpha_collection: NDArray[np.float64], + inv_mass_second_moment_of_inertia: NDArray[np.float64], + internal_torques: NDArray[np.float64], + external_torques: NDArray[np.float64], + dilatation: NDArray[np.float64], +) -> None: """ Update given . """ @@ -1076,8 +1104,10 @@ def _update_accelerations( ) * dilatation[k] -@numba.njit(cache=True) -def _zeroed_out_external_forces_and_torques(external_forces, external_torques): +@numba.njit(cache=True) # type: ignore +def _zeroed_out_external_forces_and_torques( + external_forces: NDArray[np.float64], external_torques: NDArray[np.float64] +) -> None: """ This function is to zeroed out external forces and torques. @@ -1097,3 +1127,26 @@ def _zeroed_out_external_forces_and_torques(external_forces, external_torques): for i in range(3): for k in range(n_elems): external_torques[i, k] = 0.0 + + +if TYPE_CHECKING: + _: CosseratRodProtocol = CosseratRod.straight_rod( + 3, + np.zeros(3), + np.array([0, 1, 0]), + np.array([0, 0, 1]), + 1.0, + 0.1, + 1.0, + youngs_modulus=1.0, + ) + _: CosseratRodProtocol = CosseratRod.ring_rod( # type: ignore[no-redef] + 3, + np.zeros(3), + np.array([0, 1, 0]), + np.array([0, 0, 1]), + 1.0, + 0.1, + 1.0, + youngs_modulus=1.0, + ) diff --git a/elastica/rod/data_structures.py b/elastica/rod/data_structures.py index 2a9466eb..f7b248bc 100644 --- a/elastica/rod/data_structures.py +++ b/elastica/rod/data_structures.py @@ -1,51 +1,56 @@ __doc__ = "Data structure wrapper for rod components" +from typing import TYPE_CHECKING, Optional +from typing_extensions import Self import numpy as np +from numpy.typing import NDArray from numba import njit from elastica._rotations import _get_rotation_matrix, _rotate from elastica._linalg import _batch_matmul +if TYPE_CHECKING: + from elastica.systems.protocol import SymplecticSystemProtocol +else: + SymplecticSystemProtocol = "SymplecticSystemProtocol" # FIXME : Explicit Stepper doesn't work as States lose the # views they initially had when working with a timestepper. -""" -class _RodExplicitStepperMixin: - def __init__(self): - ( - self.state, - self.__deriv_state, - self.position_collection, - self.director_collection, - self.velocity_collection, - self.omega_collection, - self.acceleration_collection, - self.alpha_collection, # angular acceleration - ) = _bootstrap_from_data( - "explicit", self.n_elems, self._vector_states, self._matrix_states - ) - - # def __setattr__(self, name, value): - # np.copy(self.__dict__[name], value) - - def __call__(self, time, *args, **kwargs): - self.update_accelerations(time) # Internal, external - - # print("KRC", self.state.kinematic_rate_collection) - # print("DEr", self.__deriv_state.rate_collection) - if np.shares_memory( - self.state.kinematic_rate_collection, - self.velocity_collection - # self.__deriv_state.rate_collection - ): - print("Shares memory") - else: - print("Explicit states does not share memory") - return self.__deriv_state -""" +# class _RodExplicitStepperMixin: +# def __init__(self) -> None: +# ( +# self.state, +# self.__deriv_state, +# self.position_collection, +# self.director_collection, +# self.velocity_collection, +# self.omega_collection, +# self.acceleration_collection, +# self.alpha_collection, # angular acceleration +# ) = _bootstrap_from_data( +# "explicit", self.n_elems, self._vector_states, self._matrix_states +# ) +# +# # def __setattr__(self, name, value): +# # np.copy(self.__dict__[name], value) +# +# def __call__(self, time, *args, **kwargs): +# self.update_accelerations(time) # Internal, external +# +# # print("KRC", self.state.kinematic_rate_collection) +# # print("DEr", self.__deriv_state.rate_collection) +# if np.shares_memory( +# self.state.kinematic_rate_collection, +# self.velocity_collection +# # self.__deriv_state.rate_collection +# ): +# print("Shares memory") +# else: +# print("Explicit states does not share memory") +# return self.__deriv_state class _RodSymplecticStepperMixin: - def __init__(self): + def __init__(self: SymplecticSystemProtocol) -> None: self.kinematic_states = _KinematicState( self.position_collection, self.director_collection ) @@ -62,18 +67,32 @@ def __init__(self): # is another function self.kinematic_rates = self.dynamic_states.kinematic_rates - def update_internal_forces_and_torques(self, time, *args, **kwargs): - self.compute_internal_forces_and_torques(time) - - def dynamic_rates(self, time, prefac, *args, **kwargs): + def dynamic_rates( + self: SymplecticSystemProtocol, + time: np.float64, + prefac: np.float64, + ) -> NDArray[np.float64]: self.update_accelerations(time) - return self.dynamic_states.dynamic_rates(time, prefac, *args, **kwargs) - - def reset_external_forces_and_torques(self, time, *args, **kwargs): - self.zeroed_out_external_forces_and_torques(time) - - -def _bootstrap_from_data(stepper_type: str, n_elems: int, vector_states, matrix_states): + return self.dynamic_states.dynamic_rates(time, prefac) + + +def _bootstrap_from_data( + stepper_type: str, + n_elems: int, + vector_states: NDArray[np.float64], + matrix_states: NDArray[np.float64], +) -> Optional[ + tuple[ + "_State", + "_DerivativeState", + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + ] +]: """Returns states wrapping numpy arrays based on the time-stepping algorithm Convenience method that takes in rod internal (raw np.ndarray) data, create views @@ -106,7 +125,7 @@ def _bootstrap_from_data(stepper_type: str, n_elems: int, vector_states, matrix_ position = np.ndarray.view(vector_states[..., :n_nodes]) directors = np.ndarray.view(matrix_states) v_w_dvdt_dwdt = np.ndarray.view(vector_states[..., n_nodes:]) - output = () + output: tuple = () if stepper_type == "explicit": v_w_states = np.ndarray.view(vector_states[..., n_nodes : 3 * n_nodes - 1]) output += ( @@ -121,7 +140,7 @@ def _bootstrap_from_data(stepper_type: str, n_elems: int, vector_states, matrix_ # ) raise NotImplementedError else: - return + return None n_velocity_end = n_nodes + n_nodes velocity = np.ndarray.view(vector_states[..., n_nodes:n_velocity_end]) @@ -156,10 +175,10 @@ class _State: def __init__( self, n_elems: int, - position_collection_view, - director_collection_view, - kinematic_rate_collection_view, - ): + position_collection_view: NDArray[np.float64], + director_collection_view: NDArray[np.float64], + kinematic_rate_collection_view: NDArray[np.float64], + ) -> None: """ Parameters ---------- @@ -175,7 +194,7 @@ def __init__( self.director_collection = director_collection_view self.kinematic_rate_collection = kinematic_rate_collection_view - def __iadd__(self, scaled_deriv_array): + def __iadd__(self, scaled_deriv_array: NDArray[np.float64]) -> Self: """overloaded += operator The add for directors is customized to reflect Rodrigues' rotation @@ -244,7 +263,7 @@ def __iadd__(self, scaled_deriv_array): ] return self - def __add__(self, scaled_derivative_state): + def __add__(self, scaled_derivative_state: NDArray[np.float64]) -> "_State": """overloaded + operator, useful in state.k1 = state + dt * deriv_state The add for directors is customized to reflect Rodrigues' rotation @@ -298,7 +317,9 @@ class _DerivativeState: /multiplication used. """ - def __init__(self, _unused_n_elems: int, rate_collection_view): + def __init__( + self, _unused_n_elems: int, rate_collection_view: NDArray[np.float64] + ) -> None: """ Parameters ---------- @@ -309,7 +330,7 @@ def __init__(self, _unused_n_elems: int, rate_collection_view): super(_DerivativeState, self).__init__() self.rate_collection = rate_collection_view - def __rmul__(self, scalar): + def __rmul__(self, scalar: np.float64) -> NDArray[np.float64]: # type: ignore """overloaded scalar * self, Parameters @@ -357,7 +378,7 @@ def __rmul__(self, scalar): """ return scalar * self.rate_collection - def __mul__(self, scalar): + def __mul__(self, scalar: np.float64) -> NDArray[np.float64]: """overloaded self * scalar TODO Check if this pattern (forwarding to __mul__) has @@ -390,7 +411,11 @@ class _KinematicState: only these methods are provided. """ - def __init__(self, position_collection_view, director_collection_view): + def __init__( + self, + position_collection_view: NDArray[np.float64], + director_collection_view: NDArray[np.float64], + ) -> None: """ Parameters ---------- @@ -403,15 +428,15 @@ def __init__(self, position_collection_view, director_collection_view): self.director_collection = director_collection_view -@njit(cache=True) +@njit(cache=True) # type: ignore def overload_operator_kinematic_numba( - n_nodes, - prefac, - position_collection, - director_collection, - velocity_collection, - omega_collection, -): + n_nodes: int, + prefac: np.float64, + position_collection: NDArray[np.float64], + director_collection: NDArray[np.float64], + velocity_collection: NDArray[np.float64], + omega_collection: NDArray[np.float64], +) -> None: """overloaded += operator The add for directors is customized to reflect Rodrigues' rotation @@ -451,11 +476,11 @@ class _DynamicState: def __init__( self, - v_w_collection, - dvdt_dwdt_collection, - velocity_collection, - omega_collection, - ): + v_w_collection: NDArray[np.float64], + dvdt_dwdt_collection: NDArray[np.float64], + velocity_collection: NDArray[np.float64], + omega_collection: NDArray[np.float64], + ) -> None: """ Parameters ---------- @@ -472,7 +497,9 @@ def __init__( self.velocity_collection = velocity_collection self.omega_collection = omega_collection - def kinematic_rates(self, time, *args, **kwargs): + def kinematic_rates( + self, time: np.float64, prefac: np.float64 + ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: """Yields kinematic rates to interact with _KinematicState Returns @@ -488,7 +515,9 @@ def kinematic_rates(self, time, *args, **kwargs): # Comes from kin_state -> (x,Q) += dt * (v,w) <- First part of dyn_state return self.velocity_collection, self.omega_collection - def dynamic_rates(self, time, prefac, *args, **kwargs): + def dynamic_rates( + self, time: np.float64, prefac: np.float64 + ) -> NDArray[np.float64]: """Yields dynamic rates to add to with _DynamicState Returns ------- @@ -502,8 +531,11 @@ def dynamic_rates(self, time, prefac, *args, **kwargs): return prefac * self.dvdt_dwdt_collection -@njit(cache=True) -def overload_operator_dynamic_numba(rate_collection, scaled_second_deriv_array): +@njit(cache=True) # type: ignore +def overload_operator_dynamic_numba( + rate_collection: NDArray[np.float64], + scaled_second_deriv_array: NDArray[np.float64], +) -> None: """overloaded += operator, updating dynamic_rates Parameters ---------- diff --git a/elastica/rod/factory_function.py b/elastica/rod/factory_function.py index 6eca6b5a..2765e872 100644 --- a/elastica/rod/factory_function.py +++ b/elastica/rod/factory_function.py @@ -1,30 +1,64 @@ __doc__ = """ Factory function to allocate variables for Cosserat Rod""" -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import logging import numpy as np from numpy.testing import assert_allclose +from numpy.typing import NDArray from elastica.utils import MaxDimension, Tolerance from elastica._linalg import _batch_cross, _batch_norm, _batch_dot def allocate( - n_elements, - direction, - normal, - base_length, - base_radius, - density, - youngs_modulus: float, + n_elements: int, + direction: NDArray[np.float64], + normal: NDArray[np.float64], + base_length: np.float64, + base_radius: np.float64, + density: np.float64, + youngs_modulus: np.float64, *, rod_origin_position: np.ndarray, ring_rod_flag: bool, - shear_modulus: Optional[float] = None, + shear_modulus: Optional[np.float64] = None, position: Optional[np.ndarray] = None, directors: Optional[np.ndarray] = None, rest_sigma: Optional[np.ndarray] = None, rest_kappa: Optional[np.ndarray] = None, - **kwargs, -): + **kwargs: Any, +) -> tuple[ + int, + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], + NDArray[np.float64], +]: log = logging.getLogger() if "poisson_ratio" in kwargs: @@ -335,14 +369,16 @@ def allocate( """ -def _assert_dim(vector, max_dim: int, name: str): +def _assert_dim(vector: np.ndarray, max_dim: int, name: str) -> None: assert vector.ndim < max_dim, ( f"Input {name} dimension is not correct {vector.shape}" + f" It should be maximum {max_dim}D vector or single floating number." ) -def _assert_shape(array: np.ndarray, expected_shape: Tuple[int], name: str): +def _assert_shape( + array: np.ndarray, expected_shape: Tuple[int, ...], name: str +) -> None: assert array.shape == expected_shape, ( f"Given {name} shape is not correct, it should be " + str(expected_shape) @@ -351,7 +387,9 @@ def _assert_shape(array: np.ndarray, expected_shape: Tuple[int], name: str): ) -def _position_validity_checker(position, start, n_elements): +def _position_validity_checker( + position: NDArray[np.float64], start: NDArray[np.float64], n_elements: int +) -> None: """Checker on user-defined position validity""" _assert_shape(position, (MaxDimension.value(), n_elements + 1), "position") @@ -367,7 +405,9 @@ def _position_validity_checker(position, start, n_elements): ) -def _directors_validity_checker(directors, tangents, n_elements): +def _directors_validity_checker( + directors: NDArray[np.float64], tangents: NDArray[np.float64], n_elements: int +) -> None: """Checker on user-defined directors validity""" _assert_shape( directors, (MaxDimension.value(), MaxDimension.value(), n_elements), "directors" @@ -413,7 +453,11 @@ def _directors_validity_checker(directors, tangents, n_elements): ) -def _position_validity_checker_ring_rod(position, ring_center_position, n_elements): +def _position_validity_checker_ring_rod( + position: NDArray[np.float64], + ring_center_position: NDArray[np.float64], + n_elements: int, +) -> None: """Checker on user-defined position validity""" _assert_shape(position, (MaxDimension.value(), n_elements), "position") diff --git a/elastica/rod/knot_theory.py b/elastica/rod/knot_theory.py index 23d2b009..b5f00184 100644 --- a/elastica/rod/knot_theory.py +++ b/elastica/rod/knot_theory.py @@ -10,32 +10,14 @@ The details discussion is included in `N Charles et. al. PRL (2019) `_. """ -from typing import Protocol, Union - -from numba import njit import numpy as np +from numpy.typing import NDArray +from numba import njit from elastica.rod.rod_base import RodBase from elastica._linalg import _batch_norm, _batch_dot, _batch_cross - -class KnotTheoryCompatibleProtocol(Protocol): - """KnotTheoryCompatibleProtocol - - Required properties to use KnotTheory mixin - """ - - @property - def position_collection(self) -> np.ndarray: ... - - @property - def director_collection(self) -> np.ndarray: ... - - @property - def radius(self) -> np.ndarray: ... - - @property - def base_length(self) -> np.ndarray: ... +from .protocol import CosseratRodProtocol class KnotTheory: @@ -46,7 +28,7 @@ class KnotTheory: KnotTheory can be mixed with any rod-class based on RodBase:: class MyRod(RodBase, KnotTheory): - def __init__(self): + def __init__(self) -> None: super().__init__() rod = MyRod(...) @@ -76,9 +58,7 @@ def __init__(self): """ - MIXIN_PROTOCOL = Union[RodBase, KnotTheoryCompatibleProtocol] - - def compute_twist(self: MIXIN_PROTOCOL): + def compute_twist(self: CosseratRodProtocol) -> NDArray[np.float64]: """ See :ref:`api/rods:Knot Theory (Mixin)` for the detail. """ @@ -89,10 +69,10 @@ def compute_twist(self: MIXIN_PROTOCOL): return total_twist[0] def compute_writhe( - self: MIXIN_PROTOCOL, + self: CosseratRodProtocol, type_of_additional_segment: str = "next_tangent", alpha: float = 1.0, - ): + ) -> NDArray[np.float64]: """ See :ref:`api/rods:Knot Theory (Mixin)` for the detail. @@ -112,10 +92,10 @@ def compute_writhe( )[0] def compute_link( - self: MIXIN_PROTOCOL, + self: CosseratRodProtocol, type_of_additional_segment: str = "next_tangent", alpha: float = 1.0, - ): + ) -> NDArray[np.float64]: """ See :ref:`api/rods:Knot Theory (Mixin)` for the detail. @@ -138,7 +118,9 @@ def compute_link( )[0] -def compute_twist(center_line, normal_collection): +def compute_twist( + center_line: NDArray[np.float64], normal_collection: NDArray[np.float64] +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: """ Compute the twist of a rod, using center_line and normal collection. @@ -188,8 +170,10 @@ def compute_twist(center_line, normal_collection): return total_twist, local_twist -@njit(cache=True) -def _compute_twist(center_line, normal_collection): +@njit(cache=True) # type: ignore +def _compute_twist( + center_line: NDArray[np.float64], normal_collection: NDArray[np.float64] +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: """ Parameters ---------- @@ -264,7 +248,11 @@ def _compute_twist(center_line, normal_collection): return total_twist, local_twist -def compute_writhe(center_line, segment_length, type_of_additional_segment): +def compute_writhe( + center_line: NDArray[np.float64], + segment_length: np.float64, + type_of_additional_segment: str, +) -> NDArray[np.float64]: """ This function computes the total writhe history of a rod. @@ -313,8 +301,8 @@ def compute_writhe(center_line, segment_length, type_of_additional_segment): return total_writhe -@njit(cache=True) -def _compute_writhe(center_line): +@njit(cache=True) # type: ignore +def _compute_writhe(center_line: NDArray[np.float64]) -> NDArray[np.float64]: """ Parameters ---------- @@ -386,12 +374,12 @@ def _compute_writhe(center_line): def compute_link( - center_line: np.ndarray, - normal_collection: np.ndarray, - radius: np.ndarray, - segment_length: float, + center_line: NDArray[np.float64], + normal_collection: NDArray[np.float64], + radius: NDArray[np.float64], + segment_length: np.float64, type_of_additional_segment: str, -): +) -> NDArray[np.float64]: """ This function computes the total link history of a rod. @@ -469,8 +457,12 @@ def compute_link( return total_link -@njit(cache=True) -def _compute_auxiliary_line(center_line, normal_collection, radius): +@njit(cache=True) # type: ignore +def _compute_auxiliary_line( + center_line: NDArray[np.float64], + normal_collection: NDArray[np.float64], + radius: NDArray[np.float64], +) -> NDArray[np.float64]: """ This function computes the auxiliary line using rod center line and normal collection. @@ -524,8 +516,10 @@ def _compute_auxiliary_line(center_line, normal_collection, radius): return auxiliary_line -@njit(cache=True) -def _compute_link(center_line, auxiliary_line): +@njit(cache=True) # type: ignore +def _compute_link( + center_line: NDArray[np.float64], auxiliary_line: NDArray[np.float64] +) -> NDArray[np.float64]: """ Parameters @@ -602,10 +596,13 @@ def _compute_link(center_line, auxiliary_line): return total_link -@njit(cache=True) +@njit(cache=True) # type: ignore def _compute_auxiliary_line_added_segments( - beginning_direction, end_direction, auxiliary_line, segment_length -): + beginning_direction: NDArray[np.float64], + end_direction: NDArray[np.float64], + auxiliary_line: NDArray[np.float64], + segment_length: np.float64, +) -> NDArray[np.float64]: """ This code is for computing position of added segments to the auxiliary line. @@ -645,10 +642,12 @@ def _compute_auxiliary_line_added_segments( return new_auxiliary_line -@njit(cache=True) +@njit(cache=True) # type: ignore def _compute_additional_segment( - center_line, segment_length, type_of_additional_segment -): + center_line: NDArray[np.float64], + segment_length: np.float64, + type_of_additional_segment: str, +) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]: """ This function adds two points at the end of center line. Distance from the center line is given by segment_length. Direction from center line to the new point locations can be computed using 3 methods, which can be selected by diff --git a/elastica/rod/protocol.py b/elastica/rod/protocol.py new file mode 100644 index 00000000..c28de3bb --- /dev/null +++ b/elastica/rod/protocol.py @@ -0,0 +1,49 @@ +from typing import Protocol + +import numpy as np +from numpy.typing import NDArray + +from elastica.systems.protocol import SystemProtocol, SlenderBodyGeometryProtocol + + +class _RodEnergy(Protocol): + def compute_bending_energy(self) -> NDArray[np.float64]: ... + + def compute_shear_energy(self) -> NDArray[np.float64]: ... + + +class CosseratRodProtocol( + SystemProtocol, SlenderBodyGeometryProtocol, _RodEnergy, Protocol +): + + mass: NDArray[np.float64] + volume: NDArray[np.float64] + radius: NDArray[np.float64] + tangents: NDArray[np.float64] + lengths: NDArray[np.float64] + rest_lengths: NDArray[np.float64] + rest_voronoi_lengths: NDArray[np.float64] + kappa: NDArray[np.float64] + sigma: NDArray[np.float64] + rest_kappa: NDArray[np.float64] + rest_sigma: NDArray[np.float64] + + internal_stress: NDArray[np.float64] + internal_couple: NDArray[np.float64] + dilatation: NDArray[np.float64] + dilatation_rate: NDArray[np.float64] + voronoi_dilatation: NDArray[np.float64] + + bend_matrix: NDArray[np.float64] + shear_matrix: NDArray[np.float64] + + mass_second_moment_of_inertia: NDArray[np.float64] + inv_mass_second_moment_of_inertia: NDArray[np.float64] + + ghost_voronoi_idx: NDArray[np.int32] + ghost_elems_idx: NDArray[np.int32] + + ring_rod_flag: bool + periodic_boundary_nodes_idx: NDArray[np.int32] + periodic_boundary_elems_idx: NDArray[np.int32] + periodic_boundary_voronoi_idx: NDArray[np.int32] diff --git a/elastica/rod/rod_base.py b/elastica/rod/rod_base.py index b769555f..3bd846b7 100644 --- a/elastica/rod/rod_base.py +++ b/elastica/rod/rod_base.py @@ -1,5 +1,9 @@ __doc__ = """Base class for rods""" +from typing import Type +import numpy as np +from numpy.typing import NDArray + class RodBase: """ @@ -11,16 +15,24 @@ class RodBase: """ - REQUISITE_MODULES = [] + REQUISITE_MODULES: list[Type] = [] - def __init__(self): + def __init__(self) -> None: """ RodBase does not take any arguments. """ - pass - # self.position_collection = NotImplemented - # self.omega_collection = NotImplemented - # self.acceleration_collection = NotImplemented - # self.alpha_collection = NotImplemented - # self.external_forces = NotImplemented - # self.external_torques = NotImplemented + self.position_collection: NDArray[np.float64] + self.velocity_collection: NDArray[np.float64] + self.acceleration_collection: NDArray[np.float64] + self.director_collection: NDArray[np.float64] + self.omega_collection: NDArray[np.float64] + self.alpha_collection: NDArray[np.float64] + self.external_forces: NDArray[np.float64] + self.external_torques: NDArray[np.float64] + + self.ghost_voronoi_idx: NDArray[np.int32] + self.ghost_elems_idx: NDArray[np.int32] + + self.periodic_boundary_nodes_idx: NDArray[np.int32] + self.periodic_boundary_elems_idx: NDArray[np.int32] + self.periodic_boundary_voronoi_idx: NDArray[np.int32] diff --git a/elastica/surface/plane.py b/elastica/surface/plane.py index a68a5be6..b12a3fea 100644 --- a/elastica/surface/plane.py +++ b/elastica/surface/plane.py @@ -2,12 +2,14 @@ from elastica.surface.surface_base import SurfaceBase import numpy as np -from numpy.testing import assert_allclose +from numpy.typing import NDArray from elastica.utils import Tolerance class Plane(SurfaceBase): - def __init__(self, plane_origin: np.ndarray, plane_normal: np.ndarray): + def __init__( + self, plane_origin: NDArray[np.float64], plane_normal: NDArray[np.float64] + ): """ Plane surface initializer. @@ -21,11 +23,10 @@ def __init__(self, plane_origin: np.ndarray, plane_normal: np.ndarray): Expect (3,1)-shaped array. """ - assert_allclose( + assert np.allclose( np.linalg.norm(plane_normal), 1, - atol=Tolerance.atol(), - err_msg="plane normal is not a unit vector", - ) + atol=float(Tolerance.atol()), + ), "plane normal is not a unit vector" self.normal = np.asarray(plane_normal).reshape(3) self.origin = np.asarray(plane_origin).reshape(3, 1) diff --git a/elastica/surface/surface_base.py b/elastica/surface/surface_base.py index 31549612..37812ea7 100644 --- a/elastica/surface/surface_base.py +++ b/elastica/surface/surface_base.py @@ -1,4 +1,8 @@ __doc__ = """Base class for surfaces""" +from typing import TYPE_CHECKING, Type + +import numpy as np +from numpy.typing import NDArray class SurfaceBase: @@ -11,10 +15,17 @@ class SurfaceBase: """ - REQUISITE_MODULES = [] + REQUISITE_MODULES: list[Type] = [] - def __init__(self): + def __init__(self) -> None: """ SurfaceBase does not take any arguments. """ - pass + self.normal: NDArray[np.float64] # (3,) + self.origin: NDArray[np.float64] # (3, 1) + + +if TYPE_CHECKING: + from elastica.systems.protocol import StaticSystemProtocol + + _: StaticSystemProtocol = SurfaceBase() diff --git a/elastica/systems/__init__.py b/elastica/systems/__init__.py index 33a39b83..b61b905e 100644 --- a/elastica/systems/__init__.py +++ b/elastica/systems/__init__.py @@ -1,4 +1,9 @@ -def is_system_a_collection(system): +from typing import Type + +from elastica.typing import SystemType, SystemCollectionType + + +def is_system_a_collection(system: "SystemType | SystemCollectionType") -> bool: # Check if system is a "collection" of smaller systems # by checking for the [] method """ @@ -33,82 +38,4 @@ def is_system_a_collection(system): from elastica.modules import BaseSystemCollection __sys_get_item = getattr(system, "__getitem__", None) - return issubclass(system.__class__, BaseSystemCollection) or callable( - __sys_get_item - ) - - -def make_memory_for_explicit_stepper(stepper, system): - # TODO Automated logic (class creation, memory management logic) agnostic of stepper details (RK, AB etc.) - - from elastica.timestepper.explicit_steppers import ( - RungeKutta4, - EulerForward, - ) - - is_this_system_a_collection = is_system_a_collection(system) - - if RungeKutta4 in stepper.__class__.mro(): - # Bad way of doing it, introduces tight coupling - # this should rather be taken from the class itself - class MemoryRungeKutta4: - def __init__(self): - super(MemoryRungeKutta4, self).__init__() - self.initial_state = None - self.k_1 = None - self.k_2 = None - self.k_3 = None - self.k_4 = None - - memory_cls = MemoryRungeKutta4 - elif EulerForward in stepper.__class__.mro(): - memory_cls = NotImplementedError - else: - # TODO Memory allocation for other integrators - raise NotImplementedError("Making memory for other types not supported") - - return ( - MemoryCollection(memory_cls(), len(system)) - if is_this_system_a_collection - else memory_cls() - ) - - -class MemoryCollection: - """Slots of memories for timestepper in a cohesive unit. - - A `MemoryCollection` object is meant to be used in conjunction - with a `SystemCollection`, where each independent `System` to - be integrated has its own `Memory`. - - Example - ------- - - A RK4 integrator needs to store k_1, k_2, k_3, k_4 (intermediate - results from four stages) for each `System`. The restriction for - having a memory slot arises because the `Systems` are usually - not independent of one another and may need communication after - every stage. - """ - - def __init__(self, memory, n_memory_slots): - super(MemoryCollection, self).__init__() - - self.__memories = [None] * n_memory_slots - - from copy import copy - - for i_slot in range(n_memory_slots - 1): - self.__memories[i_slot] = copy(memory) - - # Save final copy - self.__memories[-1] = memory - - def __getitem__(self, idx): - return self.__memories[idx] - - def __len__(self): - return len(self.__memories) - - def __iter__(self): - return self.__memories.__iter__() + return isinstance(system, BaseSystemCollection) or callable(__sys_get_item) diff --git a/elastica/systems/analytical.py b/elastica/systems/analytical.py index 1b64a570..60fe712b 100644 --- a/elastica/systems/analytical.py +++ b/elastica/systems/analytical.py @@ -3,6 +3,7 @@ import numpy as np from elastica._rotations import _rotate from elastica.rod.data_structures import _RodSymplecticStepperMixin +from elastica.rod.rod_base import RodBase class BaseStatefulSystem: @@ -161,10 +162,10 @@ def energy(st): current_energy = energy(self._state) return current_energy, anal_energy - def update_internal_forces_and_torques(self, time): + def compute_internal_forces_and_torques(self, time): pass - def reset_external_forces_and_torques(self, time): + def zeroed_out_external_forces_and_torques(self, time): pass @@ -299,7 +300,12 @@ class CollectiveSystem: def __init__(self): self._memory_blocks = [] - self.systems = self._memory_blocks + + def systems(self): + return self._memory_blocks + + def block_systems(self): + return self._memory_blocks def __getitem__(self, idx): return self._memory_blocks[idx] @@ -343,8 +349,8 @@ def __init__(self): super( ScalarExponentialDampedHarmonicOscillatorCollectiveSystem, self ).__init__() - self.systems.append(ScalarExponentialDecaySystem()) - self.systems.append(DampedSimpleHarmonicOscillatorSystem()) + self._memory_blocks.append(ScalarExponentialDecaySystem()) + self._memory_blocks.append(DampedSimpleHarmonicOscillatorSystem()) def make_simple_system_with_positions_directors( @@ -355,8 +361,9 @@ def make_simple_system_with_positions_directors( ) -class SimpleSystemWithPositionsDirectors(_RodSymplecticStepperMixin): +class SimpleSystemWithPositionsDirectors(_RodSymplecticStepperMixin, RodBase): def __init__(self, start_position, end_position, start_director): + self.ring_rod_flag = False # TODO: self.a = 0.5 self.b = 1 self.c = 2 diff --git a/elastica/systems/memory.py b/elastica/systems/memory.py new file mode 100644 index 00000000..c669be9b --- /dev/null +++ b/elastica/systems/memory.py @@ -0,0 +1,84 @@ +from typing import Iterator, TypeVar, Generic, Type +from elastica.timestepper.protocol import ExplicitStepperProtocol +from elastica.typing import SystemCollectionType + +from copy import copy + + +# FIXME: Move memory related functions to separate module or as part of the timestepper +# TODO: Use MemoryProtocol +def make_memory_for_explicit_stepper( + stepper: ExplicitStepperProtocol, system: SystemCollectionType +) -> "MemoryCollection": + # TODO Automated logic (class creation, memory management logic) agnostic of stepper details (RK, AB etc.) + + from elastica.timestepper.explicit_steppers import ( + RungeKutta4, + EulerForward, + ) + + # is_this_system_a_collection = is_system_a_collection(system) + + memory_cls: Type + if RungeKutta4 in stepper.__class__.mro(): + # Bad way of doing it, introduces tight coupling + # this should rather be taken from the class itself + class MemoryRungeKutta4: + def __init__(self) -> None: + self.initial_state = None + self.k_1 = None + self.k_2 = None + self.k_3 = None + self.k_4 = None + + memory_cls = MemoryRungeKutta4 + elif EulerForward in stepper.__class__.mro(): + + class MemoryEulerForward: + def __init__(self) -> None: + self.initial_state = None + self.k = None + + memory_cls = MemoryEulerForward + else: + raise NotImplementedError("Making memory for other types not supported") + + return MemoryCollection(memory_cls(), len(system)) + + +M = TypeVar("M", bound="MemoryCollection") + + +class MemoryCollection(Generic[M]): + """Slots of memories for timestepper in a cohesive unit. + + A `MemoryCollection` object is meant to be used in conjunction + with a `SystemCollection`, where each independent `System` to + be integrated has its own `Memory`. + + Example + ------- + + A RK4 integrator needs to store k_1, k_2, k_3, k_4 (intermediate + results from four stages) for each `System`. The restriction for + having a memory slot arises because the `Systems` are usually + not independent of one another and may need communication after + every stage. + """ + + def __init__(self, memory: M, n_memory_slots: int): + super(MemoryCollection, self).__init__() + + self.__memories: list[M] = [] + for _ in range(n_memory_slots - 1): + self.__memories.append(copy(memory)) + self.__memories.append(memory) + + def __getitem__(self, idx: int) -> M: + return self.__memories[idx] + + def __len__(self) -> int: + return len(self.__memories) + + def __iter__(self) -> Iterator[M]: + return self.__memories.__iter__() diff --git a/elastica/systems/protocol.py b/elastica/systems/protocol.py new file mode 100644 index 00000000..254cdcaf --- /dev/null +++ b/elastica/systems/protocol.py @@ -0,0 +1,82 @@ +__doc__ = """Base class for elastica system""" + +from typing import Protocol, Type +from elastica.typing import StateType, SystemType + +from elastica.rod.data_structures import _KinematicState, _DynamicState + +import numpy as np +from numpy.typing import NDArray + + +class StaticSystemProtocol(Protocol): + REQUISITE_MODULES: list[Type] + + +class SystemProtocol(StaticSystemProtocol, Protocol): + """ + Protocol for all dynamic elastica system + """ + + def compute_internal_forces_and_torques(self, time: np.float64) -> None: ... + + def update_accelerations(self, time: np.float64) -> None: ... + + def zeroed_out_external_forces_and_torques(self, time: np.float64) -> None: ... + + +class SlenderBodyGeometryProtocol(Protocol): + @property + def n_nodes(self) -> int: ... + + @property + def n_elems(self) -> int: ... + + position_collection: NDArray[np.float64] + velocity_collection: NDArray[np.float64] + acceleration_collection: NDArray[np.float64] + + omega_collection: NDArray[np.float64] + alpha_collection: NDArray[np.float64] + director_collection: NDArray[np.float64] + + external_forces: NDArray[np.float64] + external_torques: NDArray[np.float64] + + internal_forces: NDArray[np.float64] + internal_torques: NDArray[np.float64] + + +class SymplecticSystemProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol): + """ + Protocol for system with symplectic state variables + """ + + v_w_collection: NDArray[np.float64] + dvdt_dwdt_collection: NDArray[np.float64] + + @property + def kinematic_states(self) -> _KinematicState: ... + + @property + def dynamic_states(self) -> _DynamicState: ... + + def kinematic_rates( + self, time: np.float64, prefac: np.float64 + ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: ... + + def dynamic_rates( + self, time: np.float64, prefac: np.float64 + ) -> NDArray[np.float64]: ... + + +class ExplicitSystemProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol): + # TODO: Temporarily made to handle explicit stepper. + # Need to be refactored as the explicit stepper is further developed. + def __call__(self, time: np.float64, dt: np.float64) -> np.float64: ... + @property + def state(self) -> StateType: ... + @state.setter + def state(self, state: StateType) -> None: ... + @property + def n_elems(self) -> int: ... diff --git a/elastica/timestepper/__init__.py b/elastica/timestepper/__init__.py index 3efa466e..25f88341 100644 --- a/elastica/timestepper/__init__.py +++ b/elastica/timestepper/__init__.py @@ -1,85 +1,73 @@ __doc__ = """Timestepping utilities to be used with Rod and RigidBody classes""" +from typing import Tuple, List, Callable, Type, Any, overload, cast +from elastica.typing import SystemType, SystemCollectionType, SteppersOperatorsType import numpy as np from tqdm import tqdm -from elastica.timestepper.symplectic_steppers import ( - SymplecticStepperTag, - PositionVerlet, - PEFRL, -) -from elastica.timestepper.explicit_steppers import ( - ExplicitStepperTag, - RungeKutta4, - EulerForward, -) - - -# TODO: Both extend_stepper_interface and integrate should be in separate file. -# __init__ is probably not an ideal place to have these scripts. -def extend_stepper_interface(Stepper, System): - from elastica.utils import extend_instance - from elastica.systems import is_system_a_collection - - # Check if system is a "collection" of smaller systems - # by checking for the [] method - is_this_system_a_collection = is_system_a_collection(System) - """ - # Stateful steppers are no more used so remove them - ConcreteStepper = ( - Stepper.stepper if _StatefulStepper in Stepper.__class__.mro() else Stepper - ) - """ - ConcreteStepper = Stepper - - if type(ConcreteStepper.Tag) == SymplecticStepperTag: - from elastica.timestepper.symplectic_steppers import ( - _SystemInstanceStepper, - _SystemCollectionStepper, - SymplecticStepperMethods as StepperMethodCollector, - ) - elif type(ConcreteStepper.Tag) == ExplicitStepperTag: - from elastica.timestepper.explicit_steppers import ( - _SystemInstanceStepper, - _SystemCollectionStepper, - ExplicitStepperMethods as StepperMethodCollector, - ) - # elif SymplecticCosseratRodStepper in ConcreteStepper.__class__.mro(): - # return # hacky fix for now. remove HybridSteppers in a future version. - else: - raise NotImplementedError( - "Only explicit and symplectic steppers are supported, given stepper is {}".format( - ConcreteStepper.__class__.__name__ - ) - ) - - stepper_methods = StepperMethodCollector(ConcreteStepper) - do_step_method = ( - _SystemCollectionStepper.do_step - if is_this_system_a_collection - else _SystemInstanceStepper.do_step - ) - return do_step_method, stepper_methods.step_methods() - - -# TODO Improve interface of this function to take args and kwargs for ease of use +from elastica.systems import is_system_a_collection + +from .symplectic_steppers import PositionVerlet, PEFRL +from .explicit_steppers import RungeKutta4, EulerForward +from .protocol import StepperProtocol, SymplecticStepperProtocol + + +# Deprecated: Remove in the future version +# Many script still uses this method to control timestep. Keep it for backward compatibility +def extend_stepper_interface( + stepper: StepperProtocol, system_collection: SystemCollectionType +) -> Tuple[ + Callable[ + [StepperProtocol, SystemCollectionType, np.float64, np.float64], np.float64 + ], + SteppersOperatorsType, +]: + try: + stepper_methods: SteppersOperatorsType = stepper.steps_and_prefactors + do_step_method: Callable = stepper.do_step # type: ignore[attr-defined] + except AttributeError as e: + raise NotImplementedError(f"{stepper} stepper is not supported.") from e + return do_step_method, stepper_methods + + +@overload +def integrate( + stepper: StepperProtocol, + systems: SystemType, + final_time: float, + n_steps: int, + restart_time: float, + progress_bar: bool, +) -> float: ... + + +@overload def integrate( - StatefulStepper, - System, + stepper: StepperProtocol, + systems: SystemCollectionType, + final_time: float, + n_steps: int, + restart_time: float, + progress_bar: bool, +) -> float: ... + + +def integrate( + stepper: StepperProtocol, + systems: "SystemType | SystemCollectionType", final_time: float, n_steps: int = 1000, restart_time: float = 0.0, progress_bar: bool = True, - **kwargs, -): +) -> float: """ Parameters ---------- - StatefulStepper : + stepper : StepperProtocol Stepper algorithm to use. - System : + systems : SystemType | SystemCollectionType The elastica-system to simulate. final_time : float Total simulation time. The timestep is determined by final_time / n_steps. @@ -93,17 +81,17 @@ def integrate( assert final_time > 0.0, "Final time is negative!" assert n_steps > 0, "Number of integration steps is negative!" - # Extend the stepper's interface after introspecting the properties - # of the system. If system is a collection of small systems (whose - # states cannot be aggregated), then stepper now loops over the system - # state - do_step, stages_and_updates = extend_stepper_interface(StatefulStepper, System) - dt = np.float64(float(final_time) / n_steps) - time = restart_time + time = np.float64(restart_time) - for i in tqdm(range(n_steps), disable=(not progress_bar)): - time = do_step(StatefulStepper, stages_and_updates, System, time, dt) + if is_system_a_collection(systems): + systems = cast(SystemCollectionType, systems) + for i in tqdm(range(n_steps), disable=(not progress_bar)): + time = stepper.step(systems, time, dt) + else: + systems = cast(SystemType, systems) + for i in tqdm(range(n_steps), disable=(not progress_bar)): + time = stepper.step_single_instance(systems, time, dt) print("Final time of simulation is : ", time) - return time + return float(time) diff --git a/elastica/timestepper/explicit_steppers.py b/elastica/timestepper/explicit_steppers.py index 2e06e975..5fb1531f 100644 --- a/elastica/timestepper/explicit_steppers.py +++ b/elastica/timestepper/explicit_steppers.py @@ -1,7 +1,20 @@ __doc__ = """Explicit timesteppers and concepts""" + +from typing import Any + import numpy as np from copy import copy +from elastica.typing import ( + SystemType, + SystemCollectionType, + OperatorType, + SteppersOperatorsType, + StateType, +) +from elastica.systems.protocol import ExplicitSystemProtocol +from .protocol import ExplicitStepperProtocol, MemoryProtocol + """ Developer Note @@ -52,208 +65,253 @@ """ -class _SystemInstanceStepper: - # # noinspection PyUnresolvedReferences - @staticmethod - def do_step( - TimeStepper, - _stages_and_updates, - System, - Memory, - time: np.float64, - dt: np.float64, - ): - for stage, update in _stages_and_updates: - stage(TimeStepper, System, Memory, time, dt) - time = update(TimeStepper, System, Memory, time, dt) - return time +class EulerForwardMemory: + def __init__(self, initial_state: StateType) -> None: + self.initial_state = initial_state -class _SystemCollectionStepper: - # # noinspection PyUnresolvedReferences - @staticmethod - def do_step( - TimeStepper, - _stages_and_updates, - SystemCollection, - MemoryCollection, - time: np.float64, - dt: np.float64, - ): - for stage, update in _stages_and_updates: - SystemCollection.synchronize(time) - for system, memory in zip(SystemCollection[:-1], MemoryCollection[:-1]): - stage(TimeStepper, system, memory, time, dt) - _ = update(TimeStepper, system, memory, time, dt) +class RungeKutta4Memory: + """ + Stores all states of Rk within the time-stepper. Works as long as the states + are all one big numpy array, made possible by carefully using views. - stage(TimeStepper, SystemCollection[-1], MemoryCollection[-1], time, dt) - time = update( - TimeStepper, SystemCollection[-1], MemoryCollection[-1], time, dt - ) - return time + Convenience wrapper around Stateless that provides memory + """ + def __init__( + self, + initial_state: StateType, + ) -> None: + self.initial_state = initial_state + self.k_1 = initial_state + self.k_2 = initial_state + self.k_3 = initial_state + self.k_4 = initial_state -class ExplicitStepperMethods: + +class ExplicitStepperMixin: """Base class for all explicit steppers Can also be used as a mixin with optional cls argument below """ - def __init__(self, timestepper_instance): - take_methods_from = timestepper_instance - __stages = [ - v - for (k, v) in take_methods_from.__class__.__dict__.items() - if k.endswith("stage") - ] - __updates = [ - v - for (k, v) in take_methods_from.__class__.__dict__.items() - if k.endswith("update") - ] + def __init__(self: ExplicitStepperProtocol): + self.steps_and_prefactors = self.step_methods() - # Tuples are almost immutable - _n_stages = len(__stages) - _n_updates = len(__updates) + def step_methods(self: ExplicitStepperProtocol) -> SteppersOperatorsType: + stages = self.get_stages() + updates = self.get_updates() - assert ( - _n_stages == _n_updates + assert len(stages) == len( + updates ), "Number of stages and updates should be equal to one another" - - self._stages_and_updates = tuple(zip(__stages, __updates)) - - def step_methods(self): - return self._stages_and_updates + return tuple(zip(stages, updates)) @property - def n_stages(self): - return len(self._stages_and_updates) + def n_stages(self: ExplicitStepperProtocol) -> int: + return len(self.steps_and_prefactors) + def step( + self: ExplicitStepperProtocol, + SystemCollection: SystemCollectionType, + time: np.float64, + dt: np.float64, + ) -> np.float64: + if isinstance( + self, EulerForward + ): # TODO: Cleanup - use depedency injection instead + Memory = EulerForwardMemory + elif isinstance(self, RungeKutta4): + Memory = RungeKutta4Memory # type: ignore[assignment] + else: + raise NotImplementedError(f"Memory class not defined for {self}") + memory_collection = tuple( + [Memory(initial_state=system.state) for system in SystemCollection] + ) + return ExplicitStepperMixin.do_step(self, self.steps_and_prefactors, SystemCollection, memory_collection, time, dt) # type: ignore[attr-defined] -# class StatefulRungeKutta4(_StatefulStepper): -# """ -# Stores all states of Rk within the time-stepper. Works as long as the states -# are all one big numpy array, made possible by carefully using views. -# -# Convenience wrapper around Stateless that provides memory -# """ -# -# def __init__(self): -# super(StatefulRungeKutta4, self).__init__() -# self.stepper = RungeKutta4() -# self.initial_state = None -# self.k_1 = None -# self.k_2 = None -# self.k_3 = None -# self.k_4 = None - - -""" -Classical EulerForward -""" - + @staticmethod + def do_step( + TimeStepper: ExplicitStepperProtocol, + steps_and_prefactors: SteppersOperatorsType, + SystemCollection: SystemCollectionType, + MemoryCollection: Any, # TODO + time: np.float64, + dt: np.float64, + ) -> np.float64: + for stage, update in steps_and_prefactors: + SystemCollection.synchronize(time) + for system, memory in zip(SystemCollection[:-1], MemoryCollection[:-1]): + stage(system, memory, time, dt) + _ = update(system, memory, time, dt) -# class EulerForward: -# Tag = ExplicitStepperTag() -# -# def __init__(self): -# pass -# -# def _first_stage(self, System, Memory, time, dt): -# pass -# -# def _first_update(self, System, Memory, time, dt): -# System.state += dt * System(time, dt) -# return time + dt + stage(SystemCollection[-1], MemoryCollection[-1], time, dt) + time = update(SystemCollection[-1], MemoryCollection[-1], time, dt) + return time + def step_single_instance( + self: ExplicitStepperProtocol, + System: SystemType, + Memory: MemoryProtocol, + time: np.float64, + dt: np.float64, + ) -> np.float64: + for stage, update in self.steps_and_prefactors: + stage(System, Memory, time, dt) + time = update(System, Memory, time, dt) + return time -# class StatefulEulerForward(_StatefulStepper): -# def __init__(self): -# super(StatefulEulerForward, self).__init__() -# self.stepper = EulerForward() +class EulerForward(ExplicitStepperMixin): + """ + Classical Euler Forward stepper. Stateless, coordinates operations only. + """ -""" -class ExplicitLinearExponentialIntegrator( - _LinearExponentialIntegratorMixin, ExplicitStepper -): - def __init__(self): - _LinearExponentialIntegratorMixin.__init__(self) - ExplicitStepper.__init__(self, _LinearExponentialIntegratorMixin) - - -class StatefulLinearExponentialIntegrator(_StatefulStepper): - def __init__(self): - super(StatefulLinearExponentialIntegrator, self).__init__() - self.stepper = ExplicitLinearExponentialIntegrator() - self.linear_operator = None -""" + def get_stages(self) -> list[OperatorType]: + return [self._first_stage] + def get_updates(self) -> list[OperatorType]: + return [self._first_update] -class ExplicitStepperTag: - def __init__(self): + def _first_stage( + self, + System: ExplicitSystemProtocol, + Memory: EulerForwardMemory, + time: np.float64, + dt: np.float64, + ) -> None: pass + def _first_update( + self, + System: ExplicitSystemProtocol, + Memory: EulerForwardMemory, + time: np.float64, + dt: np.float64, + ) -> np.float64: + System.state += dt * System(time, dt) # type: ignore[arg-type] + return time + dt -class RungeKutta4: + +class RungeKutta4(ExplicitStepperMixin): """ Stateless runge-kutta4. coordinates operations only, memory needs to be externally managed and allocated. """ - Tag = ExplicitStepperTag() + def get_stages(self) -> list[OperatorType]: + return [ + self._first_stage, + self._second_stage, + self._third_stage, + self._fourth_stage, + ] - def __init__(self): - pass + def get_updates(self) -> list[OperatorType]: + return [ + self._first_update, + self._second_update, + self._third_update, + self._fourth_update, + ] # These methods should be static, but because we need to enable automatic # discovery in ExplicitStepper, these are bound to the RungeKutta4 class # For automatic discovery, the order of declaring stages here is very important - def _first_stage(self, System, Memory, time: np.float64, dt: np.float64): + def _first_stage( + self, + System: ExplicitSystemProtocol, + Memory: RungeKutta4Memory, + time: np.float64, + dt: np.float64, + ) -> None: Memory.initial_state = copy(System.state) - Memory.k_1 = dt * System(time, dt) # Don't update state yet + Memory.k_1 = dt * System(time, dt) # type: ignore[operator, assignment] - def _first_update(self, System, Memory, time: np.float64, dt: np.float64): + def _first_update( + self, + System: ExplicitSystemProtocol, + Memory: RungeKutta4Memory, + time: np.float64, + dt: np.float64, + ) -> np.float64: # prepare for next stage - System.state = Memory.initial_state + 0.5 * Memory.k_1 + System.state = Memory.initial_state + 0.5 * Memory.k_1 # type: ignore[operator] return time + 0.5 * dt - def _second_stage(self, System, Memory, time: np.float64, dt: np.float64): - Memory.k_2 = dt * System(time, dt) # Don't update state yet + def _second_stage( + self, + System: ExplicitSystemProtocol, + Memory: RungeKutta4Memory, + time: np.float64, + dt: np.float64, + ) -> None: + Memory.k_2 = dt * System(time, dt) # type: ignore[operator, assignment] - def _second_update(self, System, Memory, time: np.float64, dt: np.float64): + def _second_update( + self, + System: ExplicitSystemProtocol, + Memory: RungeKutta4Memory, + time: np.float64, + dt: np.float64, + ) -> np.float64: # prepare for next stage - System.state = Memory.initial_state + 0.5 * Memory.k_2 + System.state = Memory.initial_state + 0.5 * Memory.k_2 # type: ignore[operator] return time - def _third_stage(self, System, Memory, time: np.float64, dt: np.float64): - Memory.k_3 = dt * System(time, dt) # Don't update state yet + def _third_stage( + self, + System: ExplicitSystemProtocol, + Memory: RungeKutta4Memory, + time: np.float64, + dt: np.float64, + ) -> None: + Memory.k_3 = dt * System(time, dt) # type: ignore[operator, assignment] - def _third_update(self, System, Memory, time: np.float64, dt: np.float64): + def _third_update( + self, + System: ExplicitSystemProtocol, + Memory: RungeKutta4Memory, + time: np.float64, + dt: np.float64, + ) -> np.float64: # prepare for next stage - System.state = Memory.initial_state + Memory.k_3 + System.state = Memory.initial_state + Memory.k_3 # type: ignore[operator] return time + 0.5 * dt - def _fourth_stage(self, System, Memory, time: np.float64, dt: np.float64): - Memory.k_4 = dt * System(time, dt) # Don't update state yet + def _fourth_stage( + self, + System: ExplicitSystemProtocol, + Memory: RungeKutta4Memory, + time: np.float64, + dt: np.float64, + ) -> None: + Memory.k_4 = dt * System(time, dt) # type: ignore[operator, assignment] - def _fourth_update(self, System, Memory, time: np.float64, dt: np.float64): + def _fourth_update( + self, + System: ExplicitSystemProtocol, + Memory: RungeKutta4Memory, + time: np.float64, + dt: np.float64, + ) -> np.float64: # prepare for next stage System.state = ( Memory.initial_state - + (Memory.k_1 + 2.0 * Memory.k_2 + 2.0 * Memory.k_3 + Memory.k_4) / 6.0 + + (Memory.k_1 + 2.0 * Memory.k_2 + 2.0 * Memory.k_3 + Memory.k_4) / 6.0 # type: ignore[operator] ) return time -class EulerForward: - Tag = ExplicitStepperTag() - - def __init__(self): - super(EulerForward, self).__init__() - - def _first_stage(self, System, Memory, time, dt): - pass - - def _first_update(self, System, Memory, time, dt): - System.state += dt * System(time, dt) - return time + dt +# class ExplicitLinearExponentialIntegrator( +# _LinearExponentialIntegratorMixin, ExplicitStepper +# ): +# def __init__(self): +# _LinearExponentialIntegratorMixin.__init__(self) +# ExplicitStepper.__init__(self, _LinearExponentialIntegratorMixin) +# +# +# class StatefulLinearExponentialIntegrator(_StatefulStepper): +# def __init__(self): +# super(StatefulLinearExponentialIntegrator, self).__init__() +# self.stepper = ExplicitLinearExponentialIntegrator() +# self.linear_operator = None diff --git a/elastica/timestepper/_stepper_interface.py b/elastica/timestepper/protocol.py similarity index 61% rename from elastica/timestepper/_stepper_interface.py rename to elastica/timestepper/protocol.py index 99d2a6ce..1a64c725 100644 --- a/elastica/timestepper/_stepper_interface.py +++ b/elastica/timestepper/protocol.py @@ -1,34 +1,58 @@ __doc__ = "Time stepper interface" +from typing import Protocol -class _TimeStepper: - """Interface classes for all time-steppers""" +from elastica.typing import ( + SystemType, + SteppersOperatorsType, + OperatorType, + SystemCollectionType, +) - def __init__(self): - pass +import numpy as np - def do_step(self, *args, **kwargs): - raise NotImplementedError( - "TimeStepper hierarchy is not supposed to access the do-step routine of the TimeStepper base class. " - ) +class StepperProtocol(Protocol): + """Protocol for all time-steppers""" + + steps_and_prefactors: SteppersOperatorsType + + def __init__(self) -> None: ... + + @property + def n_stages(self) -> int: ... + + def step_methods(self) -> SteppersOperatorsType: ... + + def step( + self, SystemCollection: SystemCollectionType, time: np.float64, dt: np.float64 + ) -> np.float64: ... + + def step_single_instance( + self, SystemCollection: SystemType, time: np.float64, dt: np.float64 + ) -> np.float64: ... + + +class SymplecticStepperProtocol(StepperProtocol, Protocol): + """symplectic stepper protocol.""" + + def get_steps(self) -> list[OperatorType]: ... + + def get_prefactors(self) -> list[OperatorType]: ... + + +class MemoryProtocol(Protocol): + @property + def initial_state(self) -> bool: ... + + +class ExplicitStepperProtocol(StepperProtocol, Protocol): + """symplectic stepper protocol.""" + + def get_stages(self) -> list[OperatorType]: ... + + def get_updates(self) -> list[OperatorType]: ... -# class _StatefulStepper: -# """ -# Stateful explicit, symplectic stepper wrapper. -# """ -# -# def __init__(self): -# pass -# -# # For stateful steppes, bind memory to self -# def do_step(self, System, time: np.float64, dt: np.float64): -# return self.stepper.do_step(System, self, time, dt) -# -# @property -# def n_stages(self): -# return self.stepper.n_stages -# # class _LinearExponentialIntegratorMixin: # """ diff --git a/elastica/timestepper/symplectic_steppers.py b/elastica/timestepper/symplectic_steppers.py index 9a2f57ff..3937e337 100644 --- a/elastica/timestepper/symplectic_steppers.py +++ b/elastica/timestepper/symplectic_steppers.py @@ -1,21 +1,26 @@ __doc__ = """Symplectic time steppers and concepts for integrating the kinematic and dynamic equations of rod-like objects. """ -import numpy as np +from typing import Any, Final + +from itertools import zip_longest + +from elastica.typing import ( + SystemType, + SystemCollectionType, + # StepOperatorType, + # PrefactorOperatorType, + OperatorType, + SteppersOperatorsType, +) -# from elastica._elastica_numba._timestepper._symplectic_steppers import ( -# SymplecticStepperTag, -# PositionVerlet, -# PEFRL, -# ) +import numpy as np -# from elastica.timestepper._stepper_interface import ( -# _TimeStepper, -# _LinearExponentialIntegratorMixin, -# ) from elastica.rod.data_structures import ( overload_operator_kinematic_numba, overload_operator_dynamic_numba, ) +from elastica.systems.protocol import SymplecticSystemProtocol +from .protocol import SymplecticStepperProtocol """ Developer Note @@ -26,201 +31,159 @@ """ -class _SystemInstanceStepper: - @staticmethod - def do_step( - TimeStepper, _steps_and_prefactors, System, time: np.float64, dt: np.float64 - ): - for kin_prefactor, kin_step, dyn_step in _steps_and_prefactors[:-1]: - kin_step(TimeStepper, System, time, dt) - time += kin_prefactor(TimeStepper, dt) - System.update_internal_forces_and_torques(time) - dyn_step(TimeStepper, System, time, dt) +class SymplecticStepperMixin: + def __init__(self: SymplecticStepperProtocol): + self.steps_and_prefactors: Final[SteppersOperatorsType] = self.step_methods() - # Peel the last kinematic step and prefactor alone - last_kin_prefactor = _steps_and_prefactors[-1][0] - last_kin_step = _steps_and_prefactors[-1][1] + def step_methods(self: SymplecticStepperProtocol) -> SteppersOperatorsType: + # Let the total number of steps for the Symplectic method + # be (2*n + 1) (for time-symmetry). + _steps: list[OperatorType] = self.get_steps() + # Prefac here is necessary because the linear-exponential integrator + # needs only the prefactor and not the dt. + _prefactors: list[OperatorType] = self.get_prefactors() + assert int(np.ceil(len(_steps) / 2)) == len( + _prefactors + ), f"{len(_steps)=}, {len(_prefactors)=}" - last_kin_step(TimeStepper, System, time, dt) - return time + last_kin_prefactor(TimeStepper, dt) + # Separate the kinematic and dynamic steps + _kinematic_steps: list[OperatorType] = _steps[::2] + _dynamic_steps: list[OperatorType] = _steps[1::2] + def no_operation(*args: Any) -> None: + pass -class _SystemCollectionStepper: - """ - Symplectic stepper collection class - """ + return tuple( + zip_longest( + _prefactors, + _kinematic_steps, + _dynamic_steps, + fillvalue=no_operation, + ) + ) + + @property + def n_stages(self: SymplecticStepperProtocol) -> int: + return len(self.steps_and_prefactors) + + def step( + self: SymplecticStepperProtocol, + SystemCollection: SystemCollectionType, + time: np.float64, + dt: np.float64, + ) -> np.float64: + return SymplecticStepperMixin.do_step( + self, self.steps_and_prefactors, SystemCollection, time, dt + ) + # TODO: Merge with .step method in the future. + # DEPRECATED: Use .step instead. @staticmethod def do_step( - TimeStepper, - _steps_and_prefactors, - SystemCollection, + TimeStepper: SymplecticStepperProtocol, + steps_and_prefactors: SteppersOperatorsType, + SystemCollection: SystemCollectionType, time: np.float64, dt: np.float64, - ): + ) -> np.float64: """ Function for doing symplectic stepper over the user defined rods (system). - Parameters - ---------- - SystemCollection: rod object - time: float - dt: float - Returns ------- + time: float + The time after the integration step. """ - for kin_prefactor, kin_step, dyn_step in _steps_and_prefactors[:-1]: + for kin_prefactor, kin_step, dyn_step in steps_and_prefactors[:-1]: - for system in SystemCollection._memory_blocks: - kin_step(TimeStepper, system, time, dt) + for system in SystemCollection.block_systems(): + kin_step(system, time, dt) - time += kin_prefactor(TimeStepper, dt) + time += kin_prefactor(dt) # Constrain only values SystemCollection.constrain_values(time) # We need internal forces and torques because they are used by interaction module. - for system in SystemCollection._memory_blocks: - system.update_internal_forces_and_torques(time) + for system in SystemCollection.block_systems(): + system.compute_internal_forces_and_torques(time) # system.update_internal_forces_and_torques() # Add external forces, controls etc. SystemCollection.synchronize(time) - for system in SystemCollection._memory_blocks: - dyn_step(TimeStepper, system, time, dt) + for system in SystemCollection.block_systems(): + dyn_step(system, time, dt) # Constrain only rates SystemCollection.constrain_rates(time) # Peel the last kinematic step and prefactor alone - last_kin_prefactor = _steps_and_prefactors[-1][0] - last_kin_step = _steps_and_prefactors[-1][1] + last_kin_prefactor = steps_and_prefactors[-1][0] + last_kin_step = steps_and_prefactors[-1][1] - for system in SystemCollection._memory_blocks: - last_kin_step(TimeStepper, system, time, dt) - time += last_kin_prefactor(TimeStepper, dt) + for system in SystemCollection.block_systems(): + last_kin_step(system, time, dt) + time += last_kin_prefactor(dt) SystemCollection.constrain_values(time) # Call back function, will call the user defined call back functions and store data SystemCollection.apply_callbacks(time, round(time / dt)) # Zero out the external forces and torques - for system in SystemCollection._memory_blocks: - system.reset_external_forces_and_torques(time) + for system in SystemCollection.block_systems(): + system.zeroed_out_external_forces_and_torques(time) return time + def step_single_instance( + self: SymplecticStepperProtocol, + System: SymplecticSystemProtocol, + time: np.float64, + dt: np.float64, + ) -> np.float64: -class SymplecticStepperMethods: - def __init__(self, timestepper_instance): - take_methods_from = timestepper_instance - # Let the total number of steps for the Symplectic method - # be (2*n + 1) (for time-symmetry). What we do is collect - # the first n + 1 entries down in _steps and _prefac below, and then - # reverse and append it to itself. - self._steps = [ - v - for (k, v) in take_methods_from.__class__.__dict__.items() - if k.endswith("step") - ] - # Prefac here is necessary because the linear-exponential integrator - # needs only the prefactor and not the dt. - self._prefactors = [ - v - for (k, v) in take_methods_from.__class__.__dict__.items() - if k.endswith("prefactor") - ] - - # # We are getting function named as _update_internal_forces_torques from dictionary, - # # it turns a list. - # self._update_internal_forces_torques = [ - # v - # for (k, v) in take_methods_from.__class__.__dict__.items() - # if k.endswith("forces_torques") - # ] - - def mirror(in_list): - """Mirrors an input list ignoring the last element - If steps = [A, B, C] - then this call makes it [A, B, C, B, A] - - Parameters - ---------- - in_list : input list to be mirrored, modified in-place - - Returns - ------- - - """ - # syntax is very ugly - if len(in_list) > 1: - in_list.extend(in_list[-2::-1]) - elif in_list: - in_list.append(in_list[0]) - - mirror(self._steps) - mirror(self._prefactors) - - assert ( - len(self._steps) == 2 * len(self._prefactors) - 1 - ), "Size mismatch in the number of steps and prefactors provided for a Symplectic Stepper!" - - self._kinematic_steps = self._steps[::2] - self._dynamic_steps = self._steps[1::2] - - # Avoid this check for MockClasses - if len(self._kinematic_steps) > 0: - assert ( - len(self._kinematic_steps) == len(self._dynamic_steps) + 1 - ), "Size mismatch in the number of kinematic and dynamic steps provided for a Symplectic Stepper!" - - from itertools import zip_longest - - def NoOp(*args): - pass - - self._steps_and_prefactors = tuple( - zip_longest( - self._prefactors, - self._kinematic_steps, - self._dynamic_steps, - fillvalue=NoOp, - ) - ) - - def step_methods(self): - return self._steps_and_prefactors - - @property - def n_stages(self): - return len(self._steps_and_prefactors) + for kin_prefactor, kin_step, dyn_step in self.steps_and_prefactors[:-1]: + kin_step(System, time, dt) + time += kin_prefactor(dt) + System.compute_internal_forces_and_torques(time) + dyn_step(System, time, dt) + # Peel the last kinematic step and prefactor alone + last_kin_prefactor = self.steps_and_prefactors[-1][0] + last_kin_step = self.steps_and_prefactors[-1][1] -class SymplecticStepperTag: - def __init__(self): - pass + last_kin_step(System, time, dt) + return time + last_kin_prefactor(dt) -class PositionVerlet: +class PositionVerlet(SymplecticStepperMixin): """ Position Verlet symplectic time stepper class, which includes methods for second-order position Verlet. """ - Tag = SymplecticStepperTag() + def get_steps(self) -> list[OperatorType]: + return [ + self._first_kinematic_step, + self._first_dynamic_step, + self._first_kinematic_step, + ] - def __init__(self): - pass + def get_prefactors(self) -> list[OperatorType]: + return [ + self._first_prefactor, + self._first_prefactor, + ] - def _first_prefactor(self, dt): + def _first_prefactor(self, dt: np.float64) -> np.float64: return 0.5 * dt - def _first_kinematic_step(self, System, time: np.float64, dt: np.float64): + def _first_kinematic_step( + self, System: SymplecticSystemProtocol, time: np.float64, dt: np.float64 + ) -> None: prefac = self._first_prefactor(dt) - overload_operator_kinematic_numba( System.n_nodes, prefac, @@ -230,15 +193,16 @@ def _first_kinematic_step(self, System, time: np.float64, dt: np.float64): System.omega_collection, ) - def _first_dynamic_step(self, System, time: np.float64, dt: np.float64): - + def _first_dynamic_step( + self, System: SymplecticSystemProtocol, time: np.float64, dt: np.float64 + ) -> None: overload_operator_dynamic_numba( System.dynamic_states.rate_collection, System.dynamic_rates(time, dt), ) -class PEFRL: +class PEFRL(SymplecticStepperMixin): """ Position Extended Forest-Ruth Like Algorithm of I.M. Omelyan, I.M. Mryglod and R. Folk, Computer Physics Communications 146, 188 (2002), @@ -246,23 +210,39 @@ class PEFRL: """ # xi and chi are confusing, but be careful! - ξ = np.float64(0.1786178958448091e0) # ξ - λ = -np.float64(0.2123418310626054e0) # λ - χ = -np.float64(0.6626458266981849e-1) # χ + ξ: np.float64 = np.float64(0.1786178958448091e0) # ξ + λ: np.float64 = -np.float64(0.2123418310626054e0) # λ + χ: np.float64 = -np.float64(0.6626458266981849e-1) # χ # Pre-calculate other coefficients - lambda_dash_coeff = 0.5 * (1.0 - 2.0 * λ) - xi_chi_dash_coeff = 1.0 - 2.0 * (ξ + χ) - - Tag = SymplecticStepperTag() - - def __init__(self): - pass + lambda_dash_coeff: np.float64 = 0.5 * (1.0 - 2.0 * λ) + xi_chi_dash_coeff: np.float64 = 1.0 - 2.0 * (ξ + χ) + + def get_steps(self) -> list[OperatorType]: + operators = [ + self._first_kinematic_step, + self._first_dynamic_step, + self._second_kinematic_step, + self._second_dynamic_step, + self._third_kinematic_step, + ] + return operators + operators[-2::-1] + + def get_prefactors(self) -> list[OperatorType]: + return [ + self._first_kinematic_prefactor, + self._second_kinematic_prefactor, + self._third_kinematic_prefactor, + self._second_kinematic_prefactor, + self._first_kinematic_prefactor, + ] - def _first_kinematic_prefactor(self, dt): + def _first_kinematic_prefactor(self, dt: np.float64) -> np.float64: return self.ξ * dt - def _first_kinematic_step(self, System, time: np.float64, dt: np.float64): + def _first_kinematic_step( + self, System: SymplecticSystemProtocol, time: np.float64, dt: np.float64 + ) -> None: prefac = self._first_kinematic_prefactor(dt) overload_operator_kinematic_numba( System.n_nodes, @@ -274,7 +254,9 @@ def _first_kinematic_step(self, System, time: np.float64, dt: np.float64): ) # System.kinematic_states += prefac * System.kinematic_rates(time, prefac) - def _first_dynamic_step(self, System, time: np.float64, dt: np.float64): + def _first_dynamic_step( + self, System: SymplecticSystemProtocol, time: np.float64, dt: np.float64 + ) -> None: prefac = self.lambda_dash_coeff * dt overload_operator_dynamic_numba( System.dynamic_states.rate_collection, @@ -282,10 +264,12 @@ def _first_dynamic_step(self, System, time: np.float64, dt: np.float64): ) # System.dynamic_states += prefac * System.dynamic_rates(time, prefac) - def _second_kinematic_prefactor(self, dt): + def _second_kinematic_prefactor(self, dt: np.float64) -> np.float64: return self.χ * dt - def _second_kinematic_step(self, System, time: np.float64, dt: np.float64): + def _second_kinematic_step( + self, System: SymplecticSystemProtocol, time: np.float64, dt: np.float64 + ) -> None: prefac = self._second_kinematic_prefactor(dt) overload_operator_kinematic_numba( System.n_nodes, @@ -297,7 +281,9 @@ def _second_kinematic_step(self, System, time: np.float64, dt: np.float64): ) # System.kinematic_states += prefac * System.kinematic_rates(time, prefac) - def _second_dynamic_step(self, System, time: np.float64, dt: np.float64): + def _second_dynamic_step( + self, System: SymplecticSystemProtocol, time: np.float64, dt: np.float64 + ) -> None: prefac = self.λ * dt overload_operator_dynamic_numba( System.dynamic_states.rate_collection, @@ -305,10 +291,12 @@ def _second_dynamic_step(self, System, time: np.float64, dt: np.float64): ) # System.dynamic_states += prefac * System.dynamic_rates(time, prefac) - def _third_kinematic_prefactor(self, dt): + def _third_kinematic_prefactor(self, dt: np.float64) -> np.float64: return self.xi_chi_dash_coeff * dt - def _third_kinematic_step(self, System, time: np.float64, dt: np.float64): + def _third_kinematic_step( + self, System: SymplecticSystemProtocol, time: np.float64, dt: np.float64 + ) -> None: prefac = self._third_kinematic_prefactor(dt) # Need to fill in overload_operator_kinematic_numba( diff --git a/elastica/transformations.py b/elastica/transformations.py index b002e956..3a2e10b5 100644 --- a/elastica/transformations.py +++ b/elastica/transformations.py @@ -10,11 +10,14 @@ from .utils import MaxDimension, isqrt +from numpy.typing import NDArray # TODO Complete, but nicer interface, evolve it eventually -def format_vector_shape(vector_collection): +def format_vector_shape( + vector_collection: NDArray[np.float64], +) -> NDArray[np.float64]: """ Function for formatting vector shapes into correct format @@ -59,7 +62,9 @@ def format_vector_shape(vector_collection): return vector_collection -def format_matrix_shape(matrix_collection): +def format_matrix_shape( + matrix_collection: NDArray[np.float64], +) -> NDArray[np.float64]: """ Formats input matrix into correct format @@ -77,7 +82,7 @@ def format_matrix_shape(matrix_collection): # check first two dimensions are same and matrix is square # other possibility is one dimension is dim**2 and other is blocksize, # we need to convert the matrix in that case. - def assert_proper_square(num1): + def assert_proper_square(num1: int) -> int: sqrt_num = isqrt(num1) assert sqrt_num**2 == num1, "Matrix dimension passed is not a perfect square" return sqrt_num @@ -136,12 +141,14 @@ def assert_proper_square(num1): return matrix_collection -def skew_symmetrize(vector): +def skew_symmetrize(vector: NDArray[np.float64]) -> NDArray[np.float64]: vector = format_vector_shape(vector) return _skew_symmetrize(vector) -def inv_skew_symmetrize(matrix_collection): +def inv_skew_symmetrize( + matrix_collection: NDArray[np.float64], +) -> NDArray[np.float64]: """ Safe wrapper around inv_skew_symmetrize that does checking and formatting on type of matrix_collection using format_matrix_shape @@ -167,7 +174,9 @@ def inv_skew_symmetrize(matrix_collection): raise ValueError("matrix_collection passed is not skew-symmetric") -def rotate(matrix, scale, axis): +def rotate( + matrix: NDArray[np.float64], scale: np.float64, axis: NDArray[np.float64] +) -> NDArray[np.float64]: """ This function takes single or multiple frames as matrix. Then rotates these frames around a single axis for all frames, or can rotate each frame around its own diff --git a/elastica/typing.py b/elastica/typing.py index 4307801e..79cba003 100644 --- a/elastica/typing.py +++ b/elastica/typing.py @@ -1,13 +1,71 @@ -from elastica.rod import RodBase -from elastica.rigidbody import RigidBodyBase -from elastica.surface import SurfaceBase +__doc__ = """ +This module contains aliases of type-hints for elastica. -from typing import Type, Union, TypeAlias, Callable +""" -RodType = Type[RodBase] -SystemType = Union[RodType, Type[RigidBodyBase]] -AllowedContactType = Union[SystemType, Type[SurfaceBase]] +from typing import TYPE_CHECKING +from typing import Callable, Any, ParamSpec, TypeAlias -OperatorType: TypeAlias = Callable[[float], None] -OperatorCallbackType: TypeAlias = Callable[[float, int], None] -OperatorFinalizeType: TypeAlias = Callable +import numpy as np + + +if TYPE_CHECKING: + # Used for type hinting without circular imports + # NEVER BACK-IMPORT ANY ELASTICA MODULES HERE + from .rod.protocol import CosseratRodProtocol + from .rigidbody.protocol import RigidBodyProtocol + from .surface.surface_base import SurfaceBase + from .modules.base_system import BaseSystemCollection + + from .modules.protocol import SystemCollectionProtocol + from .rod.data_structures import _State as State + from .systems.protocol import ( + SystemProtocol, + StaticSystemProtocol, + SymplecticSystemProtocol, + ExplicitSystemProtocol, + ) + from .timestepper.protocol import ( + StepperProtocol, + SymplecticStepperProtocol, + MemoryProtocol, + ) + from .memory_block.protocol import BlockSystemProtocol + + from .mesh.protocol import MeshProtocol + + +StaticSystemType: TypeAlias = "StaticSystemProtocol" +SystemType: TypeAlias = "SystemProtocol" +SystemIdxType: TypeAlias = int +BlockSystemType: TypeAlias = "BlockSystemProtocol" + + +# Mostly used in explicit stepper: for symplectic, use kinetic and dynamic state +StateType: TypeAlias = "State" + +# TODO: Maybe can be more specific. Up for discussion. +OperatorType: TypeAlias = Callable[..., Any] +SteppersOperatorsType: TypeAlias = tuple[tuple[OperatorType, ...], ...] + + +RodType: TypeAlias = "CosseratRodProtocol" +RigidBodyType: TypeAlias = "RigidBodyProtocol" +SurfaceType: TypeAlias = "SurfaceBase" + +SystemCollectionType: TypeAlias = "SystemCollectionProtocol" + +# Indexing types +# TODO: Maybe just use slice?? +ConstrainingIndex: TypeAlias = tuple[int, ...] +ConnectionIndex: TypeAlias = ( + int | np.int32 | list[int] | tuple[int, ...] | np.typing.NDArray[np.int32] +) + +# Operators in elastica.modules +# TODO: can be more specific. +OperatorParam = ParamSpec("OperatorParam") +OperatorCallbackType: TypeAlias = Callable[..., None] +OperatorFinalizeType: TypeAlias = Callable[..., None] + +MeshType: TypeAlias = "MeshProtocol" diff --git a/elastica/utils.py b/elastica/utils.py index bbf9baa2..7f851ed0 100644 --- a/elastica/utils.py +++ b/elastica/utils.py @@ -1,12 +1,15 @@ """ Handy utilities """ +from typing import Generator, Iterable, Any, Literal, TypeVar import functools import numpy as np from numpy import finfo, float64 from itertools import islice from scipy.interpolate import BSpline +from numpy.typing import NDArray + # Slower than the python3.8 isqrt implementation for small ints # python isqrt : ~130 ns @@ -47,6 +50,8 @@ def isqrt(num: int) -> int: elif num == 0: return 0 + raise ValueError("num must be a positive number") + class MaxDimension: """ @@ -54,7 +59,7 @@ class MaxDimension: """ @staticmethod - def value(): + def value() -> Literal[3]: """ Returns spatial dimension @@ -67,7 +72,7 @@ def value(): class Tolerance: @staticmethod - def atol(): + def atol() -> float: """ Static absolute tolerance method @@ -75,10 +80,10 @@ def atol(): ------- atol : library-wide set absolute tolerance for kernels """ - return finfo(float64).eps * 1e4 + return float(finfo(float64).eps * 1e4) @staticmethod - def rtol(): + def rtol() -> float: """ Static relative tolerance method @@ -86,10 +91,10 @@ def rtol(): ------- tol : library-wide set relative tolerance for kernels """ - return finfo(float64).eps * 1e11 + return float(finfo(float64).eps * 1e11) -def perm_parity(lst): +def perm_parity(lst: list[int]) -> int: """ Given a permutation of the digits 0..N in order as a list, returns its parity (or sign): +1 for even parity; -1 for odd. @@ -115,7 +120,10 @@ def perm_parity(lst): return parity -def grouper(iterable, n): +_T = TypeVar("_T") + + +def grouper(iterable: Iterable[_T], n: int) -> Generator[tuple[_T, ...], None, None]: """Collect data into fixed-length chunks or blocks" Parameters @@ -144,7 +152,7 @@ def grouper(iterable, n): yield group -def extend_instance(obj, cls): +def extend_instance(obj: Any, cls: Any) -> None: """ Apply mixins to a class instance after creation @@ -170,7 +178,9 @@ def extend_instance(obj, cls): obj.__class__ = type(base_cls_name, (cls, base_cls), {}) -def _bspline(t_coeff, l_centerline=1.0): +def _bspline( # type: ignore[no-any-unimported] + t_coeff: NDArray, l_centerline: np.float64 = np.float64(1.0) +) -> tuple[BSpline, NDArray, NDArray]: """Generates a bspline object that plots the spline interpolant for any vector x. Optionally takes in a centerline length, set to 1.0 by default and keep_pts for keeping record of control points @@ -198,7 +208,9 @@ def _bspline(t_coeff, l_centerline=1.0): return __bspline_impl__(control_pts, t_coeff, degree) -def __bspline_impl__(x_pts, t_c, degree): +def __bspline_impl__( # type: ignore[no-any-unimported] + x_pts: NDArray, t_c: NDArray, degree: int +) -> tuple[BSpline, NDArray, NDArray]: """""" # Update the knots diff --git a/elastica/wrappers.py b/elastica/wrappers.py deleted file mode 100644 index fa3990f8..00000000 --- a/elastica/wrappers.py +++ /dev/null @@ -1,23 +0,0 @@ -import warnings - -__all__ = [ - "BaseSystemCollection", - "Connections", - "Constraints", - "Forcing", - "CallBacks", - "Damping", -] - -from elastica.modules.base_system import BaseSystemCollection -from elastica.modules.connections import Connections -from elastica.modules.constraints import Constraints -from elastica.modules.forcing import Forcing -from elastica.modules.callbacks import CallBacks -from elastica.modules.damping import Damping - - -warnings.warn( - "elastica.wrappers is refactored to elastica.modules in version 0.3.0.", - DeprecationWarning, -) diff --git a/examples/Binder/1_Timoshenko_Beam.ipynb b/examples/Binder/1_Timoshenko_Beam.ipynb index a8b81fdf..63efe00d 100644 --- a/examples/Binder/1_Timoshenko_Beam.ipynb +++ b/examples/Binder/1_Timoshenko_Beam.ipynb @@ -627,16 +627,14 @@ "\n", "\n", "def run_and_update_plot(simulator, dt, start_time, stop_time, ax):\n", - " from elastica.timestepper import extend_stepper_interface\n", " from elastica.timestepper.symplectic_steppers import PositionVerlet\n", "\n", " timestepper = PositionVerlet()\n", - " do_step, stages_and_updates = extend_stepper_interface(timestepper, simulator)\n", "\n", " n_steps = int((stop_time - start_time) / dt)\n", " time = start_time\n", " for i in range(n_steps):\n", - " time = do_step(timestepper, stages_and_updates, simulator, time, dt)\n", + " time = timestepper.step(simulator, time, dt)\n", " plot_timoshenko_dynamic(shearable_rod_new, unshearable_rod_new, end_force, time, ax)\n", " return time\n", "\n", diff --git a/examples/ContinuumSnakeWithLiftingWaveCase/snake_contact.py b/examples/ContinuumSnakeWithLiftingWaveCase/snake_contact.py index 68dea81d..464b1fd3 100755 --- a/examples/ContinuumSnakeWithLiftingWaveCase/snake_contact.py +++ b/examples/ContinuumSnakeWithLiftingWaveCase/snake_contact.py @@ -1,4 +1,5 @@ __doc__ = """Rod plane contact with anistropic friction (no static friction)""" +from typing import Type import numpy as np from elastica._linalg import ( @@ -22,10 +23,11 @@ _node_to_element_mass_or_force, ) from numba import njit -from elastica.rod import RodBase +from elastica.rod.rod_base import RodBase from elastica.surface import Plane +from elastica.surface.surface_base import SurfaceBase from elastica.contact_forces import NoContact -from elastica.typing import RodType, SystemType, AllowedContactType +from elastica.typing import RodType, SystemType @njit(cache=True) @@ -286,31 +288,10 @@ def __init__( self.kinetic_mu_sideways, ) = kinetic_mu_array - def _check_systems_validity( - self, - system_one: SystemType, - system_two: AllowedContactType, - ) -> None: - """ - This checks the contact order and type of a SystemType object and an AllowedContactType object. - For the RodSphereContact class first_system should be a rod and second_system should be a plane. - Parameters - ---------- - system_one - SystemType - system_two - AllowedContactType - """ - if not issubclass(system_one.__class__, RodBase) or not issubclass( - system_two.__class__, Plane - ): - raise TypeError( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a plane".format( - system_one.__class__, system_two.__class__ - ) - ) + @property + def _allowed_system_two(self) -> list[Type]: + # Modify this list to include the allowed system types for contact + return [SurfaceBase] def apply_contact(self, system_one: RodType, system_two: SystemType) -> None: """ diff --git a/examples/MuscularSnake/muscular_snake.py b/examples/MuscularSnake/muscular_snake.py index c2e1c7bb..61d68f30 100644 --- a/examples/MuscularSnake/muscular_snake.py +++ b/examples/MuscularSnake/muscular_snake.py @@ -151,8 +151,8 @@ class MuscularSnakeSimulator( muscle_glue_connection_index.append( np.hstack( ( - np.arange(0, 4 * 3, 1, dtype=np.int64), - np.arange(9 * 3, n_elem_muscle_group_one_to_three, 1, dtype=np.int64), + np.arange(0, 4 * 3, 1, dtype=np.int32), + np.arange(9 * 3, n_elem_muscle_group_one_to_three, 1, dtype=np.int32), ) ) ) @@ -213,11 +213,11 @@ class MuscularSnakeSimulator( muscle_rod_list.append(muscle_rod) muscle_end_connection_index.append(index + n_elem_muscle_group_two_to_four) muscle_glue_connection_index.append( - # np.array([0,1, 2, 3, 9, 10 ], dtype=np.int) + # np.array([0,1, 2, 3, 9, 10 ], dtype=np.int32) np.hstack( ( - np.arange(0, 4 * 3, 1, dtype=np.int64), - np.arange(9 * 3, n_elem_muscle_group_two_to_four, 1, dtype=np.int64), + np.arange(0, 4 * 3, 1, dtype=np.int32), + np.arange(9 * 3, n_elem_muscle_group_two_to_four, 1, dtype=np.int32), ) ) ) @@ -290,45 +290,43 @@ class MuscularSnakeSimulator( offset_btw_rods.copy(), ] ) - for k in range(rod_two.n_elems): - rod_one_index = k + muscle_start_connection_index[idx] - rod_two_index = k - k_conn = ( - rod_one.radius[rod_one_index] - * rod_two.radius[rod_two_index] - / (rod_one.radius[rod_one_index] + rod_two.radius[rod_two_index]) - * body_elem_length - * E - / (rod_one.radius[rod_one_index] + rod_two.radius[rod_two_index]) - ) - if k < 12 or k >= 27: - scale = 1 * 2 - scale_contact = 20 - else: - scale = 0.01 * 5 - scale_contact = 20 - - muscular_snake_simulator.connect( - first_rod=rod_one, - second_rod=rod_two, - first_connect_idx=rod_one_index, - second_connect_idx=rod_two_index, - ).using( - SurfaceJointSideBySide, - k=k_conn * scale, - nu=1e-4, - k_repulsive=k_conn * scale_contact, - rod_one_direction_vec_in_material_frame=rod_one_direction_vec_in_material_frame[ - ..., k - ], - rod_two_direction_vec_in_material_frame=rod_two_direction_vec_in_material_frame[ - ..., k - ], - offset_btw_rods=offset_btw_rods[k], - post_processing_dict=straight_straight_rod_connection_post_processing_dict, - step_skip=step_skip, - ) + ks = np.arange(rod_two.n_elems) + scale = np.ones(rod_two.n_elems) * 1 * 2 + scale[12:27] = 0.01 * 5 + scale_contact = np.ones(rod_two.n_elems) * 20 + scale_contact[12:27] = 20 + rod_one_index = ks + muscle_start_connection_index[idx] + rod_two_index = ks + k_conn = ( + rod_one.radius[rod_one_index] + * rod_two.radius[rod_two_index] + / (rod_one.radius[rod_one_index] + rod_two.radius[rod_two_index]) + * body_elem_length + * E + / (rod_one.radius[rod_one_index] + rod_two.radius[rod_two_index]) + ) + + muscular_snake_simulator.connect( + first_rod=rod_one, + second_rod=rod_two, + first_connect_idx=rod_one_index, + second_connect_idx=rod_two_index, + ).using( + SurfaceJointSideBySide, + k=k_conn * scale, + nu=1e-4, + k_repulsive=k_conn * scale_contact, + rod_one_direction_vec_in_material_frame=rod_one_direction_vec_in_material_frame[ + ..., ks + ], + rod_two_direction_vec_in_material_frame=rod_two_direction_vec_in_material_frame[ + ..., ks + ], + offset_btw_rods=offset_btw_rods[ks], + post_processing_dict=straight_straight_rod_connection_post_processing_dict, + step_skip=step_skip, + ) # Friction forces # Only apply to the snake body. diff --git a/examples/RigidBodyCases/rigid_cylinder_rotational_motion_case.py b/examples/RigidBodyCases/rigid_cylinder_rotational_motion_case.py index caa529fb..64271916 100644 --- a/examples/RigidBodyCases/rigid_cylinder_rotational_motion_case.py +++ b/examples/RigidBodyCases/rigid_cylinder_rotational_motion_case.py @@ -52,7 +52,7 @@ def __init__(self, torque, direction=np.array([0.0, 0.0, 0.0])): super(PointCoupleToCenter, self).__init__() self.torque = (torque * direction).reshape(3, 1) - def apply_forces(self, system, time: np.float = 0.0): + def apply_forces(self, system, time: np.float64 = np.float64(0.0)): system.external_torques += np.einsum( "ijk, jk->ik", system.director_collection, self.torque ) diff --git a/examples/RigidBodyCases/rigid_cylinder_translational_motion_case.py b/examples/RigidBodyCases/rigid_cylinder_translational_motion_case.py index 0f95ffcb..ce78a84a 100644 --- a/examples/RigidBodyCases/rigid_cylinder_translational_motion_case.py +++ b/examples/RigidBodyCases/rigid_cylinder_translational_motion_case.py @@ -52,7 +52,7 @@ def __init__(self, force, direction=np.array([0.0, 0.0, 0.0])): super(PointForceToCenter, self).__init__() self.force = (force * direction).reshape(3, 1) - def apply_forces(self, system, time: np.float = 0.0): + def apply_forces(self, system, time: np.float64 = np.float64(0.0)): system.external_forces += self.force # Add point force on the rod diff --git a/examples/RigidBodyCases/rigid_sphere_rotational_motion_case.py b/examples/RigidBodyCases/rigid_sphere_rotational_motion_case.py index 5c2f5d83..600d93d8 100644 --- a/examples/RigidBodyCases/rigid_sphere_rotational_motion_case.py +++ b/examples/RigidBodyCases/rigid_sphere_rotational_motion_case.py @@ -43,7 +43,7 @@ def __init__(self, torque, direction=np.array([0.0, 0.0, 0.0])): super(PointCoupleToCenter, self).__init__() self.torque = (torque * direction).reshape(3, 1) - def apply_forces(self, system, time: np.float = 0.0): + def apply_forces(self, system, time: np.float64 = np.float64(0.0)): system.external_torques += np.einsum( "ijk, jk->ik", system.director_collection, self.torque ) diff --git a/examples/RigidBodyCases/rigid_sphere_translational_motion_case.py b/examples/RigidBodyCases/rigid_sphere_translational_motion_case.py index 3a3ab071..91a19015 100644 --- a/examples/RigidBodyCases/rigid_sphere_translational_motion_case.py +++ b/examples/RigidBodyCases/rigid_sphere_translational_motion_case.py @@ -44,7 +44,7 @@ def __init__(self, force, direction=np.array([0.0, 0.0, 0.0])): super(PointForceToCenter, self).__init__() self.force = (force * direction).reshape(3, 1) - def apply_forces(self, system, time: np.float = 0.0): + def apply_forces(self, system, time: np.float64 = np.float64(0.0)): system.external_forces += self.force # Add point force on the rod diff --git a/poetry.lock b/poetry.lock index a848f5da..196c8dd3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -26,18 +26,19 @@ files = [ ] [[package]] -name = "autoflake8" -version = "0.4.1" -description = "Tool to automatically fix some issues reported by flake8 (forked from autoflake)." +name = "autoflake" +version = "2.3.1" +description = "Removes unused imports and unused variables" optional = false -python-versions = ">=3.7,<4.0" +python-versions = ">=3.8" files = [ - {file = "autoflake8-0.4.1-py3-none-any.whl", hash = "sha256:fdf663b627993ac38e5b55b7d742c388fb2a4f34798a052f43eecc5e8d629e9d"}, - {file = "autoflake8-0.4.1.tar.gz", hash = "sha256:c17da499bd2b71ba02fb11fe53ff1ad83d7dae6efb0f115fd1344f467797c679"}, + {file = "autoflake-2.3.1-py3-none-any.whl", hash = "sha256:3ae7495db9084b7b32818b4140e6dc4fc280b712fb414f5b8fe57b0a8e85a840"}, + {file = "autoflake-2.3.1.tar.gz", hash = "sha256:c98b75dc5b0a86459c4f01a1d32ac7eb4338ec4317a4469515ff1e687ecd909e"}, ] [package.dependencies] -pyflakes = ">=2.3.0" +pyflakes = ">=3.0.0" +tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} [[package]] name = "babel" @@ -496,19 +497,19 @@ typing = ["typing-extensions (>=4.8)"] [[package]] name = "flake8" -version = "3.9.2" +version = "7.0.0" description = "the modular source code checker: pep8 pyflakes and co" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" +python-versions = ">=3.8.1" files = [ - {file = "flake8-3.9.2-py2.py3-none-any.whl", hash = "sha256:bf8fd333346d844f616e8d47905ef3a3384edae6b4e9beb0c5101e25e3110907"}, - {file = "flake8-3.9.2.tar.gz", hash = "sha256:07528381786f2a6237b061f6e96610a4167b226cb926e2aa2b6b1d78057c576b"}, + {file = "flake8-7.0.0-py2.py3-none-any.whl", hash = "sha256:a6dfbb75e03252917f2473ea9653f7cd799c3064e54d4c8140044c5c065f53c3"}, + {file = "flake8-7.0.0.tar.gz", hash = "sha256:33f96621059e65eec474169085dc92bf26e7b2d47366b70be2f67ab80dc25132"}, ] [package.dependencies] -mccabe = ">=0.6.0,<0.7.0" -pycodestyle = ">=2.7.0,<2.8.0" -pyflakes = ">=2.3.0,<2.4.0" +mccabe = ">=0.7.0,<0.8.0" +pycodestyle = ">=2.11.0,<2.12.0" +pyflakes = ">=3.2.0,<3.3.0" [[package]] name = "fonttools" @@ -624,13 +625,13 @@ files = [ [[package]] name = "jinja2" -version = "3.1.3" +version = "3.1.4" description = "A very fast and expressive template engine." optional = true python-versions = ">=3.7" files = [ - {file = "Jinja2-3.1.3-py3-none-any.whl", hash = "sha256:7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa"}, - {file = "Jinja2-3.1.3.tar.gz", hash = "sha256:ac8bd6544d4bb2c9792bf3a159e80bba8fda7f07e81bc3aed565432d5925ba90"}, + {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, + {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, ] [package.dependencies] @@ -932,13 +933,13 @@ dev = ["meson-python (>=0.13.1)", "numpy (>=1.25)", "pybind11 (>=2.6)", "setupto [[package]] name = "mccabe" -version = "0.6.1" +version = "0.7.0" description = "McCabe checker, plugin for flake8" optional = false -python-versions = "*" +python-versions = ">=3.6" files = [ - {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, - {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"}, + {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, + {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] [[package]] @@ -971,6 +972,53 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mypy" +version = "1.10.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:da1cbf08fb3b851ab3b9523a884c232774008267b1f83371ace57f412fe308c2"}, + {file = "mypy-1.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:12b6bfc1b1a66095ab413160a6e520e1dc076a28f3e22f7fb25ba3b000b4ef99"}, + {file = "mypy-1.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e36fb078cce9904c7989b9693e41cb9711e0600139ce3970c6ef814b6ebc2b2"}, + {file = "mypy-1.10.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2b0695d605ddcd3eb2f736cd8b4e388288c21e7de85001e9f85df9187f2b50f9"}, + {file = "mypy-1.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:cd777b780312ddb135bceb9bc8722a73ec95e042f911cc279e2ec3c667076051"}, + {file = "mypy-1.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3be66771aa5c97602f382230165b856c231d1277c511c9a8dd058be4784472e1"}, + {file = "mypy-1.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8b2cbaca148d0754a54d44121b5825ae71868c7592a53b7292eeb0f3fdae95ee"}, + {file = "mypy-1.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ec404a7cbe9fc0e92cb0e67f55ce0c025014e26d33e54d9e506a0f2d07fe5de"}, + {file = "mypy-1.10.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e22e1527dc3d4aa94311d246b59e47f6455b8729f4968765ac1eacf9a4760bc7"}, + {file = "mypy-1.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:a87dbfa85971e8d59c9cc1fcf534efe664d8949e4c0b6b44e8ca548e746a8d53"}, + {file = "mypy-1.10.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a781f6ad4bab20eef8b65174a57e5203f4be627b46291f4589879bf4e257b97b"}, + {file = "mypy-1.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b808e12113505b97d9023b0b5e0c0705a90571c6feefc6f215c1df9381256e30"}, + {file = "mypy-1.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f55583b12156c399dce2df7d16f8a5095291354f1e839c252ec6c0611e86e2e"}, + {file = "mypy-1.10.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4cf18f9d0efa1b16478c4c129eabec36148032575391095f73cae2e722fcf9d5"}, + {file = "mypy-1.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:bc6ac273b23c6b82da3bb25f4136c4fd42665f17f2cd850771cb600bdd2ebeda"}, + {file = "mypy-1.10.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9fd50226364cd2737351c79807775136b0abe084433b55b2e29181a4c3c878c0"}, + {file = "mypy-1.10.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f90cff89eea89273727d8783fef5d4a934be2fdca11b47def50cf5d311aff727"}, + {file = "mypy-1.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fcfc70599efde5c67862a07a1aaf50e55bce629ace26bb19dc17cece5dd31ca4"}, + {file = "mypy-1.10.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:075cbf81f3e134eadaf247de187bd604748171d6b79736fa9b6c9685b4083061"}, + {file = "mypy-1.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:3f298531bca95ff615b6e9f2fc0333aae27fa48052903a0ac90215021cdcfa4f"}, + {file = "mypy-1.10.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fa7ef5244615a2523b56c034becde4e9e3f9b034854c93639adb667ec9ec2976"}, + {file = "mypy-1.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3236a4c8f535a0631f85f5fcdffba71c7feeef76a6002fcba7c1a8e57c8be1ec"}, + {file = "mypy-1.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a2b5cdbb5dd35aa08ea9114436e0d79aceb2f38e32c21684dcf8e24e1e92821"}, + {file = "mypy-1.10.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92f93b21c0fe73dc00abf91022234c79d793318b8a96faac147cd579c1671746"}, + {file = "mypy-1.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:28d0e038361b45f099cc086d9dd99c15ff14d0188f44ac883010e172ce86c38a"}, + {file = "mypy-1.10.0-py3-none-any.whl", hash = "sha256:f8c083976eb530019175aabadb60921e73b4f45736760826aa1689dda8208aee"}, + {file = "mypy-1.10.0.tar.gz", hash = "sha256:3d087fcbec056c4ee34974da493a826ce316947485cef3901f511848e687c131"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.1.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -1308,13 +1356,13 @@ files = [ [[package]] name = "pycodestyle" -version = "2.7.0" +version = "2.11.1" description = "Python style guide checker" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +python-versions = ">=3.8" files = [ - {file = "pycodestyle-2.7.0-py2.py3-none-any.whl", hash = "sha256:514f76d918fcc0b55c6680472f0a37970994e07bbb80725808c17089be302068"}, - {file = "pycodestyle-2.7.0.tar.gz", hash = "sha256:c389c1d06bf7904078ca03399a4816f974a1d590090fecea0c63ec26ebaf1cef"}, + {file = "pycodestyle-2.11.1-py2.py3-none-any.whl", hash = "sha256:44fe31000b2d866f2e41841b18528a505fbd7fef9017b04eff4e2648a0fadc67"}, + {file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"}, ] [[package]] @@ -1346,13 +1394,13 @@ test = ["pytest", "pytest-cov", "pytest-regressions"] [[package]] name = "pyflakes" -version = "2.3.1" +version = "3.2.0" description = "passive checker of Python programs" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +python-versions = ">=3.8" files = [ - {file = "pyflakes-2.3.1-py2.py3-none-any.whl", hash = "sha256:7893783d01b8a89811dd72d7dfd4d84ff098e5eed95cfa8905b22bbffe52efc3"}, - {file = "pyflakes-2.3.1.tar.gz", hash = "sha256:f5bc8ecabc05bb9d291eb5203d6810b49040f6ff446a756326104746cc00c1db"}, + {file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"}, + {file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"}, ] [[package]] @@ -1509,7 +1557,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -1517,16 +1564,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -1543,7 +1582,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -1551,7 +1589,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -1829,6 +1866,17 @@ files = [ [package.extras] test = ["flake8", "mypy", "pytest"] +[[package]] +name = "sphinxcontrib-mermaid" +version = "0.9.2" +description = "Mermaid diagrams in yours Sphinx powered docs" +optional = true +python-versions = ">=3.7" +files = [ + {file = "sphinxcontrib-mermaid-0.9.2.tar.gz", hash = "sha256:252ef13dd23164b28f16d8b0205cf184b9d8e2b714a302274d9f59eb708e77af"}, + {file = "sphinxcontrib_mermaid-0.9.2-py3-none-any.whl", hash = "sha256:6795a72037ca55e65663d2a2c1a043d636dc3d30d418e56dd6087d1459d98a5d"}, +] + [[package]] name = "sphinxcontrib-qthelp" version = "1.0.7" @@ -1888,13 +1936,13 @@ files = [ [[package]] name = "tqdm" -version = "4.66.2" +version = "4.66.3" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.66.2-py3-none-any.whl", hash = "sha256:1ee4f8a893eb9bef51c6e35730cebf234d5d0b6bd112b0271e10ed7c24a02bd9"}, - {file = "tqdm-4.66.2.tar.gz", hash = "sha256:6cd52cdf0fef0e0f543299cfc96fec90d7b8a7e88745f411ec33eb44d5ed3531"}, + {file = "tqdm-4.66.3-py3-none-any.whl", hash = "sha256:4f41d54107ff9a223dca80b53efe4fb654c67efaba7f47bada3ee9d50e05bd53"}, + {file = "tqdm-4.66.3.tar.gz", hash = "sha256:23097a41eba115ba99ecae40d06444c15d1c0c698d527a01c6c8bd1c5d0647e5"}, ] [package.dependencies] @@ -1996,10 +2044,10 @@ numpy = ["numpy (>=1.9)"] web = ["wslink (>=1.0.4)"] [extras] -docs = ["Sphinx", "docutils", "myst-parser", "numpydoc", "readthedocs-sphinx-search", "sphinx-autodoc-typehints", "sphinx-book-theme"] +docs = ["Sphinx", "docutils", "myst-parser", "numpydoc", "readthedocs-sphinx-search", "sphinx-autodoc-typehints", "sphinx-book-theme", "sphinxcontrib-mermaid"] examples = ["cma"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "efc2d5991f4b33ab404e6f1e306ef6d9e69663d963845b24cdc463f36d082140" +content-hash = "c9747151a346aa7c37c83f66c82b8283aa4ed796788d545d65378b7e8a117260" diff --git a/pyproject.toml b/pyproject.toml index add885b7..bbd4e9ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,10 +48,14 @@ Sphinx = {version = "^6.1", optional = true} sphinx-book-theme = {version = "^1.0", optional = true} readthedocs-sphinx-search = {version = ">=0.1.1,<0.4.0", optional = true} sphinx-autodoc-typehints = {version = "^1.21", optional = true} +sphinxcontrib-mermaid = {version = "^0.9.2", optional = true} myst-parser = {version = "^1.0", optional = true} numpydoc = {version = "^1.3.1", optional = true} docutils = {version = "^0.18", optional = true} cma = {version = "^3.2.2", optional = true} +mypy = "^1.10.0" +mypy-extensions = "^1.0.0" +flake8 = "^7.0.0" [tool.poetry.dev-dependencies] black = "24.3.0" @@ -60,10 +64,9 @@ coverage = "^6.3.3" pre-commit = "^2.19.0" pytest-html = "^3.1.1" pytest-cov = "^3.0.0" -flake8 = "^3.8" codecov = "2.1.13" click = "8.0.0" -autoflake8 = "^0.4" +autoflake = "^2.3.1" [tool.poetry.extras] docs = [ @@ -71,6 +74,7 @@ docs = [ "sphinx-book-theme", "readthedocs-sphinx-search", "sphinx-autodoc-typehints", + "sphinxcontrib-mermaid", "myst-parser", "numpydoc", "docutils", @@ -100,10 +104,54 @@ exclude = ''' )/ ''' +[tool.autoflake] +ignore-init-module-imports = true +ignore-pass-statements = true +ignore-pass-after-docstring = true + [tool.pytest.ini_options] # https://docs.pytest.org/en/6.2.x/customize.html#pyproject-toml # Directories that are not visited by pytest collector: -norecursedirs =["hooks", "*.egg", ".eggs", "dist", "build", "docs", ".tox", ".git", "__pycache__"] +norecursedirs = ["hooks", "*.egg", ".eggs", "dist", "build", "docs", ".tox", ".git", "__pycache__"] + +[tool.mypy] +# https://mypy.readthedocs.io/en/latest/config_file.html#using-a-pyproject-toml-file +python_version = "3.10" +pretty = true +show_traceback = true +color_output = true +strict = true + +allow_redefinition = false +check_untyped_defs = true +disallow_any_unimported = true +disallow_any_generics = false +disallow_incomplete_defs = true +disallow_untyped_calls = true +disallow_untyped_defs = true +ignore_missing_imports = true +implicit_reexport = true +no_implicit_optional = true +show_column_numbers = true +show_error_codes = true +show_error_context = true +strict_equality = true +strict_optional = true +warn_no_return = true +warn_redundant_casts = true +warn_return_any = false +warn_unreachable = false +warn_unused_configs = true +warn_unused_ignores = false + +exclude = [ + "elastica/systems/analytical.py", + "elastica/experimental/*", +] + +untyped_calls_exclude = [ + "pyvista", +] [tool.coverage.report] # Regexes for lines to exclude from consideration @@ -117,13 +165,15 @@ exclude_lines = [ "if __name__ == __main__:", "pass", "def __repr__", +#"if self.debug:", +#"if settings.DEBUG", "from", "import", "if TYPE_CHECKING:", "raise AssertionError", "raise NotImplementedError", '''class '.*\bProtocol\)':''', - # ''''@(abc\.)?'abstractmethod''', +#''''@(abc\.)?'abstractmethod''', ] fail_under = 40 show_missing = true @@ -135,4 +185,5 @@ omit = [ "setup.py", "elastica/systems/analytical.py", "elastica/experimental/*", + "elastica/**/protocol.py", ] diff --git a/tests/test_callback_functions.py b/tests/test_callback_functions.py index 8c6362b3..0e5f3e21 100644 --- a/tests/test_callback_functions.py +++ b/tests/test_callback_functions.py @@ -104,14 +104,14 @@ def test_my_call_back_base_class(self, n_elems): class TestExportCallBackClass: @pytest.mark.parametrize("method", ["0", 1, "numba", "test", "some string", None]) - def test_export_call_back_unavailable_save_methods(self, method): + def test_export_call_back_unavailable_save_methods(self, tmp_path, method): with pytest.raises(AssertionError) as excinfo: - callback = ExportCallBack(1, "rod", "tempdir", method) + callback = ExportCallBack(1, "rod", tmp_path.as_posix(), method) @pytest.mark.parametrize("method", ExportCallBack.AVAILABLE_METHOD) - def test_export_call_back_available_save_methods(self, method): + def test_export_call_back_available_save_methods(self, tmp_path, method): try: - callback = ExportCallBack(1, "rod", "tempdir", method) + callback = ExportCallBack(1, "rod", tmp_path.as_posix(), method) except Error: pytest.fail( f"Could not create callback module with available method {method}" @@ -236,7 +236,7 @@ def test_export_call_back_clear_test(self): assert os.path.exists(saved_path_name), "File is not saved." @pytest.mark.parametrize("n_elems", [2, 4, 16]) - def test_export_call_back_class_tempfile_option(self, n_elems): + def test_export_call_back_class_tempfile_option(self, tmp_path, n_elems): """ This test case is for testing ExportCallBack function, saving into temporary files. """ @@ -256,7 +256,7 @@ def test_export_call_back_class_tempfile_option(self, n_elems): } callback = ExportCallBack( - step_skip, "rod", "tempdir", "tempfile", file_save_interval=10 + step_skip, "rod", tmp_path.as_posix(), "tempfile", file_save_interval=10 ) for i in range(10): callback.make_callback(mock_rod, time[i], current_step[i]) diff --git a/tests/test_contact_classes.py b/tests/test_contact_classes.py index 160b2fb7..509d8e32 100644 --- a/tests/test_contact_classes.py +++ b/tests/test_contact_classes.py @@ -12,7 +12,7 @@ RodPlaneContactWithAnisotropicFriction, CylinderPlaneContact, ) -from elastica.typing import RodBase +from elastica.rod import RodBase from elastica.rigidbody import Cylinder, Sphere from elastica.surface import Plane import pytest @@ -52,8 +52,8 @@ def mock_cylinder_init(self): self.director_collection = np.array( [[[1.0], [0.0], [0.0]], [[0.0], [1.0], [0.0]], [[0.0], [0.0], [1.0]]] ) - self.radius = np.array([1.0]) - self.length = np.array([2.0]) + self.radius = 1.0 + self.length = 2.0 self.external_forces = np.array([[0.0], [0.0], [0.0]]) self.external_torques = np.array([[0.0], [0.0], [0.0]]) self.velocity_collection = np.array([[0.0], [0.0], [0.0]]) @@ -69,7 +69,7 @@ def mock_sphere_init(self): self.director_collection = np.array( [[[1.0], [0.0], [0.0]], [[0.0], [1.0], [0.0]], [[0.0], [0.0], [1.0]]] ) - self.radius = np.array([1.0]) + self.radius = 1.0 self.velocity_collection = np.array([[0.0], [0.0], [0.0]]) self.external_forces = np.array([[0.0], [0.0], [0.0]]) self.external_torques = np.array([[0.0], [0.0], [0.0]]) @@ -104,43 +104,41 @@ def test_check_systems_validity_with_invalid_systems( "Testing Rod Cylinder Contact wrapper with incorrect type for second argument" with pytest.raises(TypeError) as excinfo: rod_cylinder_contact._check_systems_validity(mock_rod, mock_list) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a cylinder" - ).format(mock_rod.__class__, mock_list.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['Cylinder']." == str( + excinfo.value + ) - "Testing Rod Cylinder Contact wrapper with incorrect type for first argument" with pytest.raises(TypeError) as excinfo: rod_cylinder_contact._check_systems_validity(mock_list, mock_rod) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a cylinder" - ).format(mock_list.__class__, mock_rod.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['RodBase']." == str( + excinfo.value + ) - "Testing Rod Cylinder Contact wrapper with incorrect order" with pytest.raises(TypeError) as excinfo: rod_cylinder_contact._check_systems_validity(mock_cylinder, mock_rod) - print(excinfo.value) assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a cylinder" - ).format(mock_cylinder.__class__, mock_rod.__class__) == str(excinfo.value) + "System provided (MockCylinder) must be derived from ['RodBase']." + == str(excinfo.value) + ) + + with pytest.raises(TypeError) as excinfo: + rod_cylinder_contact._check_systems_validity(mock_list, mock_cylinder) + assert "System provided (list) must be derived from ['RodBase']." == str( + excinfo.value + ) def test_contact_rod_cylinder_with_collision_with_k_without_nu_and_friction( self, ): - "Testing Rod Cylinder Contact wrapper with Collision with analytical verified values" + # Testing Rod Cylinder Contact wrapper with Collision with analytical verified values mock_rod = MockRod() mock_cylinder = MockCylinder() rod_cylinder_contact = RodCylinderContact(k=1.0, nu=0.0) rod_cylinder_contact.apply_contact(mock_rod, mock_cylinder) - """Details and reasoning about the values are given in 'test_contact_specific_functions.py/test_calculate_contact_forces_rod_cylinder()'""" + # Details and reasoning about the values are given in 'test_contact_specific_functions.py/test_calculate_contact_forces_rod_cylinder()' assert_allclose( mock_rod.external_forces, np.array([[0.166666, 0.333333, 0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), @@ -293,32 +291,27 @@ def test_check_systems_validity_with_invalid_systems( mock_list = [1, 2, 3] rod_rod_contact = RodRodContact(k=1.0, nu=0.0) - "Testing Rod Rod Contact wrapper with incorrect type for second argument" + # Testing Rod Rod Contact wrapper with incorrect type for second argument with pytest.raises(TypeError) as excinfo: rod_rod_contact._check_systems_validity(mock_rod_one, mock_list) - assert ( - "Systems provided to the contact class have incorrect order. \n" - " First system is {0} and second system is {1}. \n" - " Both systems must be distinct rods" - ).format(mock_rod_one.__class__, mock_list.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['RodBase']." == str( + excinfo.value + ) - "Testing Rod Rod Contact wrapper with incorrect type for first argument" + # Testing Rod Rod Contact wrapper with incorrect type for first argument with pytest.raises(TypeError) as excinfo: rod_rod_contact._check_systems_validity(mock_list, mock_rod_one) - assert ( - "Systems provided to the contact class have incorrect order. \n" - " First system is {0} and second system is {1}. \n" - " Both systems must be distinct rods" - ).format(mock_list.__class__, mock_rod_one.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['RodBase']." == str( + excinfo.value + ) - "Testing Rod Rod Contact wrapper with same rod for both arguments" + # Testing Rod Rod Contact wrapper with same rod for both arguments with pytest.raises(TypeError) as excinfo: rod_rod_contact._check_systems_validity(mock_rod_one, mock_rod_one) assert ( - "First rod is identical to second rod. \n" - "Rods must be distinct for RodRodConact. \n" - "If you want self contact, use RodSelfContact instead" - ) == str(excinfo.value) + "First system is identical to second system. Systems must be distinct for contact." + == str(excinfo.value) + ) def test_contact_with_two_rods_with_collision_with_k_without_nu(self): @@ -439,35 +432,26 @@ def test_check_systems_validity_with_invalid_systems( mock_list = [1, 2, 3] self_contact = RodSelfContact(k=1.0, nu=0.0) - "Testing Self Contact wrapper with incorrect type for second argument" + # Testing Self Contact wrapper with incorrect type for second argument with pytest.raises(TypeError) as excinfo: self_contact._check_systems_validity(mock_rod_one, mock_list) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system and second system should be the same rod \n" - " If you want rod rod contact, use RodRodContact instead" - ).format(mock_rod_one.__class__, mock_list.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['RodBase']." == str( + excinfo.value + ) - "Testing Self Contact wrapper with incorrect type for first argument" + # Testing Self Contact wrapper with incorrect type for first argument with pytest.raises(TypeError) as excinfo: self_contact._check_systems_validity(mock_list, mock_rod_one) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system and second system should be the same rod \n" - " If you want rod rod contact, use RodRodContact instead" - ).format(mock_list.__class__, mock_rod_one.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['RodBase']." == str( + excinfo.value + ) - "Testing Self Contact wrapper with different rods" + # Testing Self Contact wrapper with different rods with pytest.raises(TypeError) as excinfo: self_contact._check_systems_validity(mock_rod_one, mock_rod_two) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system and second system should be the same rod \n" - " If you want rod rod contact, use RodRodContact instead" - ).format(mock_rod_one.__class__, mock_rod_two.__class__) == str(excinfo.value) + assert "First system must be identical to the second system." == str( + excinfo.value + ) def test_self_contact_with_rod_self_collision(self): @@ -530,46 +514,39 @@ def test_check_systems_validity_with_invalid_systems( mock_sphere = MockSphere() rod_sphere_contact = RodSphereContact(k=1.0, nu=0.0) - "Testing Rod Sphere Contact wrapper with incorrect type for second argument" + # Testing Rod Sphere Contact wrapper with incorrect type for second argument with pytest.raises(TypeError) as excinfo: rod_sphere_contact._check_systems_validity(mock_rod, mock_list) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a sphere" - ).format(mock_rod.__class__, mock_list.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['Sphere']." == str( + excinfo.value + ) - "Testing Rod Sphere Contact wrapper with incorrect type for first argument" + # Testing Rod Sphere Contact wrapper with incorrect type for first argument with pytest.raises(TypeError) as excinfo: rod_sphere_contact._check_systems_validity(mock_list, mock_rod) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a sphere" - ).format(mock_list.__class__, mock_rod.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['RodBase']." == str( + excinfo.value + ) - "Testing Rod Sphere Contact wrapper with incorrect order" + # Testing Rod Sphere Contact wrapper with incorrect order with pytest.raises(TypeError) as excinfo: rod_sphere_contact._check_systems_validity(mock_sphere, mock_rod) - print(excinfo.value) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a sphere" - ).format(mock_sphere.__class__, mock_rod.__class__) == str(excinfo.value) + assert "System provided (MockSphere) must be derived from ['RodBase']." == str( + excinfo.value + ) def test_contact_rod_sphere_with_collision_with_k_without_nu_and_friction( self, ): - "Testing Rod Sphere Contact wrapper with Collision with analytical verified values" + # "Testing Rod Sphere Contact wrapper with Collision with analytical verified values mock_rod = MockRod() mock_sphere = MockSphere() rod_sphere_contact = RodSphereContact(k=1.0, nu=0.0) rod_sphere_contact.apply_contact(mock_rod, mock_sphere) - """Details and reasoning about the values are given in 'test_contact_specific_functions.py/test_calculate_contact_forces_rod_sphere_with_k_without_nu_and_friction()'""" + # Details and reasoning about the values are given in 'test_contact_specific_functions.py/test_calculate_contact_forces_rod_sphere_with_k_without_nu_and_friction()' assert_allclose( mock_sphere.external_forces, np.array([[-0.5], [0], [0]]), atol=1e-6 ) @@ -630,23 +607,19 @@ def test_check_systems_validity_with_invalid_systems( mock_list = [1, 2, 3] rod_plane_contact = RodPlaneContact(k=1.0, nu=0.0) - "Testing Rod Plane Contact wrapper with incorrect type for second argument" + # Testing Rod Plane Contact wrapper with incorrect type for second argument with pytest.raises(TypeError) as excinfo: rod_plane_contact._check_systems_validity(mock_rod, mock_list) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a plane" - ).format(mock_rod.__class__, mock_list.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['SurfaceBase']." == str( + excinfo.value + ) - "Testing Rod Plane Contact wrapper with incorrect type for first argument" + # Testing Rod Plane Contact wrapper with incorrect type for first argument with pytest.raises(TypeError) as excinfo: rod_plane_contact._check_systems_validity(mock_list, mock_plane) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a plane" - ).format(mock_list.__class__, mock_plane.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['RodBase']." == str( + excinfo.value + ) def test_rod_plane_contact_without_contact(self): """ @@ -905,23 +878,19 @@ def test_check_systems_validity_with_invalid_systems( np.array([0.0, 0.0, 0.0]), # forward, backward, sideways ) - "Testing Rod Plane Contact wrapper with incorrect type for second argument" + # Testing Rod Plane Contact wrapper with incorrect type for second argument with pytest.raises(TypeError) as excinfo: rod_plane_contact._check_systems_validity(mock_rod, mock_list) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a plane" - ).format(mock_rod.__class__, mock_list.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['SurfaceBase']." == str( + excinfo.value + ) - "Testing Rod Plane wrapper with incorrect type for first argument" + # Testing Rod Plane wrapper with incorrect type for first argument with pytest.raises(TypeError) as excinfo: rod_plane_contact._check_systems_validity(mock_list, mock_plane) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a rod, second should be a plane" - ).format(mock_list.__class__, mock_plane.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['RodBase']." == str( + excinfo.value + ) @pytest.mark.parametrize("velocity", [-1.0, -3.0, 1.0, 5.0, 2.0]) def test_axial_kinetic_friction(self, velocity): @@ -1212,7 +1181,7 @@ def initializer( # create plane plane = MockPlane() - plane.origin = np.array([0.0, -cylinder.radius[0] + shift, 0.0]).reshape(3, 1) + plane.origin = np.array([0.0, -cylinder.radius + shift, 0.0]).reshape(3, 1) plane.normal = plane_normal.reshape( 3, ) @@ -1232,23 +1201,19 @@ def test_check_systems_validity_with_invalid_systems( mock_list = [1, 2, 3] cylinder_plane_contact = CylinderPlaneContact(k=1.0, nu=0.0) - "Testing Cylinder Plane Contact wrapper with incorrect type for second argument" + # Testing Cylinder Plane Contact wrapper with incorrect type for second argument with pytest.raises(TypeError) as excinfo: cylinder_plane_contact._check_systems_validity(mock_cylinder, mock_list) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a cylinder, second should be a plane" - ).format(mock_cylinder.__class__, mock_list.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['SurfaceBase']." == str( + excinfo.value + ) - "Testing Cylinder Plane wrapper with incorrect type for first argument" + # Testing Cylinder Plane wrapper with incorrect type for first argument with pytest.raises(TypeError) as excinfo: cylinder_plane_contact._check_systems_validity(mock_list, mock_plane) - assert ( - "Systems provided to the contact class have incorrect order/type. \n" - " First system is {0} and second system is {1}. \n" - " First system should be a cylinder, second should be a plane" - ).format(mock_list.__class__, mock_plane.__class__) == str(excinfo.value) + assert "System provided (list) must be derived from ['Cylinder']." == str( + excinfo.value + ) def test_cylinder_plane_contact_without_contact(self): """ diff --git a/tests/test_contact_functions.py b/tests/test_contact_functions.py index f25d5830..c3cae4b6 100644 --- a/tests/test_contact_functions.py +++ b/tests/test_contact_functions.py @@ -2,7 +2,7 @@ import numpy as np from numpy.testing import assert_allclose -from elastica.typing import RodBase +from elastica.rod import RodBase from elastica.rigidbody import Cylinder, Sphere from elastica._contact_functions import ( diff --git a/tests/test_contact_utils.py b/tests/test_contact_utils.py index e459d815..4e5fe1bb 100644 --- a/tests/test_contact_utils.py +++ b/tests/test_contact_utils.py @@ -4,7 +4,7 @@ import numpy as np from numpy.testing import assert_allclose from elastica.utils import Tolerance -from elastica.typing import RodBase +from elastica.rod import RodBase from elastica.rigidbody import Cylinder, Sphere from elastica.contact_utils import ( _dot_product, diff --git a/tests/test_dissipation.py b/tests/test_dissipation.py index 49fde712..8c48c680 100644 --- a/tests/test_dissipation.py +++ b/tests/test_dissipation.py @@ -151,7 +151,7 @@ def test_laplace_dissipation_filter_for_constant_field(filter_order): ) test_rod.velocity_collection[...] = 2.0 test_rod.omega_collection[...] = 3.0 - filter_damper.dampen_rates(rod=test_rod, time=0) + filter_damper.dampen_rates(system=test_rod, time=0) # filter should keep a spatially invariant field unaffected post_damping_velocity_collection = 2.0 * np.ones_like(test_rod.velocity_collection) post_damping_omega_collection = 3.0 * np.ones_like(test_rod.omega_collection) @@ -178,7 +178,7 @@ def test_laplace_dissipation_filter_for_flip_flop_field(): test_rod.omega_collection[..., 1::2] = 3.0 pre_damping_velocity_collection = test_rod.velocity_collection.copy() pre_damping_omega_collection = test_rod.omega_collection.copy() - filter_damper.dampen_rates(rod=test_rod, time=0) + filter_damper.dampen_rates(system=test_rod, time=0) post_damping_velocity_collection = np.zeros_like(test_rod.velocity_collection) post_damping_omega_collection = np.zeros_like(test_rod.omega_collection) # filter should remove the flip-flop mode th give the average constant mode @@ -243,7 +243,7 @@ def test_laplace_dissipation_filter_for_constant_field_for_ring_rod(filter_order ) test_rod.velocity_collection[...] = 2.0 test_rod.omega_collection[...] = 3.0 - filter_damper.dampen_rates(rod=test_rod, time=0) + filter_damper.dampen_rates(system=test_rod, time=0) # filter should keep a spatially invariant field unaffected post_damping_velocity_collection = 2.0 * np.ones_like(test_rod.velocity_collection) post_damping_omega_collection = 3.0 * np.ones_like(test_rod.omega_collection) @@ -270,7 +270,7 @@ def test_laplace_dissipation_filter_for_flip_flop_field_for_ring_rod(): test_rod.omega_collection[..., 1::2] = 3.0 pre_damping_velocity_collection = test_rod.velocity_collection.copy() pre_damping_omega_collection = test_rod.omega_collection.copy() - filter_damper.dampen_rates(rod=test_rod, time=0) + filter_damper.dampen_rates(system=test_rod, time=0) post_damping_velocity_collection = np.zeros_like(test_rod.velocity_collection) post_damping_omega_collection = np.zeros_like(test_rod.omega_collection) # filter should remove the flip-flop mode th give the average constant mode diff --git a/tests/test_interaction.py b/tests/test_interaction.py index 8d2b0e13..9d30b806 100644 --- a/tests/test_interaction.py +++ b/tests/test_interaction.py @@ -6,13 +6,8 @@ from elastica.utils import Tolerance, MaxDimension from elastica.interaction import ( InteractionPlane, - find_slipping_elements, AnisotropicFrictionalPlane, - node_to_element_mass_or_force, SlenderBodyTheory, - nodes_to_elements, - elements_to_nodes_inplace, - apply_normal_force_numba_rigid_body, ) from elastica.contact_utils import ( _node_to_element_mass_or_force, @@ -105,7 +100,7 @@ def test_interaction_without_contact(self, n_elem): [rod, interaction_plane, external_forces] = self.initializer(n_elem, shift) - interaction_plane.apply_normal_force(rod) + interaction_plane.apply_forces(rod) correct_forces = external_forces # since no contact assert_allclose(correct_forces, rod.external_forces, atol=Tolerance.atol()) @@ -127,7 +122,7 @@ def test_interaction_plane_without_k_and_nu(self, n_elem): [rod, interaction_plane, external_forces] = self.initializer(n_elem) - interaction_plane.apply_normal_force(rod) + interaction_plane.apply_forces(rod) correct_forces = np.zeros((3, n_elem + 1)) assert_allclose(correct_forces, rod.external_forces, atol=Tolerance.atol()) @@ -158,7 +153,7 @@ def test_interaction_plane_with_k_without_nu(self, n_elem, k_w): correct_forces[..., 0] *= 0.5 correct_forces[..., -1] *= 0.5 - interaction_plane.apply_normal_force(rod) + interaction_plane.apply_forces(rod) assert_allclose(correct_forces, rod.external_forces, atol=Tolerance.atol()) @@ -194,7 +189,7 @@ def test_interaction_plane_without_k_with_nu(self, n_elem, nu_w): correct_forces[..., 0] *= 0.5 correct_forces[..., -1] *= 0.5 - interaction_plane.apply_normal_force(rod) + interaction_plane.apply_forces(rod) assert_allclose(correct_forces, rod.external_forces, atol=Tolerance.atol()) @@ -224,7 +219,7 @@ def test_interaction_when_rod_is_under_plane(self, n_elem): n_elem, shift=offset_of_plane_with_respect_to_rod, plane_normal=plane_normal ) - interaction_plane.apply_normal_force(rod) + interaction_plane.apply_forces(rod) correct_forces = np.zeros((3, n_elem + 1)) assert_allclose(correct_forces, rod.external_forces, atol=Tolerance.atol()) @@ -269,7 +264,7 @@ def test_interaction_when_rod_is_under_plane_with_k_without_nu(self, n_elem, k_w correct_forces[..., 0] *= 0.5 correct_forces[..., -1] *= 0.5 - interaction_plane.apply_normal_force(rod) + interaction_plane.apply_forces(rod) assert_allclose(correct_forces, rod.external_forces, atol=Tolerance.atol()) @@ -317,58 +312,11 @@ def test_interaction_when_rod_is_under_plane_without_k_with_nu(self, n_elem, nu_ correct_forces[..., 0] *= 0.5 correct_forces[..., -1] *= 0.5 - interaction_plane.apply_normal_force(rod) + interaction_plane.apply_forces(rod) assert_allclose(correct_forces, rod.external_forces, atol=Tolerance.atol()) -class TestAuxiliaryFunctions: - @pytest.mark.parametrize("n_elem", [2, 3, 5, 10, 20]) - def test_linear_interpolation_slip_error_message(self, n_elem): - velocity_threshold = 1.0 - - # if slip velocity larger than threshold - velocity_slip = np.repeat( - np.array([0.0, 0.0, 2.0]).reshape(3, 1), n_elem, axis=1 - ) - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._find_slipping_elements()\n" - "instead for finding slipping elements." - ) - with pytest.raises(NotImplementedError) as error_info: - slip_function = find_slipping_elements(velocity_slip, velocity_threshold) - assert error_info.value.args[0] == error_message - - @pytest.mark.parametrize("n_elem", [2, 3, 5, 10, 20]) - def test_node_to_element_mass_or_force_error_message(self, n_elem): - random_vector = np.random.rand(3).reshape(3, 1) - input = np.repeat(random_vector, n_elem + 1, axis=1) - input[..., 0] *= 0.5 - input[..., -1] *= 0.5 - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._node_to_element_mass_or_force()\n" - "instead for converting the mass/forces on rod nodes to elements." - ) - with pytest.raises(NotImplementedError) as error_info: - output = node_to_element_mass_or_force(input) - assert error_info.value.args[0] == error_message - - @pytest.mark.parametrize("n_elem", [2, 10]) - def test_not_impl_error_for_nodes_to_elements(self, n_elem): - random_vector = np.random.rand(3).reshape(3, 1) - input = np.repeat(random_vector, n_elem + 1, axis=1) - error_message = ( - "This function is removed in v0.3.1. Please use\n" - "elastica.interaction.node_to_element_mass_or_force()\n" - "instead for node-to-element interpolation of mass/forces." - ) - with pytest.raises(NotImplementedError) as error_info: - nodes_to_elements(input) - assert error_info.value.args[0] == error_message - - class TestAnisotropicFriction: def initializer( self, @@ -739,140 +687,33 @@ def test_static_rolling_friction_total_torque_larger_than_static_friction_force( # Slender Body Theory Unit Tests +from elastica.interaction import ( + sum_over_elements, +) -try: - from elastica.interaction import ( - sum_over_elements, - node_to_element_position, - node_to_element_velocity, - node_to_element_pos_or_vel, - ) - - # These functions are used in the case if Numba is available - class TestAuxiliaryFunctionsForSlenderBodyTheory: - @pytest.mark.parametrize("n_elem", [2, 3, 5, 10, 20]) - def test_sum_over_elements(self, n_elem): - """ - This function test sum over elements function with - respect to default python function .sum(). We write - this function because with numba we can get the sum - faster. - Parameters - ---------- - n_elem - - Returns - ------- - - """ - - input_variable = np.random.rand(n_elem) - correct_output = input_variable.sum() - output = sum_over_elements(input_variable) - assert_allclose(correct_output, output, atol=Tolerance.atol()) - - @pytest.mark.parametrize("n_elem", [2, 3, 5, 10, 20]) - def test_node_to_element_position_error_message(self, n_elem): - """ - This function tests node_to_element_position function. We are - converting node positions to element positions. Here also - we are using numba to speed up the process. - - Parameters - ---------- - n_elem - - Returns - ------- - - """ - random = np.random.rand() # Adding some random numbers - input_position = random * np.ones((3, n_elem + 1)) - correct_output = random * np.ones((3, n_elem)) - - error_message = ( - "This function is removed in v0.3.2. For node-to-element_position() interpolation please use: \n" - "elastica.contact_utils._node_to_element_position() for rod position \n" - "For detail, refer to issue #113." - ) - with pytest.raises(NotImplementedError) as error_info: - output = node_to_element_position(input_position) - assert error_info.value.args[0] == error_message - - @pytest.mark.parametrize("n_elem", [2, 3, 5, 10, 20]) - def test_node_to_element_velocity_error_message(self, n_elem): - """ - This function tests node_to_element_velocity function. We are - converting node velocities to element velocities. Here also - we are using numba to speed up the process. - - Parameters - ---------- - n_elem - - Returns - ------- - - """ - random = np.random.rand() # Adding some random numbers - input_velocity = random * np.ones((3, n_elem + 1)) - input_mass = 2.0 * random * np.ones(n_elem + 1) - correct_output = random * np.ones((3, n_elem)) - - error_message = ( - "This function is removed in v0.3.2. For node-to-element_velocity() interpolation please use: \n" - "elastica.contact_utils._node_to_element_velocity() for rod velocity. \n" - "For detail, refer to issue #113." - ) - with pytest.raises(NotImplementedError) as error_info: - output = node_to_element_velocity( - mass=input_mass, node_velocity_collection=input_velocity - ) - assert error_info.value.args[0] == error_message - - @pytest.mark.parametrize("n_elem", [2, 3, 5, 10, 20]) - def test_not_impl_error_for_node_to_element_pos_or_vel(self, n_elem): - random = np.random.rand() # Adding some random numbers - input_velocity = random * np.ones((3, n_elem + 1)) - error_message = ( - "This function is removed in v0.3.0. For node-to-element interpolation please use: \n" - "elastica.contact_utils._node_to_element_position() for rod position \n" - "elastica.contact_utils._node_to_element_velocity() for rod velocity. \n" - "For detail, refer to issue #80." - ) - with pytest.raises(NotImplementedError) as error_info: - node_to_element_pos_or_vel(input_velocity) - assert error_info.value.args[0] == error_message - - @pytest.mark.parametrize("n_elem", [2, 3, 5, 10, 20]) - def test_elements_to_nodes_inplace_error_message(self, n_elem): - """ - This function tests _elements_to_nodes_inplace. We are - converting node velocities to element velocities. Here also - we are using numba to speed up the process. - - Parameters - ---------- - n_elem - - Returns - ------- - - """ - random = np.random.rand() # Adding some random numbers - vector_in_element_frame = random * np.ones((3, n_elem)) - vector_in_node_frame = np.zeros((3, n_elem + 1)) - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._elements_to_nodes_inplace()\n" - "instead for updating nodal forces using the forces computed on elements." - ) - with pytest.raises(NotImplementedError) as error_info: - elements_to_nodes_inplace(vector_in_element_frame, vector_in_node_frame) - assert error_info.value.args[0] == error_message -except ImportError: - pass +# These functions are used in the case if Numba is available +class TestAuxiliaryFunctionsForSlenderBodyTheory: + @pytest.mark.parametrize("n_elem", [2, 3, 5, 10, 20]) + def test_sum_over_elements(self, n_elem): + """ + This function test sum over elements function with + respect to default python function .sum(). We write + this function because with numba we can get the sum + faster. + Parameters + ---------- + n_elem + + Returns + ------- + + """ + + input_variable = np.random.rand(n_elem) + correct_output = input_variable.sum() + output = sum_over_elements(input_variable) + assert_allclose(correct_output, output, atol=Tolerance.atol()) class TestSlenderBody: @@ -1011,42 +852,3 @@ def test_slender_body_matrix_product_only_xy(self, n_elem): slender_body_theory.apply_forces(rod) assert_allclose(correct_forces, rod.external_forces, atol=Tolerance.atol()) - - -@pytest.mark.parametrize("n_elem", [2, 3, 5, 10, 20]) -def test_apply_normal_force_numba_rigid_body_error_message(n_elem): - """ - This function _elements_to_nodes_inplace. We are - converting node velocities to element velocities. Here also - we are using numba to speed up the process. - - Parameters - ---------- - n_elem - - Returns - ------- - - """ - - position_collection = np.zeros((3, n_elem + 1)) - position_collection[0, :] = np.linspace(0, 1.0, n_elem + 1) - - error_message = ( - "This function is removed in v0.3.2. For cylinder plane contact please use: \n" - "elastica._contact_functions._calculate_contact_forces_cylinder_plane() \n" - "For detail, refer to issue #113." - ) - with pytest.raises(NotImplementedError) as error_info: - apply_normal_force_numba_rigid_body( - plane_origin=np.array([0.0, 0.0, 0.0]), - plane_normal=np.array([0.0, 0.0, 1.0]), - surface_tol=1e-4, - k=1.0, - nu=1.0, - length=1.0, - position_collection=position_collection, - velocity_collection=np.zeros((3, n_elem + 1)), - external_forces=np.zeros((3, n_elem + 1)), - ) - assert error_info.value.args[0] == error_message diff --git a/tests/test_joint.py b/tests/test_joint.py index d016226a..6d516591 100644 --- a/tests/test_joint.py +++ b/tests/test_joint.py @@ -13,7 +13,6 @@ import numpy as np import pytest from scipy.spatial.transform import Rotation -from elastica.joint import ExternalContact, SelfContact # TODO: change tests and made them independent of rod, at least assigin hardcoded values for forces and torques @@ -376,118 +375,8 @@ def test_fixedjoint(rest_euler_angle): ) -from elastica.joint import ( - _dot_product, - _norm, - _clip, - _out_of_bounds, - _find_min_dist, - _aabbs_not_intersecting, -) - - -@pytest.mark.parametrize("ndim", [2, 3, 5, 10, 20]) -def test_dot_product_error_message(ndim): - vector1 = np.random.randn(ndim) - vector2 = np.random.randn(ndim) - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._dot_product()\n" - "instead for find the dot product between a and b." - ) - with pytest.raises(NotImplementedError) as error_info: - dot_product = _dot_product(vector1, vector2) - assert error_info.value.args[0] == error_message - - -@pytest.mark.parametrize("ndim", [2, 3, 5, 10, 20]) -def test_norm_error_message(ndim): - vec1 = np.random.randn(ndim) - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._norm()\n" - "instead for finding the norm of a." - ) - with pytest.raises(NotImplementedError) as error_info: - norm = _norm(vec1) - assert error_info.value.args[0] == error_message - - -@pytest.mark.parametrize( - "x, result", - [(0.5, 1), (1.5, 1.5), (2.5, 2)], -) -def test_clip_error_message(x, result): - low = 1.0 - high = 2.0 - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._clip()\n" - "instead for clipping x." - ) - with pytest.raises(NotImplementedError) as error_info: - _clip(x, low, high) - assert error_info.value.args[0] == error_message - - -@pytest.mark.parametrize( - "x, result", - [(0.5, 1), (1.5, 1.5), (2.5, 2)], -) -def test_out_of_bounds_error_message(x, result): - low = 1.0 - high = 2.0 - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._out_of_bounds()\n" - "instead for checking if x is out of bounds." - ) - with pytest.raises(NotImplementedError) as error_info: - _out_of_bounds(x, low, high) - assert error_info.value.args[0] == error_message - - -def test_find_min_dist_error_message(): - x1 = np.array([0, 0, 0]) - e1 = np.array([1, 1, 1]) - x2 = np.array([0, 1, 0]) - e2 = np.array([1, 0, 1]) - - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._find_min_dist()\n" - "instead for finding minimum distance between contact points." - ) - with pytest.raises(NotImplementedError) as error_info: - ( - min_dist_vec, - contact_point_of_system2, - contact_point_of_system1, - ) = _find_min_dist(x1, e1, x2, e2) - assert error_info.value.args[0] == error_message - - -def test_aabbs_not_intersecting_error_message(): - aabb_one = np.array([[0, 0], [0, 0], [0, 0]]) - aabb_two = np.array([[0, 0], [0, 0], [0, 0]]) - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._aabbs_not_intersecting()\n" - "instead for checking aabbs intersection." - ) - with pytest.raises(NotImplementedError) as error_info: - _aabbs_not_intersecting(aabb_one, aabb_two) - assert error_info.value.args[0] == error_message - - -from elastica.typing import RodBase, RigidBodyBase -from elastica.joint import ( - _prune_using_aabbs_rod_rigid_body, - _prune_using_aabbs_rod_rod, - _calculate_contact_forces_rod_rigid_body, - _calculate_contact_forces_rod_rod, - _calculate_contact_forces_self_rod, -) +from elastica.rod import RodBase +from elastica.rigidbody import RigidBodyBase def mock_rod_init(self): @@ -522,502 +411,3 @@ def mock_rigid_body_init(self): self.external_forces = np.array([[0.0], [0.0], [0.0]]) self.external_torques = np.array([[0.0], [0.0], [0.0]]) self.velocity_collection = np.array([[0.0], [0.0], [0.0]]) - - -MockRod = type("MockRod", (RodBase,), {"__init__": mock_rod_init}) - -MockRigidBody = type( - "MockRigidBody", (RigidBodyBase,), {"__init__": mock_rigid_body_init} -) - - -def test_prune_using_aabbs_rod_rigid_body_error_message(): - rod = MockRod() - cylinder = MockRigidBody() - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._prune_using_aabbs_rod_cylinder()\n" - "instead for checking rod cylinder intersection." - ) - with pytest.raises(NotImplementedError) as error_info: - _prune_using_aabbs_rod_rigid_body( - rod.position_collection, - rod.radius, - rod.lengths, - cylinder.position_collection, - cylinder.director_collection, - cylinder.radius, - cylinder.length, - ) - assert error_info.value.args[0] == error_message - - -def test_prune_using_aabbs_rod_rod_error_message(): - rod_one = MockRod() - rod_two = MockRod() - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica.contact_utils._prune_using_aabbs_rod_rod()\n" - "instead for checking rod rod intersection." - ) - with pytest.raises(NotImplementedError) as error_info: - _prune_using_aabbs_rod_rod( - rod_one.position_collection, - rod_one.radius, - rod_one.lengths, - rod_two.position_collection, - rod_two.radius, - rod_two.lengths, - ) - assert error_info.value.args[0] == error_message - - -def test_calculate_contact_forces_rod_rigid_body_error_message(): - - "initializing rod parameters" - rod = MockRod() - rod_element_position = 0.5 * ( - rod.position_collection[..., 1:] + rod.position_collection[..., :-1] - ) - - "initializing cylinder parameters" - cylinder = MockRigidBody() - x_cyl = ( - cylinder.position_collection[..., 0] - - 0.5 * cylinder.length * cylinder.director_collection[2, :, 0] - ) - - "initializing constants" - """ - Setting contact_k = 1 and other parameters to 0, - so the net forces becomes a function of contact forces only. - """ - k = 1.0 - nu = 0 - velocity_damping_coefficient = 0 - friction_coefficient = 0 - - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica._contact_functions._calculate_contact_forces_rod_cylinder()\n" - "instead for calculating rod cylinder contact forces." - ) - - "Function call" - with pytest.raises(NotImplementedError) as error_info: - _calculate_contact_forces_rod_rigid_body( - rod_element_position, - rod.lengths * rod.tangents, - cylinder.position_collection[..., 0], - x_cyl, - cylinder.length * cylinder.director_collection[2, :, 0], - rod.radius + cylinder.radius, - rod.lengths + cylinder.length, - rod.internal_forces, - rod.external_forces, - cylinder.external_forces, - cylinder.external_torques, - cylinder.director_collection[:, :, 0], - rod.velocity_collection, - cylinder.velocity_collection, - k, - nu, - velocity_damping_coefficient, - friction_coefficient, - ) - assert error_info.value.args[0] == error_message - - -def test_calculate_contact_forces_rod_rod_error_message(): - - rod_one = MockRod() - rod_two = MockRod() - """Placing rod two such that its first element just touches the last element of rod one.""" - rod_two.position_collection = np.array([[4, 5, 6], [0, 0, 0], [0, 0, 0]]) - - "initializing constants" - """ - Setting contact_k = 1 and nu to 0, - so the net forces becomes a function of contact forces only. - """ - k = 1.0 - nu = 0.0 - - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica._contact_functions._calculate_contact_forces_rod_rod()\n" - "instead for calculating rod rod contact forces." - ) - - "Function call" - with pytest.raises(NotImplementedError) as error_info: - _calculate_contact_forces_rod_rod( - rod_one.position_collection[..., :-1], - rod_one.radius, - rod_one.lengths, - rod_one.tangents, - rod_one.velocity_collection, - rod_one.internal_forces, - rod_one.external_forces, - rod_two.position_collection[..., :-1], - rod_two.radius, - rod_two.lengths, - rod_two.tangents, - rod_two.velocity_collection, - rod_two.internal_forces, - rod_two.external_forces, - k, - nu, - ) - assert error_info.value.args[0] == error_message - - -def test_calculate_contact_forces_self_rod_error_message(): - "Function to test the calculate contact forces self rod function" - - "Testing function with handcrafted/calculated values" - - rod = MockRod() - - "initializing constants" - k = 1.0 - nu = 1.0 - - error_message = ( - "This function is removed in v0.3.2. Please use\n" - "elastica._contact_functions._calculate_contact_forces_self_rod()\n" - "instead for calculating rod self-contact forces." - ) - - "Function call" - with pytest.raises(NotImplementedError) as error_info: - _calculate_contact_forces_self_rod( - rod.position_collection[..., :-1], - rod.radius, - rod.lengths, - rod.tangents, - rod.velocity_collection, - rod.external_forces, - k, - nu, - ) - assert error_info.value.args[0] == error_message - - -class TestExternalContact: - def test_external_contact_rod_rigid_body_with_collision_with_k_without_nu_and_friction( - self, - ): - - "Testing External Contact wrapper with Collision with analytical verified values" - - mock_rod = MockRod() - mock_rigid_body = MockRigidBody() - ext_contact = ExternalContact(k=1.0, nu=0.0) - ext_contact.apply_forces(mock_rod, 0, mock_rigid_body, 1) - - """Details and reasoning about the values are given in 'test_contact_specific_functions.py/test_claculate_contact_forces_rod_rigid_body()'""" - assert_allclose( - mock_rod.external_forces, - np.array([[0.166666, 0.333333, 0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), - atol=1e-6, - ) - - assert_allclose( - mock_rigid_body.external_forces, np.array([[-0.5], [0.0], [0.0]]), atol=1e-6 - ) - - assert_allclose( - mock_rigid_body.external_torques, np.array([[0.0], [0.0], [0.0]]), atol=1e-6 - ) - - def test_external_contact_rod_rigid_body_with_collision_with_nu_without_k_and_friction( - self, - ): - - "Testing External Contact wrapper with Collision with analytical verified values" - - mock_rod = MockRod() - "Moving rod towards the cylinder with a velocity of -1 in x-axis" - mock_rod.velocity_collection = np.array([[-1, 0, 0], [-1, 0, 0], [-1, 0, 0]]) - mock_rigid_body = MockRigidBody() - "Moving cylinder towards the rod with a velocity of 1 in x-axis" - mock_rigid_body.velocity_collection = np.array([[1], [0], [0]]) - ext_contact = ExternalContact(k=0.0, nu=1.0) - ext_contact.apply_forces(mock_rod, 0, mock_rigid_body, 1) - - """Details and reasoning about the values are given in 'test_contact_specific_functions.py/test_claculate_contact_forces_rod_rigid_body()'""" - assert_allclose( - mock_rod.external_forces, - np.array([[0.5, 1, 0], [0, 0, 0], [0, 0, 0]]), - atol=1e-6, - ) - - assert_allclose( - mock_rigid_body.external_forces, np.array([[-1.5], [0], [0]]), atol=1e-6 - ) - - assert_allclose( - mock_rigid_body.external_torques, np.array([[0.0], [0.0], [0.0]]), atol=1e-6 - ) - - def test_external_contact_rod_rigid_body_with_collision_with_k_and_nu_without_friction( - self, - ): - - "Testing External Contact wrapper with Collision with analytical verified values" - - mock_rod = MockRod() - "Moving rod towards the cylinder with a velocity of -1 in x-axis" - mock_rod.velocity_collection = np.array([[-1, 0, 0], [-1, 0, 0], [-1, 0, 0]]) - mock_rigid_body = MockRigidBody() - "Moving cylinder towards the rod with a velocity of 1 in x-axis" - mock_rigid_body.velocity_collection = np.array([[1], [0], [0]]) - ext_contact = ExternalContact(k=1.0, nu=1.0) - ext_contact.apply_forces(mock_rod, 0, mock_rigid_body, 1) - - """Details and reasoning about the values are given in 'test_contact_specific_functions.py/test_claculate_contact_forces_rod_rigid_body()'""" - assert_allclose( - mock_rod.external_forces, - np.array([[0.666666, 1.333333, 0], [0, 0, 0], [0, 0, 0]]), - atol=1e-6, - ) - - assert_allclose( - mock_rigid_body.external_forces, np.array([[-2], [0], [0]]), atol=1e-6 - ) - - assert_allclose( - mock_rigid_body.external_torques, np.array([[0.0], [0.0], [0.0]]), atol=1e-6 - ) - - def test_external_contact_rod_rigid_body_with_collision_with_k_and_nu_and_friction( - self, - ): - - "Testing External Contact wrapper with Collision with analytical verified values" - - mock_rod = MockRod() - "Moving rod towards the cylinder with a velocity of -1 in x-axis" - mock_rod.velocity_collection = np.array([[-1, 0, 0], [-1, 0, 0], [-1, 0, 0]]) - mock_rigid_body = MockRigidBody() - "Moving cylinder towards the rod with a velocity of 1 in x-axis" - mock_rigid_body.velocity_collection = np.array([[1], [0], [0]]) - ext_contact = ExternalContact( - k=1.0, nu=1.0, velocity_damping_coefficient=0.1, friction_coefficient=0.1 - ) - ext_contact.apply_forces(mock_rod, 0, mock_rigid_body, 1) - - """Details and reasoning about the values are given in 'test_contact_specific_functions.py/test_claculate_contact_forces_rod_rigid_body()'""" - assert_allclose( - mock_rod.external_forces, - np.array( - [ - [0.666666, 1.333333, 0], - [0.033333, 0.066666, 0], - [0.033333, 0.066666, 0], - ] - ), - atol=1e-6, - ) - - assert_allclose( - mock_rigid_body.external_forces, np.array([[-2], [-0.1], [-0.1]]), atol=1e-6 - ) - - assert_allclose( - mock_rigid_body.external_torques, np.array([[0.0], [0.0], [0.0]]), atol=1e-6 - ) - - def test_external_contact_rod_rigid_body_without_collision(self): - - "Testing External Contact wrapper without Collision with analytical verified values" - - mock_rod = MockRod() - mock_rigid_body = MockRigidBody() - ext_contact = ExternalContact(k=1.0, nu=0.5) - - """Setting rigid body position such that there is no collision""" - mock_rigid_body.position_collection = np.array([[400], [500], [600]]) - mock_rod_external_forces_before_execution = mock_rod.external_forces.copy() - mock_rigid_body_external_forces_before_execution = ( - mock_rigid_body.external_forces.copy() - ) - mock_rigid_body_external_torques_before_execution = ( - mock_rigid_body.external_torques.copy() - ) - ext_contact.apply_forces(mock_rod, 0, mock_rigid_body, 1) - - assert_allclose( - mock_rod.external_forces, mock_rod_external_forces_before_execution - ) - assert_allclose( - mock_rigid_body.external_forces, - mock_rigid_body_external_forces_before_execution, - ) - assert_allclose( - mock_rigid_body.external_torques, - mock_rigid_body_external_torques_before_execution, - ) - - def test_external_contact_with_two_rods_with_collision_with_k_without_nu(self): - - "Testing External Contact wrapper with two rods with analytical verified values" - "Test values have been copied from 'test_contact_specific_functions.py/test_calculate_contact_forces_rod_rod()'" - - mock_rod_one = MockRod() - mock_rod_two = MockRod() - mock_rod_two.position_collection = np.array([[4, 5, 6], [0, 0, 0], [0, 0, 0]]) - ext_contact = ExternalContact(k=1.0, nu=0.0) - ext_contact.apply_forces(mock_rod_one, 0, mock_rod_two, 0) - - assert_allclose( - mock_rod_one.external_forces, - np.array([[0, -0.666666, -0.333333], [0, 0, 0], [0, 0, 0]]), - atol=1e-6, - ) - assert_allclose( - mock_rod_two.external_forces, - np.array([[0.333333, 0.666666, 0], [0, 0, 0], [0, 0, 0]]), - atol=1e-6, - ) - - def test_external_contact_with_two_rods_with_collision_without_k_with_nu(self): - - "Testing External Contact wrapper with two rods with analytical verified values" - "Test values have been copied from 'test_contact_specific_functions.py/test_calculate_contact_forces_rod_rod()'" - - mock_rod_one = MockRod() - mock_rod_two = MockRod() - - """Moving the rods towards each other with a velocity of 1 along the x-axis.""" - mock_rod_one.velocity_collection = np.array([[1, 0, 0], [1, 0, 0], [1, 0, 0]]) - mock_rod_two.velocity_collection = np.array( - [[-1, 0, 0], [-1, 0, 0], [-1, 0, 0]] - ) - mock_rod_two.position_collection = np.array([[4, 5, 6], [0, 0, 0], [0, 0, 0]]) - ext_contact = ExternalContact(k=0.0, nu=1.0) - ext_contact.apply_forces(mock_rod_one, 0, mock_rod_two, 0) - - assert_allclose( - mock_rod_one.external_forces, - np.array( - [[0, -0.333333, -0.166666], [0, 0, 0], [0, 0, 0]], - ), - atol=1e-6, - ) - assert_allclose( - mock_rod_two.external_forces, - np.array([[0.166666, 0.333333, 0], [0, 0, 0], [0, 0, 0]]), - atol=1e-6, - ) - - def test_external_contact_with_two_rods_with_collision_with_k_and_nu(self): - - "Testing External Contact wrapper with two rods with analytical verified values" - "Test values have been copied from 'test_contact_specific_functions.py/test_calculate_contact_forces_rod_rod()'" - - mock_rod_one = MockRod() - mock_rod_two = MockRod() - - """Moving the rods towards each other with a velocity of 1 along the x-axis.""" - mock_rod_one.velocity_collection = np.array([[1, 0, 0], [1, 0, 0], [1, 0, 0]]) - mock_rod_two.velocity_collection = np.array( - [[-1, 0, 0], [-1, 0, 0], [-1, 0, 0]] - ) - mock_rod_two.position_collection = np.array([[4, 5, 6], [0, 0, 0], [0, 0, 0]]) - ext_contact = ExternalContact(k=1.0, nu=1.0) - ext_contact.apply_forces(mock_rod_one, 0, mock_rod_two, 0) - - assert_allclose( - mock_rod_one.external_forces, - np.array( - [[0, -1, -0.5], [0, 0, 0], [0, 0, 0]], - ), - atol=1e-6, - ) - assert_allclose( - mock_rod_two.external_forces, - np.array([[0.5, 1, 0], [0, 0, 0], [0, 0, 0]]), - atol=1e-6, - ) - - def test_external_contact_with_two_rods_without_collision(self): - - "Testing External Contact wrapper with two rods with analytical verified values" - - mock_rod_one = MockRod() - mock_rod_two = MockRod() - - "Setting rod two position such that there is no collision" - mock_rod_two.position_collection = np.array( - [[100, 101, 102], [0, 0, 0], [0, 0, 0]] - ) - ext_contact = ExternalContact(k=1.0, nu=1.0) - mock_rod_one_external_forces_before_execution = ( - mock_rod_one.external_forces.copy() - ) - mock_rod_two_external_forces_before_execution = ( - mock_rod_two.external_forces.copy() - ) - ext_contact.apply_forces(mock_rod_one, 0, mock_rod_two, 0) - - assert_allclose( - mock_rod_one.external_forces, mock_rod_one_external_forces_before_execution - ) - assert_allclose( - mock_rod_two.external_forces, mock_rod_two_external_forces_before_execution - ) - - -class TestSelfContact: - def test_self_contact_with_rod_self_collision(self): - - "Testing Self Contact wrapper rod self collision with analytical verified values" - - mock_rod = MockRod() - - "Test values have been copied from 'test_contact_specific_functions.py/test_calculate_contact_forces_self_rod()'" - mock_rod.n_elems = 3 - mock_rod.position_collection = np.array( - [[1, 4, 4, 1], [0, 0, 1, 1], [0, 0, 0, 0]] - ) - mock_rod.radius = np.array([1, 1, 1]) - mock_rod.lengths = np.array([3, 1, 3]) - mock_rod.tangents = np.array( - [[1.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]] - ) - mock_rod.velocity_collection = np.array( - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]] - ) - mock_rod.internal_forces = np.array( - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]] - ) - mock_rod.external_forces = np.array( - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]] - ) - sel_contact = SelfContact(k=1.0, nu=0.0) - sel_contact.apply_forces(mock_rod, 0, mock_rod, 0) - - assert_allclose( - mock_rod.external_forces, - np.array( - [[0, 0, 0, 0], [-0.333333, -0.666666, 0.666666, 0.333333], [0, 0, 0, 0]] - ), - atol=1e-6, - ) - - def test_self_contact_with_rod_no_self_collision(self): - - "Testing Self Contact wrapper rod no self collision with analytical verified values" - - mock_rod = MockRod() - - "the initially set rod does not have self collision" - mock_rod_external_forces_before_execution = mock_rod.external_forces.copy() - sel_contact = SelfContact(k=1.0, nu=1.0) - sel_contact.apply_forces(mock_rod, 0, mock_rod, 0) - - assert_allclose( - mock_rod.external_forces, mock_rod_external_forces_before_execution - ) diff --git a/tests/test_math/test_governing_equations.py b/tests/test_math/test_governing_equations.py index eca87eef..92c22332 100644 --- a/tests/test_math/test_governing_equations.py +++ b/tests/test_math/test_governing_equations.py @@ -49,8 +49,8 @@ def constructor(n_elem): ) # Ghost needed for Cosserat rod functions adapted for block structure. - rod.ghost_elems_idx = np.empty((0), dtype=int) - rod.ghost_voronoi_idx = np.empty((0), dtype=int) + rod.ghost_elems_idx = np.empty((0), dtype=np.int32) + rod.ghost_voronoi_idx = np.empty((0), dtype=np.int32) return cls, rod @@ -372,7 +372,7 @@ def test_case_compute_internal_forces(self, n_elem, dilatation): test_rod.shear_matrix, test_rod.internal_stress, test_rod.internal_forces, - ghost_elems_idx=np.empty((0), dtype=int), + ghost_elems_idx=np.empty((0), dtype=np.int32), ) assert_allclose( @@ -725,7 +725,7 @@ def test_case_lagrange_transport_unsteady_dilatation(self): test_rod.dilatation, test_rod.dilatation_rate, test_rod.internal_torques, - ghost_voronoi_idx=np.empty((0), dtype=int), + ghost_voronoi_idx=np.empty((0), dtype=np.int32), ) # computed internal torques has to be zero. Internal torques created by Lagrangian diff --git a/tests/test_math/test_timestepper.py b/tests/test_math/test_timestepper.py index 878d9250..2afa6de6 100644 --- a/tests/test_math/test_timestepper.py +++ b/tests/test_math/test_timestepper.py @@ -14,22 +14,16 @@ ScalarExponentialDampedHarmonicOscillatorCollectiveSystem, ) from elastica.timestepper import integrate, extend_stepper_interface -from elastica.timestepper._stepper_interface import _TimeStepper from elastica.timestepper.explicit_steppers import ( RungeKutta4, - ExplicitStepperTag, EulerForward, + ExplicitStepperMixin, ) - -# from elastica.timestepper.explicit_steppers import ( -# StatefulRungeKutta4, -# StatefulEulerForward, -# ) from elastica.timestepper.symplectic_steppers import ( PositionVerlet, PEFRL, - SymplecticStepperTag, + SymplecticStepperMixin, ) @@ -39,40 +33,36 @@ class TestExtendStepperInterface: """TODO add documentation""" - class MockSymplecticStepper: - Tag = SymplecticStepperTag() + class MockSymplecticStepper(SymplecticStepperMixin): + + def get_steps(self): + return [self._kinematic_step, self._dynamic_step, self._kinematic_step] + + def get_prefactors(self): + return [self._prefactor, self._prefactor] - def _first_prefactor(self): + def _prefactor(self): pass - def _first_kinematic_step(self): + def _kinematic_step(self): pass - def _first_dynamic_step(self): + def _dynamic_step(self): pass - class MockExplicitStepper: - Tag = ExplicitStepperTag() + class MockExplicitStepper(ExplicitStepperMixin): - def _first_stage(self): - pass + def get_stages(self): + return [self._stage] - def _first_update(self): - pass + def get_updates(self): + return [self._update] - from elastica.timestepper.symplectic_steppers import ( - _SystemInstanceStepper as symplectic_instance_stepper, - ) - from elastica.timestepper.symplectic_steppers import ( - _SystemCollectionStepper as symplectic_collection_stepper, - ) + def _stage(self): + pass - from elastica.timestepper.explicit_steppers import ( - _SystemInstanceStepper as explicit_instance_stepper, - ) - from elastica.timestepper.explicit_steppers import ( - _SystemCollectionStepper as explicit_collection_stepper, - ) + def _update(self): + pass # We cannot call a stepper on a system until both the stepper # and system "see" one another (for performance reasons, mostly) @@ -80,63 +70,45 @@ def _first_update(self): # the interface (interface_cls). It should however have the interface # after "seeing" the system, via extend_stepper_interface @pytest.mark.parametrize( - "stepper_and_interface", + "stepper_module", [ - (MockSymplecticStepper, symplectic_instance_stepper), - (MockExplicitStepper, explicit_instance_stepper), + MockSymplecticStepper, + MockExplicitStepper, ], ) - def test_symplectic_stepper_interface_for_simple_systems( - self, stepper_and_interface - ): + def test_symplectic_stepper_interface_for_simple_systems(self, stepper_module): system = ScalarExponentialDecaySystem() - (stepper_cls, interface_cls) = stepper_and_interface - stepper = stepper_cls() + stepper = stepper_module() stepper_methods = None - assert stepper_methods is None - _, stepper_methods = extend_stepper_interface(stepper, system) assert stepper_methods @pytest.mark.parametrize( - "stepper_and_interface", - [ - (MockSymplecticStepper, symplectic_collection_stepper), - (MockExplicitStepper, explicit_collection_stepper), - ], + "stepper_module", + [MockSymplecticStepper, MockExplicitStepper], ) - def test_symplectic_stepper_interface_for_collective_systems( - self, stepper_and_interface - ): + def test_symplectic_stepper_interface_for_collective_systems(self, stepper_module): system = SymplecticUndampedHarmonicOscillatorCollectiveSystem() - (stepper_cls, interface_cls) = stepper_and_interface - stepper = stepper_cls() + stepper = stepper_module() stepper_methods = None - assert stepper_methods is None - _, stepper_methods = extend_stepper_interface(stepper, system) - assert stepper_methods + assert stepper_methods == stepper.steps_and_prefactors class MockBadStepper: - Tag = int() # an arbitrary tag that doesn't mean anything - - @pytest.mark.parametrize( - "stepper_and_interface", [(MockBadStepper, symplectic_collection_stepper)] - ) - def test_symplectic_stepper_throws_for_bad_stepper(self, stepper_and_interface): - system = ScalarExponentialDecaySystem() - (stepper_cls, interface_cls) = stepper_and_interface - stepper = stepper_cls() + pass - assert interface_cls not in stepper.__class__.__bases__ + @pytest.mark.parametrize("stepper_module", [MockBadStepper]) + def test_symplectic_stepper_throws_for_bad_stepper(self, stepper_module): + system = SymplecticUndampedHarmonicOscillatorCollectiveSystem() + stepper = stepper_module() with pytest.raises(NotImplementedError) as excinfo: extend_stepper_interface(stepper, system) - assert "steppers are supported" in str(excinfo.value) + assert "stepper is not supported" in str(excinfo.value) def test_integrate_throws_an_assert_for_negative_final_time(): @@ -159,138 +131,79 @@ def test_integrate_throws_an_assert_for_negative_total_steps(): SymplecticSteppers = [PositionVerlet, PEFRL] -class TestStepperInterface: - def test_no_base_access_error(self): - with pytest.raises(NotImplementedError) as excinfo: - _TimeStepper().do_step() - assert "not supposed to access" in str(excinfo.value) - - # @pytest.mark.parametrize("stepper", StatefulExplicitSteppers + SymplecticSteppers) - # def test_correct_orders(self, stepper): - # assert stepper().n_stages > 0, "Explicit stepper routine has no stages!" - - -""" -class TestExplicitSteppers: - @pytest.mark.parametrize("stepper", StatefulExplicitSteppers) - def test_against_scalar_exponential(self, stepper): - system = ScalarExponentialDecaySystem(-1, 1) - final_time = 1 - n_steps = 1000 - integrate(stepper(), system, final_time=final_time, n_steps=n_steps) - - assert_allclose( - system.state, - system.analytical_solution(final_time), - rtol=Tolerance.rtol() * 1e3, - atol=Tolerance.atol(), - ) - - @pytest.mark.parametrize("stepper", StatefulExplicitSteppers[:-1]) - def test_against_undamped_harmonic_oscillator(self, stepper): - system = UndampedSimpleHarmonicOscillatorSystem() - final_time = 4.0 * np.pi - n_steps = 2000 - integrate(stepper(), system, final_time=final_time, n_steps=n_steps) - - assert_allclose( - system.state, - system.analytical_solution(final_time), - rtol=Tolerance.rtol(), - atol=Tolerance.atol(), - ) - - @pytest.mark.parametrize("stepper", StatefulExplicitSteppers[:-1]) - def test_against_damped_harmonic_oscillator(self, stepper): - system = DampedSimpleHarmonicOscillatorSystem() - final_time = 4.0 * np.pi - n_steps = 2000 - integrate(stepper(), system, final_time=final_time, n_steps=n_steps) - - assert_allclose( - system.state, - system.analytical_solution(final_time), - rtol=Tolerance.rtol(), - atol=Tolerance.atol(), - ) - - def test_linear_exponential_integrator(self): - system = MultipleFrameRotationSystem(n_frames=128) - final_time = np.pi - n_steps = 1000 - integrate( - StatefulLinearExponentialIntegrator(), - system, - final_time=final_time, - n_steps=n_steps, - ) - - assert_allclose( - system.linearly_evolving_state, - system.analytical_solution(final_time), - atol=1e-4, - ) - - @pytest.mark.parametrize("explicit_stepper", StatefulExplicitSteppers[:-1]) - def test_explicit_against_analytical_system(self, explicit_stepper): - system = SecondOrderHybridSystem() - final_time = 1.0 - n_steps = 2000 - integrate(explicit_stepper(), system, final_time=final_time, n_steps=n_steps) - - assert_allclose( - system.final_solution(final_time), - system.analytical_solution(final_time), - rtol=Tolerance.rtol() * 1e2, - atol=Tolerance.atol(), - ) -""" - - -class TestSymplecticSteppers: - @pytest.mark.parametrize("stepper", SymplecticSteppers) - def test_symplectic_against_undamped_harmonic_oscillator(self, stepper): - system = SymplecticUndampedSimpleHarmonicOscillatorSystem( - omega=1.0 * np.pi, init_val=np.array([0.2, 0.8]) - ) - final_time = 4.0 * np.pi - n_steps = 2000 - time_stepper = stepper() - integrate(time_stepper, system, final_time=final_time, n_steps=n_steps) - - # Symplectic systems conserve energy to a certain extent - assert_allclose( - *system.compute_energy(final_time), - rtol=Tolerance.rtol() * 1e1, - atol=Tolerance.atol(), - ) - - # assert_allclose( - # system._state, - # system.analytical_solution(final_time), - # rtol=Tolerance.rtol(), - # atol=Tolerance.atol(), - # ) - - -""" - @pytest.mark.xfail - @pytest.mark.parametrize("symplectic_stepper", SymplecticSteppers) - def test_hybrid_symplectic_against_analytical_system(self, symplectic_stepper): - system = SecondOrderHybridSystem() - final_time = 1.0 - n_steps = 2000 - # stepper = SymplecticCosseratRodStepper(symplectic_stepper=symplectic_stepper()) - stepper = symplectic_stepper() - integrate(stepper, system, final_time=final_time, n_steps=n_steps) - - assert_allclose( - system.final_solution(final_time), - system.analytical_solution(final_time), - rtol=Tolerance.rtol() * 1e2, - atol=Tolerance.atol(), - ) -""" +# class TestExplicitSteppers: +# @pytest.mark.parametrize("stepper", StatefulExplicitSteppers) +# def test_against_scalar_exponential(self, stepper): +# system = ScalarExponentialDecaySystem(-1, 1) +# final_time = 1 +# n_steps = 1000 +# integrate(stepper(), system, final_time=final_time, n_steps=n_steps) +# +# assert_allclose( +# system.state, +# system.analytical_solution(final_time), +# rtol=Tolerance.rtol() * 1e3, +# atol=Tolerance.atol(), +# ) +# +# @pytest.mark.parametrize("stepper", StatefulExplicitSteppers[:-1]) +# def test_against_undamped_harmonic_oscillator(self, stepper): +# system = UndampedSimpleHarmonicOscillatorSystem() +# final_time = 4.0 * np.pi +# n_steps = 2000 +# integrate(stepper(), system, final_time=final_time, n_steps=n_steps) +# +# assert_allclose( +# system.state, +# system.analytical_solution(final_time), +# rtol=Tolerance.rtol(), +# atol=Tolerance.atol(), +# ) +# +# @pytest.mark.parametrize("stepper", StatefulExplicitSteppers[:-1]) +# def test_against_damped_harmonic_oscillator(self, stepper): +# system = DampedSimpleHarmonicOscillatorSystem() +# final_time = 4.0 * np.pi +# n_steps = 2000 +# integrate(stepper(), system, final_time=final_time, n_steps=n_steps) +# +# assert_allclose( +# system.state, +# system.analytical_solution(final_time), +# rtol=Tolerance.rtol(), +# atol=Tolerance.atol(), +# ) +# +# def test_linear_exponential_integrator(self): +# system = MultipleFrameRotationSystem(n_frames=128) +# final_time = np.pi +# n_steps = 1000 +# integrate( +# StatefulLinearExponentialIntegrator(), +# system, +# final_time=final_time, +# n_steps=n_steps, +# ) +# +# assert_allclose( +# system.linearly_evolving_state, +# system.analytical_solution(final_time), +# atol=1e-4, +# ) +# +# @pytest.mark.parametrize("explicit_stepper", StatefulExplicitSteppers[:-1]) +# def test_explicit_against_analytical_system(self, explicit_stepper): +# system = SecondOrderHybridSystem() +# final_time = 1.0 +# n_steps = 2000 +# integrate(explicit_stepper(), system, final_time=final_time, n_steps=n_steps) +# +# assert_allclose( +# system.final_solution(final_time), +# system.analytical_solution(final_time), +# rtol=Tolerance.rtol() * 1e2, +# atol=Tolerance.atol(), +# ) class TestSteppersAgainstCollectiveSystems: @@ -332,7 +245,7 @@ def test_explicit_steppers(self, explicit_stepper): # Before stepping, let's extend the interface of the stepper # while providing memory slots - from elastica.systems import make_memory_for_explicit_stepper + from elastica.systems.memory import make_memory_for_explicit_stepper memory_collection = make_memory_for_explicit_stepper(stepper, collective_system) from elastica.timestepper import extend_stepper_interface @@ -408,9 +321,13 @@ def test_symplectics_against_ellipse_motion(self, symplectic_stepper): ) final_time = 1.0 n_steps = 1000 + dt = final_time / n_steps + stepper = symplectic_stepper() - integrate(stepper, rod_like_system, final_time=final_time, n_steps=n_steps) + time = 0.0 + for _ in range(n_steps): + time = stepper.step_single_instance(rod_like_system, time, dt) assert_allclose( rod_like_system.position_collection, diff --git a/tests/test_memory_block_validity.py b/tests/test_memory_block_validity.py index 12fbd234..0c264954 100644 --- a/tests/test_memory_block_validity.py +++ b/tests/test_memory_block_validity.py @@ -749,10 +749,10 @@ def test_periodic_boundary_one_ring_rod(): block_structure = MemoryBlockCosseratRod([ring_rod], [0]) correct_periodic_boundary_node_idx = np.array( - [[0, 6, 7], [5, 1, 2]], dtype=np.int64 + [[0, 6, 7], [5, 1, 2]], dtype=np.int32 ) - correct_periodic_boundary_elem_idx = np.array([[0, 6], [5, 1]], dtype=np.int64) - correct_periodic_boundary_voronoi_idx = np.array([[0], [5]], dtype=np.int64) + correct_periodic_boundary_elem_idx = np.array([[0, 6], [5, 1]], dtype=np.int32) + correct_periodic_boundary_voronoi_idx = np.array([[0], [5]], dtype=np.int32) assert_allclose( correct_periodic_boundary_node_idx, @@ -770,14 +770,14 @@ def test_periodic_boundary_one_ring_rod(): atol=Tolerance.atol(), ) - correct_start_node_idx = np.array([1], dtype=np.int64) - correct_end_node_idx = np.array([6], dtype=np.int64) + correct_start_node_idx = np.array([1], dtype=np.int32) + correct_end_node_idx = np.array([6], dtype=np.int32) - correct_start_elem_idx = np.array([1], dtype=np.int64) - correct_end_elem_idx = np.array([6], dtype=np.int64) + correct_start_elem_idx = np.array([1], dtype=np.int32) + correct_end_elem_idx = np.array([6], dtype=np.int32) - correct_start_voronoi_idx = np.array([1], dtype=np.int64) - correct_end_voronoi_idx = np.array([6], dtype=np.int64) + correct_start_voronoi_idx = np.array([1], dtype=np.int32) + correct_end_voronoi_idx = np.array([6], dtype=np.int32) assert_allclose( correct_start_node_idx, @@ -826,12 +826,12 @@ def test_periodic_boundary_two_ring_rod(): block_structure = MemoryBlockCosseratRod([ring_rod_1, ring_rod_2], [0, 1]) correct_periodic_boundary_node_idx = np.array( - [[0, 6, 7, 9, 15, 16], [5, 1, 2, 14, 10, 11]], dtype=np.int64 + [[0, 6, 7, 9, 15, 16], [5, 1, 2, 14, 10, 11]], dtype=np.int32 ) correct_periodic_boundary_elem_idx = np.array( - [[0, 6, 9, 15], [5, 1, 14, 10]], dtype=np.int64 + [[0, 6, 9, 15], [5, 1, 14, 10]], dtype=np.int32 ) - correct_periodic_boundary_voronoi_idx = np.array([[0, 9], [5, 14]], dtype=np.int64) + correct_periodic_boundary_voronoi_idx = np.array([[0, 9], [5, 14]], dtype=np.int32) assert_allclose( correct_periodic_boundary_node_idx, @@ -849,14 +849,14 @@ def test_periodic_boundary_two_ring_rod(): atol=Tolerance.atol(), ) - correct_start_node_idx = np.array([1, 10], dtype=np.int64) - correct_end_node_idx = np.array([6, 15], dtype=np.int64) + correct_start_node_idx = np.array([1, 10], dtype=np.int32) + correct_end_node_idx = np.array([6, 15], dtype=np.int32) - correct_start_elem_idx = np.array([1, 10], dtype=np.int64) - correct_end_elem_idx = np.array([6, 15], dtype=np.int64) + correct_start_elem_idx = np.array([1, 10], dtype=np.int32) + correct_end_elem_idx = np.array([6, 15], dtype=np.int32) - correct_start_voronoi_idx = np.array([1, 10], dtype=np.int64) - correct_end_voronoi_idx = np.array([6, 15], dtype=np.int64) + correct_start_voronoi_idx = np.array([1, 10], dtype=np.int32) + correct_end_voronoi_idx = np.array([6, 15], dtype=np.int32) assert_allclose( correct_start_node_idx, @@ -906,12 +906,12 @@ def test_periodic_boundary_two_ring_rod_different_nelems(): block_structure = MemoryBlockCosseratRod([ring_rod_1, ring_rod_2], [0, 1]) correct_periodic_boundary_node_idx = np.array( - [[0, 6, 7, 9, 13, 14], [5, 1, 2, 12, 10, 11]], dtype=np.int64 + [[0, 6, 7, 9, 13, 14], [5, 1, 2, 12, 10, 11]], dtype=np.int32 ) correct_periodic_boundary_elem_idx = np.array( - [[0, 6, 9, 13], [5, 1, 12, 10]], dtype=np.int64 + [[0, 6, 9, 13], [5, 1, 12, 10]], dtype=np.int32 ) - correct_periodic_boundary_voronoi_idx = np.array([[0, 9], [5, 12]], dtype=np.int64) + correct_periodic_boundary_voronoi_idx = np.array([[0, 9], [5, 12]], dtype=np.int32) assert_allclose( correct_periodic_boundary_node_idx, @@ -929,14 +929,14 @@ def test_periodic_boundary_two_ring_rod_different_nelems(): atol=Tolerance.atol(), ) - correct_start_node_idx = np.array([1, 10], dtype=np.int64) - correct_end_node_idx = np.array([6, 13], dtype=np.int64) + correct_start_node_idx = np.array([1, 10], dtype=np.int32) + correct_end_node_idx = np.array([6, 13], dtype=np.int32) - correct_start_elem_idx = np.array([1, 10], dtype=np.int64) - correct_end_elem_idx = np.array([6, 13], dtype=np.int64) + correct_start_elem_idx = np.array([1, 10], dtype=np.int32) + correct_end_elem_idx = np.array([6, 13], dtype=np.int32) - correct_start_voronoi_idx = np.array([1, 10], dtype=np.int64) - correct_end_voronoi_idx = np.array([6, 13], dtype=np.int64) + correct_start_voronoi_idx = np.array([1, 10], dtype=np.int32) + correct_end_voronoi_idx = np.array([6, 13], dtype=np.int32) assert_allclose( correct_start_node_idx, @@ -985,10 +985,10 @@ def test_periodic_boundary_one_ring_one_straight_rod(): block_structure = MemoryBlockCosseratRod([ring_rod, straight_rod], [0, 1]) correct_periodic_boundary_node_idx = np.array( - [[7, 13, 14], [12, 8, 9]], dtype=np.int64 + [[7, 13, 14], [12, 8, 9]], dtype=np.int32 ) - correct_periodic_boundary_elem_idx = np.array([[7, 13], [12, 8]], dtype=np.int64) - correct_periodic_boundary_voronoi_idx = np.array([[7], [12]], dtype=np.int64) + correct_periodic_boundary_elem_idx = np.array([[7, 13], [12, 8]], dtype=np.int32) + correct_periodic_boundary_voronoi_idx = np.array([[7], [12]], dtype=np.int32) assert_allclose( correct_periodic_boundary_node_idx, @@ -1006,14 +1006,14 @@ def test_periodic_boundary_one_ring_one_straight_rod(): atol=Tolerance.atol(), ) - correct_start_node_idx = np.array([0, 8], dtype=np.int64) - correct_end_node_idx = np.array([6, 13], dtype=np.int64) + correct_start_node_idx = np.array([0, 8], dtype=np.int32) + correct_end_node_idx = np.array([6, 13], dtype=np.int32) - correct_start_elem_idx = np.array([0, 8], dtype=np.int64) - correct_end_elem_idx = np.array([5, 13], dtype=np.int64) + correct_start_elem_idx = np.array([0, 8], dtype=np.int32) + correct_end_elem_idx = np.array([5, 13], dtype=np.int32) - correct_start_voronoi_idx = np.array([0, 8], dtype=np.int64) - correct_end_voronoi_idx = np.array([4, 13], dtype=np.int64) + correct_start_voronoi_idx = np.array([0, 8], dtype=np.int32) + correct_end_voronoi_idx = np.array([4, 13], dtype=np.int32) assert_allclose( correct_start_node_idx, diff --git a/tests/test_modules/test_base_system.py b/tests/test_modules/test_base_system.py index 76f63cd4..694bd84f 100644 --- a/tests/test_modules/test_base_system.py +++ b/tests/test_modules/test_base_system.py @@ -25,11 +25,12 @@ def test_check_type_with_illegal_type_throws(self, illegal_type): @pytest.fixture(scope="class") def load_collection(self): bsc = BaseSystemCollection() + bsc.extend_allowed_types((int, float, str, np.ndarray)) # Bypass check, but its fine for testing - bsc._systems.append(3) - bsc._systems.append(5.0) - bsc._systems.append("a") - bsc._systems.append(np.random.randn(3, 5)) + bsc.append(3) + bsc.append(5.0) + bsc.append("a") + bsc.append(np.random.randn(3, 5)) return bsc def test_len(self, load_collection): @@ -70,12 +71,12 @@ def test_str(self, load_collection): def test_extend_allowed_types(self, load_collection): bsc = load_collection - bsc.extend_allowed_types((int, float, str)) from elastica.rod import RodBase from elastica.rigidbody import RigidBodyBase from elastica.surface import SurfaceBase + # Types are extended in the fixture assert bsc.allowed_sys_types == ( RodBase, RigidBodyBase, @@ -83,6 +84,7 @@ def test_extend_allowed_types(self, load_collection): int, float, str, + np.ndarray, ) def test_extend_correctness(self, load_collection): @@ -121,11 +123,11 @@ def test_invalid_idx_in_get_sys_index_throws(self, load_collection): bsc = load_collection bsc.override_allowed_types((RodBase,)) with pytest.raises(AssertionError) as excinfo: - bsc._get_sys_idx_if_valid(100) + bsc.get_system_index(100) assert "exceeds number of" in str(excinfo.value) with pytest.raises(AssertionError) as excinfo: - load_collection._get_sys_idx_if_valid(np.int_(100)) + load_collection.get_system_index(np.int32(100)) assert "exceeds number of" in str(excinfo.value) def test_unregistered_system_in_get_sys_index_throws( @@ -135,11 +137,11 @@ def test_unregistered_system_in_get_sys_index_throws( my_mock_rod = mock_rod with pytest.raises(ValueError) as excinfo: - load_collection._get_sys_idx_if_valid(my_mock_rod) + load_collection.get_system_index(my_mock_rod) assert "was not found, did you" in str(excinfo.value) def test_get_sys_index_returns_correct_idx(self, load_collection): - assert load_collection._get_sys_idx_if_valid(1) == 1 + assert load_collection.get_system_index(1) == 1 @pytest.mark.xfail def test_delitem(self, load_collection): @@ -171,7 +173,7 @@ def load_collection(self): youngs_modulus=1, ) # Bypass check, but its fine for testing - sc._systems.append(rod) + sc.append(rod) return sc, rod @@ -183,7 +185,9 @@ def test_constraint(self, load_collection, legal_constraint): simulator_class.constrain(rod).using(legal_constraint) simulator_class.finalize() # After finalize check if the created constrain object is instance of the class we have given. - assert isinstance(simulator_class._constraints[-1][-1], legal_constraint) + assert isinstance( + simulator_class._constraints_operators[-1][-1], legal_constraint + ) # TODO: this is a dummy test for constrain values and rates find a better way to test them simulator_class.constrain_values(time=0) @@ -221,7 +225,7 @@ def test_callback(self, load_collection, legal_callback): simulator_class.collect_diagnostics(rod).using(legal_callback) simulator_class.finalize() # After finalize check if the created callback object is instance of the class we have given. - assert isinstance(simulator_class._callback_list[-1][-1], legal_callback) + assert isinstance(simulator_class._callback_operators[-1][-1], legal_callback) # TODO: this is a dummy test for apply_callbacks find a better way to test them simulator_class.apply_callbacks(time=0, current_step=0) diff --git a/tests/test_modules/test_callbacks.py b/tests/test_modules/test_callbacks.py index 57d4a426..a58901e2 100644 --- a/tests/test_modules/test_callbacks.py +++ b/tests/test_modules/test_callbacks.py @@ -55,7 +55,7 @@ def mock_init(self, *args, **kwargs): # Actual test is here, this should not throw with pytest.raises(TypeError) as excinfo: - _ = callback() + _ = callback.instantiate() assert "Unable to construct" in str(excinfo.value) @@ -80,7 +80,7 @@ def load_system_with_callbacks(self, request): sys_coll_with_callbacks.append(self.MockRod(2, 3, 4, 5)) return sys_coll_with_callbacks - """ The following calls test _get_sys_idx_if_valid from BaseSystem indirectly, + """ The following calls test get_system_index from BaseSystem indirectly, and are here because of legacy reasons. I have not removed them because there are Callbacks require testing against multiple indices, which is still use ful to cross-verify against. @@ -96,7 +96,7 @@ def test_callback_with_illegal_index_throws(self, load_system_with_callbacks): assert "exceeds number of" in str(excinfo.value) with pytest.raises(AssertionError) as excinfo: - scwc.collect_diagnostics(np.int_(100)) + scwc.collect_diagnostics(np.int32(100)) assert "exceeds number of" in str(excinfo.value) def test_callback_with_unregistered_system_throws(self, load_system_with_callbacks): @@ -164,10 +164,12 @@ def test_callback_finalize_correctness(self, load_rod_with_callbacks): scwc._finalize_callback() - for x, y in scwc._callback_list: + for x, y in scwc._callback_operators: assert type(x) is int assert type(y) is callback_cls + assert not hasattr(scwc, "_callback_list") + @pytest.mark.xfail def test_callback_finalize_sorted(self, load_rod_with_callbacks): scwc, callback_cls = load_rod_with_callbacks diff --git a/tests/test_modules/test_connections.py b/tests/test_modules/test_connections.py index 334ece41..eadf98a0 100644 --- a/tests/test_modules/test_connections.py +++ b/tests/test_modules/test_connections.py @@ -108,7 +108,7 @@ def test_set_index_with_illegal_type_second_idx_throws( # Below test is to increase code coverage. If we pass nothing or idx=None, then do nothing. def test_set_index_no_input(self, load_connect): - load_connect.set_index(first_idx=None, second_idx=None) + load_connect.set_index(first_idx=(), second_idx=()) @pytest.mark.parametrize( "legal_idx", [(80, 80), (0, 50), (50, 0), (-20, -20), (-20, 50), (-50, -20)] @@ -201,7 +201,7 @@ def load_system_with_connects(self, request): sys_coll_with_connects.append(self.MockRod(2, 3, 4, 5)) return sys_coll_with_connects - """ The following calls test _get_sys_idx_if_valid from BaseSystem indirectly, + """ The following calls test get_system_index from BaseSystem indirectly, and are here because of legacy reasons. I have not removed them because there are Connections require testing against multiple indices, which is still use ful to cross-verify against. @@ -232,7 +232,7 @@ def test_connect_with_illegal_index_throws( assert "exceeds number of" in str(excinfo.value) with pytest.raises(AssertionError) as excinfo: - system_collection_with_connections.connect(*[np.int_(x) for x in sys_idx]) + system_collection_with_connections.connect(*[np.int32(x) for x in sys_idx]) assert "exceeds number of" in str(excinfo.value) def test_connect_with_unregistered_system_throws(self, load_system_with_connects): @@ -291,7 +291,8 @@ def test_connect_registers_and_returns_Connect(self, load_system_with_connects): assert _mock_connect in system_collection_with_connections._connections assert _mock_connect.__class__ == _Connect # check sane defaults provided for connection indices - assert _mock_connect.id()[2] is None and _mock_connect.id()[3] is None + assert _mock_connect.id()[2] == () + assert _mock_connect.id()[3] == () from elastica.joint import FreeJoint @@ -313,15 +314,15 @@ def mock_init(self, *args, **kwargs): ) # Constrain any and all systems - system_collection_with_connections.connect(0, 1).using( - MockConnect, 2, 42 - ) # index based connect + # system_collection_with_connections.connect(0, 1).using( + # MockConnect, 2, 42 + # ) # index based connect system_collection_with_connections.connect(mock_rod_one, mock_rod_two).using( MockConnect, 2, 3 ) # system based connect - system_collection_with_connections.connect(0, mock_rod_one).using( - MockConnect, 1, 2 - ) # index/system based connect + # system_collection_with_connections.connect(0, mock_rod_one).using( + # MockConnect, 1, 2 + # ) # index/system based connect return system_collection_with_connections, MockConnect @@ -401,20 +402,20 @@ def test_connect_call_on_systems(self, load_rod_with_connects_and_indices): connect = connection.instantiate() end_distance_vector = ( - system_collection_with_connections_and_indices._systems[ + system_collection_with_connections_and_indices[ sidx ].position_collection[..., sconnect] - - system_collection_with_connections_and_indices._systems[ + - system_collection_with_connections_and_indices[ fidx ].position_collection[..., fconnect] ) elastic_force = connect.k * end_distance_vector relative_velocity = ( - system_collection_with_connections_and_indices._systems[ + system_collection_with_connections_and_indices[ sidx ].velocity_collection[..., sconnect] - - system_collection_with_connections_and_indices._systems[ + - system_collection_with_connections_and_indices[ fidx ].velocity_collection[..., fconnect] ) @@ -423,16 +424,16 @@ def test_connect_call_on_systems(self, load_rod_with_connects_and_indices): contact_force = elastic_force + damping_force assert_allclose( - system_collection_with_connections_and_indices._systems[ - fidx - ].external_forces[..., fconnect], + system_collection_with_connections_and_indices[fidx].external_forces[ + ..., fconnect + ], contact_force, atol=Tolerance.atol(), ) assert_allclose( - system_collection_with_connections_and_indices._systems[ - sidx - ].external_forces[..., sconnect], + system_collection_with_connections_and_indices[sidx].external_forces[ + ..., sconnect + ], -1 * contact_force, atol=Tolerance.atol(), ) diff --git a/tests/test_modules/test_constraints.py b/tests/test_modules/test_constraints.py index aeee1873..c6072c39 100644 --- a/tests/test_modules/test_constraints.py +++ b/tests/test_modules/test_constraints.py @@ -45,7 +45,7 @@ def test_call_without_setting_constraint_throws_runtime_error( constraint = load_constraint with pytest.raises(RuntimeError) as excinfo: - constraint(None) # None is the rod/system parameter + constraint.instantiate(None) # None is the rod/system parameter assert "No boundary condition" in str(excinfo.value) def test_call_without_position_director_kwargs(self, load_constraint): @@ -60,7 +60,7 @@ def mock_init(self, *args, **kwargs): constraint.using(MockBC, 3.9, 4.0, "5", k=1, l_var="2", j=3.0) # Actual test is here, this should not throw - mock_bc = constraint(None) # None is Fake rod + mock_bc = constraint.instantiate(None) # None is Fake rod # More tests reinforcing the first assert mock_bc.dummy_one == 3.9 @@ -93,7 +93,7 @@ def mock_init(self, *args, **kwargs): # Actual test is here, this should not throw mock_rod = self.MockRod() - mock_bc = constraint(mock_rod) + mock_bc = constraint.instantiate(mock_rod) # More tests reinforcing the first for pos_idx_in_rod, pos_idx_in_bc in zip(position_indices, range(3)): @@ -125,7 +125,7 @@ def mock_init(self, *args, **kwargs): # Actual test is here, this should not throw mock_rod = self.MockRod() - mock_bc = constraint(mock_rod) + mock_bc = constraint.instantiate(mock_rod) # More tests reinforcing the first for dir_idx_in_rod, dir_idx_in_bc in zip(director_indices, range(3)): @@ -160,7 +160,7 @@ def mock_init(self, *args, **kwargs): # Actual test is here, this should not throw mock_rod = self.MockRod() - mock_bc = constraint(mock_rod) + mock_bc = constraint.instantiate(mock_rod) # More tests reinforcing the first pos_dir_offset = len(dof_indices) @@ -202,7 +202,7 @@ def mock_init(self, nu, **kwargs): mock_rod = self.MockRod() # Actual test is here, this should not throw with pytest.raises(TypeError) as excinfo: - _ = constraint(mock_rod) + _ = constraint.instantiate(mock_rod) assert "Unable to construct" in str(excinfo.value) @@ -227,7 +227,7 @@ def load_system_with_constraints(self, request): sys_coll_with_constraints.append(self.MockRod(2, 3, 4, 5)) return sys_coll_with_constraints - """ The following calls test _get_sys_idx_if_valid from BaseSystem indirectly, + """ The following calls test get_system_index from BaseSystem indirectly, and are here because of legacy reasons. I have not removed them because there are Connections require testing against multiple indices, which is still use ful to cross-verify against. @@ -243,7 +243,7 @@ def test_constrain_with_illegal_index_throws(self, load_system_with_constraints) assert "exceeds number of" in str(excinfo.value) with pytest.raises(AssertionError) as excinfo: - scwc.constrain(np.int_(100)) + scwc.constrain(np.int32(100)) assert "exceeds number of" in str(excinfo.value) def test_constrain_with_unregistered_system_throws( @@ -281,7 +281,7 @@ def test_constrain_registers_and_returns_Constraint( scwc.append(mock_rod) _mock_constraint = scwc.constrain(mock_rod) - assert _mock_constraint in scwc._constraints + assert _mock_constraint in scwc._constraints_list assert _mock_constraint.__class__ == _Constraint from elastica.boundary_conditions import ConstraintBase @@ -318,7 +318,7 @@ def test_constrain_finalize_correctness(self, load_rod_with_constraints): scwc._finalize_constraints() - for x, y in scwc._constraints: + for x, y in scwc._constraints_operators: assert type(x) is int assert type(y) is bc_cls @@ -327,12 +327,12 @@ def test_constraint_properties(self, load_rod_with_constraints): scwc._finalize_constraints() for i in [0, 1, -1]: - x, y = scwc._constraints[i] - mock_rod = scwc._systems[i] + x, y = scwc._constraints_operators[i] + mock_rod = scwc[i] # Test system assert type(x) is int assert type(y.system) is type(mock_rod) - assert y.system is mock_rod, f"{len(scwc._systems)}" + assert y.system is mock_rod, f"{len(scwc)}" # Test node indices assert y.constrained_position_idx.size == 0 # Test element indices. TODO: maybe add more generalized test @@ -346,7 +346,7 @@ def test_constrain_finalize_sorted(self, load_rod_with_constraints): # this is allowed to fail (not critical) num = -np.inf - for x, _ in scwc._constraints: + for x, _ in scwc._constraints_list: assert num < x num = x diff --git a/tests/test_modules/test_contact.py b/tests/test_modules/test_contact.py index 82c41f69..79abb7c5 100644 --- a/tests/test_modules/test_contact.py +++ b/tests/test_modules/test_contact.py @@ -125,7 +125,7 @@ def load_system_with_contacts(self, request): sys_coll_with_contacts.append(self.MockRod(2, 3, 4, 5)) return sys_coll_with_contacts - """ The following calls test _get_sys_idx_if_valid from BaseSystem indirectly, + """ The following calls test get_system_index from BaseSystem indirectly, and are here because of legacy reasons. I have not removed them because there are Contacts require testing against multiple indices, which is still use ful to cross-verify against. @@ -157,7 +157,7 @@ def test_contact_with_illegal_index_throws( with pytest.raises(AssertionError) as excinfo: system_collection_with_contacts.detect_contact_between( - *[np.int_(x) for x in sys_idx] + *[np.int32(x) for x in sys_idx] ) assert "exceeds number of" in str(excinfo.value) @@ -240,9 +240,17 @@ def mock_init(self, *args, **kwargs): pass # in place class - MockContact = type( - "MockContact", (self.NoContact, object), {"__init__": mock_init} - ) + class MockContact(self.NoContact): + def __init__(self, *args, **kwargs): + pass + + @property + def _allowed_system_one(self): + return [TestContactMixin.MockRod] + + @property + def _allowed_system_two(self): + return [TestContactMixin.MockRod] # Constrain any and all systems system_collection_with_contacts.detect_contact_between(0, 1).using( @@ -283,9 +291,17 @@ def mock_init(self, *args, **kwargs): pass # in place class - MockContact = type( - "MockContact", (self.NoContact, object), {"__init__": mock_init} - ) + class MockContact(self.NoContact): + def __init__(self, *args, **kwargs): + pass + + @property + def _allowed_system_one(self): + return [TestContactMixin.MockRod] + + @property + def _allowed_system_two(self): + return [TestContactMixin.MockRigidBody] # incorrect order contact system_collection_with_contacts.detect_contact_between( @@ -302,17 +318,12 @@ def test_contact_check_order(self, load_contact_objects_with_incorrect_order): contact_cls, ) = load_contact_objects_with_incorrect_order - mock_rod = self.MockRod(2, 3, 4, 5) - mock_rigid_body = self.MockRigidBody(5.0, 5.0) - with pytest.raises(TypeError) as excinfo: system_collection_with_contacts._finalize_contact() assert ( - "Systems provided to the contact class have incorrect order. \n" - " First system is {0} and second system is {1}. \n" - " If the first system is a rod, the second system can be a rod, rigid body or surface. \n" - " If the first system is a rigid body, the second system can be a rigid body or surface." - ).format(mock_rigid_body.__class__, mock_rod.__class__) == str(excinfo.value) + "System provided (MockRigidBody) must be derived from ['MockRod']" + in str(excinfo.value) + ) @pytest.fixture def load_system_with_rods_in_contact(self, load_system_with_contacts): @@ -353,8 +364,8 @@ def test_contact_call_on_systems(self, load_system_with_rods_in_contact): fidx, sidx = _contact.id() contact = _contact.instantiate() - system_one = system_collection_with_rods_in_contact._systems[fidx] - system_two = system_collection_with_rods_in_contact._systems[sidx] + system_one = system_collection_with_rods_in_contact[fidx] + system_two = system_collection_with_rods_in_contact[sidx] external_forces_system_one = np.zeros_like(system_one.external_forces) external_forces_system_two = np.zeros_like(system_two.external_forces) @@ -382,12 +393,12 @@ def test_contact_call_on_systems(self, load_system_with_rods_in_contact): ) assert_allclose( - system_collection_with_rods_in_contact._systems[fidx].external_forces, + system_collection_with_rods_in_contact[fidx].external_forces, external_forces_system_one, atol=Tolerance.atol(), ) assert_allclose( - system_collection_with_rods_in_contact._systems[sidx].external_forces, + system_collection_with_rods_in_contact[sidx].external_forces, external_forces_system_two, atol=Tolerance.atol(), ) diff --git a/tests/test_modules/test_damping.py b/tests/test_modules/test_damping.py index d36f0187..2634da84 100644 --- a/tests/test_modules/test_damping.py +++ b/tests/test_modules/test_damping.py @@ -39,7 +39,7 @@ def test_call_without_setting_damper_throws_runtime_error(self, load_damper): damper = load_damper with pytest.raises(RuntimeError) as excinfo: - damper(None) # None is the rod/system parameter + damper.instantiate(None) # None is the rod/system parameter assert "No damper" in str(excinfo.value) def test_call_with_args_and_kwargs(self, load_damper): @@ -56,7 +56,7 @@ def mock_init(self, *args, **kwargs): damper.using(MockDamper, 3.9, 4.0, "5", k=1, l_var="2", j=3.0) # Actual test is here, this should not throw - mock_damper = damper(None) # None is Fake rod + mock_damper = damper.instantiate(None) # None is Fake rod # More tests reinforcing the first assert mock_damper.dummy_one == 3.9 @@ -78,7 +78,7 @@ def test_call_improper_bc_throws_type_error(self, load_damper): mock_rod = self.MockRod() # Actual test is here, this should not throw with pytest.raises(TypeError) as excinfo: - _ = damper(mock_rod) + _ = damper.instantiate(mock_rod) assert "Unable to construct" in str(excinfo.value) @@ -103,7 +103,7 @@ def load_system_with_dampers(self, request): sys_coll_with_dampers.append(self.MockRod(2, 3, 4, 5)) return sys_coll_with_dampers - """ The following calls test _get_sys_idx_if_valid from BaseSystem indirectly, + """ The following calls test get_system_index from BaseSystem indirectly, and are here because of legacy reasons. I have not removed them because there are Connections require testing against multiple indices, which is still use ful to cross-verify against. @@ -119,7 +119,7 @@ def test_dampen_with_illegal_index_throws(self, load_system_with_dampers): assert "exceeds number of" in str(excinfo.value) with pytest.raises(AssertionError) as excinfo: - scwd.dampen(np.int_(100)) + scwd.dampen(np.int32(100)) assert "exceeds number of" in str(excinfo.value) def test_dampen_with_unregistered_system_throws(self, load_system_with_dampers): @@ -153,7 +153,7 @@ def test_dampen_registers_and_returns_Damper(self, load_system_with_dampers): scwd.append(mock_rod) _mock_damper = scwd.dampen(mock_rod) - assert _mock_damper in scwd._dampers + assert _mock_damper in scwd._damping_list assert _mock_damper.__class__ == _Damper from elastica.dissipation import DamperBase @@ -185,7 +185,7 @@ def test_dampen_finalize_correctness(self, load_rod_with_dampers): scwd._finalize_dampers() - for x, y in scwd._dampers: + for x, y in scwd._damping_operators: assert type(x) is int assert type(y) is damper_cls @@ -194,12 +194,12 @@ def test_damper_properties(self, load_rod_with_dampers): scwd._finalize_dampers() for i in [0, 1, -1]: - x, y = scwd._dampers[i] - mock_rod = scwd._systems[i] + x, y = scwd._damping_operators[i] + mock_rod = scwd[i] # Test system assert type(x) is int assert type(y.system) is type(mock_rod) - assert y.system is mock_rod, f"{len(scwd._systems)}" + assert y.system is mock_rod, f"{len(scwd)}" @pytest.mark.xfail def test_dampers_finalize_sorted(self, load_rod_with_dampers): diff --git a/tests/test_modules/test_forcing.py b/tests/test_modules/test_forcing.py index bd384fc6..4f621fc4 100644 --- a/tests/test_modules/test_forcing.py +++ b/tests/test_modules/test_forcing.py @@ -87,7 +87,7 @@ def load_system_with_forcings(self, request): sys_coll_with_forcings.append(self.MockRod(2, 3, 4, 5)) return sys_coll_with_forcings - """ The following calls test _get_sys_idx_if_valid from BaseSystem indirectly, + """ The following calls test get_system_index from BaseSystem indirectly, and are here because of legacy reasons. I have not removed them because there are Connections require testing against multiple indices, which is still use ful to cross-verify against. @@ -103,7 +103,7 @@ def test_constrain_with_illegal_index_throws(self, load_system_with_forcings): assert "exceeds number of" in str(excinfo.value) with pytest.raises(AssertionError) as excinfo: - scwf.add_forcing_to(np.int_(100)) + scwf.add_forcing_to(np.int32(100)) assert "exceeds number of" in str(excinfo.value) def test_constrain_with_unregistered_system_throws(self, load_system_with_forcings): diff --git a/tests/test_modules/test_memory_block_rigid_body.py b/tests/test_modules/test_memory_block_rigid_body.py index 7eb4882e..07b11d04 100644 --- a/tests/test_modules/test_memory_block_rigid_body.py +++ b/tests/test_modules/test_memory_block_rigid_body.py @@ -81,7 +81,7 @@ def test_memory_block_rigid_body(n_bodies): memory_block = MemoryBlockRigidBody(systems, system_idx_list) - assert memory_block.n_bodies == n_bodies + assert memory_block.n_systems == n_bodies assert memory_block.n_elems == n_bodies assert memory_block.n_nodes == n_bodies diff --git a/tests/test_modules/test_memory_block_base.py b/tests/test_modules/test_memory_block_utils.py similarity index 70% rename from tests/test_modules/test_memory_block_base.py rename to tests/test_modules/test_memory_block_utils.py index 9dde3923..663a3aa2 100644 --- a/tests/test_modules/test_memory_block_base.py +++ b/tests/test_modules/test_memory_block_utils.py @@ -3,7 +3,7 @@ import pytest import numpy as np from numpy.testing import assert_array_equal -from elastica.memory_block.memory_block_rod_base import ( +from elastica.memory_block.utils import ( make_block_memory_metadata, make_block_memory_periodic_boundary_metadata, ) @@ -13,48 +13,48 @@ "n_elems_in_rods, outputs", [ ( - np.array([5], dtype=np.int64), + np.array([5], dtype=np.int32), [ np.int64(5), - np.array([], dtype=np.int64), - np.array([], dtype=np.int64), - np.array([], dtype=np.int64), + np.array([], dtype=np.int32), + np.array([], dtype=np.int32), + np.array([], dtype=np.int32), ], ), ( - np.array([5, 5], dtype=np.int64), + np.array([5, 5], dtype=np.int32), [ np.int64(12), - np.array([6], dtype=np.int64), - np.array([5, 6], dtype=np.int64), - np.array([4, 5, 6], dtype=np.int64), + np.array([6], dtype=np.int32), + np.array([5, 6], dtype=np.int32), + np.array([4, 5, 6], dtype=np.int32), ], ), ( - np.array([1, 1, 1], dtype=np.int64), + np.array([1, 1, 1], dtype=np.int32), [ np.int64(7), - np.array([2, 5], dtype=np.int64), - np.array([1, 2, 4, 5], dtype=np.int64), - np.array([0, 1, 2, 3, 4, 5], dtype=np.int64), + np.array([2, 5], dtype=np.int32), + np.array([1, 2, 4, 5], dtype=np.int32), + np.array([0, 1, 2, 3, 4, 5], dtype=np.int32), ], ), ( - np.array([1, 2, 3], dtype=np.int64), + np.array([1, 2, 3], dtype=np.int32), [ np.int64(10), - np.array([2, 6], dtype=np.int64), - np.array([1, 2, 5, 6], dtype=np.int64), - np.array([0, 1, 2, 4, 5, 6], dtype=np.int64), + np.array([2, 6], dtype=np.int32), + np.array([1, 2, 5, 6], dtype=np.int32), + np.array([0, 1, 2, 4, 5, 6], dtype=np.int32), ], ), ( - np.array([10, 10, 5, 5], dtype=np.int64), + np.array([10, 10, 5, 5], dtype=np.int32), [ np.int64(36), - np.array([11, 23, 30], dtype=np.int64), - np.array([10, 11, 22, 23, 29, 30], dtype=np.int64), - np.array([9, 10, 11, 21, 22, 23, 28, 29, 30], dtype=np.int64), + np.array([11, 23, 30], dtype=np.int32), + np.array([10, 11, 22, 23, 29, 30], dtype=np.int32), + np.array([9, 10, 11, 21, 22, 23, 28, 29, 30], dtype=np.int32), ], ), ], @@ -76,10 +76,10 @@ def test_make_block_memory_metadata(n_elems_in_rods, outputs): @pytest.mark.parametrize( "n_elems_in_ring_rods", [ - np.array([10], dtype=np.int64), - np.array([2, 4], dtype=np.int64), - np.array([4, 5, 7], dtype=np.int64), - np.array([10, 10, 10, 10], dtype=np.int64), + np.array([10], dtype=np.int32), + np.array([2, 4], dtype=np.int32), + np.array([4, 5, 7], dtype=np.int32), + np.array([10, 10, 10, 10], dtype=np.int32), ], ) def test_make_block_memory_periodic_boundary_metadata(n_elems_in_ring_rods): @@ -92,9 +92,9 @@ def test_make_block_memory_periodic_boundary_metadata(n_elems_in_ring_rods): n_ring_rods = len(n_elems_in_ring_rods) expected_n_elem = n_elems_in_ring_rods + 2 - expected_node_idx = np.empty((2, 3 * n_ring_rods), dtype=np.int64) - expected_element_idx = np.empty((2, 2 * n_ring_rods), dtype=np.int64) - expected_voronoi_idx = np.empty((2, n_ring_rods), dtype=np.int64) + expected_node_idx = np.empty((2, 3 * n_ring_rods), dtype=np.int32) + expected_element_idx = np.empty((2, 2 * n_ring_rods), dtype=np.int32) + expected_voronoi_idx = np.empty((2, n_ring_rods), dtype=np.int32) accumulation = np.hstack((0, np.cumsum(n_elems_in_ring_rods[:-1] + 4))) diff --git a/tests/test_restart.py b/tests/test_restart.py index 0f814b21..28d30595 100644 --- a/tests/test_restart.py +++ b/tests/test_restart.py @@ -41,20 +41,20 @@ def load_collection(self): youngs_modulus=1, ) # Bypass check, but its fine for testing - sc._systems.append(rod) + sc.append(rod) # Also add rods to a separate list rod_list.append(rod) return sc, rod_list - def test_restart_save_load(self, load_collection): + def test_restart_save_load(self, tmp_path, load_collection): simulator_class, rod_list = load_collection # Finalize simulator simulator_class.finalize() - directory = "restart_test_data/" + directory = (tmp_path / "restart_test_data").as_posix() time = np.random.rand() # save state @@ -79,7 +79,7 @@ def test_restart_save_load(self, load_collection): assert_allclose(test_value, correct_value) - def run_sim(self, final_time, load_from_restart, save_data_restart): + def run_sim(self, final_time, load_from_restart, save_data_restart, tmp_path): class BaseSimulatorClass( BaseSystemCollection, Constraints, Forcing, Connections, CallBacks ): @@ -100,7 +100,7 @@ class BaseSimulatorClass( youngs_modulus=1, ) # Bypass check, but its fine for testing - simulator_class._systems.append(rod) + simulator_class.append(rod) # Also add rods to a separate list rod_list.append(rod) @@ -118,7 +118,7 @@ class BaseSimulatorClass( # Finalize simulator simulator_class.finalize() - directory = "restart_test_data/" + directory = (tmp_path / "restart_test_data").as_posix() time_step = 1e-4 total_steps = int(final_time / time_step) @@ -149,22 +149,31 @@ class BaseSimulatorClass( return recorded_list @pytest.mark.parametrize("final_time", [0.2, 1.0]) - def test_save_restart_run_sim(self, final_time): + def test_save_restart_run_sim(self, tmp_path, final_time): # First half of simulation _ = self.run_sim( - final_time / 2, load_from_restart=False, save_data_restart=True + final_time / 2, + load_from_restart=False, + save_data_restart=True, + tmp_path=tmp_path, ) # Second half of simulation recorded_list = self.run_sim( - final_time / 2, load_from_restart=True, save_data_restart=False + final_time / 2, + load_from_restart=True, + save_data_restart=False, + tmp_path=tmp_path, ) recorded_list_second_half = recorded_list.copy() # Full simulation recorded_list = self.run_sim( - final_time, load_from_restart=False, save_data_restart=False + final_time, + load_from_restart=False, + save_data_restart=False, + tmp_path=tmp_path, ) recorded_list_full_sim = recorded_list.copy() @@ -190,20 +199,20 @@ def load_collection(self): density=1, ) # Bypass check, but its fine for testing - sc._systems.append(cylinder) + sc.append(cylinder) # Also add rods to a separate list cylinder_list.append(cylinder) return sc, cylinder_list - def test_restart_save_load(self, load_collection): + def test_restart_save_load(self, tmp_path, load_collection): simulator_class, cylinder_list = load_collection # Finalize simulator simulator_class.finalize() - directory = "restart_test_data/" + directory = (tmp_path / "restart_test_data").as_posix() time = np.random.rand() # save state diff --git a/tests/test_rigid_body/test_rigid_body_data_structures.py b/tests/test_rigid_body/test_rigid_body_data_structures.py index b482202b..0ca7715a 100644 --- a/tests/test_rigid_body/test_rigid_body_data_structures.py +++ b/tests/test_rigid_body/test_rigid_body_data_structures.py @@ -64,7 +64,7 @@ def __init__(self, start_position, start_director): # Givees position, director etc. super(SimpleSystemWithPositionsDirectors, self).__init__() - def _compute_internal_forces_and_torques(self, time): + def compute_internal_forces_and_torques(self, time): pass def update_accelerations(self, time): diff --git a/tests/test_rod/test_knot_theory.py b/tests/test_rod/test_knot_theory.py index a5630d73..a74cc252 100644 --- a/tests/test_rod/test_knot_theory.py +++ b/tests/test_rod/test_knot_theory.py @@ -10,13 +10,14 @@ from elastica.rod.rod_base import RodBase from elastica.rod.knot_theory import ( - KnotTheoryCompatibleProtocol, compute_twist, compute_writhe, compute_link, _compute_additional_segment, ) +from elastica.rod.protocol import CosseratRodProtocol + @pytest.fixture def knot_theory(): @@ -28,7 +29,7 @@ def knot_theory(): def test_knot_theory_protocol(): # To clear the protocol test coverage with pytest.raises(TypeError) as e_info: - protocol = KnotTheoryCompatibleProtocol() + protocol = CosseratRodProtocol() assert "cannot be instantiated" in e_info @@ -39,9 +40,6 @@ def __init__(self): self.radius = np.random.randn(MaxDimension.value(), self.n_elems) rod = TestRodWithKnotTheory() - assert hasattr( - rod, "MIXIN_PROTOCOL" - ), "Expected to mix-in variables: MIXIN_PROTOCOL" assert hasattr( rod, "compute_writhe" ), "Expected to mix-in functionals into the rod class: compute_writhe" diff --git a/tests/test_synchronize_periodic_boundary.py b/tests/test_synchronize_periodic_boundary.py index 08a1787c..b20d776d 100644 --- a/tests/test_synchronize_periodic_boundary.py +++ b/tests/test_synchronize_periodic_boundary.py @@ -28,7 +28,7 @@ def test_synchronize_periodic_boundary_vector(n_elems): input_vector = np.random.random((3, n_elems + 3)) - periodic_idx = np.zeros((2, 3), dtype=np.int64) + periodic_idx = np.zeros((2, 3), dtype=np.int32) periodic_idx[0, 0] = 0 periodic_idx[0, 1] = -2 periodic_idx[0, 2] = -1 @@ -63,7 +63,7 @@ def test_synchronize_periodic_boundary_matrix(n_elems): input_matrix = np.random.random((3, 3, n_elems + 3)) - periodic_idx = np.zeros((2, 3), dtype=np.int64) + periodic_idx = np.zeros((2, 3), dtype=np.int32) periodic_idx[0, 0] = 0 periodic_idx[0, 1] = -2 periodic_idx[0, 2] = -1 @@ -98,7 +98,7 @@ def test_synchronize_periodic_boundary_scalar(n_elems): input_matrix = np.random.random((n_elems + 3)) - periodic_idx = np.zeros((2, 3), dtype=np.int64) + periodic_idx = np.zeros((2, 3), dtype=np.int32) periodic_idx[0, 0] = 0 periodic_idx[0, 1] = -2 periodic_idx[0, 2] = -1