Skip to content

Commit

Permalink
feat: added input validation for network creation
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianprelipcean committed Apr 18, 2024
1 parent c2a9001 commit 7356ffb
Showing 1 changed file with 83 additions and 5 deletions.
88 changes: 83 additions & 5 deletions src/models/network.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ NetworkRow *create_network(const char *network_query_str, int64_t *network_size)
double arrival_time;
double departure_time;
int stop_sequence;
Oid arrival_oid;
Oid departure_oid;
Oid trip_id_oid;
Oid stop_id_oid;
Oid stop_sequence_oid;

if (ret != SPI_OK_CONNECT)
{
Expand Down Expand Up @@ -56,18 +61,91 @@ NetworkRow *create_network(const char *network_query_str, int64_t *network_size)
return NULL;
}

if (SPI_tuptable->tupdesc->natts < 5)
{
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Expected at least 5 columns, got %d", SPI_tuptable->tupdesc->natts)));
}

// Retrieve data from SPI results
for (int i = 0; i < num_rows; i++)
{
HeapTuple tuple = SPI_tuptable->vals[i];
TupleDesc tupdesc = SPI_tuptable->tupdesc;
memset(network_rows[i].nulls, false, sizeof(network_rows[i].nulls));

trip_id_text = DatumGetTextP(SPI_getbinval(tuple, tupdesc, 1, &network_rows[i].nulls[0]));
stop_id_text = DatumGetTextP(SPI_getbinval(tuple, tupdesc, 2, &network_rows[i].nulls[1]));
arrival_time = DatumGetFloat8(SPI_getbinval(tuple, tupdesc, 3, &network_rows[i].nulls[2]));
departure_time = DatumGetFloat8(SPI_getbinval(tuple, tupdesc, 4, &network_rows[i].nulls[3]));
stop_sequence = DatumGetInt32(SPI_getbinval(tuple, tupdesc, 5, &network_rows[i].nulls[4]));
trip_id_oid = TupleDescAttr(tupdesc, 0)->atttypid;
stop_id_oid = TupleDescAttr(tupdesc, 1)->atttypid;
arrival_oid = TupleDescAttr(tupdesc, 2)->atttypid;
departure_oid = TupleDescAttr(tupdesc, 3)->atttypid;
stop_sequence_oid = TupleDescAttr(tupdesc, 4)->atttypid;

if (trip_id_oid == TEXTOID)
{
trip_id_text = DatumGetTextP(SPI_getbinval(tuple, tupdesc, 1, &network_rows[i].nulls[0]));
}
else
{
char *trip_id_type_name = format_type_be(trip_id_oid);
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Expected text for trip ID, got %s", trip_id_type_name)));
pfree(trip_id_type_name);
}

if (stop_id_oid == TEXTOID)
{
stop_id_text = DatumGetTextP(SPI_getbinval(tuple, tupdesc, 2, &network_rows[i].nulls[1]));
}
else
{
char *stop_id_type_name = format_type_be(stop_id_oid);
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Expected text for stop ID, got %s", stop_id_type_name)));
pfree(stop_id_type_name);
}

if (arrival_oid == FLOAT8OID)
{
arrival_time = DatumGetFloat8(SPI_getbinval(tuple, tupdesc, 3, &network_rows[i].nulls[2]));
}
else
{
char *arrival_type_name = format_type_be(arrival_oid);
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Expected float8 for arrival time, got %s", arrival_type_name)));
pfree(arrival_type_name);
}

if (departure_oid == FLOAT8OID)
{
departure_time = DatumGetFloat8(SPI_getbinval(tuple, tupdesc, 4, &network_rows[i].nulls[3]));
}
else
{
char *departure_type_name = format_type_be(departure_oid);
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Expected float8 for departure time, got %s", departure_type_name)));
pfree(departure_type_name);
}

if (stop_sequence_oid == INT4OID)
{
stop_sequence = DatumGetInt32(SPI_getbinval(tuple, tupdesc, 5, &network_rows[i].nulls[4]));
}
else
{
char *stop_sequence_type_name = format_type_be(stop_sequence_oid);
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Expected integer for stop sequence, got %s", stop_sequence_type_name)));
pfree(stop_sequence_type_name);
}

// Copy trip_id directly
strncpy(network_rows[i].trip_id, text_to_cstring(trip_id_text), MAX_STRING_LENGTH - 1);
network_rows[i].trip_id[MAX_STRING_LENGTH - 1] = '\0'; // Ensure null-terminated
Expand Down

0 comments on commit 7356ffb

Please sign in to comment.