Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions Framework/Core/include/Framework/DataModelViews.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct count_parts {
count += 1;
mi += header->splitPayloadParts + 1;
} else {
count += header->splitPayloadParts;
count += header->splitPayloadParts ? header->splitPayloadParts : 1;
mi += header->splitPayloadParts ? 2 * header->splitPayloadParts : 2;
}
}
Expand Down Expand Up @@ -104,11 +104,11 @@ struct get_pair {
}
mi += header->splitPayloadParts + 1;
} else {
count += header->splitPayloadParts ? header->splitPayloadParts : 1;
if (self.pairId < count) {
return {mi, mi + 2 * diff + 1};
if (self.pairId == count) {
return {mi, mi + 1};
}
mi += header->splitPayloadParts ? 2 * header->splitPayloadParts : 2;
count += 1;
mi += 2;
}
}
throw std::runtime_error("Payload not found");
Expand Down Expand Up @@ -138,10 +138,10 @@ struct get_dataref_indices {
mi += header->splitPayloadParts + 1;
} else {
if (self.part == count) {
return {mi, mi + 2 * self.subPart + 1};
return {mi, mi + self.subPart + 1};
}
count += 1;
mi += header->splitPayloadParts ? 2 * header->splitPayloadParts : 2;
mi += 2;
}
}
throw std::runtime_error("Payload not found");
Expand Down Expand Up @@ -172,32 +172,41 @@ struct get_payload {
};

struct get_num_payloads {
size_t id;
// ends the pipeline, returns the number of parts
size_t n;
// ends the pipeline, returns the number of payloads which are associated
// to the multipart n-th sequence of messages found in the range
template <typename R>
requires std::ranges::random_access_range<R> && std::ranges::sized_range<R>
friend size_t operator|(R&& r, get_num_payloads self)
{
size_t count = 0;
size_t mi = 0;
// Un
while (mi < r.size()) {
auto* header = o2::header::get<o2::header::DataHeader*>(r[mi]->GetData());
if (!header) {
throw std::runtime_error("Not a DataHeader");
}
if (self.id == count) {
if (header->splitPayloadParts > 1 && (header->splitPayloadIndex == header->splitPayloadParts)) {
if (header->splitPayloadParts > 1 && (header->splitPayloadIndex == header->splitPayloadParts)) {
// This is the case for the new multi payload messages where the number of parts
// is as many as the splitPayloadParts number.
if (self.n == count) {
return header->splitPayloadParts;
} else {
return 1;
}
}
if (header->splitPayloadParts > 1 && (header->splitPayloadIndex == header->splitPayloadParts)) {
// For multipayload we skip all the parts and their associated header
count += 1;
mi += header->splitPayloadParts + 1;
} else {
count += 1;
mi += header->splitPayloadParts ? 2 * header->splitPayloadParts : 2;
// This is the case of a multipart (header, payload), (header, payload), ...
// sequence where we know how many pairs are there.
// When splitPayloadParts == 0, it means it is a non-multipart (header, payload)
// pair. Each pair has exactly 1 payload.
auto pairs = header->splitPayloadParts ? header->splitPayloadParts : 1;
if (self.n < count + pairs) {
return 1;
}
count += pairs;
mi += 2 * pairs;
}
}
return 0;
Expand Down
159 changes: 159 additions & 0 deletions Framework/Core/test/test_DataRelayer.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
#include "Framework/WorkflowSpec.h"
#include <Monitoring/Monitoring.h>
#include <fairmq/TransportFactory.h>
#include <fairmq/Channel.h>
#include "Framework/FairMQDeviceProxy.h"
#include "Framework/ExpirationHandler.h"
#include "Framework/LifetimeHelpers.h"
#include <array>
#include <vector>
#include <uv.h>
Expand Down Expand Up @@ -808,4 +812,159 @@ TEST_CASE("DataRelayer")
}
}
}

SECTION("ProcessDanglingInputs")
{
InputSpec spec{"condition", "TST", "COND"};
std::vector<InputRoute> inputs = {
InputRoute{spec, 0, "from_source_to_self", 0}};

std::vector<InputChannelInfo> infos{1};
TimesliceIndex index{1, infos};
ref.registerService(ServiceRegistryHelpers::handleForService<TimesliceIndex>(&index));

// Bind a fake input channel so FairMQDeviceProxy::getInputChannelIndex works
FairMQDeviceProxy proxy;
std::vector<fair::mq::Channel> channels{fair::mq::Channel("from_source_to_self")};
auto findChannel = [&channels](std::string const& name) -> fair::mq::Channel& {
for (auto& ch : channels) {
if (ch.GetName() == name) {
return ch;
}
}
throw std::runtime_error("Channel not found: " + name);
};
proxy.bind({}, inputs, {}, findChannel, [] { return false; });
ref.registerService(ServiceRegistryHelpers::handleForService<FairMQDeviceProxy>(&proxy));

auto policy = CompletionPolicyHelpers::consumeWhenAny();
DataRelayer relayer(policy, inputs, index, {registry}, -1);
relayer.setPipelineLength(4);

auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq");
auto channelAlloc = o2::pmr::getTransportAllocator(transport.get());

DataHeader dh{"COND", "TST", 0};
dh.splitPayloadParts = 1;
dh.splitPayloadIndex = 0;
DataProcessingHeader dph{0, 1};

ExpirationHandler handler;
handler.name = "test-condition";
handler.routeIndex = RouteIndex{0};
handler.lifetime = Lifetime::Condition;

// Creator: claim an empty slot and assign timeslice 0 to it
handler.creator = [](ServiceRegistryRef services, ChannelIndex channelIndex) -> TimesliceSlot {
auto& index = services.get<TimesliceIndex>();
for (size_t si = 0; si < index.size(); si++) {
TimesliceSlot slot{si};
if (!index.isValid(slot)) {
index.associate(TimesliceId{0}, slot);
(void)index.setOldestPossibleInput({1}, channelIndex);
return slot;
}
}
return TimesliceSlot{TimesliceSlot::INVALID};
};

// Checker: always trigger expiration
handler.checker = LifetimeHelpers::expireAlways();

// Handler: materialise a dummy header+payload into the PartRef
handler.handler = [&transport, &channelAlloc, &dh, &dph](ServiceRegistryRef, PartRef& ref, data_matcher::VariableContext&) {
ref.header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph});
ref.payload = transport->CreateMessage(4);
};

std::vector<ExpirationHandler> handlers{handler};
auto activity = relayer.processDanglingInputs(handlers, {registry}, true);

REQUIRE(activity.newSlots == 1);
REQUIRE(activity.expiredSlots == 1);

// The materialised data should now be ready to consume
std::vector<RecordAction> ready;
relayer.getReadyToProcess(ready);
REQUIRE(ready.size() == 1);
REQUIRE(ready[0].op == CompletionPolicy::CompletionOp::Consume);

auto result = relayer.consumeAllInputsForTimeslice(ready[0].slot);
REQUIRE(result.size() == 1);
REQUIRE((result.at(0).messages | count_parts{}) == 1);
}

SECTION("ProcessDanglingInputsSkipsWhenDataPresent")
{
// processDanglingInputs must not overwrite a slot that already has data.
// This is guarded by the (part.messages | get_header{0}) != nullptr check.
InputSpec spec{"condition", "TST", "COND"};
std::vector<InputRoute> inputs = {
InputRoute{spec, 0, "from_source_to_self", 0}};

std::vector<InputChannelInfo> infos{1};
TimesliceIndex index{1, infos};
ref.registerService(ServiceRegistryHelpers::handleForService<TimesliceIndex>(&index));

FairMQDeviceProxy proxy;
std::vector<fair::mq::Channel> channels{fair::mq::Channel("from_source_to_self")};
auto findChannel = [&channels](std::string const& name) -> fair::mq::Channel& {
for (auto& ch : channels) {
if (ch.GetName() == name) {
return ch;
}
}
throw std::runtime_error("Channel not found: " + name);
};
proxy.bind({}, inputs, {}, findChannel, [] { return false; });
ref.registerService(ServiceRegistryHelpers::handleForService<FairMQDeviceProxy>(&proxy));

auto policy = CompletionPolicyHelpers::consumeWhenAny();
DataRelayer relayer(policy, inputs, index, {registry}, -1);
relayer.setPipelineLength(4);

auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq");
auto channelAlloc = o2::pmr::getTransportAllocator(transport.get());

DataHeader dh{"COND", "TST", 0};
dh.splitPayloadParts = 1;
dh.splitPayloadIndex = 0;
DataProcessingHeader dph{0, 1};

// Build an expiration handler that always tries to expire
ExpirationHandler handler;
handler.name = "test-condition";
handler.routeIndex = RouteIndex{0};
handler.lifetime = Lifetime::Condition;
handler.creator = [](ServiceRegistryRef services, ChannelIndex channelIndex) -> TimesliceSlot {
auto& index = services.get<TimesliceIndex>();
for (size_t si = 0; si < index.size(); si++) {
TimesliceSlot slot{si};
if (!index.isValid(slot)) {
index.associate(TimesliceId{0}, slot);
(void)index.setOldestPossibleInput({1}, channelIndex);
return slot;
}
}
return TimesliceSlot{TimesliceSlot::INVALID};
};
handler.checker = LifetimeHelpers::expireAlways();
int handlerCallCount = 0;
handler.handler = [&transport, &channelAlloc, &dh, &dph, &handlerCallCount](ServiceRegistryRef, PartRef& ref, data_matcher::VariableContext&) {
ref.header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph});
ref.payload = transport->CreateMessage(4);
handlerCallCount++;
};
std::vector<ExpirationHandler> handlers{handler};

// First call: slot is empty, so the handler fires and materialises data
auto activity1 = relayer.processDanglingInputs(handlers, {registry}, true);
REQUIRE(activity1.expiredSlots == 1);
REQUIRE(handlerCallCount == 1);

// Second call: slot already has data — the handler must NOT fire again
auto activity2 = relayer.processDanglingInputs(handlers, {registry}, false);
REQUIRE(activity2.expiredSlots == 0);
REQUIRE(handlerCallCount == 1); // handler was not called a second time
}
}
Loading
Loading