Temporal encoding for HGTConv #10469
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds support for Relative Temporal Encoding (RTE) to the
HGTConv
layer, as described in the original "Heterogeneous Graph Transformer" paper.Description of Changes
use_RTE
flag: A new boolean argumentuse_RTE
is added to theHGTConv
constructor to enable or disable temporal encoding. When enabled, it initializes aPositionalEncoding
module.New
forward
argument: The forward method now accepts an optionaledge_time_diff_dict
. This dictionary should contain a 1D tensor of time differences (∆T) for each edge type, which serves as the input to the encoding function.Input Validation: A new
_validate_inputs
helper function has been added to ensure that ifuse_RTE
is enabled, theedge_time_diff_dict
is provided and contains a time difference tensor for every edge type. It also warns the user if they provide time data whenuse_RTE
is disabled.RTE Application: In the
message
function, the calculated temporal encoding (temporal_features
) is added to the key (k_j
) and value (v_j
) vectors of the source nodes. This injects the temporal information directly into the attention mechanism.Implementation Note
This implementation adds temporal encoding to the key (
k_j
) and value (v_j
) vectors after their projection (a deviation from the paper) to preserve the efficient, parallelized node-level computation, which would otherwise become a much slower, edge-specific operation.Tests Added
Tests have been added to validate this feature:
References
Hu, Z., Dong, Y., Wang, K., & Sun, Y. (2020). Heterogeneous Graph Transformer.
arXiv link: https://arxiv.org/abs/2003.01332