Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions rpd_tracer/Logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,20 @@ void rpdflush()
Logger::singleton().rpdflush();
}

void rpdFinalize()
{
Logger::singleton().rpdFinalize();
}
void rpdInit()
{
Logger::singleton().rpdInit();
}

void rpdResetFinalize()
{
Logger::singleton().rpdResetFinalize();
}

void rpd_rangePush(const char *domain, const char *apiName, const char* args)
{
Logger::singleton().rpd_rangePush(domain, apiName, args);
Expand Down Expand Up @@ -95,6 +109,10 @@ void Logger::rpdInit() {
void Logger::rpdFinalize() {
Logger::singleton().finalize();
}
void Logger::rpdResetFinalize()
{
Logger::singleton().set_finalize_true();
}


void Logger::rpdstart()
Expand Down Expand Up @@ -249,6 +267,12 @@ void Logger::init()
static bool doFinalize = true;
std::mutex finalizeMutex;

void Logger::set_finalize_true()
{
std::lock_guard<std::mutex> guard(finalizeMutex);
doFinalize = true;
}

void Logger::finalize()
{
std::lock_guard<std::mutex> guard(finalizeMutex);
Expand Down
2 changes: 2 additions & 0 deletions rpd_tracer/Logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Logger
void rpdstart();
void rpdstop();
void rpdflush();
void rpdResetFinalize();

// External maker api
void rpd_rangePush(const char *domain, const char *apiName, const char* args);
Expand Down Expand Up @@ -80,6 +81,7 @@ class Logger

void init();
void finalize();
void set_finalize_true();

std::string m_filename;
bool m_writeOverheadRecords {true};
Expand Down
58 changes: 56 additions & 2 deletions rpd_tracer/rpdTracerControl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def initializeFile(cls):
connection.close()
#
os.environ["RPDT_FILENAME"] = cls.__filename
cls.__initFile = False
cls.__initFile = False


# You can set the output filename and optionally append to an exiting file.
Expand All @@ -87,9 +87,26 @@ def setFilename(cls, name, append = False):
os.environ["RPDT_FILENAME"] = cls.__filename
cls.__initFile = False

def __init__(self):
@classmethod
def rpdReset(cls):
rpdTracerControl.__rpd.rpdResetFinalize()
rpdTracerControl.__rpd.rpdFinalize()
#cls.__initFile = True

def __init__(self, file_name = None, nvtx = False):
if file_name != None:
# Force reset filename, if we are getting called in a loop.
rpdTracerControl.__initFile = True
rpdTracerControl.__filename = file_name
rpdTracerControl.initializeFile()
rpdTracerControl.loadLibrary()
self.nvtx = None
if nvtx:
import torch
self.nvtx = torch.autograd.profiler.emit_nvtx()
if file_name != None:
# Reinit for the new file
rpdTracerControl.__rpd.rpdInit()

def __del__(self):
pass
Expand All @@ -105,9 +122,46 @@ def flush(self):

def __enter__(self):
self.start()
if self.nvtx:
self.nvtx.__enter__()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.nvtx:
self.nvtx.__exit__(exc_type, exc_val, exc_tb)
if exc_type != None:
#Propagate exception
return False
self.stop()
self.flush()
rpdTracerControl.rpdReset()

def top_totals(self):
try:
conn = sqlite3.connect(rpdTracerControl.__filename)
cursor = conn.cursor()
cursor.execute("SELECT Name, TotalCalls, TotalDuration, Ave, Percentage FROM top;")
rows = cursor.fetchall()

if rows:
from prettytable import PrettyTable
import textwrap
table = PrettyTable()
table.field_names = ["Name", "TotalCalls", "TotalDuration", "Ave", "Percentage"]
table.align = "l"

for row in rows:
wrapped_name = '\n'.join(textwrap.wrap(row[0], 60))
table.add_row([wrapped_name] + list(row[1:]))

print(table)
else:
print("No data found in 'top' table.")

except sqlite3.Error as e:
print(f"Error querying database: {e}")
finally:
conn.close()

def rangePush(self, domain: str, apiName: str, args: str):
rpdTracerControl.__rpd.rpd_rangePush(bytes(domain, encoding='utf-8'), bytes(apiName, encoding='utf-8'), bytes(args, encoding='utf-8'))
Expand Down