| | |
| | |
| | import torch |
| | import pickle_inspector |
| | import sys |
| | from pathlib import Path |
| |
|
| | debug = len(sys.argv) == 3 |
| |
|
| | dir = sys.argv[1] |
| | print("checking dir: " + dir) |
| |
|
| | BASE_DIR = Path(dir) |
| | EXTENSIONS = {'.pt', '.bin', '.ckpt'} |
| | BAD_CALLS = {'os', 'shutil', 'sys', 'requests', 'net'} |
| | BAD_SIGNAL = {'rm ', 'cat ', 'nc ', '/bin/sh '} |
| |
|
| | for path in BASE_DIR.glob(r'**/*'): |
| | if path.suffix in EXTENSIONS: |
| | print("") |
| | print("..." + path.as_posix()) |
| | result = torch.load(path.as_posix(), pickle_module=pickle_inspector.pickle) |
| | result_total = 0 |
| | result_other = 0 |
| | result_calls = {} |
| | result_signals = {} |
| | result_output = "" |
| |
|
| | for call in BAD_CALLS: |
| | result_calls[call] = 0 |
| |
|
| | for signal in BAD_SIGNAL: |
| | result_signals[signal] = 0 |
| |
|
| | for c in result.calls: |
| | for call in BAD_CALLS: |
| | if (c.find(call + ".") == 0): |
| | result_calls[call] += 1 |
| | result_total += 1 |
| | result_output += "\n--- found lib call (" + call + ") ---\n" |
| | result_output += c |
| | result_output += "\n---------------\n" |
| | break |
| | for signal in BAD_SIGNAL: |
| | if (c.find(signal) > -1): |
| | result_signals[signal] += 1 |
| | result_total += 1 |
| | result_output += "\n--- found malicious signal (" + signal + ") ---\n" |
| | result_output += c |
| | result_output += "\n---------------\n" |
| | break |
| |
|
| | if ( |
| | c.find("numpy.") != 0 and |
| | c.find("_codecs.") != 0 and |
| | c.find("collections.") != 0 and |
| | c.find("torch.") != 0): |
| | result_total += 1 |
| | result_other += 1 |
| | result_output += "\n--- found non-standard lib call ---\n" |
| | result_output += c |
| | result_output += "\n---------------\n" |
| |
|
| | if (result_total > 0): |
| | for call in BAD_CALLS: |
| | print("library call (" + call + ".): " + str(result_calls[call])) |
| | for signal in BAD_SIGNAL: |
| | print("malicious signal (" + signal + "): " + str(result_signals[signal])) |
| | print("non-standard calls: " + str(result_other)) |
| | print("total: " + str(result_total)) |
| | print("") |
| | print("SCAN FAILED") |
| |
|
| | if (debug): |
| | print(result_output) |
| | else: |
| | print("SCAN PASSED!") |
| |
|