From f5e704a6f25939478f770f8980c344ab461f0113 Mon Sep 17 00:00:00 2001 From: laithsakka Date: Sun, 11 Aug 2024 16:22:18 -0700 Subject: [PATCH] Add instruction count benchmark to run on pull requests (#131475) This PR only adds the execution of the benchmarks on this PR and print results, following diffs will add checking out head~1 and running it and comparing. to access results goto test pr_time_benchmarks and inspect logs: you should see ``` + echo 'benchmark results on current PR: ' benchmark results on current PR: + cat /var/lib/jenkins/workspace/test/test-reports/pr_time_benchmarks_before.txt update_hint_regression,instruction_count,27971461254 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/131475 Approved by: https://github.com/ezyang --- .ci/pytorch/test.sh | 15 +++- .github/workflows/pull.yml | 22 ++++- benchmarks/__init__.py | 0 .../dynamo/pr_time_benchmarks/__init__.py | 0 .../pr_time_benchmarks/benchmark_base.py | 64 ++++++++++++++ .../pr_time_benchmarks/benchmark_runner.sh | 25 ++++++ .../benchmarks/update_hint_benchmark.py | 46 ++++++++++ .../instruction_counts/applications/ci.py | 1 + benchmarks/instruction_counts/core/api.py | 1 + benchmarks/instruction_counts/core/expand.py | 1 + benchmarks/instruction_counts/core/types.py | 1 + benchmarks/instruction_counts/core/utils.py | 1 + .../instruction_counts/definitions/setup.py | 2 +- .../definitions/standard.py | 1 + .../instruction_counts/execution/runner.py | 1 + .../instruction_counts/execution/work.py | 1 + benchmarks/instruction_counts/main.py | 1 + .../benchmark_all_other_test.py | 3 +- .../benchmark_all_quantized_test.py | 3 +- .../operator_benchmark/benchmark_all_test.py | 3 +- benchmarks/operator_benchmark/pt/conv_test.py | 3 +- .../pt/embeddingbag_test.py | 3 +- .../operator_benchmark/pt/gather_test.py | 1 - .../pt/index_select_test.py | 1 - .../operator_benchmark/pt/linear_test.py | 3 +- .../pt/qatembedding_ops_test.py | 3 +- .../operator_benchmark/pt/qconv_test.py | 3 +- .../pt/qembedding_bag_lookups_test.py | 1 - .../pt/qembeddingbag_test.py | 3 +- .../operator_benchmark/pt/qlinear_test.py | 3 +- build_variables.bzl | 1 + torch/csrc/Module.cpp | 2 + torch/csrc/instruction_counter/Module.cpp | 86 +++++++++++++++++++ torch/csrc/instruction_counter/Module.h | 8 ++ 34 files changed, 287 insertions(+), 26 deletions(-) create mode 100644 benchmarks/__init__.py create mode 100644 benchmarks/dynamo/pr_time_benchmarks/__init__.py create mode 100644 benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py create mode 100644 benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh create mode 100644 benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py create mode 100644 torch/csrc/instruction_counter/Module.cpp create mode 100644 torch/csrc/instruction_counter/Module.h diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 53bd1a50838d1..6d0797d41e2f3 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -392,7 +392,20 @@ test_inductor_cpp_wrapper_abi_compatible() { # .github/workflows/inductor-perf-test-nightly.yml DYNAMO_BENCHMARK_FLAGS=() -if [[ "${TEST_CONFIG}" == *dynamo_eager* ]]; then +pr_time_benchmarks() { + + TEST_REPORTS_DIR=$(pwd)/test/test-reports + mkdir -p "$TEST_REPORTS_DIR" + source benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh "$TEST_REPORTS_DIR/pr_time_benchmarks_after.txt" "benchmarks/dynamo/pr_time_benchmarks/benchmarks" + echo "benchmark results on current PR: " + cat "$TEST_REPORTS_DIR/pr_time_benchmarks_after.txt" + +} + +if [[ "${TEST_CONFIG}" == *pr_time_benchmarks* ]]; then + pr_time_benchmarks + exit 0 +elif [[ "${TEST_CONFIG}" == *dynamo_eager* ]]; then DYNAMO_BENCHMARK_FLAGS+=(--backend eager) elif [[ "${TEST_CONFIG}" == *aot_eager* ]]; then DYNAMO_BENCHMARK_FLAGS+=(--backend aot_eager) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 2deb1051669a9..997554562c8fe 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -110,7 +110,6 @@ jobs: { config: "default", shard: 1, num_shards: 1 }, ]} - linux-jammy-py3_10-clang15-asan-build: name: linux-jammy-py3.10-clang15-asan uses: ./.github/workflows/_linux-build.yml @@ -571,3 +570,24 @@ jobs: docker-image: ${{ needs.linux-focal-py3_12-clang10-experimental-split-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-py3_12-clang10-experimental-split-build.outputs.test-matrix }} timeout-minutes: 600 + + linux-focal-cuda12_1-py3_10-gcc9-inductor-build: + name: cuda12.1-py3.10-gcc9-sm75 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm75 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks + cuda-arch-list: '7.5' + test-matrix: | + { include: [ + { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, + ]} + + linux-focal-cuda12_1-py3_10-gcc9-inductor-test: + name: cuda12.1-py3.10-gcc9-sm75 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_1-py3_10-gcc9-inductor-build + with: + build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm75 + docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-inductor-build.outputs.test-matrix }} diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/benchmarks/dynamo/pr_time_benchmarks/__init__.py b/benchmarks/dynamo/pr_time_benchmarks/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py new file mode 100644 index 0000000000000..f599ddc2ba70e --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py @@ -0,0 +1,64 @@ +import csv +from abc import ABC, abstractmethod + +import torch._C._instruction_counter as i_counter + + +class BenchmarkBase(ABC): + _instruction_count = False + + def enable_instruction_count(self): + self._instruction_count = True + return self + + def name(self): + return "" + + def description(self): + return "" + + @abstractmethod + def prepare(self): + pass + + @abstractmethod + def work(self): + pass + + def prepare_once(self): # noqa: B027 + pass + + def count_instructions(self): + print(f"collecting instruction count for {self.name()}") + self.prepare_once() + + results = [] + for i in range(10): + self.prepare() + id = i_counter.start() + self.work() + count = i_counter.end(id) + print(f"instruction count for iteration {i} is {count}") + if i != 0: + results.append(count) + return min(results) + + def append_results(self, path): + with open(path, "a", newline="") as csvfile: + # Create a writer object + writer = csv.writer(csvfile) + # Write the data to the CSV file + for entry in self.results: + writer.writerow(entry) + + def print(self): + for entry in self.results: + print(f"{entry[0]},{entry[1]},{entry[2]}") + + def collect_all(self): + self.results = [] + if self._instruction_count: + self.results.append( + (self.name(), "instruction_count", self.count_instructions()) + ) + return self diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh b/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh new file mode 100644 index 0000000000000..c6128a3ba2df4 --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Check if the output file argument was provided +if [ $# -eq 0 ] +then + echo "Please provide the output file as an argument" + return +fi + +# Check if the directory of Python programs argument was provided +if [ $# -eq 1 ] +then + echo "Please provide the directory of Python programs as an argument" + return +fi + +# Set the output file +output_file=$1 +# Set the directory of Python programs +python_programs_dir=$2 +# Loop through all files in the directory of Python programs +for file in $python_programs_dir/*.py +do + # Execute the Python program and append the output to the output file + sudo env PATH="$PATH" python $file $output_file +done diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py new file mode 100644 index 0000000000000..cd398456faec7 --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py @@ -0,0 +1,46 @@ +import random +import sys + +from benchmarks.dynamo.pr_time_benchmarks.benchmark_base import BenchmarkBase + +import torch + + +class Benchmark(BenchmarkBase): + N = 20 + + def name(self): + return "update_hint_regression" + + def description(self): + return "information at https://github.com/pytorch/pytorch/pull/129893" + + def prepare_once(self): + torch._dynamo.config.capture_scalar_outputs = True + random.seed(42) + self.splits = torch.randint(10, (self.N,)) + sz = self.splits.sum().item() + self.input = torch.randn(sz) + + def prepare(self): + torch._dynamo.reset() + + def work(self): + @torch.compile(fullgraph=True) + def f(a, b): + xs = b.tolist() + for x in xs: + torch._check_is_size(x) + torch._check(x <= self.N) + return a.split(xs) + + f(self.input, self.splits) + + +def main(): + result_path = sys.argv[1] + Benchmark().enable_instruction_count().collect_all().append_results(result_path) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/instruction_counts/applications/ci.py b/benchmarks/instruction_counts/applications/ci.py index 67248d8b1976f..9d50ad0fee061 100644 --- a/benchmarks/instruction_counts/applications/ci.py +++ b/benchmarks/instruction_counts/applications/ci.py @@ -1,4 +1,5 @@ """Collect instruction counts for continuous integration.""" +# mypy: ignore-errors import argparse import hashlib import json diff --git a/benchmarks/instruction_counts/core/api.py b/benchmarks/instruction_counts/core/api.py index c598e726e9715..640ef3f19270a 100644 --- a/benchmarks/instruction_counts/core/api.py +++ b/benchmarks/instruction_counts/core/api.py @@ -1,4 +1,5 @@ """Key enums and structs used to handle data flow within the benchmark.""" +# mypy: ignore-errors import dataclasses import enum import itertools as it diff --git a/benchmarks/instruction_counts/core/expand.py b/benchmarks/instruction_counts/core/expand.py index c91b2311c4074..01b22533dbc6e 100644 --- a/benchmarks/instruction_counts/core/expand.py +++ b/benchmarks/instruction_counts/core/expand.py @@ -2,6 +2,7 @@ This is mostly string manipulation, with just a bit of importlib magic. """ +# mypy: ignore-errors import importlib.abc import importlib.util import itertools as it diff --git a/benchmarks/instruction_counts/core/types.py b/benchmarks/instruction_counts/core/types.py index 6fa91a94d4def..06c6c2e87d893 100644 --- a/benchmarks/instruction_counts/core/types.py +++ b/benchmarks/instruction_counts/core/types.py @@ -1,4 +1,5 @@ """Type annotations for various benchmark objects.""" +# mypy: ignore-errors from typing import Any, Dict, Optional, Tuple, Union from core.api import AutoLabels, GroupedBenchmark, TimerArgs diff --git a/benchmarks/instruction_counts/core/utils.py b/benchmarks/instruction_counts/core/utils.py index 0602b540fcb96..dbb1cd655af50 100644 --- a/benchmarks/instruction_counts/core/utils.py +++ b/benchmarks/instruction_counts/core/utils.py @@ -1,3 +1,4 @@ +# mypy: ignore-errors import atexit import re import shutil diff --git a/benchmarks/instruction_counts/definitions/setup.py b/benchmarks/instruction_counts/definitions/setup.py index 32f9b987ef1d7..fbc3798d9988f 100644 --- a/benchmarks/instruction_counts/definitions/setup.py +++ b/benchmarks/instruction_counts/definitions/setup.py @@ -1,5 +1,5 @@ """Define some common setup blocks which benchmarks can reuse.""" - +# mypy: ignore-errors import enum from core.api import GroupedSetup diff --git a/benchmarks/instruction_counts/definitions/standard.py b/benchmarks/instruction_counts/definitions/standard.py index 190c86a05198d..bfd05a0f7628f 100644 --- a/benchmarks/instruction_counts/definitions/standard.py +++ b/benchmarks/instruction_counts/definitions/standard.py @@ -11,6 +11,7 @@ - To set a label for the succeeding block, add `# @YOUR_LABEL` (Python) or `// @YOUR_LABEL` (C++). """ +# mypy: ignore-errors from core.api import GroupedModules, GroupedStmts, GroupedVariants from core.types import FlatIntermediateDefinition diff --git a/benchmarks/instruction_counts/execution/runner.py b/benchmarks/instruction_counts/execution/runner.py index c4faed86900af..8d18ba02bc200 100644 --- a/benchmarks/instruction_counts/execution/runner.py +++ b/benchmarks/instruction_counts/execution/runner.py @@ -1,4 +1,5 @@ """Run benchmarks while handling parallelism, isolation, and fault tolerance.""" +# mypy: ignore-errors import math import multiprocessing import subprocess diff --git a/benchmarks/instruction_counts/execution/work.py b/benchmarks/instruction_counts/execution/work.py index ab076854e41db..b1b77282c4521 100644 --- a/benchmarks/instruction_counts/execution/work.py +++ b/benchmarks/instruction_counts/execution/work.py @@ -1,4 +1,5 @@ """Handle the details of subprocess calls and retries for a given benchmark run.""" +# mypy: ignore-errors import dataclasses import json import os diff --git a/benchmarks/instruction_counts/main.py b/benchmarks/instruction_counts/main.py index 7ab128057ee8e..2f8e40b9dcb2e 100644 --- a/benchmarks/instruction_counts/main.py +++ b/benchmarks/instruction_counts/main.py @@ -5,6 +5,7 @@ components) in future iterations. However this allows us to excercise the underlying benchmark generation infrastructure in the mean time. """ +# mypy: ignore-errors import argparse import sys from typing import List diff --git a/benchmarks/operator_benchmark/benchmark_all_other_test.py b/benchmarks/operator_benchmark/benchmark_all_other_test.py index 05022e8407f0c..55aa640d6aa2d 100644 --- a/benchmarks/operator_benchmark/benchmark_all_other_test.py +++ b/benchmarks/operator_benchmark/benchmark_all_other_test.py @@ -1,3 +1,4 @@ +import operator_benchmark as op_bench from pt import ( # noqa: F401 add_test, ao_sparsifier_test, @@ -29,8 +30,6 @@ tensor_to_test, ) -import operator_benchmark as op_bench - if __name__ == "__main__": op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/benchmark_all_quantized_test.py b/benchmarks/operator_benchmark/benchmark_all_quantized_test.py index 00a792b580ce2..4f8781f8dce6b 100644 --- a/benchmarks/operator_benchmark/benchmark_all_quantized_test.py +++ b/benchmarks/operator_benchmark/benchmark_all_quantized_test.py @@ -1,3 +1,4 @@ +import operator_benchmark as op_bench from pt import ( # noqa: F401 qactivation_test, qarithmetic_test, @@ -21,8 +22,6 @@ qunary_test, ) -import operator_benchmark as op_bench - if __name__ == "__main__": op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/benchmark_all_test.py b/benchmarks/operator_benchmark/benchmark_all_test.py index f7d967c2c261a..193d7c436f3ae 100644 --- a/benchmarks/operator_benchmark/benchmark_all_test.py +++ b/benchmarks/operator_benchmark/benchmark_all_test.py @@ -1,8 +1,7 @@ import benchmark_all_other_test # noqa: F401 import benchmark_all_quantized_test # noqa: F401 -from pt import unary_test # noqa: F401 - import operator_benchmark as op_bench +from pt import unary_test # noqa: F401 if __name__ == "__main__": diff --git a/benchmarks/operator_benchmark/pt/conv_test.py b/benchmarks/operator_benchmark/pt/conv_test.py index 93b4942cea2b9..5b373a0a25de5 100644 --- a/benchmarks/operator_benchmark/pt/conv_test.py +++ b/benchmarks/operator_benchmark/pt/conv_test.py @@ -1,6 +1,5 @@ -from pt import configs - import operator_benchmark as op_bench +from pt import configs import torch import torch.nn as nn diff --git a/benchmarks/operator_benchmark/pt/embeddingbag_test.py b/benchmarks/operator_benchmark/pt/embeddingbag_test.py index d82d69707597e..7733230eae74e 100644 --- a/benchmarks/operator_benchmark/pt/embeddingbag_test.py +++ b/benchmarks/operator_benchmark/pt/embeddingbag_test.py @@ -1,7 +1,6 @@ import numpy -from pt import configs - import operator_benchmark as op_bench +from pt import configs import torch diff --git a/benchmarks/operator_benchmark/pt/gather_test.py b/benchmarks/operator_benchmark/pt/gather_test.py index 67be175f9c44e..5b8033a95733c 100644 --- a/benchmarks/operator_benchmark/pt/gather_test.py +++ b/benchmarks/operator_benchmark/pt/gather_test.py @@ -1,5 +1,4 @@ import numpy - import operator_benchmark as op_bench import torch diff --git a/benchmarks/operator_benchmark/pt/index_select_test.py b/benchmarks/operator_benchmark/pt/index_select_test.py index 870610071c117..5cab8507d9e7f 100644 --- a/benchmarks/operator_benchmark/pt/index_select_test.py +++ b/benchmarks/operator_benchmark/pt/index_select_test.py @@ -1,5 +1,4 @@ import numpy - import operator_benchmark as op_bench import torch diff --git a/benchmarks/operator_benchmark/pt/linear_test.py b/benchmarks/operator_benchmark/pt/linear_test.py index 6f5e239754585..ced8290e85f30 100644 --- a/benchmarks/operator_benchmark/pt/linear_test.py +++ b/benchmarks/operator_benchmark/pt/linear_test.py @@ -1,6 +1,5 @@ -from pt import configs - import operator_benchmark as op_bench +from pt import configs import torch import torch.nn as nn diff --git a/benchmarks/operator_benchmark/pt/qatembedding_ops_test.py b/benchmarks/operator_benchmark/pt/qatembedding_ops_test.py index 03fc533961040..30f5ff444e8f7 100644 --- a/benchmarks/operator_benchmark/pt/qatembedding_ops_test.py +++ b/benchmarks/operator_benchmark/pt/qatembedding_ops_test.py @@ -1,7 +1,6 @@ import numpy -from pt import configs - import operator_benchmark as op_bench +from pt import configs import torch import torch.ao.nn.qat as nnqat diff --git a/benchmarks/operator_benchmark/pt/qconv_test.py b/benchmarks/operator_benchmark/pt/qconv_test.py index 540480c297c05..5e3b9f479c60f 100644 --- a/benchmarks/operator_benchmark/pt/qconv_test.py +++ b/benchmarks/operator_benchmark/pt/qconv_test.py @@ -1,6 +1,5 @@ -from pt import configs - import operator_benchmark as op_bench +from pt import configs import torch import torch.ao.nn.quantized as nnq diff --git a/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py b/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py index f5fd9dc4ff345..369bc220bac63 100644 --- a/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py +++ b/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py @@ -1,7 +1,6 @@ from typing import Optional import numpy as np - import operator_benchmark as op_bench import torch diff --git a/benchmarks/operator_benchmark/pt/qembeddingbag_test.py b/benchmarks/operator_benchmark/pt/qembeddingbag_test.py index adc8c3b27695c..40ef69078e0cc 100644 --- a/benchmarks/operator_benchmark/pt/qembeddingbag_test.py +++ b/benchmarks/operator_benchmark/pt/qembeddingbag_test.py @@ -1,7 +1,6 @@ import numpy -from pt import configs - import operator_benchmark as op_bench +from pt import configs import torch import torch.ao.nn.quantized as nnq diff --git a/benchmarks/operator_benchmark/pt/qlinear_test.py b/benchmarks/operator_benchmark/pt/qlinear_test.py index d90b4e8af64ec..89bf77ddc4f42 100644 --- a/benchmarks/operator_benchmark/pt/qlinear_test.py +++ b/benchmarks/operator_benchmark/pt/qlinear_test.py @@ -1,6 +1,5 @@ -from pt import configs - import operator_benchmark as op_bench +from pt import configs import torch import torch.ao.nn.quantized as nnq diff --git a/build_variables.bzl b/build_variables.bzl index f9dc667613c42..b3b58169b79ad 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -913,6 +913,7 @@ libtorch_python_core_sources = [ "torch/csrc/utils/disable_torch_function.cpp", "torch/csrc/utils/verbose.cpp", "torch/csrc/cpu/Module.cpp", + "torch/csrc/instruction_counter/Module.cpp", ] + lazy_tensor_core_python_sources libtorch_python_distributed_core_sources = [ diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 7f53476280d88..244ddc43daf87 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -70,6 +70,7 @@ #include #include #include +#include #include #include #include @@ -1698,6 +1699,7 @@ PyObject* initModule() { #endif torch::mtia::initModule(module); torch::cpu::initModule(module); + torch::instruction_counter::initModule(module); torch::initVerboseBindings(module); ASSERT_TRUE(THPStorage_init(module)); diff --git a/torch/csrc/instruction_counter/Module.cpp b/torch/csrc/instruction_counter/Module.cpp new file mode 100644 index 0000000000000..4465769054872 --- /dev/null +++ b/torch/csrc/instruction_counter/Module.cpp @@ -0,0 +1,86 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__linux__) +#include +#include +#include +#include +#endif + +namespace torch::instruction_counter { + +long start() { +#if !defined(__linux__) + throw std::runtime_error("This systems seems not to be Linux"); +#else + + // Construct base perf_event_attr struct + perf_event_attr attr{}; + memset(&attr, 0, sizeof(attr)); + attr.size = sizeof(attr); + attr.exclude_kernel = 1; + attr.disabled = 1; + attr.exclude_hv = 1; + attr.sample_period = 0; + // Enable hardware counting + attr.type = PERF_TYPE_HARDWARE; + attr.config = PERF_COUNT_HW_INSTRUCTIONS; + + long fd = syscall(SYS_perf_event_open, &attr, 0, -1, -1, 0); + if (fd == -1) { + fprintf( + stderr, + "Failed to open instruction count event: %s.\n", + strerror(errno)); + return -1; + } + ioctl((int)fd, PERF_EVENT_IOC_RESET, 0); // Reset the counter + ioctl((int)fd, PERF_EVENT_IOC_ENABLE, 0); // Enable the counter + return fd; +#endif +} + +uint64_t end(int fd) { +#if !defined(__linux__) + throw std::runtime_error("This systems seems not to be Linux"); +#else + // Disable the event group + if (ioctl(fd, PERF_EVENT_IOC_DISABLE, PERF_IOC_FLAG_GROUP) == -1) { + fprintf( + stderr, + "Error disabling perf event (fd: %d): %s\n", + fd, + strerror(errno)); + return -1; + } + + uint64_t total_instructions = 0; + + // Read results + long ret_val = read(fd, &total_instructions, sizeof(total_instructions)); + if (ret_val == -1) { + fprintf(stderr, "Error reading perf event results: %s\n", strerror(errno)); + return -1; + } + + close(fd); + return total_instructions; +#endif +} + +void initModule(PyObject* module) { + auto m = py::handle(module).cast(); + auto instruction_counter = m.def_submodule( + "_instruction_counter", "instruction_counter related pybind."); + instruction_counter.def("start", start); + instruction_counter.def("end", end); +} + +} // namespace torch::instruction_counter diff --git a/torch/csrc/instruction_counter/Module.h b/torch/csrc/instruction_counter/Module.h new file mode 100644 index 0000000000000..ab56586ae24ea --- /dev/null +++ b/torch/csrc/instruction_counter/Module.h @@ -0,0 +1,8 @@ +#pragma once +#include + +namespace torch::instruction_counter { + +void initModule(PyObject* module); + +} // namespace torch::instruction_counter