001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache license, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the license for the specific language governing permissions and
015 * limitations under the license.
016 */
017package org.apache.logging.log4j.util;
018
019import java.io.IOException;
020import java.io.InputStream;
021import java.io.InvalidObjectException;
022import java.io.ObjectInputStream;
023import java.io.ObjectStreamClass;
024import java.util.Arrays;
025import java.util.Collection;
026import java.util.HashSet;
027import java.util.List;
028
029/**
030 * Extended ObjectInputStream that only allows certain classes to be deserialized.
031 *
032 * @since 2.8.2
033 */
034public class FilteredObjectInputStream extends ObjectInputStream {
035
036    private static final List<String> REQUIRED_JAVA_CLASSES = Arrays.asList(
037            "java.math.BigDecimal",
038            "java.math.BigInteger",
039            // for Message delegate
040            "java.rmi.MarshalledObject",
041            "[B"
042    );
043
044    private static final List<String> REQUIRED_JAVA_PACKAGES = Arrays.asList(
045            "java.lang.",
046            "java.time",
047            "java.util.",
048            "org.apache.logging.log4j.",
049            "[Lorg.apache.logging.log4j."
050    );
051
052    private final Collection<String> allowedClasses;
053
054    public FilteredObjectInputStream() throws IOException, SecurityException {
055        super();
056        this.allowedClasses = new HashSet<>();
057    }
058
059    public FilteredObjectInputStream(final InputStream in) throws IOException {
060        super(in);
061        this.allowedClasses = new HashSet<>();
062    }
063
064    public FilteredObjectInputStream(final Collection<String> allowedClasses) throws IOException, SecurityException {
065        super();
066        this.allowedClasses = allowedClasses;
067    }
068
069    public FilteredObjectInputStream(final InputStream in, final Collection<String> allowedClasses) throws IOException {
070        super(in);
071        this.allowedClasses = allowedClasses;
072    }
073
074    public Collection<String> getAllowedClasses() {
075        return allowedClasses;
076    }
077
078    @Override
079    protected Class<?> resolveClass(final ObjectStreamClass desc) throws IOException, ClassNotFoundException {
080        final String name = desc.getName();
081        if (!(isAllowedByDefault(name) || allowedClasses.contains(name))) {
082            throw new InvalidObjectException("Class is not allowed for deserialization: " + name);
083        }
084        return super.resolveClass(desc);
085    }
086
087    private static boolean isAllowedByDefault(final String name) {
088        return isRequiredPackage(name) || REQUIRED_JAVA_CLASSES.contains(name);
089    }
090
091    private static boolean isRequiredPackage(final String name) {
092        for (final String packageName : REQUIRED_JAVA_PACKAGES) {
093            if (name.startsWith(packageName)) {
094                return true;
095            }
096        }
097        return false;
098    }
099
100}