diff --git a/test/serial_io_tcp.cpp b/test/serial_io_tcp.cpp index 8f86a44..6ba8f38 100644 --- a/test/serial_io_tcp.cpp +++ b/test/serial_io_tcp.cpp @@ -1,4 +1,4 @@ -/* Copyright 2018-2023 Espressif Systems (Shanghai) CO LTD +/* Copyright 2018-2024 Espressif Systems (Shanghai) CO LTD * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ using namespace std; const uint32_t PORT = 5555; static int sock = 0; ofstream file; +static chrono::time_point s_time_end; #if SERIAL_FLASHER_DEBUG_TRACE static void transfer_debug_print(const uint8_t *data, uint16_t size, bool write) @@ -115,26 +116,36 @@ esp_loader_error_t loader_port_write(const uint8_t *data, uint16_t size, uint32_ esp_loader_error_t loader_port_read(uint8_t *data, uint16_t size, uint32_t timeout) { - uint32_t written = 0; - int bytes_read = 0; + // Timeout is specified in milliseconds, split to seconds and microsecond remainder + const struct timeval timeout_values = { + .tv_sec = timeout / 1000, + .tv_usec = (timeout % 1000) * 1000 + }; + + if (setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, + (const char *)&timeout_values, sizeof(timeout_values)) != 0) { + cout << "Could not set socket read timeout\n"; + return ESP_LOADER_ERROR_FAIL; + } - do { - bytes_read = read(sock, &data[written], size - written); - if (bytes_read == 0) { + const int bytes_read = read(sock, data, size); + + if (bytes_read != size) { + if (errno == EWOULDBLOCK || errno == EAGAIN) { + cout << "A socket read timeout occurred\n"; + return ESP_LOADER_ERROR_TIMEOUT; + } else { cout << "Socket connection lost\n"; return ESP_LOADER_ERROR_FAIL; } + } #if SERIAL_FLASHER_DEBUG_TRACE - transfer_debug_print(data, bytes_read, false); + transfer_debug_print(data, bytes_read, false); #endif - file.write((const char *)&data[written], bytes_read); - file.flush(); - - written += bytes_read; - } while (written != size); - + file.write((const char *)data, size); + file.flush(); return ESP_LOADER_SUCCESS; } @@ -157,11 +168,15 @@ void loader_port_delay_ms(uint32_t ms) void loader_port_start_timer(uint32_t ms) { - (void)ms; + s_time_end = chrono::steady_clock::now() + chrono::milliseconds(ms); } uint32_t loader_port_remaining_time(void) { - return 1; + const auto remaining = s_time_end - chrono::steady_clock::now(); + + const auto remaining_ms = chrono::duration_cast(remaining).count(); + + return (remaining_ms > 0) ? (uint32_t)remaining_ms : 0; }