Skip to content

Commit

Permalink
Math: Implement bivariate Gaussian copula (WIP)
Browse files Browse the repository at this point in the history
Co-authored-by: Amrita Goswami <[email protected]>
  • Loading branch information
MSallermann and amritagos committed Mar 23, 2024
1 parent f7386a3 commit d140126
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 2 deletions.
110 changes: 110 additions & 0 deletions include/util/erfinv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#pragma once
#include <cmath>
#include <limits>

namespace Seldon::Math
{

// Implementation adapted from https://github.com/lakshayg/erfinv same as used in golang math library
template<typename T>
T erfinv_golang_math( T x )
{
if( x < -1 || x > 1 )
{
return std::numeric_limits<T>::quiet_NaN();
}
else if( x == 1.0 )
{
return std::numeric_limits<T>::infinity();
}
else if( x == -1.0 )
{
return -std::numeric_limits<T>::infinity();
}

const T LN2 = 6.931471805599453094172321214581e-1L;

const T A0 = 1.1975323115670912564578e0L;
const T A1 = 4.7072688112383978012285e1L;
const T A2 = 6.9706266534389598238465e2L;
const T A3 = 4.8548868893843886794648e3L;
const T A4 = 1.6235862515167575384252e4L;
const T A5 = 2.3782041382114385731252e4L;
const T A6 = 1.1819493347062294404278e4L;
const T A7 = 8.8709406962545514830200e2L;

const T B0 = 1.0000000000000000000e0L;
const T B1 = 4.2313330701600911252e1L;
const T B2 = 6.8718700749205790830e2L;
const T B3 = 5.3941960214247511077e3L;
const T B4 = 2.1213794301586595867e4L;
const T B5 = 3.9307895800092710610e4L;
const T B6 = 2.8729085735721942674e4L;
const T B7 = 5.2264952788528545610e3L;

const T C0 = 1.42343711074968357734e0L;
const T C1 = 4.63033784615654529590e0L;
const T C2 = 5.76949722146069140550e0L;
const T C3 = 3.64784832476320460504e0L;
const T C4 = 1.27045825245236838258e0L;
const T C5 = 2.41780725177450611770e-1L;
const T C6 = 2.27238449892691845833e-2L;
const T C7 = 7.74545014278341407640e-4L;

const T D0 = 1.4142135623730950488016887e0L;
const T D1 = 2.9036514445419946173133295e0L;
const T D2 = 2.3707661626024532365971225e0L;
const T D3 = 9.7547832001787427186894837e-1L;
const T D4 = 2.0945065210512749128288442e-1L;
const T D5 = 2.1494160384252876777097297e-2L;
const T D6 = 7.7441459065157709165577218e-4L;
const T D7 = 1.4859850019840355905497876e-9L;

const T E0 = 6.65790464350110377720e0L;
const T E1 = 5.46378491116411436990e0L;
const T E2 = 1.78482653991729133580e0L;
const T E3 = 2.96560571828504891230e-1L;
const T E4 = 2.65321895265761230930e-2L;
const T E5 = 1.24266094738807843860e-3L;
const T E6 = 2.71155556874348757815e-5L;
const T E7 = 2.01033439929228813265e-7L;

const T F0 = 1.414213562373095048801689e0L;
const T F1 = 8.482908416595164588112026e-1L;
const T F2 = 1.936480946950659106176712e-1L;
const T F3 = 2.103693768272068968719679e-2L;
const T F4 = 1.112800997078859844711555e-3L;
const T F5 = 2.611088405080593625138020e-5L;
const T F6 = 2.010321207683943062279931e-7L;
const T F7 = 2.891024605872965461538222e-15L;

T abs_x = std::abs( x );

T r, num, den;

if( abs_x <= 0.85 )
{
r = 0.180625 - 0.25 * x * x;
num = ( ( ( ( ( ( ( A7 * r + A6 ) * r + A5 ) * r + A4 ) * r + A3 ) * r + A2 ) * r + A1 ) * r + A0 );
den = ( ( ( ( ( ( ( B7 * r + B6 ) * r + B5 ) * r + B4 ) * r + B3 ) * r + B2 ) * r + B1 ) * r + B0 );
return x * num / den;
}

r = std::sqrt( LN2 - std::log1p( -abs_x ) );
if( r <= 5.0 )
{
r = r - 1.6;
num = ( ( ( ( ( ( ( C7 * r + C6 ) * r + C5 ) * r + C4 ) * r + C3 ) * r + C2 ) * r + C1 ) * r + C0 );
den = ( ( ( ( ( ( ( D7 * r + D6 ) * r + D5 ) * r + D4 ) * r + D3 ) * r + D2 ) * r + D1 ) * r + D0 );
}
else
{
r = r - 5.0;
num = ( ( ( ( ( ( ( E7 * r + E6 ) * r + E5 ) * r + E4 ) * r + E3 ) * r + E2 ) * r + E1 ) * r + E0 );
den = ( ( ( ( ( ( ( F7 * r + F6 ) * r + F5 ) * r + F4 ) * r + F3 ) * r + F2 ) * r + F1 ) * r + F0 );
}

return std::copysign<T>( num / den, x );
}

} // namespace Seldon::Math
94 changes: 93 additions & 1 deletion include/util/math.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#pragma once
#include "fmt/core.h"
#include "util/erfinv.hpp"
#include <algorithm>
#include <cstddef>
#include <optional>
Expand Down Expand Up @@ -129,9 +131,14 @@ class power_law_distribution

template<typename Generator>
ScalarT operator()( Generator & gen )
{
return inverse_cumulative_probability( dist( gen ) );
}

ScalarT inverse_cumulative_probability( ScalarT x )
{
return std::pow(
( 1.0 - std::pow( eps, ( 1.0 - gamma ) ) ) * dist( gen ) + std::pow( eps, ( 1.0 - gamma ) ),
( 1.0 - std::pow( eps, ( 1.0 - gamma ) ) ) * x + std::pow( eps, ( 1.0 - gamma ) ),
( 1.0 / ( 1.0 - gamma ) ) );
}

Expand All @@ -157,6 +164,16 @@ class truncated_normal_distribution
std::normal_distribution<ScalarT> normal_dist{};
size_t max_tries = 5000;

ScalarT inverse_cum_gauss( ScalarT y )
{
return erfinv( 2.0 * y - 1 ) * std::sqrt( 2.0 ) * sigma + mean;
}

ScalarT cum_gauss( ScalarT x )
{
return 0.5 * ( 1 + std::erf( ( x - mean ) / ( sigma * std::sqrt( 2.0 ) ) ) );
}

public:
truncated_normal_distribution( ScalarT mean, ScalarT sigma, ScalarT eps )
: mean( mean ), sigma( sigma ), eps( eps ), normal_dist( std::normal_distribution<ScalarT>( mean, sigma ) )
Expand All @@ -174,6 +191,81 @@ class truncated_normal_distribution
}
return eps;
}

ScalarT inverse_cumulative_probability( ScalarT y )
{
return inverse_cum_gauss(
y * ( 1.0 - cum_gauss( eps, sigma, mean ) ) + cum_gauss( eps, sigma, mean ), sigma, mean );
}
};

/**
* @brief Bivariate normal distribution
* with mean mu = [0,0]
* and covariance matrix Sigma = [[1, cov], [cov, 1]]
* |cov| < 1 is required
*/
template<typename ScalarT = double>
class bivariate_normal_distribution
{
private:
ScalarT covariance;
std::normal_distribution<ScalarT> normal_dist{};

public:
bivariate_normal_distribution( ScalarT covariance ) : covariance( covariance ) {}

template<typename Generator>
std::array<ScalarT, 2> operator()( Generator & gen )
{
ScalarT n1 = normal_dist( gen );
ScalarT n2 = normal_dist( gen );

ScalarT r1 = n1;
ScalarT r2 = covariance * n1 + std::sqrt( 1 - covariance * covariance );

return { r1, r2 };
}
};

template<typename ScalarT, typename dist1T, typename dist2T>
class bivariate_gaussian_copula
{
private:
ScalarT covariance;
bivariate_normal_distribution<ScalarT> bivariate_normal_dist{};
// std::normal_distribution<ScalarT> normal_dist{};

// Cumulative probability function for gaussian with mean 0 and variance 1
ScalarT cum_gauss( ScalarT x )
{
return 0.5 * ( 1 + std::erf( ( x ) / std::sqrt( 2.0 ) ) );
}

dist1T dist1;
dist2T dist2;

public:
bivariate_gaussian_copula( ScalarT covariance, dist1T dist1, dist2T dist2 )
: covariance( covariance ),
dist1( dist1 ),
dist2( dist2 ),
bivariate_normal_dist( bivariate_normal_dist( covariance ) )
{
}

template<typename Generator>
std::array<ScalarT, 2> operator()( Generator & gen )
{
// 1. Draw from bivariate gaussian
auto z = bivariate_normal_dist( gen );
// 2. Transform marginals to unit interval
std::array<ScalarT, 2> z_unit = { cum_gauss( z[0] ), cum_gauss( z[1] ) };
// 3. Apply inverse transform sampling
std::array<ScalarT, 2> res
= { dist1.inverse_cumulative_probability( z_unit[0] ), dist2.inverse_cumulative_probability( z_unit[1] ) };
return res;
}
};

template<typename T>
Expand Down
2 changes: 1 addition & 1 deletion meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ tests = [
['Test_Network_Generation', 'test/test_network_generation.cpp'],
['Test_Sampling', 'test/test_sampling.cpp'],
['Test_IO', 'test/test_io.cpp'],
['Test_Util', 'test/test_util.cpp'],
['Test_Prob', 'test/test_probability_distributions.cpp'],
]

Catch2 = dependency('Catch2', method : 'cmake', modules : ['Catch2::Catch2WithMain', 'Catch2::Catch2'])
Expand Down
70 changes: 70 additions & 0 deletions test/test_probability_distributions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_floating_point.hpp>
#include <catch2/matchers/catch_matchers_range_equals.hpp>

#include "util/math.hpp"
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <iomanip>
#include <random>
namespace fs = std::filesystem;

template<std::size_t N>
std::ostream & operator<<( std::ostream & os, std::array<double, N> const & v1 )
{
std::for_each( begin( v1 ), end( v1 ), [&os]( int val ) { os << val << " "; } );
return os;
}

// Samples the distribution n_samples times and writes results to file
template<typename distT>
void write_results_to_file( int N_Samples, distT dist, const std::string & filename )
{
auto proj_root_path = fs::current_path();
auto file = proj_root_path / fs::path( "/test/output/" + filename );
fs::create_directories( file );

auto gen = std::mt19937( 0 );

std::ofstream filestream( file );
filestream << std::setprecision( 16 );

for( int i = 0; i < N_Samples; i++ )
{
filestream << dist( gen ) << "\n";
}
filestream.close();
}

TEST_CASE( "Test the probability distributions", "[prob]" )
{
write_results_to_file( 10000, Seldon::truncated_normal_distribution( 1.0, 0.5, 0.1 ), "truncated_normal.txt" );
write_results_to_file( 10000, Seldon::power_law_distribution( 0.01, 2.1 ), "power_law.txt" );
write_results_to_file( 10000, Seldon::bivariate_normal_distribution( 0.5 ), "bivariate_normal.txt" );
}

// TEST_CASE( "Test reading in the agents from a file", "[io_agents]" )
// {
// using namespace Seldon;
// using namespace Catch::Matchers;

// auto proj_root_path = fs::current_path();
// auto network_file = proj_root_path / fs::path( "test/res/opinions.txt" );

// auto agents = Seldon::agents_from_file<ActivityDrivenModel::AgentT>( network_file );

// std::vector<double> opinions_expected = { 2.1127107987061544, 0.8088982488089491, -0.8802809369462433 };
// std::vector<double> activities_expected = { 0.044554683389757696, 0.015813166022685163, 0.015863953902810535 };
// std::vector<double> reluctances_expected = { 1.0, 1.0, 2.3 };

// REQUIRE( agents.size() == 3 );

// for( size_t i = 0; i < agents.size(); i++ )
// {
// fmt::print( "{}", i );
// REQUIRE_THAT( agents[i].data.opinion, Catch::Matchers::WithinAbs( opinions_expected[i], 1e-16 ) );
// REQUIRE_THAT( agents[i].data.activity, Catch::Matchers::WithinAbs( activities_expected[i], 1e-16 ) );
// REQUIRE_THAT( agents[i].data.reluctance, Catch::Matchers::WithinAbs( reluctances_expected[i], 1e-16 ) );
// }
// }

0 comments on commit d140126

Please sign in to comment.