Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ULTRA-LM: Language model integration for ULTRA #24

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,6 @@ output/
.vscode/
.DS_Store
kg-datasets/
*.csv
*.csv

kg-datasets/
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,30 @@ The queries were sampled from the WikiTopics splits proposed in [Double Equivari

New metrics include `auroc`, `spearmanr`, `mape`. We don't support Mean Rank `mr` in complex queries. If you ever see `nan` in one of those metrics, consider reducing the batch size as those metrics are computed with the variadic functions that might be numerically unstable on large batches.

## ULTRA-LM ##
ULTRA-LM (Language Model Integration for ULTRA) is a new variant of ULTRA that integrates a language model embeddings into the KG reasoning pipeline. We assume that each entity has a textual description and we use a pre-trained language model to encode those descriptions into embeddings.</br>
ULTRA-LM architecture is inspired by [Galkin's suggestion](https://github.com/DeepGraphLearning/ULTRA/issues/9).

### Blogpost
TBD Soon!

### Dataset
Currently, we provide one dataset supporting ULTRA-LM - `RedHatCVE` - a dataset of security vulnerabilities with textual descriptions. </br>
The dataset inherit the `TransductiveDataset` class, the valid and test splits contains only unseen CVEs, thus the dataset is inductive (node-wise) in nature. The embeddings are obtained from the `OpenAI Ada`, however, you can use any other language model.</br>

The original goal of the dataset is to predict the `cpe` (Common Platform Enumeration) of a given vulnerability. Therefore, we are mostly intrested in the `MatchingCVE` (head prediction) relation type during the evaluation. Complete details on the dataset can be found in [VulnScopper pre-print](https://deepness-lab.org/publications/unveiling-hidden-links-between-unseen-security-entities/).

### Running ULTRA-LM
To run ULTRA-LM, you need to download the pre-trained language model embeddings from [here](https://github.com/acsac24submissionvulnscopper/VulnScopper/releases/download/dataset/redhat_entity2vec.pickle).</br> Place the file within ULTRA's root directory (or any inner directory).

**IMPORTANT!** Remember to replace `lm_vectors` parameter with the complete path (from the root, without using `~`) to the `redhat_entity2vec.pickle` file in the `config/ultralm/pretrain.yaml` configuration file.

To run ULTRA-LM with multiple GPUs, use the following commands:

```bash
python -m torch.distributed.launch --nproc_per_node=2 script/pretrain_lm.py -c config/ultralm/pretrain.yaml --dataset RedHatCVE --epochs 10 --bpe null --gpus [0,1]
```

## Citation ##

If you find this codebase useful in your research, please cite the original papers.
Expand Down
46 changes: 46 additions & 0 deletions config/ultralm/pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
output_dir: ~/ULTRA/output

dataset:
class: JointDataset
graphs: [RedHatCVE]
root: ~/ULTRA/kg-datasets/

model:
class: Ultra
relation_model:
class: RelNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
aggregate_func: sum
short_cut: yes
layer_norm: yes
entity_model:
class: LmNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
aggregate_func: sum
short_cut: yes
layer_norm: yes

task:
name: MultiGraphPretraining
num_negative: 512
strict_negative: yes
adversarial_temperature: 1
metric: [mr, mrr, hits@1, hits@3, hits@10]

optimizer:
class: AdamW
lr: 5.0e-4

train:
gpus: {{ gpus }}
batch_size: 32
num_epoch: 10
log_interval: 50
batch_per_epoch: 1000
fast_test: 500
#logger: wandb
lm_vectors: /home/REPLACE_WITH_USER/ULTRA/descriptions-datasets/redhat_entity2vec.pickle
293 changes: 293 additions & 0 deletions script/pretrain_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
import os
import sys
import copy
import math
import pprint
from itertools import islice
from functools import partial

import torch
from torch import optim
from torch import nn
from torch.nn import functional as F
from torch import distributed as dist
from torch.utils import data as torch_data
from torch_geometric.data import Data

sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from ultra import tasks, util
from ultra.models import Ultra


separator = ">" * 30
line = "-" * 30


def multigraph_collator(batch, train_graphs):
num_graphs = len(train_graphs)
probs = torch.tensor([graph.edge_index.shape[1] for graph in train_graphs]).float()
probs /= probs.sum()
graph_id = torch.multinomial(probs, 1, replacement=False).item()

graph = train_graphs[graph_id]
bs = len(batch)
edge_mask = torch.randperm(graph.target_edge_index.shape[1])[:bs]

batch = torch.cat([graph.target_edge_index[:, edge_mask], graph.target_edge_type[edge_mask].unsqueeze(0)]).t()
return graph, batch

# here we assume that train_data and valid_data are tuples of datasets
def train_and_validate(cfg, model, train_data, valid_data, filtered_data=None, batch_per_epoch=None):

if cfg.train.num_epoch == 0:
return

world_size = util.get_world_size()
rank = util.get_rank()

train_triplets = torch.cat([
torch.cat([g.target_edge_index, g.target_edge_type.unsqueeze(0)]).t()
for g in train_data
])
sampler = torch_data.DistributedSampler(train_triplets, world_size, rank)
train_loader = torch_data.DataLoader(train_triplets, cfg.train.batch_size, sampler=sampler, collate_fn=partial(multigraph_collator, train_graphs=train_data))

batch_per_epoch = batch_per_epoch or len(train_loader)

cls = cfg.optimizer.pop("class")
optimizer = getattr(optim, cls)(model.parameters(), **cfg.optimizer)
num_params = sum(p.numel() for p in model.parameters())
logger.warning(line)
logger.warning(f"Number of parameters: {num_params}")

if world_size > 1:
parallel_model = nn.parallel.DistributedDataParallel(model, device_ids=[device])
else:
parallel_model = model

step = math.ceil(cfg.train.num_epoch / 10)
best_result = float("-inf")
best_epoch = -1

batch_id = 0
for i in range(0, cfg.train.num_epoch, step):
parallel_model.train()
for epoch in range(i, min(cfg.train.num_epoch, i + step)):
if util.get_rank() == 0:
logger.warning(separator)
logger.warning("Epoch %d begin" % epoch)

losses = []
sampler.set_epoch(epoch)
for batch in islice(train_loader, batch_per_epoch):
# now at each step we sample a new graph and edges from it
train_graph, batch = batch
batch = tasks.negative_sampling(train_graph, batch, cfg.task.num_negative,
strict=cfg.task.strict_negative)
pred = parallel_model(train_graph, batch)
target = torch.zeros_like(pred)
target[:, 0] = 1
loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
neg_weight = torch.ones_like(pred)
if cfg.task.adversarial_temperature > 0:
with torch.no_grad():
neg_weight[:, 1:] = F.softmax(pred[:, 1:] / cfg.task.adversarial_temperature, dim=-1)
else:
neg_weight[:, 1:] = 1 / cfg.task.num_negative
loss = (loss * neg_weight).sum(dim=-1) / neg_weight.sum(dim=-1)
loss = loss.mean()

loss.backward()
optimizer.step()
optimizer.zero_grad()

if util.get_rank() == 0 and batch_id % cfg.train.log_interval == 0:
logger.warning(separator)
logger.warning("binary cross entropy: %g" % loss)
losses.append(loss.item())
batch_id += 1

if util.get_rank() == 0:
avg_loss = sum(losses) / len(losses)
logger.warning(separator)
logger.warning("Epoch %d end" % epoch)
logger.warning(line)
logger.warning("average binary cross entropy: %g" % avg_loss)

epoch = min(cfg.train.num_epoch, i + step)
if rank == 0:
logger.warning("Save checkpoint to model_epoch_%d.pth" % epoch)
state = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict()
}
torch.save(state, "model_epoch_%d.pth" % epoch)
util.synchronize()

if rank == 0:
logger.warning(separator)
logger.warning("Evaluate on valid")
result = test(cfg, model, valid_data, filtered_data=filtered_data)
if result > best_result:
best_result = result
best_epoch = epoch

if rank == 0:
logger.warning("Load checkpoint from model_epoch_%d.pth" % best_epoch)
state = torch.load("model_epoch_%d.pth" % best_epoch, map_location=device)
model.load_state_dict(state["model"])
util.synchronize()


@torch.no_grad()
def test(cfg, model, test_data, filtered_data=None):
world_size = util.get_world_size()
rank = util.get_rank()

# test_data is a tuple of validation/test datasets
# process sequentially
all_metrics = []
for test_graph, filters in zip(test_data, filtered_data):

test_triplets = torch.cat([test_graph.target_edge_index, test_graph.target_edge_type.unsqueeze(0)]).t()
sampler = torch_data.DistributedSampler(test_triplets, world_size, rank)
test_loader = torch_data.DataLoader(test_triplets, cfg.train.batch_size, sampler=sampler)

model.eval()
rankings = []
num_negatives = []
for batch in test_loader:
t_batch, h_batch = tasks.all_negative(test_graph, batch)
t_pred = model(test_graph, t_batch)
h_pred = model(test_graph, h_batch)

if filtered_data is None:
t_mask, h_mask = tasks.strict_negative_mask(test_graph, batch)
else:
t_mask, h_mask = tasks.strict_negative_mask(filters, batch)
pos_h_index, pos_t_index, pos_r_index = batch.t()
t_ranking = tasks.compute_ranking(t_pred, pos_t_index, t_mask)
h_ranking = tasks.compute_ranking(h_pred, pos_h_index, h_mask)
num_t_negative = t_mask.sum(dim=-1)
num_h_negative = h_mask.sum(dim=-1)

rankings += [t_ranking, h_ranking]
num_negatives += [num_t_negative, num_h_negative]

ranking = torch.cat(rankings)
num_negative = torch.cat(num_negatives)
all_size = torch.zeros(world_size, dtype=torch.long, device=device)
all_size[rank] = len(ranking)
if world_size > 1:
dist.all_reduce(all_size, op=dist.ReduceOp.SUM)
cum_size = all_size.cumsum(0)
all_ranking = torch.zeros(all_size.sum(), dtype=torch.long, device=device)
all_ranking[cum_size[rank] - all_size[rank]: cum_size[rank]] = ranking
all_num_negative = torch.zeros(all_size.sum(), dtype=torch.long, device=device)
all_num_negative[cum_size[rank] - all_size[rank]: cum_size[rank]] = num_negative
if world_size > 1:
dist.all_reduce(all_ranking, op=dist.ReduceOp.SUM)
dist.all_reduce(all_num_negative, op=dist.ReduceOp.SUM)

if rank == 0:
for metric in cfg.task.metric:
if metric == "mr":
score = all_ranking.float().mean()
elif metric == "mrr":
score = (1 / all_ranking.float()).mean()
elif metric.startswith("hits@"):
values = metric[5:].split("_")
threshold = int(values[0])
if len(values) > 1:
num_sample = int(values[1])
# unbiased estimation
fp_rate = (all_ranking - 1).float() / all_num_negative
score = 0
for i in range(threshold):
# choose i false positive from num_sample - 1 negatives
num_comb = math.factorial(num_sample - 1) / \
math.factorial(i) / math.factorial(num_sample - i - 1)
score += num_comb * (fp_rate ** i) * ((1 - fp_rate) ** (num_sample - i - 1))
score = score.mean()
else:
score = (all_ranking <= threshold).float().mean()
logger.warning("%s: %g" % (metric, score))
mrr = (1 / all_ranking.float()).mean()

all_metrics.append(mrr)
if rank == 0:
logger.warning(separator)

avg_metric = sum(all_metrics) / len(all_metrics)
return avg_metric


if __name__ == "__main__":
args, vars = util.parse_args()
cfg = util.load_config(args.config, context=vars)
working_dir = util.create_working_directory(cfg)

torch.manual_seed(args.seed + util.get_rank())

logger = util.get_root_logger()
if util.get_rank() == 0:
logger.warning("Random seed: %d" % args.seed)
logger.warning("Config file: %s" % args.config)
logger.warning(pprint.pformat(cfg))

task_name = cfg.task["name"]
dataset = util.build_dataset(cfg)
device = util.get_device(cfg)

train_data, valid_data, test_data = dataset._data[0], dataset._data[1], dataset._data[2]

if "fast_test" in cfg.train:
num_val_edges = cfg.train.fast_test
if util.get_rank() == 0:
logger.warning(f"Fast evaluation on {num_val_edges} samples in validation")
short_valid = [copy.deepcopy(vd) for vd in valid_data]
for graph in short_valid:
mask = torch.randperm(graph.target_edge_index.shape[1])[:num_val_edges]
graph.target_edge_index = graph.target_edge_index[:, mask]
graph.target_edge_type = graph.target_edge_type[mask]

short_valid = [sv.to(device) for sv in short_valid]

train_data = [td.to(device) for td in train_data]
valid_data = [vd.to(device) for vd in valid_data]
test_data = [tst.to(device) for tst in test_data]

lm_vectors = util.load_language_model_vectors(train_data[0].metadata['entity_vocab'], cfg.train.lm_vectors)
cfg.model.entity_model["lm_vectors"] = lm_vectors

model = Ultra(
rel_model_cfg=cfg.model.relation_model,
entity_model_cfg=cfg.model.entity_model,
)

if "checkpoint" in cfg:
state = torch.load(cfg.checkpoint, map_location="cpu")
model.load_state_dict(state["model"])

model = model.to(device)

assert task_name == "MultiGraphPretraining", "Only the MultiGraphPretraining task is allowed for this script"

# for transductive setting, use the whole graph for filtered ranking
filtered_data = [
Data(
edge_index=torch.cat([trg.target_edge_index, valg.target_edge_index, testg.target_edge_index], dim=1),
edge_type=torch.cat([trg.target_edge_type, valg.target_edge_type, testg.target_edge_type,]),
num_nodes=trg.num_nodes).to(device)
for trg, valg, testg in zip(train_data, valid_data, test_data)
]

train_and_validate(cfg, model, train_data, valid_data if "fast_test" not in cfg.train else short_valid, filtered_data=filtered_data, batch_per_epoch=cfg.train.batch_per_epoch)
if util.get_rank() == 0:
logger.warning(separator)
logger.warning("Evaluate on valid")
test(cfg, model, valid_data, filtered_data=filtered_data)
if util.get_rank() == 0:
logger.warning(separator)
logger.warning("Evaluate on test")
test(cfg, model, test_data, filtered_data=filtered_data)
Loading